From 986385c470e2647e2786c625e5b48a2d378e1564 Mon Sep 17 00:00:00 2001 From: Jiting Xu <126802425+jitingxu1@users.noreply.github.com> Date: Tue, 23 Apr 2024 16:12:55 -0700 Subject: [PATCH] fix(steps): avoid bool-to-int cast and handle NULL (#71) Co-authored-by: Deepyaman Datta <deepyaman.datta@utexas.edu> --- ibisml/steps/encode.py | 6 ++++-- tests/test_encode.py | 38 +++++++++++++++++++++++++++++++++++--- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/ibisml/steps/encode.py b/ibisml/steps/encode.py index abcf1e9..ace4d39 100644 --- a/ibisml/steps/encode.py +++ b/ibisml/steps/encode.py @@ -72,7 +72,9 @@ class OneHotEncode(Step): """A step for one-hot encoding select columns. The original input column is dropped, and N-category new columns are - created with names like ``{input_column}_{category}``. + created with names like ``{input_column}_{category}``. Unknown categories + will be ignored during transformation; the resulting one-hot encoded + columns for this feature will be all zeros. Parameters ---------- @@ -152,7 +154,7 @@ def transform_table(self, table: ir.Table) -> ir.Table: return table.mutate( [ - (table[col] == cat).cast("int8").name(f"{col}_{cat}") + ibis.ifelse((table[col] == cat), 1, 0).name(f"{col}_{cat}") for col, cats in self.categories_.items() for cat in cats ] diff --git a/tests/test_encode.py b/tests/test_encode.py index 9dff7d8..756352b 100644 --- a/tests/test_encode.py +++ b/tests/test_encode.py @@ -1,11 +1,14 @@ import ibis import pandas as pd +import pandas.testing as tm +import pytest import ibisml as ml -def test_count_encode(): - t_train = ibis.memtable( +@pytest.fixture() +def t_train(): + return ibis.memtable( { "time": [ pd.Timestamp("2016-05-25 13:30:00.023"), @@ -20,7 +23,11 @@ def test_count_encode(): "ticker": ["GOOG", "MSFT", "MSFT", "MSFT", None, "AAPL", "GOOG", "MSFT"], } ) - t_test = ibis.memtable( + + +@pytest.fixture() +def t_test(): + return ibis.memtable( { "time": [ pd.Timestamp("2016-05-25 13:30:00.023"), @@ -34,7 +41,32 @@ def test_count_encode(): } ) + +def test_count_encode(t_train, t_test): step = ml.CountEncode("ticker") step.fit_table(t_train, ml.core.Metadata()) res = step.transform_table(t_test) assert res.to_pandas().sort_values(by="time").ticker.to_list() == [4, 4, 2, 2, 0, 0] + + +def test_one_hot_encode(t_train, t_test): + step = ml.OneHotEncode("ticker") + step.fit_table(t_train, ml.core.Metadata()) + result = step.transform_table(t_test) + expected = pd.DataFrame( + { + "time": [ + pd.Timestamp("2016-05-25 13:30:00.023"), + pd.Timestamp("2016-05-25 13:30:00.038"), + pd.Timestamp("2016-05-25 13:30:00.048"), + pd.Timestamp("2016-05-25 13:30:00.049"), + pd.Timestamp("2016-05-25 13:30:00.050"), + pd.Timestamp("2016-05-25 13:30:00.051"), + ], + "ticker_AAPL": [0, 0, 0, 0, 0, 0], + "ticker_GOOG": [0, 0, 1, 1, 0, 0], + "ticker_MSFT": [1, 1, 0, 0, 0, 0], + "ticker_None": [0, 0, 0, 0, 0, 1], + } + ) + tm.assert_frame_equal(result.execute(), expected, check_dtype=False)