Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RCAL-878: TVAC Node Saving #369

Merged
merged 20 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

- Enable asdf "lazy_tree" mode for all roman datamodels files [#358]

- Fix to preserve extra TVAC specific data when processed through DQ Init. [#369]

0.20.0 (2024-05-15)
===================

Expand Down
24 changes: 12 additions & 12 deletions src/roman_datamodels/datamodels/_datamodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,22 +174,22 @@ def from_science_raw(cls, model):
ramp.groupdq = model.resultantdq.copy()

# Define how to recursively copy all attributes.
def node_update(self, other):
def node_update(ramp, other):
"""Implement update to directly access each value"""
for key in other.keys():
if key == "resultantdq":
continue
if key in self:
if isinstance(self[key], Mapping):
node_update(self[key], other.__getattr__(key))
continue
if isinstance(self[key], list):
self[key] = other.__getattr__(key).data
continue
if isinstance(self[key], np.ndarray):
self[key] = other.__getattr__(key).astype(self[key].dtype)
continue
self[key] = other.__getattr__(key)
if key in ramp:
if isinstance(ramp[key], Mapping):
node_update(getattr(ramp, key), getattr(other, key))
elif isinstance(ramp[key], list):
setattr(ramp, key, getattr(other, key).data)
elif isinstance(ramp[key], np.ndarray):
setattr(ramp, key, getattr(other, key).astype(ramp[key].dtype))
else:
setattr(ramp, key, getattr(other, key))
else:
ramp[key] = other[key]

node_update(ramp, model)

Expand Down
28 changes: 14 additions & 14 deletions src/roman_datamodels/stnode/_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,18 @@ def ctx(self):
return self._ctx

@staticmethod
def _convert_to_scalar(key, value):
def _convert_to_scalar(key, value, ref=None):
"""Find and wrap scalars in the appropriate class, if its a tagged one."""
from ._tagged import TaggedScalarNode

if isinstance(ref, TaggedScalarNode):
# we want the exact class (not possible subclasses)
if type(value) == type(ref): # noqa: E721
return value
return type(ref)(value)

if isinstance(value, TaggedScalarNode):
return value

if key in SCALAR_NODE_CLASSES_BY_KEY:
value = SCALAR_NODE_CLASSES_BY_KEY[key](value)
Expand Down Expand Up @@ -231,7 +241,7 @@ def __setattr__(self, key, value):
if key[0] != "_":

# Wrap things in the tagged scalar classes if necessary
value = self._convert_to_scalar(key, value)
value = self._convert_to_scalar(key, value, self._data.get(key))

if key in self._data or key in self._schema_attributes:
# Perform validation if enabled
Expand Down Expand Up @@ -316,22 +326,12 @@ def __setitem__(self, key, value):
"""Dictionary style access set data"""

# Convert the value to a tagged scalar if necessary
if self._tag and "/tvac" in self._tag:
value = self._convert_to_scalar("tvac_" + key, value)
elif self._tag and "/fps" in self._tag:
value = self._convert_to_scalar("fps_" + key, value)
else:
value = self._convert_to_scalar(key, value)
value = self._convert_to_scalar(key, value, self._data.get(key))

# If the value is a dictionary, loop over its keys and convert them to tagged scalars
if isinstance(value, (dict, asdf.lazy_nodes.AsdfDictNode)):
for sub_key, sub_value in value.items():
if self._tag and "/tvac" in self._tag:
value[sub_key] = self._convert_to_scalar("tvac_" + sub_key, sub_value)
elif self._tag and "/fps" in self._tag:
value[sub_key] = self._convert_to_scalar("fps_" + sub_key, sub_value)
else:
value[sub_key] = self._convert_to_scalar(sub_key, sub_value)
value[sub_key] = self._convert_to_scalar(sub_key, sub_value)

self._data[key] = value

Expand Down
41 changes: 40 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,17 +970,56 @@ def test_datamodel_save_filename(tmp_path):
(datamodels.MosaicModel, False),
],
)
def test_rampmodel_from_science_raw(model_class, expect_success):
def test_rampmodel_from_science_raw(tmp_path, model_class, expect_success):
"""Test creation of RampModel from raw science/tvac"""
model = utils.mk_datamodel(
model_class, meta={"calibration_software_version": "1.2.3", "exposure": {"read_pattern": [[1], [2], [3]]}}
)
if expect_success:
filename = tmp_path / "fancy_filename.asdf"
ramp = datamodels.RampModel.from_science_raw(model)

assert ramp.meta.calibration_software_version == model.meta.calibration_software_version
assert ramp.meta.exposure.read_pattern == model.meta.exposure.read_pattern
assert ramp.validate() is None

ramp.save(filename)
with datamodels.open(filename) as new_ramp:
assert new_ramp.meta.calibration_software_version == model.meta.calibration_software_version

else:
with pytest.raises(ValueError):
datamodels.RampModel.from_science_raw(model)


@pytest.mark.parametrize(
"model_class",
[datamodels.FpsModel, datamodels.RampModel, datamodels.ScienceRawModel, datamodels.TvacModel, datamodels.MosaicModel],
)
def test_model_assignment_access_types(model_class):
"""Test assignment and access of model keyword value via keys and dot notation"""
# Test creation
model = utils.mk_datamodel(
model_class, meta={"calibration_software_version": "1.2.3", "exposure": {"read_pattern": [[1], [2], [3]]}}
)

assert model["meta"]["filename"] == model.meta["filename"]
assert model["meta"]["filename"] == model.meta.filename
assert model.meta.filename == model.meta["filename"]
assert type(model["meta"]["filename"]) == type(model.meta["filename"]) # noqa: E721
assert type(model["meta"]["filename"]) == type(model.meta.filename) # noqa: E721
assert type(model.meta.filename) == type(model.meta["filename"]) # noqa: E721

# Test assignment
model2 = utils.mk_datamodel(model_class, meta={"calibration_software_version": "4.5.6"})

model.meta["filename"] = "Roman_keys_test.asdf"
model2.meta.filename = "Roman_dot_test.asdf"

assert model.validate() is None
assert model2.validate() is None

# Test assignment types
assert type(model["meta"]["filename"]) == type(model2.meta["filename"]) # noqa: E721
assert type(model["meta"]["filename"]) == type(model2.meta.filename) # noqa: E721
assert type(model.meta.filename) == type(model2.meta["filename"]) # noqa: E721
2 changes: 1 addition & 1 deletion tests/test_stnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def test_set_pattern_properties():
with pytest.raises(asdf.ValidationError):
mdl.phot_table.F062.pixelareasr = 3.14

# This is invalid be cause it is not a scalar
# This is invalid because it is not a scalar
with pytest.raises(asdf.ValidationError):
mdl.phot_table.F062.photmjsr = [37.0] * (u.MJy / u.sr)
with pytest.raises(asdf.ValidationError):
Expand Down