-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AIR/Data] Add collate_fn
to iter_torch_batches
#32412
Conversation
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Signed-off-by: Balaji Veeramani <balaji@anyscale.com>
Signed-off-by: Balaji Veeramani <balaji@anyscale.com>
Signed-off-by: Balaji Veeramani <balaji@anyscale.com>
python/ray/data/dataset.py
Outdated
@@ -2858,6 +2859,7 @@ def iter_batches( | |||
to select ``numpy.ndarray`` for tensor datasets and | |||
``Dict[str, numpy.ndarray]`` for tabular datasets. Default is "default". | |||
drop_last: Whether to drop the last batch if it's incomplete. | |||
collate_fn: A function to apply to each data batch before returning it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I don't know if it makes sense to have this as part of the ds.iter_batches()
API, since the primary motivation behind this feature is to (1) make the Torch DataLoader
port easier, and (2) provide a hook to override how we convert NumPy ndarrays to Torch Tensors. collate_fn
doesn't provide much critical value for the core ds.iter_batches()
API: it's functionally exactly equivalent to
for batch in ds.iter_batches():
batch = collate_fn(batch)
Is the primary motivation for this to be able to include it in the local pipelining operation chain, to improve collation performance? If that's the case, I'm wondering if ds.iter_torch_batches()
shouldn't reuse ds.iter_batches()
, and we should instead have ds.iter_torch_batches()
use the internal batch_block_refs
API directly, which can expose collate_fn
. Most of the batching logic is captured in batch_block_refs
anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I was also thinking the same...will make the change.
Easiest thing is just to have a private arg in iter_batches
. IMO that would be preferable for future maintenance than both APIs directly using batch_block_refs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to make _collate_fn
private for iter_batches
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I'd rather not muddy the user-facing .iter_batches()
API even with a private _collate_fn
arg, if possible, since it's already a pretty wide interface. And since batch_block_refs
contains most of the batching logic, I don't think that there's much maintenance or drift risk here.
What are your primary concerns for maintenance here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's mostly about the timing logic, but it's not a big deal. Updated to using batch_block_refs
directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If stuff like execution triggering and stats collection becomes involved enough around batch_block_refs
, we can always introduce a private Dataset._iter_batches()
method that can support all of the other batch-based iterator APIs without making the user-facing Dataset.iter_batches()
API any wider.
But for the sake of this PR, duplicating that thin logic around batch_block_refs
seems fine to me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this actually doesn't work for the pipelined case...for some reason that calls into Dataset.iter_torch_batches
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DatasetPipeline makes an implicit assumption that Dataset.iter_torch_batches
calls into self.iter_batches
, which no longer holds true if we change this implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think private arg is the way to go, because the abstractions here are very messy. We can remove this later once we deprecate DatasetPipeline
.
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: Balaji Veeramani <balaji@anyscale.com>
Signed-off-by: Balaji Veeramani <balaji@anyscale.com>
Signed-off-by: Balaji Veeramani <balaji@anyscale.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, nice work on this one!
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Signed-off-by: amogkam <amogkamsetty@yahoo.com>
Adds a collate_fn argument to iter_batches and iter_torch_batches. This is useful for any last meter preprocessing that's done directly on the batch to be used for training. Closes ray-project#32224. Signed-off-by: amogkam <amogkamsetty@yahoo.com> Signed-off-by: Balaji Veeramani <balaji@anyscale.com> Co-authored-by: Balaji Veeramani <balaji@anyscale.com> Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com>
Adds a collate_fn argument to iter_batches and iter_torch_batches. This is useful for any last meter preprocessing that's done directly on the batch to be used for training. Closes ray-project#32224. Signed-off-by: amogkam <amogkamsetty@yahoo.com> Signed-off-by: Balaji Veeramani <balaji@anyscale.com> Co-authored-by: Balaji Veeramani <balaji@anyscale.com>
Adds a collate_fn argument to iter_batches and iter_torch_batches. This is useful for any last meter preprocessing that's done directly on the batch to be used for training. Closes ray-project#32224. Signed-off-by: amogkam <amogkamsetty@yahoo.com> Signed-off-by: Balaji Veeramani <balaji@anyscale.com> Co-authored-by: Balaji Veeramani <balaji@anyscale.com> Signed-off-by: elliottower <elliot@elliottower.com>
Signed-off-by: amogkam amogkamsetty@yahoo.com
Adds a
collate_fn
argument toiter_batches
anditer_torch_batches
. This is useful for any last meter preprocessing that's done directly on the batch to be used for training.Closes #32224.
Why are these changes needed?
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.