Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add decomposition for aten.roll #1170

Merged
merged 4 commits into from
Aug 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions e2e_testing/torchscript/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
"ReshapeAliasCollapseModule_basic",
"ReshapeAliasExpandModule_basic",
"ReshapeExpandModule_basic",
"RollModule_basic",
"TestMultipleTensorReturn_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
"BaddbmmStaticModule_basic",
Expand Down Expand Up @@ -447,6 +448,7 @@
"QuantizedMLP_basic",
"RandLikeDtypeModule_basic",
"RandLikeModule_basic",
"RollModule_basic",
"ScalarImplicitFloatModule_basic",
"ScalarImplicitIntModule_basic",
"SliceEndSleStartModule_basic",
Expand Down
25 changes: 25 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3478,6 +3478,31 @@ def Torch_Aten_ConvolutionDeprecatedOp : Torch_Op<"aten._convolution.deprecated"
}];
}

def Torch_AtenRollOp : Torch_Op<"aten.roll", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::roll : (Tensor, int[], int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$shifts,
AnyTorchListOfTorchIntType:$dims
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRollOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenRollOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

def Torch_AtenFlipOp : Torch_Op<"aten.flip", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TorchToMhlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -979,7 +979,7 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(

size_t posDim = toPositiveDim(dim, outType.getRank());
rewriter.replaceOpWithNewOp<mhlo::ConcatenateOp>(
op, ValueRange(builtinTensors), posDim);
op, outType, ValueRange(builtinTensors), posDim);
return success();
}
} // namespace
Expand Down
73 changes: 73 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,77 @@ class DecomposeAtenTOp : public OpRewritePattern<AtenTOp> {
};
} // namespace

// Decompose aten.roll into aten.slice and aten.cat ops.
// https://pytorch.org/docs/stable/generated/torch.roll.html
namespace {
class DecomposeAtenRollOp : public OpRewritePattern<AtenRollOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenRollOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value> shifts;
if (!getListConstructElements(op.shifts(), shifts))
return rewriter.notifyMatchFailure(
op, "unimplemented: shifts not list of Scalar");
SmallVector<Value> dims;
if (!getListConstructElements(op.dims(), dims))
return rewriter.notifyMatchFailure(
op, "unimplemented: dims not list of Scalar");

if (shifts.size() != dims.size())
return op.emitError("list sizes of shifts and dims are not the same");

auto loc = op.getLoc();
Value constNone = rewriter.create<ConstantNoneOp>(loc);
Value constZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value constOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
auto self = op.self();
auto selfTy = self.getType().cast<BaseTensorType>();
// roll(input, shift, dim) = cat({
// slice(input, dim, -shift, none),
// slice(input, dim, 0, -shift)}, dim)
auto imitateRoll = [&](Value input, Value shift, Value dim,
int64_t cstDim) {
Value negShift = rewriter.create<AtenNegIntOp>(loc, shift);
ArrayRef<int64_t> inputShape = selfTy.getSizes();
SmallVector<int64_t> sizes;
sizes.append(inputShape.begin(), inputShape.end());
sizes[cstDim] = ShapedType::kDynamicSize;
Type sliceTy = selfTy.getWithSizesAndDtype(llvm::makeArrayRef(sizes),
selfTy.getDtype());
Value slice0 = rewriter.create<AtenSliceTensorOp>(
loc, sliceTy, input, dim, negShift, constNone, constOne);
Value slice1 = rewriter.create<AtenSliceTensorOp>(
loc, sliceTy, input, dim, constZero, negShift, constOne);

Type listType = Torch::ListType::get(sliceTy);
Value slices = rewriter.create<PrimListConstructOp>(
loc, listType, llvm::ArrayRef<Value>{slice0, slice1});
return rewriter.create<AtenCatOp>(loc, self.getType(), slices, dim);
};
int rank = getTensorRank(self);
if (rank < 0)
return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor");
Value output = self;
auto nShifts = shifts.size();
for (size_t k = 0; k < nShifts; ++k) {
auto dim = dims[k];
int64_t cstDim = -1;
if (!matchPattern(dim, m_TorchConstantInt(&cstDim)))
return rewriter.notifyMatchFailure(
op, "unimplemented: dim must be constant");

cstDim = toPositiveDim(cstDim, rank);
output = imitateRoll(output, shifts[k], dim, cstDim);
}
rewriter.replaceOp(op, output);
return success();
}
};
} // namespace

// Decompose aten.repeat into aten.expand and aten.view ops.
//
// Ref: https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html
Expand Down Expand Up @@ -2555,6 +2626,8 @@ class DecomposeComplexOpsPass
patterns.add<DecomposeConstantTensorAllocLikeOp<AtenZerosLikeOp, 0>>(
context);
target.addIllegalOp<AtenZerosLikeOp>();
patterns.add<DecomposeAtenRollOp>(context);
target.addIllegalOp<AtenRollOp>();
patterns.add<DecomposeAtenRepeatOp>(context);
target.addIllegalOp<AtenRepeatOp>();
patterns.add<DecomposeAtenExpandOp>(context);
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,8 @@ void TypeAnalysis::visitOperation(Operation *op,
AtenZero_Op, AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp,
AtenIndexPutOp, ValsemVariantAtenCopyOp, AtenZeroOp,
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp, AtenMaskedFillTensorOp>(
PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp, AtenMaskedFillTensorOp,
AtenRollOp>(
op)) {
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
}
Expand Down
4 changes: 4 additions & 0 deletions lib/Dialect/Torch/Transforms/ShapeLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4213,6 +4213,10 @@ module {
}
return %7 : !torch.list<int>
}
func.func @__torch_mlir_shape_fn.aten.roll(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @__torch__.torch.jit._shape_functions.expand(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%int-1 = torch.constant.int -1
%true = torch.constant.bool true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,9 @@ def aten〇repeat(self: List[int], repeats: List[int]) -> List[int]:
out.append(self[i] * repeats[i + leading_rank])
return out

def aten〇roll(self: List[int], shifts: List[int], dims: List[int] = ()) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇expand(self: List[int], size: List[int], implicit: bool = False) -> List[int]:
return upstream_shape_functions.expand(self, size)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)")
emit("aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)")
emit("aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)")
emit("aten::roll : (Tensor, int[], int[]) -> (Tensor)"),
emit("aten::flip : (Tensor, int[]) -> (Tensor)")
emit(
"aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)"
Expand Down
22 changes: 21 additions & 1 deletion python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,27 @@ def BroadcastToModule_basic(module, tu: TestUtils):
# ==============================================================================


class RollModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([3, -1, 2], torch.float32, True),
])
def forward(self, x):
return x.roll([2, -1], [0, 2])


@register_test_case(module_factory=lambda: RollModule())
def RollModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 2))

# ==============================================================================


class RepeatModule(torch.nn.Module):

def __init__(self):
Expand All @@ -1065,7 +1086,6 @@ def forward(self, x):
def RepeatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 2))


# ==============================================================================


Expand Down
32 changes: 32 additions & 0 deletions test/Dialect/Torch/decompose-complex-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1336,6 +1336,7 @@ func.func @torch.aten.std.dim(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vten
}

// -----

// CHECK-LABEL: func.func @torch.aten.flatten.using_ints(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?],f32> {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
Expand All @@ -1350,3 +1351,34 @@ func.func @torch.aten.flatten.using_ints(%arg0: !torch.vtensor<[?,?,?,?],f32>) -
%1 = torch.aten.flatten.using_ints %arg0, %int0, %int3: !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?],f32>
return %1 : !torch.vtensor<[?],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.roll(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.int, %[[ARG2:.*]]: !torch.int) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[ARG1]], %[[ARG2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT:.*]]-2 = torch.constant.int -2
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT]]-2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT1_0:.*]] = torch.constant.int 1
// CHECK: %[[T2:.*]] = torch.aten.neg.int %[[ARG1]] : !torch.int -> !torch.int
// CHECK: %[[T3:.*]] = torch.aten.slice.Tensor %[[ARG0]], %[[INT1]], %[[T2]], %[[NONE]], %[[INT1]]_0 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: %[[T4:.*]] = torch.aten.slice.Tensor %[[ARG0]], %[[INT1]], %[[INT0]], %[[T2]], %[[INT1]]_0 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: %[[T5:.*]] = torch.prim.ListConstruct %[[T3]], %[[T4]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list<vtensor<[?,?],f32>>
// CHECK: %[[T6:.*]] = torch.aten.cat %[[T5]], %[[INT1]] : !torch.list<vtensor<[?,?],f32>>, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: %[[T7:.*]] = torch.aten.neg.int %[[ARG2]] : !torch.int -> !torch.int
// CHECK: %[[T8:.*]] = torch.aten.slice.Tensor %[[T6]], %[[INT]]-2, %[[T7]], %[[NONE]], %[[INT]]1_0 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: %[[T9:.*]] = torch.aten.slice.Tensor %[[T6]], %[[INT]]-2, %[[INT]]0, %[[T7]], %[[INT]]1_0 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: %[[T10:.*]] = torch.prim.ListConstruct %[[T8]], %[[T9]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list<vtensor<[?,?],f32>>
// CHECK: %[[T11:.*]] = torch.aten.cat %[[T10]], %[[INT]]-2 : !torch.list<vtensor<[?,?],f32>>, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T11]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.roll(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.vtensor<[?,?],f32> {
%0 = torch.prim.ListConstruct %arg1, %arg2: (!torch.int, !torch.int) -> !torch.list<int>
%int1 = torch.constant.int 1
%int-2 = torch.constant.int -2
%1 = torch.prim.ListConstruct %int1, %int-2: (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.aten.roll %arg0, %0, %1 : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[?,?],f32>
return %2 : !torch.vtensor<[?,?],f32>
}