From 11a890107883dbf405c4512405f09b4405030658 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Thu, 28 Jul 2022 19:00:02 -0400 Subject: [PATCH] [MLIR][TORCH] Add support for multiple indexing tensors for aten.index.Tensor (#1097) - Includes a canonicalizer for `aten.add.t`needed for successfully lowering the shape function - Only offers support for statically sized index tensors when there is more than one - Dynamic shape support remains for single indexing tensors --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 1 + .../TorchToLinalg/IndirectDataMovement.cpp | 175 ++++++++++++++---- lib/Dialect/Torch/IR/TorchOps.cpp | 29 +++ lib/Dialect/Torch/Transforms/ShapeLibrary.cpp | 22 +-- .../Transforms/SimplifyShapeCalculations.cpp | 1 + .../jit_ir/build_tools/shape_lib_gen.py | 5 +- .../jit_ir/build_tools/torch_ods_gen.py | 2 +- .../torch_mlir_e2e_test/test_suite/basic.py | 124 +++++++++++++ test/Dialect/Torch/canonicalize.mlir | 47 +++++ 9 files changed, 355 insertions(+), 51 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index b40eaaafd88c..d505c0eec267 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6619,6 +6619,7 @@ def Torch_AtenAddTOp : Torch_Op<"aten.add.t", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasCanonicalizer = 1; } def Torch_AtenEqIntListOp : Torch_Op<"aten.eq.int_list", [ diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index 13ae325f5e86..50bbf42efc1d 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -244,6 +244,21 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern { }; } // namespace +// IndexTensor for multiple input tensors broadcasts their shapes to a common +// shape and then replaces the indexed dims with the indices given by the +// indexing tensors: +// x[i_1, i_2, ..., i_M] = result +// result[...] = x[i_1[...], i_2[...], ..., i_M[...]] +// +// where the result shape is computed as follows: +// 1. broadcast i_1, i_2, ..., i_M to a common shape +// 2. if i_1, i_2, ..., i_M is not contiguous, transpose the broadcasted +// shape to the beginning of the result shape, while removing the +// unchanged dims (marked by None) +// 3. Otherwise replace the indexed dims with the broadcasted shape +// +// e.g. x: [2, 3] +// x[[4], [6, 1]] -> x[6, 4] namespace { class ConvertAtenIndexTensorOp : public OpConversionPattern { public: @@ -251,6 +266,7 @@ class ConvertAtenIndexTensorOp : public OpConversionPattern { LogicalResult matchAndRewrite(AtenIndexTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); @@ -266,78 +282,165 @@ class ConvertAtenIndexTensorOp : public OpConversionPattern { SmallVector indicesVal = getTypeConvertedValues(rewriter, loc, getTypeConverter(), indicesTuple); - int indexTensorDim = -1; + // Identify the indices with non-None index tensors and determine if they + // are contiguous within the input list. + SmallVector indexTensorDims; + SmallVector indexTensors; + bool contiguous = true; for (auto i : llvm::seq(0, (int)indicesVal.size())) { Value index = indicesVal[i]; if (!index || failed(checkNotNone(rewriter, op, index))) continue; - if (indexTensorDim >= 0) { - return rewriter.notifyMatchFailure( - op, "unimplemented: only one index tensor allowed"); - } - indexTensorDim = i; + if (!indexTensorDims.empty() && indexTensorDims.back() != i - 1) + contiguous = false; + indexTensorDims.push_back(i); + indexTensors.push_back(index); } - if (indexTensorDim == -1) { + if (indexTensors.empty()) { return rewriter.notifyMatchFailure( - op, "unimplemented: index tensor must not be None"); + op, "aten.index.Tensor: index tensor must not be None"); } - Value indexTensor = indicesVal[indexTensorDim]; RankedTensorType inputType = input.getType().cast(); - RankedTensorType indexTensorType = - indexTensor.getType().cast(); RankedTensorType resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); Type elementType = resultType.getElementType(); int inputRank = inputType.getRank(); - int indexTensorRank = indexTensorType.getRank(); + int resultRank = resultType.getRank(); + int firstIndexDim = indexTensorDims[0]; + int replacedIndexCount = indexTensorDims.size(); + int64_t startIndex = contiguous ? firstIndexDim : 0; + + // Currently we only support statically sized index tensors + // when there is more than one index tensor. + // TODO: Add support for dynamic size index tensors. This will probably + // require broadcasting the index tensors to a common shape. + SmallVector broadcastedIndexShape; + if (indexTensors.size() > 1) { + int maxRank = -1; + for (auto indexTensor : indexTensors) { + RankedTensorType indexTensorType = + indexTensor.getType().cast(); + maxRank = std::max(maxRank, (int)indexTensorType.getRank()); + } + + // Because we are assuming static shapes, we can get the shape of the + // broadcasted index tensors from the shape refinement pass + auto refinedResultShape = resultType.getShape(); + for (auto i : llvm::seq(startIndex, startIndex + maxRank)) { + auto resultDimSize = refinedResultShape[i]; + if (ShapedType::isDynamic(resultDimSize)) { + return rewriter.notifyMatchFailure( + op, "unimplemented: index tensors must have static shape if " + "there is more than one index tensor"); + } + broadcastedIndexShape.push_back( + getConstant(rewriter, loc, resultDimSize, rewriter.getIndexType())); + } + } else { + // For a single indexing tensor we can simply use its (dynamic) sizes + broadcastedIndexShape = + getTensorSizes(rewriter, loc, indexTensors.front()); + } // This result shape calculation assumes that there is only one - // index tensor of the input tensor. The calculation for arbitrary inputs is - // much more complex. + // index tensor, or all of the index tensors are statically shaped. + int broadcastRank = broadcastedIndexShape.size(); + SmallVector resultShape; - for (auto i : llvm::seq(0, indexTensorDim)) { - resultShape.push_back(getDimOp(rewriter, loc, input, i)); - } - for (auto i : llvm::seq(0, indexTensorRank)) { - resultShape.push_back(getDimOp(rewriter, loc, indexTensor, i)); - } - for (auto i : llvm::seq(indexTensorDim + 1, inputRank)) { - resultShape.push_back(getDimOp(rewriter, loc, input, i)); + if (contiguous) { + for (auto i : llvm::seq(0, firstIndexDim)) { + resultShape.push_back(getDimOp(rewriter, loc, input, i)); + } + resultShape.append(broadcastedIndexShape); + for (auto i : llvm::seq((int)resultShape.size(), resultRank)) { + resultShape.push_back(getDimOp(rewriter, loc, input, + i - broadcastRank + replacedIndexCount)); + } + } else { + resultShape.append(broadcastedIndexShape); + int j = 0; + for (auto i : llvm::seq(0, inputRank)) { + if (j < replacedIndexCount && i == indexTensorDims[j]) { + j++; + continue; + } + resultShape.push_back(getDimOp(rewriter, loc, input, i)); + } } - int resultRank = resultShape.size(); + // Initialize the indexing maps for the generic op. Because we are assuming + // static shapes for the indexing tensors when there are more than 1, we can + // safely map all size 1 dims to 0 in the corresponding affine maps. + // TODO: For dynamic shapes, we have to either broadcast the index tensors + // to a common shape or introduce some form of control flow. Value initTensor = rewriter.create(loc, resultShape, elementType); - SmallVector indicesExpr, resultExpr; + SmallVector indexingMaps; SmallVector iteratorTypes; - for (auto i : llvm::seq(indexTensorDim, indexTensorDim + indexTensorRank)) - indicesExpr.push_back(rewriter.getAffineDimExpr(i)); + for (auto indexTensor : indexTensors) { + RankedTensorType indexTensorType = + indexTensor.getType().cast(); + auto indexTensorShape = indexTensorType.getShape(); + int rank = indexTensorShape.size(); + SmallVector indicesExpr; + for (auto dim : llvm::seq(0, rank)) { + if (indexTensorShape[dim] == 1) { + indicesExpr.push_back(rewriter.getAffineConstantExpr(0)); + continue; + } + indicesExpr.push_back( + rewriter.getAffineDimExpr(startIndex + broadcastRank - rank + dim)); + } + indexingMaps.push_back( + AffineMap::get(resultRank, 0, indicesExpr, op->getContext())); + } + + SmallVector resultExpr; for (auto i : llvm::seq(0, resultRank)) { resultExpr.push_back(rewriter.getAffineDimExpr(i)); iteratorTypes.push_back(getParallelIteratorTypeName()); } - auto indexingMaps = AffineMap::inferFromExprList({indicesExpr, resultExpr}); + + indexingMaps.push_back( + AffineMap::get(resultRank, 0, resultExpr, op->getContext())); Value finalRes = rewriter .create( - loc, initTensor.getType(), indexTensor, initTensor, + loc, initTensor.getType(), indexTensors, initTensor, indexingMaps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { - Value index = castIntToIndex(b, loc, args[0]); SmallVector extractionIndices; - int extra_dims = 0; - for (auto i : llvm::seq(0, inputRank)) { - if (i == indexTensorDim) { - extractionIndices.push_back(index); - extra_dims += indexTensorRank - 1; - } else { + if (contiguous) { + for (auto i : llvm::seq(0, firstIndexDim)) { + extractionIndices.push_back( + b.create(loc, i)); + } + for (auto i : llvm::seq(0, (int)indexTensorDims.size())) { extractionIndices.push_back( - b.create(loc, i + extra_dims)); + castIntToIndex(b, loc, args[i])); + } + for (auto i : + llvm::seq((int)extractionIndices.size(), inputRank)) { + extractionIndices.push_back(b.create( + loc, i + broadcastRank - replacedIndexCount)); + } + } else { + int indexCount = 0, unchanged = 0; + for (auto i : llvm::seq(0, inputRank)) { + if (indexCount < replacedIndexCount && + i == indexTensorDims[indexCount]) { + extractionIndices.push_back( + castIntToIndex(b, loc, args[indexCount++])); + continue; + } + extractionIndices.push_back(b.create( + loc, broadcastRank + unchanged)); + unchanged++; } } Value extractedElement = b.create( diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 5b8c02bb4abd..b6a21b7b884d 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1479,6 +1479,35 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns( }); } +//===----------------------------------------------------------------------===// +// AtenAddTOp +//===----------------------------------------------------------------------===// + +void AtenAddTOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenAddTOp op, PatternRewriter &rewriter) { + auto lhsListConstruct = op.a().getDefiningOp(); + if (!lhsListConstruct || isListPotentiallyMutated(lhsListConstruct)) + return failure(); + + auto rhsListConstruct = op.b().getDefiningOp(); + if (!rhsListConstruct || isListPotentiallyMutated(rhsListConstruct)) + return failure(); + + SmallVector concatenatedList; + for (auto a : lhsListConstruct.getOperands()) { + concatenatedList.push_back(a); + } + for (auto b : rhsListConstruct.getOperands()) { + concatenatedList.push_back(b); + } + + rewriter.replaceOpWithNewOp(op, op.getType(), + concatenatedList); + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenEqIntListOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index ad52744936c2..01febf4ed0ed 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -6590,30 +6590,30 @@ module { %10 = torch.aten.len.t %arg1 : !torch.list>> -> !torch.int %11 = torch.prim.ListConstruct %int9223372036854775807, %10 : (!torch.int, !torch.int) -> !torch.list %12 = torch.prim.min.self_int %11 : !torch.list -> !torch.int - %13:3 = torch.prim.Loop %12, %true, init(%true, %int-1, %int-1) { - ^bb0(%arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.int, %arg5: !torch.int): + %13:2 = torch.prim.Loop %12, %true, init(%true, %int-1) { + ^bb0(%arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.int): %16 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list>>, !torch.int -> !torch.optional> %17 = torch.aten.__isnot__ %16, %none : !torch.optional>, !torch.none -> !torch.bool - %18:3 = torch.prim.If %17 -> (!torch.bool, !torch.int, !torch.int) { + %18:2 = torch.prim.If %17 -> (!torch.bool, !torch.int) { %19 = torch.aten.eq.int %arg4, %int-1 : !torch.int, !torch.int -> !torch.bool - %20:3 = torch.prim.If %19 -> (!torch.bool, !torch.int, !torch.int) { - torch.prim.If.yield %arg3, %arg2, %arg2 : !torch.bool, !torch.int, !torch.int + %20:2 = torch.prim.If %19 -> (!torch.bool, !torch.int) { + torch.prim.If.yield %arg3, %arg2 : !torch.bool, !torch.int } else { - %21 = torch.aten.sub.int %arg2, %arg5 : !torch.int, !torch.int -> !torch.int + %21 = torch.aten.sub.int %arg2, %arg4 : !torch.int, !torch.int -> !torch.int %22 = torch.aten.ne.int %21, %int1 : !torch.int, !torch.int -> !torch.bool %23 = torch.prim.If %22 -> (!torch.bool) { torch.prim.If.yield %false : !torch.bool } else { torch.prim.If.yield %arg3 : !torch.bool } - torch.prim.If.yield %23, %arg4, %arg5 : !torch.bool, !torch.int, !torch.int + torch.prim.If.yield %23, %arg4 : !torch.bool, !torch.int } - torch.prim.If.yield %20#0, %20#1, %20#2 : !torch.bool, !torch.int, !torch.int + torch.prim.If.yield %20#0, %20#1 : !torch.bool, !torch.int } else { - torch.prim.If.yield %arg3, %arg4, %arg5 : !torch.bool, !torch.int, !torch.int + torch.prim.If.yield %arg3, %arg4 : !torch.bool, !torch.int } - torch.prim.Loop.condition %true, iter(%18#0, %18#1, %18#2 : !torch.bool, !torch.int, !torch.int) - } : (!torch.int, !torch.bool, !torch.bool, !torch.int, !torch.int) -> (!torch.bool, !torch.int, !torch.int) + torch.prim.Loop.condition %true, iter(%18#0, %18#1 : !torch.bool, !torch.int) + } : (!torch.int, !torch.bool, !torch.bool, !torch.int) -> (!torch.bool, !torch.int) %14 = torch.aten.__not__ %13#0 : !torch.bool -> !torch.bool %15 = torch.prim.If %14 -> (!torch.list) { %16 = torch.aten.add.t %6, %4 : !torch.list, !torch.list -> !torch.list diff --git a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp index b9a4ea29aede..f8bd58878738 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp @@ -418,6 +418,7 @@ class SimplifyShapeCalculationsPass Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context); AtenSizeOp::getCanonicalizationPatterns(patterns, context); AtenLenTOp::getCanonicalizationPatterns(patterns, context); + AtenAddTOp::getCanonicalizationPatterns(patterns, context); // TODO: Debug visitation order to make this more efficient. // A single linear scan should suffice. diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index 1aa13ef5e45d..5b267fe5676d 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -1016,6 +1016,7 @@ def aten〇pad(self: List[int], pad: List[int], mode: str = "constant", value: O Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), None]), # Explicit None value. Invocation(TensorOfShape(2, 3, 4, 5), [None, LongTensorOfShape(4), LongTensorOfShape(4)]), # Indexing tensors on consecutive dimensions. Invocation(TensorOfShape(2, 3, 4, 5), [None, LongTensorOfShape(4), None, LongTensorOfShape(4)]), # Indexing tensors on non-consecutive dimensions. + Invocation(TensorOfShape(2, 3, 4, 5), [LongTensorOfShape(4, 2), None, LongTensorOfShape(2)]), # Indexing tensors on non-consecutive dimensions. Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4, 5, 6), LongTensorOfShape(1, 5, 1)]), # Broadcasting of index tensors. Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4)]), # Fewer index tensors than dimensions. ErrorInvocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(4), LongTensorOfShape(4)]), # More index tensors than dimensions. @@ -1037,15 +1038,13 @@ def aten〇index〇Tensor(self: List[int], indices: List[Optional[List[int]]]) - if len(unused_dim_sizes) == 0: return broadcasted_shape - prev_index_tensor_location = -1 first_index_tensor_location = -1 index_tensors_are_together = True for e, index_tensor_shape in enumerate(indices): if index_tensor_shape is not None: if first_index_tensor_location == -1: first_index_tensor_location = e - prev_index_tensor_location = e - elif e - prev_index_tensor_location != 1: + elif e - first_index_tensor_location != 1: index_tensors_are_together = False if not index_tensors_are_together: diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 48ee9722f1af..326db5b88c7c 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -489,7 +489,7 @@ def emit_with_mutating_variants(key, **kwargs): # List ops. emit("aten::cat : (Tensor[], int) -> (Tensor)") emit("aten::append.t : (t[], t) -> (t[])") - emit("aten::add.t : (t[], t[]) -> (t[])") + emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True) emit("aten::eq.int_list : (int[], int[]) -> (bool)", has_folder=True) emit("aten::list.t : (t[]) -> (t[])") emit("aten::slice.t : (t[], int?, int?, int) -> (t[])") diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 1100f2532ead..e4333d0810a0 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -1700,6 +1700,130 @@ def IndexTensorSelectDimModule_basic(module, tu: TestUtils): # ============================================================================== +class IndexTensorMultiInput(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([3, 3], torch.int64, True), + ([3], torch.int64, True), + ]) + def forward(self, x, index1, index2): + return torch.ops.aten.index(x, (index1, index2,)) + + +@register_test_case(module_factory=lambda: IndexTensorMultiInput()) +def IndexTensorMultiInput_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3), torch.randint(3, (3, 3)), torch.randint(3, (3,))) + + +# ============================================================================== + + +class IndexTensorMultiInputOneDim(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([6, 1], torch.int64, True), + ([3], torch.int64, True), + ]) + def forward(self, x, index1, index2): + return torch.ops.aten.index(x, (index1, index2,)) + + +@register_test_case(module_factory=lambda: IndexTensorMultiInputOneDim()) +def IndexTensorMultiInputOneDim_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3), torch.randint(4, (6, 1)), torch.randint(3, (3,))) + + +# ============================================================================== + + +class IndexTensorMultiInputNonContiguous(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([4, 2], torch.int64, True), + ([4, 2], torch.int64, True), + ]) + def forward(self, x, index1, index2): + return torch.ops.aten.index(x, (index1, None, index2)) + + +@register_test_case(module_factory=lambda: IndexTensorMultiInputNonContiguous()) +def IndexTensorMultiInputNonContiguous_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3, 2), torch.randint(3, (4, 2)), torch.randint(1, (4, 2,))) + + +# ============================================================================== + + +class IndexTensorMultiInputThreeIndexers(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1, -1, -1], torch.float32, True), + ([8, 4, 2], torch.int64, True), + ([8, 1, 1], torch.int64, True), + ([4, 2], torch.int64, True), + ]) + def forward(self, x, index1, index2, index3): + return torch.ops.aten.index(x, (None, None, index1, None, index2, index3)) + + +@register_test_case(module_factory=lambda: IndexTensorMultiInputThreeIndexers()) +def IndexTensorMultiInputThreeIndexers_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 2, 4, 4, 5, 3), + torch.randint(3, (8, 4, 2,)), + torch.randint(4, (8, 1, 1,)), + torch.randint(2, (4, 2,))) + + +# ============================================================================== + + +class IndexTensorMultiInputContiguousCenter(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([2, 2], torch.int64, True), + ([2], torch.int64, True), + ]) + def forward(self, x, index1, index2): + return torch.ops.aten.index(x, (None, index1, index2, None)) + + +@register_test_case(module_factory=lambda: IndexTensorMultiInputContiguousCenter()) +def IndexTensorMultiInputContiguousCenter_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3, 2), torch.randint(3, (2, 2)), torch.randint(2, [2])) + + +# ============================================================================== + + class SquareModule(torch.nn.Module): def __init__(self): diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index e4add75692d8..d5cfb0e36ca0 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -635,6 +635,53 @@ func.func @torch.aten.__getitem__.t$invalid_index() -> !torch.int { return %1 : !torch.int } +// Not canonicalized because of mutated lhs list +// CHECK-LABEL: func.func @torch.aten.add.t$no_canonicalize_lhs_mutated() +func.func @torch.aten.add.t$no_canonicalize_lhs_mutated() -> !torch.list { + %int4 = torch.constant.int 4 + %0 = torch.prim.ListConstruct : () -> !torch.list + %1 = torch.prim.ListConstruct : () -> !torch.list + %2 = torch.aten.append.t %0, %int4 : !torch.list, !torch.int -> !torch.list + // CHECK: torch.aten.add.t + %3 = torch.aten.add.t %0, %1 : !torch.list, !torch.list -> !torch.list + return %3 : !torch.list +} + +// Not canonicalized because of mutated rhs list +// CHECK-LABEL: func.func @torch.aten.add.t$no_canonicalize_rhs_mutated() +func.func @torch.aten.add.t$no_canonicalize_rhs_mutated() -> !torch.list { + %int4 = torch.constant.int 4 + %0 = torch.prim.ListConstruct : () -> !torch.list + %1 = torch.prim.ListConstruct : () -> !torch.list + %2 = torch.aten.append.t %1, %int4 : !torch.list, !torch.int -> !torch.list + // CHECK: torch.aten.add.t + %3 = torch.aten.add.t %0, %1 : !torch.list, !torch.list -> !torch.list + return %3 : !torch.list +} + +// CHECK-LABEL: func.func @torch.aten.add.t$concat( +// CHECK-SAME: %[[ARG0:.*]]: !torch.int, +// CHECK-SAME: %[[ARG1:.*]]: !torch.int) -> !torch.list { +// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[ARG0]], %[[ARG1]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: return %[[LIST]] : !torch.list +func.func @torch.aten.add.t$concat(%arg0: !torch.int, %arg1: !torch.int) -> !torch.list { + %0 = torch.prim.ListConstruct %arg0 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list + %2 = torch.aten.add.t %0, %1 : !torch.list, !torch.list -> !torch.list + return %2 : !torch.list +} + +// CHECK-LABEL: func.func @torch.aten.add.t$concat_empty( +// CHECK-SAME: %[[ARG0:.*]]: !torch.int) -> !torch.list { +// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[ARG0]] : (!torch.int) -> !torch.list +// CHECK: return %[[LIST]] : !torch.list +func.func @torch.aten.add.t$concat_empty(%arg0: !torch.int) -> !torch.list { + %0 = torch.prim.ListConstruct %arg0 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct : () -> !torch.list + %2 = torch.aten.add.t %0, %1 : !torch.list, !torch.list -> !torch.list + return %2 : !torch.list +} + // CHECK-LABEL: func.func @torch.aten.eq.int_list$fold$literals_of_different_sizes // CHECK: %[[RET:.*]] = torch.constant.bool false // CHECK: return %[[RET]] : !torch.bool