diff --git a/jdaviz/configs/default/plugins/markers/markers.py b/jdaviz/configs/default/plugins/markers/markers.py index ed06f00910..754eee5d8c 100644 --- a/jdaviz/configs/default/plugins/markers/markers.py +++ b/jdaviz/configs/default/plugins/markers/markers.py @@ -139,7 +139,6 @@ def _on_viewer_key_event(self, viewer, data): raise ValueError(f'failed to add {row_info} to table: {repr(err)}') x, y = row_info['axes_x'], row_info['axes_y'] - # TODO: will need to test/update when adding support for display units self._get_mark(viewer).append_xy(getattr(x, 'value', x), getattr(y, 'value', y)) def clear_table(self): diff --git a/jdaviz/configs/default/plugins/viewers.py b/jdaviz/configs/default/plugins/viewers.py index 7023f19671..70f33aa657 100644 --- a/jdaviz/configs/default/plugins/viewers.py +++ b/jdaviz/configs/default/plugins/viewers.py @@ -216,6 +216,10 @@ def jdaviz_helper(self): """The Jdaviz configuration helper tied to the viewer.""" return self.jdaviz_app._jdaviz_helper + @property + def hub(self): + return self.session.hub + @property def reference_id(self): return self._reference_id diff --git a/jdaviz/core/marks.py b/jdaviz/core/marks.py index 7c2cc5a223..08ffba76b9 100644 --- a/jdaviz/core/marks.py +++ b/jdaviz/core/marks.py @@ -7,6 +7,7 @@ from glue.core import HubListener from specutils import Spectrum1D +from jdaviz.core.events import GlobalDisplayUnitChanged from jdaviz.core.events import (SliceToolStateMessage, LineIdentifyMessage, SpectralMarksChangedMessage, RedshiftMessage) @@ -485,6 +486,20 @@ def _on_shadowing_changed(self, change): class PluginMark(): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.xunit = None + self.yunit = None + # whether to update existing marks when global display units are changed + self.auto_update_units = True + self.hub.subscribe(self, GlobalDisplayUnitChanged, + handler=self._on_global_display_unit_changed) + self._update_units() + + @property + def hub(self): + return self.viewer.hub + def update_xy(self, x, y): self.x = np.asarray(x) self.y = np.asarray(y) @@ -493,24 +508,79 @@ def append_xy(self, x, y): self.x = np.append(self.x, x) self.y = np.append(self.y, y) + def _update_units(self): + if not self.auto_update_units: + return + if self.xunit is None: + self.set_x_unit() + if self.yunit is None: + self.set_y_unit() + + def set_x_unit(self, unit=None): + if unit is None: + if not hasattr(self.viewer.state, 'x_display_unit'): + return + unit = self.viewer.state.x_display_unit + unit = u.Unit(unit) + if self.xunit is not None: + x = (self.x * self.xunit).to_value(unit, u.spectral()) + self.xunit = unit + self.x = x + self.xunit = unit + + def set_y_unit(self, unit=None): + if unit is None: + if not hasattr(self.viewer.state, 'y_display_unit'): + return + unit = self.viewer.state.y_display_unit + unit = u.Unit(unit) + if self.yunit is not None: + self.y = (self.y * self.yunit).to_value(unit) + self.yunit = unit + + def _on_global_display_unit_changed(self, msg): + if not self.auto_update_units: + return + if self.viewer.__class__.__name__ in ['SpecvizProfileView', 'CubevizProfileView']: + axis_map = {'spectral': 'x', 'flux': 'y'} + elif self.viewer.__class__.__name__ == 'MosvizProfile2DView': + axis_map = {'spectral': 'x'} + else: + return + axis = axis_map.get(msg.axis, None) + if axis is not None: + getattr(self, f'set_{axis}_unit')(msg.unit) + def clear(self): self.update_xy([], []) +class LinesAutoUnit(PluginMark, Lines, HubListener): + def __init__(self, viewer, *args, **kwargs): + self.viewer = viewer + super().__init__(*args, **kwargs) + + class PluginLine(Lines, PluginMark, HubListener): def __init__(self, viewer, x=[], y=[], **kwargs): + self.viewer = viewer # color is same blue as import button super().__init__(x=x, y=y, colors=["#007BA1"], scales=viewer.scales, **kwargs) class PluginScatter(Scatter, PluginMark, HubListener): def __init__(self, viewer, x=[], y=[], **kwargs): + self.viewer = viewer # color is same blue as import button super().__init__(x=x, y=y, colors=["#007BA1"], scales=viewer.scales, **kwargs) class LineAnalysisContinuum(PluginLine): - pass + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # units do not need to be updated because the plugin itself reruns + # the computation and automatically changes the arrays themselves + self.auto_update_units = False class LineAnalysisContinuumCenter(LineAnalysisContinuum): @@ -529,9 +599,9 @@ class LineAnalysisContinuumRight(LineAnalysisContinuumLeft): pass -class LineUncertainties(Lines): - def __init__(self, **kwargs): - super().__init__(**kwargs) +class LineUncertainties(LinesAutoUnit): + def __init__(self, viewer, *args, **kwargs): + super().__init__(viewer, *args, **kwargs) class ScatterMask(Scatter):