Skip to content

Commit

Permalink
Update pipeline for post-load update hook
Browse files Browse the repository at this point in the history
  • Loading branch information
diehlbw committed Jul 24, 2024
1 parent b52256f commit eeced43
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 4 deletions.
1 change: 1 addition & 0 deletions changelog/56.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Updates the loading pipeline to have a hook for updating the loaded dataframe. Not accessible via run_startup(), requires direct initialization of Seismogram.
7 changes: 6 additions & 1 deletion src/seismometer/data/loader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
from .pipeline import ConfigFrameHook, ConfigOnlyHook, MergeFramesHook, SeismogramLoader


def loader_factory(config: ConfigProvider) -> SeismogramLoader:
def loader_factory(config: ConfigProvider, post_load_fn: ConfigFrameHook = None) -> SeismogramLoader:
"""
Construct a SeismogramLoader from the provided configuration.
Parameters
----------
config : ConfigProvider
The loaded configuration object
post_load_fn : ConfigFrameHook, optional
A callable taking a ConfigProvider and the fully loaded dataframe and returning a dataframe.
Used to allow any custom manipulations of the Seismogram dataframe after all other load steps complete.
WARNING: This can completly overwrite/discard the daframe that was loaded.
Returns
-------
Expand All @@ -29,4 +33,5 @@ def loader_factory(config: ConfigProvider) -> SeismogramLoader:
event_fn=event_loader,
post_event_fn=event.post_transform_fn,
merge_fn=event.merge_onto_predictions,
post_load_fn=post_load_fn,
)
20 changes: 20 additions & 0 deletions src/seismometer/data/loader/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class SeismogramLoader:
* load events [ConfigOnlyHook]
* transform (type) the loaded events [ConfigFrameHook]
* merge events onto predictions [MergeFramesHook]
* post load manipulations [ConfigFrameHook]
Each step is expected to return a dataframe, chaining the steps to get the frame driving a loaded Seismogram.
"""
Expand All @@ -65,6 +66,7 @@ def __init__(
post_predict_fn: Optional[ConfigFrameHook] = None,
post_event_fn: Optional[ConfigFrameHook] = None,
merge_fn: Optional[MergeFramesHook] = None,
post_load_fn: Optional[ConfigFrameHook] = None,
):
"""
Initialize a data loading pipeline of functions returning a dataframe for a Seismogram session.
Expand All @@ -90,6 +92,10 @@ def __init__(
A callable taking a ConfigProvider, a (events) dataframe, and a (predictions) dataframe
and returning a dataframe.
Used to merge events onto predictions based on configuration.
post_load_fn : ConfigFrameHook, optional
A callable taking a ConfigProvider and the fully loaded dataframe and returning a dataframe.
Used to allow any custom manipulations of the Seismogram dataframe during load.
WARNING: This can completly overwrite/discard the daframe that was loaded.
"""
self.config = config

Expand All @@ -103,6 +109,9 @@ def __init__(
self.prediction_from_memory: ConfigFrameHook = _passthru_framehook
self.event_from_memory: ConfigFrameHook = _passthru_framehook

# Hooks for custom transformations
self.post_load_fn: ConfigFrameHook = post_load_fn or _passthru_framehook

def load_data(self, prediction_obj: pd.DataFrame = None, event_obj: pd.DataFrame = None) -> pd.DataFrame:
"""
Entry point for loading data for a Seismogram session.
Expand All @@ -128,6 +137,9 @@ def load_data(self, prediction_obj: pd.DataFrame = None, event_obj: pd.DataFrame
dataframe = self.post_predict_fn(self.config, dataframe)
dataframe = self._add_events(dataframe, event_obj)

dataframe = self._add_custom_columns(dataframe)

dataframe = self.post_load_fn(self.config, dataframe)
return dataframe

def _load_predictions(self, prediction_obj: pd.DataFrame = None):
Expand Down Expand Up @@ -161,5 +173,13 @@ def _load_events(self, event_obj: pd.DataFrame = None) -> pd.DataFrame:
return self.event_fn(self.config)
return self.event_from_memory(self.config, event_obj)

def _add_custom_columns(self, dataframe: pd.DataFrame) -> pd.DataFrame:
"""
Add custom columns to the dataframe based on configuration.
NotImplemented -- currently a pass-through.
"""
return dataframe


__all__ = ["SeismogramLoader", "ConfigOnlyHook", "ConfigFrameHook", "MergeFramesHook"]
9 changes: 6 additions & 3 deletions src/seismometer/seismogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from seismometer.core.patterns import Singleton
from seismometer.data import pandas_helpers as pdh
from seismometer.data import resolve_cohorts
from seismometer.data.loader import loader_factory
from seismometer.data.loader import SeismogramLoader, loader_factory
from seismometer.report.alerting import AlertConfigProvider

MAXIMUM_NUM_COHORTS = 25
Expand Down Expand Up @@ -53,6 +53,7 @@ def __init__(
config_path: Optional[str | Path] = None,
output_path: Optional[str | Path] = None,
definitions: Optional[dict] = None,
loader: SeismogramLoader = None,
):
"""
Constructor for Seismogram, which can only be instantiated once.
Expand All @@ -68,7 +69,9 @@ def __init__(
Defaults to the config.yml info_dir, and then the notebook's output directory.
definitions : dict, optional
Additional definitions to be used instead of loading based on configuration, by default None.
loader : SeismogramLoader, optional
A loader instance for defining the data loading pipeline, by default None.
If not provided, uses factory to instantiate the loader based on configuration.
"""
if config_path is None:
config_path = Path.cwd() / "data"
Expand All @@ -83,7 +86,7 @@ def __init__(

self.config.set_output(output_path)
self.config.output_dir.mkdir(parents=True, exist_ok=True)
self.dataloader = loader_factory(self.config)
self.dataloader = loader or loader_factory(self.config)

def load_data(
self, *, predictions: Optional[pd.DataFrame] = None, events: Optional[pd.DataFrame] = None, reset: bool = False
Expand Down

0 comments on commit eeced43

Please sign in to comment.