Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable lowering ttir all_reduce and mesh_shard to ttnn and flatbuffer #1432

Merged
merged 1 commit into from
Dec 10, 2024

Conversation

wooseokTT
Copy link
Contributor

  • Overall Plan
    As a second step of multi-device support plan, this PR allows to lower all_reduce and mesh_shard ops to TTNN and flatbuffer format.
  1. Convert MLIRs from JAX/OpenXLA/PJRT to TTIR (merged)
  2. Pass converted TTIR to TTNN MLIR and Flatbuffer format (this PR)
  3. Parse TTNN flatbuffer and execute in TT Runtime
  • Implementation Details

The all_reduce op lowering needs special handling in this PR.

Although all_reduce lowering is supposed to be one to one mapping from ttir all_reduce to ttnn all_reduce, the all_reduce ops are broken down into reduce_scatter and all_gather ops because current support of all_reduce in TTNN is not stable.
In addition, the reduce_scatter op in TTNN currently does not support two dimensional tensor correctly. As a temporary workaround, we insert reshape ops front and back to make the tensor as four dimensional tensor.

Once stable code is available, the code can be removed and simple one to one mapping code can be replaced.

  • Changes in this PR
  1. all_reduce breaks down into reduce_scatter and all_gather
  2. mesh_shard is directly converted
  3. corresponding dialect unit tests are added

@wooseokTT wooseokTT force-pushed the wooseok/ttnn_all_reduce_mesh_shard branch from 0cec6ec to ed37e0a Compare November 27, 2024 20:23
Copy link
Contributor

@gfengTT gfengTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mesh shape in the output of ./build/bin/ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=ttrt-artifacts/system_desc.ttsys" test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir doesn't look right:

%0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 140733412566816x4>}> : () -> !tt.device<#device>

@wooseokTT wooseokTT force-pushed the wooseok/ttnn_all_reduce_mesh_shard branch from ed37e0a to 8b5b18d Compare December 2, 2024 15:54
@wooseokTT
Copy link
Contributor Author

wooseokTT commented Dec 2, 2024

The mesh shape in the output of ./build/bin/ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=ttrt-artifacts/system_desc.ttsys" test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir doesn't look right:

%0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 140733412566816x4>}> : () -> !tt.device<#device>

@gfengTT It seems like there was some issue in getOrInsertDevice() when assigning default meshShape. Thanks for catching this. I believe fixed it. Can you double check if it is fine?

@gfengTT
Copy link
Contributor

gfengTT commented Dec 2, 2024

The mesh shape in the output of ./build/bin/ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=ttrt-artifacts/system_desc.ttsys" test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir doesn't look right:

%0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 140733412566816x4>}> : () -> !tt.device<#device>

@gfengTT It seems like there was some issue in getOrInsertDevice() when assigning default meshShape. Thanks for catching this. I believe fixed it. Can you double check if it is fine?

Yes, looking good now.

@wooseokTT wooseokTT force-pushed the wooseok/ttnn_all_reduce_mesh_shard branch 2 times, most recently from 7a3b7ea to 4d49cc3 Compare December 2, 2024 23:47
@wooseokTT wooseokTT linked an issue Dec 3, 2024 that may be closed by this pull request
@@ -852,6 +852,7 @@ def TTNN_AllGatherOp: TTNN_Op<"all_gather"> {
}];

let arguments = (ins AnyRankedTensor:$input,
TT_Device:$device,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we plan on adding a MeshDevice attr/type?

Additionally, can you please sort this param to match the ordering in actual signature of the op? Same for reduce_scatter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@svuckovicTT AFAIK, TT_Device handle points opened device that includes MeshDevice in runtime. So, I believe we do not want to add MeshDevice attr/type. Do you have any particular reason or concern to specifically add MeshDevice here?

// Additionally, can you please sort this param to match the ordering in actual signature of the op? Same for reduce_scatter.

Can you help me better understand your comment here? I can see almost identical ordering between all_gather and redeuce_scatter except different attribute names. Can you point to the signature of all_gather op for me to refer to?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@svuckovicTT AFAIK, TT_Device handle points opened device that includes MeshDevice in runtime. So, I believe we do not want to add MeshDevice attr/type. Do you have any particular reason or concern to specifically add MeshDevice here?

I was going to say that we should differentiate between ttnn::Device and ttnn::MeshDevice, but I see that AnyDevice was added in metal as a wrapper for device-like objects, which complicates things further...

I'm afraid that if we bundle all of these into one type like TT_Device, there's going to be a situation where we'll need to know what type of a device it is, and then we'll hack something up instead of properly fixing it (because it's less work to hack around it), and tech debt is going to accrue over time, then it's going to come back and bite us at some point. We made these types of mistakes in pybuda, it was unsustainable.

@nsmithtt I'd really like to hear your opinion on this as well.

Can you help me better understand your comment here? I can see almost identical ordering between all_gather and redeuce_scatter except different attribute names. Can you point to the signature of all_gather op for me to refer to?

I meant to say that the IR definition of the op should match the signature of that same op in TTNN lib, and that the same reasoning should be applied for reduce_scatter as well.

If you look at all_gather signature with MeshDevice in TTNN lib, it's:

struct ExecuteAllGather {
    ...

    static ttnn::Tensor invoke(
        const ttnn::Tensor& input_tensor,
        const int32_t dim,
        const uint32_t cluster_axis,
        const MeshDevice& mesh_device,
        const uint32_t num_links = 1,
        const std::optional<ttnn::MemoryConfig>& memory_config = std::nullopt,
        const std::optional<size_t> num_workers = std::nullopt,
        const std::optional<size_t> num_buffers_per_channel = std::nullopt,
        const ttnn::ccl::Topology topology = ttnn::ccl::Topology::Ring);
};

The MeshDevice param comes after dim, whereas in the IR definition, it's before dim.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should just have a single device type that optionally carries mesh information. TTNN is moving in this direction too as you noticed with AnyDevice. FWIW in the runtime we always open a mesh device which is a strict superset over single device.

reduce_scatter and other CCLs will eventually have a verify function that checks the device argument / device in scope (for TTIR) has a mesh attribute that's compatible with the way the op itself and input tensors are programmed.

I think that introducing multiple device types in the compiler will just result in confusion/bugs and I believe this is a temporary state that TTNN is in.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@svuckovicTT, we decided to land because we want to get the next round of multi-device changes going. I filed a follow on issue for the argument reordering: #1550

DefaultValuedAttr<SI32Attr, "1">:$num_links);
let results = (outs AnyRankedTensor:$result);

let hasVerifier = 1;
}

def TTNN_MeshShardOp: TTNN_Op<"mesh_shard"> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find mesh_shard in metal, can you point me to it?

Copy link
Contributor Author

@wooseokTT wooseokTT Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@svuckovicTT This is host-side input shard / output concat operation and TTNN team is working on the APIs. tenstorrent/tt-metal#15061 (comment)

@wooseokTT wooseokTT force-pushed the wooseok/ttnn_all_reduce_mesh_shard branch 2 times, most recently from 1aa9207 to cf307d7 Compare December 5, 2024 14:56
@wooseokTT wooseokTT force-pushed the wooseok/ttnn_all_reduce_mesh_shard branch from cf307d7 to a4abcc5 Compare December 5, 2024 18:05
@wooseokTT wooseokTT force-pushed the wooseok/ttnn_all_reduce_mesh_shard branch 2 times, most recently from 1f9d87d to a04957f Compare December 5, 2024 18:44
Copy link
Contributor

@sdjordjevicTT sdjordjevicTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making these changes Wooseok! I left a couple more comments regarding tests and the coding standards that we are trying to follow...

lib/Dialect/TTNN/IR/TTNNOps.cpp Show resolved Hide resolved
lib/Target/TTNN/TTNNToFlatbuffer.cpp Outdated Show resolved Hide resolved
lib/Target/TTNN/TTNNToFlatbuffer.cpp Outdated Show resolved Hide resolved
lib/Target/TTNN/TTNNToFlatbuffer.cpp Outdated Show resolved Hide resolved
lib/Dialect/TTNN/Utils/TransformUtils.cpp Show resolved Hide resolved
lib/Dialect/TTNN/Transforms/TTNNWorkarounds.cpp Outdated Show resolved Hide resolved
* all_reduce breaks down into reduce_scatter and all_gather

* mesh_shard is directly converted

* corresponding dialect unit tests are added

* fix issue in default mesh shape in getOrInserDeivce()

* enable workaround pass to avoid issues in ttnn
@wooseokTT wooseokTT force-pushed the wooseok/ttnn_all_reduce_mesh_shard branch from a04957f to c764b41 Compare December 6, 2024 20:42
@wooseokTT wooseokTT enabled auto-merge (squash) December 10, 2024 14:25
@wooseokTT wooseokTT merged commit 2e63d09 into main Dec 10, 2024
19 checks passed
azecevicTT pushed a commit that referenced this pull request Dec 17, 2024
…#1432)

* all_reduce breaks down into reduce_scatter and all_gather

* mesh_shard is directly converted

* corresponding dialect unit tests are added

* fix issue in default mesh shape in getOrInserDeivce()

* enable workaround pass to avoid issues in ttnn
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Push Jax test through
5 participants