From c19cde4d7262f389396f9f989a1d61a5c97ed910 Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Fri, 7 Jul 2023 17:01:10 -0400 Subject: [PATCH] Update maker_utils so that saved nodes are still returned (#218) --- CHANGES.rst | 2 + src/roman_datamodels/maker_utils/_base.py | 22 ++++ .../maker_utils/_datamodels.py | 59 ++-------- .../maker_utils/_ref_files.py | 101 +++--------------- tests/test_open.py | 19 ++++ 5 files changed, 67 insertions(+), 136 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 7edf4430..496e4c71 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -5,6 +5,8 @@ - Add tests to ensure consistency between file-level schemas in RAD and the corresponding datamodels in ``roman_datamodels``. [#214] +- Make ``maker_utils`` return the node when writing the node to a file. [#218] + 0.16.1 (2023-06-27) =================== diff --git a/src/roman_datamodels/maker_utils/_base.py b/src/roman_datamodels/maker_utils/_base.py index 3a38cb14..822118bf 100644 --- a/src/roman_datamodels/maker_utils/_base.py +++ b/src/roman_datamodels/maker_utils/_base.py @@ -1,4 +1,26 @@ +import asdf + NONUM = -999999 NOSTR = "dummy value" MESSAGE = "This function assumes shape is 2D, but it was given at least 3 dimensions" + + +def save_node(node, filepath=None): + """ + Save the node to a file if there is a file path given, and return the node. + + Parameters + ---------- + node: DNode + The node to save. + filepath: str + (optional) File name and path to write model to. + """ + + if filepath: + af = asdf.AsdfFile() + af.tree = {"roman": node} + af.write_to(filepath) + + return node diff --git a/src/roman_datamodels/maker_utils/_datamodels.py b/src/roman_datamodels/maker_utils/_datamodels.py index 5737da7a..bdd94bec 100644 --- a/src/roman_datamodels/maker_utils/_datamodels.py +++ b/src/roman_datamodels/maker_utils/_datamodels.py @@ -1,12 +1,11 @@ import warnings -import asdf import numpy as np from astropy import units as u from roman_datamodels import stnode -from ._base import MESSAGE +from ._base import MESSAGE, save_node from ._common_meta import mk_common_meta, mk_guidewindow_meta, mk_msos_stack_meta, mk_photometry_meta, mk_resample_meta from ._tagged_nodes import mk_cal_logs @@ -48,12 +47,7 @@ def mk_level1_science_raw(*, shape=(8, 4096, 4096), filepath=None, **kwargs): "amp33", u.Quantity(np.zeros((n_groups, 4096, 128), dtype=np.uint16), u.DN, dtype=np.uint16) ) - if filepath: - af = asdf.AsdfFile() - af.tree = {"roman": wfi_science_raw} - af.write_to(filepath) - else: - return wfi_science_raw + return save_node(wfi_science_raw, filepath=filepath) def mk_level2_image(*, shape=(4088, 4088), n_groups=8, filepath=None, **kwargs): @@ -135,12 +129,7 @@ def mk_level2_image(*, shape=(4088, 4088), n_groups=8, filepath=None, **kwargs): ) wfi_image["cal_logs"] = mk_cal_logs(**kwargs) - if filepath: - af = asdf.AsdfFile() - af.tree = {"roman": wfi_image} - af.write_to(filepath) - else: - return wfi_image + return save_node(wfi_image, filepath=filepath) def mk_level3_mosaic(*, shape=(4088, 4088), n_images=2, filepath=None, **kwargs): @@ -193,12 +182,7 @@ def mk_level3_mosaic(*, shape=(4088, 4088), n_images=2, filepath=None, **kwargs) ) wfi_mosaic["cal_logs"] = mk_cal_logs(**kwargs) - if filepath: - af = asdf.AsdfFile() - af.tree = {"roman": wfi_mosaic} - af.write_to(filepath) - else: - return wfi_mosaic + return save_node(wfi_mosaic, filepath=filepath) def mk_msos_stack(*, shape=(4096, 4096), filepath=None, **kwargs): @@ -232,12 +216,7 @@ def mk_msos_stack(*, shape=(4096, 4096), filepath=None, **kwargs): msos_stack["mask"] = kwargs.get("mask", np.zeros(shape, dtype=np.uint8)) msos_stack["coverage"] = kwargs.get("coverage", np.zeros(shape, dtype=np.uint8)) - if filepath: - af = asdf.AsdfFile() - af.tree = {"roman": msos_stack} - af.write_to(filepath) - else: - return msos_stack + return save_node(msos_stack, filepath=filepath) def mk_ramp(*, shape=(8, 4096, 4096), filepath=None, **kwargs): @@ -294,12 +273,7 @@ def mk_ramp(*, shape=(8, 4096, 4096), filepath=None, **kwargs): ramp["groupdq"] = kwargs.get("groupdq", np.zeros(shape, dtype=np.uint8)) ramp["err"] = kwargs.get("err", u.Quantity(np.zeros(shape, dtype=np.float32), u.DN, dtype=np.float32)) - if filepath: - af = asdf.AsdfFile() - af.tree = {"roman": ramp} - af.write_to(filepath) - else: - return ramp + return save_node(ramp, filepath=filepath) def mk_ramp_fit_output(*, shape=(8, 4096, 4096), filepath=None, **kwargs): @@ -346,12 +320,7 @@ def mk_ramp_fit_output(*, shape=(8, 4096, 4096), filepath=None, **kwargs): "var_rnoise", u.Quantity(np.zeros(shape, dtype=np.float32), u.electron**2 / u.s**2, dtype=np.float32) ) - if filepath: - af = asdf.AsdfFile() - af.tree = {"roman": rampfitoutput} - af.write_to(filepath) - else: - return rampfitoutput + return save_node(rampfitoutput, filepath=filepath) def mk_rampfitoutput(**kwargs): @@ -414,12 +383,7 @@ def mk_associations(*, shape=(2, 3, 1), filepath=None, **kwargs): file_idx += 1 associations["products"].append({"name": f"product{product_idx}", "members": members_lst}) - if filepath: - af = asdf.AsdfFile() - af.tree = {"roman": associations} - af.write_to(filepath) - else: - return associations + return save_node(associations, filepath=filepath) def mk_guidewindow(*, shape=(2, 8, 16, 32, 32), filepath=None, **kwargs): @@ -454,9 +418,4 @@ def mk_guidewindow(*, shape=(2, 8, 16, 32, 32), filepath=None, **kwargs): ) guidewindow["amp33"] = kwargs.get("amp33", u.Quantity(np.zeros(shape, dtype=np.uint16), u.DN, dtype=np.uint16)) - if filepath: - af = asdf.AsdfFile() - af.tree = {"roman": guidewindow} - af.write_to(filepath) - else: - return guidewindow + return save_node(guidewindow, filepath=filepath) diff --git a/src/roman_datamodels/maker_utils/_ref_files.py b/src/roman_datamodels/maker_utils/_ref_files.py index a3bfd958..df8effca 100644 --- a/src/roman_datamodels/maker_utils/_ref_files.py +++ b/src/roman_datamodels/maker_utils/_ref_files.py @@ -1,13 +1,12 @@ import warnings -import asdf import numpy as np from astropy import units as u from astropy.modeling import models from roman_datamodels import stnode -from ._base import MESSAGE +from ._base import MESSAGE, save_node from ._common_meta import ( mk_ref_common, mk_ref_dark_meta, @@ -65,12 +64,7 @@ def mk_flat(*, shape=(4096, 4096), filepath=None, **kwargs): flatref["dq"] = kwargs.get("dq", np.zeros(shape, dtype=np.uint32)) flatref["err"] = kwargs.get("err", np.zeros(shape, dtype=np.float32)) - if filepath: - af = asdf.AsdfFile() - af.tree = {"roman": flatref} - af.write_to(filepath) - else: - return flatref + return save_node(flatref, filepath=filepath) def mk_dark(*, shape=(2, 4096, 4096), filepath=None, **kwargs): @@ -101,12 +95,7 @@ def mk_dark(*, shape=(2, 4096, 4096), filepath=None, **kwargs): darkref["dq"] = kwargs.get("dq", np.zeros(shape[1:], dtype=np.uint32)) darkref["err"] = kwargs.get("err", u.Quantity(np.zeros(shape, dtype=np.float32), u.DN, dtype=np.float32)) - if filepath: - af = asdf.AsdfFile() - af.tree = {"roman": darkref} - af.write_to(filepath) - else: - return darkref + return save_node(darkref, filepath=filepath) def mk_distortion(*, filepath=None, **kwargs): @@ -131,12 +120,7 @@ def mk_distortion(*, filepath=None, **kwargs): "coordinate_distortion_transform", models.Shift(1) & models.Shift(2) ) - if filepath: - af = asdf.AsdfFile() - af.tree = {"roman": distortionref} - af.write_to(filepath) - else: - return distortionref + return save_node(distortionref, filepath=filepath) def mk_gain(*, shape=(4096, 4096), filepath=None, **kwargs): @@ -167,12 +151,7 @@ def mk_gain(*, shape=(4096, 4096), filepath=None, **kwargs): gainref["data"] = kwargs.get("data", u.Quantity(np.zeros(shape, dtype=np.float32), u.electron / u.DN, dtype=np.float32)) - if filepath: - af = asdf.AsdfFile() - af.tree = {"roman": gainref} - af.write_to(filepath) - else: - return gainref + return save_node(gainref, filepath=filepath) def mk_ipc(*, shape=(3, 3), filepath=None, **kwargs): @@ -207,12 +186,7 @@ def mk_ipc(*, shape=(3, 3), filepath=None, **kwargs): ipcref["data"] = np.zeros(shape, dtype=np.float32) ipcref["data"][int(np.floor(shape[0] / 2))][int(np.floor(shape[1] / 2))] = 1.0 - if filepath: - af = asdf.AsdfFile() - af.tree = {"roman": ipcref} - af.write_to(filepath) - else: - return ipcref + return save_node(ipcref, filepath=filepath) def mk_linearity(*, shape=(2, 4096, 4096), filepath=None, **kwargs): @@ -242,12 +216,7 @@ def mk_linearity(*, shape=(2, 4096, 4096), filepath=None, **kwargs): linearityref["dq"] = kwargs.get("dq", np.zeros(shape[1:], dtype=np.uint32)) linearityref["coeffs"] = kwargs.get("coeffs", np.zeros(shape, dtype=np.float32)) - if filepath: - af = asdf.AsdfFile() - af.tree = {"roman": linearityref} - af.write_to(filepath) - else: - return linearityref + return save_node(linearityref, filepath=filepath) def mk_inverse_linearity(*, shape=(2, 4096, 4096), filepath=None, **kwargs): @@ -277,12 +246,7 @@ def mk_inverse_linearity(*, shape=(2, 4096, 4096), filepath=None, **kwargs): inverselinearityref["dq"] = kwargs.get("dq", np.zeros(shape[1:], dtype=np.uint32)) inverselinearityref["coeffs"] = kwargs.get("coeffs", np.zeros(shape, dtype=np.float32)) - if filepath: - af = asdf.AsdfFile() - af.tree = {"roman": inverselinearityref} - af.write_to(filepath) - else: - return inverselinearityref + return save_node(inverselinearityref, filepath=filepath) def mk_mask(*, shape=(4096, 4096), filepath=None, **kwargs): @@ -313,12 +277,7 @@ def mk_mask(*, shape=(4096, 4096), filepath=None, **kwargs): maskref["dq"] = kwargs.get("dq", np.zeros(shape, dtype=np.uint32)) - if filepath: - af = asdf.AsdfFile() - af.tree = {"roman": maskref} - af.write_to(filepath) - else: - return maskref + return save_node(maskref, filepath=filepath) def mk_pixelarea(*, shape=(4096, 4096), filepath=None, **kwargs): @@ -349,12 +308,7 @@ def mk_pixelarea(*, shape=(4096, 4096), filepath=None, **kwargs): pixelarearef["data"] = kwargs.get("data", np.zeros(shape, dtype=np.float32)) - if filepath: - af = asdf.AsdfFile() - af.tree = {"roman": pixelarearef} - af.write_to(filepath) - else: - return pixelarearef + return save_node(pixelarearef, filepath=filepath) def _mk_phot_table_entry(key, **kwargs): @@ -405,12 +359,7 @@ def mk_wfi_img_photom(*, filepath=None, **kwargs): wfi_img_photomref["phot_table"] = _mk_phot_table(**kwargs.get("phot_table", {})) - if filepath: - af = asdf.AsdfFile() - af.tree = {"roman": wfi_img_photomref} - af.write_to(filepath) - else: - return wfi_img_photomref + return save_node(wfi_img_photomref, filepath=filepath) def mk_readnoise(*, shape=(4096, 4096), filepath=None, **kwargs): @@ -441,12 +390,7 @@ def mk_readnoise(*, shape=(4096, 4096), filepath=None, **kwargs): readnoiseref["data"] = kwargs.get("data", u.Quantity(np.zeros(shape, dtype=np.float32), u.DN, dtype=np.float32)) - if filepath: - af = asdf.AsdfFile() - af.tree = {"roman": readnoiseref} - af.write_to(filepath) - else: - return readnoiseref + return save_node(readnoiseref, filepath=filepath) def mk_saturation(*, shape=(4096, 4096), filepath=None, **kwargs): @@ -478,12 +422,7 @@ def mk_saturation(*, shape=(4096, 4096), filepath=None, **kwargs): saturationref["dq"] = kwargs.get("dq", np.zeros(shape, dtype=np.uint32)) saturationref["data"] = kwargs.get("data", u.Quantity(np.zeros(shape, dtype=np.float32), u.DN, dtype=np.float32)) - if filepath: - af = asdf.AsdfFile() - af.tree = {"roman": saturationref} - af.write_to(filepath) - else: - return saturationref + return save_node(saturationref, filepath=filepath) def mk_superbias(*, shape=(4096, 4096), filepath=None, **kwargs): @@ -516,12 +455,7 @@ def mk_superbias(*, shape=(4096, 4096), filepath=None, **kwargs): superbiasref["dq"] = kwargs.get("dq", np.zeros(shape, dtype=np.uint32)) superbiasref["err"] = kwargs.get("err", np.zeros(shape, dtype=np.float32)) - if filepath: - af = asdf.AsdfFile() - af.tree = {"roman": superbiasref} - af.write_to(filepath) - else: - return superbiasref + return save_node(superbiasref, filepath=filepath) def mk_refpix(*, shape=(32, 286721), filepath=None, **kwargs): @@ -568,9 +502,4 @@ def mk_refpix(*, shape=(32, 286721), filepath=None, **kwargs): refpix["zeta"] = kwargs.get("zeta", np.zeros(shape, dtype=np.complex128)) refpix["alpha"] = kwargs.get("alpha", np.zeros(shape, dtype=np.complex128)) - if filepath: - af = asdf.AsdfFile() - af.tree = {"roman": refpix} - af.write_to(filepath) - else: - return refpix + return save_node(refpix, filepath=filepath) diff --git a/tests/test_open.py b/tests/test_open.py index 5224134a..91c5af76 100644 --- a/tests/test_open.py +++ b/tests/test_open.py @@ -12,6 +12,7 @@ from roman_datamodels import datamodels from roman_datamodels import maker_utils as utils from roman_datamodels import stnode +from roman_datamodels.testing import assert_node_equal def test_asdf_file_input(): @@ -209,18 +210,36 @@ def test_no_memmap(tmp_path, kwargs): assert (model.data == data).all() +@pytest.mark.parametrize("node_class", [node for node in datamodels.MODEL_REGISTRY]) +@pytest.mark.filterwarnings("ignore:This function assumes shape is 2D") +@pytest.mark.filterwarnings("ignore:Input shape must be 5D") +def test_node_round_trip(tmp_path, node_class): + file_path = tmp_path / "test.asdf" + + # Create/return a node and write it to disk, then check if the node round trips + node = utils.mk_node(node_class, filepath=file_path, shape=(2, 8, 8)) + with asdf.open(file_path) as af: + assert_node_equal(af.tree["roman"], node) + + @pytest.mark.parametrize("node_class", [node for node in datamodels.MODEL_REGISTRY]) @pytest.mark.filterwarnings("ignore:This function assumes shape is 2D") @pytest.mark.filterwarnings("ignore:Input shape must be 5D") def test_opening_model(tmp_path, node_class): file_path = tmp_path / "test.asdf" + # Create a node and write it to disk utils.mk_node(node_class, filepath=file_path, shape=(2, 8, 8)) + + # Opened saved file as a datamodel with datamodels.open(file_path) as model: + # Check that some of read data is correct if node_class == stnode.Associations: assert model.asn_type == "image" else: assert model.meta.instrument.optical_element == "F158" + + # Check that the model is the correct type assert isinstance(model, datamodels.MODEL_REGISTRY[node_class])