From a7c7d9d9a62ccc184405709c562440bc1a7bb4ff Mon Sep 17 00:00:00 2001 From: MaximCarbonell Date: Tue, 26 Dec 2023 13:32:57 -0800 Subject: [PATCH] new feature to filter by column --- README.md | 9 +++++++++ eurybia/core/smartdrift.py | 6 ++++++ 2 files changed, 15 insertions(+) diff --git a/README.md b/README.md index 2ca09a5..47a1bc8 100644 --- a/README.md +++ b/README.md @@ -82,6 +82,15 @@ sd.compile( full_validation=True, # Optional: to save time, leave the default False value. If True, analyze consistency on modalities between columns. date_compile_auc="01/01/2022", # Optional: useful when computing the drift for a time that is not now datadrift_file="datadrift_auc.csv", # Optional: name of the csv file that contains the performance history of data drift + filter_column="name", # Optional: Name of the column you wish to filter + filter_values=[ + "France", + "Ottomans", + "Austria", + "Poland", + "Brandenburg", + "Bohemia", + ], # Optional: Names of the values from the column you chose above that you wish to filter. ) ``` diff --git a/eurybia/core/smartdrift.py b/eurybia/core/smartdrift.py index ec21f75..d79ebef 100644 --- a/eurybia/core/smartdrift.py +++ b/eurybia/core/smartdrift.py @@ -198,6 +198,8 @@ def __init__( def compile( self, full_validation=False, + filter_column=None, + filter_values=None, ignore_cols: list = list(), sampling=True, sample_size=100000, @@ -236,6 +238,10 @@ def compile( >>> SD.compile() """ + print("FILTERING") + if filter_column and filter_values: + self.df_current = self.df_current[self.df_current[filter_column].isin(filter_values)] + self.df_baseline = self.df_baseline[self.df_baseline[filter_column].isin(filter_values)] if datadrift_file is not None: self.datadrift_file = datadrift_file if hyperparameter is not None: