Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes error propagation #200

Merged
merged 5 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions python/lvmdrp/core/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3126,10 +3126,6 @@ def reject_cosmics(self, sigma_det=5, rlim=1.2, iterations=5, fwhm_gauss=[2.0,2.

if inplace:
self._data = out._data
if self._error is None:
self._error = out._error
else:
self._error += out._error
if self._mask is None:
self._mask = out._mask
else:
Expand Down
69 changes: 66 additions & 3 deletions python/lvmdrp/core/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,13 +311,76 @@ def plot_detrend(ori_image, det_image, axs, mbias=None, mdark=None, labels=False

# add labels if requested
if labels:
fig = axs[0].get_figure()
fig.supxlabel(f"counts ({unit})")
fig.supylabel("#")
axs[2].set_xlabel(f"counts ({unit})")
axs[3].set_xlabel(f"Counts ({unit})")
axs[0].set_ylabel("#")
axs[2].set_ylabel("#")

return axs


def plot_error(frame, axs, counts_threshold=(1000, 20000), ref_value=1.0, labels=False):
"""Create plot to validate Poisson error propagation

It takes the given frame data and compares sqrt(data) / error to a given
reference value. Optionally a 3-tuple of quantiles can be given for the
reference value.

Parameters
----------
frame : lvmdrp.core.image.Image|lvmdrp.core.rss.RSS
2D or RSS frame containing data and error attributes
axs : plt.Axes
Axes where to make the plots
counts_threshold : tuple[int], optional
levels of counts above/below which the Poisson statistic holds, by default (1000, 20000)
ref_value : float|tuple[float], optional
Reference value(s) expected for the sqrt(data) / error ratio, by default 1.0
labels : bool, optional
Whether to add titles or not to the axes, by default False
"""

unit = frame._header["BUNIT"]

if isinstance(ref_value, (float, int)):
mu = ref_value
sig1 = sig2 = None
elif isinstance(ref_value, (tuple, list, np.ndarray)) and len(ref_value) == 3:
sig1, mu, sig2 = sorted(ref_value)
else:
raise ValueError(f"Wrong value for {ref_value = }, expected `float` or `3-tuple` for percentile levels")

data = frame._data.copy()
error = frame._error.copy()

pcut = (data >= counts_threshold[0])&(data<=counts_threshold[1])
data[~pcut] = np.nan
error[~pcut] = np.nan

n_pixels = pcut.sum()
median_ratio = np.nanmedian(np.sqrt(np.nanmedian(data, axis=0))/np.nanmedian(error, axis=0))

xs = data[pcut]
ys = np.sqrt(xs) / error[pcut]

axs[0].plot(xs, ys, ".", ms=4, color="tab:blue")
axs[0].axhline(mu, ls="--", lw=1, color="0.2")
axs[1].hist(ys, color="tab:blue", bins=500, range=(mu*0.9, mu*1.1), orientation="horizontal")
if sig1 is not None and sig2 is not None:
axs[0].axhspan(sig1, sig2, lw=0, color="0.2", alpha=0.2)
axs[1].axhspan(sig1, sig2, lw=0, color="0.2", alpha=0.2)
axs[1].axhline(mu, ls="--", lw=1, color="0.2")

axs[0].set_ylim(mu*0.9, mu*1.1)
axs[1].set_ylim(mu*0.9, mu*1.1)

if labels:
axs[0].set_title(f"{n_pixels = } | {median_ratio = :.2f} | {mu = :.2f}", loc="left")
axs[0].set_xlabel(f"Counts ({unit})")
axs[1].set_xlabel("#")
axs[0].set_ylabel(r"$\sqrt{\mathrm{Counts}} / \mathrm{Error}$")


def plot_wavesol_residuals(fiber, ref_waves, lines_pixels, poly_cls, coeffs, ax=None, labels=False):
"""Plot residuals in wavelength polynomial fitting

Expand Down
26 changes: 12 additions & 14 deletions python/lvmdrp/core/resample.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import numpy
import numpy
from scipy.signal import correlate
from scipy.signal.windows import tukey
from scipy import interpolate
Expand Down Expand Up @@ -169,9 +169,9 @@ def resample_flux_density(xout, x, flux, ivar=None, extrapolate=False):
- ivar: weights for flux; default is unweighted resampling
- extrapolate: extrapolate using edge values of input array, default is False,
in which case values outside of input array are set to zero.

Setting both ivar and extrapolate raises a ValueError because one cannot
assign an ivar outside of the input data range.
assign an ivar outside of the input data range.

Returns:
if ivar is None, returns outflux
Expand All @@ -190,10 +190,8 @@ def resample_flux_density(xout, x, flux, ivar=None, extrapolate=False):
mask = (b>0)
outflux = numpy.zeros(a.shape)
outflux[mask] = a[mask] / b[mask]
dx = numpy.gradient(x)
dxout = numpy.gradient(xout)
outivar = _unweighted_resample(xout, x, ivar/dx)*dxout

outivar = _unweighted_resample(xout, x, ivar)

return outflux, outivar


Expand All @@ -211,7 +209,7 @@ def resample_flux(xout, x, flux, extrapolate=False):
Options:
- extrapolate: extrapolate using edge values of input array, default is False,
in which case values outside of input array are set to zero.

Returns:
returns outflux

Expand All @@ -234,7 +232,7 @@ def _unweighted_resample(output_x, input_x, input_flux_density, extrapolate=Fals

both must represent the same quantity with the same unit
input_flux_density = dflux/dx sampled at input_x

Options:
extrapolate: extrapolate using edge values of input array, default is False,
in which case values outside of input array are set to zero
Expand Down Expand Up @@ -289,7 +287,7 @@ def interpolate_mask(x, y, mask, kind="linear", fill_value=0):
bins[1:-1] = (ox[:-1]+ox[1:])/2.
bins[0] = 1.5*ox[0]-0.5*ox[1] # = ox[0]-(ox[1]-ox[0])/2
bins[-1] = 1.5*ox[-1]-0.5*ox[-2] # = ox[-1]+(ox[-1]-ox[-2])/2

# make a temporary node array including input nodes and output bin bounds
# first the boundaries of output bins
tx = bins.copy()
Expand All @@ -309,18 +307,18 @@ def interpolate_mask(x, y, mask, kind="linear", fill_value=0):
# this sets values left and right of input range to first and/or last input values
# first and last values are = 0 if we are not extrapolating
ty = numpy.interp(tx,ix,iy)

# add input nodes which are inside the node array
k = numpy.where((ix >= tx[0])&(ix <= tx[-1]))[0]
if k.size :
tx = numpy.append(tx,ix[k])
ty = numpy.append(ty,iy[k])

# sort this node array
p = tx.argsort()
tx = tx[p]
ty = ty[p]

# now we do a simple integration in each bin of the piece-wise
# linear function of the temporary nodes

Expand All @@ -335,5 +333,5 @@ def interpolate_mask(x, y, mask, kind="linear", fill_value=0):

if numpy.any(binsize<=0) :
raise ValueError("Zero or negative bin size")

return numpy.histogram(trapeze_centers, bins=bins, weights=trapeze_integrals)[0] / binsize
142 changes: 95 additions & 47 deletions python/lvmdrp/functions/imageMethod.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
glueImages,
loadImage,
)
from lvmdrp.core.plot import plt, create_subplots, plot_detrend, plot_strips, plot_image_shift, plot_fiber_thermal_shift, save_fig
from lvmdrp.core.plot import plt, create_subplots, plot_detrend, plot_error, plot_strips, plot_image_shift, plot_fiber_thermal_shift, save_fig
from lvmdrp.core.rss import RSS
from lvmdrp.core.spectrum1d import Spectrum1D, _spec_from_lines, _cross_match
from lvmdrp.core.tracemask import TraceMask
Expand Down Expand Up @@ -2747,6 +2747,16 @@ def extract_spectra(
rss.add_header_comment(f"{in_model}, fiber model used for {camera}")
rss.add_header_comment(f"{in_acorr}, fiber aperture correction used for {camera}")

# create error propagation plot
fig = plt.figure(figsize=(15, 5), layout="constrained")
gs = GridSpec(1, 14, figure=fig)

ax_1 = fig.add_subplot(gs[0, :-4])
ax_2 = fig.add_subplot(gs[0, -4:])

plot_error(frame=rss, axs=[ax_1, ax_2], counts_threshold=(3000, 60000), labels=True)
save_fig(fig, product_path=out_rss, to_display=display_plots, figure_path="qa", label="extracted_error")

# save extracted RSS
log.info(f"writing extracted spectra to {os.path.basename(out_rss)}")
rss.writeFitsData(out_rss)
Expand Down Expand Up @@ -2975,52 +2985,84 @@ def reprojectRSS_drp(
rep.writeto(f"{out_path}/{out_name}_2d.fits", overwrite=True)


def testres_drp(image, trace, fwhm, flux):
"""
Historic task used for debugging of the the extraction routine...
def validate_extraction(in_image, in_cent, in_width, in_rss, plot_columns=[1000, 2000, 3000], display_plots=False):
"""Evaluates the extracted flux into the original 2D pixel grid

This routine will evaluate the extracted flux in the original
2D grid and compare the resulting 2D model against the original
2D image. Three images are stored as outputs:

- The residual 2D image: model - data
- The ratio 2D image: model / data
- The 2D model

Parameters
----------
in_image : str
Path to the original 2D image of the extracted flux
in_cent : str
Path to the fiber centroids trace
in_width : str
Path to the fiber width (FWHM) trace
in_rss : str
Path to the extracted flux in RSS format
plot_columns : array_like, optional
columns to show in plot, by default [1000, 2000, 3000]
display_plots : bool, optional
whether to display plots to screen or not, by dafult False
"""
log.info(f"loading 2D image {in_image}")
img = Image()
# t1 = time.time()
img.loadFitsData(image, extension_data=0)
trace_mask = TraceMask()
trace_mask.loadFitsData(trace, extension_data=0)
trace_fwhm = TraceMask()
# trace_fwhm.setData(data=numpy.ones(trace_mask._data.shape)*2.5)
trace_fwhm.loadFitsData(fwhm, extension_data=0)
img._data = numpy.nan_to_num(img._data)
img.loadFitsData(in_image)

log.info(f"loading fiber parameters in {in_cent} and {in_width}")
cent = TraceMask.from_file(in_cent)
width = TraceMask.from_file(in_width)
width._data /= 2.354

trace_flux = TraceMask()
trace_flux.loadFitsData(flux, extension_data=0)
log.info(f"loading extracted flux in {in_rss}")
rss = RSS.from_file(in_rss)
rss._data = numpy.nan_to_num(rss._data)

ypix_cor = rss._slitmap[["spectrographid"] == int(img._header["CCD"][1])]["ypix_z"]
ypix_ori = img._slitmap[["spectrographid"] == int(img._header["CCD"][1])]["ypix_z"]
thermal_shift = ypix_cor - ypix_ori
log.info(f"fiber thermal shift in slitmap: {thermal_shift:.4f}")
cent._data += thermal_shift

log.info(f"evaluating extracted flux into 2D pixel grid for {img._dim[1]} columns")
x = numpy.arange(img._dim[0])
out = numpy.zeros(img._dim)
fact = numpy.sqrt(2.0 * numpy.pi)

fig, axs = create_subplots(to_display=display_plots, nrows=len(plot_columns), ncols=1, figsize=(15,5), sharex=True, layout="constrained")
for i in range(img._dim[1]):
# print i
A = (
1.0
* numpy.exp(
-0.5
* (
(x[:, numpy.newaxis] - trace_mask._data[:, i][numpy.newaxis, :])
/ abs(trace_fwhm._data[:, i][numpy.newaxis, :] / 2.354)
)
** 2
)
/ (fact * abs(trace_fwhm._data[:, i][numpy.newaxis, :] / 2.354))
)
spec = numpy.dot(A, trace_flux._data[:, i])
A = (numpy.exp(-0.5 * ((x[:, None] - cent._data[:, i][None, :]) / abs(width._data[:, i][None, :])) ** 2) / (fact * abs(width._data[:, i][None, :])))
spec = numpy.dot(A, rss._data[:, i])
out[:, i] = spec
if i == 1000:
plt.plot(spec, "-r")
plt.plot(img._data[:, i], "ok")
plt.show()

if i in plot_columns:
axs[plot_columns.index(i)].step(x, img._data[:, i], color="k", lw=1, where="mid")
axs[plot_columns.index(i)].step(x, spec, color="r", lw=1, where="mid")

out_path = os.path.dirname(in_image)
out_name = os.path.basename(in_image).split(".fits")[0]
out_residual = os.path.join(out_path, f"{out_name}_residual.fits")
out_2dimage = os.path.join(out_path, f"{out_name}_2dimage.fits")
out_ratio = os.path.join(out_path, f"{out_name}_ratio.fits")
save_fig(fig, product_path=out_2dimage, to_display=display_plots, figure_path="qa", label="2D_extracted_model")

log.info(f"writing residual to {out_residual}")
hdu = pyfits.PrimaryHDU(img._data - out)
hdu.writeto("res.fits", overwrite=True)
hdu = pyfits.PrimaryHDU(out)
hdu.writeto("fit.fits", overwrite=True)
hdu.writeto(out_residual, overwrite=True)

hdu = pyfits.PrimaryHDU((img._data - out) / img._data)
hdu.writeto("res_rel.fits", overwrite=True)
log.info(f"writing ratio to {out_ratio}")
hdu = pyfits.PrimaryHDU(out / img._data)
hdu.writeto(out_ratio, overwrite=True)

log.info(f"writing 2D model {out_2dimage}")
hdu = pyfits.PrimaryHDU(out)
hdu.writeto(out_2dimage, overwrite=True)


# TODO: for arcs take short exposures for bright lines & long exposures for faint lines
Expand Down Expand Up @@ -3841,16 +3883,22 @@ def detrend_frame(
# show plots
log.info("plotting results")
# detrending process
fig, axs = create_subplots(
to_display=display_plots,
nrows=2,
ncols=2,
figsize=(15, 15),
sharex=True,
sharey=True,
)
plt.subplots_adjust(wspace=0.15, hspace=0.1)
plot_detrend(ori_image=org_img, det_image=detrended_img, axs=axs, mbias=mbias_img, mdark=mdark_img, labels=True)
fig = plt.figure(figsize=(15, 10), layout="constrained")
gs = GridSpec(3, 14, figure=fig)

ax1 = fig.add_subplot(gs[0, :7])
ax2 = fig.add_subplot(gs[0, 7:], sharex=ax1, sharey=ax1)
ax3 = fig.add_subplot(gs[1, :7], sharex=ax1, sharey=ax1)
ax4 = fig.add_subplot(gs[1, 7:], sharex=ax1, sharey=ax1)
ax1.tick_params(labelbottom=False)
ax2.tick_params(labelbottom=False)
ax2.tick_params(labelleft=False)
ax4.tick_params(labelleft=False)
ax_1 = fig.add_subplot(gs[2, :-4])
ax_2 = fig.add_subplot(gs[2, -4:], sharey=ax_1)
plot_detrend(ori_image=org_img, det_image=detrended_img, axs=[ax1, ax2, ax3, ax4], mbias=mbias_img, mdark=mdark_img, labels=True)
# Poisson error
plot_error(frame=detrended_img, axs=[ax_1, ax_2], labels=True)
save_fig(
fig,
product_path=out_image,
Expand Down
Loading
Loading