diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 57383c007d..7fb7770778 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -852,6 +852,7 @@ def TTNN_AllGatherOp: TTNN_Op<"all_gather"> { }]; let arguments = (ins AnyRankedTensor:$input, + TT_Device:$device, SI32Attr:$dim, DefaultValuedAttr:$num_links); @@ -874,12 +875,31 @@ def TTNN_ReduceScatterOp: TTNN_Op<"reduce_scatter"> { }]; let arguments = (ins AnyRankedTensor:$input, + TT_Device:$device, SI32Attr:$scatter_split_dim, - TTNN_ReduceType:$math_op, + TT_ReduceTypeAttr:$math_op, DefaultValuedAttr:$num_links); let results = (outs AnyRankedTensor:$result); let hasVerifier = 1; } +def TTNN_MeshShardOp: TTNN_Op<"mesh_shard"> { + let summary = "Mesh shard op."; + let description = [{ + Tensor Mesh Shard operation + }]; + + let arguments = (ins + AnyRankedTensor:$input, + TT_Device:$device, + TT_MeshShardDirectionAttr:$shard_direction, + TT_MeshShardTypeAttr:$shard_type, + TT_GridAttr:$shard_shape); + + let results = (outs AnyRankedTensor:$result); + + let hasVerifier = 1; +} + #endif diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index 19b1dbc92a..332ebd0391 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -253,10 +253,29 @@ table DeallocateOp { table AllGatherOp { in: tt.target.TensorRef; out: tt.target.TensorRef; + device: tt.target.DeviceRef; dim: uint32; num_links: uint32; } +table ReduceScatterOp { + in: tt.target.TensorRef; + out: tt.target.TensorRef; + device: tt.target.DeviceRef; + scatter_split_dim: uint32; + math_op: uint32; + num_links: uint32; +} + +table MeshShardOp { + in: tt.target.TensorRef; + out: tt.target.TensorRef; + device: tt.target.DeviceRef; + shard_direction: uint32; + shard_type: uint32; + shard_shape: [int64]; +} + union OpType { GetDeviceOp, ToMemoryConfigOp, @@ -280,6 +299,8 @@ union OpType { MaxPool2dOp, DeallocateOp, AllGatherOp, + ReduceScatterOp, + MeshShardOp, ArangeOp, } diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 789485eac3..b7110712ec 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -42,9 +42,13 @@ static Value getOrInsertDevice(ConversionPatternRewriter &rewriter, DeviceAttr deviceAttr = getCurrentScopeDevice(op); auto currentInsertionPoint = rewriter.saveInsertionPoint(); rewriter.setInsertionPoint(block, block->begin()); + auto meshShape = deviceAttr.getMeshShape(); + if (meshShape.empty()) { + meshShape = llvm::SmallVector{1, 1}; + } auto deviceOp = rewriter.create( op->getLoc(), rewriter.getType(deviceAttr), - ttnn::MeshShapeAttr::get(op->getContext(), 1, 1)); + ttnn::MeshShapeAttr::get(op->getContext(), meshShape[0], meshShape[1])); rewriter.restoreInsertionPoint(currentInsertionPoint); return deviceOp.getResult(); } @@ -852,6 +856,129 @@ class SubtractOpConversionPattern } }; +class AllReduceOpConversionPattern + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttir::AllReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType inputType = + mlir::cast(adaptor.getInputs().front().getType()); + SmallVector inputTypeShape(inputType.getShape()); + auto replicaGroupsShape = adaptor.getReplicaGroups().getType().getShape(); + size_t scatter_dim = adaptor.getDim(); + int32_t scatter_num = + replicaGroupsShape[scatter_dim % replicaGroupsShape.size()]; + auto device = getOrInsertDevice(rewriter, op); + + // This lowering is supposed to be one to one mapping from ttir all_reduce + // to ttnn all_reduce. However, the all_reduce ops are broken down into + // reduce_scatter and all_gather ops because current support of all_reduce + // in TTNN is not stable. + + // The reduce_scatter op in TTNN currently does not support two dimensional + // tensor correctly. As a temporary workaround, we insert reshape ops front + // and back to make the tensor as four dimensional tensor. + + // TODO(wooseoklee): Once it supports two dimensional + // tensor(https://github.com/tenstorrent/tt-metal/issues/15010), we can + // remove this workaround solution. + if (inputTypeShape.size() < 4) { + std::vector reshapedInputShape(4, 1); + for (size_t i = 0; i < inputTypeShape.size(); ++i) { + reshapedInputShape[i + inputTypeShape.size()] = inputTypeShape[i]; + } + + ArrayAttr reshapedInputShapeAttr = + rewriter.getI32ArrayAttr(std::vector( + reshapedInputShape.begin(), reshapedInputShape.end())); + + auto reshapedInputType = + RankedTensorType::Builder(inputType).setShape(reshapedInputShape); + + ttnn::ReshapeOp preReshapeOp = rewriter.create( + op.getLoc(), this->getTypeConverter()->convertType(reshapedInputType), + adaptor.getInputs().front(), reshapedInputShapeAttr); + + scatter_dim = scatter_dim + (4 - inputTypeShape.size()); + + reshapedInputShape[scatter_dim] = + static_cast(reshapedInputShape[scatter_dim] / scatter_num); + + auto scatteredInputType = + RankedTensorType::Builder(inputType).setShape(reshapedInputShape); + + ttnn::ReduceScatterOp reduceScatterOp = + rewriter.create( + op.getLoc(), + this->getTypeConverter()->convertType(scatteredInputType), + preReshapeOp.getResult(), device, scatter_dim, + adaptor.getReduceType()); + + RankedTensorType outputType = mlir::cast(op.getType(0)); + SmallVector outputTypeShape(outputType.getShape()); + + std::vector reshapedOutputShape(4, 1); + for (size_t i = 0; i < outputTypeShape.size(); ++i) { + reshapedOutputShape[i + outputTypeShape.size()] = outputTypeShape[i]; + } + + auto reshapedOutputType = + RankedTensorType::Builder(outputType).setShape(reshapedOutputShape); + + ttnn::AllGatherOp allGatherOp = rewriter.create( + op.getLoc(), + this->getTypeConverter()->convertType(reshapedOutputType), + reduceScatterOp.getResult(), device, scatter_dim); + + ArrayAttr reshapedOutputShapeAttr = rewriter.getI32ArrayAttr( + std::vector(outputTypeShape.begin(), outputTypeShape.end())); + + rewriter.replaceOpWithNewOp( + op, this->getTypeConverter()->convertType(outputType), + allGatherOp.getResult(), reshapedOutputShapeAttr); + } else { + // TODO(wooseoklee): Once all_reduce support in ttnn is stable, we can + // convert directly to ttnn.all_reduce. + inputTypeShape[scatter_dim] = inputTypeShape[scatter_dim] / scatter_num; + auto scatteredInputType = + RankedTensorType::Builder(inputType).setShape(inputTypeShape); + + ttnn::ReduceScatterOp reduceScatterOp = + rewriter.create( + op.getLoc(), + this->getTypeConverter()->convertType(scatteredInputType), + adaptor.getInputs().front(), device, scatter_dim, + adaptor.getReduceType()); + + rewriter.replaceOpWithNewOp( + op, this->getTypeConverter()->convertType(op.getType(0)), + reduceScatterOp.getResult(), device, scatter_dim); + } + return success(); + } +}; + +class MeshShardOpConversionPattern + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttir::MeshShardOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto device = getOrInsertDevice(rewriter, op); + rewriter.replaceOpWithNewOp( + op, this->getTypeConverter()->convertType(op.getType()), + adaptor.getInput(), device, adaptor.getShardDirection(), + adaptor.getShardType(), adaptor.getShardShape()); + + return success(); + } +}; + class AllGatherOpConversionPattern : public OpConversionPattern { public: @@ -860,15 +987,11 @@ class AllGatherOpConversionPattern LogicalResult matchAndRewrite(ttir::AllGatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - RankedTensorType type = - mlir::cast(adaptor.getInput().getType()); - Value device = getOrInsertDevice(rewriter, op); - tensor::EmptyOp emptyOp = rewriter.create( - op.getLoc(), this->getTypeConverter()->convertType(type), device); + auto device = getOrInsertDevice(rewriter, op); rewriter.replaceOpWithNewOp( - op, this->getTypeConverter()->convertType(op.getType()), emptyOp, - adaptor.getDim()); + op, this->getTypeConverter()->convertType(op.getType()), + adaptor.getInput(), device, adaptor.getDim()); return success(); } }; @@ -995,6 +1118,8 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, Conv2dOpConversionPattern, MaxPool2dOpConversionPattern, SubtractOpConversionPattern, + MeshShardOpConversionPattern, + AllReduceOpConversionPattern, AllGatherOpConversionPattern, ArangeOpConversionPattern, ScatterOpConversionPattern diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index f04d5566b9..9c12230cae 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -751,6 +751,10 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, // patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, + ctx); + patterns.add>(typeConverter, + ctx); // Module op // diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index 8e41368cbb..eefdd9ce15 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -970,7 +970,34 @@ ::mlir::LogicalResult AllGatherOp::verify() { //===----------------------------------------------------------------------===// ::mlir::LogicalResult ReduceScatterOp::verify() { - // TODO(gfengTT) + ::mlir::RankedTensorType inputType = getInput().getType(); + int32_t scatterSplitDim = getScatterSplitDim(); + auto mathOp = getMathOp(); + + if (scatterSplitDim >= inputType.getRank() || + scatterSplitDim < -inputType.getRank()) { + return emitOpError("Invalid dimension for reduce scatter op."); + } + + // Check reduction op that we currently support in tt_nn + if (mathOp != ::mlir::tt::ReduceType::Sum && + mathOp != ::mlir::tt::ReduceType::Max && + mathOp != ::mlir::tt::ReduceType::Min) { + return emitOpError("Invalid reduction op for reduce scatter op."); + } + + return success(); +} + +::mlir::LogicalResult MeshShardOp::verify() { + auto shardType = getShardType(); + + // Check sharding is one of replicate or devices + if (shardType != ::mlir::tt::MeshShardType::Replicate && + shardType != ::mlir::tt::MeshShardType::Devices) { + return emitOpError("Invalid shard_type for mesh_shard op."); + } + return success(); } diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 9706880e38..19b4d542da 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -156,7 +156,6 @@ ::flatbuffers::Offset<::tt::target::DeviceRef> createDeviceRef(FlatbufferObjectCache &cache, Value device) { auto deviceType = mlir::cast(device.getType()); auto chipIds = deviceType.getDesc().getChipIds(); - assert(chipIds.size() == 1 && "expected single chip"); return ::tt::target::CreateDeviceRef(*cache.fbb, chipIds[0]); } @@ -175,13 +174,13 @@ createOp(FlatbufferObjectCache &cache, GetDeviceOp op) { auto resultType = mlir::cast(result.getType()); auto meshShape = resultType.getDesc().getMeshShape(); auto meshVolume = ttmlir::utils::volume(meshShape); + ::tt::target::Dim2d mesh; if (meshVolume > 1) { - // Only support creating meshes along batch dim for now - assert(meshShape.size() == 3 && "expected 3D mesh shape"); - assert(meshShape[1] == 1 && "expected non-batch dim to be 1"); - assert(meshShape[2] == 1 && "expected non-batch dim to be 1"); + mesh = ::tt::target::Dim2d(meshShape[0], meshShape[1]); + } else { + mesh = ::tt::target::Dim2d(1, 1); } - ::tt::target::Dim2d mesh(1, meshVolume); + auto chipIds = toFlatbuffer(cache, resultType.getDesc().getChipIds()); auto out = cache.getOrCreate(result, createDeviceRef); return ::tt::target::ttnn::CreateGetDeviceOp(*cache.fbb, &mesh, chipIds, out); @@ -275,6 +274,43 @@ createOp(FlatbufferObjectCache &cache, FromDeviceOp op) { return ::tt::target::ttnn::CreateFromDeviceOp(*cache.fbb, input, output); } +::flatbuffers::Offset<::tt::target::DistributionStrategy> +createDistributionStrategy(FlatbufferObjectCache &cache, const Value &dev, + const RankedTensorType &type, uint32_t &numShards) { + if (dev) { + auto device = getOperandThroughDPSOps(dev); + auto devOp = dyn_cast(device.getDefiningOp()); + auto resultType = mlir::cast(devOp.getResult().getType()); + auto meshShape = resultType.getDesc().getMeshShape(); + numShards = ttmlir::utils::volume(meshShape); + if (numShards > 1) { + assert(meshShape.size() <= 2 && "expected 2D mesh shape"); + + if (meshShape[0] == 1 || meshShape[1] == 1) { + // Tensor is sliced by the number of devices at a certain dimension. + // For EmptyOp and FullOp, we assume that the tensor is sliced at the + // fastest dimension. + assert(type.getShape().size() > 0 && "expected non-zero tensor shape"); + uint32_t target_dim = type.getShape().size() - 1; + auto strategy = ::tt::target::CreateShardTensor(*cache.fbb, target_dim); + return ::tt::target::CreateDistributionStrategy( + *cache.fbb, ::tt::target::DistributedTensorConfig::ShardTensor, + strategy.Union()); + } + + const ::tt::target::Dim2d shard_mesh(meshShape[0], meshShape[1]); + auto strategy = + ::tt::target::CreateShardTensor2D(*cache.fbb, &shard_mesh); + return ::tt::target::CreateDistributionStrategy( + *cache.fbb, ::tt::target::DistributedTensorConfig::ShardTensor2D, + strategy.Union()); + } + } + ::flatbuffers::Offset distribution = 0; + return ::tt::target::CreateDistributionStrategy( + *cache.fbb, ::tt::target::DistributedTensorConfig::NONE, distribution); +} + ::flatbuffers::Offset<::tt::target::ttnn::EmptyOp> createOp(FlatbufferObjectCache &cache, EmptyOp op) { ::llvm::ArrayRef shape = op.getShape().getShape(); @@ -284,16 +320,12 @@ createOp(FlatbufferObjectCache &cache, EmptyOp op) { ::tt::mlir::ttnn::utils::toTargetTensorLayout(op.getLayout().value()); uint32_t numShards = 1; - ::tt::target::DistributedTensorConfig distributionType = - ::tt::target::DistributedTensorConfig::NONE; - ::flatbuffers::Offset distribution = 0; - flatbuffers::Offset<::tt::target::DistributionStrategy> strategy = - ::tt::target::CreateDistributionStrategy(*cache.fbb, distributionType, - distribution); + auto strategy = createDistributionStrategy( + cache, op.getDevice(), mlir::cast(op.getType()), + numShards); auto output = getOperandThroughDPSOps(op.getResult()); // If the device is not set, we create on host - // if (!op.getDevice()) { return ::tt::target::ttnn::CreateEmptyOp( *cache.fbb, cache.fbb->CreateVector(shape), dtype, layout, @@ -321,12 +353,9 @@ createOp(FlatbufferObjectCache &cache, FullOp op) { auto fillValue = op.getFillValue().convertToFloat(); auto output = getOperandThroughDPSOps(op.getResult()); uint32_t numShards = 1; - ::tt::target::DistributedTensorConfig distributionType = - ::tt::target::DistributedTensorConfig::NONE; - ::flatbuffers::Offset distribution = 0; - flatbuffers::Offset<::tt::target::DistributionStrategy> strategy = - ::tt::target::CreateDistributionStrategy(*cache.fbb, distributionType, - distribution); + auto strategy = createDistributionStrategy( + cache, op.getDevice(), mlir::cast(op.getType()), + numShards); return ::tt::target::ttnn::CreateFullOp( *cache.fbb, cache.at<::tt::target::DeviceRef>(device), fillValue, numShards, strategy, @@ -417,8 +446,38 @@ createOp(FlatbufferObjectCache &cache, AllGatherOp op) { cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, kHostAllocatedAddress, kHostAllocatedSize); - return ::tt::target::ttnn::CreateAllGatherOp(*cache.fbb, input, output, - op.getDim(), op.getNumLinks()); + auto device = getOperandThroughDPSOps(op.getDevice()); + return ::tt::target::ttnn::CreateAllGatherOp( + *cache.fbb, input, output, cache.at<::tt::target::DeviceRef>(device), + op.getDim(), op.getNumLinks()); +} + +::flatbuffers::Offset<::tt::target::ttnn::ReduceScatterOp> +createOp(FlatbufferObjectCache &cache, ReduceScatterOp op) { + auto input = + cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); + auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, + kHostAllocatedAddress, kHostAllocatedSize); + auto device = getOperandThroughDPSOps(op.getDevice()); + return ::tt::target::ttnn::CreateReduceScatterOp( + *cache.fbb, input, output, cache.at<::tt::target::DeviceRef>(device), + op.getScatterSplitDim(), static_cast(op.getMathOp()), + op.getNumLinks()); +} + +::flatbuffers::Offset<::tt::target::ttnn::MeshShardOp> +createOp(FlatbufferObjectCache &cache, MeshShardOp op) { + auto input = + cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); + auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, + kHostAllocatedAddress, kHostAllocatedSize); + auto device = getOperandThroughDPSOps(op.getDevice()); + llvm::ArrayRef shardShape = op.getShardShape().getShape(); + return ::tt::target::ttnn::CreateMeshShardOp( + *cache.fbb, input, output, cache.at<::tt::target::DeviceRef>(device), + static_cast(op.getShardDirection()), + static_cast(op.getShardType()), + cache.fbb->CreateVector(shardShape)); } template @@ -915,6 +974,13 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, return createOperation(cache, createOp(cache, allGatherOp), debugString, locInfo); } + if (auto reduceScatterOp = dyn_cast(op); reduceScatterOp) { + return createOperation(cache, createOp(cache, reduceScatterOp), + debugString); + } + if (auto meshShardOp = dyn_cast(op); meshShardOp) { + return createOperation(cache, createOp(cache, meshShardOp), debugString); + } if (auto concatOp = dyn_cast(op); concatOp) { return createOperation(cache, createConcatOp(cache, concatOp), debugString, locInfo); diff --git a/test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir b/test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir index f1f5a5965c..a120347943 100644 --- a/test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir +++ b/test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir @@ -2,10 +2,9 @@ #any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<1x1x32x32xbf16>) -> tensor<1x1x32x128xbf16> { - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<1x1x32x128xbf16> - // CHECK: %[[C:.*]] = "ttnn.all_gather"[[C:.*]] %1 = "ttir.all_gather"(%arg0, %0) <{dim = 3 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x1x32x32xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> + // CHECK: %[[C:.*]] = "ttnn.all_gather"[[C:.*]] return %1 : tensor<1x1x32x128xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/ccl/all_reduce.mlir b/test/ttmlir/Dialect/TTNN/ccl/all_reduce.mlir new file mode 100644 index 0000000000..4ee4797f2c --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/ccl/all_reduce.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @all_reduce(%arg0: tensor<4096x16384xf32>) -> tensor<4096x16384xf32> { + %0 = tensor.empty() : tensor<4096x16384xf32> + %1 = "ttir.all_reduce"(%arg0, %0) <{channel_handle = 1 : si32, dim = 0 : si32, operand_constraints = [#any_device_tile, #any_device_tile], reduce_type = #tt.reduce_type, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>, use_global_device_ids}> : (tensor<4096x16384xf32>, tensor<4096x16384xf32>) -> tensor<4096x16384xf32> + // CHECK: %[[C:.*]] = "ttnn.reduce_scatter"[[C:.*]] + // CHECK: %[[C:.*]] = "ttnn.all_gather"[[C:.*]] + return %1 : tensor<4096x16384xf32> + } +} diff --git a/test/ttmlir/Dialect/TTNN/ccl/mesh_shard.mlir b/test/ttmlir/Dialect/TTNN/ccl/mesh_shard.mlir new file mode 100644 index 0000000000..e7d20cfa8f --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/ccl/mesh_shard.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#operand_constraint = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<8192x784xf32>) -> tensor<4096x196xf32> { + %0 = tensor.empty() : tensor<4096x196xf32> + %1 = "ttir.mesh_shard"(%arg0, %0) <{operand_constraints = [#operand_constraint, #operand_constraint], shard_direction = #tt.shard_direction, shard_shape = #tt.grid<2x4>, shard_type = #tt.shard_type}> : (tensor<8192x784xf32>, tensor<4096x196xf32>) -> tensor<4096x196xf32> + // CHECK: %[[C:.*]] = "ttnn.mesh_shard"[[C:.*]] + return %1 : tensor<4096x196xf32> + } +}