diff --git a/ibis_ml/utils/_split.py b/ibis_ml/utils/_split.py index 7edc577..6ed0a31 100644 --- a/ibis_ml/utils/_split.py +++ b/ibis_ml/utils/_split.py @@ -90,6 +90,6 @@ def train_test_split( < int((1 - test_size) * num_buckets) ) - return table[table.train].drop(["combined_key"]), table[~table.train].drop( - ["combined_key"] + return table[table.train].drop(["combined_key", "train"]), table[~table.train].drop( + ["combined_key", "train"] ) diff --git a/tests/test_train_test_split.py b/tests/test_train_test_split.py index c0a82f0..9eaa7d2 100644 --- a/tests/test_train_test_split.py +++ b/tests/test_train_test_split.py @@ -16,6 +16,8 @@ def test_train_test_split(): # Check counts and overlaps in train and test dataset assert train_table.count().execute() + test_table.count().execute() == N assert train_table.intersect(test_table).count().execute() == 0 + assert set(train_table.columns) == set(table.columns) + assert set(test_table.columns) == set(table.columns) # Check reproducibility reproduced_train_table, reproduced_test_table = ml.train_test_split(