Skip to content

Commit

Permalink
Expose memory loader (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
diehlbw authored Jul 2, 2024
1 parent 1b2867a commit 2ea1219
Show file tree
Hide file tree
Showing 16 changed files with 135 additions and 44 deletions.
1 change: 1 addition & 0 deletions changelog/34.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
* seismometer.run_startup() can now accept preloaded prediction and event dataframes that take precendence over loading from configuration
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = seismometer
version = attr: seismometer.__version__
version = attr: seismometer._version.__version__
description = seismometer: Data Science visualization and investigation tools for AI Trust & Assurance
author = Epic
author_email = OpenSourceContributions-Python@epic.com
Expand Down
31 changes: 28 additions & 3 deletions src/seismometer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
import logging
from pathlib import Path
from typing import Optional

import pandas as pd

from seismometer._version import __version__
from seismometer.core.logger import add_log_formatter, set_default_logger_config


def run_startup(*, config_path: str | Path = None, output_path: str | Path = None, log_level: int = logging.WARN):
def run_startup(
*,
config_path: str | Path = None,
output_path: str | Path = None,
predictions_frame: Optional[pd.DataFrame] = None,
events_frame: Optional[pd.DataFrame] = None,
definitions: Optional[dict] = None,
log_level: int = logging.WARN,
reset: bool = False,
):
"""
Runs the required startup for instantiating seismometer.
Expand All @@ -16,8 +28,19 @@ def run_startup(*, config_path: str | Path = None, output_path: str | Path = Non
output_path : Optional[str | Path], optional
An output path to write data to, overwriting the default path specified by info_dir in config.yml,
by default None.
predictions_frame : Optional[pd.DataFrame], optional
An optional DataFrame containing the fully loaded predictions data, by default None.
By default, when not specified here, these data will be loaded based on conifguration.
events_frame : Optional[pd.DataFrame], optional
An optional DataFrame containing the fully loaded events data, by default None.
By default, when not specified here, these data will be loaded based on conifguration.
definitions : Optional[dict], optional
A dictionary of definitions to use instead of loading those specified by configuration, by default None.
By default, when not specified here, these data will be loaded based on conifguration.
log_level : logging._Level, optional
The log level to set. by default, logging.WARN.
reset : bool, optional
A flag when True, will reset the Seismogram instance before loading configuration and data, by default False.
"""
import importlib

Expand All @@ -31,8 +54,10 @@ def run_startup(*, config_path: str | Path = None, output_path: str | Path = Non
logger.setLevel(log_level)
logger.info(f"seismometer version {__version__} starting")

sg = Seismogram(config_path, output_path)
sg.load_data()
if reset:
Seismogram.kill()
sg = Seismogram(config_path, output_path, definitions=definitions)
sg.load_data(predictions=predictions_frame, events=events_frame)

# Surface api into namespace
s_module = importlib.import_module("seismometer._api")
Expand Down
7 changes: 7 additions & 0 deletions src/seismometer/configuration/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class ConfigProvider:
Specifies the template notebook name to use during building, by default None; it uses "template" from the
primary config file.
This is the template that will be used as a base for building the final notebook.
definitions : Optional[dict], optional
A dictionary of definitions to use instead of loading those specified by configuration, by default None.
"""

Expand All @@ -47,6 +49,7 @@ def __init__(
info_dir: str | Path = None,
data_dir: str | Path = None,
template_notebook: Option = None,
definitions: dict = None,
):
self._config: OtherInfo = None
self._usage: DataUsage = None
Expand All @@ -55,6 +58,10 @@ def __init__(
self._output_dir: Path = None
self._output_notebook: str = ""

if definitions is not None:
self._prediction_defs = PredictionDictionary(predictions=definitions.pop("predictions", []))
self._event_defs = EventDictionary(events=definitions.pop("events", None))

self._load_config_config(config_config)
self._resolve_other_paths(usage_config, info_dir, data_dir)
self._override_template(template_notebook)
Expand Down
6 changes: 4 additions & 2 deletions src/seismometer/core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

Pathlike = str | Path

logger = logging.getLogger("seismometer")


def slugify(value: str) -> str:
"""
Expand Down Expand Up @@ -108,7 +110,7 @@ def resolve_filename(
# Create pre-emptively
if not basedir.is_dir():
if not create:
logging.warning(f"No directory found for group: {basedir}")
logger.warning(f"No directory found for group: {basedir}")
else:
basedir.mkdir(parents=True, exist_ok=True)
return basedir / filename
Expand Down Expand Up @@ -317,4 +319,4 @@ def _write(writer: Callable[[Any, "fileobject"], None], content: Any, file: Path

with open(file, "w") as fo:
writer(content, fo)
logging.info(f"File written: {file.resolve()}")
logger.info(f"File written: {file.resolve()}")
22 changes: 16 additions & 6 deletions src/seismometer/core/logger.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import logging
from typing import Optional


def set_default_logger_config() -> None:
Expand All @@ -9,13 +10,19 @@ def set_default_logger_config() -> None:
logging.basicConfig()


def remove_default_log_handler() -> None:
def remove_default_handler(logger: Optional[logging.Logger] = None) -> None:
"""
Removes the default logging handlers.
Removes the default logging handler.
Parameters
----------
logger : Optional[Logger], optional
Descriptor of the logger do modify, by default None.
When None, the root logger is modified.
"""
root_log = logging.getLogger()
while root_log.hasHandlers():
root_log.removeHandler(root_log.handlers[0])
logger = logger or logging.getLogger()
while logger.hasHandlers():
logger.removeHandler(logger.handlers[0])


def add_log_formatter(logger: logging.Logger):
Expand All @@ -27,7 +34,10 @@ def add_log_formatter(logger: logging.Logger):
logger : logging.Logger
The logger to add formatting to.
"""
remove_default_log_handler()
# Remove root-handler / default
remove_default_handler()
# Remove default handler for seismometer - make safe to call multiple times
remove_default_handler(logger)

handler = logging.StreamHandler()
formatter = TimeFormatter()
Expand Down
1 change: 1 addition & 0 deletions src/seismometer/data/loader/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def _log_column_mismatch(actual_columns: list[str], desired_columns: list[str],
"""Logs warnings if the actual columns and desired columns are a mismatch."""
if len(actual_columns) == len(desired_columns):
return

logger.warning(
"Not all requested columns are present. " + f"Missing columns are {', '.join(desired_columns-present_columns)}"
)
Expand Down
54 changes: 47 additions & 7 deletions src/seismometer/seismogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ class Seismogram(object, metaclass=Singleton):
"""

# entity_keys: list[str] = ['Entity Id', 'Entity Dat']

entity_keys: list[str]
""" The one or two columns used as identifiers for data. """
predict_time: str
Expand All @@ -50,7 +48,12 @@ class Seismogram(object, metaclass=Singleton):
output_list: list[str]
""" The list of columns representing model outputs."""

def __init__(self, config_path: Optional[str | Path] = None, output_path: Optional[str | Path] = None):
def __init__(
self,
config_path: Optional[str | Path] = None,
output_path: Optional[str | Path] = None,
definitions: Optional[dict] = None,
):
"""
Constructor for Seismogram, which can only be instantiated once.
Expand All @@ -63,23 +66,49 @@ def __init__(self, config_path: Optional[str | Path] = None, output_path: Option
output_path : str or Path, optional
Override location to place resulting data and report files.
Defaults to the config.yml info_dir, and then the notebook's output directory.
definitions : dict, optional
Additional definitions to be used instead of loading based on configuration, by default None.
"""
if config_path is None:
config_path = Path.cwd() / "data"
else:
config_path = Path(config_path)

self.dataframe: pd.DataFrame = None
self.cohort_cols: list[str] = []
self.config_path = config_path

self.load_config(config_path)
self.load_config(config_path, definitions=definitions)

self.config.set_output(output_path)
self.config.output_dir.mkdir(parents=True, exist_ok=True)
self.dataloader = loader_factory(self.config)

def load_data(self, predictions=None, events=None):
def load_data(
self, *, predictions: Optional[pd.DataFrame] = None, events: Optional[pd.DataFrame] = None, reset: bool = False
):
"""
Loads the seismogram data.
Uses the passed in frames if they are specified, otherwise uses configuration to load data.
If data is already loaded, does not change state unless reset is true.
Parameters
----------
predictions : pd.DataFrame, optional
The fully prepared predictions dataframe, by default None.
Uses this when specified, otherwise loads based on configuration.
events : pd.DataFrame, optional
The pre-loaded events dataframe, by default None.
Uses this when specified, otherwise loads based on configuration.
reset : bool, optional
Flag when set to true will overwrite existing dataframe, by default False
"""
if self.dataframe is not None and not reset:
logger.debug("Data already loaded; pass reset=True to clear data and re-evaluate.")
return

self._load_metadata()

self.dataframe = self.dataloader.load_data(predictions, events)
Expand Down Expand Up @@ -273,8 +302,19 @@ def score_bins(self):
# endregion

# region initialization and preprocessing (this region knows about config)
def load_config(self, config_path: Path):
self.config = ConfigProvider(config_path)
def load_config(self, config_path: Path, definitions: Optional[dict] = None):
"""
Loads the base configuration and alerting congfiguration
Parameters
----------
config_path : Path
The location of the main configuration file.
definitions : Optional[dict], optional
An optional dictionary containing both events and predictions lists, by default None.
If not passed, these will be loaded based on configuration.
"""
self.config = ConfigProvider(config_path, definitions=definitions)
self.alert_config = AlertConfigProvider(config_path)

if len(self.config.cohorts) == 0:
Expand Down
8 changes: 4 additions & 4 deletions tests/configuration/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_reduce_events_to_unique_names(self, caplog):
undertest.Event(source="event2", display_name="Event 2"),
undertest.Event(source="different source", display_name="Event 1"),
]
with caplog.at_level(logging.WARNING):
with caplog.at_level(logging.WARNING, logger="seismometer"):
data_usage = undertest.DataUsage(events=events)
assert "Duplicate" in caplog.text

Expand All @@ -93,7 +93,7 @@ def test_reduce_events_eliminates_source_display_collision(self, caplog):
undertest.Event(source="event1"),
undertest.Event(source="event2", display_name="event1"),
]
with caplog.at_level(logging.WARNING):
with caplog.at_level(logging.WARNING, logger="seismometer"):
data_usage = undertest.DataUsage(events=events)
assert "Duplicate" in caplog.text

Expand All @@ -107,7 +107,7 @@ def test_reduce_cohorts_to_unique_names(self, caplog):
undertest.Cohort(source="cohort2", display_name="Cohort 2"),
undertest.Cohort(source="different source", display_name="Cohort 1"),
]
with caplog.at_level(logging.WARNING):
with caplog.at_level(logging.WARNING, logger="seismometer"):
data_usage = undertest.DataUsage(cohorts=cohorts)
assert "Duplicate" in caplog.text

Expand All @@ -123,7 +123,7 @@ def test_reduce_cohorts_eliminates_source_display_collision(self, caplog):
undertest.Cohort(source="cohort2", display_name="cohort1"),
]

with caplog.at_level(logging.WARNING):
with caplog.at_level(logging.WARNING, logger="seismometer"):
data_usage = undertest.DataUsage(cohorts=cohorts)
assert "Duplicate" in caplog.text

Expand Down
23 changes: 14 additions & 9 deletions tests/core/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@
from seismometer.core.decorators import DiskCachedFunction, export


def foo(arg1, kwarg1=None):
if kwarg1 is None:
return arg1 + 1
else:
return kwarg1 + 1
def get_test_function():
def foo(arg1, kwarg1=None):
if kwarg1 is None:
return arg1 + 1
else:
return kwarg1 + 1

return foo


class Test_Export:
Expand All @@ -26,14 +29,15 @@ def test_export(self):
global __all__
__all__ = []

new_fn = export(foo)
test_fn = get_test_function()
new_fn = export(test_fn)

assert new_fn(1) == 2
assert new_fn(1, 5) == 6
assert __all__ == ["foo"]

with pytest.raises(ImportError):
export(foo)
export(test_fn)

def test_mod_none(self):
"""
Expand All @@ -45,8 +49,9 @@ def test_mod_none(self):
global __all__
__all__ = []

foo.__module__ = None
new_fn = export(foo)
test_fn = get_test_function()
test_fn.__module__ = None
new_fn = export(test_fn)

assert new_fn(1) == 2
assert new_fn(1, 5) == 6
Expand Down
8 changes: 4 additions & 4 deletions tests/core/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def test_write_new_file_in_nonexistent_directory(tmp_as_current):
def test_write_logs_file_written(tmp_as_current, caplog):
file = Path("new_file.txt")

with caplog.at_level(logging.INFO):
with caplog.at_level(logging.INFO, logger="seismometer"):
undertest._write(lambda content, fo: fo.write(content), "test content", file, overwrite=False)

assert f"File written: {file.resolve()}" in caplog.text
Expand Down Expand Up @@ -227,7 +227,7 @@ def test_create_makesdir_on_none(self, tmp_path, caplog):
filename = "new_file"
expected = Path("output/attr") / "age1_0-3_0-self+strip" / filename

with caplog.at_level(30):
with caplog.at_level(30, logger="seismometer"):
_ = undertest.resolve_filename("new_file", attribute, subgroups)

assert not caplog.text
Expand All @@ -239,7 +239,7 @@ def test_no_create_warns_on_nonexistent(self, caplog):
filename = "new_file"
expected = Path("output/attr") / "age1_0-3_0-self+strip" / filename

with caplog.at_level(30):
with caplog.at_level(30, logger="seismometer"):
_ = undertest.resolve_filename("new_file", attribute, subgroups, create=False)

assert "No directory" in caplog.text
Expand All @@ -252,7 +252,7 @@ def test_no_create_existent_does_not_warn(self, caplog):
expected = Path("output/attr") / "gg" / filename

expected.parent.mkdir(parents=True, exist_ok=False)
with caplog.at_level(30):
with caplog.at_level(30, logger="seismometer"):
_ = undertest.resolve_filename("new_file", attribute, subgroups, create=False)

assert not caplog.text
Expand Down
Loading

0 comments on commit 2ea1219

Please sign in to comment.