Skip to content

Commit

Permalink
Fix TTIR to TTNN conversion for all gather (#1182)
Browse files Browse the repository at this point in the history
  • Loading branch information
gfengTT authored and azecevicTT committed Dec 16, 2024
1 parent 401a04f commit cbf0b3a
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 9 deletions.
10 changes: 2 additions & 8 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -905,15 +905,9 @@ class AllGatherOpConversionPattern
LogicalResult
matchAndRewrite(ttir::AllGatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
RankedTensorType type =
mlir::cast<RankedTensorType>(adaptor.getInput().getType());
Value device = ::ttnn::utils::getOrInsertDevice(rewriter, op);
tensor::EmptyOp emptyOp = rewriter.create<tensor::EmptyOp>(
op.getLoc(), this->getTypeConverter()->convertType(type), device);

rewriter.replaceOpWithNewOp<ttnn::AllGatherOp>(
op, this->getTypeConverter()->convertType(op.getType()), emptyOp,
adaptor.getDim());
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), adaptor.getDim());
return success();
}
};
Expand Down
1 change: 0 additions & 1 deletion test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<1x1x32x32xbf16>) -> tensor<1x1x32x128xbf16> {
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<1x1x32x128xbf16>
// CHECK: %[[C:.*]] = "ttnn.all_gather"[[C:.*]]
%1 = "ttir.all_gather"(%arg0, %0) <{dim = 3 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x1x32x32xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16>
Expand Down

0 comments on commit cbf0b3a

Please sign in to comment.