From 5906761ad4c01610fb03c8987ffa993f292fbe53 Mon Sep 17 00:00:00 2001 From: George Petterson Date: Sat, 25 Jun 2022 02:36:05 -0400 Subject: [PATCH] Address comments --- lib/Conversion/TorchToLinalg/Linear.cpp | 13 ++----------- lib/Conversion/TorchToLinalg/Utils.cpp | 4 ++-- lib/Conversion/TorchToLinalg/Utils.h | 3 +-- .../Torch/Transforms/DecomposeComplexOps.cpp | 4 ++-- 4 files changed, 7 insertions(+), 17 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index d3cf28da4b63..9482187e5bed 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -7,7 +7,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "../PassDetail.h" @@ -21,12 +20,9 @@ #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" -#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" -#include "llvm/ADT/ArrayRef.h" #include -#include using namespace mlir; using namespace mlir::torch; @@ -776,7 +772,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { Value outerSize = rewriter.create(loc, offset, c2); outerSize = rewriter.create(loc, outerSize, innerSize); - innerSizes.push_back(innerSize); outerSizes.push_back(outerSize); offsets.push_back(offset); } @@ -792,8 +787,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { SmallVector insertSizes = getTensorSizes(rewriter, loc, input); paddedInput = rewriter.create( - loc, torch_to_linalg::dynamicCast(rewriter, loc, input), initTensor, - offsets, insertSizes, strideIndexValues); + loc, torch_to_linalg::removeSizeInformation(rewriter, loc, input), + initTensor, offsets, insertSizes, strideIndexValues); // Calculate output dims for (size_t i = 0; i < numSpacialDims; i++) @@ -804,11 +799,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { // Set stride to 1 strideInts.clear(); strideInts.append(numSpacialDims, 1); - strideIntValues = getAsConstantIntValues(rewriter, loc, strideInts); - paddingIntValues.clear(); - for (auto pad = offsets.begin() + 2; pad < offsets.end(); pad++) - paddingIntValues.push_back(castIndexToInt(*pad)); } else { // Pad input SmallVector paddingIncludingNC = {0, 0}; diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index f5331564d62c..57a50a688f81 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -20,7 +20,6 @@ #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" -#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" @@ -365,7 +364,8 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( return success(); } -Value torch_to_linalg::dynamicCast(OpBuilder &b, Location loc, Value tensor) { +Value torch_to_linalg::removeSizeInformation(OpBuilder &b, Location loc, + Value tensor) { auto tensorType = tensor.getType().cast(); auto rank = tensorType.getRank(); SmallVector unknownSizes(rank, kUnknownSize); diff --git a/lib/Conversion/TorchToLinalg/Utils.h b/lib/Conversion/TorchToLinalg/Utils.h index ded9d6c12e60..f57c7eaa376d 100644 --- a/lib/Conversion/TorchToLinalg/Utils.h +++ b/lib/Conversion/TorchToLinalg/Utils.h @@ -71,8 +71,7 @@ LogicalResult broadcastToGivenShape(Operation *op, PatternRewriter &rewriter, // Cast a tensor to a rank-equivalent tensor of unknown size, i.e. <1x2xf32> -> // -Value dynamicCast(OpBuilder &b, Location loc, Value tensor); - +Value removeSizeInformation(OpBuilder &b, Location loc, Value tensor); } // namespace torch_to_linalg } // namespace torch } // namespace mlir diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index abe5216f5e64..f60f192ac014 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -980,8 +980,8 @@ class DecomposeAtenConvTranspose2dOp Value cstTrue = rewriter.create(op.getLoc(), true); rewriter.replaceOpWithNewOp( op, op->getResultTypes(), op.input(), op.weight(), op.bias(), - op.stride(), op.padding(), op.dilation(), cstTrue, op.output_padding(), - op.groups()); + op.stride(), op.padding(), op.dilation(), /*transposed=*/cstTrue, + op.output_padding(), op.groups()); return success(); }