From da81a2f219fba38e0b603ffe3459e2017f7b54dd Mon Sep 17 00:00:00 2001 From: George Petterson Date: Mon, 20 Jun 2022 03:01:39 -0400 Subject: [PATCH] Add transposed case for at::convolution --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 90 ++++++++++ lib/Conversion/TorchToLinalg/Linear.cpp | 152 +++++++++++++--- lib/Conversion/TorchToLinalg/Utils.cpp | 33 ++++ lib/Conversion/TorchToLinalg/Utils.h | 12 ++ .../Torch/Transforms/DecomposeComplexOps.cpp | 22 +++ lib/Dialect/Torch/Transforms/RefineTypes.cpp | 2 +- lib/Dialect/Torch/Transforms/ShapeLibrary.cpp | 106 +++++++++++- .../jit_ir/build_tools/shape_lib_gen.py | 33 +++- .../jit_ir/build_tools/torch_ods_gen.py | 3 + .../test_suite/__init__.py | 2 + python/torch_mlir_e2e_test/test_suite/conv.py | 163 ++++++++++++++++++ 11 files changed, 591 insertions(+), 27 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index bf955203ab79..d7732144b3dc 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -3083,6 +3083,96 @@ def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [ }]; } +def Torch_AtenConvTranspose1dOp : Torch_Op<"aten.conv_transpose1d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::conv_transpose1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$output_padding, + Torch_IntType:$groups, + AnyTorchListOfTorchIntType:$dilation + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenConvTranspose1dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 8, 1); + } + void AtenConvTranspose1dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 8, 1); + } + }]; +} + +def Torch_AtenConvTranspose2dInputOp : Torch_Op<"aten.conv_transpose2d.input", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::conv_transpose2d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$output_padding, + Torch_IntType:$groups, + AnyTorchListOfTorchIntType:$dilation + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenConvTranspose2dInputOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 8, 1); + } + void AtenConvTranspose2dInputOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 8, 1); + } + }]; +} + +def Torch_AtenConvTranspose3dInputOp : Torch_Op<"aten.conv_transpose3d.input", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::conv_transpose3d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$output_padding, + Torch_IntType:$groups, + AnyTorchListOfTorchIntType:$dilation + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenConvTranspose3dInputOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 8, 1); + } + void AtenConvTranspose3dInputOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 8, 1); + } + }]; +} + def Torch_AtenConvolutionOp : Torch_Op<"aten.convolution", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 8fa598f8d95a..8502381b8f8a 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -7,6 +7,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "../PassDetail.h" @@ -20,8 +21,12 @@ #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "llvm/ADT/ArrayRef.h" +#include +#include using namespace mlir; using namespace mlir::torch; @@ -534,12 +539,18 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { Value input = adaptor.input(); /* in form of N*C*H*W */ Value weight = adaptor.weight(); /* in form of F*C*H*W */ + bool transposed = true; + if (!matchPattern(op.transposed(), m_TorchConstantBool(&transposed))) + return rewriter.notifyMatchFailure( + op, "unimplemented: only constant transposed supported"); + Type elementType = input.getType().cast().getElementType(); if (!elementType.isa()) return op.emitError("unimplemented: non-floating point type"); size_t inRank = input.getType().cast().getRank(); - if (inRank != 4) + size_t numSpacialDims = inRank - 2; + if (numSpacialDims != 2) return rewriter.notifyMatchFailure( op, "unimplemented: only 2D convolution currently supported"); @@ -563,6 +574,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { "only support constant int dilations"); Value N = getDimOp(rewriter, loc, input, 0); + Value inChannels = getDimOp(rewriter, loc, input, 1); SmallVector inDims; for (size_t i = 2; i < inRank; i++) inDims.push_back(getDimOp(rewriter, loc, input, i)); @@ -571,37 +583,131 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { for (size_t i = 2; i < inRank; i++) weightDims.push_back(getDimOp(rewriter, loc, weight, i)); - // Guard unused values (transposed, groups) + SmallVector dilationIntValues = + getAsConstantIntValues(rewriter, loc, dilationInts); + SmallVector paddingIntValues = + getAsConstantIntValues(rewriter, loc, paddingInts); + SmallVector strideIntValues = + getAsConstantIntValues(rewriter, loc, strideInts); + + // Guard unused values (groups) int64_t group_size; if (!matchPattern(op.groups(), m_TorchConstantInt(&group_size)) || group_size != 1) return rewriter.notifyMatchFailure( op, "unimplemented: only group size of 1 supported"); - bool transposed = true; - if (!matchPattern(op.transposed(), m_TorchConstantBool(&transposed)) || - transposed) - return rewriter.notifyMatchFailure( - op, "unimplemented: only non-transposed convolution supported"); // Pad the input tensor according to padding. - SmallVector paddingIncludingNC = {0, 0}; - paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(), - paddingInts.end()); - Value paddedInput = torch_to_linalg::getZeroPaddedTensor( - op, rewriter, input, paddingIncludingNC); + SmallVector outDims{N, F}; + Value paddedInput; + if (transposed) { + Value c0 = + rewriter.create(loc, rewriter.getIndexAttr(0)); + Value c1 = + rewriter.create(loc, rewriter.getIndexAttr(1)); + Value c2 = + rewriter.create(loc, rewriter.getIndexAttr(2)); + + // Transpose and flip weight + SmallVector weightInitDims = getTensorSizes(rewriter, loc, weight); + std::iter_swap(weightInitDims.begin(), weightInitDims.begin() + 1); + outDims[1] = weightInitDims[0]; + Value weightInitTensor = + createZeroInitTensor(rewriter, loc, weightInitDims, elementType); + SmallVector iteratorTypes(inRank, + getParallelIteratorTypeName()); + SmallVector indexingMaps( + 2, AffineMap::getMultiDimIdentityMap(inRank, context)); + weight = rewriter + .create( + loc, weightInitTensor.getType(), weight, + weightInitTensor, indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + SmallVector indices; + for (size_t i = 0; i < inRank; i++) + indices.push_back(b.create(loc, i)); + std::iter_swap(indices.begin(), indices.begin() + 1); + // Flip only the spatial dimensions (from 2 to inRank) + for (size_t flipDim = 2; flipDim < inRank; flipDim++) { + indices[flipDim] = b.create( + loc, + b.create( + loc, weightInitDims[flipDim], c1), + indices[flipDim]); + } + Value res = + b.create(loc, weight, indices) + .getResult(); + b.create(loc, res); + }) + .getResult(0); - SmallVector paddingIntValues = - getAsConstantIntValues(rewriter, loc, paddingInts); - SmallVector dilationIntValues = - getAsConstantIntValues(rewriter, loc, dilationInts); - SmallVector strideIntValues = - getAsConstantIntValues(rewriter, loc, strideInts); + // Calculate padded input size, allocate tensor + SmallVector outerSizes{N, inChannels}; + SmallVector innerSizes{N, inChannels}; + SmallVector offsets{c0, c0}; + for (size_t i = 0; i < numSpacialDims; i++) { + Value innerSize = rewriter.create(loc, inDims[i], c1); + innerSize = rewriter.create( + loc, innerSize, castIntToIndex(rewriter, loc, strideIntValues[i])); + innerSize = rewriter.create(loc, innerSize, c1); + + Value offset = rewriter.create(loc, weightDims[i], c1); + offset = rewriter.create( + loc, offset, castIntToIndex(rewriter, loc, dilationIntValues[i])); + offset = rewriter.create( + loc, offset, castIntToIndex(rewriter, loc, paddingIntValues[i])); + + Value outerSize = rewriter.create(loc, offset, c2); + outerSize = rewriter.create(loc, outerSize, innerSize); + + innerSizes.push_back(innerSize); + outerSizes.push_back(outerSize); + offsets.push_back(offset); + } - SmallVector outDims{N, F}; - for (size_t i = 0; i < inRank - 2; i++) - outDims.push_back(torch_to_linalg::getOutputDimForConvOps( - rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i], - castIndexToInt(weightDims[i]), strideIntValues[i])); + // Allocate padded input tensor + Value initTensor = + createZeroInitTensor(rewriter, loc, outerSizes, elementType); + + // Insert input into allocated tensor + SmallVector strideIndexValues{c1, c1}; + for (auto stride : strideIntValues) + strideIndexValues.push_back(castIntToIndex(rewriter, loc, stride)); + SmallVector insertSizes = getTensorSizes(rewriter, loc, input); + + paddedInput = rewriter.create( + loc, torch_to_linalg::dynamicCast(rewriter, loc, input), initTensor, + offsets, insertSizes, strideIndexValues); + + // Calculate output dims + for (size_t i = 0; i < numSpacialDims; i++) + outDims.push_back(torch_to_linalg::getOutputDimForConvTransposeOps( + rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i], + castIndexToInt(weightDims[i]), strideIntValues[i])); + + // Set stride to 1 + strideInts.clear(); + strideInts.append(numSpacialDims, 1); + strideIntValues = getAsConstantIntValues(rewriter, loc, strideInts); + + paddingIntValues.clear(); + for (auto pad = offsets.begin() + 2; pad < offsets.end(); pad++) + paddingIntValues.push_back(castIndexToInt(*pad)); + } else { + // Pad input + SmallVector paddingIncludingNC = {0, 0}; + paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(), + paddingInts.end()); + paddedInput = torch_to_linalg::getZeroPaddedTensor(op, rewriter, input, + paddingIncludingNC); + + // Calculate output dims + for (size_t i = 0; i < numSpacialDims; i++) + outDims.push_back(torch_to_linalg::getOutputDimForConvOps( + rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i], + castIndexToInt(weightDims[i]), strideIntValues[i])); + } Value initTensor = rewriter.create(loc, outDims, elementType); diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 4d8492b454c3..312be75aa3dd 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -20,6 +20,7 @@ #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" @@ -97,6 +98,31 @@ Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc, return castIntToIndex(b, loc, out); } +Value torch_to_linalg::getOutputDimForConvTransposeOps( + OpBuilder &b, Location loc, Value in, Value paddingInt, Value dilationInt, + Value kernelSizeInt, Value strideInt) { + Value c1 = b.create(loc, b.getI64IntegerAttr(1)); + Value c2 = b.create(loc, b.getI64IntegerAttr(2)); + + // (in - 1) * stride + Value inStrided = + b.create(loc, castIndexToInt64(b, loc, in), c1); + inStrided = b.create(loc, inStrided, strideInt); + + // 2 * padding + Value doublePadding = b.create(loc, paddingInt, c2); + + // (kernelSize - 1) * dilation + Value kernelDilated = b.create(loc, kernelSizeInt, c1); + kernelDilated = b.create(loc, kernelDilated, dilationInt); + + Value out = b.create(loc, inStrided, doublePadding); + out = b.create(loc, out, kernelDilated); + out = b.create(loc, out, c1); + + return castIntToIndex(b, loc, out); +} + Value torch_to_linalg::createReductionLinalgGeneric( OpBuilder &b, Location loc, const ReductionOpInfo &opInfo, Value initElem, function_ref bodyBuild) { @@ -256,3 +282,10 @@ Value torch_to_linalg::createElementwiseLinalgGeneric( iteratorTypes, bodyBuild) .getResult(0); } + +Value torch_to_linalg::dynamicCast(OpBuilder &b, Location loc, Value tensor) { + auto tensorType = tensor.getType().cast(); + auto rank = tensorType.getRank(); + SmallVector unknownSizes(rank, kUnknownSize); + return b.create(loc, tensorType.clone(unknownSizes), tensor); +} diff --git a/lib/Conversion/TorchToLinalg/Utils.h b/lib/Conversion/TorchToLinalg/Utils.h index eb16387e0eca..a47b63378865 100644 --- a/lib/Conversion/TorchToLinalg/Utils.h +++ b/lib/Conversion/TorchToLinalg/Utils.h @@ -39,6 +39,14 @@ Value getOutputDimForConvOps(OpBuilder &b, Location loc, Value in, Value kernelSizeInt, Value strideInt, bool ceilMode = false); +// As above but for transposed convolution ops +// Along each dim: +// dim_out = +// (dim_in - 1) * stride - 2 * padding + dilation * (kernelSize - 1) + 1 +Value getOutputDimForConvTransposeOps(OpBuilder &b, Location loc, Value in, + Value paddingInt, Value dilationInt, + Value kernelSizeInt, Value strideInt); + // Create a reduction of `opInfo.tensorOperand`, reducing along the dimensions // in `opInfo.dimSet`. If `opInfo.keepDim` is true, the output tensor is the // same rank as the `opInfo.tensorOperand` and reduced dimensions are set to @@ -54,6 +62,10 @@ Value createElementwiseLinalgGeneric( OpBuilder &b, Location loc, ValueRange tensorOperands, Type resultElementType, function_ref bodyBuild); + +// Cast a tensor to a rank-equivalent tensor of unknown size, i.e. <1x2xf32> -> +// +Value dynamicCast(OpBuilder &b, Location loc, Value tensor); } // namespace torch_to_linalg } // namespace torch } // namespace mlir diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 1db2ed9ecda6..7146c767f827 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -791,6 +791,26 @@ class DecomposeAtenConv2dOp : public OpRewritePattern { }; } // namespace +// Decompose aten.conv_transpose2d to aten.convolution +namespace { +class DecomposeAtenConvTranspose2dOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenConvTranspose2dInputOp op, + PatternRewriter &rewriter) const override { + + Value cstTrue = rewriter.create(op.getLoc(), true); + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op.input(), op.weight(), op.bias(), + op.stride(), op.padding(), op.dilation(), cstTrue, op.output_padding(), + op.groups()); + + return success(); + } +}; +} // namespace + // Decompose aten.addmm into aten.mm and aten.add.Tensor op. namespace { class DecomposeAtenAddmmOp : public OpRewritePattern { @@ -2055,6 +2075,8 @@ class DecomposeComplexOpsPass patterns.add(context); target.addIllegalOp(); patterns.add(context); + target.addIllegalOp(); + patterns.add(context); patterns.add(context); target.addIllegalOp(); patterns.add(context); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 5cecc1189b3a..003f5710298a 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -691,7 +691,7 @@ ChangeResult TypeAnalyzer::visitOperation( // Promote the two dtypes assuming non-zero rank. if (isa(op)) { + AtenConvolutionOverrideableOp, AtenConvTranspose2dInputOp>(op)) { auto knowledge = ValueKnowledge::getTensorPessimisticValueState(op->getContext()); knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index eae3954127c0..5c2c5bcc96c8 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -6222,9 +6222,111 @@ module { %0 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list return %0 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten.conv_transpose2d.input"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.list { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int0 = torch.constant.int 0 + %true = torch.constant.bool true + %0 = torch.aten.len.t %arg7 : !torch.list -> !torch.int + %1 = torch.aten.gt.int %0, %int0 : !torch.int, !torch.int -> !torch.bool + %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %3 = torch.prim.ListConstruct : () -> !torch.list + %4 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int + %5 = torch.aten.append.t %3, %4 : !torch.list, !torch.int -> !torch.list + %6 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int + %7 = torch.aten.append.t %3, %6 : !torch.list, !torch.int -> !torch.list + %8 = torch.aten.__range_length %int2, %2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int + torch.prim.Loop %8, %true, init() { + ^bb0(%arg8: !torch.int): + %9 = torch.aten.__derive_index %arg8, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int + %10 = torch.prim.If %1 -> (!torch.int) { + %26 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int + %27 = torch.aten.__getitem__.t %arg7, %26 : !torch.list, !torch.int -> !torch.int + torch.prim.If.yield %27 : !torch.int + } else { + torch.prim.If.yield %int1 : !torch.int + } + %11 = torch.aten.__getitem__.t %arg1, %9 : !torch.list, !torch.int -> !torch.int + %12 = torch.aten.sub.int %11, %int1 : !torch.int, !torch.int -> !torch.int + %13 = torch.aten.mul.int %10, %12 : !torch.int, !torch.int -> !torch.int + %14 = torch.aten.__getitem__.t %arg0, %9 : !torch.list, !torch.int -> !torch.int + %15 = torch.aten.sub.int %14, %int1 : !torch.int, !torch.int -> !torch.int + %16 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int + %17 = torch.aten.__getitem__.t %arg3, %16 : !torch.list, !torch.int -> !torch.int + %18 = torch.aten.mul.int %15, %17 : !torch.int, !torch.int -> !torch.int + %19 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int + %20 = torch.aten.__getitem__.t %arg4, %19 : !torch.list, !torch.int -> !torch.int + %21 = torch.aten.mul.int %int2, %20 : !torch.int, !torch.int -> !torch.int + %22 = torch.aten.sub.int %18, %21 : !torch.int, !torch.int -> !torch.int + %23 = torch.aten.add.int %22, %13 : !torch.int, !torch.int -> !torch.int + %24 = torch.aten.add.int %23, %int1 : !torch.int, !torch.int -> !torch.int + %25 = torch.aten.append.t %3, %24 : !torch.list, !torch.int -> !torch.list + torch.prim.Loop.condition %true, iter() + } : (!torch.int, !torch.bool) -> () + return %3 : !torch.list + } func.func @"__torch_mlir_shape_fn.aten.convolution"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.conv_output_size(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg8) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list - return %0 : !torch.list + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int0 = torch.constant.int 0 + %true = torch.constant.bool true + %0 = torch.aten.len.t %arg5 : !torch.list -> !torch.int + %1 = torch.aten.gt.int %0, %int0 : !torch.int, !torch.int -> !torch.bool + %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %3 = torch.prim.ListConstruct : () -> !torch.list + %4 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int + %5 = torch.aten.append.t %3, %4 : !torch.list, !torch.int -> !torch.list + %6 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int + %7 = torch.aten.append.t %3, %6 : !torch.list, !torch.int -> !torch.list + %8 = torch.aten.__range_length %int2, %2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int + torch.prim.Loop %8, %true, init() { + ^bb0(%arg9: !torch.int): + %9 = torch.aten.__derive_index %arg9, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int + %10 = torch.prim.If %1 -> (!torch.int) { + %11 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int + %12 = torch.aten.__getitem__.t %arg5, %11 : !torch.list, !torch.int -> !torch.int + torch.prim.If.yield %12 : !torch.int + } else { + torch.prim.If.yield %int1 : !torch.int + } + torch.prim.If %arg6 -> () { + %11 = torch.aten.__getitem__.t %arg1, %9 : !torch.list, !torch.int -> !torch.int + %12 = torch.aten.sub.int %11, %int1 : !torch.int, !torch.int -> !torch.int + %13 = torch.aten.mul.int %10, %12 : !torch.int, !torch.int -> !torch.int + %14 = torch.aten.__getitem__.t %arg0, %9 : !torch.list, !torch.int -> !torch.int + %15 = torch.aten.sub.int %14, %int1 : !torch.int, !torch.int -> !torch.int + %16 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int + %17 = torch.aten.__getitem__.t %arg3, %16 : !torch.list, !torch.int -> !torch.int + %18 = torch.aten.mul.int %15, %17 : !torch.int, !torch.int -> !torch.int + %19 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int + %20 = torch.aten.__getitem__.t %arg4, %19 : !torch.list, !torch.int -> !torch.int + %21 = torch.aten.mul.int %int2, %20 : !torch.int, !torch.int -> !torch.int + %22 = torch.aten.sub.int %18, %21 : !torch.int, !torch.int -> !torch.int + %23 = torch.aten.add.int %22, %13 : !torch.int, !torch.int -> !torch.int + %24 = torch.aten.add.int %23, %int1 : !torch.int, !torch.int -> !torch.int + %25 = torch.aten.append.t %3, %24 : !torch.list, !torch.int -> !torch.list + torch.prim.If.yield + } else { + %11 = torch.aten.__getitem__.t %arg1, %9 : !torch.list, !torch.int -> !torch.int + %12 = torch.aten.sub.int %11, %int1 : !torch.int, !torch.int -> !torch.int + %13 = torch.aten.mul.int %10, %12 : !torch.int, !torch.int -> !torch.int + %14 = torch.aten.add.int %13, %int1 : !torch.int, !torch.int -> !torch.int + %15 = torch.aten.__getitem__.t %arg0, %9 : !torch.list, !torch.int -> !torch.int + %16 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int + %17 = torch.aten.__getitem__.t %arg4, %16 : !torch.list, !torch.int -> !torch.int + %18 = torch.aten.mul.int %int2, %17 : !torch.int, !torch.int -> !torch.int + %19 = torch.aten.add.int %15, %18 : !torch.int, !torch.int -> !torch.int + %20 = torch.aten.sub.int %19, %14 : !torch.int, !torch.int -> !torch.int + %21 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int + %22 = torch.aten.__getitem__.t %arg3, %21 : !torch.list, !torch.int -> !torch.int + %23 = torch.aten.floordiv.int %20, %22 : !torch.int, !torch.int -> !torch.int + %24 = torch.aten.add.int %23, %int1 : !torch.int, !torch.int -> !torch.int + %25 = torch.aten.append.t %3, %24 : !torch.list, !torch.int -> !torch.list + torch.prim.If.yield + } + torch.prim.Loop.condition %true, iter() + } : (!torch.int, !torch.bool) -> () + return %3 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.flip"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { return %arg0 : !torch.list diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index afd1d3297f30..9f6243929c7e 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -866,8 +866,39 @@ def aten〇topk(self: List[int], k: int, dim: int = -1, largest: bool = True, so def aten〇conv2d(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), groups: int = 1) -> List[int]: return upstream_shape_functions.conv2d(input, weight, bias, stride, padding, dilation, groups) +def aten〇conv_transpose2d〇input(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), output_padding: List[int] = (0, 0), groups: int = 1, dilation: List[int] = (1, 1)) -> List[int]: + has_dilation = len(dilation) > 0 + dim = len(input) + output_size: List[int] = [] + input_batch_size_dim = 0 + weight_output_channels_dim = 0 + output_size.append(input[input_batch_size_dim]) + output_size.append(weight[weight_output_channels_dim]) + + for d in range(2, dim): + dilation_ = dilation[d - 2] if has_dilation else 1 + kernel = dilation_ * (weight[d] - 1) + output_size.append((input[d] - 1) * stride[d - 2] - 2 * padding[d - 2] + kernel + 1) + return output_size + def aten〇convolution(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int) -> List[int]: - return upstream_shape_functions.conv_output_size(input, weight, bias, stride, padding, dilation, groups) + has_dilation = len(dilation) > 0 + dim = len(input) + output_size: List[int] = [] + input_batch_size_dim = 0 + weight_output_channels_dim = 0 + output_size.append(input[input_batch_size_dim]) + output_size.append(weight[weight_output_channels_dim]) + + for d in range(2, dim): + dilation_ = dilation[d - 2] if has_dilation else 1 + if transposed: + kernel = dilation_ * (weight[d] - 1) + output_size.append((input[d] - 1) * stride[d - 2] - 2 * padding[d - 2] + kernel + 1) + else: + kernel = dilation_ * (weight[d] - 1) + 1 + output_size.append((input[d] + (2 * padding[d - 2]) - kernel) // stride[d - 2] + 1) + return output_size def aten〇flip(self: List[int], dims: List[int]) -> List[int]: return self diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index cd9b149af1e9..6bec3a35e977 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -327,6 +327,9 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" ) + emit("aten::conv_transpose1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)") + emit("aten::conv_transpose2d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)") + emit("aten::conv_transpose3d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)") emit("aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)") emit("aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)") emit("aten::flip : (Tensor, int[]) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/__init__.py b/python/torch_mlir_e2e_test/test_suite/__init__.py index 584d4c0310fd..481a0c8a7ef2 100644 --- a/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -13,6 +13,8 @@ "MobilenetV3Module_basic", "ConvolutionModule3D_basic", "ConvolutionModule1D_basic", + "Conv_Transpose3dModule_basic", + "Conv_Transpose1dModule_basic", "MaxPool2dWith3dInputModule_basic", "MaxPool2dWithIndicesWith3dInputModule_basic", } diff --git a/python/torch_mlir_e2e_test/test_suite/conv.py b/python/torch_mlir_e2e_test/test_suite/conv.py index aa6b5eeb3fe5..66650c678af4 100644 --- a/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/python/torch_mlir_e2e_test/test_suite/conv.py @@ -260,3 +260,166 @@ def forward(self, inputVec, weight): @register_test_case(module_factory=lambda: ConvolutionModule2DStrided()) def ConvolutionModule2DStrided_basic(module, tu: TestUtils): module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) + +# ============================================================================== + +class ConvolutionModule2DTranspose(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten.convolution(inputVec, + weight, + bias=None, + stride=[1, 1], + padding=[1, 1], + dilation=[1, 1], + transposed=True, + output_padding=[0, 0], + groups=1) + + +@register_test_case(module_factory=lambda: ConvolutionModule2DTranspose()) +def ConvolutionModule2DTranspose_basic(module, tu: TestUtils): + module.forward(torch.randn(3, 3, 4, 4), torch.randn(3, 3, 2, 2)) + +class ConvolutionModule2DTransposeStrided(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten.convolution(inputVec, + weight, + bias=None, + stride=[2, 2], + padding=[1, 1], + dilation=[1, 1], + transposed=True, + output_padding=[0, 0], + groups=1) + + +@register_test_case(module_factory=lambda: ConvolutionModule2DTransposeStrided()) +def ConvolutionModule2DTransposeStrided_basic(module, tu: TestUtils): + module.forward(torch.randn(5, 2, 5, 6), torch.randn(2, 5, 2, 2)) + +class ConvolutionModule2DTransposeStridedStatic(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([5, 2, 5, 6], torch.float32, True), + ([2, 5, 2, 2], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten.convolution(inputVec, + weight, + bias=None, + stride=[2, 2], + padding=[1, 1], + dilation=[1, 1], + transposed=True, + output_padding=[0, 0], + groups=1) + + +@register_test_case(module_factory=lambda: ConvolutionModule2DTransposeStridedStatic()) +def ConvolutionModule2DTransposeStridedStatic_basic(module, tu: TestUtils): + module.forward(torch.randn(5, 2, 5, 6), torch.randn(2, 5, 2, 2)) + + +class Conv_Transpose1dModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten.conv_transpose1d(inputVec, + weight, + bias=None, + stride=[2], + padding=[1], + dilation=[1], + output_padding=[0], + groups=1) + + +@register_test_case(module_factory=lambda: Conv_Transpose3dModule()) +def Conv_Transpose3dModule_basic(module, tu: TestUtils): + module.forward(torch.randn(5, 2, 5), torch.randn(2, 5, 2)) + + +class Conv_Transpose2dModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten.conv_transpose2d(inputVec, + weight, + bias=None, + stride=[2, 2], + padding=[1, 1], + dilation=[1, 1], + output_padding=[0, 0], + groups=1) + + +@register_test_case(module_factory=lambda: Conv_Transpose2dModule()) +def Conv_Transpose2dModule_basic(module, tu: TestUtils): + module.forward(torch.randn(5, 2, 5, 6), torch.randn(2, 5, 2, 2)) + +class Conv_Transpose3dModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten.conv_transpose3d(inputVec, + weight, + bias=None, + stride=[2, 2, 2], + padding=[1, 1, 1], + dilation=[1, 1, 1], + output_padding=[0, 0, 0], + groups=1) + + +@register_test_case(module_factory=lambda: Conv_Transpose3dModule()) +def Conv_Transpose3dModule_basic(module, tu: TestUtils): + module.forward(torch.randn(5, 2, 5, 6, 4), torch.randn(2, 5, 2, 2, 2))