-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable lowering ttir all_reduce and mesh_shard to ttnn and flatbuffer #1432
Conversation
0cec6ec
to
ed37e0a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mesh shape in the output of ./build/bin/ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=ttrt-artifacts/system_desc.ttsys" test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir
doesn't look right:
%0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 140733412566816x4>}> : () -> !tt.device<#device>
ed37e0a
to
8b5b18d
Compare
@gfengTT It seems like there was some issue in getOrInsertDevice() when assigning default meshShape. Thanks for catching this. I believe fixed it. Can you double check if it is fine? |
Yes, looking good now. |
7a3b7ea
to
4d49cc3
Compare
@@ -852,6 +852,7 @@ def TTNN_AllGatherOp: TTNN_Op<"all_gather"> { | |||
}]; | |||
|
|||
let arguments = (ins AnyRankedTensor:$input, | |||
TT_Device:$device, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we plan on adding a MeshDevice
attr/type?
Additionally, can you please sort this param to match the ordering in actual signature of the op? Same for reduce_scatter
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@svuckovicTT AFAIK, TT_Device handle points opened device that includes MeshDevice in runtime. So, I believe we do not want to add MeshDevice attr/type. Do you have any particular reason or concern to specifically add MeshDevice here?
// Additionally, can you please sort this param to match the ordering in actual signature of the op? Same for reduce_scatter.
Can you help me better understand your comment here? I can see almost identical ordering between all_gather and redeuce_scatter except different attribute names. Can you point to the signature of all_gather op for me to refer to?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@svuckovicTT AFAIK, TT_Device handle points opened device that includes MeshDevice in runtime. So, I believe we do not want to add MeshDevice attr/type. Do you have any particular reason or concern to specifically add MeshDevice here?
I was going to say that we should differentiate between ttnn::Device
and ttnn::MeshDevice
, but I see that AnyDevice
was added in metal as a wrapper for device-like objects, which complicates things further...
I'm afraid that if we bundle all of these into one type like TT_Device
, there's going to be a situation where we'll need to know what type of a device it is, and then we'll hack something up instead of properly fixing it (because it's less work to hack around it), and tech debt is going to accrue over time, then it's going to come back and bite us at some point. We made these types of mistakes in pybuda, it was unsustainable.
@nsmithtt I'd really like to hear your opinion on this as well.
Can you help me better understand your comment here? I can see almost identical ordering between all_gather and redeuce_scatter except different attribute names. Can you point to the signature of all_gather op for me to refer to?
I meant to say that the IR definition of the op should match the signature of that same op in TTNN lib, and that the same reasoning should be applied for reduce_scatter
as well.
If you look at all_gather
signature with MeshDevice
in TTNN lib, it's:
struct ExecuteAllGather {
...
static ttnn::Tensor invoke(
const ttnn::Tensor& input_tensor,
const int32_t dim,
const uint32_t cluster_axis,
const MeshDevice& mesh_device,
const uint32_t num_links = 1,
const std::optional<ttnn::MemoryConfig>& memory_config = std::nullopt,
const std::optional<size_t> num_workers = std::nullopt,
const std::optional<size_t> num_buffers_per_channel = std::nullopt,
const ttnn::ccl::Topology topology = ttnn::ccl::Topology::Ring);
};
The MeshDevice
param comes after dim
, whereas in the IR definition, it's before dim
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should just have a single device type that optionally carries mesh information. TTNN is moving in this direction too as you noticed with AnyDevice. FWIW in the runtime we always open a mesh device which is a strict superset over single device.
reduce_scatter
and other CCLs will eventually have a verify function that checks the device argument / device in scope (for TTIR) has a mesh attribute that's compatible with the way the op itself and input tensors are programmed.
I think that introducing multiple device types in the compiler will just result in confusion/bugs and I believe this is a temporary state that TTNN is in.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@svuckovicTT, we decided to land because we want to get the next round of multi-device changes going. I filed a follow on issue for the argument reordering: #1550
DefaultValuedAttr<SI32Attr, "1">:$num_links); | ||
let results = (outs AnyRankedTensor:$result); | ||
|
||
let hasVerifier = 1; | ||
} | ||
|
||
def TTNN_MeshShardOp: TTNN_Op<"mesh_shard"> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't find mesh_shard
in metal, can you point me to it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@svuckovicTT This is host-side input shard / output concat operation and TTNN team is working on the APIs. tenstorrent/tt-metal#15061 (comment)
1aa9207
to
cf307d7
Compare
cf307d7
to
a4abcc5
Compare
1f9d87d
to
a04957f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making these changes Wooseok! I left a couple more comments regarding tests and the coding standards that we are trying to follow...
* all_reduce breaks down into reduce_scatter and all_gather * mesh_shard is directly converted * corresponding dialect unit tests are added * fix issue in default mesh shape in getOrInserDeivce() * enable workaround pass to avoid issues in ttnn
a04957f
to
c764b41
Compare
…#1432) * all_reduce breaks down into reduce_scatter and all_gather * mesh_shard is directly converted * corresponding dialect unit tests are added * fix issue in default mesh shape in getOrInserDeivce() * enable workaround pass to avoid issues in ttnn
As a second step of multi-device support plan, this PR allows to lower all_reduce and mesh_shard ops to TTNN and flatbuffer format.
The all_reduce op lowering needs special handling in this PR.
Although all_reduce lowering is supposed to be one to one mapping from ttir all_reduce to ttnn all_reduce, the all_reduce ops are broken down into reduce_scatter and all_gather ops because current support of all_reduce in TTNN is not stable.
In addition, the reduce_scatter op in TTNN currently does not support two dimensional tensor correctly. As a temporary workaround, we insert reshape ops front and back to make the tensor as four dimensional tensor.
Once stable code is available, the code can be removed and simple one to one mapping code can be replaced.