Skip to content

Commit

Permalink
RCAL-878: TVAC Node Saving (#369)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Brett <brettgraham@gmail.com>
  • Loading branch information
3 people authored Aug 2, 2024
1 parent 911485a commit ea915f8
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 28 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

- Enable asdf "lazy_tree" mode for all roman datamodels files [#358]

- Fix to preserve extra TVAC specific data when processed through DQ Init. [#369]

0.20.0 (2024-05-15)
===================

Expand Down
24 changes: 12 additions & 12 deletions src/roman_datamodels/datamodels/_datamodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,22 +174,22 @@ def from_science_raw(cls, model):
ramp.groupdq = model.resultantdq.copy()

# Define how to recursively copy all attributes.
def node_update(self, other):
def node_update(ramp, other):
"""Implement update to directly access each value"""
for key in other.keys():
if key == "resultantdq":
continue
if key in self:
if isinstance(self[key], Mapping):
node_update(self[key], other.__getattr__(key))
continue
if isinstance(self[key], list):
self[key] = other.__getattr__(key).data
continue
if isinstance(self[key], np.ndarray):
self[key] = other.__getattr__(key).astype(self[key].dtype)
continue
self[key] = other.__getattr__(key)
if key in ramp:
if isinstance(ramp[key], Mapping):
node_update(getattr(ramp, key), getattr(other, key))
elif isinstance(ramp[key], list):
setattr(ramp, key, getattr(other, key).data)
elif isinstance(ramp[key], np.ndarray):
setattr(ramp, key, getattr(other, key).astype(ramp[key].dtype))
else:
setattr(ramp, key, getattr(other, key))
else:
ramp[key] = other[key]

node_update(ramp, model)

Expand Down
28 changes: 14 additions & 14 deletions src/roman_datamodels/stnode/_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,18 @@ def ctx(self):
return self._ctx

@staticmethod
def _convert_to_scalar(key, value):
def _convert_to_scalar(key, value, ref=None):
"""Find and wrap scalars in the appropriate class, if its a tagged one."""
from ._tagged import TaggedScalarNode

if isinstance(ref, TaggedScalarNode):
# we want the exact class (not possible subclasses)
if type(value) == type(ref): # noqa: E721
return value
return type(ref)(value)

if isinstance(value, TaggedScalarNode):
return value

if key in SCALAR_NODE_CLASSES_BY_KEY:
value = SCALAR_NODE_CLASSES_BY_KEY[key](value)
Expand Down Expand Up @@ -231,7 +241,7 @@ def __setattr__(self, key, value):
if key[0] != "_":

# Wrap things in the tagged scalar classes if necessary
value = self._convert_to_scalar(key, value)
value = self._convert_to_scalar(key, value, self._data.get(key))

if key in self._data or key in self._schema_attributes:
# Perform validation if enabled
Expand Down Expand Up @@ -316,22 +326,12 @@ def __setitem__(self, key, value):
"""Dictionary style access set data"""

# Convert the value to a tagged scalar if necessary
if self._tag and "/tvac" in self._tag:
value = self._convert_to_scalar("tvac_" + key, value)
elif self._tag and "/fps" in self._tag:
value = self._convert_to_scalar("fps_" + key, value)
else:
value = self._convert_to_scalar(key, value)
value = self._convert_to_scalar(key, value, self._data.get(key))

# If the value is a dictionary, loop over its keys and convert them to tagged scalars
if isinstance(value, (dict, asdf.lazy_nodes.AsdfDictNode)):
for sub_key, sub_value in value.items():
if self._tag and "/tvac" in self._tag:
value[sub_key] = self._convert_to_scalar("tvac_" + sub_key, sub_value)
elif self._tag and "/fps" in self._tag:
value[sub_key] = self._convert_to_scalar("fps_" + sub_key, sub_value)
else:
value[sub_key] = self._convert_to_scalar(sub_key, sub_value)
value[sub_key] = self._convert_to_scalar(sub_key, sub_value)

self._data[key] = value

Expand Down
41 changes: 40 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,17 +970,56 @@ def test_datamodel_save_filename(tmp_path):
(datamodels.MosaicModel, False),
],
)
def test_rampmodel_from_science_raw(model_class, expect_success):
def test_rampmodel_from_science_raw(tmp_path, model_class, expect_success):
"""Test creation of RampModel from raw science/tvac"""
model = utils.mk_datamodel(
model_class, meta={"calibration_software_version": "1.2.3", "exposure": {"read_pattern": [[1], [2], [3]]}}
)
if expect_success:
filename = tmp_path / "fancy_filename.asdf"
ramp = datamodels.RampModel.from_science_raw(model)

assert ramp.meta.calibration_software_version == model.meta.calibration_software_version
assert ramp.meta.exposure.read_pattern == model.meta.exposure.read_pattern
assert ramp.validate() is None

ramp.save(filename)
with datamodels.open(filename) as new_ramp:
assert new_ramp.meta.calibration_software_version == model.meta.calibration_software_version

else:
with pytest.raises(ValueError):
datamodels.RampModel.from_science_raw(model)


@pytest.mark.parametrize(
"model_class",
[datamodels.FpsModel, datamodels.RampModel, datamodels.ScienceRawModel, datamodels.TvacModel, datamodels.MosaicModel],
)
def test_model_assignment_access_types(model_class):
"""Test assignment and access of model keyword value via keys and dot notation"""
# Test creation
model = utils.mk_datamodel(
model_class, meta={"calibration_software_version": "1.2.3", "exposure": {"read_pattern": [[1], [2], [3]]}}
)

assert model["meta"]["filename"] == model.meta["filename"]
assert model["meta"]["filename"] == model.meta.filename
assert model.meta.filename == model.meta["filename"]
assert type(model["meta"]["filename"]) == type(model.meta["filename"]) # noqa: E721
assert type(model["meta"]["filename"]) == type(model.meta.filename) # noqa: E721
assert type(model.meta.filename) == type(model.meta["filename"]) # noqa: E721

# Test assignment
model2 = utils.mk_datamodel(model_class, meta={"calibration_software_version": "4.5.6"})

model.meta["filename"] = "Roman_keys_test.asdf"
model2.meta.filename = "Roman_dot_test.asdf"

assert model.validate() is None
assert model2.validate() is None

# Test assignment types
assert type(model["meta"]["filename"]) == type(model2.meta["filename"]) # noqa: E721
assert type(model["meta"]["filename"]) == type(model2.meta.filename) # noqa: E721
assert type(model.meta.filename) == type(model2.meta["filename"]) # noqa: E721
2 changes: 1 addition & 1 deletion tests/test_stnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def test_set_pattern_properties():
with pytest.raises(asdf.ValidationError):
mdl.phot_table.F062.pixelareasr = 3.14

# This is invalid be cause it is not a scalar
# This is invalid because it is not a scalar
with pytest.raises(asdf.ValidationError):
mdl.phot_table.F062.photmjsr = [37.0] * (u.MJy / u.sr)
with pytest.raises(asdf.ValidationError):
Expand Down

0 comments on commit ea915f8

Please sign in to comment.