Skip to content

Commit

Permalink
Fixes to to_tf_dataset (#3085)
Browse files Browse the repository at this point in the history
* Fix for columns added by the collation function

* More special-casing around labels

* Style pass

* Tweak to handling of column names

* Adding TODO with the roadmap
  • Loading branch information
Rocketknight1 authored Oct 21, 2021
1 parent 2f311fb commit a1c8b49
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,11 @@ def to_tf_dataset(
a small buffer of batches for training. Improves performance by allowing data to be loaded in the
background while the model is training.
"""

# TODO There is some hacky hardcoding in this function that needs to be fixed.
# We're planning to rework it so less code is needed at the start to remove columns before
# we know the final list of fields (post-data collator). This should clean up most of the special
# casing while retaining the API.
if config.TF_AVAILABLE:
import tensorflow as tf
else:
Expand Down Expand Up @@ -328,13 +333,14 @@ def to_tf_dataset(
# Special casing when the dataset has 'label' and the model expects 'labels' and the collator fixes it up for us
if "labels" in cols_to_retain and "labels" not in self.features and "label" in self.features:
cols_to_retain[cols_to_retain.index("labels")] = "label"
# Watch for nonexistent columns, except those that the data collators add for us
for col in cols_to_retain:
if col not in self.features:
if col not in self.features and not (col in ("attention_mask", "labels") and collate_fn is not None):
raise ValueError(f"Couldn't find column {col} in dataset.")
if drop_remainder is None:
# We assume that if you're shuffling it's the train set, so we drop the remainder unless told not to
drop_remainder = shuffle
dataset = self.with_format("numpy", columns=cols_to_retain)
dataset = self.with_format("python", columns=[col for col in cols_to_retain if col in self.features])

def numpy_pad(data):
try:
Expand Down Expand Up @@ -432,6 +438,18 @@ def add_dummy_labels(input_batch):

tf_dataset = tf_dataset.map(add_dummy_labels)

def rename_label_col(inputs, labels=None):
if not isinstance(inputs, tf.Tensor):
if "label" in inputs:
inputs["labels"] = inputs["label"]
del inputs["label"]
if labels is None:
return inputs
else:
return inputs, labels

tf_dataset = tf_dataset.map(rename_label_col)

if prefetch:
tf_dataset = tf_dataset.prefetch(tf.data.experimental.AUTOTUNE)

Expand Down

1 comment on commit a1c8b49

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Show benchmarks

PyArrow==3.0.0

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.012804 / 0.011353 (0.001451) 0.004307 / 0.011008 (-0.006701) 0.040891 / 0.038508 (0.002383) 0.040819 / 0.023109 (0.017710) 0.359378 / 0.275898 (0.083479) 0.492523 / 0.323480 (0.169044) 0.008702 / 0.007986 (0.000716) 0.005019 / 0.004328 (0.000691) 0.010792 / 0.004250 (0.006542) 0.045095 / 0.037052 (0.008043) 0.393801 / 0.258489 (0.135312) 0.393365 / 0.293841 (0.099524) 0.041128 / 0.128546 (-0.087418) 0.012062 / 0.075646 (-0.063584) 0.319988 / 0.419271 (-0.099283) 0.060594 / 0.043533 (0.017061) 0.370510 / 0.255139 (0.115371) 0.394873 / 0.283200 (0.111673) 0.100307 / 0.141683 (-0.041376) 2.059412 / 1.452155 (0.607258) 2.131114 / 1.492716 (0.638398)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.262238 / 0.018006 (0.244232) 0.590124 / 0.000490 (0.589634) 0.017545 / 0.000200 (0.017346) 0.000387 / 0.000054 (0.000333)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.043307 / 0.037411 (0.005896) 0.027846 / 0.014526 (0.013320) 0.030782 / 0.176557 (-0.145775) 0.154794 / 0.737135 (-0.582341) 0.034248 / 0.296338 (-0.262091)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.650631 / 0.215209 (0.435422) 6.304625 / 2.077655 (4.226971) 2.497522 / 1.504120 (0.993402) 2.079268 / 1.541195 (0.538073) 1.988391 / 1.468490 (0.519901) 0.653644 / 4.584777 (-3.931133) 7.157145 / 3.745712 (3.411433) 1.473428 / 5.269862 (-3.796434) 1.399669 / 4.565676 (-3.166008) 0.065894 / 0.424275 (-0.358381) 0.005759 / 0.007607 (-0.001848) 0.760635 / 0.226044 (0.534590) 7.579795 / 2.268929 (5.310866) 3.168011 / 55.444624 (-52.276614) 2.448392 / 6.876477 (-4.428085) 2.499644 / 2.142072 (0.357571) 0.864347 / 4.805227 (-3.940880) 0.172593 / 6.500664 (-6.328071) 0.069830 / 0.075469 (-0.005639)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.977369 / 1.841788 (0.135581) 15.330452 / 8.074308 (7.256144) 42.059023 / 10.191392 (31.867631) 0.925406 / 0.680424 (0.244982) 0.651387 / 0.534201 (0.117186) 0.281232 / 0.579283 (-0.298051) 0.703886 / 0.434364 (0.269522) 0.234929 / 0.540337 (-0.305409) 0.247352 / 1.386936 (-1.139584)
PyArrow==latest
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.010681 / 0.011353 (-0.000672) 0.004553 / 0.011008 (-0.006456) 0.036713 / 0.038508 (-0.001796) 0.037063 / 0.023109 (0.013954) 0.373386 / 0.275898 (0.097488) 0.425040 / 0.323480 (0.101560) 0.009738 / 0.007986 (0.001753) 0.005339 / 0.004328 (0.001011) 0.009843 / 0.004250 (0.005593) 0.047626 / 0.037052 (0.010574) 0.374780 / 0.258489 (0.116290) 0.431167 / 0.293841 (0.137326) 0.040272 / 0.128546 (-0.088274) 0.012509 / 0.075646 (-0.063138) 0.325150 / 0.419271 (-0.094121) 0.059581 / 0.043533 (0.016048) 0.395511 / 0.255139 (0.140372) 0.426031 / 0.283200 (0.142831) 0.094057 / 0.141683 (-0.047626) 2.107438 / 1.452155 (0.655283) 2.265468 / 1.492716 (0.772752)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.425804 / 0.018006 (0.407798) 0.580608 / 0.000490 (0.580119) 0.080336 / 0.000200 (0.080136) 0.000373 / 0.000054 (0.000318)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.040243 / 0.037411 (0.002832) 0.024764 / 0.014526 (0.010238) 0.031440 / 0.176557 (-0.145117) 0.142253 / 0.737135 (-0.594882) 0.030270 / 0.296338 (-0.266069)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.606332 / 0.215209 (0.391123) 6.280219 / 2.077655 (4.202565) 2.354131 / 1.504120 (0.850012) 1.969234 / 1.541195 (0.428039) 1.952314 / 1.468490 (0.483824) 0.671691 / 4.584777 (-3.913086) 6.956387 / 3.745712 (3.210675) 1.620316 / 5.269862 (-3.649545) 1.461920 / 4.565676 (-3.103756) 0.072424 / 0.424275 (-0.351851) 0.005960 / 0.007607 (-0.001647) 0.757513 / 0.226044 (0.531469) 7.528449 / 2.268929 (5.259521) 3.068972 / 55.444624 (-52.375652) 2.404942 / 6.876477 (-4.471534) 2.422613 / 2.142072 (0.280541) 0.821996 / 4.805227 (-3.983231) 0.158996 / 6.500664 (-6.341668) 0.065984 / 0.075469 (-0.009485)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.998908 / 1.841788 (0.157120) 15.861052 / 8.074308 (7.786744) 42.853785 / 10.191392 (32.662393) 0.942990 / 0.680424 (0.262566) 0.670344 / 0.534201 (0.136143) 0.288479 / 0.579283 (-0.290804) 0.734256 / 0.434364 (0.299892) 0.249351 / 0.540337 (-0.290987) 0.262006 / 1.386936 (-1.124930)

CML watermark

Please sign in to comment.