From f012279fa28e7ef1c80a6e2fbaeccac32d44ad89 Mon Sep 17 00:00:00 2001 From: gpetters94 Date: Wed, 24 Aug 2022 12:19:35 -0400 Subject: [PATCH] Add transposed case for at::convolution (#917) Also adds a decomposition for aten::conv_transposed2d.input --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 90 ++++++++++ lib/Conversion/TorchToLinalg/Linear.cpp | 138 ++++++++++++--- lib/Conversion/TorchToLinalg/Utils.cpp | 33 ++++ lib/Conversion/TorchToLinalg/Utils.h | 11 ++ .../Torch/Transforms/DecomposeComplexOps.cpp | 22 +++ lib/Dialect/Torch/Transforms/RefineTypes.cpp | 3 +- lib/Dialect/Torch/Transforms/ShapeLibrary.cpp | 141 ++++++++++++++- .../jit_ir/build_tools/shape_lib_gen.py | 5 +- .../jit_ir/build_tools/torch_ods_gen.py | 3 + .../test_suite/__init__.py | 2 + python/torch_mlir_e2e_test/test_suite/conv.py | 163 ++++++++++++++++++ 11 files changed, 585 insertions(+), 26 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 371fdcf2a1de..e2f2b5061e12 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -3347,6 +3347,96 @@ def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [ }]; } +def Torch_AtenConvTranspose1dOp : Torch_Op<"aten.conv_transpose1d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::conv_transpose1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$output_padding, + Torch_IntType:$groups, + AnyTorchListOfTorchIntType:$dilation + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenConvTranspose1dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 8, 1); + } + void AtenConvTranspose1dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 8, 1); + } + }]; +} + +def Torch_AtenConvTranspose2dInputOp : Torch_Op<"aten.conv_transpose2d.input", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::conv_transpose2d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$output_padding, + Torch_IntType:$groups, + AnyTorchListOfTorchIntType:$dilation + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenConvTranspose2dInputOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 8, 1); + } + void AtenConvTranspose2dInputOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 8, 1); + } + }]; +} + +def Torch_AtenConvTranspose3dInputOp : Torch_Op<"aten.conv_transpose3d.input", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::conv_transpose3d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$output_padding, + Torch_IntType:$groups, + AnyTorchListOfTorchIntType:$dilation + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenConvTranspose3dInputOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 8, 1); + } + void AtenConvTranspose3dInputOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 8, 1); + } + }]; +} + def Torch_AtenConvolutionOp : Torch_Op<"aten.convolution", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 1425e0d67c46..9482187e5bed 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -22,6 +22,7 @@ #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include using namespace mlir; using namespace mlir::torch; @@ -635,12 +636,18 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { Value input = adaptor.input(); /* in form of N*C*H*W */ Value weight = adaptor.weight(); /* in form of F*C*H*W */ + bool transposed = true; + if (!matchPattern(op.transposed(), m_TorchConstantBool(&transposed))) + return rewriter.notifyMatchFailure( + op, "unimplemented: only constant transposed supported"); + Type elementType = input.getType().cast().getElementType(); if (!elementType.isa()) return op.emitError("unimplemented: non-floating point type"); size_t inRank = input.getType().cast().getRank(); - if (inRank != 4) + size_t numSpacialDims = inRank - 2; + if (numSpacialDims != 2) return rewriter.notifyMatchFailure( op, "unimplemented: only 2D convolution currently supported"); @@ -674,13 +681,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { for (size_t i = 2; i < inRank; i++) weightDims.push_back(getDimOp(rewriter, loc, weight, i)); - // Guard unused values (transposed) - bool transposed = true; - if (!matchPattern(op.transposed(), m_TorchConstantBool(&transposed)) || - transposed) - return rewriter.notifyMatchFailure( - op, "unimplemented: only non-transposed convolution supported"); - // Checks for valid group size int64_t groupSize; if (!matchPattern(op.groups(), m_TorchConstantInt(&groupSize))) @@ -701,19 +701,119 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { "invalid: groups must divide input channel size evenly."); validate(weightBatch, "invalid: groups must divide weight batch size evenly."); - - SmallVector paddingIntValues = - getAsConstantIntValues(rewriter, loc, paddingInts); SmallVector dilationIntValues = getAsConstantIntValues(rewriter, loc, dilationInts); + SmallVector paddingIntValues = + getAsConstantIntValues(rewriter, loc, paddingInts); SmallVector strideIntValues = getAsConstantIntValues(rewriter, loc, strideInts); + // Pad the input tensor according to padding. SmallVector outDims{inBatch, weightBatch}; - for (size_t i = 0; i < inRank - 2; i++) - outDims.push_back(torch_to_linalg::getOutputDimForConvOps( - rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i], - castIndexToInt(weightDims[i]), strideIntValues[i])); + Value paddedInput; + if (transposed) { + Value c0 = + rewriter.create(loc, rewriter.getIndexAttr(0)); + Value c1 = + rewriter.create(loc, rewriter.getIndexAttr(1)); + Value c2 = + rewriter.create(loc, rewriter.getIndexAttr(2)); + + // Transpose and flip weight + SmallVector weightInitDims = getTensorSizes(rewriter, loc, weight); + std::iter_swap(weightInitDims.begin(), weightInitDims.begin() + 1); + outDims[1] = weightInitDims[0]; + Value weightInitTensor = + createZeroInitTensor(rewriter, loc, weightInitDims, elementType); + SmallVector iteratorTypes(inRank, + getParallelIteratorTypeName()); + SmallVector indexingMaps( + 2, AffineMap::getMultiDimIdentityMap(inRank, context)); + weight = rewriter + .create( + loc, weightInitTensor.getType(), weight, + weightInitTensor, indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + SmallVector indices; + for (size_t i = 0; i < inRank; i++) + indices.push_back(b.create(loc, i)); + std::iter_swap(indices.begin(), indices.begin() + 1); + // Flip only the spatial dimensions (from 2 to inRank) + for (size_t flipDim = 2; flipDim < inRank; flipDim++) { + indices[flipDim] = b.create( + loc, + b.create( + loc, weightInitDims[flipDim], c1), + indices[flipDim]); + } + Value res = + b.create(loc, weight, indices) + .getResult(); + b.create(loc, res); + }) + .getResult(0); + + // Calculate padded input size, allocate tensor + SmallVector outerSizes{inBatch, inChannels}; + SmallVector innerSizes{inBatch, inChannels}; + SmallVector offsets{c0, c0}; + for (size_t i = 0; i < numSpacialDims; i++) { + Value innerSize = rewriter.create(loc, inDims[i], c1); + innerSize = rewriter.create( + loc, innerSize, castIntToIndex(rewriter, loc, strideIntValues[i])); + innerSize = rewriter.create(loc, innerSize, c1); + + Value offset = rewriter.create(loc, weightDims[i], c1); + offset = rewriter.create( + loc, offset, castIntToIndex(rewriter, loc, dilationIntValues[i])); + offset = rewriter.create( + loc, offset, castIntToIndex(rewriter, loc, paddingIntValues[i])); + + Value outerSize = rewriter.create(loc, offset, c2); + outerSize = rewriter.create(loc, outerSize, innerSize); + + outerSizes.push_back(outerSize); + offsets.push_back(offset); + } + + // Allocate padded input tensor + Value initTensor = + createZeroInitTensor(rewriter, loc, outerSizes, elementType); + + // Insert input into allocated tensor + SmallVector strideIndexValues{c1, c1}; + for (auto stride : strideIntValues) + strideIndexValues.push_back(castIntToIndex(rewriter, loc, stride)); + SmallVector insertSizes = getTensorSizes(rewriter, loc, input); + + paddedInput = rewriter.create( + loc, torch_to_linalg::removeSizeInformation(rewriter, loc, input), + initTensor, offsets, insertSizes, strideIndexValues); + + // Calculate output dims + for (size_t i = 0; i < numSpacialDims; i++) + outDims.push_back(torch_to_linalg::getOutputDimForConvTransposeOps( + rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i], + castIndexToInt(weightDims[i]), strideIntValues[i])); + + // Set stride to 1 + strideInts.clear(); + strideInts.append(numSpacialDims, 1); + + } else { + // Pad input + SmallVector paddingIncludingNC = {0, 0}; + paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(), + paddingInts.end()); + paddedInput = torch_to_linalg::getZeroPaddedTensor(op, rewriter, input, + paddingIncludingNC); + + // Calculate output dims + for (size_t i = 0; i < numSpacialDims; i++) + outDims.push_back(torch_to_linalg::getOutputDimForConvOps( + rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i], + castIndexToInt(weightDims[i]), strideIntValues[i])); + } Value initTensor = rewriter.create(loc, outDims, elementType); @@ -769,14 +869,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { SmallVector weightSliceSizes{weightStride, weightChannels}; weightSliceSizes.append(weightDims); - // Pad the input tensor according to padding. - SmallVector paddingIncludingNC = {0, 0}; - paddingIncludingNC.append(paddingInts); - - // Pad inputSlice - Value paddedInput = torch_to_linalg::getZeroPaddedTensor( - op, rewriter, input, paddingIncludingNC); - Value conv; if (groupSize == 1) { // TODO: add 1D and 3D case diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 0d04dc552640..57a50a688f81 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -97,6 +97,31 @@ Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc, return castIntToIndex(b, loc, out); } +Value torch_to_linalg::getOutputDimForConvTransposeOps( + OpBuilder &b, Location loc, Value in, Value paddingInt, Value dilationInt, + Value kernelSizeInt, Value strideInt) { + Value c1 = b.create(loc, b.getI64IntegerAttr(1)); + Value c2 = b.create(loc, b.getI64IntegerAttr(2)); + + // (in - 1) * stride + Value inStrided = + b.create(loc, castIndexToInt64(b, loc, in), c1); + inStrided = b.create(loc, inStrided, strideInt); + + // 2 * padding + Value doublePadding = b.create(loc, paddingInt, c2); + + // (kernelSize - 1) * dilation + Value kernelDilated = b.create(loc, kernelSizeInt, c1); + kernelDilated = b.create(loc, kernelDilated, dilationInt); + + Value out = b.create(loc, inStrided, doublePadding); + out = b.create(loc, out, kernelDilated); + out = b.create(loc, out, c1); + + return castIntToIndex(b, loc, out); +} + Value torch_to_linalg::createReductionLinalgGeneric( OpBuilder &b, Location loc, const ReductionOpInfo &opInfo, Value initElem, function_ref bodyBuild) { @@ -338,3 +363,11 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( return success(); } + +Value torch_to_linalg::removeSizeInformation(OpBuilder &b, Location loc, + Value tensor) { + auto tensorType = tensor.getType().cast(); + auto rank = tensorType.getRank(); + SmallVector unknownSizes(rank, kUnknownSize); + return b.create(loc, tensorType.clone(unknownSizes), tensor); +} diff --git a/lib/Conversion/TorchToLinalg/Utils.h b/lib/Conversion/TorchToLinalg/Utils.h index 6279b8c9e802..f57c7eaa376d 100644 --- a/lib/Conversion/TorchToLinalg/Utils.h +++ b/lib/Conversion/TorchToLinalg/Utils.h @@ -39,6 +39,14 @@ Value getOutputDimForConvOps(OpBuilder &b, Location loc, Value in, Value kernelSizeInt, Value strideInt, bool ceilMode = false); +// As above but for transposed convolution ops +// Along each dim: +// dim_out = +// (dim_in - 1) * stride - 2 * padding + dilation * (kernelSize - 1) + 1 +Value getOutputDimForConvTransposeOps(OpBuilder &b, Location loc, Value in, + Value paddingInt, Value dilationInt, + Value kernelSizeInt, Value strideInt); + // Create a reduction of `opInfo.tensorOperand`, reducing along the dimensions // in `opInfo.dimSet`. If `opInfo.keepDim` is true, the output tensor is the // same rank as the `opInfo.tensorOperand` and reduced dimensions are set to @@ -61,6 +69,9 @@ LogicalResult broadcastToGivenShape(Operation *op, PatternRewriter &rewriter, SmallVector broadcastToShape, Value &result); +// Cast a tensor to a rank-equivalent tensor of unknown size, i.e. <1x2xf32> -> +// +Value removeSizeInformation(OpBuilder &b, Location loc, Value tensor); } // namespace torch_to_linalg } // namespace torch } // namespace mlir diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 77c0ec6d2610..6d95ccc09215 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1102,6 +1102,26 @@ class DecomposeAtenConv2dOp : public OpRewritePattern { }; } // namespace +// Decompose aten.conv_transpose2d to aten.convolution +namespace { +class DecomposeAtenConvTranspose2dOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenConvTranspose2dInputOp op, + PatternRewriter &rewriter) const override { + + Value cstTrue = rewriter.create(op.getLoc(), true); + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op.input(), op.weight(), op.bias(), + op.stride(), op.padding(), op.dilation(), /*transposed=*/cstTrue, + op.output_padding(), op.groups()); + + return success(); + } +}; +} // namespace + // Decompose aten.addmm into aten.mm and aten.add.Tensor op. namespace { class DecomposeAtenAddmmOp : public OpRewritePattern { @@ -2686,6 +2706,8 @@ class DecomposeComplexOpsPass context); target.addIllegalOp(); patterns.add(context); + target.addIllegalOp(); + patterns.add(context); patterns.add(context); target.addIllegalOp(); patterns.add(context); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 4597208fcb32..a67f65ec3444 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -713,7 +713,8 @@ void TypeAnalysis::visitOperation(Operation *op, // Promote the two dtypes assuming non-zero rank. if (isa(op)) { + Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp, + AtenConvolutionOverrideableOp, AtenConvTranspose2dInputOp>(op)) { auto knowledge = ValueKnowledge::getTensorPessimisticValueState(op->getContext()); knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index 2afd5640c1d7..4868a711b47e 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -3494,6 +3494,137 @@ module { %6 = torch.prim.TupleConstruct %0, %2, %5 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list> return %6 : !torch.tuple, list, list> } + func.func @__torch__.torch.jit._shape_functions.conv_forwards(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int) -> !torch.list { + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %0 = torch.aten.len.t %arg5 : !torch.list -> !torch.int + %1 = torch.aten.gt.int %0, %int0 : !torch.int, !torch.int -> !torch.bool + %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %3 = torch.prim.ListConstruct : () -> !torch.list + %4 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int + %5 = torch.aten.append.t %3, %4 : !torch.list, !torch.int -> !torch.list + %6 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int + %7 = torch.aten.append.t %3, %6 : !torch.list, !torch.int -> !torch.list + %8 = torch.aten.__range_length %int2, %2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int + torch.prim.Loop %8, %true, init() { + ^bb0(%arg9: !torch.int): + %9 = torch.aten.__derive_index %arg9, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int + %10 = torch.prim.If %1 -> (!torch.int) { + %11 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int + %12 = torch.aten.__getitem__.t %arg5, %11 : !torch.list, !torch.int -> !torch.int + torch.prim.If.yield %12 : !torch.int + } else { + torch.prim.If.yield %int1 : !torch.int + } + torch.prim.If %arg6 -> () { + %11 = torch.aten.__getitem__.t %arg1, %9 : !torch.list, !torch.int -> !torch.int + %12 = torch.aten.sub.int %11, %int1 : !torch.int, !torch.int -> !torch.int + %13 = torch.aten.mul.int %10, %12 : !torch.int, !torch.int -> !torch.int + %14 = torch.aten.__getitem__.t %arg0, %9 : !torch.list, !torch.int -> !torch.int + %15 = torch.aten.sub.int %14, %int1 : !torch.int, !torch.int -> !torch.int + %16 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int + %17 = torch.aten.__getitem__.t %arg3, %16 : !torch.list, !torch.int -> !torch.int + %18 = torch.aten.mul.int %15, %17 : !torch.int, !torch.int -> !torch.int + %19 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int + %20 = torch.aten.__getitem__.t %arg4, %19 : !torch.list, !torch.int -> !torch.int + %21 = torch.aten.mul.int %20, %int2 : !torch.int, !torch.int -> !torch.int + %22 = torch.aten.sub.int %18, %21 : !torch.int, !torch.int -> !torch.int + %23 = torch.aten.add.int %22, %13 : !torch.int, !torch.int -> !torch.int + %24 = torch.aten.add.int %23, %int1 : !torch.int, !torch.int -> !torch.int + %25 = torch.aten.append.t %3, %24 : !torch.list, !torch.int -> !torch.list + torch.prim.If.yield + } else { + %11 = torch.aten.__getitem__.t %arg1, %9 : !torch.list, !torch.int -> !torch.int + %12 = torch.aten.sub.int %11, %int1 : !torch.int, !torch.int -> !torch.int + %13 = torch.aten.mul.int %10, %12 : !torch.int, !torch.int -> !torch.int + %14 = torch.aten.add.int %13, %int1 : !torch.int, !torch.int -> !torch.int + %15 = torch.aten.__getitem__.t %arg0, %9 : !torch.list, !torch.int -> !torch.int + %16 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int + %17 = torch.aten.__getitem__.t %arg4, %16 : !torch.list, !torch.int -> !torch.int + %18 = torch.aten.mul.int %17, %int2 : !torch.int, !torch.int -> !torch.int + %19 = torch.aten.add.int %15, %18 : !torch.int, !torch.int -> !torch.int + %20 = torch.aten.sub.int %19, %14 : !torch.int, !torch.int -> !torch.int + %21 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int + %22 = torch.aten.__getitem__.t %arg3, %21 : !torch.list, !torch.int -> !torch.int + %23 = torch.aten.floordiv.int %20, %22 : !torch.int, !torch.int -> !torch.int + %24 = torch.aten.add.int %23, %int1 : !torch.int, !torch.int -> !torch.int + %25 = torch.aten.append.t %3, %24 : !torch.list, !torch.int -> !torch.list + torch.prim.If.yield + } + torch.prim.Loop.condition %true, iter() + } : (!torch.int, !torch.bool) -> () + return %3 : !torch.list + } + func.func @__torch__.torch.jit._shape_functions.conv_transpose2d_input(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.optional>, %arg6: !torch.int, %arg7: !torch.optional>) -> !torch.list { + %true = torch.constant.bool true + %none = torch.constant.none + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %0 = torch.aten.__is__ %arg3, %none : !torch.optional>, !torch.none -> !torch.bool + %1 = torch.prim.If %0 -> (!torch.list) { + %15 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + torch.prim.If.yield %15 : !torch.list + } else { + %15 = torch.prim.unchecked_cast %arg3 : !torch.optional> -> !torch.list + torch.prim.If.yield %15 : !torch.list + } + %2 = torch.aten.__is__ %arg4, %none : !torch.optional>, !torch.none -> !torch.bool + %3 = torch.prim.If %2 -> (!torch.list) { + %15 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + torch.prim.If.yield %15 : !torch.list + } else { + %15 = torch.prim.unchecked_cast %arg4 : !torch.optional> -> !torch.list + torch.prim.If.yield %15 : !torch.list + } + %4 = torch.aten.__is__ %arg7, %none : !torch.optional>, !torch.none -> !torch.bool + %5 = torch.prim.If %4 -> (!torch.list) { + %15 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + torch.prim.If.yield %15 : !torch.list + } else { + %15 = torch.prim.unchecked_cast %arg7 : !torch.optional> -> !torch.list + torch.prim.If.yield %15 : !torch.list + } + %6 = torch.aten.len.t %5 : !torch.list -> !torch.int + %7 = torch.aten.gt.int %6, %int0 : !torch.int, !torch.int -> !torch.bool + %8 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %9 = torch.prim.ListConstruct : () -> !torch.list + %10 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int + %11 = torch.aten.append.t %9, %10 : !torch.list, !torch.int -> !torch.list + %12 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int + %13 = torch.aten.append.t %9, %12 : !torch.list, !torch.int -> !torch.list + %14 = torch.aten.__range_length %int2, %8, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int + torch.prim.Loop %14, %true, init() { + ^bb0(%arg8: !torch.int): + %15 = torch.aten.__derive_index %arg8, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int + %16 = torch.prim.If %7 -> (!torch.int) { + %32 = torch.aten.sub.int %15, %int2 : !torch.int, !torch.int -> !torch.int + %33 = torch.aten.__getitem__.t %5, %32 : !torch.list, !torch.int -> !torch.int + torch.prim.If.yield %33 : !torch.int + } else { + torch.prim.If.yield %int1 : !torch.int + } + %17 = torch.aten.__getitem__.t %arg1, %15 : !torch.list, !torch.int -> !torch.int + %18 = torch.aten.sub.int %17, %int1 : !torch.int, !torch.int -> !torch.int + %19 = torch.aten.mul.int %16, %18 : !torch.int, !torch.int -> !torch.int + %20 = torch.aten.__getitem__.t %arg0, %15 : !torch.list, !torch.int -> !torch.int + %21 = torch.aten.sub.int %20, %int1 : !torch.int, !torch.int -> !torch.int + %22 = torch.aten.sub.int %15, %int2 : !torch.int, !torch.int -> !torch.int + %23 = torch.aten.__getitem__.t %1, %22 : !torch.list, !torch.int -> !torch.int + %24 = torch.aten.mul.int %21, %23 : !torch.int, !torch.int -> !torch.int + %25 = torch.aten.sub.int %15, %int2 : !torch.int, !torch.int -> !torch.int + %26 = torch.aten.__getitem__.t %3, %25 : !torch.list, !torch.int -> !torch.int + %27 = torch.aten.mul.int %26, %int2 : !torch.int, !torch.int -> !torch.int + %28 = torch.aten.sub.int %24, %27 : !torch.int, !torch.int -> !torch.int + %29 = torch.aten.add.int %28, %19 : !torch.int, !torch.int -> !torch.int + %30 = torch.aten.add.int %29, %int1 : !torch.int, !torch.int -> !torch.int + %31 = torch.aten.append.t %9, %30 : !torch.list, !torch.int -> !torch.list + torch.prim.Loop.condition %true, iter() + } : (!torch.int, !torch.bool) -> () + return %9 : !torch.list + } func.func @__torch__.torch.jit._shape_functions.flatten(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list { %none = torch.constant.none %str = torch.constant.str "AssertionError: " @@ -6337,8 +6468,16 @@ module { %0 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list return %0 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten.conv_transpose2d.input"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.list { + %0 = torch.derefine %arg3 : !torch.list to !torch.optional> + %1 = torch.derefine %arg4 : !torch.list to !torch.optional> + %2 = torch.derefine %arg5 : !torch.list to !torch.optional> + %3 = torch.derefine %arg7 : !torch.list to !torch.optional> + %4 = call @__torch__.torch.jit._shape_functions.conv_transpose2d_input(%arg0, %arg1, %arg2, %0, %1, %2, %arg6, %3) : (!torch.list, !torch.list, !torch.optional>, !torch.optional>, !torch.optional>, !torch.optional>, !torch.int, !torch.optional>) -> !torch.list + return %4 : !torch.list + } func.func @"__torch_mlir_shape_fn.aten.convolution"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.conv_output_size(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg8) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list + %0 = call @__torch__.torch.jit._shape_functions.conv_forwards(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list return %0 : !torch.list } func.func @"__torch_mlir_shape_fn.aten._convolution"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool, %arg12: !torch.bool) -> !torch.list { diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index 41379eb66b51..3f547720a40a 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -938,8 +938,11 @@ def aten〇topk(self: List[int], k: int, dim: int = -1, largest: bool = True, so def aten〇conv2d(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), groups: int = 1) -> List[int]: return upstream_shape_functions.conv2d(input, weight, bias, stride, padding, dilation, groups) +def aten〇conv_transpose2d〇input(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), output_padding: List[int] = (0, 0), groups: int = 1, dilation: List[int] = (1, 1)) -> List[int]: + return upstream_shape_functions.conv_transpose2d_input(input, weight, bias, stride, padding, output_padding, groups, dilation) + def aten〇convolution(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int) -> List[int]: - return upstream_shape_functions.conv_output_size(input, weight, bias, stride, padding, dilation, groups) + return upstream_shape_functions.conv_forwards(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups) def aten〇_convolution(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> List[int]: return aten〇convolution(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 4f223b1e360b..e2ff4146bc16 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -334,6 +334,9 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" ) + emit("aten::conv_transpose1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)") + emit("aten::conv_transpose2d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)") + emit("aten::conv_transpose3d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)") emit("aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)") emit("aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)") emit("aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/__init__.py b/python/torch_mlir_e2e_test/test_suite/__init__.py index 63d30d8fd43c..af77f919b306 100644 --- a/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -11,6 +11,8 @@ "TableBatchEmbeddingModule_basic", "Convolution3DModule_basic", "Convolution1DModule_basic", + "Conv_Transpose3dModule_basic", + "Conv_Transpose1dModule_basic", "MaxPool2dWith3dInputModule_basic", "MaxPool2dWithIndicesWith3dInputModule_basic", } diff --git a/python/torch_mlir_e2e_test/test_suite/conv.py b/python/torch_mlir_e2e_test/test_suite/conv.py index cb92710d2e6c..5175adf20e06 100644 --- a/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/python/torch_mlir_e2e_test/test_suite/conv.py @@ -542,3 +542,166 @@ def forward(self, inputVec, weight): @register_test_case(module_factory=lambda: ConvolutionModule2DGroups()) def ConvolutionModule2DGroups_basic(module, tu: TestUtils): module.forward(torch.randn(1, 32, 4, 4), torch.randn(32, 8, 3, 3)) + +# ============================================================================== + +class ConvolutionModule2DTranspose(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten.convolution(inputVec, + weight, + bias=None, + stride=[1, 1], + padding=[1, 1], + dilation=[1, 1], + transposed=True, + output_padding=[0, 0], + groups=1) + + +@register_test_case(module_factory=lambda: ConvolutionModule2DTranspose()) +def ConvolutionModule2DTranspose_basic(module, tu: TestUtils): + module.forward(torch.randn(3, 3, 4, 4), torch.randn(3, 3, 2, 2)) + +class ConvolutionModule2DTransposeStrided(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten.convolution(inputVec, + weight, + bias=None, + stride=[2, 2], + padding=[1, 1], + dilation=[1, 1], + transposed=True, + output_padding=[0, 0], + groups=1) + + +@register_test_case(module_factory=lambda: ConvolutionModule2DTransposeStrided()) +def ConvolutionModule2DTransposeStrided_basic(module, tu: TestUtils): + module.forward(torch.randn(5, 2, 5, 6), torch.randn(2, 5, 2, 2)) + +class ConvolutionModule2DTransposeStridedStatic(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([5, 2, 5, 6], torch.float32, True), + ([2, 5, 2, 2], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten.convolution(inputVec, + weight, + bias=None, + stride=[2, 2], + padding=[1, 1], + dilation=[1, 1], + transposed=True, + output_padding=[0, 0], + groups=1) + + +@register_test_case(module_factory=lambda: ConvolutionModule2DTransposeStridedStatic()) +def ConvolutionModule2DTransposeStridedStatic_basic(module, tu: TestUtils): + module.forward(torch.randn(5, 2, 5, 6), torch.randn(2, 5, 2, 2)) + + +class Conv_Transpose1dModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten.conv_transpose1d(inputVec, + weight, + bias=None, + stride=[2], + padding=[1], + dilation=[1], + output_padding=[0], + groups=1) + + +@register_test_case(module_factory=lambda: Conv_Transpose1dModule()) +def Conv_Transpose1dModule_basic(module, tu: TestUtils): + module.forward(torch.randn(5, 2, 5), torch.randn(2, 5, 2)) + + +class Conv_Transpose2dModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten.conv_transpose2d(inputVec, + weight, + bias=None, + stride=[2, 2], + padding=[1, 1], + dilation=[1, 1], + output_padding=[0, 0], + groups=1) + + +@register_test_case(module_factory=lambda: Conv_Transpose2dModule()) +def Conv_Transpose2dModule_basic(module, tu: TestUtils): + module.forward(torch.randn(5, 2, 5, 6), torch.randn(2, 5, 2, 2)) + +class Conv_Transpose3dModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten.conv_transpose3d(inputVec, + weight, + bias=None, + stride=[2, 2, 2], + padding=[1, 1, 1], + dilation=[1, 1, 1], + output_padding=[0, 0, 0], + groups=1) + + +@register_test_case(module_factory=lambda: Conv_Transpose3dModule()) +def Conv_Transpose3dModule_basic(module, tu: TestUtils): + module.forward(torch.randn(5, 2, 5, 6, 4), torch.randn(2, 5, 2, 2, 2)) \ No newline at end of file