diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 3d9604cc15b..0f0cb7ae3c9 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -775,6 +775,7 @@ def class_encode_column(self, column: str) -> "Dataset": class_names = sorted(dset.unique(column)) dst_feat = ClassLabel(names=class_names) dset = dset.map(lambda batch: {column: dst_feat.str2int(batch)}, input_columns=column, batched=True) + dset = concatenate_datasets([self.remove_columns([column]), dset], axis=1) new_features = copy.deepcopy(dset.features) new_features[column] = dst_feat diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 3380c17cf32..2b0367218e7 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -1975,7 +1975,7 @@ def test_dataset_add_item(item, in_memory, dataset_dict, arrow_path, transform): dataset = dataset_to_test.add_item(item) assert dataset.data.shape == (5, 3) expected_features = dataset_to_test.features - assert dataset.data.column_names == list(expected_features.keys()) + assert sorted(dataset.data.column_names) == sorted(expected_features.keys()) for feature, expected_dtype in expected_features.items(): assert dataset.features[feature] == expected_dtype assert len(dataset.data.blocks) == 1 if in_memory else 2 # multiple InMemoryTables are consolidated as one