diff --git a/changes/1438.flatfield.rst b/changes/1438.flatfield.rst new file mode 100644 index 000000000..d98a4f9ef --- /dev/null +++ b/changes/1438.flatfield.rst @@ -0,0 +1 @@ +Make var_flat optional. diff --git a/docs/roman/flatfield/main.rst b/docs/roman/flatfield/main.rst index 1a8ca75df..ec01cc1ae 100644 --- a/docs/roman/flatfield/main.rst +++ b/docs/roman/flatfield/main.rst @@ -59,3 +59,10 @@ and finally the error that is associated with the science data is given by, The total ERR array in the science exposure is updated as the square root of the quadratic sum of VAR_POISSON, VAR_RNOISE, and VAR_FLAT. + +Note that by default we do not compute VAR_FLAT nor include its +contribution to ERR, unless the "include_var_flat" is specified. This +means that the uncertainties on very bright pixels are +underestimated. However, other effects like charge migration, +saturation, and non-linearity can be important at these flux levels, +and their contributions to the uncertainty are never included. diff --git a/pyproject.toml b/pyproject.toml index 16fd4d7ae..d9089ef81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,8 @@ dependencies = [ "photutils >=1.13.0", "pyparsing >=2.4.7", "requests >=2.26", - # "roman_datamodels>=0.22.0,<0.23.0", + "roman_datamodels>=0.22.0,<0.23.0", + "rad @ git+https://github.com/schlafly/rad.git@remove-var-flat", "roman_datamodels @ git+https://github.com/spacetelescope/roman_datamodels.git", "scipy >=1.11", # "stcal>=1.10.0,<1.11.0", diff --git a/romancal/flatfield/flat_field.py b/romancal/flatfield/flat_field.py index b74475274..44c899c65 100644 --- a/romancal/flatfield/flat_field.py +++ b/romancal/flatfield/flat_field.py @@ -13,7 +13,7 @@ MICRONS_100 = 1.0e-4 # 100 microns, in meters -def do_correction(input_model, flat=None): +def do_correction(input_model, flat=None, include_var_flat=False): """Flat-field a Roman data model using a flat-field model Parameters @@ -24,6 +24,9 @@ def do_correction(input_model, flat=None): flat : Roman data model, or None Data model containing flat-field for all instruments + include_var_flat : bool + compute & store the flat field variance? + Returns ------- output_model : data model @@ -31,12 +34,12 @@ def do_correction(input_model, flat=None): The data is modified in place. """ - do_flat_field(input_model, flat) + do_flat_field(input_model, flat, include_var_flat=include_var_flat) return input_model -def do_flat_field(output_model, flat_model): +def do_flat_field(output_model, flat_model, include_var_flat=False): """Apply flat-fielding, and update the output model. Parameters @@ -46,6 +49,9 @@ def do_flat_field(output_model, flat_model): flat_model : Roman data model data model containing flat-field + + include_var_flat : bool + compute & store the flat field variance? """ if flat_model is not None and output_model.data.shape != flat_model.data.shape: # Check to see if flat data array is smaller than science data @@ -61,11 +67,11 @@ def do_flat_field(output_model, flat_model): log.info("Skipping flat field - no flat reference file.") output_model.meta.cal_step.flat_field = "SKIPPED" else: - apply_flat_field(output_model, flat_model) + apply_flat_field(output_model, flat_model, include_var_flat=include_var_flat) output_model.meta.cal_step.flat_field = "COMPLETE" -def apply_flat_field(science, flat): +def apply_flat_field(science, flat, include_var_flat=False): """Flat field the data and error arrays. Extended summary @@ -82,6 +88,9 @@ def apply_flat_field(science, flat): flat : Roman data model flat field data model + + include_var_flat : bool + compute & store the flat vield variance? """ flat_data = flat.data.copy() flat_dq = flat.dq.copy() @@ -111,13 +120,18 @@ def apply_flat_field(science, flat): flat_data_squared = flat_data**2 science.var_poisson /= flat_data_squared science.var_rnoise /= flat_data_squared - try: - science.var_flat = science.data**2 / flat_data_squared * flat_err**2 - except AttributeError: - science["var_flat"] = np.zeros(shape=science.data.shape, dtype=np.float32) - science.var_flat = science.data**2 / flat_data_squared * flat_err**2 - science.err = np.sqrt(science.var_poisson + science.var_rnoise + science.var_flat) + total_var = science.var_poisson + science.var_rnoise + if include_var_flat: + var_flat = science.data**2 / flat_data_squared * flat_err**2 + try: + science.var_flat = var_flat + except AttributeError: + science["var_flat"] = np.zeros(shape=science.data.shape, dtype=np.float32) + science.var_flat = var_flat + total_var += science.var_flat + + science.err = np.sqrt(total_var) # Combine the science and flat DQ arrays science.dq = np.bitwise_or(science.dq, flat_dq) diff --git a/romancal/flatfield/flat_field_step.py b/romancal/flatfield/flat_field_step.py index 09b981770..54588a600 100644 --- a/romancal/flatfield/flat_field_step.py +++ b/romancal/flatfield/flat_field_step.py @@ -14,6 +14,9 @@ class FlatFieldStep(RomanStep): """Flat-field a science image using a flatfield reference image.""" class_alias = "flat_field" + spec = """ + include_var_flat = boolean(default=False) # include flat field variance + """ # noqa: E501 reference_file_types = ["flat"] @@ -38,8 +41,7 @@ def process(self, input_model): # Do the flat-field correction output_model = flat_field.do_correction( - input_model, - reference_file_model, + input_model, reference_file_model, include_var_flat=self.include_var_flat ) # Close reference file diff --git a/romancal/flatfield/tests/test_flatfield.py b/romancal/flatfield/tests/test_flatfield.py index f246fae2e..c9045a4eb 100644 --- a/romancal/flatfield/tests/test_flatfield.py +++ b/romancal/flatfield/tests/test_flatfield.py @@ -86,6 +86,21 @@ def test_crds_temporal_match(instrument, exptype): ) +def test_skip_var_flat(): + """Test that we don't populate var_flat if requested.""" + + wfi_image1 = maker_utils.mk_level2_image() + wfi_image2 = maker_utils.mk_level2_image() + del wfi_image1["var_flat"] + del wfi_image2["var_flat"] + wfi_image_model1 = ImageModel(wfi_image1) + wfi_image_model2 = ImageModel(wfi_image2) + result1 = FlatFieldStep.call(wfi_image_model1, include_var_flat=False) + result2 = FlatFieldStep.call(wfi_image_model2, include_var_flat=True) + assert not hasattr(result1, "var_flat") + assert hasattr(result2, "var_flat") + + @pytest.mark.parametrize( "instrument", [ diff --git a/romancal/flux/flux_step.py b/romancal/flux/flux_step.py index 6e793db32..e4f08b5d0 100644 --- a/romancal/flux/flux_step.py +++ b/romancal/flux/flux_step.py @@ -99,7 +99,9 @@ def apply_flux_correction(model): """ # Define the various arrays to be converted. DATA = ("data", "err") - VARIANCES = ("var_rnoise", "var_poisson", "var_flat") + VARIANCES = ("var_rnoise", "var_poisson") + if hasattr(model, "var_flat"): + VARIANCES = VARIANCES + ("var_flat",) if model.meta.cal_step["flux"] == "COMPLETE": message = ( diff --git a/romancal/regtest/test_resample.py b/romancal/regtest/test_resample.py index d669eecc6..b6b2b79c7 100644 --- a/romancal/regtest/test_resample.py +++ b/romancal/regtest/test_resample.py @@ -59,14 +59,12 @@ def test_resample_single_file(rtdata, ignore_asdf_paths): "err", "var_poisson", "var_rnoise", - "var_flat", ] ) }""" ) assert all( - hasattr(resample_out, x) - for x in ["data", "err", "var_poisson", "var_rnoise", "var_flat"] + hasattr(resample_out, x) for x in ["data", "err", "var_poisson", "var_rnoise"] ) step.log.info( @@ -94,14 +92,14 @@ def test_resample_single_file(rtdata, ignore_asdf_paths): np.isnan(getattr(resample_out, x)), np.equal(getattr(resample_out, x), 0) ) - ) > 0 for x in ["var_poisson", "var_rnoise", "var_flat"] + ) > 0 for x in ["var_poisson", "var_rnoise"] ) }""" ) assert all( np.sum(np.isnan(getattr(resample_out, x))) - for x in ["var_poisson", "var_rnoise", "var_flat"] + for x in ["var_poisson", "var_rnoise"] ) step.log.info( diff --git a/romancal/resample/resample.py b/romancal/resample/resample.py index aa01c1b11..754f8b938 100644 --- a/romancal/resample/resample.py +++ b/romancal/resample/resample.py @@ -169,6 +169,9 @@ def __init__( with self.input_models: models = list(self.input_models) + self.all_have_var_flat = np.all( + [hasattr(model, "var_flat") for model in models] + ) # update meta.basic populate_mosaic_basic(self.blank_output, models) @@ -201,6 +204,9 @@ def __init__( for i, m in enumerate(models): self.input_models.shelve(m, i, modify=False) + if not self.all_have_var_flat: + del self.blank_output._instance["var_flat"] + def do_drizzle(self): """Pick the correct drizzling mode based on ``self.single``.""" if self.single: @@ -355,6 +361,9 @@ def resample_many_to_one(self): ) log.info("Resampling science data") + + all_have_var_flat = True + with self.input_models: for i, img in enumerate(self.input_models): inwht = resample_utils.build_driz_weight( @@ -396,22 +405,17 @@ def resample_many_to_one(self): # Resample variances array in self.input_models to output_model self.resample_variance_array("var_rnoise", output_model) self.resample_variance_array("var_poisson", output_model) - self.resample_variance_array("var_flat", output_model) + if self.all_have_var_flat: + self.resample_variance_array("var_flat", output_model) # Make exposure time image exptime_tot = self.resample_exposure_time(output_model) - # TODO: fix unit here - output_model.err = np.sqrt( - np.nansum( - [ - output_model.var_rnoise, - output_model.var_poisson, - output_model.var_flat, - ], - axis=0, - ) - ) + all_vars = [output_model.var_rnoise, output_model.var_poisson] + if self.all_have_var_flat: + all_vars = all_vars + [output_model.var_flat] + + output_model.err = np.sqrt(np.nansum(all_vars, axis=0)) self.update_exposure_times(output_model, exptime_tot) diff --git a/romancal/resample/tests/test_resample.py b/romancal/resample/tests/test_resample.py index ec8a27b69..ca29c47ea 100644 --- a/romancal/resample/tests/test_resample.py +++ b/romancal/resample/tests/test_resample.py @@ -585,6 +585,28 @@ def test_update_exposure_times_same_sca_different_exposures(exposure_1, exposure output_models.shelve(output_model, 0, modify=False) +@pytest.mark.parametrize("include_var_flat", [False, True]) +def test_var_flat_presence(exposure_1, include_var_flat): + """Test that var_flat is included or excluded depending on its presence in the underlying exposures.""" + if not include_var_flat: + exposure_1 = [e.copy() for e in exposure_1] + for e in exposure_1: + del e._instance["var_flat"] + input_models = ModelLibrary(exposure_1) + resample_data = ResampleData(input_models) + + output_models = resample_data.resample_many_to_one() + with output_models: + output_model = output_models.borrow(0) + + if not include_var_flat: + assert not hasattr(output_model, "var_flat") + else: + assert hasattr(output_model, "var_flat") + + output_models.shelve(output_model, 0, modify=False) + + @pytest.mark.parametrize( "name", ["var_rnoise", "var_poisson", "var_flat"],