Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR/Data] Add collate_fn to iter_torch_batches #32412

Merged
merged 29 commits into from
Feb 21, 2023

Conversation

amogkam
Copy link
Contributor

@amogkam amogkam commented Feb 10, 2023

Signed-off-by: amogkam amogkamsetty@yahoo.com

Adds a collate_fn argument to iter_batches and iter_torch_batches. This is useful for any last meter preprocessing that's done directly on the batch to be used for training.

Closes #32224.

Why are these changes needed?

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
@amogkam amogkam marked this pull request as ready for review February 11, 2023 01:35
python/ray/data/dataset.py Outdated Show resolved Hide resolved
python/ray/data/dataset.py Outdated Show resolved Hide resolved
python/ray/air/_internal/torch_utils.py Outdated Show resolved Hide resolved
python/ray/data/dataset.py Outdated Show resolved Hide resolved
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Copy link
Member

@bveeramani bveeramani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

python/ray/air/_internal/torch_utils.py Outdated Show resolved Hide resolved
python/ray/data/_internal/block_batching.py Outdated Show resolved Hide resolved
python/ray/data/_internal/bulk_dataset_iterator.py Outdated Show resolved Hide resolved
python/ray/data/_internal/pipelined_dataset_iterator.py Outdated Show resolved Hide resolved
python/ray/data/dataset.py Outdated Show resolved Hide resolved
python/ray/data/dataset.py Outdated Show resolved Hide resolved
Signed-off-by: Balaji Veeramani <balaji@anyscale.com>
Signed-off-by: Balaji Veeramani <balaji@anyscale.com>
Signed-off-by: Balaji Veeramani <balaji@anyscale.com>
python/ray/air/_internal/torch_utils.py Outdated Show resolved Hide resolved
python/ray/data/dataset.py Show resolved Hide resolved
@@ -2858,6 +2859,7 @@ def iter_batches(
to select ``numpy.ndarray`` for tensor datasets and
``Dict[str, numpy.ndarray]`` for tabular datasets. Default is "default".
drop_last: Whether to drop the last batch if it's incomplete.
collate_fn: A function to apply to each data batch before returning it.
Copy link
Contributor

@clarkzinzow clarkzinzow Feb 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I don't know if it makes sense to have this as part of the ds.iter_batches() API, since the primary motivation behind this feature is to (1) make the Torch DataLoader port easier, and (2) provide a hook to override how we convert NumPy ndarrays to Torch Tensors. collate_fn doesn't provide much critical value for the core ds.iter_batches() API: it's functionally exactly equivalent to

for batch in ds.iter_batches():
    batch = collate_fn(batch)

Is the primary motivation for this to be able to include it in the local pipelining operation chain, to improve collation performance? If that's the case, I'm wondering if ds.iter_torch_batches() shouldn't reuse ds.iter_batches(), and we should instead have ds.iter_torch_batches() use the internal batch_block_refs API directly, which can expose collate_fn. Most of the batching logic is captured in batch_block_refs anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was also thinking the same...will make the change.

Easiest thing is just to have a private arg in iter_batches. IMO that would be preferable for future maintenance than both APIs directly using batch_block_refs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to make _collate_fn private for iter_batches

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I'd rather not muddy the user-facing .iter_batches() API even with a private _collate_fn arg, if possible, since it's already a pretty wide interface. And since batch_block_refs contains most of the batching logic, I don't think that there's much maintenance or drift risk here.

What are your primary concerns for maintenance here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's mostly about the timing logic, but it's not a big deal. Updated to using batch_block_refs directly.

Copy link
Contributor

@clarkzinzow clarkzinzow Feb 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If stuff like execution triggering and stats collection becomes involved enough around batch_block_refs, we can always introduce a private Dataset._iter_batches() method that can support all of the other batch-based iterator APIs without making the user-facing Dataset.iter_batches() API any wider.

But for the sake of this PR, duplicating that thin logic around batch_block_refs seems fine to me!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this actually doesn't work for the pipelined case...for some reason that calls into Dataset.iter_torch_batches

Copy link
Contributor Author

@amogkam amogkam Feb 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DatasetPipeline makes an implicit assumption that Dataset.iter_torch_batches calls into self.iter_batches, which no longer holds true if we change this implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think private arg is the way to go, because the abstractions here are very messy. We can remove this later once we deprecate DatasetPipeline.

Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
amogkam and others added 12 commits February 14, 2023 16:45
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: Balaji Veeramani <balaji@anyscale.com>
Signed-off-by: Balaji Veeramani <balaji@anyscale.com>
Signed-off-by: Balaji Veeramani <balaji@anyscale.com>
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Copy link
Contributor

@clarkzinzow clarkzinzow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, nice work on this one!

python/ray/data/_internal/block_batching.py Outdated Show resolved Hide resolved
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
@amogkam amogkam merged commit 8c054fc into ray-project:master Feb 21, 2023
@amogkam amogkam deleted the collate_fn branch February 21, 2023 18:19
edoakes pushed a commit to edoakes/ray that referenced this pull request Mar 22, 2023
Adds a collate_fn argument to iter_batches and iter_torch_batches. This is useful for any last meter preprocessing that's done directly on the batch to be used for training.

Closes ray-project#32224.

Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: Balaji Veeramani <balaji@anyscale.com>
Co-authored-by: Balaji Veeramani <balaji@anyscale.com>
Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com>
peytondmurray pushed a commit to peytondmurray/ray that referenced this pull request Mar 22, 2023
Adds a collate_fn argument to iter_batches and iter_torch_batches. This is useful for any last meter preprocessing that's done directly on the batch to be used for training.

Closes ray-project#32224.

Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: Balaji Veeramani <balaji@anyscale.com>
Co-authored-by: Balaji Veeramani <balaji@anyscale.com>
elliottower pushed a commit to elliottower/ray that referenced this pull request Apr 22, 2023
Adds a collate_fn argument to iter_batches and iter_torch_batches. This is useful for any last meter preprocessing that's done directly on the batch to be used for training.

Closes ray-project#32224.

Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: Balaji Veeramani <balaji@anyscale.com>
Co-authored-by: Balaji Veeramani <balaji@anyscale.com>
Signed-off-by: elliottower <elliot@elliottower.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Ray Data] Introduce collate_fn argument in Dataset.iter_torch_batches
9 participants