Source code for jdaviz.core.marks

import numpy as np

from astropy import units as u
from bqplot import LinearScale
from bqplot.marks import Lines, Label, Scatter
from glue.core import HubListener
from specutils import Spectrum1D

from jdaviz.core.events import GlobalDisplayUnitChanged
from jdaviz.core.events import (SliceToolStateMessage, LineIdentifyMessage,
                                SpectralMarksChangedMessage,
                                RedshiftMessage)

__all__ = ['OffscreenLinesMarks', 'BaseSpectrumVerticalLine', 'SpectralLine',
           'SliceIndicatorMarks', 'ShadowMixin', 'ShadowLine', 'ShadowLabelFixedY',
           'PluginMark', 'LinesAutoUnit', 'PluginLine', 'PluginScatter',
           'LineAnalysisContinuum', 'LineAnalysisContinuumCenter',
           'LineAnalysisContinuumLeft', 'LineAnalysisContinuumRight',
           'LineUncertainties', 'ScatterMask', 'SelectedSpaxel', 'MarkersMark', 'FootprintOverlay',
           'ApertureMark']

accent_color = "#c75d2c"


[docs] class OffscreenLinesMarks(HubListener): def __init__(self, viewer): self.viewer = viewer viewer.state.add_callback("x_min", lambda x_min: self._update_counts()) viewer.state.add_callback("x_max", lambda x_max: self._update_counts()) viewer.session.hub.subscribe(self, RedshiftMessage, handler=self._update_counts) viewer.session.hub.subscribe(self, SpectralMarksChangedMessage, handler=self._update_counts) self.left = Label(text=[''], x=[0.02], y=[0.8], scales={'x': LinearScale(min=0, max=1), 'y': LinearScale(min=0, max=1)}, colors=['gray'], default_size=12, align='start') self.right = Label(text=[''], x=[0.98], y=[0.8], scales={'x': LinearScale(min=0, max=1), 'y': LinearScale(min=0, max=1)}, colors=['gray'], default_size=12, align='end') self._update_counts() @property def marks(self): return [self.left, self.right] def _update_counts(self, *args): oob_left, oob_right = 0, 0 for m in self.viewer.figure.marks: if isinstance(m, SpectralLine): if m.x[0] < self.viewer.state.x_min: oob_left += 1 elif m.x[0] > self.viewer.state.x_max: oob_right += 1 self.left.text = [f'\u25c0 {oob_left}' if oob_left > 0 else ''] self.right.text = [f'{oob_right} \u25b6' if oob_right > 0 else '']
[docs] class PluginMark: def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.xunit = None self.yunit = None # whether to update existing marks when global display units are changed self.auto_update_units = True self.hub.subscribe(self, GlobalDisplayUnitChanged, handler=self._on_global_display_unit_changed) if self.xunit is None: self.set_x_unit() if self.yunit is None: self.set_y_unit() @property def hub(self): return self.viewer.hub
[docs] def update_xy(self, x, y): self.x = np.asarray(x) self.y = np.asarray(y)
[docs] def append_xy(self, x, y): self.x = np.append(self.x, x) self.y = np.append(self.y, y)
[docs] def set_x_unit(self, unit=None): if unit is None: if not hasattr(self.viewer.state, 'x_display_unit'): return unit = self.viewer.state.x_display_unit unit = u.Unit(unit) if self.xunit is not None and not np.all([s == 0 for s in self.x.shape]): x = (self.x * self.xunit).to_value(unit, u.spectral()) self.xunit = unit self.x = x self.xunit = unit
[docs] def set_y_unit(self, unit=None): if unit is None: if not hasattr(self.viewer.state, 'y_display_unit'): return unit = self.viewer.state.y_display_unit unit = u.Unit(unit) if self.yunit is not None and not np.all([s == 0 for s in self.y.shape]): if self.viewer.default_class is Spectrum1D: spec = self.viewer.state.reference_data.get_object(cls=Spectrum1D) eqv = u.spectral_density(spec.spectral_axis) y = (self.y * self.yunit).to_value(unit, equivalencies=eqv) else: y = (self.y * self.yunit).to_value(unit) self.yunit = unit self.y = y self.yunit = unit
def _on_global_display_unit_changed(self, msg): if not self.auto_update_units: return if self.viewer.__class__.__name__ in ['SpecvizProfileView', 'CubevizProfileView']: axis_map = {'spectral': 'x', 'flux': 'y'} elif self.viewer.__class__.__name__ == 'MosvizProfile2DView': axis_map = {'spectral': 'x'} else: return axis = axis_map.get(msg.axis, None) if axis is not None: getattr(self, f'set_{axis}_unit')(msg.unit)
[docs] def clear(self): self.update_xy([], [])
[docs] class BaseSpectrumVerticalLine(Lines, PluginMark, HubListener): def __init__(self, viewer, x, **kwargs): self.viewer = viewer # the location of the marker will need to update automatically if the # underlying data changes (through a unit conversion, for example) if hasattr(viewer.state, 'reference_data'): viewer.state.add_callback("reference_data", self._update_reference_data) scales = viewer.scales # Lines.__init__ will set self.x super().__init__(x=[x, x], y=[0, 1], scales={'x': scales['x'], 'y': LinearScale(min=0, max=1)}, **kwargs) def _update_reference_data(self, reference_data): if reference_data is None: return self._update_unit(reference_data.get_object(cls=Spectrum1D).spectral_axis.unit) def _update_unit(self, new_unit): # the x-units may have changed. We want to convert the internal self.x # from self.xunit to the new units (x_all.unit) if self.xunit is None: self.xunit = new_unit return if new_unit == self.xunit: return old_quant = self.x[0]*self.xunit x = old_quant.to_value(new_unit, equivalencies=u.spectral()) self.x = [x, x] self.xunit = new_unit
[docs] class SpectralLine(BaseSpectrumVerticalLine): """ Subclass on bqplot Lines, mostly so that we can erase spectral lines by eliminating any SpectralLines objects from a figures marks list. Also lets us do wavelength redshifting here on mark creation. """ def __init__(self, viewer, rest_value, redshift=0, name=None, **kwargs): self._rest_value = rest_value self._identify = False self.name = name # table_index is same as name_rest elsewhere self.table_index = kwargs.pop("table_index", None) # setting redshift will set self.x and enable the obs_value property, # but to do that we need x_unit set first (would normally be assigned # in the super init) self.xunit = u.Unit(viewer.state.x_display_unit) self.redshift = redshift viewer.session.hub.subscribe(self, LineIdentifyMessage, handler=self._process_identify_change) super().__init__(viewer=viewer, x=self.obs_value, stroke_width=1, fill='none', close_path=False, **kwargs) @property def name_rest(self): return self.table_index @property def rest_value(self): return self._rest_value @property def obs_value(self): return self.x[0]
[docs] def set_x_unit(self, unit=None): prev_unit = self.xunit super().set_x_unit(unit=unit) self._rest_value = (self._rest_value * prev_unit).to_value(unit, u.spectral())
@property def redshift(self): return self._redshift @redshift.setter def redshift(self, redshift): self._redshift = redshift if str(self.xunit.physical_type) == 'length': obs_value = self._rest_value*(1+redshift) elif str(self.xunit.physical_type) == 'frequency': obs_value = self._rest_value/(1+redshift) else: # catch all for anything else (wavenumber, energy, etc) rest_angstrom = (self._rest_value*self.xunit).to_value(u.Angstrom, equivalencies=u.spectral()) obs_angstrom = rest_angstrom*(1+redshift) obs_value = (obs_angstrom*u.Angstrom).to_value(self.xunit, equivalencies=u.spectral()) self.x = [obs_value, obs_value] @property def identify(self): return self._identify @identify.setter def identify(self, identify): if not isinstance(identify, bool): # pragma: no cover raise TypeError("identify must be of type bool") self._identify = identify self.stroke_width = 3 if identify else 1 def _process_identify_change(self, msg): self.identify = msg.name_rest == self.table_index def _update_unit(self, new_unit): if self.xunit is None: self.xunit = new_unit return if new_unit == self.xunit: return old_quant = self._rest_value*self.xunit self._rest_value = old_quant.to_value(new_unit, equivalencies=u.spectral()) # re-compute self.x from current redshift (instead of converting that as well) self.redshift = self._redshift self.xunit = new_unit
[docs] class SliceIndicatorMarks(BaseSpectrumVerticalLine, HubListener): """Subclass on bqplot Lines to handle slice/wavelength indicator. """ def __init__(self, viewer, value=0, **kwargs): self._viewer = viewer self._value = None self._oob = False # out-of-bounds, either False, 'left', or 'right' self._active = False # TODO: new viewers need to respect plugin settings self._show_if_inactive = True self._show_value = True viewer.state.add_callback("x_min", lambda x_min: self._value_handle_oob(update_label=True)) viewer.state.add_callback("x_max", lambda x_max: self._value_handle_oob(update_label=True)) viewer.session.hub.subscribe(self, SliceToolStateMessage, handler=self._on_change_state) super().__init__(viewer=viewer, x=[value, value], stroke_width=2, marker='diamond', fill='none', close_path=False, labels=['slice'], labels_visibility='none', **kwargs) self.value = value # instead of using the Lines label which is limited, we'll use a Label object which # will follow the x-coordinate of the slice indicator line, with a fixed y-value # (in axes-units) and will flip its alignment depending on whether the line is on the # left or right side of the axes. self.label = ShadowLabelFixedY(viewer, self, shadow_traits=[], default_size=12, y=0.95) # default to the initial state of the tool since we can't control if this will # happen before or after the initialization of the tool tool_active = self.viewer.toolbar.active_tool_id == 'jdaviz:selectslice' self._on_change_state({'active': tool_active}) @property def marks(self): return [self, self.label] def _on_global_display_unit_changed(self, msg): # Updating the value is handled by the plugin itself, need to update unit string. if msg.axis in ["spectral", "x"]: self.xunit = msg.unit self._update_label() def _value_handle_oob(self, x=None, update_label=False): if x is None: x = self.value else: self._value = x x_min, x_max = self._viewer.state.x_min, self._viewer.state.x_max if x_min is None or x_max is None: self.x = [x, x] return x_range = x_max - x_min padding_fig = 0.01 padding = padding_fig * x_range x_min += padding x_max -= padding # ensure y-scale has been set (we'll only be overriding x, but scatter viewers complain # if y-scale is not set) self.scales.setdefault('y', LinearScale(min=0, max=1)) if x < x_min: self.x = [padding_fig, padding_fig] self.scales = {**self.scales, 'x': LinearScale(min=0, max=1)} self.line_style = 'dashed' self._oob = 'left' elif x > x_max: self.x = [1-padding_fig, 1-padding_fig] self.scales = {**self.scales, 'x': LinearScale(min=0, max=1)} self.line_style = 'dashed' self._oob = 'right' else: self.x = [x, x] self.scales = {**self.scales, 'x': self._viewer.scales['x']} self.line_style = 'solid' self._oob = False if update_label: self._update_label() def _update_colors_opacities(self): # orange (accent) if active, import button blue otherwise (see css in main_styles.vue) if not self._show_if_inactive and not self._active: self.label.visible = False self.visible = False return self.visible = True self.label.visible = self._show_value self.colors = ["#c75109" if self._active else "#007BA1"] self.opacities = [1.0 if self._active else 0.9] def _on_change_state(self, msg={}): if isinstance(msg, dict): changes = msg else: if msg.viewer is not None and msg.viewer != self.viewer: return changes = msg.change for k, v in changes.items(): if k == 'active': self._active = v elif k == 'show_indicator': self._show_if_inactive = v elif k == 'show_value': self._show_value = v self._update_colors_opacities() def _update_label(self): def _formatted_value(value): power = abs(np.log10(value)) if power >= 3: # use scientific notation return f'{value:0.4e}' else: return f'{value:0.4f}' valuestr = _formatted_value(self.value) xunit = str(self.xunit) if self.xunit is not None else '' # U+00A0 is a blank space, U+25C0 a left arrow triangle, and U+25B6 a right arrow triangle if self._oob == 'left': self.labels = [f'\u00A0 \u25c0 {valuestr} {xunit} \u00A0'] # noqa elif self._oob == 'right': self.labels = [f'{valuestr} {xunit} \u25b6 \u00A0'] else: self.labels = [f'\u00A0 {valuestr} {xunit} \u00A0'] @property def value(self): return self._value @value.setter def value(self, value): self._value_handle_oob(value, update_label=True)
[docs] class ShadowMixin: """Mixin class to propagate traits from one mark object to another. Anything in ``sync_traits`` will be mirrored directly from ``shadowing`` to the shadowed object. Can manually override ``_on_shadowing_changed`` for more advanced logic cases. """ def _get_id(self, mark): return getattr(mark, '_model_id', None) def _setup_shadowing(self, shadowing, sync_traits=[], other_traits=[]): """ sync_traits: traits to set now, and mirror any changes to shadowing in the future other_trait: traits to set now, but not mirror in the future """ if not hasattr(self, '_shadowing'): self._shadowing = {} self._sync_traits = {} shadowing_id = self._get_id(shadowing) self._shadowing[shadowing_id] = shadowing self._sync_traits[shadowing_id] = sync_traits # sync initial values for attr in sync_traits + other_traits: self._on_shadowing_changed({'name': attr, 'new': getattr(shadowing, attr), 'owner': shadowing}) # subscribe to future changes shadowing.observe(self._on_shadowing_changed) def _on_shadowing_changed(self, change): if change['name'] in self._sync_traits.get(self._get_id(change.get('owner')), []): setattr(self, change['name'], change['new']) return
[docs] class ShadowLine(Lines, HubListener, ShadowMixin): """Create a white shadow line around another line to help make it standout on top of other lines. """ def __init__(self, shadowing, shadow_width=1, **kwargs): self._shadow_width = shadow_width super().__init__(scales=shadowing.scales, stroke_width=shadowing.stroke_width+shadow_width if shadowing.stroke_width else 0, # noqa marker_size=shadowing.marker_size+shadow_width if shadowing.marker_size else 0, # noqa colors=[kwargs.pop('color', 'white')], **kwargs) self._setup_shadowing(shadowing, ['scales', 'x', 'y', 'visible', 'line_style', 'marker'], ['stroke_width', 'marker_size'])
[docs] class ShadowLabelFixedY(Label, ShadowMixin): """Label whose position shadows that of a parent ``shadowing`` line and will flip alignment based on whether it is left or right of the center of the viewer. """ def __init__(self, viewer, shadowing, shadow_traits=['visible'], y=0.95, point_index=0, **kwargs): super().__init__(**kwargs) self._viewer = viewer self.y = [y] self.scales['y'] = LinearScale(min=0, max=1) self._point_index = point_index self._setup_shadowing(shadowing, shadow_traits, ['x', 'scales', 'labels', 'colors']) viewer.state.add_callback("x_min", lambda x_min: self._update_align()) viewer.state.add_callback("x_max", lambda x_max: self._update_align()) def _force_redraw(self): # TODO: bug in bqplot that change in align/colors traitlet doesn't update immediately, # we'll get around it in the meantime by just forcing the Label to see a change to the # text traitlet text = self.text self.text = [''] self.text = text def _update_align(self): if not isinstance(self.scales.get('x'), LinearScale): return # determine alignment automatically if self.scales['x'].min == 0 and self.scales['x'].max == 1: # then we're in axes units, so just check position compared to 0.5 is_to_right = self.x[0] > 0.5 else: # then we're in data units, so check position compared to the median of the axes limits is_to_right = self.x[0] > (self._viewer.state.x_min + self._viewer.state.x_max) / 2. if is_to_right and self.align != 'end': self.align = 'end' # force redraw by re-updating label self._force_redraw() if not is_to_right and self.align != 'start': self.align = 'start' # force redraw by re-updating label self._force_redraw() def _on_shadowing_changed(self, change): super()._on_shadowing_changed(change) if change['name'] == 'labels': self.text = [change['new'][self._point_index]] elif change['name'] in ('x', 'colors'): setattr(self, change['name'], [change['new'][self._point_index]]) if change['name'] == 'colors': # bqplot bug that won't notice change to colors, manually force re-draw self._force_redraw() elif change['name'] == 'scales': self.scales = {**self.scales, 'x': change['new']['x']} if change['name'] in ('x', 'scales'): # then the position of the label on the plot has changed, so re-determine whether # it should be aligned to the left or right self._update_align()
[docs] class LinesAutoUnit(PluginMark, Lines, HubListener): def __init__(self, viewer, *args, **kwargs): self.viewer = viewer super().__init__(*args, **kwargs)
[docs] class PluginLine(Lines, PluginMark, HubListener): def __init__(self, viewer, x=[], y=[], **kwargs): self.viewer = viewer # color is same blue as import button kwargs.setdefault('colors', [accent_color]) super().__init__(x=x, y=y, scales=kwargs.pop('scales', viewer.scales), **kwargs)
[docs] class PluginScatter(Scatter, PluginMark, HubListener): def __init__(self, viewer, x=[], y=[], **kwargs): self.viewer = viewer # default color is same blue as import button kwargs.setdefault('colors', [accent_color]) super().__init__(x=x, y=y, scales=kwargs.pop('scales', viewer.scales), **kwargs)
[docs] class LineAnalysisContinuum(PluginLine): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # units do not need to be updated because the plugin itself reruns # the computation and automatically changes the arrays themselves self.auto_update_units = False
[docs] class LineAnalysisContinuumCenter(LineAnalysisContinuum): def __init__(self, viewer, x=[], y=[], **kwargs): super().__init__(viewer, x, y, **kwargs) self.stroke_width = 1
[docs] class LineAnalysisContinuumLeft(LineAnalysisContinuum): def __init__(self, viewer, x=[], y=[], **kwargs): super().__init__(viewer, x, y, **kwargs) self.stroke_width = 5
[docs] class LineAnalysisContinuumRight(LineAnalysisContinuumLeft): pass
[docs] class LineUncertainties(LinesAutoUnit): def __init__(self, viewer, *args, **kwargs): super().__init__(viewer, *args, **kwargs)
[docs] class ScatterMask(Scatter): def __init__(self, **kwargs): super().__init__(**kwargs)
[docs] class SelectedSpaxel(Lines): def __init__(self, **kwargs): super().__init__(**kwargs)
[docs] class MarkersMark(PluginScatter): def __init__(self, viewer, **kwargs): kwargs.setdefault('marker', 'circle') super().__init__(viewer, **kwargs)
[docs] class FootprintOverlay(PluginLine): def __init__(self, viewer, overlay, **kwargs): self._overlay = overlay kwargs.setdefault('stroke_width', 2) kwargs.setdefault('close_path', True) kwargs.setdefault('opacities', [0.8]) kwargs.setdefault('fill', 'inside') kwargs.setdefault('fill_opacities', [0.2]) super().__init__(viewer, **kwargs) @property def overlay(self): return self._overlay
[docs] class ApertureMark(PluginLine): def __init__(self, viewer, id, **kwargs): self._id = id super().__init__(viewer, **kwargs)
class HistogramMark(Lines): def __init__(self, min_max_value, scales, **kwargs): # Vertical line in LinearScale y = [0, 1] colors = [accent_color] line_style = "solid" super().__init__(x=min_max_value, y=y, scales=scales, colors=colors, line_style=line_style, **kwargs)