From 3d0e18bbe7e79df4c715e5fc1b4a2c1282eaa3c7 Mon Sep 17 00:00:00 2001 From: Tanyo Kwok Date: Wed, 24 Aug 2022 08:36:05 +0800 Subject: [PATCH] Add decomposition for aten.roll (#1170) * Add decomposition for aten.roll * add e2e unittest * refine type of torch.roll * fix aten::cat output type --- e2e_testing/torchscript/xfail_sets.py | 2 + .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 +++++++ lib/Conversion/TorchToMhlo/Basic.cpp | 2 +- .../Torch/Transforms/DecomposeComplexOps.cpp | 73 +++++++++++++++++++ lib/Dialect/Torch/Transforms/RefineTypes.cpp | 3 +- lib/Dialect/Torch/Transforms/ShapeLibrary.cpp | 4 + .../jit_ir/build_tools/shape_lib_gen.py | 3 + .../jit_ir/build_tools/torch_ods_gen.py | 1 + .../torch_mlir_e2e_test/test_suite/basic.py | 22 +++++- test/Dialect/Torch/decompose-complex-ops.mlir | 32 ++++++++ 10 files changed, 164 insertions(+), 3 deletions(-) diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index 142419d45fb..8904d7faa91 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -127,6 +127,7 @@ "ReshapeAliasCollapseModule_basic", "ReshapeAliasExpandModule_basic", "ReshapeExpandModule_basic", + "RollModule_basic", "TestMultipleTensorReturn_basic", "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", "BaddbmmStaticModule_basic", @@ -447,6 +448,7 @@ "QuantizedMLP_basic", "RandLikeDtypeModule_basic", "RandLikeModule_basic", + "RollModule_basic", "ScalarImplicitFloatModule_basic", "ScalarImplicitIntModule_basic", "SliceEndSleStartModule_basic", diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 235d910c9c4..371fdcf2a1d 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -3478,6 +3478,31 @@ def Torch_Aten_ConvolutionDeprecatedOp : Torch_Op<"aten._convolution.deprecated" }]; } +def Torch_AtenRollOp : Torch_Op<"aten.roll", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::roll : (Tensor, int[], int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$shifts, + AnyTorchListOfTorchIntType:$dims + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRollOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenRollOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenFlipOp : Torch_Op<"aten.flip", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToMhlo/Basic.cpp b/lib/Conversion/TorchToMhlo/Basic.cpp index bf3e073c79a..81ba3dfbdef 100644 --- a/lib/Conversion/TorchToMhlo/Basic.cpp +++ b/lib/Conversion/TorchToMhlo/Basic.cpp @@ -979,7 +979,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( size_t posDim = toPositiveDim(dim, outType.getRank()); rewriter.replaceOpWithNewOp( - op, ValueRange(builtinTensors), posDim); + op, outType, ValueRange(builtinTensors), posDim); return success(); } } // namespace diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 960ba274e6d..77c0ec6d261 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -709,6 +709,77 @@ class DecomposeAtenTOp : public OpRewritePattern { }; } // namespace +// Decompose aten.roll into aten.slice and aten.cat ops. +// https://pytorch.org/docs/stable/generated/torch.roll.html +namespace { +class DecomposeAtenRollOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenRollOp op, + PatternRewriter &rewriter) const override { + SmallVector shifts; + if (!getListConstructElements(op.shifts(), shifts)) + return rewriter.notifyMatchFailure( + op, "unimplemented: shifts not list of Scalar"); + SmallVector dims; + if (!getListConstructElements(op.dims(), dims)) + return rewriter.notifyMatchFailure( + op, "unimplemented: dims not list of Scalar"); + + if (shifts.size() != dims.size()) + return op.emitError("list sizes of shifts and dims are not the same"); + + auto loc = op.getLoc(); + Value constNone = rewriter.create(loc); + Value constZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value constOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + auto self = op.self(); + auto selfTy = self.getType().cast(); + // roll(input, shift, dim) = cat({ + // slice(input, dim, -shift, none), + // slice(input, dim, 0, -shift)}, dim) + auto imitateRoll = [&](Value input, Value shift, Value dim, + int64_t cstDim) { + Value negShift = rewriter.create(loc, shift); + ArrayRef inputShape = selfTy.getSizes(); + SmallVector sizes; + sizes.append(inputShape.begin(), inputShape.end()); + sizes[cstDim] = ShapedType::kDynamicSize; + Type sliceTy = selfTy.getWithSizesAndDtype(llvm::makeArrayRef(sizes), + selfTy.getDtype()); + Value slice0 = rewriter.create( + loc, sliceTy, input, dim, negShift, constNone, constOne); + Value slice1 = rewriter.create( + loc, sliceTy, input, dim, constZero, negShift, constOne); + + Type listType = Torch::ListType::get(sliceTy); + Value slices = rewriter.create( + loc, listType, llvm::ArrayRef{slice0, slice1}); + return rewriter.create(loc, self.getType(), slices, dim); + }; + int rank = getTensorRank(self); + if (rank < 0) + return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor"); + Value output = self; + auto nShifts = shifts.size(); + for (size_t k = 0; k < nShifts; ++k) { + auto dim = dims[k]; + int64_t cstDim = -1; + if (!matchPattern(dim, m_TorchConstantInt(&cstDim))) + return rewriter.notifyMatchFailure( + op, "unimplemented: dim must be constant"); + + cstDim = toPositiveDim(cstDim, rank); + output = imitateRoll(output, shifts[k], dim, cstDim); + } + rewriter.replaceOp(op, output); + return success(); + } +}; +} // namespace + // Decompose aten.repeat into aten.expand and aten.view ops. // // Ref: https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html @@ -2555,6 +2626,8 @@ class DecomposeComplexOpsPass patterns.add>( context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); patterns.add(context); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 8696e3c4c0b..4597208fcb3 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -658,7 +658,8 @@ void TypeAnalysis::visitOperation(Operation *op, AtenZero_Op, AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp, ValsemVariantAtenCopyOp, AtenZeroOp, AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp, - PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp, AtenMaskedFillTensorOp>( + PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp, AtenMaskedFillTensorOp, + AtenRollOp>( op)) { return incorporateKnowledge(op->getResult(0), operands[0]->getValue()); } diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index 7f435531579..2afd5640c1d 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -4213,6 +4213,10 @@ module { } return %7 : !torch.list } + func.func @__torch_mlir_shape_fn.aten.roll(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list) -> !torch.list { + %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list + return %0 : !torch.list + } func.func @__torch__.torch.jit._shape_functions.expand(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { %int-1 = torch.constant.int -1 %true = torch.constant.bool true diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index 2d8ba3b6e2c..41379eb66b5 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -635,6 +635,9 @@ def aten〇repeat(self: List[int], repeats: List[int]) -> List[int]: out.append(self[i] * repeats[i + leading_rank]) return out +def aten〇roll(self: List[int], shifts: List[int], dims: List[int] = ()) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇expand(self: List[int], size: List[int], implicit: bool = False) -> List[int]: return upstream_shape_functions.expand(self, size) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 49b21799cd2..4f223b1e360 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -338,6 +338,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)") emit("aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)") emit("aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)") + emit("aten::roll : (Tensor, int[], int[]) -> (Tensor)"), emit("aten::flip : (Tensor, int[]) -> (Tensor)") emit( "aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)" diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 8a0b05ab2ff..6c155fa1ee4 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -1047,6 +1047,27 @@ def BroadcastToModule_basic(module, tu: TestUtils): # ============================================================================== +class RollModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, -1, 2], torch.float32, True), + ]) + def forward(self, x): + return x.roll([2, -1], [0, 2]) + + +@register_test_case(module_factory=lambda: RollModule()) +def RollModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 1, 2)) + +# ============================================================================== + + class RepeatModule(torch.nn.Module): def __init__(self): @@ -1065,7 +1086,6 @@ def forward(self, x): def RepeatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 1, 2)) - # ============================================================================== diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 60d5589ec24..9cd2d18538b 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -1336,6 +1336,7 @@ func.func @torch.aten.std.dim(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vten } // ----- + // CHECK-LABEL: func.func @torch.aten.flatten.using_ints( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?],f32> { // CHECK: %[[INT0:.*]] = torch.constant.int 0 @@ -1350,3 +1351,34 @@ func.func @torch.aten.flatten.using_ints(%arg0: !torch.vtensor<[?,?,?,?],f32>) - %1 = torch.aten.flatten.using_ints %arg0, %int0, %int3: !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?],f32> return %1 : !torch.vtensor<[?],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.roll( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.int, %[[ARG2:.*]]: !torch.int) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[ARG1]], %[[ARG2]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[INT:.*]]-2 = torch.constant.int -2 +// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT]]-2 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[INT1_0:.*]] = torch.constant.int 1 +// CHECK: %[[T2:.*]] = torch.aten.neg.int %[[ARG1]] : !torch.int -> !torch.int +// CHECK: %[[T3:.*]] = torch.aten.slice.Tensor %[[ARG0]], %[[INT1]], %[[T2]], %[[NONE]], %[[INT1]]_0 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[?,?],f32> +// CHECK: %[[T4:.*]] = torch.aten.slice.Tensor %[[ARG0]], %[[INT1]], %[[INT0]], %[[T2]], %[[INT1]]_0 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> +// CHECK: %[[T5:.*]] = torch.prim.ListConstruct %[[T3]], %[[T4]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list> +// CHECK: %[[T6:.*]] = torch.aten.cat %[[T5]], %[[INT1]] : !torch.list>, !torch.int -> !torch.vtensor<[?,?],f32> +// CHECK: %[[T7:.*]] = torch.aten.neg.int %[[ARG2]] : !torch.int -> !torch.int +// CHECK: %[[T8:.*]] = torch.aten.slice.Tensor %[[T6]], %[[INT]]-2, %[[T7]], %[[NONE]], %[[INT]]1_0 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[?,?],f32> +// CHECK: %[[T9:.*]] = torch.aten.slice.Tensor %[[T6]], %[[INT]]-2, %[[INT]]0, %[[T7]], %[[INT]]1_0 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> +// CHECK: %[[T10:.*]] = torch.prim.ListConstruct %[[T8]], %[[T9]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list> +// CHECK: %[[T11:.*]] = torch.aten.cat %[[T10]], %[[INT]]-2 : !torch.list>, !torch.int -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T11]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.roll(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.vtensor<[?,?],f32> { + %0 = torch.prim.ListConstruct %arg1, %arg2: (!torch.int, !torch.int) -> !torch.list + %int1 = torch.constant.int 1 + %int-2 = torch.constant.int -2 + %1 = torch.prim.ListConstruct %int1, %int-2: (!torch.int, !torch.int) -> !torch.list + %2 = torch.aten.roll %arg0, %0, %1 : !torch.vtensor<[?,?],f32>, !torch.list, !torch.list -> !torch.vtensor<[?,?],f32> + return %2 : !torch.vtensor<[?,?],f32> +}