"""
Plot a spectrum using matplotlib
"""
from copy import deepcopy
import astropy.units as u
import matplotlib as mpl
import numpy as np
from astropy.utils.masked import Masked
from matplotlib.widgets import SpanSelector
from dysh.log import logger
from ..coordinates import (
decode_veldef,
frame_to_label,
)
from ..util.docstring_manip import docstring_parameter
from . import check_kwargs, parse_html
from .plotbase import PlotBase
_KMS = u.km / u.s
kwargs_docstring = """xaxis_unit : str or `~astropy.unit.Unit`
The units to use on the x-axis, e.g. "km/s" to plot velocity.
yaxis_unit : str or `~astropy.unit.Unit`
The units to use on the y-axis.
xmin : float
Minimum x-axis value, in `xaxis_unit`.
xmax : float
Maximum x-axis value, in `yaxis_unit`.
ymin : float
Minimum y-axis value, in `xaxis_unit`.
ymax : float
Maximum y-axis value, in `yaxis_unit`.
xlabel : str
x-axis label.
ylabel : str
y-axis label.
label : str
Label for legend.
alpha : float
Alpha value for the plot. Between 0 and 1.
grid : bool
Show a plot grid or not.
figsize : tuple
Figure size (see matplotlib).
linewidth : float
Line width, default: 2.0.
drawstyle : str
Line style, default 'default'.
color : str
Line color, c also works.
title : str
Plot title.
vel_frame : str
The velocity frame (see VELDEF FITS Keyword).
doppler_convention: str
The velocity convention (see VELDEF FITS Keyword).
"""
[docs]
@docstring_parameter(kwargs_docstring)
class SpectrumPlot(PlotBase):
r"""
The SpectrumPlot class is for simple plotting of a `~dysh.spectra.spectrum.Spectrum`
using matplotlib functions. Plots attributes are modified using keywords
(\*\*kwargs) described below. SpectrumPlot will attempt to make smart default
choices for the plot if no additional keywords are given.
Parameters
----------
spectrum : `~dysh.spectra.spectrum.Spectrum`
The spectrum to plot
Other Parameters
----------------
{0}
"""
def __init__(self, spectrum, **kwargs):
super().__init__()
self.reset()
self._spectrum = spectrum
self._sa = spectrum._spectral_axis
self._set_xaxis_info()
self._plot_kwargs.update(kwargs)
self._title = self._plot_kwargs["title"]
self._selector: MultiSpanSelector = None
self._freezey = (self._plot_kwargs["ymin"] is not None) or (self._plot_kwargs["ymax"] is not None)
self._freezex = (self._plot_kwargs["xmin"] is not None) or (self._plot_kwargs["xmax"] is not None)
self._scan_numbers = np.array([self._spectrum.meta["SCAN"]])
def _set_xaxis_info(self):
"""Ensure the xaxis info is up to date if say, the spectrum frame has changed."""
self._plot_kwargs["doppler_convention"] = self._spectrum.doppler_convention
self._plot_kwargs["vel_frame"] = self._spectrum.velocity_frame
self._plot_kwargs["xaxis_unit"] = self._spectrum.spectral_axis.unit
self._plot_kwargs["yaxis_unit"] = self._spectrum.unit
@property
def spectrum(self):
"""The underlying `~dysh.spectra.spectrum.Spectrum`"""
return self._spectrum
[docs]
def default_plot_kwargs(self):
return {
"xmin": None,
"xmax": None,
"ymin": None,
"ymax": None,
"xlabel": None,
"ylabel": None,
"xaxis_unit": None,
"yaxis_unit": None,
"grid": False,
"label": None,
"alpha": 1.0,
"figsize": None,
"linewidth": 2.0,
"drawstyle": "default",
"color": None,
"title": None,
"doppler_convention": None,
"vel_frame": None,
}
[docs]
def reset(self):
"""Reset the plot keyword arguments to their defaults."""
self._plot_kwargs = self.default_plot_kwargs()
[docs]
@docstring_parameter(kwargs_docstring)
def plot(self, show_header=True, select=True, oshow=None, oshow_kwargs=None, **kwargs):
"""
Plot the spectrum.
Parameters
----------
show_header : bool
Show informational header.
select : bool
Allow selecting regions via click and drag.
oshow : list or `~dysh.spectra.spectrum.Spectrum`
Spectra to overlay in the plot.
oshow_kwargs : dict
Dictionary with parameters for `SpectrumPlot.oshow`.
These include color, linestyle, label, and alpha.
Other Parameters
----------------
{0}
"""
check_kwargs(self.default_plot_kwargs(), kwargs)
self._set_xaxis_info()
# Plot arguments for this call of plot(). i.e. non-sticky plot attributes
this_plot_kwargs = deepcopy(self._plot_kwargs)
this_plot_kwargs.update(kwargs)
# Clean up old resources before creating new figure/selector
if self._selector is not None:
self._selector.disconnect()
self._selector = None
if self.figure is not None:
self.figure = mpl.figure.Figure(figsize=(10, 6))
self.axes = self.figure.subplots(nrows=1, ncols=1)
# TODO: procedurally generate subplot params based on show header/buttons args.
# ideally place left/right params right here, then top gets determined below.
s = self._spectrum
lw = this_plot_kwargs["linewidth"]
self._xunit = this_plot_kwargs["xaxis_unit"] # need to kick back a ref to xunit for baseline overlays
self._yunit = this_plot_kwargs["yaxis_unit"]
if self._xunit is None:
self._xunit = str(sa.unit) # noqa: F821
if "vel_frame" not in this_plot_kwargs:
if u.Unit(self._xunit).is_equivalent("km/s") and "VELDEF" in s.meta:
# If the user specified velocity units, default to
# the velframe the data were taken in. This we can
# get from VELDEF keyword. See issue #303
this_plot_kwargs["vel_frame"] = decode_veldef(s.meta["VELDEF"])[1].lower()
else:
this_plot_kwargs["vel_frame"] = s.velocity_frame
if "chan" in str(self._xunit).lower():
self._sa = u.Quantity(np.arange(len(self._sa)))
this_plot_kwargs["xlabel"] = "Channel"
else:
# convert the x axis to the requested velocity frame and Doppler convention.
self._sa = s.velocity_axis_to(
unit=self._xunit,
toframe=this_plot_kwargs["vel_frame"],
doppler_convention=this_plot_kwargs["doppler_convention"],
)
sf = s.flux
if self._yunit is not None:
sf = s.flux.to(self._yunit)
sf = Masked(sf, s.mask)
lines = self.axes.plot(
self._sa,
sf,
color=this_plot_kwargs["color"],
lw=lw,
drawstyle=this_plot_kwargs["drawstyle"],
label=this_plot_kwargs["label"],
alpha=this_plot_kwargs["alpha"],
)
self._line = lines[0]
if this_plot_kwargs["label"] is not None:
self.axes.legend()
if not this_plot_kwargs["xmin"] and not this_plot_kwargs["xmax"]:
self.axes.set_xlim(np.min(self._sa).value, np.max(self._sa).value)
else:
self.axes.set_xlim(this_plot_kwargs["xmin"], this_plot_kwargs["xmax"])
if self._freezey:
self.axes.autoscale(enable=False)
else:
self.axes.autoscale(axis="y", enable=True)
self.axes.set_ylim(this_plot_kwargs["ymin"], this_plot_kwargs["ymax"])
self.axes.tick_params(axis="both", which="both", bottom=True, top=True, left=True, right=True, direction="in")
if this_plot_kwargs["grid"]:
self.axes.grid(visible=True, which="major", axis="both", lw=lw / 2, color="k", alpha=0.33)
self.axes.grid(visible=True, which="minor", axis="both", lw=lw / 2, color="k", alpha=0.22, linestyle="--")
self._set_labels(**this_plot_kwargs)
if self._title is not None:
self.axes.set_title(self._title)
if show_header:
self.figure.subplots_adjust(top=0.79, left=0.09, right=0.95)
self._set_header(s)
if select:
self._selector = MultiSpanSelector(self.axes, minspan=abs(self._sa[0].value - self._sa[1].value))
self._spectrum._selection = self._selector.get_selected_regions()
if oshow is not None:
if isinstance(oshow, type(self._spectrum)):
oshow = [oshow]
if type(oshow) is not list:
raise TypeError(f"oshow ({oshow}) must be a list or Spectrum")
for i, sp in enumerate(oshow):
if not isinstance(sp, type(self._spectrum)):
raise TypeError(f"Element {i} of oshow ({oshow}) is not a Spectrum")
if oshow_kwargs is None:
oshow_kwargs = {}
self.oshow(oshow, **oshow_kwargs)
def _compose_xlabel(self, **kwargs):
"""Create a sensible spectral axis label given units, velframe, and doppler convention"""
xlabel = kwargs.get("xlabel", None)
if xlabel is not None:
return xlabel
if kwargs["doppler_convention"] == "radio":
subscript = "_{rad}"
elif kwargs["doppler_convention"] == "optical":
subscript = "_{opt}"
elif kwargs["doppler_convention"] == "relativistic":
subscript = "_{rel}"
else: # should never happen
subscript = ""
if kwargs.get("xaxis_unit", None) is not None:
xunit = u.Unit(kwargs["xaxis_unit"])
else:
xunit = self.spectrum.spectral_axis.unit
if xunit.is_equivalent(u.Hz):
xname = r"\nu"
elif xunit.is_equivalent(_KMS):
xname = r"V" + subscript
elif xunit.is_equivalent(u.angstrom):
xname = r"\lambda"
# Channel is handled in plot() with kwargs['xlabel']
else:
raise ValueError(f"Unrecognized spectral axis unit: {xunit}")
_xunit = xunit.to_string(format="latex_inline")
xlabel = f"{frame_to_label[kwargs['vel_frame']]} ${xname}$ ({_xunit})"
return xlabel
def _set_labels(self, **kwargs):
r"""Set x and y labels according to spectral units
Parameters
----------
title : str
Plot title.
xlabel : str
x-axis label.
ylabel : str
x-axis label.
doppler_convention : str
Doppler convention for x-axis.
xaxis_unit : str
Units for x-axis.
yaxis_unit : str
Units for y-axis.
"""
title = kwargs.get("title", None)
ylabel = kwargs.get("ylabel", None)
if title is not None:
self._title = title
if kwargs.get("yaxis_unit", None) is not None:
yunit = u.Unit(kwargs["yaxis_unit"])
else:
yunit = self.spectrum.unit
self.axes.set_xlabel(self._compose_xlabel(**kwargs))
if ylabel is not None:
self.axes.set_ylabel(ylabel)
else:
if "TSCALE" in self.spectrum.meta:
ylabel = self.spectrum.meta["TSCALE"]
elif "TUNIT7" in self.spectrum.meta:
tunit7 = self.spectrum.meta["TUNIT7"]
if tunit7 == "Ta": # what about Ta*
ylabel = "Ta"
yunit = "K"
elif tunit7 == "Ta*":
ylabel = "Ta*"
yunit = "K"
elif tunit7 == "Jy":
ylabel = "Flux"
yunit = "Jy"
else:
ylabel = "Unknown"
yunit = "()"
logger.info(f"Missing TSCALE: patching Y-axis as '{ylabel} ({yunit})'")
self.axes.set_ylabel(f"{ylabel} ({yunit})")
def _show_exclude(self, **kwargs):
"""TODO: Method to show the exclude array on the plot"""
kwargs_opts = {
"loc": "bottom", # top,bottom ?
"color": "silver",
}
kwargs_opts.update(kwargs)
[docs]
def get_selected_regions(self):
""" """
regions = self._selector.get_selected_regions()
return [tuple(np.sort([np.argmin(abs(p - self._sa.value)) for p in r])) for r in regions]
[docs]
def freex(self):
"""Free the X-axis if limits have been set. Resets the limits to be the span of the spectrum."""
self._freezex = False
mins = []
maxs = []
for line in self.axes.lines:
mins.append(line._x.min())
maxs.append(line._x.max())
self.axes.set_xlim((min(mins), max(maxs)))
[docs]
def freey(self):
"""Free the Y-axis if limits have been set. Autoscales the Y-axis according to your matplotlib configuration."""
self._freezey = False
self.axes.relim()
self.axes.autoscale(axis="y", enable=True)
self.axes.autoscale_view()
[docs]
def freexy(self):
r"""Free the X and Y axes simultaneously. See `freex` and `freey` for more details."""
self.freex()
self.freey()
[docs]
def clear_overlays(self, blines=True, oshows=True, catalog=True):
"""Clear Overlays from the plot.
Parameters
----------
blines : bool
Remove baseline models overlaid on the plot. Default: True
oshows : bool
Remove other spectra overlaid on the plot. Default: True
catalog : bool
Remove catalog spectral lines overlaid on the plot. Default: True
"""
if blines:
self._clear_overlay_objects("lines", "baseline")
if oshows:
self._clear_overlay_objects("lines", "oshow")
if catalog:
self._clear_overlay_objects("lines", "catalogline")
self._clear_overlay_objects("texts", "catalogtext")
[docs]
def clear_lines(self, gid):
self._clear_overlay_objects("lines", gid)
def _clear_overlay_objects(self, otype, gid):
"""
Clears lines with `gid` from the plot.
Parameters
----------
otype : str
Type of overlay. Can be "lines" or "texts".
gid : str
Group id for the lines to be cleared.
"""
if otype == "lines":
tgt_list = self.axes.lines
elif otype == "texts":
tgt_list = self.axes.texts
for b in tgt_list:
if b.get_gid() == gid:
b.remove()
self.figure.canvas.draw_idle()
[docs]
def oshow(self, spectra, color=None, linestyle=None, label=None, alpha=None):
"""
Add `spectra` to the current plot.
Parameters
----------
spectra : list of `dysh.spectra.spectrum.Spectrum` or `dysh.spectra.spectrum.Spectrum`
Spectra to add to the plot.
color : list of valid `matplotlib` colors or `matplotlib` color
Colors for the spectra. There must be one element per spectra.
linestyle : list of valid `matplotlib` linestyles or `matplotlib` linestyle
Linestyles for the spectra. There must be one element per spectra.
label : list of str
Labels for the spectra. There must be one element per spectra.
alpha : list of float
Alpha values for the spectra, between 0 and 1. There must be one element per spectra.
"""
# If a single Spectrum is the input, make everything a list.
if isinstance(spectra, type(self._spectrum)):
spectra = [spectra]
if color is not None:
color = [color]
if linestyle is not None:
linestyle = [linestyle]
if label is not None:
label = [label]
if alpha is not None:
alpha = [alpha]
for i, s in enumerate(spectra):
if not isinstance(s, type(self._spectrum)):
raise TypeError(f"Element {i} of spectra ({s}) is not a Spectrum.")
# Pack args together, and check that we have enough of each.
zargs = (spectra,)
if color is not None:
# Check that we have enough colors.
if len(color) != len(spectra):
raise ValueError(f"How do I color {len(spectra)} spectra with {len(color)} colors?")
zargs += (color,)
else:
zargs += ([None] * len(spectra),)
if linestyle is not None:
# Check that we have enough linestyles.
if len(linestyle) != len(spectra):
raise ValueError(f"How do I style {len(spectra)} spectra with {len(linestyle)} linestyles?")
zargs += (linestyle,)
else:
zargs += ([None] * len(spectra),)
if label is not None:
if len(label) != len(spectra):
raise ValueError(f"How do I label {len(spectra)} spectra with {len(label)} labels?")
zargs += (label,)
else:
zargs += ([None] * len(spectra),)
if alpha is not None:
if len(alpha) != len(spectra):
raise ValueError(f"How do I set alpha for {len(spectra)} spectra with {len(label)} alpha values?")
zargs += (alpha,)
else:
zargs += ([None] * len(spectra),)
for s, c, ls, l, a in zip(*zargs, strict=True):
self._oshow(s, color=c, linestyle=ls, label=l, alpha=a)
def _oshow(self, oshow_spectrum, color=None, linestyle=None, label=None, alpha=None):
this_plot_kwargs = deepcopy(self._plot_kwargs)
sf = oshow_spectrum.flux.to(self._spectrum.unit)
sa = oshow_spectrum.velocity_axis_to(
unit=self._xunit,
toframe=this_plot_kwargs["vel_frame"],
doppler_convention=this_plot_kwargs["doppler_convention"],
)
self.axes.plot(sa, sf, color=color, linestyle=linestyle, label=label, alpha=alpha, gid="oshow")
if label is not None:
self.axes.legend()
self.freexy()
self.figure.canvas.draw_idle()
[docs]
def show_catalog_lines(self, rotation=0, **kwargs):
"""
Overlay spectral lines from various catalogs on the plot, with annotations.
Parameters
----------
rotation : float, degrees
Rotate the annotation text CCW to aid in readability. Default 0.
**kwargs
All other kwargs get passed to `dysh.line.query_lines`.
"""
self.sl_tbl = self._spectrum.query_lines(**kwargs)
fsize = 9 # font size
num_vsteps = 7 # number of vertical steps of annotations
rot_factor = (rotation / 90) * 0.3 / num_vsteps # adjust ylocs to avoid rotated text running into each other
fracstep = 0.04 + rot_factor
ystart = 0.86 - (num_vsteps * fracstep)
for i, line in enumerate(self.sl_tbl):
line_name = parse_html(line["name"])
line_freq = (line["obs_frequency"] * u.MHz).to(self._xunit, equivalencies=self.spectrum.equivalencies).value
vloc = ystart + (i % num_vsteps) * fracstep
self.axes.axvline(line_freq, c="k", linewidth=1, gid="catalogline")
self.axes.annotate(
line_name,
(line_freq, vloc),
xycoords=("data", "axes fraction"),
size=fsize,
gid="catalogtext",
rotation=rotation,
)
self.figure.canvas.draw_idle()
[docs]
def annotate_vline(self, xval, text="", rotation=0):
"""
Add a single annotated vline to the plot. Can be cleared with the "catalog" gid.
Parameters
----------
xval : float
X value of the line, in the same units as the plot.
text : str
Associated text for the vline. Defaults to an empty string.
rotation : float
Rotate the text CCW degrees. Default 0.
"""
fsize = 9
self.axes.axvline(xval, c="k", linewidth=1, gid="catalogline")
self.axes.annotate(
text,
(xval, 0.7),
xycoords=("data", "axes fraction"),
size=fsize,
gid="catalogtext",
rotation=rotation,
)
self.figure.canvas.draw_idle()
[docs]
class MultiSpanSelector:
def __init__(self, ax, minspan):
self.ax = ax
self.canvas = ax.figure.canvas
self.spans = []
self.minspan = minspan
self.colors = {
"edge": (0, 0, 0, 1),
"face": (0, 0, 0, 0.3),
"edge_selected": (*mpl.colors.to_rgb("#6c3483"), 1.0),
}
# Register callbacks before creating any spans.
self.cid_press = self.canvas.mpl_connect("button_press_event", self.on_press)
self.spans.extend(self.init_selector())
self.active_span = self.spans[0]
self.selected_span = None
[docs]
def init_selector(self):
span = SpanSelector(
self.ax,
self.on_select,
direction="horizontal",
useblit=False, # blitting causes spans to blink on press.
interactive=True,
drag_from_anywhere=True,
ignore_event_outside=True,
props=dict(facecolor=self.colors["face"], alpha=0.3),
minspan=self.minspan,
)
return [span]
[docs]
def on_select(self, vmin, vmax):
span = vmax - vmin
if span > self.minspan and np.all(np.diff(self.get_selected_regions(False)) > self.minspan):
if self.active_span is not None:
self.active_span.set_active(False)
self.active_span = None
self.spans.extend(self.init_selector())
self.active_span = self.spans[-1]
elif self.active_span is not None:
self.active_span.set_active(False)
self.active_span = None
return
[docs]
def on_press(self, event):
# Do nothing if outside the axes.
if event.inaxes != self.ax:
return
# Do nothing if another widget is enabled.
if self.ax.get_navigate_mode() is not None:
return
# Determine if the event is in a span.
if len(self.spans) > 1:
got_one = False
for span in self.spans:
# Select only a single span at a time.
if span._contains(event) and not got_one and np.diff(span.extents) > self.minspan:
got_one = True
self.active_span = span
self.active_span.set_active(True)
self.active_span.set_props(**{"linewidth": 20, "edgecolor": self.colors["edge_selected"]})
self.selected_span = span
else:
span.set_active(False)
props = {"linewidth": 1, "edgecolor": self.colors["edge"]}
span.set_props(**props)
if not got_one:
self.active_span = None
self.selected_span = None
for span in self.spans:
# Determine if there's a span that needs to be completed
# and activate it.
if np.diff(span.extents) <= self.minspan:
self.active_span = span
self.active_span.set_active(True)
[docs]
def disconnect(self):
"""Disconnect all event handlers to prevent memory leaks and dangling references."""
if hasattr(self, "cid_press") and self.cid_press is not None:
self.canvas.mpl_disconnect(self.cid_press)
self.cid_press = None
[docs]
def clear_region(self, event=None):
if not self.selected_span:
return
self.selected_span.clear()
self.selected_span.disconnect_events()
self.spans.remove(self.selected_span)
del self.selected_span
self.selected_span = None
self.active_span = None
# If there's only one more span
# activate it.
if len(self.spans) == 1:
self.spans[0].set_active(True)
self.active_span = self.spans[0]
[docs]
def clear_regions(self, event=None):
for span in self.spans:
span.clear()
span.disconnect_events()
del span
self.spans.clear()
self.selected_span = None
self.active_span = None
# Add a new span.
self.spans.extend(self.init_selector())
self.active_span = self.spans[-1]
[docs]
def get_selected_regions(self, ignore_incomplete=True):
"""
Parameters
----------
ignore_complete : bool
If True ignore spans that are smaller than `self.minspan`.
Returns
-------
regions : list of tuples
List with edges of the spans as tuples.
"""
if ignore_incomplete:
regions = [span.extents for span in self.spans if np.diff(span.extents) > self.minspan]
else:
regions = [span.extents for span in self.spans]
return regions