-
Notifications
You must be signed in to change notification settings - Fork 22.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[dtensor][1/n] refactor op dispatch logic to reduce overhead #107305
Conversation
This PR is the first change of a series of refactors to the op dispatch logic to: 1. remove the redundant logic in the op dispatch, simplify the error checking 2. reduce the number of tree_map/tree_flatten/unflatten needed to reduce the overhead coming from those operations 3. remove the CachedShardingPropagator by using lru_cache from functools directly 4. change the view ops behavior by inplace changing the op_schema, which is dangerous for sharding prop caching, model the view op as one type of resharding too 5. enrich output sharding to include whether the op needs redistribute so that we don't need explicit op schema comparison to know it. This should help with further reducing the CPU overhead, benchmark results coming soon By refactoring the dispatch logic, it could possibly enable us to have more features later (i.e. add IrregularShard placement easier) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/107305
Note: Links to docs will display an error until the docs builds have been completed. ❌ 4 New Failures, 1 Unrelated FailureAs of commit dc70183 with merge base 5b9b816 (): NEW FAILURES - The following jobs have failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR is the first change of a series of refactors to the op dispatch logic to: 1. remove the redundant logic in the op dispatch, simplify the error checking 2. reduce the number of tree_map/tree_flatten/unflatten needed to reduce the overhead coming from those operations 3. remove the CachedShardingPropagator by using lru_cache from functools directly, this makes it not only helps TP, but general DTensor operations could be faster! 4. change the view ops behavior by inplace changing the op_schema, which is dangerous for sharding prop caching, model the view op as one type of resharding too 5. enrich output sharding to include whether the op needs redistribute so that we don't need explicit op schema comparison to know it. This should help with further reducing the CPU overhead, benchmark results coming soon By refactoring the dispatch logic, it could possibly enable us to have more features later (i.e. add IrregularShard placement easier) [ghstack-poisoned]
This PR is the first change of a series of refactors to the op dispatch logic to: 1. remove the redundant logic in the op dispatch, simplify the error checking 2. reduce the number of tree_map/tree_flatten/unflatten needed to reduce the overhead coming from those operations 3. remove the CachedShardingPropagator by using lru_cache from functools directly, this makes it not only helps TP, but general DTensor operations could be faster! 4. change the view ops behavior by inplace changing the op_schema, which is dangerous for sharding prop caching, model the view op as one type of resharding too 5. enrich output sharding to include whether the op needs redistribute so that we don't need explicit op schema comparison to know it. This should help with further reducing the CPU overhead, benchmark results: before (without this change), aten.addmm latency: 0.476ms ![Screenshot 2023-08-16 at 10 46 26 AM](https://github.com/pytorch/pytorch/assets/9443650/7692e6c1-1936-4c7f-bf9c-6c8c9b8f6c76) after (with this change), aten.addmm latency: 0.341ms ![Screenshot 2023-08-16 at 11 05 49 AM](https://github.com/pytorch/pytorch/assets/9443650/15a53f0b-7a95-444e-ab2f-3ee0ad2fa47f) overall one layer of mlp time reduced from 13.535 -> 9.665ms [ghstack-poisoned]
This PR is the first change of a series of refactors to the op dispatch logic to: 1. remove the redundant logic in the op dispatch, simplify the error checking 2. reduce the number of tree_map/tree_flatten/unflatten needed to reduce the overhead coming from those operations 3. remove the CachedShardingPropagator by using lru_cache from functools directly, this makes it not only helps TP, but general DTensor operations could be faster! 4. change the view ops behavior by inplace changing the op_schema, which is dangerous for sharding prop caching, model the view op as one type of resharding too 5. enrich output sharding to include whether the op needs redistribute so that we don't need explicit op schema comparison to know it. This should help with further reducing the CPU overhead, benchmark results: before (without this change), aten.addmm latency: 0.476ms ![Screenshot 2023-08-16 at 10 46 26 AM](https://github.com/pytorch/pytorch/assets/9443650/7692e6c1-1936-4c7f-bf9c-6c8c9b8f6c76) after (with this change), aten.addmm latency: 0.341ms ![Screenshot 2023-08-16 at 11 05 49 AM](https://github.com/pytorch/pytorch/assets/9443650/15a53f0b-7a95-444e-ab2f-3ee0ad2fa47f) overall one layer of mlp time reduced from 13.535 -> 9.665ms [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work!! I have some questions around naming and cache miss stats, also is it possible to change the how much we cache in the lru cache?
tree_unflatten(flat_args_schema, args_spec), | ||
tree_unflatten(flat_kwargs_schema, kwargs_spec), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there anyway that we can avoid this unflatten? To me this seems to be unnecessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I realized that so added a TODO here https://github.com/pytorch/pytorch/pull/107305/files#diff-8aca68cfab443c93335bbe5e6a1c3c3cb34df117fc08a2330b7966752a049b47R81
We can possibly keep the op schema be flattened, but we will need to change all of our ops first to behave like this, I'll do more refactor later to help us gradually move the op registration to use flattened schema
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Make sense, if treemap is unavoidable, ideally we only want to do it once.
# compute locally with redistribute first if needed | ||
assert output_sharding.schema_suggestions is not None | ||
suggested_input_schema = output_sharding.schema_suggestions[0] | ||
redistribute_local_args(op_info, suggested_input_schema) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This name is a little bit confusing to me because the concept local
to me is more like tensor rather DTensor while redistribute is a unique thing for DTensor. Would it be possible that we can just call it redistribute_args
or redistribute_n_update_local_args
? I understand the logic here is to first redistribute dtensor and update local_args so that we pass the correct args to the final Aten ops. But the naming seems to suggest we are doing redistribute directly on top of a Tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you look at the refactor I did to this function https://github.com/pytorch/pytorch/pull/107305/files#diff-8aca68cfab443c93335bbe5e6a1c3c3cb34df117fc08a2330b7966752a049b47R75
It is actually indeed directly redistributing on the local tensors, so when redistributing we just have a local tensor + the dtensor spec we want to redistribute to, we don't need to make the redistribute work on the dtensor wrapper, hence that's why I renamed this function to redistributed_local_args
. Let me know if this make sense or not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After reading your comment below and a second thought... I think this name is fine.
@@ -72,19 +73,25 @@ def _decompose_reshard(val: List[_PlacementItem]) -> List[_PlacementItem]: | |||
|
|||
|
|||
# Intentionally expose this API to trace ops on local tensors | |||
def _redistribute_with_local_tensor( | |||
def redistribute_local_tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto, like I mentioned in the dispatch.py
, maybe the old name redistribute_with_local_tensor
or redistribute_with_local_tensor_updated` is better? because the name here seems to suggest that we directly redistribute a Tensor, which is somewhat confusing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm this function/api is indeed directly redistribute a torch.Tensor (i.e. the first arg is the local_tensor) with src/dst dtensor spec, maybe we should add more comment on the API to clarify?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, essentially there is no DTensor involved indeed. We just call collectives directly (and with no-autograd). I just think redistributing a tensor is a little confusing. Well since we redistribute a local tensor of a DTensor, this sounds make sense to me. If we can add some comments here, that would be much appreciated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep will add in the follow up PR
This PR is the first change of a series of refactors to the op dispatch logic to: 1. remove the redundant logic in the op dispatch, simplify the error checking 2. reduce the number of tree_map/tree_flatten/unflatten needed to reduce the overhead coming from those operations 3. remove the CachedShardingPropagator by using lru_cache from functools directly, this makes it not only helps TP, but general DTensor operations could be faster! 4. change the view ops behavior by inplace changing the op_schema, which is dangerous for sharding prop caching, model the view op as one type of resharding too 5. enrich output sharding to include whether the op needs redistribute so that we don't need explicit op schema comparison to know it. This should help with further reducing the CPU overhead, benchmark results: before (without this change), aten.addmm latency: 0.476ms ![Screenshot 2023-08-16 at 10 46 26 AM](https://github.com/pytorch/pytorch/assets/9443650/7692e6c1-1936-4c7f-bf9c-6c8c9b8f6c76) after (with this change), aten.addmm latency: 0.341ms ![Screenshot 2023-08-16 at 11 05 49 AM](https://github.com/pytorch/pytorch/assets/9443650/15a53f0b-7a95-444e-ab2f-3ee0ad2fa47f) overall one layer of mlp time reduced from 13.535 -> 9.665ms Apart from overhead reduction, this PR simplifies the op dispatching logic and the resharding logic (more refactor needed to make things more clean, which will be done in later PRs) [ghstack-poisoned]
This PR is the first change of a series of refactors to the op dispatch logic to: 1. remove the redundant logic in the op dispatch, simplify the error checking 2. reduce the number of tree_map/tree_flatten/unflatten needed to reduce the overhead coming from those operations 3. remove the CachedShardingPropagator by using lru_cache from functools directly, this makes it not only helps TP, but general DTensor operations could be faster! 4. change the view ops behavior by inplace changing the op_schema, which is dangerous for sharding prop caching, model the view op as one type of resharding too 5. enrich output sharding to include whether the op needs redistribute so that we don't need explicit op schema comparison to know it. This should help with further reducing the CPU overhead, benchmark results: before (without this change), aten.addmm latency: 0.476ms ![Screenshot 2023-08-16 at 10 46 26 AM](https://github.com/pytorch/pytorch/assets/9443650/7692e6c1-1936-4c7f-bf9c-6c8c9b8f6c76) after (with this change), aten.addmm latency: 0.341ms ![Screenshot 2023-08-16 at 11 05 49 AM](https://github.com/pytorch/pytorch/assets/9443650/15a53f0b-7a95-444e-ab2f-3ee0ad2fa47f) overall one layer of mlp time reduced from 13.535 -> 9.665ms Apart from overhead reduction, this PR simplifies the op dispatching logic and the resharding logic (more refactor needed to make things more clean, which will be done in later PRs) [ghstack-poisoned]
from torch.distributed._tensor.api import DTensor | ||
|
||
|
||
def get_sharding_prop_cache_info(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this. It would be nice to add a comment or a reference on Python LRU so that user know what information we can get from this API. (This can be done in a follow-up PR as a BE work)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah sure I'll add comments in a follow up PR :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM and thanks for the refactoring.
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 5 checks: pull / linux-jammy-py3.9-clang12-asan / test (default, 1, 6, linux.4xlarge), periodic / parallelnative-linux-jammy-py3.8-gcc11 / test (default, 1, 3, linux.2xlarge), periodic / linux-focal-cuda11.8-py3.10-gcc9-debug / test (default, 1, 5, linux.4xlarge.nvidia.gpu), periodic / win-vs2019-cuda11.8-py3 / test (default, 1, 4, windows.g5.4xlarge.nvidia.gpu), periodic / macos-12-py3-x86-64 / test (default, 1, 4, macos-12) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This update some comments from the follow up of #107305 [ghstack-poisoned]
@pytorchbot drci |
1 similar comment
@pytorchbot drci |
This update some comments from the follow up of #107305 [ghstack-poisoned]
This update some comments from the follow up of #107305 Pull Request resolved: #107608 Approved by: https://github.com/fduwjj ghstack dependencies: #107606
Stack from ghstack (oldest at bottom):
This PR is the first change of a series of refactors to the op dispatch logic to:
checking
the overhead coming from those operations
directly, this makes it not only helps TP, but general DTensor
operations could be faster!
is dangerous for sharding prop caching, model the view op as one type
of resharding too
so that we don't need explicit op schema comparison to know it.
This should help with further reducing the CPU overhead, benchmark
results:
before (without this change), aten.addmm latency: 0.476ms
after (with this change), aten.addmm latency: 0.341ms
overall one layer of mlp time reduced from 13.535 -> 9.665ms
Apart from overhead reduction, this PR simplifies the op dispatching logic and the resharding logic (more refactor needed to make things more clean, which will be done in later PRs)