Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Profiler] Unify the device(CUDA, XPU, PrivateUse1) in torch profiler post processing #123247

Conversation

zejun-chen
Copy link
Contributor

@zejun-chen zejun-chen commented Apr 3, 2024

This PR unifies the CUDA, XPU and PrivateUse1 in the torch profiler. Now CUDA, XPU and PrivateUse1 can together use string object use_device to distinguish each other and share one device path for calculating kineto time durations and memory statistics for post processing.

#suppress-api-compatibility-check

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @ezyang @gchanan @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @rohan-varma @aakhundov

…processing

Use use_device to identify the required device

Signed-off-by: Chen, Zejun <zejun.chen@intel.com>
@zejun-chen zejun-chen requested a review from aaronenyeshi as a code owner April 3, 2024 09:25
Copy link

pytorch-bot bot commented Apr 3, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/123247

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 2509aa8 with merge base 58e403c (image):

UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@zejun-chen
Copy link
Contributor Author

zejun-chen commented Apr 3, 2024

Hi @ezyang @aaronenyeshi @valentinandrei

Could you help review this PR?
It unifies the post processing functions for different devices in torch profiler and we enable XPU in it.
Any comments are welcomed.

Thank you :)

@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 3, 2024
Signed-off-by: Chen, Zejun <zejun.chen@intel.com>
Copy link
Member

@aaronenyeshi aaronenyeshi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, this looks great! Thank you for the change to consolidate the devices, it improves the code readability and maintainability in the future.

A few comments:

  1. Could we keep use_cuda as an argument for awhile. We need to mark it for deprecation in case anyone is currently using it.
  2. I wonder why self.use_device is None for privateuseone. Why not support all valids types: cuda, xpu, privateuseone? cc @fwenguang , @NmomoN

torch/autograd/profiler.py Show resolved Hide resolved
torch/autograd/profiler.py Outdated Show resolved Hide resolved
torch/autograd/profiler.py Outdated Show resolved Hide resolved
torch/autograd/profiler.py Outdated Show resolved Hide resolved
torch/autograd/profiler.py Show resolved Hide resolved
Copy link
Contributor

@briancoutinho briancoutinho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! This significantly improves the readability of this code and I appreciate the effort going into refactoring this. Overall looks good to me. Few minor comments and we can approve

torch/autograd/profiler.py Show resolved Hide resolved
torch/autograd/profiler.py Show resolved Hide resolved
torch/autograd/profiler.py Show resolved Hide resolved
torch/autograd/profiler.py Show resolved Hide resolved
torch/autograd/profiler.py Show resolved Hide resolved
torch/autograd/profiler.py Outdated Show resolved Hide resolved
torch/utils/bottleneck/__main__.py Outdated Show resolved Hide resolved
torch/autograd/profiler_util.py Outdated Show resolved Hide resolved
torch/autograd/profiler_util.py Show resolved Hide resolved
Signed-off-by: Chen, Zejun <zejun.chen@intel.com>
Signed-off-by: Chen, Zejun <zejun.chen@intel.com>
@pytorch-bot pytorch-bot bot added module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels Apr 10, 2024
Signed-off-by: Chen, Zejun <zejun.chen@intel.com>
Signed-off-by: Chen, Zejun <zejun.chen@intel.com>
Signed-off-by: Chen, Zejun <zejun.chen@intel.com>
Copy link
Member

@aaronenyeshi aaronenyeshi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great! Just a few small changes, and should be good to ship. Check the lint failures too please.

torch/autograd/profiler.py Show resolved Hide resolved
torch/autograd/profiler.py Outdated Show resolved Hide resolved
torch/autograd/profiler.py Outdated Show resolved Hide resolved
@aaronenyeshi aaronenyeshi removed oncall: distributed Add this issue/PR to distributed oncall triage queue module: inductor labels Apr 10, 2024
@pytorch-bot pytorch-bot bot added module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels Apr 10, 2024
Signed-off-by: Chen, Zejun <zejun.chen@intel.com>
@zejun-chen
Copy link
Contributor Author

zejun-chen commented Apr 20, 2024

It looks like we're on the final step, please take a look at this failing test:

2024-04-19T10:51:41.8146877Z FAILED [4.4162s] distributed/rpc/cuda/test_tensorpipe_agent.py::TensorPipeCudaRpcTest::test_profiler_remote_cuda - RuntimeError: Process 1 exited with error code 10 and exception:
2024-04-19T10:51:41.8147068Z Traceback (most recent call last):
2024-04-19T10:51:41.8148024Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_distributed.py", line 653, in run_test
2024-04-19T10:51:41.8148351Z     getattr(self, test_name)()
2024-04-19T10:51:41.8149282Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_distributed.py", line 539, in wrapper
2024-04-19T10:51:41.8149437Z     fn()
2024-04-19T10:51:41.8150314Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 2755, in wrapper
2024-04-19T10:51:41.8150493Z     method(*args, **kwargs)
2024-04-19T10:51:41.8151519Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_distributed.py", line 178, in wrapper
2024-04-19T10:51:41.8151705Z     return func(*args, **kwargs)
2024-04-19T10:51:41.8152625Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/dist_utils.py", line 81, in new_test_method
2024-04-19T10:51:41.8152897Z     return_value = old_test_method(self, *arg, **kwargs)
2024-04-19T10:51:41.8154009Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/distributed/rpc/rpc_test.py", line 4617, in test_profiler_remote_cuda
2024-04-19T10:51:41.8154267Z     self.assertGreater(event.device_time_total, 0)
2024-04-19T10:51:41.8154798Z   File "/opt/conda/envs/py_3.10/lib/python3.10/unittest/case.py", line 1244, in assertGreater
2024-04-19T10:51:41.8155075Z     self.fail(self._formatMessage(msg, standardMsg))
2024-04-19T10:51:41.8155533Z   File "/opt/conda/envs/py_3.10/lib/python3.10/unittest/case.py", line 675, in fail
2024-04-19T10:51:41.8155715Z     raise self.failureException(msg)
2024-04-19T10:51:41.8155917Z AssertionError: 0 not greater than 0
2024-04-19T10:51:41.8155938Z 
2024-04-19T10:51:41.8156276Z To execute this test, run the following from the base repo dir:
2024-04-19T10:51:41.8156879Z      python test/distributed/rpc/cuda/test_tensorpipe_agent.py -k test_profiler_remote_cuda

I've merged the stable branch, so let's see if the failure reproduces. Dr. CI said it was unrelated, but let's see.

Hi, @aaronenyeshi @DanilBaibak

Thank you for help. It do be a bug from my code change here:
https://github.com/pytorch/pytorch/pull/123247/files#diff-25291b895a9131f22bf0d5c77ad3403629990bbe9c410256f571f26ce126f461R561
I merge the 2 functions cuda_time_total and privateuse1_time_total into one function device_time_total. The 'use_device' should be set instead of None for creating FunctionEvent when running torch.autograd.profiler_legacy.profile(), otherwise the conditional code in above mentioned link is always True, so that is why CI gets the error: self.assertGreater(event.device_time_total, 0).

Suggested code change:
https://github.com/pytorch/pytorch/blob/main/torch/autograd/profiler_legacy.py#L255
+ use_device = 'cuda' if start.has_cuda() else None,

Thank you.

pytorch-bot bot pushed a commit that referenced this pull request Apr 22, 2024
… post processing (#123247)

This PR unifies the CUDA, XPU and PrivateUse1 in the torch profiler. Now CUDA, XPU and PrivateUse1 can together use string object `use_device` to distinguish each other and share one device path for calculating kineto time durations and memory statistics for post processing.

#suppress-api-compatibility-check

Co-authored-by: Aaron Enye Shi <enye.shi@gmail.com>
Pull Request resolved: #123247
Approved by: https://github.com/aaronenyeshi, https://github.com/gujinghui
sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this pull request Apr 22, 2024
…profiler post processing (pytorch#123247)"

This reverts commit 768ce2c.

Reverted pytorch#123247 on behalf of https://github.com/DanilBaibak due to Broken trunk ([comment](pytorch#123247 (comment)))
@zejun-chen
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

xwang233 added a commit to NVIDIA/Fuser that referenced this pull request Apr 24, 2024
Re: pytorch/pytorch#123247

This PR changed `self_cuda_time_total` to `self_device_time_total` in API calls.
@kit1980 kit1980 removed the Reverted label Apr 29, 2024
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
… post processing (pytorch#123247)

This PR unifies the CUDA, XPU and PrivateUse1 in the torch profiler. Now CUDA, XPU and PrivateUse1 can together use string object `use_device` to distinguish each other and share one device path for calculating kineto time durations and memory statistics for post processing.

#suppress-api-compatibility-check

Co-authored-by: Aaron Enye Shi <enye.shi@gmail.com>
Pull Request resolved: pytorch#123247
Approved by: https://github.com/aaronenyeshi, https://github.com/gujinghui
pytorch-bot bot pushed a commit that referenced this pull request May 3, 2024
…profiler post processing (#123247)"

This reverts commit 768ce2c.

Reverted #123247 on behalf of https://github.com/DanilBaibak due to Broken trunk ([comment](#123247 (comment)))
pytorch-bot bot pushed a commit that referenced this pull request May 3, 2024
… post processing (#123247)

This PR unifies the CUDA, XPU and PrivateUse1 in the torch profiler. Now CUDA, XPU and PrivateUse1 can together use string object `use_device` to distinguish each other and share one device path for calculating kineto time durations and memory statistics for post processing.

#suppress-api-compatibility-check

Co-authored-by: Aaron Enye Shi <enye.shi@gmail.com>
Pull Request resolved: #123247
Approved by: https://github.com/aaronenyeshi
if self.use_device and self.use_device != _get_privateuse1_backend_name():
warn(f"{self.use_device} doesn't support profile.")
VALID_DEVICE_OPTIONS = ["cuda", "xpu", "privateuseone"]
if self.use_device not in VALID_DEVICE_OPTIONS:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This results in an annoying UserWarning: The None is not a valid device option. warning whenever CPU profiler is used :( PR is coming

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @malfet

malfet added a commit that referenced this pull request May 7, 2024
This fixes a logic regression introduced by #123247 where 
```python
if self.use_device and self.use_device != _get_privateuse1_backend_name():
``` 
was replaced with
```python
        VALID_DEVICE_OPTIONS = ["cuda", "xpu", "privateuseone"]
        if self.use_device not in VALID_DEVICE_OPTIONS:
```

That triggers a warning every time code is invoke with `self.use_device` set to None
@malfet malfet added the module: bc-breaking Related to a BC-breaking change label May 7, 2024
@malfet
Copy link
Contributor

malfet commented May 7, 2024

Hmm, why this PR added #suppress-api-compatibility-check?

@zejun-chen
Copy link
Contributor Author

Hmm, why this PR added #suppress-api-compatibility-check?

In this PR, we removed the cuda_time from the FormattedTimesMixin, then the lint checking failed because it checks the deleted function. Here is the message we got previously, so we added #suppress-api-compatibility-check.

2024-04-11T18:09:20.0642879Z ##[warning]Function FormattedTimesMixin.cuda_time: function deleted
2024-04-11T18:09:20.0645300Z ##[warning]Function FormattedTimesMixin.privateuse1_time: function deleted
2024-04-11T18:09:20.0646984Z ##[warning]Function FunctionEvent.self_cuda_memory_usage: function deleted
2024-04-11T18:09:20.0648547Z ##[warning]Function FunctionEvent.self_privateuse1_memory_usage: function deleted
2024-04-11T18:09:20.0649998Z ##[warning]Function FunctionEvent.cuda_time_total: function deleted
2024-04-11T18:09:20.0651407Z ##[warning]Function FunctionEvent.self_cuda_time_total: function deleted
2024-04-11T18:09:20.0652911Z ##[warning]Function FunctionEvent.self_privateuse1_time_total: function deleted
2024-04-11T18:09:20.0654335Z ##[warning]Function FunctionEvent.privateuse1_time_total: function deleted

pytorchmergebot pushed a commit that referenced this pull request May 7, 2024
This fixes a logic regression introduced by #123247 where
```python
if self.use_device and self.use_device != _get_privateuse1_backend_name():
```
was replaced with
```python
        VALID_DEVICE_OPTIONS = ["cuda", "xpu", "privateuseone"]
        if self.use_device not in VALID_DEVICE_OPTIONS:
```

That triggers a warning every time code is invoke with `self.use_device` set to None

This change also skips all the checks which are useless if `use_device` is None to begin with
Pull Request resolved: #125654
Approved by: https://github.com/aaronenyeshi
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/rocm ciflow/trunk Trigger trunk jobs on your pull request Merged module: bc-breaking Related to a BC-breaking change module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: profiler release notes category suppress-api-compatibility-check Suppresses the failures of API backward-compatibility linter (Lint/bc_linter) suppress-bc-linter Suppresses the failures of API backward-compatibility linter (Lint/bc_linter) topic: improvements topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.