Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#2004: Added shlo to ttir conversion pass for all_gather and updated shlo to ttir conversion test cases for all gather #2018

Merged
merged 1 commit into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1347,6 +1347,68 @@ class StableHLOToTTIRAllReduceOpConversionPattern
};
} // namespace

namespace {
class StableHLOToTTIRAllGatherOpConversionPattern
: public OpConversionPattern<mlir::stablehlo::AllGatherOp> {
using OpConversionPattern<mlir::stablehlo::AllGatherOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(mlir::stablehlo::AllGatherOp srcOp,
mlir::stablehlo::AllGatherOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Check legality of the operation
LogicalResult err = checkBasicLegality(srcOp, adaptor, rewriter);
if (failed(err)) {
return err;
}

// Create the output tensor type based on inputs
auto outputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(srcOp.getResult(0).getType()));

// Create an empty output tensor with the computed shape
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());

SmallVector<Type> ttirTypes;
if (failed(this->getTypeConverter()->convertTypes(srcOp->getResultTypes(),
ttirTypes))) {
return failure();
}

auto ttirOperands = srcOp.getOperandsMutable();
ttirOperands.append(ValueRange(outputTensor));

SmallVector<NamedAttribute> srcAttrs = to_vector(srcOp->getAttrs());
SmallVector<NamedAttribute> ttirAttrs;
StringAttr dimAttrName = StringAttr::get(this->getContext(), "dim");
IntegerAttr allGatherDimAttr = rewriter.getSI32IntegerAttr(
static_cast<int32_t>(adaptor.getAllGatherDim()));
ttirAttrs.push_back({dimAttrName, allGatherDimAttr});

auto ttirAllGatherOp = rewriter.create<mlir::tt::ttir::AllGatherOp>(
srcOp.getLoc(), ttirTypes, ValueRange(ttirOperands.getAsOperandRange()),
ttirAttrs);
rewriter.replaceOp(srcOp, ttirAllGatherOp);
return success();
}

private:
LogicalResult
checkBasicLegality(mlir::stablehlo::AllGatherOp &srcOp,
mlir::stablehlo::AllGatherOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (srcOp.getOperands().empty() || srcOp.getOperands().size() > 1) {
return rewriter.notifyMatchFailure(
srcOp, "AllGatherOp must have one input/output for now.");
}

return success();
}
};
} // namespace

tapspatel marked this conversation as resolved.
Show resolved Hide resolved
namespace {
class StableHLOToTTIRCustomCallOpConversionPattern
: public OpConversionPattern<mlir::stablehlo::CustomCallOp> {
Expand Down Expand Up @@ -2080,6 +2142,7 @@ static void addCCLOpsConversionPattern(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<StableHLOToTTIRAllReduceOpConversionPattern>(typeConverter, ctx);
patterns.add<StableHLOToTTIRAllGatherOpConversionPattern>(typeConverter, ctx);
patterns.add<StableHLOToTTIRCustomCallOpConversionPattern>(typeConverter,
ctx);
}
Expand Down
161 changes: 155 additions & 6 deletions test/ttmlir/Conversion/StableHLOToTTIR/ccl_ops.mlir
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s

// jax/pjrt sharding target 1x2 for n300
module @jit_matmul_basic attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} {
// jax/pjrt sharding target 1x2 for n300 all_reduce
module @all_reduce_1x2 attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<8192x784xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<784x16384xf32> {mhlo.layout_mode = "default"}) -> (tensor<8192x16384xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,2]<=[2]}"} : (tensor<8192x784xf32>) -> tensor<8192x784xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x784xf32>) -> tensor<8192x392xf32>
Expand All @@ -28,8 +28,8 @@ module @jit_matmul_basic attributes {mhlo.num_partitions = 2 : i32, mhlo.num_rep
}
}

// jax/pjrt sharding target 2x4 for t3k
module @jit_matmul_basic2 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
// jax/pjrt sharding target 2x4 for t3k all_reduce
module @all_reduce_2x4 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<8192x784xf32>, %arg1: tensor<784x16384xf32>) -> (tensor<8192x16384xf32> {jax.result_info = ""}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[2,4]<=[8]}"} : (tensor<8192x784xf32>) -> tensor<8192x784xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x784xf32>) -> tensor<4096x196xf32>
Expand All @@ -55,8 +55,8 @@ module @jit_matmul_basic2 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_re
}
}

// jax/pjrt sharding target 1x8 for t3k
module @jit_matmul_basic3 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
// jax/pjrt sharding target 1x8 for t3k all_reduce
module @all_reduce_1x8 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<8192x784xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<784x16384xf32> {mhlo.layout_mode = "default"}) -> (tensor<8192x16384xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,8]<=[8]}"} : (tensor<8192x784xf32>) -> tensor<8192x784xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x784xf32>) -> tensor<8192x98xf32>
Expand All @@ -81,3 +81,152 @@ module @jit_matmul_basic3 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_re
return %1 : tensor<8192x16384xf32>
}
}

// jax/pjrt sharding target 8x4 for tg all_reduce
module @all_reduce_8x4 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<8192x784xf32>, %arg1: tensor<784x16384xf32>) -> (tensor<8192x16384xf32> {jax.result_info = ""}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[8,4]<=[32]}"} : (tensor<8192x784xf32>) -> tensor<8192x784xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x784xf32>) -> tensor<1024x196xf32>
// CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]]
%2 = stablehlo.custom_call @Sharding(%arg1) {backend_config = "", mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"} : (tensor<784x16384xf32>) -> tensor<784x16384xf32>
%3 = stablehlo.custom_call @SPMDFullToShardShape(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<784x16384xf32>) -> tensor<196x16384xf32>
// CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]]
%4 = call @shmap_body(%1, %3) : (tensor<1024x196xf32>, tensor<196x16384xf32>) -> tensor<1024x16384xf32>
%5 = stablehlo.custom_call @Sharding(%4) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1024x16384xf32>) -> tensor<1024x16384xf32>
%6 = stablehlo.custom_call @SPMDShardToFullShape(%5) {backend_config = "", mhlo.sharding = "{devices=[8,1,4]<=[32] last_tile_dim_replicate}"} : (tensor<1024x16384xf32>) -> tensor<8192x16384xf32>
// CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]]
return %6 : tensor<8192x16384xf32>
}
func.func private @shmap_body(%arg0: tensor<1024x196xf32>, %arg1: tensor<196x16384xf32>) -> (tensor<1024x16384xf32> {jax.result_info = "[('batch',), None]"}) {
%0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1024x196xf32>, tensor<196x16384xf32>) -> tensor<1024x16384xf32>
%1 = "stablehlo.all_reduce"(%0) <{channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23], [24, 25, 26, 27], [28, 29, 30, 31]]> : tensor<8x4xi64>, use_global_device_ids}> ({
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%2 = stablehlo.add %arg2, %arg3 : tensor<f32>
stablehlo.return %2 : tensor<f32>
}) : (tensor<1024x16384xf32>) -> tensor<1024x16384xf32>
// CHECK: %[[C:.*]] = "ttir.all_reduce"[[C:.*]]
return %1 : tensor<1024x16384xf32>
}
}

// jax/pjrt sharding target 1x32 for tg all_reduce
module @all_reduce_1x32 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} {
tapspatel marked this conversation as resolved.
Show resolved Hide resolved
func.func public @main(%arg0: tensor<8192x800xf32>, %arg1: tensor<800x16384xf32>) -> (tensor<8192x16384xf32> {jax.result_info = ""}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,32]<=[32]}"} : (tensor<8192x800xf32>) -> tensor<8192x800xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x800xf32>) -> tensor<8192x25xf32>
// CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]]
%2 = stablehlo.custom_call @Sharding(%arg1) {backend_config = "", mhlo.sharding = "{devices=[32,1]<=[32]}"} : (tensor<800x16384xf32>) -> tensor<800x16384xf32>
%3 = stablehlo.custom_call @SPMDFullToShardShape(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<800x16384xf32>) -> tensor<25x16384xf32>
// CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]]
%4 = call @shmap_body(%1, %3) : (tensor<8192x25xf32>, tensor<25x16384xf32>) -> tensor<8192x16384xf32>
%5 = stablehlo.custom_call @Sharding(%4) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x16384xf32>) -> tensor<8192x16384xf32>
%6 = stablehlo.custom_call @SPMDShardToFullShape(%5) {backend_config = "", mhlo.sharding = "{replicated}"} : (tensor<8192x16384xf32>) -> tensor<8192x16384xf32>
// CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]]
return %6 : tensor<8192x16384xf32>
}
func.func private @shmap_body(%arg0: tensor<8192x25xf32>, %arg1: tensor<25x16384xf32>) -> (tensor<8192x16384xf32> {jax.result_info = "[None, None]"}) {
%0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<8192x25xf32>, tensor<25x16384xf32>) -> tensor<8192x16384xf32>
%1 = "stablehlo.all_reduce"(%0) <{channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]]> : tensor<1x32xi64>, use_global_device_ids}> ({
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%2 = stablehlo.add %arg2, %arg3 : tensor<f32>
stablehlo.return %2 : tensor<f32>
}) : (tensor<8192x16384xf32>) -> tensor<8192x16384xf32>
// CHECK: %[[C:.*]] = "ttir.all_reduce"[[C:.*]]
return %1 : tensor<8192x16384xf32>
}
}

// jax/pjrt sharding target 1x2 for n300 all_gather
module @all_gather_1x2 attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<8192x800xf32>) -> (tensor<16384x800xf32> {jax.result_info = ""}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[2,1]<=[2]}"} : (tensor<8192x800xf32>) -> tensor<8192x800xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x800xf32>) -> tensor<4096x800xf32>
// CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]]
%2 = call @shmap_body(%1) : (tensor<4096x800xf32>) -> tensor<8192x800xf32>
%3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x800xf32>) -> tensor<8192x800xf32>
%4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[2,1]<=[2]}"} : (tensor<8192x800xf32>) -> tensor<16384x800xf32>
// CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]]
return %4 : tensor<16384x800xf32>
}
func.func private @shmap_body(%arg0: tensor<4096x800xf32>) -> (tensor<8192x800xf32> {jax.result_info = "[('batch',), ('model',)]"}) {
%0 = "stablehlo.all_gather"(%arg0) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, use_global_device_ids}> : (tensor<4096x800xf32>) -> tensor<8192x800xf32>
// CHECK: %[[C:.*]] = "ttir.all_gather"[[C:.*]]
return %0 : tensor<8192x800xf32>
}
}

// jax/pjrt sharding target 1x8 for t3k all_gather
module @all_gather_1x8 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<8192x800xf32>) -> (tensor<65536x800xf32> {jax.result_info = ""}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[8,1]<=[8]}"} : (tensor<8192x800xf32>) -> tensor<8192x800xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x800xf32>) -> tensor<1024x800xf32>
// CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]]
%2 = call @shmap_body(%1) : (tensor<1024x800xf32>) -> tensor<8192x800xf32>
%3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x800xf32>) -> tensor<8192x800xf32>
%4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[8,1]<=[8]}"} : (tensor<8192x800xf32>) -> tensor<65536x800xf32>
// CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]]
return %4 : tensor<65536x800xf32>
}
func.func private @shmap_body(%arg0: tensor<1024x800xf32>) -> (tensor<8192x800xf32> {jax.result_info = "[('batch',), ('model',)]"}) {
%0 = "stablehlo.all_gather"(%arg0) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> : (tensor<1024x800xf32>) -> tensor<8192x800xf32>
// CHECK: %[[C:.*]] = "ttir.all_gather"[[C:.*]]
return %0 : tensor<8192x800xf32>
}
}

// jax/pjrt sharding target 2x4 for t3k all_gather
module @all_gather_2x4 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<8192x800xf32>) -> (tensor<32768x800xf32> {jax.result_info = ""}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[4,2]<=[2,4]T(1,0)}"} : (tensor<8192x800xf32>) -> tensor<8192x800xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x800xf32>) -> tensor<2048x400xf32>
// CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]]
%2 = call @shmap_body(%1) : (tensor<2048x400xf32>) -> tensor<8192x400xf32>
%3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x400xf32>) -> tensor<8192x400xf32>
%4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[4,2]<=[2,4]T(1,0)}"} : (tensor<8192x400xf32>) -> tensor<32768x800xf32>
// CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]]
return %4 : tensor<32768x800xf32>
}
func.func private @shmap_body(%arg0: tensor<2048x400xf32>) -> (tensor<8192x400xf32> {jax.result_info = "[('batch',), ('model',)]"}) {
%0 = "stablehlo.all_gather"(%arg0) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>, use_global_device_ids}> : (tensor<2048x400xf32>) -> tensor<8192x400xf32>
// CHECK: %[[C:.*]] = "ttir.all_gather"[[C:.*]]
return %0 : tensor<8192x400xf32>
}
}

// jax/pjrt sharding target 1x32 for tg all_gather
module @all_gather_1x32 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<8192x800xf32>) -> (tensor<262144x800xf32> {jax.result_info = ""}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[32,1]<=[32]}"} : (tensor<8192x800xf32>) -> tensor<8192x800xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x800xf32>) -> tensor<256x800xf32>
// CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]]
%2 = call @shmap_body(%1) : (tensor<256x800xf32>) -> tensor<8192x800xf32>
%3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x800xf32>) -> tensor<8192x800xf32>
%4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[32,1]<=[32]}"} : (tensor<8192x800xf32>) -> tensor<262144x800xf32>
// CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]]
return %4 : tensor<262144x800xf32>
}
func.func private @shmap_body(%arg0: tensor<256x800xf32>) -> (tensor<8192x800xf32> {jax.result_info = "[('batch',), ('model',)]"}) {
%0 = "stablehlo.all_gather"(%arg0) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]]> : tensor<1x32xi64>, use_global_device_ids}> : (tensor<256x800xf32>) -> tensor<8192x800xf32>
// CHECK: %[[C:.*]] = "ttir.all_gather"[[C:.*]]
return %0 : tensor<8192x800xf32>
}
}

// jax/pjrt sharding target 8x4 for tg all_gather
module @all_gather_8x4 attributes {mhlo.num_partitions = 32 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<8192x800xf32>) -> (tensor<65536x800xf32> {jax.result_info = ""}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[8,4]<=[4,8]T(1,0)}"} : (tensor<8192x800xf32>) -> tensor<8192x800xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x800xf32>) -> tensor<1024x200xf32>
// CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]]
%2 = call @shmap_body(%1) : (tensor<1024x200xf32>) -> tensor<8192x200xf32>
%3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x200xf32>) -> tensor<8192x200xf32>
%4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{devices=[8,4]<=[4,8]T(1,0)}"} : (tensor<8192x200xf32>) -> tensor<65536x800xf32>
// CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]]
return %4 : tensor<65536x800xf32>
}
func.func private @shmap_body(%arg0: tensor<1024x200xf32>) -> (tensor<8192x200xf32> {jax.result_info = "[('batch',), ('model',)]"}) {
%0 = "stablehlo.all_gather"(%arg0) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14, 15], [16, 17, 18, 19, 20, 21, 22, 23], [24, 25, 26, 27, 28, 29, 30, 31]]> : tensor<4x8xi64>, use_global_device_ids}> : (tensor<1024x200xf32>) -> tensor<8192x200xf32>
// CHECK: %[[C:.*]] = "ttir.all_gather"[[C:.*]]
return %0 : tensor<8192x200xf32>
}
}
Loading