Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetters94 committed Aug 6, 2022
1 parent 11d3412 commit de25c49
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 17 deletions.
13 changes: 2 additions & 11 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"

#include "../PassDetail.h"
Expand All @@ -21,12 +20,9 @@
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/ArrayRef.h"
#include <algorithm>
#include <bits/stdint-intn.h>

using namespace mlir;
using namespace mlir::torch;
Expand Down Expand Up @@ -776,7 +772,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
Value outerSize = rewriter.create<arith::MulIOp>(loc, offset, c2);
outerSize = rewriter.create<arith::AddIOp>(loc, outerSize, innerSize);

innerSizes.push_back(innerSize);
outerSizes.push_back(outerSize);
offsets.push_back(offset);
}
Expand All @@ -792,8 +787,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
SmallVector<Value> insertSizes = getTensorSizes(rewriter, loc, input);

paddedInput = rewriter.create<tensor::InsertSliceOp>(
loc, torch_to_linalg::dynamicCast(rewriter, loc, input), initTensor,
offsets, insertSizes, strideIndexValues);
loc, torch_to_linalg::removeSizeInformation(rewriter, loc, input),
initTensor, offsets, insertSizes, strideIndexValues);

// Calculate output dims
for (size_t i = 0; i < numSpacialDims; i++)
Expand All @@ -804,11 +799,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
// Set stride to 1
strideInts.clear();
strideInts.append(numSpacialDims, 1);
strideIntValues = getAsConstantIntValues(rewriter, loc, strideInts);

paddingIntValues.clear();
for (auto pad = offsets.begin() + 2; pad < offsets.end(); pad++)
paddingIntValues.push_back(castIndexToInt(*pad));
} else {
// Pad input
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TorchToLinalg/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"

Expand Down Expand Up @@ -365,7 +364,8 @@ LogicalResult torch_to_linalg::broadcastToGivenShape(
return success();
}

Value torch_to_linalg::dynamicCast(OpBuilder &b, Location loc, Value tensor) {
Value torch_to_linalg::removeSizeInformation(OpBuilder &b, Location loc,
Value tensor) {
auto tensorType = tensor.getType().cast<RankedTensorType>();
auto rank = tensorType.getRank();
SmallVector<int64_t> unknownSizes(rank, kUnknownSize);
Expand Down
3 changes: 1 addition & 2 deletions lib/Conversion/TorchToLinalg/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ LogicalResult broadcastToGivenShape(Operation *op, PatternRewriter &rewriter,

// Cast a tensor to a rank-equivalent tensor of unknown size, i.e. <1x2xf32> ->
// <?x?xf32>
Value dynamicCast(OpBuilder &b, Location loc, Value tensor);

Value removeSizeInformation(OpBuilder &b, Location loc, Value tensor);
} // namespace torch_to_linalg
} // namespace torch
} // namespace mlir
4 changes: 2 additions & 2 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -980,8 +980,8 @@ class DecomposeAtenConvTranspose2dOp
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), true);
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
op, op->getResultTypes(), op.input(), op.weight(), op.bias(),
op.stride(), op.padding(), op.dilation(), cstTrue, op.output_padding(),
op.groups());
op.stride(), op.padding(), op.dilation(), /*transposed=*/cstTrue,
op.output_padding(), op.groups());

return success();
}
Expand Down

0 comments on commit de25c49

Please sign in to comment.