Skip to content

Commit

Permalink
Add residual image as additional optional output
Browse files Browse the repository at this point in the history
  • Loading branch information
melanieclarke committed Jan 15, 2025
1 parent c5c8501 commit 45a85a1
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 62 deletions.
4 changes: 4 additions & 0 deletions docs/jwst/extract_1d/arguments.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ Step Arguments for Slit and Slitless Spectroscopic Data
Flag to enable saving a model of the 2D flux as defined by the extraction aperture or PSF model.
If True, the model is saved to disk with suffix "scene_model".

``--save_residual_image``
Flag to enable saving the residual image (from the input minus the scene model)
If True, the model is saved to disk with suffix "residual".

Step Arguments for IFU Data
---------------------------

Expand Down
78 changes: 53 additions & 25 deletions jwst/extract_1d/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,30 +1212,25 @@ def extract_one_slit(data_model, integration, profile, bg_profile,
The input science model. May be a single slit from a MultiSlitModel
(or similar), or a single data type, like an ImageModel, SlitModel,
or CubeModel.
integration : int
For the case that data_model is a SlitModel or a CubeModel,
`integration` is the integration number. If the integration number is
not relevant (i.e. the data array is 2-D), `integration` should be -1.
profile : ndarray of float
Spatial profile indicating the aperture location. Must be a
2D image matching the input, with floating point values between 0
and 1 assigning a weight to each pixel. 0 means the pixel is not used,
1 means the pixel is fully included in the aperture.
bg_profile : ndarray of float or None
Background profile indicating any background regions to use, following
the same format as the spatial profile. Ignored if
extract_params['subtract_background'] is False.
nod_profile : ndarray of float or None
For optimal extraction, if nod subtraction was performed, a
second spatial profile is generated, modeling the negative source
in the slit. This second spatial profile may be passed in `nod_profile`
for simultaneous fitting with the primary source in `profile`.
Otherwise, `nod_profile` should be None.
extract_params : dict
Parameters read from the extract1d reference file, as returned by
`get_extract_parameters`.
Expand All @@ -1252,44 +1247,36 @@ def extract_one_slit(data_model, integration, profile, bg_profile,
point source (column "flux"). Divide `sum_flux` by `npixels` (to
compute the average) to get the array for the "surf_bright"
(surface brightness) output column.
f_var_rnoise : ndarray, 1-D
The extracted read noise variance values to go along with the
sum_flux array.
f_var_poisson : ndarray, 1-D
The extracted poisson variance values to go along with the
sum_flux array.
f_var_flat : ndarray, 1-D
The extracted flat field variance values to go along with the
sum_flux array.
background : ndarray, 1-D
The background count rate that was subtracted from the sum of
the source data values to get `sum_flux`.
b_var_rnoise : ndarray, 1-D
The extracted read noise variance values to go along with the
background array.
b_var_poisson : ndarray, 1-D
The extracted poisson variance values to go along with the
background array.
b_var_flat : ndarray, 1-D
The extracted flat field variance values to go along with the
background array.
npixels : ndarray, 1-D, float64
The number of pixels that were added together to get `sum_flux`,
including any fractional pixels included via non-integer weights
in the input profile.
scene_model : ndarray, 2-D, float64
A 2D model of the flux in the spectral image, corresponding to
the extracted aperture.
residual : ndarray, 2-D, float64
Residual image from the input minus the scene model.
"""
# Get the data and variance arrays
if integration > -1:
Expand Down Expand Up @@ -1352,17 +1339,21 @@ def extract_one_slit(data_model, integration, profile, bg_profile,
# of the number of input profiles. It may need to be transposed to match
# the input data.
scene_model = result[-1]
residual = data - scene_model
if extract_params['dispaxis'] == HORIZONTAL:
first_result.append(scene_model)
first_result.append(residual)
else:
first_result.append(scene_model.T)
first_result.append(residual.T)
return first_result


def create_extraction(input_model, slit, output_model,
extract_ref_dict, slitname, sp_order, exp_type,
apcorr_ref_model=None, log_increment=50,
save_profile=False, save_scene_model=False, **kwargs):
save_profile=False, save_scene_model=False,
save_residual_image=False, **kwargs):
"""Extract spectra from an input model and append to an output model.
Input data, specified in the `slit` or `input_model`, should contain data
Expand Down Expand Up @@ -1436,6 +1427,9 @@ def create_extraction(input_model, slit, output_model,
save_scene_model : bool, optional
If True, the flux model created during extraction will be returned
as an ImageModel or CubeModel. If False, the return value is None.
save_residual_image : bool, optional
If True, the residual image (from input minus scene model) will be returned
as an ImageModel or CubeModel. If False, the return value is None.
kwargs : dict, optional
Additional options to pass to `get_extract_parameters`.
Expand All @@ -1449,6 +1443,9 @@ def create_extraction(input_model, slit, output_model,
If `save_scene_model` is True, the return value is an ImageModel or CubeModel
matching the input data, containing the flux model generated during
extraction.
residual : ImageModel, CubeModel, or None
If `save_residual_image` is True, the return value is an ImageModel or CubeModel
matching the input data, containing the residual image.
"""

if slit is None:
Expand Down Expand Up @@ -1587,7 +1584,7 @@ def create_extraction(input_model, slit, output_model,
integrations = range(shape[0])
progress_msg_printed = False

# Set up a flux model to update if desired
# Set up a scene model and residual image to update if desired
if save_scene_model:
if len(integrations) > 1:
scene_model = datamodels.CubeModel(shape)
Expand All @@ -1597,20 +1594,34 @@ def create_extraction(input_model, slit, output_model,
scene_model.name = slitname
else:
scene_model = None
if save_residual_image:
if len(integrations) > 1:
residual = datamodels.CubeModel(shape)
else:
residual = datamodels.ImageModel()
residual.update(input_model, only='PRIMARY')
residual.name = slitname
else:
residual = None

# Extract each integration
for integ in integrations:
(sum_flux, f_var_rnoise, f_var_poisson,
f_var_flat, background, b_var_rnoise, b_var_poisson,
b_var_flat, npixels, scene_model_2d) = extract_one_slit(
b_var_flat, npixels, scene_model_2d, residual_2d) = extract_one_slit(
data_model, integ, profile, bg_profile, nod_profile, extract_params)

# Save the flux model
# Save the scene model and residual
if save_scene_model:
if isinstance(scene_model, datamodels.CubeModel):
scene_model.data[integ] = scene_model_2d
else:
scene_model.data = scene_model_2d
if save_residual_image:
if isinstance(residual, datamodels.CubeModel):
residual.data[integ] = residual_2d
else:
residual.data = residual_2d

# Convert the sum to an average, for surface brightness.
npixels_temp = np.where(npixels > 0., npixels, 1.)
Expand Down Expand Up @@ -1768,7 +1779,7 @@ def create_extraction(input_model, slit, output_model,
if not progress_msg_printed:
log.info(f"All {input_model.data.shape[0]} integrations done")

return profile_model, scene_model
return profile_model, scene_model, residual


def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None,
Expand All @@ -1777,7 +1788,8 @@ def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None,
log_increment=50, subtract_background=None,
use_source_posn=None, position_offset=0.0,
model_nod_pair=False, optimize_psf_location=True,
save_profile=False, save_scene_model=False):
save_profile=False, save_scene_model=False,
save_residual_image=False):
"""Extract all 1-D spectra from an input model.
Parameters
Expand Down Expand Up @@ -1837,6 +1849,10 @@ def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None,
If True, a model of the 2D flux as defined by the extraction aperture
is returned as an ImageModel or CubeModel. If False, the return value
is None.
save_residual_image : bool
If True, the residual image (from the input minus the scene model)
is returned as an ImageModel or CubeModel. If False, the return value
is None.
Returns
-------
Expand All @@ -1851,6 +1867,10 @@ def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None,
If `save_scene_model` is True, the return value is an ImageModel or CubeModel
matching the input data, containing a model of the flux as defined by the
aperture, created during extraction. Otherwise, the return value is None.
residual : ModelContainer, ImageModel, CubeModel, or None
If `save_residual_image` is True, the return value is an ImageModel or CubeModel
matching the input data, containing the residual image (from the input minus
the scene model). Otherwise, the return value is None.
"""
# Set "meta_source" to either the first model in a container,
# or the individual input model, for convenience
Expand Down Expand Up @@ -1895,6 +1915,7 @@ def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None,
# Handle inputs that contain one or more slit models
profile_model = None
scene_model = None
residual = None
if isinstance(input_model, (ModelContainer, datamodels.MultiSlitModel)):
if isinstance(input_model, ModelContainer):
slits = input_model
Expand All @@ -1910,6 +1931,8 @@ def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None,
profile_model = ModelContainer()
if save_scene_model:
scene_model = ModelContainer()
if save_residual_image:
residual = ModelContainer()

for slit in slits: # Loop over the slits in the input model
log.info(f'Working on slit {slit.name}')
Expand All @@ -1928,11 +1951,12 @@ def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None,
continue

try:
profile, slit_scene_model = create_extraction(
profile, slit_scene_model, slit_residual = create_extraction(
meta_source, slit, output_model,
extract_ref_dict, slitname, sp_order, exp_type,
apcorr_ref_model=apcorr_ref_model, log_increment=log_increment,
save_profile=save_profile, save_scene_model=save_scene_model,
save_residual_image=save_residual_image,
psf_ref_name=psf_ref_name,
extraction_type=extraction_type,
smoothing_length=smoothing_length,
Expand All @@ -1949,6 +1973,8 @@ def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None,
profile_model.append(profile)
if save_scene_model:
scene_model.append(slit_scene_model)
if save_residual_image:
residual.append(slit_residual)

else:
# Define source of metadata
Expand All @@ -1969,11 +1995,12 @@ def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None,
else:
log.info(f'Processing spectral order {sp_order}')
try:
profile_model, scene_model = create_extraction(
profile_model, scene_model, residual = create_extraction(
input_model, slit, output_model,
extract_ref_dict, slitname, sp_order, exp_type,
apcorr_ref_model=apcorr_ref_model, log_increment=log_increment,
save_profile=save_profile, save_scene_model=save_scene_model,
save_residual_image=save_residual_image,
psf_ref_name=psf_ref_name,
extraction_type=extraction_type,
smoothing_length=smoothing_length,
Expand Down Expand Up @@ -2012,11 +2039,12 @@ def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None,
log.info(f'Processing spectral order {sp_order}')

try:
profile_model, scene_model = create_extraction(
profile_model, scene_model, residual = create_extraction(
input_model, slit, output_model,
extract_ref_dict, slitname, sp_order, exp_type,
apcorr_ref_model=apcorr_ref_model, log_increment=log_increment,
save_profile=save_profile, save_scene_model=save_scene_model,
save_residual_image=save_residual_image,
psf_ref_name=psf_ref_name,
extraction_type=extraction_type,
smoothing_length=smoothing_length,
Expand Down Expand Up @@ -2051,4 +2079,4 @@ def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None,
# x1d product just to hold this keyword.
output_model.meta.target.source_type = None

return output_model, profile_model, scene_model
return output_model, profile_model, scene_model, residual
13 changes: 12 additions & 1 deletion jwst/extract_1d/extract_1d_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ class Extract1dStep(Step):
If True, a model of the 2D flux as defined by the extraction aperture
is saved to disk. Ignored for IFU and NIRISS SOSS extractions.
save_residual_image : bool
If True, the residual image (from the input minus the scene model)
is saved to disk. Ignored for IFU and NIRISS SOSS extractions.
center_xy : int or None
A list of 2 pixel coordinate values at which to place the center
of the IFU extraction aperture, overriding any centering done by the step.
Expand Down Expand Up @@ -193,6 +197,7 @@ class Extract1dStep(Step):
log_increment = integer(default=50) # increment for multi-integration log messages
save_profile = boolean(default=False) # save spatial profile to disk
save_scene_model = boolean(default=False) # save flux model to disk
save_residual_image = boolean(default=False) # save residual image to disk
center_xy = float_list(min=2, max=2, default=None) # IFU extraction x/y center
ifu_autocen = boolean(default=False) # Auto source centering for IFU point source data.
Expand Down Expand Up @@ -448,12 +453,13 @@ def process(self, input):

profile = None
scene_model = None
residual = None
if isinstance(model, datamodels.IFUCubeModel):
# Call the IFU specific extraction routine
extracted = self._extract_ifu(model, exp_type, extract_ref, apcorr_ref)
else:
# Call the general extraction routine
extracted, profile, scene_model = extract.run_extract1d(
extracted, profile, scene_model, residual = extract.run_extract1d(
model,
extract_ref,
apcorr_ref,
Expand All @@ -470,6 +476,7 @@ def process(self, input):
self.optimize_psf_location,
self.save_profile,
self.save_scene_model,
self.save_residual_image,
)

# Set the step flag to complete in each model
Expand All @@ -489,6 +496,10 @@ def process(self, input):
if self.save_scene_model and scene_model is not None:
self._save_intermediate(scene_model, 'scene_model', idx)

# Save residual if needed
if self.save_residual_image and residual is not None:
self._save_intermediate(residual, 'residual', idx)

# If only one result, return the model instead of the container
if len(result) == 1:
result = result[0]
Expand Down
Loading

0 comments on commit 45a85a1

Please sign in to comment.