From 986385c470e2647e2786c625e5b48a2d378e1564 Mon Sep 17 00:00:00 2001
From: Jiting Xu <126802425+jitingxu1@users.noreply.github.com>
Date: Tue, 23 Apr 2024 16:12:55 -0700
Subject: [PATCH] fix(steps): avoid bool-to-int cast and handle NULL (#71)

Co-authored-by: Deepyaman Datta <deepyaman.datta@utexas.edu>
---
 ibisml/steps/encode.py |  6 ++++--
 tests/test_encode.py   | 38 +++++++++++++++++++++++++++++++++++---
 2 files changed, 39 insertions(+), 5 deletions(-)

diff --git a/ibisml/steps/encode.py b/ibisml/steps/encode.py
index abcf1e9..ace4d39 100644
--- a/ibisml/steps/encode.py
+++ b/ibisml/steps/encode.py
@@ -72,7 +72,9 @@ class OneHotEncode(Step):
     """A step for one-hot encoding select columns.
 
     The original input column is dropped, and N-category new columns are
-    created with names like ``{input_column}_{category}``.
+    created with names like ``{input_column}_{category}``. Unknown categories
+    will be ignored during transformation; the resulting one-hot encoded
+    columns for this feature will be all zeros.
 
     Parameters
     ----------
@@ -152,7 +154,7 @@ def transform_table(self, table: ir.Table) -> ir.Table:
 
         return table.mutate(
             [
-                (table[col] == cat).cast("int8").name(f"{col}_{cat}")
+                ibis.ifelse((table[col] == cat), 1, 0).name(f"{col}_{cat}")
                 for col, cats in self.categories_.items()
                 for cat in cats
             ]
diff --git a/tests/test_encode.py b/tests/test_encode.py
index 9dff7d8..756352b 100644
--- a/tests/test_encode.py
+++ b/tests/test_encode.py
@@ -1,11 +1,14 @@
 import ibis
 import pandas as pd
+import pandas.testing as tm
+import pytest
 
 import ibisml as ml
 
 
-def test_count_encode():
-    t_train = ibis.memtable(
+@pytest.fixture()
+def t_train():
+    return ibis.memtable(
         {
             "time": [
                 pd.Timestamp("2016-05-25 13:30:00.023"),
@@ -20,7 +23,11 @@ def test_count_encode():
             "ticker": ["GOOG", "MSFT", "MSFT", "MSFT", None, "AAPL", "GOOG", "MSFT"],
         }
     )
-    t_test = ibis.memtable(
+
+
+@pytest.fixture()
+def t_test():
+    return ibis.memtable(
         {
             "time": [
                 pd.Timestamp("2016-05-25 13:30:00.023"),
@@ -34,7 +41,32 @@ def test_count_encode():
         }
     )
 
+
+def test_count_encode(t_train, t_test):
     step = ml.CountEncode("ticker")
     step.fit_table(t_train, ml.core.Metadata())
     res = step.transform_table(t_test)
     assert res.to_pandas().sort_values(by="time").ticker.to_list() == [4, 4, 2, 2, 0, 0]
+
+
+def test_one_hot_encode(t_train, t_test):
+    step = ml.OneHotEncode("ticker")
+    step.fit_table(t_train, ml.core.Metadata())
+    result = step.transform_table(t_test)
+    expected = pd.DataFrame(
+        {
+            "time": [
+                pd.Timestamp("2016-05-25 13:30:00.023"),
+                pd.Timestamp("2016-05-25 13:30:00.038"),
+                pd.Timestamp("2016-05-25 13:30:00.048"),
+                pd.Timestamp("2016-05-25 13:30:00.049"),
+                pd.Timestamp("2016-05-25 13:30:00.050"),
+                pd.Timestamp("2016-05-25 13:30:00.051"),
+            ],
+            "ticker_AAPL": [0, 0, 0, 0, 0, 0],
+            "ticker_GOOG": [0, 0, 1, 1, 0, 0],
+            "ticker_MSFT": [1, 1, 0, 0, 0, 0],
+            "ticker_None": [0, 0, 0, 0, 0, 1],
+        }
+    )
+    tm.assert_frame_equal(result.execute(), expected, check_dtype=False)