From 2a61f70194000b8290476cd135cb634a13706c1b Mon Sep 17 00:00:00 2001 From: Guangyu Feng Date: Mon, 4 Nov 2024 10:46:27 -0600 Subject: [PATCH] Fix TTIR to TTNN conversion for all gather --- lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp | 10 ++-------- test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir | 1 - 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index bf216d362..d77d095ac 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -905,15 +905,9 @@ class AllGatherOpConversionPattern LogicalResult matchAndRewrite(ttir::AllGatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - RankedTensorType type = - mlir::cast(adaptor.getInput().getType()); - Value device = ::ttnn::utils::getOrInsertDevice(rewriter, op); - tensor::EmptyOp emptyOp = rewriter.create( - op.getLoc(), this->getTypeConverter()->convertType(type), device); - rewriter.replaceOpWithNewOp( - op, this->getTypeConverter()->convertType(op.getType()), emptyOp, - adaptor.getDim()); + op, this->getTypeConverter()->convertType(op.getType()), + adaptor.getInput(), adaptor.getDim()); return success(); } }; diff --git a/test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir b/test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir index f1f5a5965..cb2a7ad2b 100644 --- a/test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir +++ b/test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir @@ -2,7 +2,6 @@ #any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<1x1x32x32xbf16>) -> tensor<1x1x32x128xbf16> { - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<1x1x32x128xbf16> // CHECK: %[[C:.*]] = "ttnn.all_gather"[[C:.*]] %1 = "ttir.all_gather"(%arg0, %0) <{dim = 3 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x1x32x32xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16>