Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetters94 committed Jun 25, 2022
1 parent 864b19b commit e683a3c
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 40 deletions.
14 changes: 2 additions & 12 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"

#include "../PassDetail.h"
Expand All @@ -21,12 +20,9 @@
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/ArrayRef.h"
#include <algorithm>
#include <bits/stdint-intn.h>

using namespace mlir;
using namespace mlir::torch;
Expand Down Expand Up @@ -644,7 +640,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {

// Calculate padded input size, allocate tensor
SmallVector<Value> outerSizes{N, inChannels};
SmallVector<Value> innerSizes{N, inChannels};
SmallVector<Value> offsets{c0, c0};
for (size_t i = 0; i < numSpacialDims; i++) {
Value innerSize = rewriter.create<arith::SubIOp>(loc, inDims[i], c1);
Expand All @@ -661,7 +656,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
Value outerSize = rewriter.create<arith::MulIOp>(loc, offset, c2);
outerSize = rewriter.create<arith::AddIOp>(loc, outerSize, innerSize);

innerSizes.push_back(innerSize);
outerSizes.push_back(outerSize);
offsets.push_back(offset);
}
Expand All @@ -677,8 +671,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
SmallVector<Value> insertSizes = getTensorSizes(rewriter, loc, input);

paddedInput = rewriter.create<tensor::InsertSliceOp>(
loc, torch_to_linalg::dynamicCast(rewriter, loc, input), initTensor,
offsets, insertSizes, strideIndexValues);
loc, torch_to_linalg::removeSizeInformation(rewriter, loc, input),
initTensor, offsets, insertSizes, strideIndexValues);

// Calculate output dims
for (size_t i = 0; i < numSpacialDims; i++)
Expand All @@ -689,11 +683,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
// Set stride to 1
strideInts.clear();
strideInts.append(numSpacialDims, 1);
strideIntValues = getAsConstantIntValues(rewriter, loc, strideInts);

paddingIntValues.clear();
for (auto pad = offsets.begin() + 2; pad < offsets.end(); pad++)
paddingIntValues.push_back(castIndexToInt(*pad));
} else {
// Pad input
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TorchToLinalg/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"

Expand Down Expand Up @@ -283,7 +282,8 @@ Value torch_to_linalg::createElementwiseLinalgGeneric(
.getResult(0);
}

Value torch_to_linalg::dynamicCast(OpBuilder &b, Location loc, Value tensor) {
Value torch_to_linalg::removeSizeInformation(OpBuilder &b, Location loc,
Value tensor) {
auto tensorType = tensor.getType().cast<RankedTensorType>();
auto rank = tensorType.getRank();
SmallVector<int64_t> unknownSizes(rank, kUnknownSize);
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TorchToLinalg/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ Value createElementwiseLinalgGeneric(

// Cast a tensor to a rank-equivalent tensor of unknown size, i.e. <1x2xf32> ->
// <?x?xf32>
Value dynamicCast(OpBuilder &b, Location loc, Value tensor);
Value removeSizeInformation(OpBuilder &b, Location loc, Value tensor);
} // namespace torch_to_linalg
} // namespace torch
} // namespace mlir
4 changes: 2 additions & 2 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -803,8 +803,8 @@ class DecomposeAtenConvTranspose2dOp
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), true);
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
op, op->getResultTypes(), op.input(), op.weight(), op.bias(),
op.stride(), op.padding(), op.dilation(), cstTrue, op.output_padding(),
op.groups());
op.stride(), op.padding(), op.dilation(), /*transposed=*/cstTrue,
op.output_padding(), op.groups());

return success();
}
Expand Down
46 changes: 23 additions & 23 deletions python/torch_mlir_e2e_test/test_suite/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,17 +358,17 @@ def __init__(self):
])
def forward(self, inputVec, weight):
return torch.ops.aten.conv_transpose1d(inputVec,
weight,
bias=None,
stride=[2],
padding=[1],
dilation=[1],
output_padding=[0],
groups=1)
weight,
bias=None,
stride=[2],
padding=[1],
dilation=[1],
output_padding=[0],
groups=1)


@register_test_case(module_factory=lambda: Conv_Transpose3dModule())
def Conv_Transpose3dModule_basic(module, tu: TestUtils):
@register_test_case(module_factory=lambda: Conv_Transpose1dModule())
def Conv_Transpose1dModule_basic(module, tu: TestUtils):
module.forward(torch.randn(5, 2, 5), torch.randn(2, 5, 2))


Expand All @@ -385,13 +385,13 @@ def __init__(self):
])
def forward(self, inputVec, weight):
return torch.ops.aten.conv_transpose2d(inputVec,
weight,
bias=None,
stride=[2, 2],
padding=[1, 1],
dilation=[1, 1],
output_padding=[0, 0],
groups=1)
weight,
bias=None,
stride=[2, 2],
padding=[1, 1],
dilation=[1, 1],
output_padding=[0, 0],
groups=1)


@register_test_case(module_factory=lambda: Conv_Transpose2dModule())
Expand All @@ -411,13 +411,13 @@ def __init__(self):
])
def forward(self, inputVec, weight):
return torch.ops.aten.conv_transpose3d(inputVec,
weight,
bias=None,
stride=[2, 2, 2],
padding=[1, 1, 1],
dilation=[1, 1, 1],
output_padding=[0, 0, 0],
groups=1)
weight,
bias=None,
stride=[2, 2, 2],
padding=[1, 1, 1],
dilation=[1, 1, 1],
output_padding=[0, 0, 0],
groups=1)


@register_test_case(module_factory=lambda: Conv_Transpose3dModule())
Expand Down

0 comments on commit e683a3c

Please sign in to comment.