Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add decomposition for at::_convolution #956

Merged
merged 1 commit into from
Jul 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion e2e_testing/torchscript/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@
"GeluBackwardModule_basic",
"ElementwiseNeIntScalarModule_basic",
"ElementwiseNeFloatTensorModule_basic",
"ConvolutionModule2DStatic_basic",
"Convolution2DStaticModule_basic",
"ElementwiseNegModule_basic",
"TestMultipleTensorReturn_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
Expand Down
35 changes: 35 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3146,6 +3146,41 @@ def Torch_AtenConvolutionOverrideableOp : Torch_Op<"aten.convolution_overrideabl
}];
}

def Torch_Aten_ConvolutionOp : Torch_Op<"aten._convolution", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding,
AnyTorchListOfTorchIntType:$dilation,
Torch_BoolType:$transposed,
AnyTorchListOfTorchIntType:$output_padding,
Torch_IntType:$groups,
Torch_BoolType:$benchmark,
Torch_BoolType:$deterministic,
Torch_BoolType:$cudnn_enabled,
Torch_BoolType:$allow_tf32
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_ConvolutionOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 13, 1);
}
void Aten_ConvolutionOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 13, 1);
}
}];
}

def Torch_AtenFlipOp : Torch_Op<"aten.flip", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
21 changes: 21 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,25 @@ class DecomposeAtenConvolutionOverrideableOp
};
} // namespace

// Decompose aten.convolution_overrideable to aten.convolution
namespace {
class DecomposeAten_ConvolutionOp
: public OpRewritePattern<Aten_ConvolutionOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_ConvolutionOp op,
PatternRewriter &rewriter) const override {

rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
op, op->getResultTypes(), op.input(), op.weight(), op.bias(),
op.stride(), op.padding(), op.dilation(), op.transposed(),
op.output_padding(), op.groups());

return success();
}
};
} // namespace

// Decompose aten.conv2d to aten.convolution
namespace {
class DecomposeAtenConv2dOp : public OpRewritePattern<AtenConv2dOp> {
Expand Down Expand Up @@ -2176,6 +2195,8 @@ class DecomposeComplexOpsPass
patterns.add<DecomposeAtenNativeBatchNormOp>(context);
target.addIllegalOp<AtenConvolutionOverrideableOp>();
patterns.add<DecomposeAtenConvolutionOverrideableOp>(context);
target.addIllegalOp<Aten_ConvolutionOp>();
patterns.add<DecomposeAten_ConvolutionOp>(context);
target.addIllegalOp<AtenConv2dOp>();
patterns.add<DecomposeAtenConv2dOp>(context);
patterns.add<DecomposeAtenArangeOp>(context);
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ ChangeResult TypeAnalyzer::visitOperation(

// Promote the two dtypes assuming non-zero rank.
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
AtenConvolutionOverrideableOp>(op)) {
Aten_ConvolutionOp, AtenConvolutionOverrideableOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
Expand Down
4 changes: 4 additions & 0 deletions lib/Dialect/Torch/Transforms/ShapeLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6272,6 +6272,10 @@ module {
%0 = call @__torch__.torch.jit._shape_functions.conv_output_size(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg8) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten._convolution"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool, %arg12: !torch.bool) -> !torch.list<int> {
%0 = call @"__torch_mlir_shape_fn.aten.convolution"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.flip"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
return %arg0 : !torch.list<int>
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,9 @@ def aten〇conv2d(input: List[int], weight: List[int], bias: Optional[List[int]]

def aten〇convolution(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int) -> List[int]:
return upstream_shape_functions.conv_output_size(input, weight, bias, stride, padding, dilation, groups)

def aten〇_convolution(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> List[int]:
return aten〇convolution(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups)

def aten〇flip(self: List[int], dims: List[int]) -> List[int]:
return self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def emit_with_mutating_variants(key, **kwargs):
)
emit("aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)")
emit("aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)")
emit("aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)")
emit("aten::flip : (Tensor, int[]) -> (Tensor)")
emit(
"aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)"
Expand Down
4 changes: 2 additions & 2 deletions python/torch_mlir_e2e_test/test_suite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
"TableBatchEmbeddingModule_basic",
"MobilenetV2Module_basic",
"MobilenetV3Module_basic",
"ConvolutionModule3D_basic",
"ConvolutionModule1D_basic",
"Convolution3DModule_basic",
"Convolution1DModule_basic",
"MaxPool2dWith3dInputModule_basic",
"MaxPool2dWithIndicesWith3dInputModule_basic",
}
Expand Down
175 changes: 160 additions & 15 deletions python/torch_mlir_e2e_test/test_suite/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def Conv2dWithPaddingDilationStrideStaticModule_basic(module, tu: TestUtils):

# ==============================================================================

class ConvolutionModule1D(torch.nn.Module):
class Convolution1DModule(torch.nn.Module):
def __init__(self):
super().__init__()

Expand All @@ -157,11 +157,11 @@ def forward(self, inputVec, weight):
output_padding=[0],
groups=1)

@register_test_case(module_factory=lambda: ConvolutionModule1D())
def ConvolutionModule1D_basic(module, tu: TestUtils):
@register_test_case(module_factory=lambda: Convolution1DModule())
def Convolution1DModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10), torch.randn(3, 3, 2))

class ConvolutionModule2D(torch.nn.Module):
class Convolution2DModule(torch.nn.Module):
def __init__(self):
super().__init__()

Expand All @@ -182,11 +182,11 @@ def forward(self, inputVec, weight):
output_padding=[0, 0],
groups=1)

@register_test_case(module_factory=lambda: ConvolutionModule2D())
def ConvolutionModule2D_basic(module, tu: TestUtils):
@register_test_case(module_factory=lambda: Convolution2DModule())
def Convolution2DModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))

class ConvolutionModule3D(torch.nn.Module):
class Convolution3DModule(torch.nn.Module):
def __init__(self):
super().__init__()

Expand All @@ -207,11 +207,11 @@ def forward(self, inputVec, weight):
output_padding=[0, 0, 0],
groups=1)

@register_test_case(module_factory=lambda: ConvolutionModule3D())
def ConvolutionModule3D_basic(module, tu: TestUtils):
@register_test_case(module_factory=lambda: Convolution3DModule())
def Convolution3DModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10, 10), torch.randn(3, 3, 2, 2, 2))

class ConvolutionModule2DStatic(torch.nn.Module):
class Convolution2DStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()

Expand All @@ -232,11 +232,11 @@ def forward(self, inputVec, weight):
output_padding=[0, 0],
groups=1)

@register_test_case(module_factory=lambda: ConvolutionModule2DStatic())
def ConvolutionModule2DStatic_basic(module, tu: TestUtils):
@register_test_case(module_factory=lambda: Convolution2DStaticModule())
def Convolution2DStaticModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))

class ConvolutionModule2DStrided(torch.nn.Module):
class Convolution2DStridedModule(torch.nn.Module):
def __init__(self):
super().__init__()

Expand All @@ -257,6 +257,151 @@ def forward(self, inputVec, weight):
output_padding=[0, 0],
groups=1)

@register_test_case(module_factory=lambda: ConvolutionModule2DStrided())
def ConvolutionModule2DStrided_basic(module, tu: TestUtils):
@register_test_case(module_factory=lambda: Convolution2DStridedModule())
def Convolution2DStridedModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))

class _Convolution2DAllFalseModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten._convolution(inputVec,
weight,
bias=None,
stride=[3, 3],
padding=[2, 2],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=1,
benchmark=False,
deterministic=False,
cudnn_enabled=False,
allow_tf32=False)

@register_test_case(module_factory=lambda: _Convolution2DAllFalseModule())
def _Convolution2DAllFalseModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))

class _Convolution2DBenchmarkModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten._convolution(inputVec,
weight,
bias=None,
stride=[3, 3],
padding=[2, 2],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=1,
benchmark=True,
deterministic=False,
cudnn_enabled=False,
allow_tf32=False)

@register_test_case(module_factory=lambda: _Convolution2DBenchmarkModule())
def _Convolution2DBenchmarkModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))

class _Convolution2DDeterministicModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten._convolution(inputVec,
weight,
bias=None,
stride=[3, 3],
padding=[2, 2],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=1,
benchmark=False,
deterministic=True,
cudnn_enabled=False,
allow_tf32=False)

@register_test_case(module_factory=lambda: _Convolution2DDeterministicModule())
def _Convolution2DDeterministicModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))

class _Convolution2DCudnnModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten._convolution(inputVec,
weight,
bias=None,
stride=[3, 3],
padding=[2, 2],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=1,
benchmark=False,
deterministic=False,
cudnn_enabled=True,
allow_tf32=False)

@register_test_case(module_factory=lambda: _Convolution2DCudnnModule())
def _Convolution2DCudnnModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))

class _Convolution2DTF32Module(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten._convolution(inputVec,
weight,
bias=None,
stride=[3, 3],
padding=[2, 2],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=1,
benchmark=False,
deterministic=False,
cudnn_enabled=False,
allow_tf32=True)

@register_test_case(module_factory=lambda: _Convolution2DTF32Module())
def _Convolution2DTF32Module_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))