-
Notifications
You must be signed in to change notification settings - Fork 96
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
#16391: propagate sub_device_ids to mesh #16410
base: main
Are you sure you want to change the base?
Conversation
FYI @xuncaiTT |
- Further update all-gather-async tests
6254b84
to
3ff6e59
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we now also need to update the C++ APIs? See distributed_tensor.hpp
@tt-aho / @SeanNijjar - what is the high level plan for the APIs involving subdevices? Passing subdevice IDs to mesh composers is odd, as it is completely unrelated to the mesh distribution functionality. Do we plan to plumb subdevice IDs to all of the APIs that copy tensors under the hood? From the documentation: The sub-device IDs to wait on. Defaults to all sub-devices. - what does this mean exactly, do we wait before copying a tensor, or after? If this is a synchronization primitive, can we make it an explicit API instead, like ttnn.wait_for_subdevices(...)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for stalling before the reading/writing of buffers.
I am currently working on a new api similar to what you have proposed, but it is for adjusting the default to stall on, instead of an explicit stall api that you proposed. Adjusting a stored/cached list of what to stall on minimizes the burden on the user to inject synchronization calls themselves, and having to track their own list to stall on everywhere. This should also allow us to remove the need to propagate sub_device_ids to all these apis.
Ex below:
What would be coded now
sub_device_0 = ...
sub_device_1 = ...
manager = create_manager([sub_device_0, sub_device_1])
load_manager(manager)
run_long_running_op_on_sub_device_1()
adjust_default_stalls([sub_device_0])
write_buffer(sub_device_ids=[sub_device_0])
run_op_on_sub_device_0()
read_buffer(sub_device_ids=[sub_device_0])
With new api (adjust_default_stalls
is the new api and is not the final name for it)
sub_device_0 = ...
sub_device_1 = ...
manager = create_manager([sub_device_0, sub_device_1])
load_manager(manager)
run_long_running_op_on_sub_device_1()
adjust_default_stalls([sub_device_0])
write_buffer()
run_op_on_sub_device_0()
read_buffer()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tt-aho - based on above discussion - I take it the recommendation here is to abandon part of this PR (the part that updates the mesh composer) and when your changes are available, rebase and merge (well... after review of course). Correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so. This is my current pr for reference #16473. I'm planning to add the new api first, then remove the sub_device_ids propagation in the read/write apis in a subsequent pr.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this is great! Makes sense, also +1 to using the term set
instead of adjust
as per #16473.
Ticket
Link to Github Issue
Problem description
All-gather v2 hangs when running with cluster axis API on persistent fabric
What's changed
In tests:
Infra:
Checklist
Closes #16391