Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 0-dim Tensor overload to _foreach_mul #106677

Closed
wants to merge 9 commits into from

Conversation

crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Aug 6, 2023

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 6, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/106677

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 355b464:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: foreach_frontend release notes category label Aug 6, 2023
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
in fast path

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
@crcrpar
Copy link
Collaborator Author

crcrpar commented Aug 6, 2023

@pytorchbot label "module: mta"

@pytorch-bot pytorch-bot bot added the module: mta Issues related to multi-tensor apply kernels and foreach functions label Aug 6, 2023
@albanD albanD requested a review from janeyx99 August 6, 2023 23:22
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 6, 2023
Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @crcrpar for the speedy PR! :D

Looks good overall, I have some minor comments on error messages and testing. I also have a broader question => why does this overload only handle scalar tensors vs any tensor that could be broadcasted into an element of the tensorlist? What is holding us back from adding general support for the more general tensor?

aten/src/ATen/native/ForeachOpsKernels.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/ForeachOpsKernels.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu Outdated Show resolved Hide resolved
aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu Outdated Show resolved Hide resolved
aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu Outdated Show resolved Hide resolved
aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu Outdated Show resolved Hide resolved
aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu Outdated Show resolved Hide resolved
aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu Outdated Show resolved Hide resolved
#define FOREACH_BINARY_OP_SCALAR_TENSOR(FUNCTION, NAME, OP, DIVISION_OP) \
void foreach_tensor_##NAME##_tensor_kernel_cuda_( \
TensorList tensors, const Tensor& scalar) { \
check_foreach_api_restrictions(tensors); \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have to keep making this check even in the functions we dispatch into? Like in foreach_tensor_##OP##tensor_kernel_slow?

Copy link
Collaborator Author

@crcrpar crcrpar Aug 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need this in ones defined in ForeachOpsKernels.cpp since CPU TensorList inputs don't come here. would this answer your question?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I realized that. However, could we postpone this check to after we run the slow path (since that will check this anyway)? Or is the runtime for check_foreach_api_restrictions negligible? This also doesn't need to happen in this PR, I'm just curious why the precedent was set as such.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I'm not sure about the benefit of it because we call this check in the fast path at some point and what it does is basically comparing the size of tensorlists and scalarlists.
I once thought of merging these two into one for fast path while keeping the first check for the slow path.

test/test_foreach.py Show resolved Hide resolved
@bdhirsh
Copy link
Contributor

bdhirsh commented Aug 7, 2023

cc @janeyx99 there was some discussion here, where it seems like defining the behavior of "what does the user expect when the other tensor has multiple elements" is a bit ambiguous. Since this the most concrete benefit of this overload today is in preventing a H2D sync in clip_grad_norm, where we know that the second argument is a 0-dim scalar-tensor, then we tentatively agreed to put off resolving that ambiguity until there's a clearer use case for it.

#define FOREACH_BINARY_OP_SCALAR_TENSOR(FUNCTION, NAME, OP, DIVISION_OP) \
void foreach_tensor_##NAME##_tensor_kernel_cuda_( \
TensorList tensors, const Tensor& scalar) { \
check_foreach_api_restrictions(tensors); \
Copy link
Collaborator Author

@crcrpar crcrpar Aug 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need this in ones defined in ForeachOpsKernels.cpp since CPU TensorList inputs don't come here. would this answer your question?

aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu Outdated Show resolved Hide resolved
crcrpar and others added 3 commits August 8, 2023 02:19
Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
@janeyx99
Copy link
Contributor

janeyx99 commented Aug 7, 2023

CI failures look to just be error message string matching--should be chill to fix.

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good conditioned on green CI. Thanks for the change!

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
@crcrpar
Copy link
Collaborator Author

crcrpar commented Aug 8, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 8, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@crcrpar crcrpar deleted the foreach_mul_tensor_overload branch August 8, 2023 04:45
Cyril-Anto pushed a commit to Cyril-Anto/pytorch that referenced this pull request Aug 17, 2023
janeyx99 added a commit that referenced this pull request Nov 15, 2023
This PR is ALMOST basically just following the steps from #106677 EXCEPT! We do add one feature. Similar to fused_adam(w), for the CUDA dispatches: when the scalar tensor is on CPU, we .item and redispatch to the normal scalar overload. Otherwise, the cuda kernel will complain about mismatch in devices between the scalar and the tensors.

Why do we add this feature? Our optimizers want to allow lr as a tensor, and lr could be a CPU tensor. lr is used with foreach_div_ in Adam, so our CI will break otherwise.

After this PR, `_foreach_mul` and `_foreach_div` will accept either a CPU or a GPU tensor for the scalar tensor (vs only a GPU tensor). They join the ranks of `fused_adam(w)` in this characteristic. I did not yet do the same thing for foreach_add (the only other foreach op with a .Tensor overload) because there is no use case and will be more involved.

cc crcrpar 




[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Nov 15, 2023
This PR is ALMOST basically just following the steps from #106677 EXCEPT! We do add one feature. Similar to fused_adam(w), for the CUDA dispatches: when the scalar tensor is on CPU, we .item and redispatch to the normal scalar overload. Otherwise, the cuda kernel will complain about mismatch in devices between the scalar and the tensors.

Why do we add this feature? Our optimizers want to allow lr as a tensor, and lr could be a CPU tensor. lr is used with foreach_div_ in Adam, so our CI will break otherwise.

After this PR, `_foreach_mul` and `_foreach_div` will accept either a CPU or a GPU tensor for the scalar tensor (vs only a GPU tensor). They join the ranks of `fused_adam(w)` in this characteristic. I did not yet do the same thing for foreach_add (the only other foreach op with a .Tensor overload) because there is no use case and will be more involved.

Pull Request resolved: #113688
Approved by: https://github.com/mlazos, https://github.com/albanD
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: mta Issues related to multi-tensor apply kernels and foreach functions open source release notes: foreach_frontend release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants