Skip to content

Commit

Permalink
Merge pull request #32 from vincent-laurent/dev/data-checks
Browse files Browse the repository at this point in the history
[FIX] Dev/data checks
  • Loading branch information
vincent-laurent authored Jan 22, 2024
2 parents 5a711c8 + 72af247 commit 03e3344
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 33 deletions.
58 changes: 30 additions & 28 deletions examples/deepcheck.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
"id": "initial_id",
"metadata": {
"ExecuteTime": {
"end_time": "2024-01-17T12:14:54.684328995Z",
"start_time": "2024-01-17T12:14:54.676331202Z"
"end_time": "2024-01-22T10:17:05.334980215Z",
"start_time": "2024-01-22T10:17:05.334387459Z"
}
},
"outputs": [],
Expand All @@ -22,8 +22,8 @@
"id": "13f6945fd3b104a7",
"metadata": {
"ExecuteTime": {
"end_time": "2024-01-17T12:14:57.921177463Z",
"start_time": "2024-01-17T12:14:54.676497329Z"
"end_time": "2024-01-22T10:17:35.466606739Z",
"start_time": "2024-01-22T10:17:05.334600818Z"
}
},
"outputs": [],
Expand All @@ -45,19 +45,19 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 10,
"id": "c8753b3c",
"metadata": {
"ExecuteTime": {
"end_time": "2024-01-17T12:15:06.392451087Z",
"start_time": "2024-01-17T12:14:57.928390727Z"
"end_time": "2024-01-22T11:03:20.144847150Z",
"start_time": "2024-01-22T11:03:13.185508661Z"
}
},
"outputs": [],
"source": [
"%%capture\n",
"project = Project(problem=\"classification\", project_name=\"test\")\n",
"project.add(DeepCheck())\n",
"project.add(DeepCheck(raise_on_fail=False))\n",
"project.add(Leakage())\n",
"project.start(\n",
" X, y,\n",
Expand All @@ -68,12 +68,12 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 11,
"id": "af750aaa1ad437d9",
"metadata": {
"ExecuteTime": {
"end_time": "2024-01-17T12:15:11.408465846Z",
"start_time": "2024-01-17T12:15:06.394966585Z"
"end_time": "2024-01-22T11:03:28.787023467Z",
"start_time": "2024-01-22T11:03:23.751671881Z"
}
},
"outputs": [],
Expand All @@ -85,12 +85,12 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 12,
"id": "63f097fa57aed229",
"metadata": {
"ExecuteTime": {
"end_time": "2024-01-17T12:15:11.420072204Z",
"start_time": "2024-01-17T12:15:11.408358294Z"
"end_time": "2024-01-22T11:03:28.797016634Z",
"start_time": "2024-01-22T11:03:28.794453724Z"
}
},
"outputs": [],
Expand All @@ -100,22 +100,22 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 13,
"id": "c53c5e4b0367c945",
"metadata": {
"ExecuteTime": {
"end_time": "2024-01-17T12:15:11.655100877Z",
"start_time": "2024-01-17T12:15:11.414969677Z"
"end_time": "2024-01-22T11:03:29.031757938Z",
"start_time": "2024-01-22T11:03:28.797423401Z"
}
},
"outputs": [
{
"data": {
"text/plain": "Accordion(children=(VBox(children=(HTML(value='\\n<h1 id=\"summary_V634PMGICYE1CT27M63U6GG2S\">Checks on train an…",
"text/plain": "Accordion(children=(VBox(children=(HTML(value='\\n<h1 id=\"summary_NBF47SZ45XTV6PWJSL5E20GD1\">Checks on train an…",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "1554ce92e7e74251bca6c0ef362cba21"
"model_id": "099c242347ca494e90bde10611f03336"
}
},
"metadata": {},
Expand All @@ -128,21 +128,21 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 14,
"id": "95998b6bac06baaf",
"metadata": {
"ExecuteTime": {
"start_time": "2024-01-17T12:15:11.809471098Z"
"start_time": "2024-01-22T11:03:29.207641352Z"
}
},
"outputs": [
{
"data": {
"text/plain": "Accordion(children=(VBox(children=(HTML(value='\\n<h1 id=\"summary_HA8QSZT05V5MASJNJBFTEYZG8\">Checks on whole da…",
"text/plain": "Accordion(children=(VBox(children=(HTML(value='\\n<h1 id=\"summary_EA2ZPBHAIM94KR15YBIKYI5ZP\">Checks on whole da…",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "62329497e2b94a7e8d8d5c2ee2ca489d"
"model_id": "df91a2ae306e41c181242252ea47bceb"
}
},
"metadata": {},
Expand All @@ -155,25 +155,27 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 15,
"id": "bf606df8",
"metadata": {
"ExecuteTime": {
"start_time": "2024-01-17T12:15:11.952915817Z"
"end_time": "2024-01-22T11:03:29.862507025Z",
"start_time": "2024-01-22T11:03:29.835950070Z"
}
},
"outputs": [
{
"data": {
"text/plain": "{'leakage': False, 'metric': 0.5502049180327868}"
"text/plain": "{'leakage': False, 'metric': 0.5320931791520027}"
},
"execution_count": 8,
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"project.components[\"Leakage\"].metrics"
"project.components[(\"L\"\n",
" \"eakage\")].metrics"
]
}
],
Expand Down
3 changes: 2 additions & 1 deletion palma/components/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and limitations under the License.
# See the License for the specific language governing permissions and
# limitations under the License.

import pandas as pd

Expand Down
21 changes: 19 additions & 2 deletions palma/components/data_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class DeepCheck(ProjectComponent):
train-test split, such as feature drift, detecting data leakage...
By default, use the default suites train_test_validation and
train_test_leakage
raise_on_fail: bool, optional
Raises error if one test fails
"""

def __init__(
Expand All @@ -52,7 +54,8 @@ def __init__(
List[BaseCheck], BaseSuite] = data_integrity(),
train_test_datasets_checks: Union[
List[BaseCheck], BaseSuite] = Suite(
'Checks train test', train_test_validation())
'Checks train test', train_test_validation()),
raise_on_fail=True
) -> None:

if dataset_parameters:
Expand All @@ -75,6 +78,7 @@ def __init__(
train_test_datasets_checks,
'Checks on train and test datasets'
)
self.raise_on_fail = raise_on_fail

def __call__(self, project: Project) -> None:
"""
Expand All @@ -93,10 +97,23 @@ def __call__(self, project: Project) -> None:
train_dataset=self.__train_dataset,
test_dataset=self.__test_dataset
)

for results in [self.train_test_checks_results,
self.dataset_checks_results]:
logger.logger.log_artifact(results, f'{results.name}')

list_results = [
*self.train_test_checks_results.get_not_passed_checks(),
*self.dataset_checks_results.get_not_passed_checks(),
]
if self.raise_on_fail and len(list_results):
line = "="*50
raise ValueError(
f"The following tests did not pass :"
f"{line}\n"
f"{list_results}\n"
f"{line}")

def __generate_datasets(self, project: Project, **kwargs) -> None:
"""
Generate :class:`deepchecks.Dataset`
Expand All @@ -110,14 +127,14 @@ def __generate_datasets(self, project: Project, **kwargs) -> None:

df = pd.concat([project.X, project.y], axis=1)
df.columns = [*project.X.columns.to_list(), "target"]
print(df.columns)
self.__dataset = Dataset(df, label="target", **kwargs)

self.__train_dataset = self.__dataset.copy(
df.loc[project.validation_strategy.train_index])
self.__test_dataset = self.__dataset.copy(
df.loc[project.validation_strategy.test_index])


@staticmethod
def __generate_suite(
checks: Union[List[BaseCheck], BaseSuite],
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies = [
"flaml[automl]>=2",
"matplotlib>=3.4",
"numpy >= 1",
"scikit-learn >= 1",
"scikit-learn >= 1, <1.4 ",
"pandas >= 1",
"shap",
"llvmlite >= 0.39",
Expand Down
3 changes: 2 additions & 1 deletion tests/test_component/test_data_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@


def test_deep_check(classification_project):
dc = DeepCheck(dataset_parameters={"label": classification_project.y.name})
dc = DeepCheck(dataset_parameters={"label": classification_project.y.name},
raise_on_fail=False)
dc(classification_project)


Expand Down

0 comments on commit 03e3344

Please sign in to comment.