From 18407076b046177fccafdfbc2a9a7cd35ee0f5eb Mon Sep 17 00:00:00 2001 From: Tanyo Kwok Date: Mon, 8 Aug 2022 12:35:29 +0800 Subject: [PATCH] Add decomposition for aten.roll --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 ++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 58 +++++++++++++++++++ .../torch_mlir_e2e_test/test_suite/basic.py | 20 +++++++ test/Dialect/Torch/decompose-complex-ops.mlir | 26 +++++++++ 4 files changed, 129 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index ec73d5d33ae4..b13348a4e6e9 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -3389,6 +3389,31 @@ def Torch_Aten_ConvolutionOp : Torch_Op<"aten._convolution", [ }]; } +def Torch_AtenRollOp : Torch_Op<"aten.roll", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::roll : (Tensor, int[], int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$shifts, + AnyTorchListOfTorchIntType:$dims + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRollOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenRollOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenFlipOp : Torch_Op<"aten.flip", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index d020c57baca6..931528b0bf15 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -709,6 +709,62 @@ class DecomposeAtenTOp : public OpRewritePattern { }; } // namespace +// Decompose aten.roll into aten.expand and aten.slice and aten.cat ops. +// https://pytorch.org/docs/stable/generated/torch.roll.html +namespace { +class DecomposeAtenRollOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenRollOp op, + PatternRewriter &rewriter) const override { + SmallVector shifts; + if (!getListConstructElements(op.shifts(), shifts)) + return rewriter.notifyMatchFailure( + op, "unimplemented: shifts not list of Scalar"); + SmallVector dims; + if (!getListConstructElements(op.dims(), dims)) + return rewriter.notifyMatchFailure( + op, "unimplemented: dims not list of Scalar"); + + if (shifts.size() != dims.size()) + return op.emitError("list sizes of shifts and dims are not the same"); + + auto loc = op.getLoc(); + Value constNone = rewriter.create(loc); + Value constZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value constOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + auto self = op.self(); + Type listType = Torch::ListType::get(self.getType()); + // roll(input, shift, dim) = cat({ + // slice(input, dim, -shift, none), + // slice(input, dim, 0, -shift)}, dim) + auto ImitateRoll = [&](Value input, Value shift, Value dim) { + Value negShift = rewriter.create(loc, shift); + Type sliceType = computeReductionType( + rewriter, op, self.getType().cast(), dim, + /*keepDim=*/true); + Value slice0 = rewriter.create( + loc, sliceType, input, dim, negShift, constNone, constOne); + Value slice1 = rewriter.create( + loc, sliceType, input, dim, constZero, negShift, constOne); + + Value slices = rewriter.create( + loc, listType, llvm::ArrayRef{slice0, slice1}); + return rewriter.create(loc, self.getType(), slices, dim); + }; + auto output = self; + auto nShifts = shifts.size(); + for (size_t k = 0; k < nShifts; ++k) { + output = ImitateRoll(output, shifts[k], dims[k]); + } + rewriter.replaceOp(op, output); + return success(); + } +}; +} // namespace + // Decompose aten.repeat into aten.expand and aten.view ops. // // Ref: https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html @@ -2434,6 +2490,8 @@ class DecomposeComplexOpsPass patterns.add>( context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); patterns.add(context); diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 7b687d4194fd..d52245005867 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -1065,6 +1065,26 @@ def forward(self, x): def RepeatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 1, 2)) +# ============================================================================== + + +class RollModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 1, 2], torch.float32, True), + ]) + def forward(self, x): + return x.roll([2, 1], [0, 2]) + + +@register_test_case(module_factory=lambda: RollModule()) +def RollModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 1, 2)) # ============================================================================== diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 8c37005db913..9a00864013fc 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -1332,3 +1332,29 @@ func.func @torch.aten.std.dim(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vten %0 = torch.aten.std.dim %arg0, %dims, %unbiased, %keepdim: !torch.vtensor<[3,4,5],f32>, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[3,4,1],f32> return %0 : !torch.vtensor<[3,4,1],f32> } + +// ----- +// CHECK-LABEL: func.func @torch.aten.roll( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.int, %[[ARG2:.*]]: !torch.int, %[[ARG3:.*]]: !torch.int, %[[ARG4:.*]]: !torch.int) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[ARG1]], %[[ARG2]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[ARG2]], %[[ARG3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[T2:.*]] = torch.aten.neg.int %[[ARG1]] : !torch.int -> !torch.int +// CHECK: %[[T3:.*]] = torch.aten.slice.Tensor %[[ARG0]], %[[ARG2]], %[[T2]], %[[NONE]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[?,?],f32> +// CHECK: %[[T4:.*]] = torch.aten.slice.Tensor %[[ARG0]], %[[ARG2]], %[[INT0]], %[[T2]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> +// CHECK: %[[T5:.*]] = torch.prim.ListConstruct %[[T3]], %[[T4]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list> +// CHECK: %[[T6:.*]] = torch.aten.cat %[[T5]], %[[ARG2]] : !torch.list>, !torch.int -> !torch.vtensor<[?,?],f32> +// CHECK: %[[T7:.*]] = torch.aten.neg.int %[[ARG2]] : !torch.int -> !torch.int +// CHECK: %[[T8:.*]] = torch.aten.slice.Tensor %[[T6]], %[[ARG3]], %[[T7]], %[[NONE]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[?,?],f32> +// CHECK: %[[T9:.*]] = torch.aten.slice.Tensor %[[T6]], %[[ARG3]], %[[INT0]], %[[T7]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> +// CHECK: %[[T10:.*]] = torch.prim.ListConstruct %[[T8]], %[[T9]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list> +// CHECK: %[[T11:.*]] = torch.aten.cat %[[T10]], %[[ARG3]] : !torch.list>, !torch.int -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T11]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.roll(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int) -> !torch.vtensor<[?,?],f32> { + %0 = torch.prim.ListConstruct %arg1, %arg2: (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %arg2, %arg3: (!torch.int, !torch.int) -> !torch.list + %2 = torch.aten.roll %arg0, %0, %1 : !torch.vtensor<[?,?],f32>, !torch.list, !torch.list -> !torch.vtensor<[?,?],f32> + return %2 : !torch.vtensor<[?,?],f32> +}