Skip to content

Commit

Permalink
fix unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gbowlin committed Oct 8, 2024
1 parent c521e66 commit 2e20f1d
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions src/seismometer/data/performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __call__(self, dataframe: pd.DataFrame, **kwargs) -> dict[str, float]:

@export
class BinaryClassifierMetricGenerator(MetricGenerator):
def __init__(self, rho: float = DEFAULT_RHO):
def __init__(self, rho: float = None):
"""
A class that generates Binary classifier metrics from a dataframe.
Keeps track of available metric names as well as the function to call to generate them.
Expand All @@ -76,6 +76,7 @@ def __init__(self, rho: float = DEFAULT_RHO):
rho : float, optional
The relative risk reduction for NNT calculation, by default DEFAULT_RHO.
"""
rho = rho or DEFAULT_RHO
metric_names = STATNAMES + [f"NNT@{rho:0.3n}"]

metric_fn = partial(calculate_binary_stats, rho=rho)
Expand Down Expand Up @@ -147,7 +148,7 @@ def calculate_binary_stats(
target_col: str,
score_col: str,
score_threshold: float = 0.5,
rho: float = DEFAULT_RHO,
rho: float = None,
) -> dict[str, float]:
"""
Generates binary classifier metrics from a dataframe, as a specific threshold
Expand All @@ -165,6 +166,7 @@ def calculate_binary_stats(
rho : float, optional
The relative risk reduction for NNT calculation, by default DEFAULT_RHO.
"""
rho = rho or DEFAULT_RHO
score_threshold_integer = int(score_threshold * 100)
y_true = dataframe[target_col]
y_pred = dataframe[score_col]
Expand All @@ -178,7 +180,7 @@ def calculate_bin_stats(
y_pred: Optional[pd.Series] = None,
keep_score_values: bool = False,
not_point_thresholds: bool = False,
rho: float = DEFAULT_RHO,
rho: float = None,
) -> pd.DataFrame:
"""
Calculate summary statistics from y_true and y_pred (y_proba[:,1] for binary classification) arrays.
Expand All @@ -201,6 +203,7 @@ def calculate_bin_stats(
-------
pd.DataFrame of stats, rows for each threshold value between 0 and 100 with columns for basic statistics.
"""
rho = rho or DEFAULT_RHO
y_true = y_true.astype(float) # Expect numeric labels

keep = ~(np.isnan(y_true) | np.isnan(y_pred))
Expand Down Expand Up @@ -310,8 +313,7 @@ def calculate_nnt(arr: np.ndarray, rho: Optional[Number | None] = None) -> np.nd
.. [#med_metrics] eotles/med_metrics: Initial public release. Zenodo; 2024.
http://dx.doi.org/10.5281/ZENODO.10514448
"""
if rho is None:
rho = DEFAULT_RHO
rho = rho or DEFAULT_RHO

# Divide by zero is ok
with np.errstate(invalid="ignore", divide="ignore"):
Expand Down

0 comments on commit 2e20f1d

Please sign in to comment.