diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 5712da0cfa0..b8a3177890c 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -7,7 +7,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "../PassDetail.h" @@ -21,12 +20,9 @@ #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" -#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" -#include "llvm/ADT/ArrayRef.h" #include -#include using namespace mlir; using namespace mlir::torch; @@ -745,7 +741,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { // Calculate padded input size, allocate tensor SmallVector outerSizes{N, inChannels}; - SmallVector innerSizes{N, inChannels}; SmallVector offsets{c0, c0}; for (size_t i = 0; i < numSpacialDims; i++) { Value innerSize = rewriter.create(loc, inDims[i], c1); @@ -762,7 +757,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { Value outerSize = rewriter.create(loc, offset, c2); outerSize = rewriter.create(loc, outerSize, innerSize); - innerSizes.push_back(innerSize); outerSizes.push_back(outerSize); offsets.push_back(offset); } @@ -778,8 +772,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { SmallVector insertSizes = getTensorSizes(rewriter, loc, input); paddedInput = rewriter.create( - loc, torch_to_linalg::dynamicCast(rewriter, loc, input), initTensor, - offsets, insertSizes, strideIndexValues); + loc, torch_to_linalg::removeSizeInformation(rewriter, loc, input), + initTensor, offsets, insertSizes, strideIndexValues); // Calculate output dims for (size_t i = 0; i < numSpacialDims; i++) @@ -790,11 +784,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { // Set stride to 1 strideInts.clear(); strideInts.append(numSpacialDims, 1); - strideIntValues = getAsConstantIntValues(rewriter, loc, strideInts); - paddingIntValues.clear(); - for (auto pad = offsets.begin() + 2; pad < offsets.end(); pad++) - paddingIntValues.push_back(castIndexToInt(*pad)); } else { // Pad input SmallVector paddingIncludingNC = {0, 0}; diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index f5331564d62..57a50a688f8 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -20,7 +20,6 @@ #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" -#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" @@ -365,7 +364,8 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( return success(); } -Value torch_to_linalg::dynamicCast(OpBuilder &b, Location loc, Value tensor) { +Value torch_to_linalg::removeSizeInformation(OpBuilder &b, Location loc, + Value tensor) { auto tensorType = tensor.getType().cast(); auto rank = tensorType.getRank(); SmallVector unknownSizes(rank, kUnknownSize); diff --git a/lib/Conversion/TorchToLinalg/Utils.h b/lib/Conversion/TorchToLinalg/Utils.h index ded9d6c12e6..f57c7eaa376 100644 --- a/lib/Conversion/TorchToLinalg/Utils.h +++ b/lib/Conversion/TorchToLinalg/Utils.h @@ -71,8 +71,7 @@ LogicalResult broadcastToGivenShape(Operation *op, PatternRewriter &rewriter, // Cast a tensor to a rank-equivalent tensor of unknown size, i.e. <1x2xf32> -> // -Value dynamicCast(OpBuilder &b, Location loc, Value tensor); - +Value removeSizeInformation(OpBuilder &b, Location loc, Value tensor); } // namespace torch_to_linalg } // namespace torch } // namespace mlir diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 7146c767f82..5dc38cb85b2 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -803,8 +803,8 @@ class DecomposeAtenConvTranspose2dOp Value cstTrue = rewriter.create(op.getLoc(), true); rewriter.replaceOpWithNewOp( op, op->getResultTypes(), op.input(), op.weight(), op.bias(), - op.stride(), op.padding(), op.dilation(), cstTrue, op.output_padding(), - op.groups()); + op.stride(), op.padding(), op.dilation(), /*transposed=*/cstTrue, + op.output_padding(), op.groups()); return success(); } diff --git a/python/torch_mlir_e2e_test/test_suite/conv.py b/python/torch_mlir_e2e_test/test_suite/conv.py index 66650c678af..76d1baa66b7 100644 --- a/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/python/torch_mlir_e2e_test/test_suite/conv.py @@ -358,17 +358,17 @@ def __init__(self): ]) def forward(self, inputVec, weight): return torch.ops.aten.conv_transpose1d(inputVec, - weight, - bias=None, - stride=[2], - padding=[1], - dilation=[1], - output_padding=[0], - groups=1) + weight, + bias=None, + stride=[2], + padding=[1], + dilation=[1], + output_padding=[0], + groups=1) -@register_test_case(module_factory=lambda: Conv_Transpose3dModule()) -def Conv_Transpose3dModule_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: Conv_Transpose1dModule()) +def Conv_Transpose1dModule_basic(module, tu: TestUtils): module.forward(torch.randn(5, 2, 5), torch.randn(2, 5, 2)) @@ -385,13 +385,13 @@ def __init__(self): ]) def forward(self, inputVec, weight): return torch.ops.aten.conv_transpose2d(inputVec, - weight, - bias=None, - stride=[2, 2], - padding=[1, 1], - dilation=[1, 1], - output_padding=[0, 0], - groups=1) + weight, + bias=None, + stride=[2, 2], + padding=[1, 1], + dilation=[1, 1], + output_padding=[0, 0], + groups=1) @register_test_case(module_factory=lambda: Conv_Transpose2dModule()) @@ -411,13 +411,13 @@ def __init__(self): ]) def forward(self, inputVec, weight): return torch.ops.aten.conv_transpose3d(inputVec, - weight, - bias=None, - stride=[2, 2, 2], - padding=[1, 1, 1], - dilation=[1, 1, 1], - output_padding=[0, 0, 0], - groups=1) + weight, + bias=None, + stride=[2, 2, 2], + padding=[1, 1, 1], + dilation=[1, 1, 1], + output_padding=[0, 0, 0], + groups=1) @register_test_case(module_factory=lambda: Conv_Transpose3dModule())