Skip to content

Commit

Permalink
fix(utils): remove redundant column in train_test_split() (#131)
Browse files Browse the repository at this point in the history
  • Loading branch information
jitingxu1 authored Aug 23, 2024
1 parent f0263b9 commit 80881f8
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 2 additions & 2 deletions ibis_ml/utils/_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,6 @@ def train_test_split(
< int((1 - test_size) * num_buckets)
)

return table[table.train].drop(["combined_key"]), table[~table.train].drop(
["combined_key"]
return table[table.train].drop(["combined_key", "train"]), table[~table.train].drop(
["combined_key", "train"]
)
2 changes: 2 additions & 0 deletions tests/test_train_test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def test_train_test_split():
# Check counts and overlaps in train and test dataset
assert train_table.count().execute() + test_table.count().execute() == N
assert train_table.intersect(test_table).count().execute() == 0
assert set(train_table.columns) == set(table.columns)
assert set(test_table.columns) == set(table.columns)

# Check reproducibility
reproduced_train_table, reproduced_test_table = ml.train_test_split(
Expand Down

0 comments on commit 80881f8

Please sign in to comment.