Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add profiling to most reduction steps #149

Merged
merged 4 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 50 additions & 32 deletions python/lvmdrp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
get_analog_groups, match_master_metadata, create_master_path,
update_summary_file)
from lvmdrp.utils.convert import tileid_grp
from lvmdrp.utils.timer import Timer

from lvmdrp import config, log, path, __version__ as drpver

Expand Down Expand Up @@ -1559,14 +1560,16 @@ def reduce_2d(mjd, calibrations, expnums=None, exptime=None, cameras=CAMERAS,
if skip_done and os.path.isfile(dframe_path):
log.info(f"skipping {dframe_path}, file already exist")
else:
preproc_raw_frame(in_image=frame_path, out_image=pframe_path,
in_mask=mpixmask_path, replace_with_nan=replace_with_nan, assume_imagetyp=assume_imagetyp)
detrend_frame(in_image=pframe_path, out_image=dframe_path,
in_bias=mbias_path,
in_pixelflat=mpixflat_path,
replace_with_nan=replace_with_nan,
reject_cr=reject_cr,
in_slitmap=fibermap if imagetyp in {"flat", "arc", "object"} else None)
with Timer(name='Preproc '+pframe_path, logger=log.info):
preproc_raw_frame(in_image=frame_path, out_image=pframe_path,
in_mask=mpixmask_path, replace_with_nan=replace_with_nan, assume_imagetyp=assume_imagetyp)
with Timer(name='Detrend '+dframe_path, logger=log.info):
detrend_frame(in_image=pframe_path, out_image=dframe_path,
in_bias=mbias_path,
in_pixelflat=mpixflat_path,
replace_with_nan=replace_with_nan,
reject_cr=reject_cr,
in_slitmap=fibermap if imagetyp in {"flat", "arc", "object"} else None)


def science_reduction(expnum: int, use_longterm_cals: bool = False,
Expand Down Expand Up @@ -1635,7 +1638,8 @@ def science_reduction(expnum: int, use_longterm_cals: bool = False,
if skip_2d:
log.info("skipping 2D reduction")
else:
reduce_2d(mjd=sci_mjd, calibrations=calibs, expnums=[sci_expnum], reject_cr=reject_cr, skip_done=False)
with Timer(name='Reduce2d', logger=log.info):
reduce_2d(mjd=sci_mjd, calibrations=calibs, expnums=[sci_expnum], reject_cr=reject_cr, skip_done=False)

# run reduction loop for each science camera exposure
if skip_1d:
Expand Down Expand Up @@ -1670,17 +1674,20 @@ def science_reduction(expnum: int, use_longterm_cals: bool = False,
mmodel_path = calibs["model"][sci_camera]

# add astrometry to frame
add_astrometry(in_image=dsci_path, out_image=dsci_path, in_agcsci_image=agcsci_path, in_agcskye_image=agcskye_path, in_agcskyw_image=agcskyw_path)
with Timer(name='Astrometry '+dsci_path, logger=log.info):
add_astrometry(in_image=dsci_path, out_image=dsci_path, in_agcsci_image=agcsci_path, in_agcskye_image=agcskye_path, in_agcskyw_image=agcskyw_path)

# subtract straylight
subtract_straylight(in_image=dsci_path, out_image=lsci_path, out_stray=lstr_path,
in_cent_trace=mtrace_path, select_nrows=(5,5), use_weights=True,
aperture=15, smoothing=400, median_box=101, gaussian_sigma=20.0,
parallel=parallel_run)
with Timer(name='Straylight '+lsci_path, logger=log.info):
subtract_straylight(in_image=dsci_path, out_image=lsci_path, out_stray=lstr_path,
in_cent_trace=mtrace_path, select_nrows=(5,5), use_weights=True,
aperture=15, smoothing=400, median_box=101, gaussian_sigma=20.0,
parallel=parallel_run)

# extract 1d spectra
extract_spectra(in_image=lsci_path, out_rss=xsci_path, in_trace=mtrace_path, in_fwhm=mwidth_path,
in_model=mmodel_path, method=extraction_method, parallel=parallel_run)
with Timer(name='Extract '+xsci_path, logger=log.info):
extract_spectra(in_image=lsci_path, out_rss=xsci_path, in_trace=mtrace_path, in_fwhm=mwidth_path,
in_model=mmodel_path, method=extraction_method, parallel=parallel_run)

# per channel reduction
cframe_path = path.full("lvm_frame", drpver=drpver, tileid=sci_tileid, mjd=sci_mjd, expnum=sci_expnum, kind='CFrame')
Expand Down Expand Up @@ -1708,52 +1715,63 @@ def science_reduction(expnum: int, use_longterm_cals: bool = False,
kind='h', camera=channel, imagetype=sci_imagetyp, expnum=expnum)

# stack spectrographs
stack_spectrographs(in_rsss=xsci_paths, out_rss=xsci_path)
with Timer(name='Stack Spectrographs '+xsci_path, logger=log.info):
stack_spectrographs(in_rsss=xsci_paths, out_rss=xsci_path)
if not os.path.exists(xsci_path):
log.error(f'No stacked file found: {xsci_path}. Skipping remaining pipeline.')
continue

# wavelength calibrate
create_pixel_table(in_rss=xsci_path, out_rss=wsci_path, in_waves=mwave_paths, in_lsfs=mlsf_paths)
with Timer(name='Wavelengths '+wsci_path, logger=log.info):
create_pixel_table(in_rss=xsci_path, out_rss=wsci_path, in_waves=mwave_paths, in_lsfs=mlsf_paths)

# apply fiberflat correction
apply_fiberflat(in_rss=wsci_path, out_frame=frame_path, in_flat=mflat_path)
with Timer(name='Fiberflat '+frame_path, logger=log.info):
apply_fiberflat(in_rss=wsci_path, out_frame=frame_path, in_flat=mflat_path)

# correct thermal shift in wavelength direction
shift_wave_skylines(in_frame=frame_path, out_frame=frame_path)
with Timer(name='Thermal Shifts '+frame_path, logger=log.info):
shift_wave_skylines(in_frame=frame_path, out_frame=frame_path)

# interpolate sky fibers
interpolate_sky(in_frame=frame_path, out_rss=ssci_path)
with Timer(name='Interpolate Sky '+ssci_path, logger=log.info):
interpolate_sky(in_frame=frame_path, out_rss=ssci_path)

# combine sky telescopes
combine_skies(in_rss=ssci_path, out_rss=ssci_path, sky_weights=sky_weights)
with Timer(name='Combine Sky '+ssci_path, logger=log.info):
combine_skies(in_rss=ssci_path, out_rss=ssci_path, sky_weights=sky_weights)

# resample wavelength into uniform grid along fiber IDs for science and sky fibers
resample_wavelength(in_rss=ssci_path, out_rss=hsci_path, wave_range=SPEC_CHANNELS[channel], wave_disp=0.5, convert_to_density=True)
with Timer(name='Resample '+hsci_path, logger=log.info):
resample_wavelength(in_rss=ssci_path, out_rss=hsci_path, wave_range=SPEC_CHANNELS[channel], wave_disp=0.5, convert_to_density=True)

# use resampled frames for flux calibration in each camera, using standard stars observed in the spec telescope
# and field stars found in the sci ifu
fluxcal_standard_stars(hsci_path, GAIA_CACHE_DIR=MASTERS_DIR+'/gaia_cache')
fluxcal_sci_ifu_stars(hsci_path, GAIA_CACHE_DIR=MASTERS_DIR+'/gaia_cache')
with Timer(name='Fluxcal '+hsci_path, logger=log.info):
fluxcal_standard_stars(hsci_path, GAIA_CACHE_DIR=MASTERS_DIR+'/gaia_cache')
fluxcal_sci_ifu_stars(hsci_path, GAIA_CACHE_DIR=MASTERS_DIR+'/gaia_cache')

# flux-calibrate each channel
fframe_path = path.full("lvm_frame", mjd=sci_mjd, drpver=drpver, tileid=sci_tileid, expnum=sci_expnum, kind=f'FFrame-{channel}')
apply_fluxcal(in_rss=hsci_path, out_fframe=fframe_path, method=fluxcal_method)
# flux-calibrate each channel
fframe_path = path.full("lvm_frame", mjd=sci_mjd, drpver=drpver, tileid=sci_tileid, expnum=sci_expnum, kind=f'FFrame-{channel}')
apply_fluxcal(in_rss=hsci_path, out_fframe=fframe_path, method=fluxcal_method)

# stitch channels
fframe_paths = sorted(path.expand('lvm_frame', mjd=sci_mjd, tileid=sci_tileid, drpver=drpver, kind='FFrame-?', expnum=sci_expnum))
if len(fframe_paths) == 0:
log.error('No fframe files found. Cannot join spectrograph channels. Exiting pipeline.')
return

join_spec_channels(in_fframes=fframe_paths, out_cframe=cframe_path, use_weights=True)
with Timer(name='Join Channels '+cframe_path, logger=log.info):
join_spec_channels(in_fframes=fframe_paths, out_cframe=cframe_path, use_weights=True)

# sky subtraction
quick_sky_subtraction(in_cframe=cframe_path, out_sframe=sframe_path, skip_subtraction=skip_sky_subtraction)
with Timer(name='QSky '+sframe_path, logger=log.info):
quick_sky_subtraction(in_cframe=cframe_path, out_sframe=sframe_path, skip_subtraction=skip_sky_subtraction)

# update the drpall summary file
log.info('Updating the drpall summary file')
update_summary_file(sframe_path, tileid=sci_tileid, mjd=sci_mjd, expnum=sci_expnum, master_mjd=cals_mjd)
with Timer(name='DRPAll '+sframe_path, logger=log.info):
log.info('Updating the drpall summary file')
update_summary_file(sframe_path, tileid=sci_tileid, mjd=sci_mjd, expnum=sci_expnum, master_mjd=cals_mjd)

# clean ancillary folder
if clean_ancillary:
Expand Down
75 changes: 75 additions & 0 deletions python/lvmdrp/utils/timer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import time
from contextlib import ContextDecorator
from dataclasses import dataclass, field
from typing import Any, Callable, ClassVar, Dict, Optional

class TimerError(Exception):
"""A custom exception used to report errors in use of Timer class"""

@dataclass
class Timer(ContextDecorator):
"""Time your code using a class, context manager, or decorator
Use as
Timer t
t.start()
code()
t.stop()

or
@Timer()
def func(...)

or
with Timer(name):
code()

see https://realpython.com/python-timer/
"""

timers: ClassVar[Dict[str, float]] = {}
name: Optional[str] = None
text: str = "elapsed time: {:0.4f} seconds"
logger: Optional[Callable[[str], None]] = print
_start_time: Optional[float] = field(default=None, init=False, repr=False)

def __post_init__(self) -> None:
"""Initialization: add timer to dict of timers"""
if self.name:
self.timers.setdefault(self.name, 0)

def start(self) -> None:
"""Start a new timer"""
if self._start_time is not None:
raise TimerError("Timer is running. Use .stop() to stop it")

self._start_time = time.perf_counter()

def stop(self) -> float:
"""Stop the timer, and report the elapsed time"""
if self._start_time is None:
raise TimerError("Timer is not running. Use .start() to start it")

# Calculate elapsed time
elapsed_time = time.perf_counter() - self._start_time
self._start_time = None

# Report elapsed time
if self.logger:
if self.name is not None:
self.logger(self.name + ': ' + self.text.format(elapsed_time))
else:
self.text.format(elapsed_time)
if self.name:
self.timers[self.name] += elapsed_time

return elapsed_time

def __enter__(self) -> "Timer":
"""Start a new timer as a context manager"""
self.start()
return self

def __exit__(self, *exc_info: Any) -> None:
"""Stop the context manager timer"""
self.stop()

Loading