Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes to to_tf_dataset #3085

Merged
merged 5 commits into from
Oct 21, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,11 @@ def to_tf_dataset(
a small buffer of batches for training. Improves performance by allowing data to be loaded in the
background while the model is training.
"""

# TODO There is some hacky hardcoding in this function that needs to be fixed.
# We're planning to rework it so less code is needed at the start to remove columns before
# we know the final list of fields (post-data collator). This should clean up most of the special
# casing while retaining the API.
if config.TF_AVAILABLE:
import tensorflow as tf
else:
Expand Down Expand Up @@ -328,13 +333,14 @@ def to_tf_dataset(
# Special casing when the dataset has 'label' and the model expects 'labels' and the collator fixes it up for us
if "labels" in cols_to_retain and "labels" not in self.features and "label" in self.features:
cols_to_retain[cols_to_retain.index("labels")] = "label"
# Watch for nonexistent columns, except those that the data collators add for us
for col in cols_to_retain:
if col not in self.features:
if col not in self.features and not (col in ("attention_mask", "labels") and collate_fn is not None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why hardcode some column names here ? It feels hacky

Changing the collate_fn function could break this no ?

Copy link
Member Author

@Rocketknight1 Rocketknight1 Oct 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's very hacky, yeah. I need to change this to make it work properly, but I was under time pressure to get notebooks and everything ready in time to record videos for the course.

I think a better solution would be to take a remove_columns list instead of columns, and then I wouldn't have to worry so much about new columns being added by the data collator - I assume that all of those are being kept. WDYT?

raise ValueError(f"Couldn't find column {col} in dataset.")
if drop_remainder is None:
# We assume that if you're shuffling it's the train set, so we drop the remainder unless told not to
drop_remainder = shuffle
dataset = self.with_format("numpy", columns=cols_to_retain)
dataset = self.with_format("python", columns=[col for col in cols_to_retain if col in self.features])

def numpy_pad(data):
try:
Expand Down Expand Up @@ -432,6 +438,18 @@ def add_dummy_labels(input_batch):

tf_dataset = tf_dataset.map(add_dummy_labels)

def rename_label_col(inputs, labels=None):
if not isinstance(inputs, tf.Tensor):
if "label" in inputs:
inputs["labels"] = inputs["label"]
del inputs["label"]
if labels is None:
return inputs
else:
return inputs, labels

tf_dataset = tf_dataset.map(rename_label_col)

if prefetch:
tf_dataset = tf_dataset.prefetch(tf.data.experimental.AUTOTUNE)

Expand Down