From 857807a1f4f054bbae49f8e7b00764e0e935a275 Mon Sep 17 00:00:00 2001 From: diehlbw Date: Tue, 1 Oct 2024 13:58:46 +0000 Subject: [PATCH] address comments --- src/seismometer/configuration/model.py | 12 +++++++++--- src/seismometer/data/loader/prediction.py | 6 +----- src/seismometer/data/pandas_helpers.py | 4 ++-- tests/configuration/test_model.py | 4 ++-- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/seismometer/configuration/model.py b/src/seismometer/configuration/model.py index 118017d..493d4f9 100644 --- a/src/seismometer/configuration/model.py +++ b/src/seismometer/configuration/model.py @@ -113,7 +113,10 @@ def get(self, key: str, default: Optional[Any] = None) -> Union[DictionaryItem, ------- The DictionaryItem with name specified or the default value """ - return self[key] or default + try: + return self[key] + except KeyError: + return default class EventDictionary(BaseModel): @@ -160,7 +163,10 @@ def get(self, key: str, default: Optional[Any] = None) -> Union[DictionaryItem, ------- The DictionaryItem with name specified or the default value """ - return self[key] or default + try: + return self[key] + except KeyError: + return default class Cohort(BaseModel): @@ -426,4 +432,4 @@ def _search_dictionary(dictionary: list[DictionaryItem], key: str) -> Optional[D for item in dictionary: if item.name == key: return item - return + raise KeyError(f"{key} not found") diff --git a/src/seismometer/data/loader/prediction.py b/src/seismometer/data/loader/prediction.py index b4ce9f6..b746627 100644 --- a/src/seismometer/data/loader/prediction.py +++ b/src/seismometer/data/loader/prediction.py @@ -161,11 +161,7 @@ def assumed_types(config: ConfigProvider, dataframe: pd.DataFrame) -> pd.DataFra # other def _gather_defined_types(config: ConfigProvider) -> dict[str, str]: """Gathers the defined types from the configuration dictionary.""" - return { - defn.name: defn.dtype - for defn in config.prediction_defs.predictions - if getattr(defn, "dtype", None) is not None - } + return {defn.name: defn.dtype for defn in config.prediction_defs.predictions if defn.dtype is not None} def _infer_datetime(dataframe, cols=None, override_categories=None): diff --git a/src/seismometer/data/pandas_helpers.py b/src/seismometer/data/pandas_helpers.py index 4b84820..20ed66a 100644 --- a/src/seismometer/data/pandas_helpers.py +++ b/src/seismometer/data/pandas_helpers.py @@ -180,9 +180,9 @@ def _one_event( if event_base_val_dtype is not None: try: one_event[event_base_val_col] = one_event[event_base_val_col].astype(event_base_val_dtype) - except ValueError as exc: + except (ValueError, TypeError) as exc: raise ConfigurationError( - f"Cannot cast '{event_label}' values to {event_base_val_dtype}. " + f"Cannot cast '{event_label}' values to '{event_base_val_dtype}'. " + "Update dictionary config or contact the model owner." ) from exc diff --git a/tests/configuration/test_model.py b/tests/configuration/test_model.py index 9a393ed..297f043 100644 --- a/tests/configuration/test_model.py +++ b/tests/configuration/test_model.py @@ -55,7 +55,7 @@ def test_multiple_events(self): assert expected == actual @pytest.mark.parametrize("search_key,expected_key", [("evA", "filled"), ("evB", "given"), ("evC", "empty")]) - def test_search_existing_returns_item(self, search_key, expected_key): + def test_search_returns_item(self, search_key, expected_key): inputs = [ {"name": "evA", "display_name": "event_A"}, {"name": "evB", "display_name": "event_B", "dtype": "int", "definition": "a definition"}, @@ -111,7 +111,7 @@ def test_multiple_predictions(self): assert expected == actual @pytest.mark.parametrize("search_key,expected_key", [("evA", "filled"), ("evB", "given"), ("evC", "empty")]) - def test_search_existing_returns_item(self, search_key, expected_key): + def test_search_returns_item(self, search_key, expected_key): inputs = [ {"name": "evA", "display_name": "event_A"}, {"name": "evB", "display_name": "event_B", "dtype": "int", "definition": "a definition"},