Skip to content

Commit

Permalink
Add decomposition for aten.roll
Browse files Browse the repository at this point in the history
  • Loading branch information
Tanyo Kwok committed Aug 8, 2022
1 parent 1ee8659 commit 1840707
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 0 deletions.
25 changes: 25 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3389,6 +3389,31 @@ def Torch_Aten_ConvolutionOp : Torch_Op<"aten._convolution", [
}];
}

def Torch_AtenRollOp : Torch_Op<"aten.roll", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::roll : (Tensor, int[], int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$shifts,
AnyTorchListOfTorchIntType:$dims
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRollOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenRollOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

def Torch_AtenFlipOp : Torch_Op<"aten.flip", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
58 changes: 58 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,62 @@ class DecomposeAtenTOp : public OpRewritePattern<AtenTOp> {
};
} // namespace

// Decompose aten.roll into aten.expand and aten.slice and aten.cat ops.
// https://pytorch.org/docs/stable/generated/torch.roll.html
namespace {
class DecomposeAtenRollOp : public OpRewritePattern<AtenRollOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenRollOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value> shifts;
if (!getListConstructElements(op.shifts(), shifts))
return rewriter.notifyMatchFailure(
op, "unimplemented: shifts not list of Scalar");
SmallVector<Value> dims;
if (!getListConstructElements(op.dims(), dims))
return rewriter.notifyMatchFailure(
op, "unimplemented: dims not list of Scalar");

if (shifts.size() != dims.size())
return op.emitError("list sizes of shifts and dims are not the same");

auto loc = op.getLoc();
Value constNone = rewriter.create<ConstantNoneOp>(loc);
Value constZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value constOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
auto self = op.self();
Type listType = Torch::ListType::get(self.getType());
// roll(input, shift, dim) = cat({
// slice(input, dim, -shift, none),
// slice(input, dim, 0, -shift)}, dim)
auto ImitateRoll = [&](Value input, Value shift, Value dim) {
Value negShift = rewriter.create<AtenNegIntOp>(loc, shift);
Type sliceType = computeReductionType(
rewriter, op, self.getType().cast<BaseTensorType>(), dim,
/*keepDim=*/true);
Value slice0 = rewriter.create<AtenSliceTensorOp>(
loc, sliceType, input, dim, negShift, constNone, constOne);
Value slice1 = rewriter.create<AtenSliceTensorOp>(
loc, sliceType, input, dim, constZero, negShift, constOne);

Value slices = rewriter.create<PrimListConstructOp>(
loc, listType, llvm::ArrayRef<Value>{slice0, slice1});
return rewriter.create<AtenCatOp>(loc, self.getType(), slices, dim);
};
auto output = self;
auto nShifts = shifts.size();
for (size_t k = 0; k < nShifts; ++k) {
output = ImitateRoll(output, shifts[k], dims[k]);
}
rewriter.replaceOp(op, output);
return success();
}
};
} // namespace

// Decompose aten.repeat into aten.expand and aten.view ops.
//
// Ref: https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html
Expand Down Expand Up @@ -2434,6 +2490,8 @@ class DecomposeComplexOpsPass
patterns.add<DecomposeConstantTensorAllocLikeOp<AtenZerosLikeOp, 0>>(
context);
target.addIllegalOp<AtenZerosLikeOp>();
patterns.add<DecomposeAtenRollOp>(context);
target.addIllegalOp<AtenRollOp>();
patterns.add<DecomposeAtenRepeatOp>(context);
target.addIllegalOp<AtenRepeatOp>();
patterns.add<DecomposeAtenExpandOp>(context);
Expand Down
20 changes: 20 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,6 +1065,26 @@ def forward(self, x):
def RepeatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 2))

# ==============================================================================


class RollModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([3, 1, 2], torch.float32, True),
])
def forward(self, x):
return x.roll([2, 1], [0, 2])


@register_test_case(module_factory=lambda: RollModule())
def RollModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 2))

# ==============================================================================

Expand Down
26 changes: 26 additions & 0 deletions test/Dialect/Torch/decompose-complex-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1332,3 +1332,29 @@ func.func @torch.aten.std.dim(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vten
%0 = torch.aten.std.dim %arg0, %dims, %unbiased, %keepdim: !torch.vtensor<[3,4,5],f32>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[3,4,1],f32>
return %0 : !torch.vtensor<[3,4,1],f32>
}

// -----
// CHECK-LABEL: func.func @torch.aten.roll(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.int, %[[ARG2:.*]]: !torch.int, %[[ARG3:.*]]: !torch.int, %[[ARG4:.*]]: !torch.int) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[ARG1]], %[[ARG2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[ARG2]], %[[ARG3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[T2:.*]] = torch.aten.neg.int %[[ARG1]] : !torch.int -> !torch.int
// CHECK: %[[T3:.*]] = torch.aten.slice.Tensor %[[ARG0]], %[[ARG2]], %[[T2]], %[[NONE]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: %[[T4:.*]] = torch.aten.slice.Tensor %[[ARG0]], %[[ARG2]], %[[INT0]], %[[T2]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: %[[T5:.*]] = torch.prim.ListConstruct %[[T3]], %[[T4]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list<vtensor<[?,?],f32>>
// CHECK: %[[T6:.*]] = torch.aten.cat %[[T5]], %[[ARG2]] : !torch.list<vtensor<[?,?],f32>>, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: %[[T7:.*]] = torch.aten.neg.int %[[ARG2]] : !torch.int -> !torch.int
// CHECK: %[[T8:.*]] = torch.aten.slice.Tensor %[[T6]], %[[ARG3]], %[[T7]], %[[NONE]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: %[[T9:.*]] = torch.aten.slice.Tensor %[[T6]], %[[ARG3]], %[[INT0]], %[[T7]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: %[[T10:.*]] = torch.prim.ListConstruct %[[T8]], %[[T9]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list<vtensor<[?,?],f32>>
// CHECK: %[[T11:.*]] = torch.aten.cat %[[T10]], %[[ARG3]] : !torch.list<vtensor<[?,?],f32>>, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T11]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.roll(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int) -> !torch.vtensor<[?,?],f32> {
%0 = torch.prim.ListConstruct %arg1, %arg2: (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %arg2, %arg3: (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.aten.roll %arg0, %0, %1 : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[?,?],f32>
return %2 : !torch.vtensor<[?,?],f32>
}

0 comments on commit 1840707

Please sign in to comment.