diff --git a/CHANGES.rst b/CHANGES.rst index 853adc08..2c8177c1 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,7 +1,7 @@ 0.20.1 (unreleased) =================== -- +- Recursively convert all meta attributes during model casting. [#352] 0.20.0 (2024-05-15) =================== diff --git a/src/roman_datamodels/datamodels/_datamodels.py b/src/roman_datamodels/datamodels/_datamodels.py index 54cac754..5521d03f 100644 --- a/src/roman_datamodels/datamodels/_datamodels.py +++ b/src/roman_datamodels/datamodels/_datamodels.py @@ -6,6 +6,8 @@ from the schema manifest defined by RAD. """ +from collections.abc import Mapping + import asdf import numpy as np from astropy.table import QTable @@ -135,37 +137,61 @@ class RampModel(_RomanDataModel): @classmethod def from_science_raw(cls, model): """ - Construct a RampModel from a ScienceRawModel + Attempt to construct a RampModel from a DataModel + + If the model has a resultantdq attribute, this is copied into + the RampModel.groupdq attribute. Parameters ---------- - model : ScienceRawModel or RampModel - The input science raw model (a RampModel will also work) + model : ScienceRawModel, TvacModel + The input data model (a RampModel will also work). + + Returns + ------- + ramp_model : RampModel + The RampModel built from the input model. If the input is already + a RampModel, it is simply returned. """ + ALLOWED_MODELS = (FpsModel, RampModel, ScienceRawModel, TvacModel) if isinstance(model, cls): return model + if not isinstance(model, ALLOWED_MODELS): + raise ValueError(f"Input must be one of {ALLOWED_MODELS}") + + # Create base ramp node with dummy values (for validation) + from roman_datamodels.maker_utils import mk_ramp + + ramp = mk_ramp(shape=model.shape) + + # check if the input model has a resultantdq from SDF + if hasattr(model, "resultantdq"): + ramp.groupdq = model.resultantdq.copy() + + # Define how to recursively copy all attributes. + def node_update(self, other): + """Implement update to directly access each value""" + for key in other.keys(): + if key == "resultantdq": + continue + if key in self: + if isinstance(self[key], Mapping): + node_update(self[key], other.__getattr__(key)) + continue + if isinstance(self[key], list): + self[key] = other.__getattr__(key).data + continue + if isinstance(self[key], np.ndarray): + self[key] = other.__getattr__(key).astype(self[key].dtype) + continue + self[key] = other.__getattr__(key) - if isinstance(model, ScienceRawModel): - from roman_datamodels.maker_utils import mk_ramp - - instance = mk_ramp(shape=model.shape) - - # Copy input_model contents into RampModel - for key in model: - # If a dictionary (like meta), overwrite entries (but keep - # required dummy entries that may not be in input_model) - if isinstance(instance[key], dict): - instance[key].update(getattr(model, key)) - elif isinstance(instance[key], np.ndarray): - # Cast input ndarray as RampModel dtype - instance[key] = getattr(model, key).astype(instance[key].dtype) - else: - instance[key] = getattr(model, key) - - return cls(instance) + node_update(ramp, model) - raise ValueError("Input model must be a ScienceRawModel or RampModel") + # Create model from node + ramp_model = RampModel(ramp) + return ramp_model class RampFitOutputModel(_RomanDataModel): diff --git a/tests/test_models.py b/tests/test_models.py index ebaf1011..47614bef 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -958,3 +958,29 @@ def test_datamodel_save_filename(tmp_path): with datamodels.open(filename) as new_ramp: assert new_ramp.meta.filename == filename.name + + +@pytest.mark.parametrize( + "model_class, expect_success", + [ + (datamodels.FpsModel, True), + (datamodels.RampModel, True), + (datamodels.ScienceRawModel, True), + (datamodels.TvacModel, True), + (datamodels.MosaicModel, False), + ], +) +def test_rampmodel_from_science_raw(model_class, expect_success): + """Test creation of RampModel from raw science/tvac""" + model = utils.mk_datamodel( + model_class, meta={"calibration_software_version": "1.2.3", "exposure": {"read_pattern": [[1], [2], [3]]}} + ) + if expect_success: + ramp = datamodels.RampModel.from_science_raw(model) + + assert ramp.meta.calibration_software_version == model.meta.calibration_software_version + assert ramp.meta.exposure.read_pattern == model.meta.exposure.read_pattern + + else: + with pytest.raises(ValueError): + datamodels.RampModel.from_science_raw(model)