Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
diehlbw committed Oct 1, 2024
1 parent 0379d7e commit 2a1ab77
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 12 deletions.
12 changes: 9 additions & 3 deletions src/seismometer/configuration/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ def get(self, key: str, default: Optional[Any] = None) -> Union[DictionaryItem,
-------
The DictionaryItem with name specified or the default value
"""
return self[key] or default
try:
return self[key]
except KeyError:
return default


class EventDictionary(BaseModel):
Expand Down Expand Up @@ -160,7 +163,10 @@ def get(self, key: str, default: Optional[Any] = None) -> Union[DictionaryItem,
-------
The DictionaryItem with name specified or the default value
"""
return self[key] or default
try:
return self[key]
except KeyError:
return default


class Cohort(BaseModel):
Expand Down Expand Up @@ -426,4 +432,4 @@ def _search_dictionary(dictionary: list[DictionaryItem], key: str) -> Optional[D
for item in dictionary:
if item.name == key:
return item
return
raise KeyError(f"{key} not found")
6 changes: 1 addition & 5 deletions src/seismometer/data/loader/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,7 @@ def assumed_types(config: ConfigProvider, dataframe: pd.DataFrame) -> pd.DataFra
# other
def _gather_defined_types(config: ConfigProvider) -> dict[str, str]:
"""Gathers the defined types from the configuration dictionary."""
return {
defn.name: defn.dtype
for defn in config.prediction_defs.predictions
if getattr(defn, "dtype", None) is not None
}
return {defn.name: defn.dtype for defn in config.prediction_defs.predictions if defn.dtype is not None}


def _infer_datetime(dataframe, cols=None, override_categories=None):
Expand Down
4 changes: 2 additions & 2 deletions src/seismometer/data/pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,9 @@ def _one_event(
if event_base_val_dtype is not None:
try:
one_event[event_base_val_col] = one_event[event_base_val_col].astype(event_base_val_dtype)
except ValueError as exc:
except (ValueError, TypeError) as exc:
raise ConfigurationError(
f"Cannot cast '{event_label}' values to {event_base_val_dtype}. "
f"Cannot cast '{event_label}' values to '{event_base_val_dtype}'. "
+ "Update dictionary config or contact the model owner."
) from exc

Expand Down
4 changes: 2 additions & 2 deletions tests/configuration/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_multiple_events(self):
assert expected == actual

@pytest.mark.parametrize("search_key,expected_key", [("evA", "filled"), ("evB", "given"), ("evC", "empty")])
def test_search_existing_returns_item(self, search_key, expected_key):
def test_search_returns_item(self, search_key, expected_key):
inputs = [
{"name": "evA", "display_name": "event_A"},
{"name": "evB", "display_name": "event_B", "dtype": "int", "definition": "a definition"},
Expand Down Expand Up @@ -111,7 +111,7 @@ def test_multiple_predictions(self):
assert expected == actual

@pytest.mark.parametrize("search_key,expected_key", [("evA", "filled"), ("evB", "given"), ("evC", "empty")])
def test_search_existing_returns_item(self, search_key, expected_key):
def test_search_returns_item(self, search_key, expected_key):
inputs = [
{"name": "evA", "display_name": "event_A"},
{"name": "evB", "display_name": "event_B", "dtype": "int", "definition": "a definition"},
Expand Down

0 comments on commit 2a1ab77

Please sign in to comment.