From fc44950c447b623dd43c043637fc1d49999ee48a Mon Sep 17 00:00:00 2001 From: Hoang-Anh Ngo Date: Tue, 31 Oct 2023 16:21:55 +0700 Subject: [PATCH] Refactor sad.py file. --- river/anomaly/sad.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/river/anomaly/sad.py b/river/anomaly/sad.py index f85ae8463e..959719a589 100644 --- a/river/anomaly/sad.py +++ b/river/anomaly/sad.py @@ -21,7 +21,7 @@ class StandardAbsoluteDeviation(anomaly.base.AnomalyDetector): Parameters ---------- sub_stat - The statistic to be substracted, then divided by the standard deviation for scoring. + The statistic to be subtracted, then divided by the standard deviation for scoring. This parameter must be either "mean" or "median". References @@ -58,29 +58,32 @@ class StandardAbsoluteDeviation(anomaly.base.AnomalyDetector): """ - def __init__(self, subtracted_statistic: str = "mean"): - if subtracted_statistic == "mean": - self.subtracted_statistic = stats.Mean() - elif subtracted_statistic == "median": - self.subtracted_statistic = stats.Quantile(q=0.5) + def __init__(self, sub_stat: str = "mean"): + self.variance = stats.Var() + self.sub_stat = sub_stat + + if self.sub_stat == "mean": + self.subtracted_statistic_estimator = stats.Mean() + elif self.sub_stat == "median": + self.subtracted_statistic_estimator = stats.Quantile(q=0.5) else: raise ValueError( - f"Unknown subtracted statistic {subtracted_statistic}, expected one of median, mean." + f"Unknown subtracted statistic {self.sub_stat}, expected one of median, mean." ) - self.variance = stats.Var() - def learn_one(self, x): assert len(x) == 1 ((x_key, x_value),) = x.items() self.variance.update(x_value) - self.subtracted_statistic.update(x_value) + self.subtracted_statistic_estimator.update(x_value) def score_one(self, x): assert len(x) == 1 ((x_key, x_value),) = x.items() - score = (x_value - self.subtracted_statistic.get()) / (self.variance.get() ** 0.5 + 1e-10) + score = (x_value - self.subtracted_statistic_estimator.get()) / ( + self.variance.get() ** 0.5 + 1e-10 + ) return abs(score)