From 296bd8b870b05e5bf8b712d3945c5833f7b4d4e7 Mon Sep 17 00:00:00 2001 From: Alex <814943412@qq.com> Date: Sun, 21 Aug 2022 16:45:06 +0800 Subject: [PATCH 1/3] Add lowering for _convolution.deprecated --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 34 ++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 15 ++- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 2 +- lib/Dialect/Torch/Transforms/ShapeLibrary.cpp | 4 + .../jit_ir/build_tools/shape_lib_gen.py | 5 +- .../jit_ir/build_tools/torch_ods_gen.py | 1 + python/torch_mlir_e2e_test/test_suite/conv.py | 112 ++++++++++++++++++ 7 files changed, 165 insertions(+), 8 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 6b9ead8d506..235d910c9c4 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -3444,6 +3444,40 @@ def Torch_Aten_ConvolutionOp : Torch_Op<"aten._convolution", [ }]; } +def Torch_Aten_ConvolutionDeprecatedOp : Torch_Op<"aten._convolution.deprecated", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$dilation, + Torch_BoolType:$transposed, + AnyTorchListOfTorchIntType:$output_padding, + Torch_IntType:$groups, + Torch_BoolType:$benchmark, + Torch_BoolType:$deterministic, + Torch_BoolType:$cudnn_enabled + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_ConvolutionDeprecatedOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 12, 1); + } + void Aten_ConvolutionDeprecatedOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 12, 1); + } + }]; +} + def Torch_AtenFlipOp : Torch_Op<"aten.flip", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 555b9685332..acfbc27fbee 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -929,11 +929,12 @@ class DecomposeAtenConvolutionOverrideableOp // Decompose aten.convolution_overrideable to aten.convolution namespace { -class DecomposeAten_ConvolutionOp - : public OpRewritePattern { +template +class DecomposeAten_ConvolutionLikeOp + : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(Aten_ConvolutionOp op, + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ConvolutionLikeOp op, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( @@ -2542,8 +2543,10 @@ class DecomposeComplexOpsPass patterns.add(context); target.addIllegalOp(); patterns.add(context); - target.addIllegalOp(); - patterns.add(context); + target.addIllegalOp(); + patterns.add, + DecomposeAten_ConvolutionLikeOp>( + context); target.addIllegalOp(); patterns.add(context); patterns.add(context); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 80ae4661431..8696e3c4c0b 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -712,7 +712,7 @@ void TypeAnalysis::visitOperation(Operation *op, // Promote the two dtypes assuming non-zero rank. if (isa(op)) { + Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp, AtenConvolutionOverrideableOp>(op)) { auto knowledge = ValueKnowledge::getTensorPessimisticValueState(op->getContext()); knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index ff42c75bd14..7f435531579 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -6341,6 +6341,10 @@ module { %0 = call @"__torch_mlir_shape_fn.aten.convolution"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list return %0 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten._convolution.deprecated"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool) -> !torch.list { + %0 = call @"__torch_mlir_shape_fn.aten.convolution"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list + return %0 : !torch.list + } func.func @"__torch_mlir_shape_fn.aten.flip"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { return %arg0 : !torch.list } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index 7ec3b67f862..2d8ba3b6e2c 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -940,7 +940,10 @@ def aten〇convolution(input: List[int], weight: List[int], bias: Optional[List[ def aten〇_convolution(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> List[int]: return aten〇convolution(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups) - + +def aten〇_convolution〇deprecated(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> List[int]: + return aten〇convolution(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups) + def aten〇flip(self: List[int], dims: List[int]) -> List[int]: return self diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index a2151fdec1f..49b21799cd2 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -337,6 +337,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)") emit("aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)") emit("aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)") + emit("aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)") emit("aten::flip : (Tensor, int[]) -> (Tensor)") emit( "aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)" diff --git a/python/torch_mlir_e2e_test/test_suite/conv.py b/python/torch_mlir_e2e_test/test_suite/conv.py index d3bc77e7197..459a08b08ca 100644 --- a/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/python/torch_mlir_e2e_test/test_suite/conv.py @@ -406,6 +406,118 @@ def forward(self, inputVec, weight): def _Convolution2DTF32Module_basic(module, tu: TestUtils): module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) +class _ConvolutionDreprecated2DAllFalseModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten._convolution(inputVec, + weight, + bias=None, + stride=[3, 3], + padding=[2, 2], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=1, + benchmark=False, + deterministic=False, + cudnn_enabled=False) + +@register_test_case(module_factory=lambda: _ConvolutionDreprecated2DAllFalseModule()) +def _ConvolutionDreprecated2DAllFalseModule_basic(module, tu: TestUtils): + module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) + +class _ConvolutionDreprecated2DBenchmarkModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten._convolution(inputVec, + weight, + bias=None, + stride=[3, 3], + padding=[2, 2], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=1, + benchmark=True, + deterministic=False, + cudnn_enabled=False) + +@register_test_case(module_factory=lambda: _ConvolutionDreprecated2DBenchmarkModule()) +def _ConvolutionDreprecated2DBenchmarkModule_basic(module, tu: TestUtils): + module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) + +class _ConvolutionDreprecated2DDeterministicModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten._convolution(inputVec, + weight, + bias=None, + stride=[3, 3], + padding=[2, 2], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=1, + benchmark=False, + deterministic=True, + cudnn_enabled=False) + +@register_test_case(module_factory=lambda: _ConvolutionDreprecated2DDeterministicModule()) +def _ConvolutionDreprecated2DDeterministicModule_basic(module, tu: TestUtils): + module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) + +class _ConvolutionDreprecated2DCudnnModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten._convolution(inputVec, + weight, + bias=None, + stride=[3, 3], + padding=[2, 2], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=1, + benchmark=False, + deterministic=False, + cudnn_enabled=True) + +@register_test_case(module_factory=lambda: _ConvolutionDreprecated2DCudnnModule()) +def _Convolution2DCudnnModule_basic(module, tu: TestUtils): + module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) + class ConvolutionModule2DGroups(torch.nn.Module): def __init__(self): super().__init__() From d3e9d1f4d43cb84369badf0daafc5b64941209bc Mon Sep 17 00:00:00 2001 From: Alex <814943412@qq.com> Date: Sun, 21 Aug 2022 17:09:26 +0800 Subject: [PATCH 2/3] Fix typo for conv test --- python/torch_mlir_e2e_test/test_suite/conv.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/python/torch_mlir_e2e_test/test_suite/conv.py b/python/torch_mlir_e2e_test/test_suite/conv.py index 459a08b08ca..cb92710d2e6 100644 --- a/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/python/torch_mlir_e2e_test/test_suite/conv.py @@ -406,7 +406,7 @@ def forward(self, inputVec, weight): def _Convolution2DTF32Module_basic(module, tu: TestUtils): module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) -class _ConvolutionDreprecated2DAllFalseModule(torch.nn.Module): +class _ConvolutionDeprecated2DAllFalseModule(torch.nn.Module): def __init__(self): super().__init__() @@ -430,11 +430,11 @@ def forward(self, inputVec, weight): deterministic=False, cudnn_enabled=False) -@register_test_case(module_factory=lambda: _ConvolutionDreprecated2DAllFalseModule()) -def _ConvolutionDreprecated2DAllFalseModule_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DAllFalseModule()) +def _ConvolutionDeprecated2DAllFalseModule_basic(module, tu: TestUtils): module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) -class _ConvolutionDreprecated2DBenchmarkModule(torch.nn.Module): +class _ConvolutionDeprecated2DBenchmarkModule(torch.nn.Module): def __init__(self): super().__init__() @@ -458,11 +458,11 @@ def forward(self, inputVec, weight): deterministic=False, cudnn_enabled=False) -@register_test_case(module_factory=lambda: _ConvolutionDreprecated2DBenchmarkModule()) -def _ConvolutionDreprecated2DBenchmarkModule_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DBenchmarkModule()) +def _ConvolutionDeprecated2DBenchmarkModule_basic(module, tu: TestUtils): module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) -class _ConvolutionDreprecated2DDeterministicModule(torch.nn.Module): +class _ConvolutionDeprecated2DDeterministicModule(torch.nn.Module): def __init__(self): super().__init__() @@ -486,11 +486,11 @@ def forward(self, inputVec, weight): deterministic=True, cudnn_enabled=False) -@register_test_case(module_factory=lambda: _ConvolutionDreprecated2DDeterministicModule()) -def _ConvolutionDreprecated2DDeterministicModule_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DDeterministicModule()) +def _ConvolutionDeprecated2DDeterministicModule_basic(module, tu: TestUtils): module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) -class _ConvolutionDreprecated2DCudnnModule(torch.nn.Module): +class _ConvolutionDeprecated2DCudnnModule(torch.nn.Module): def __init__(self): super().__init__() @@ -514,7 +514,7 @@ def forward(self, inputVec, weight): deterministic=False, cudnn_enabled=True) -@register_test_case(module_factory=lambda: _ConvolutionDreprecated2DCudnnModule()) +@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DCudnnModule()) def _Convolution2DCudnnModule_basic(module, tu: TestUtils): module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) From 5ae51f05b573df531bdbaca03fba116530c8f56d Mon Sep 17 00:00:00 2001 From: Alex <814943412@qq.com> Date: Mon, 22 Aug 2022 10:56:23 +0800 Subject: [PATCH 3/3] Fix comment --- lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index acfbc27fbee..b36938afc98 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -927,7 +927,7 @@ class DecomposeAtenConvolutionOverrideableOp }; } // namespace -// Decompose aten.convolution_overrideable to aten.convolution +// Decompose aten._convolution-like to aten.convolution namespace { template class DecomposeAten_ConvolutionLikeOp