Skip to content

Commit

Permalink
fix(steps): avoid bool-to-int cast and handle NULL (#71)
Browse files Browse the repository at this point in the history
Co-authored-by: Deepyaman Datta <deepyaman.datta@utexas.edu>
  • Loading branch information
jitingxu1 and deepyaman authored Apr 23, 2024
1 parent 6bd2dcb commit 986385c
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 5 deletions.
6 changes: 4 additions & 2 deletions ibisml/steps/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ class OneHotEncode(Step):
"""A step for one-hot encoding select columns.
The original input column is dropped, and N-category new columns are
created with names like ``{input_column}_{category}``.
created with names like ``{input_column}_{category}``. Unknown categories
will be ignored during transformation; the resulting one-hot encoded
columns for this feature will be all zeros.
Parameters
----------
Expand Down Expand Up @@ -152,7 +154,7 @@ def transform_table(self, table: ir.Table) -> ir.Table:

return table.mutate(
[
(table[col] == cat).cast("int8").name(f"{col}_{cat}")
ibis.ifelse((table[col] == cat), 1, 0).name(f"{col}_{cat}")
for col, cats in self.categories_.items()
for cat in cats
]
Expand Down
38 changes: 35 additions & 3 deletions tests/test_encode.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import ibis
import pandas as pd
import pandas.testing as tm
import pytest

import ibisml as ml


def test_count_encode():
t_train = ibis.memtable(
@pytest.fixture()
def t_train():
return ibis.memtable(
{
"time": [
pd.Timestamp("2016-05-25 13:30:00.023"),
Expand All @@ -20,7 +23,11 @@ def test_count_encode():
"ticker": ["GOOG", "MSFT", "MSFT", "MSFT", None, "AAPL", "GOOG", "MSFT"],
}
)
t_test = ibis.memtable(


@pytest.fixture()
def t_test():
return ibis.memtable(
{
"time": [
pd.Timestamp("2016-05-25 13:30:00.023"),
Expand All @@ -34,7 +41,32 @@ def test_count_encode():
}
)


def test_count_encode(t_train, t_test):
step = ml.CountEncode("ticker")
step.fit_table(t_train, ml.core.Metadata())
res = step.transform_table(t_test)
assert res.to_pandas().sort_values(by="time").ticker.to_list() == [4, 4, 2, 2, 0, 0]


def test_one_hot_encode(t_train, t_test):
step = ml.OneHotEncode("ticker")
step.fit_table(t_train, ml.core.Metadata())
result = step.transform_table(t_test)
expected = pd.DataFrame(
{
"time": [
pd.Timestamp("2016-05-25 13:30:00.023"),
pd.Timestamp("2016-05-25 13:30:00.038"),
pd.Timestamp("2016-05-25 13:30:00.048"),
pd.Timestamp("2016-05-25 13:30:00.049"),
pd.Timestamp("2016-05-25 13:30:00.050"),
pd.Timestamp("2016-05-25 13:30:00.051"),
],
"ticker_AAPL": [0, 0, 0, 0, 0, 0],
"ticker_GOOG": [0, 0, 1, 1, 0, 0],
"ticker_MSFT": [1, 1, 0, 0, 0, 0],
"ticker_None": [0, 0, 0, 0, 0, 1],
}
)
tm.assert_frame_equal(result.execute(), expected, check_dtype=False)

0 comments on commit 986385c

Please sign in to comment.