Skip to content

Commit

Permalink
Add lowering for _convolution.deprecated (#1259)
Browse files Browse the repository at this point in the history
* Add lowering for _convolution.deprecated
  • Loading branch information
alextsao1999 authored Aug 22, 2022
1 parent 99fb4c8 commit c38308f
Show file tree
Hide file tree
Showing 7 changed files with 166 additions and 9 deletions.
34 changes: 34 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3444,6 +3444,40 @@ def Torch_Aten_ConvolutionOp : Torch_Op<"aten._convolution", [
}];
}

def Torch_Aten_ConvolutionDeprecatedOp : Torch_Op<"aten._convolution.deprecated", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding,
AnyTorchListOfTorchIntType:$dilation,
Torch_BoolType:$transposed,
AnyTorchListOfTorchIntType:$output_padding,
Torch_IntType:$groups,
Torch_BoolType:$benchmark,
Torch_BoolType:$deterministic,
Torch_BoolType:$cudnn_enabled
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_ConvolutionDeprecatedOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 12, 1);
}
void Aten_ConvolutionDeprecatedOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 12, 1);
}
}];
}

def Torch_AtenFlipOp : Torch_Op<"aten.flip", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
17 changes: 10 additions & 7 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -927,13 +927,14 @@ class DecomposeAtenConvolutionOverrideableOp
};
} // namespace

// Decompose aten.convolution_overrideable to aten.convolution
// Decompose aten._convolution-like to aten.convolution
namespace {
class DecomposeAten_ConvolutionOp
: public OpRewritePattern<Aten_ConvolutionOp> {
template<typename ConvolutionLikeOp>
class DecomposeAten_ConvolutionLikeOp
: public OpRewritePattern<ConvolutionLikeOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_ConvolutionOp op,
using OpRewritePattern<ConvolutionLikeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ConvolutionLikeOp op,
PatternRewriter &rewriter) const override {

rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
Expand Down Expand Up @@ -2542,8 +2543,10 @@ class DecomposeComplexOpsPass
patterns.add<DecomposeAtenNativeBatchNormOp>(context);
target.addIllegalOp<AtenConvolutionOverrideableOp>();
patterns.add<DecomposeAtenConvolutionOverrideableOp>(context);
target.addIllegalOp<Aten_ConvolutionOp>();
patterns.add<DecomposeAten_ConvolutionOp>(context);
target.addIllegalOp<Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp>();
patterns.add<DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionOp>,
DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionDeprecatedOp>>(
context);
target.addIllegalOp<AtenConv2dOp>();
patterns.add<DecomposeAtenConv2dOp>(context);
patterns.add<DecomposeAtenArangeOp>(context);
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ void TypeAnalysis::visitOperation(Operation *op,

// Promote the two dtypes assuming non-zero rank.
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
Aten_ConvolutionOp, AtenConvolutionOverrideableOp>(op)) {
Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp, AtenConvolutionOverrideableOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
Expand Down
4 changes: 4 additions & 0 deletions lib/Dialect/Torch/Transforms/ShapeLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6341,6 +6341,10 @@ module {
%0 = call @"__torch_mlir_shape_fn.aten.convolution"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten._convolution.deprecated"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool) -> !torch.list<int> {
%0 = call @"__torch_mlir_shape_fn.aten.convolution"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.flip"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
return %arg0 : !torch.list<int>
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -940,7 +940,10 @@ def aten〇convolution(input: List[int], weight: List[int], bias: Optional[List[

def aten〇_convolution(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> List[int]:
return aten〇convolution(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups)


def aten〇_convolution〇deprecated(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> List[int]:
return aten〇convolution(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups)

def aten〇flip(self: List[int], dims: List[int]) -> List[int]:
return self

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)")
emit("aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)")
emit("aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)")
emit("aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)")
emit("aten::flip : (Tensor, int[]) -> (Tensor)")
emit(
"aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)"
Expand Down
112 changes: 112 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,118 @@ def forward(self, inputVec, weight):
def _Convolution2DTF32Module_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))

class _ConvolutionDeprecated2DAllFalseModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten._convolution(inputVec,
weight,
bias=None,
stride=[3, 3],
padding=[2, 2],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=1,
benchmark=False,
deterministic=False,
cudnn_enabled=False)

@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DAllFalseModule())
def _ConvolutionDeprecated2DAllFalseModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))

class _ConvolutionDeprecated2DBenchmarkModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten._convolution(inputVec,
weight,
bias=None,
stride=[3, 3],
padding=[2, 2],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=1,
benchmark=True,
deterministic=False,
cudnn_enabled=False)

@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DBenchmarkModule())
def _ConvolutionDeprecated2DBenchmarkModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))

class _ConvolutionDeprecated2DDeterministicModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten._convolution(inputVec,
weight,
bias=None,
stride=[3, 3],
padding=[2, 2],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=1,
benchmark=False,
deterministic=True,
cudnn_enabled=False)

@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DDeterministicModule())
def _ConvolutionDeprecated2DDeterministicModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))

class _ConvolutionDeprecated2DCudnnModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten._convolution(inputVec,
weight,
bias=None,
stride=[3, 3],
padding=[2, 2],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=1,
benchmark=False,
deterministic=False,
cudnn_enabled=True)

@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DCudnnModule())
def _Convolution2DCudnnModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))

class ConvolutionModule2DGroups(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit c38308f

Please sign in to comment.