Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Display units: model fitting #2216

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion jdaviz/components/tooltip.vue
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ const tooltips = {
'plugin-plot-options-mixed-state': 'Current values are mixed, click to sync at shown value',
'plugin-model-fitting-add-model': 'Create model component',
'plugin-model-fitting-param-fixed': 'Check the box to freeze parameter value',
'plugin-model-fitting-reestimate-all': 'Re-estimate initial values based on the current data/subset selection for all free parameters',
'plugin-model-fitting-reestimate-all': 'Re-estimate initial values based on the current data/subset selection for all free parameters based on current display units',
'plugin-model-fitting-reestimate': 'Re-estimate initial values based on the current data/subset selection for all free parameters in this component',
'plugin-unit-conversion-apply': 'Apply unit conversion',
'plugin-line-lists-load': 'Load list into "Loaded Lines" section of plugin',
Expand Down
89 changes: 84 additions & 5 deletions jdaviz/configs/default/plugins/model_fitting/model_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from traitlets import Bool, List, Unicode, observe
from glue.core.data import Data

from jdaviz.core.events import SnackbarMessage
from jdaviz.core.events import SnackbarMessage, GlobalDisplayUnitChanged
from jdaviz.core.registries import tray_registry
from jdaviz.core.template_mixin import (PluginTemplateMixin,
SelectPluginComponent,
Expand Down Expand Up @@ -60,6 +60,7 @@ class ModelFitting(PluginTemplateMixin, DatasetSelectMixin,
* :meth:`create_model_component`
* :meth:`remove_model_component`
* :meth:`model_components`
* :meth:`valid_model_components`
* :meth:`get_model_component`
* :meth:`set_model_component`
* :meth:`reestimate_model_parameters`
Expand Down Expand Up @@ -169,13 +170,17 @@ def __init__(self, *args, **kwargs):
# set the filter on the viewer options
self._update_viewer_filters()

self.hub.subscribe(self, GlobalDisplayUnitChanged,
handler=self._on_global_display_unit_changed)

@property
def user_api(self):
expose = ['dataset']
if self.config == "cubeviz":
expose += ['spatial_subset']
expose += ['spectral_subset', 'model_component', 'poly_order', 'model_component_label',
'model_components', 'create_model_component', 'remove_model_component',
'model_components', 'valid_model_components',
'create_model_component', 'remove_model_component',
'get_model_component', 'set_model_component', 'reestimate_model_parameters',
'equation', 'equation_components',
'add_results', 'residuals_calculate', 'residuals']
Expand Down Expand Up @@ -336,6 +341,7 @@ def _dataset_selected_changed(self, event=None):
# (won't affect calculations because these locations are masked)
selected_spec.flux[np.isnan(selected_spec.flux)] = 0.0

# TODO: can we simplify this logic?
self._units["x"] = str(
selected_spec.spectral_axis.unit)
self._units["y"] = str(
Expand Down Expand Up @@ -503,8 +509,38 @@ def _initialize_model_component(self, model_comp, comp_label, poly_order=None):
self._initialized_models[comp_label] = initialized_model

new_model["Initialized"] = True
new_model["initialized_display_units"] = self._units.copy()

new_model["compat_display_units"] = True # always compatible at time of creation
return new_model

def _check_model_component_compat(self, axes=['x', 'y'], display_units=None):
if display_units is None:
display_units = [u.Unit(self._units[ax]) for ax in axes]

disp_physical_types = [unit.physical_type for unit in display_units]

for model_index, comp_model in enumerate(self.component_models):
compat = True
for ax, ax_physical_type in zip(axes, disp_physical_types):
comp_unit = u.Unit(comp_model["initialized_display_units"][ax])
compat = comp_unit.physical_type == ax_physical_type
if not compat:
break
self.component_models[model_index]["compat_display_units"] = compat

# length hasn't changed, so we need to force the traitlet to update
self.send_state("component_models")
self._check_model_equation_invalid()

def _on_global_display_unit_changed(self, msg):
axis = {'spectral': 'x', 'flux': 'y'}.get(msg.axis)

# update internal tracking of current units
self._units[axis] = str(msg.unit)

self._check_model_component_compat([axis], [msg.unit])

def remove_model_component(self, model_component_label):
"""
Remove an existing model component.
Expand Down Expand Up @@ -634,6 +670,9 @@ def reestimate_model_parameters(self, model_component_label=None):
# length hasn't changed, so we need to force the traitlet to update
self.send_state("component_models")

# model units may have changed, need to re-check their compatibility with display units
self._check_model_component_compat()

# return user-friendly info on revised model
return self.get_model_component(model_component_label)

Expand All @@ -644,12 +683,19 @@ def model_components(self):
"""
return [x["id"] for x in self.component_models]

@property
def valid_model_components(self):
"""
List of the labels of existing valid (due to display units) model components
"""
return [x["id"] for x in self.component_models if x["compat_display_units"]]

@property
def equation_components(self):
"""
List of the labels of model components in the current equation
"""
return re.split('[+*/-]', self.equation.value)
return re.split(r'[+*/-]', self.equation.value.replace(' ', ''))

def vue_add_model(self, event):
self.create_model_component()
Expand All @@ -658,10 +704,41 @@ def vue_remove_model(self, event):
self.remove_model_component(event)

@observe('model_equation')
def _model_equation_changed(self, event):
def _check_model_equation_invalid(self, event=None):
# Length is a dummy check to test the infrastructure
if len(self.model_equation) == 0:
self.model_equation_invalid_msg = 'model equation is required'
self.model_equation_invalid_msg = 'model equation is required.'
return
if '' in self.equation_components:
# includes an operator without a variable (ex: 'C+')
self.model_equation_invalid_msg = 'incomplete equation.'
return

components_not_existing = [comp for comp in self.equation_components
if comp not in self.model_components]
if len(components_not_existing):
if len(components_not_existing) == 1:
msg = "is not an existing model component."
else:
msg = "are not existing model components."
self.model_equation_invalid_msg = f'{", ".join(components_not_existing)} {msg}'
return
components_not_valid = [comp for comp in self.equation_components
if comp not in self.valid_model_components]
if len(components_not_valid):
if len(components_not_valid) == 1:
msg = ("is currently disabled because it has"
" incompatible units with the current display units."
" Remove the component from the equation,"
" re-estimate its free parameters to use the new units"
" or revert the display units.")
else:
msg = ("are currently disabled because they have"
" incompatible units with the current display units."
" Remove the components from the equation,"
" re-estimate their free parameters to use the new units"
" or revert the display units.")
self.model_equation_invalid_msg = f'{", ".join(components_not_valid)} {msg}'
return
self.model_equation_invalid_msg = ''

Expand Down Expand Up @@ -707,6 +784,8 @@ def calculate_fit(self, add_data=True):
if not self.spectral_subset_valid:
valid, spec_range, subset_range = self._check_dataset_spectral_subset_valid(return_ranges=True) # noqa
raise ValueError(f"spectral subset '{self.spectral_subset.selected}' {subset_range} is outside data range of '{self.dataset.selected}' {spec_range}") # noqa
if len(self.model_equation_invalid_msg):
raise ValueError(f"model equation is invalid: {self.model_equation_invalid_msg}")

if self.cube_fit:
ret = self._fit_model_to_cube(add_data=add_data)
Expand Down
40 changes: 32 additions & 8 deletions jdaviz/configs/default/plugins/model_fitting/model_fitting.vue
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,38 @@
</v-row>
</v-expansion-panel-header>
<v-expansion-panel-content>
<v-row
v-if="!componentInEquation(item.id)"
class="v-messages v-messages__message text--secondary"
style="padding-top: 12px"
>
<span><b>{{ item.id }}</b> model component not in equation</span>
<v-row v-if="!item.compat_display_units">
<v-alert :type="componentInEquation(item.id) ? 'error' : 'warning'">
<b>{{ item.id }}</b> is inconsistent with the current display units so cannot be used in the model equation.
Create a new model component or re-estimate the free parameters based on the current display units.
<v-row
justify="end"
style="padding-top: 12px; padding-right: 2px"
>
<j-tooltip tipid='plugin-model-fitting-reestimate'>
<v-btn
tile
:elevation=0
x-small
dense
color="turquoise"
dark
style="padding-left: 8px; padding-right: 6px;"
@click="reestimate_model_parameters(item.id)">
<v-icon left small dense style="margin-right: 2px">mdi-restart</v-icon>
Re-estimate free parameters
</v-btn>
</j-tooltip>
</v-row>
</v-alert>
</v-row>
<v-row v-if="item.compat_display_units && !componentInEquation(item.id)">
<v-alert type="info">
<b>{{ item.id }}</b> model component not in equation
</v-alert>
</v-row>
<v-row justify="end"
<v-row v-if="item.compat_display_units"
justify="end"
style="padding-top: 12px; padding-right: 2px"
>
<j-tooltip tipid='plugin-model-fitting-reestimate'>
Expand Down Expand Up @@ -290,7 +314,7 @@
},
methods: {
componentInEquation(componentId) {
return this.model_equation.split(/[+*\/-]/).indexOf(componentId) !== -1
return this.model_equation.replace(/\s/g, '').split(/[+*\/-]/).indexOf(componentId) !== -1
},
roundUncertainty(uncertainty) {
return uncertainty.toPrecision(2)
Expand Down
53 changes: 53 additions & 0 deletions jdaviz/configs/default/plugins/model_fitting/tests/test_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,3 +327,56 @@ def test_results_table(specviz_helper, spectrum1d):
'G:mean_1:fixed', 'G:mean_1:std',
'G:stddev_1', 'G:stddev_1:unit',
'G:stddev_1:fixed', 'G:stddev_1:std']


def test_equation_validation(specviz_helper, spectrum1d):
data_label = 'test'
specviz_helper.load_data(spectrum1d, data_label=data_label)

mf = specviz_helper.plugins['Model Fitting']
mf.create_model_component('Const1D')
mf.create_model_component('Linear1D')

assert mf.equation == 'C+L'
assert mf._obj.model_equation_invalid_msg == ''

mf.equation = 'L+'
assert mf._obj.model_equation_invalid_msg == 'incomplete equation.'

mf.equation = 'L+C'
assert mf._obj.model_equation_invalid_msg == ''

mf.equation = 'L+CC'
assert mf._obj.model_equation_invalid_msg == 'CC is not an existing model component.'

mf.equation = 'L+CC+DD'
assert mf._obj.model_equation_invalid_msg == 'CC, DD are not existing model components.'

mf.equation = ''
assert mf._obj.model_equation_invalid_msg == 'model equation is required.'


@pytest.mark.filterwarnings(r"ignore:Model is linear in parameters.*")
@pytest.mark.filterwarnings(r"ignore:The fit may be unsuccessful.*")
def test_incompatible_units(specviz_helper, spectrum1d):
data_label = 'test'
specviz_helper.load_data(spectrum1d, data_label=data_label)

mf = specviz_helper.plugins['Model Fitting']
mf.create_model_component('Linear1D')

mf.add_results.label = 'model native units'
mf.calculate_fit(add_data=True)

uc = specviz_helper.plugins['Unit Conversion']
assert uc.spectral_unit.selected == "Angstrom"
uc.spectral_unit = u.Hz

assert 'L is currently disabled' in mf._obj.model_equation_invalid_msg
mf.add_results.label = 'frequency units'
with pytest.raises(ValueError, match=r"model equation is invalid.*"):
mf.calculate_fit(add_data=True)

mf.reestimate_model_parameters()
assert mf._obj.model_equation_invalid_msg == ''
mf.calculate_fit(add_data=True)
3 changes: 2 additions & 1 deletion jdaviz/core/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,8 @@ def _handle_display_units(data, use_display_units):
new_uncert = data.uncertainty.__class__(data.uncertainty.quantity.to(flux_unit)) if data.uncertainty is not None else None # noqa
data = Spectrum1D(spectral_axis=data.spectral_axis.to(spectral_unit,
u.spectral()),
flux=data.flux.to(flux_unit),
flux=data.flux.to(flux_unit,
u.spectral_density(data.spectral_axis)),
uncertainty=new_uncert)
else: # pragma: nocover
raise NotImplementedError(f"converting {data.__class__.__name__} to display units is not supported") # noqa
Expand Down