From cbf0b3ad09801baca1f1e5a643a7c6a7e576de0e Mon Sep 17 00:00:00 2001 From: Guangyu Feng <157328249+gfengTT@users.noreply.github.com> Date: Fri, 6 Dec 2024 11:10:37 -0500 Subject: [PATCH] Fix TTIR to TTNN conversion for all gather (#1182) --- lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp | 10 ++-------- test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir | 1 - 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index d32c505f7e..91bfce7f67 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -905,15 +905,9 @@ class AllGatherOpConversionPattern LogicalResult matchAndRewrite(ttir::AllGatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - RankedTensorType type = - mlir::cast(adaptor.getInput().getType()); - Value device = ::ttnn::utils::getOrInsertDevice(rewriter, op); - tensor::EmptyOp emptyOp = rewriter.create( - op.getLoc(), this->getTypeConverter()->convertType(type), device); - rewriter.replaceOpWithNewOp( - op, this->getTypeConverter()->convertType(op.getType()), emptyOp, - adaptor.getDim()); + op, this->getTypeConverter()->convertType(op.getType()), + adaptor.getInput(), adaptor.getDim()); return success(); } }; diff --git a/test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir b/test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir index f1f5a5965c..cb2a7ad2b3 100644 --- a/test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir +++ b/test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir @@ -2,7 +2,6 @@ #any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<1x1x32x32xbf16>) -> tensor<1x1x32x128xbf16> { - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<1x1x32x128xbf16> // CHECK: %[[C:.*]] = "ttnn.all_gather"[[C:.*]] %1 = "ttir.all_gather"(%arg0, %0) <{dim = 3 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x1x32x32xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16>