Skip to content

Commit

Permalink
Move helper out to Utils.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetters94 committed Jun 15, 2022
1 parent 298dff9 commit 03fd7d0
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
14 changes: 2 additions & 12 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -544,14 +544,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
return rewriter.create<arith::IndexCastOp>(loc, intType, v);
};

auto dynamicCast = [&](Value tensor) {
SmallVector<int64_t> dynamicDims(inRank, kUnknownSize);
return rewriter.create<tensor::CastOp>(
loc,
RankedTensorType::get(llvm::makeArrayRef(dynamicDims), elementType),
tensor);
};

SmallVector<int64_t> paddingInts;
if (!matchPattern(op.padding(), m_TorchConstantIntList(paddingInts))) {
return rewriter.notifyMatchFailure(
Expand Down Expand Up @@ -670,10 +662,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
SmallVector<Value> insertSizes = getTensorSizes(rewriter, loc, input);

paddedInput = rewriter.create<tensor::InsertSliceOp>(
loc, dynamicCast(input), initTensor, offsets, insertSizes,
strideIndexValues);

SmallVector<Value> offsetPad(offsets.begin() + 2, offsets.end());
loc, torch_to_linalg::dynamicCast(rewriter, loc, input), initTensor,
offsets, insertSizes, strideIndexValues);

// Calculate output dims
for (size_t i = 0; i < inRank - 2; i++)
Expand Down
8 changes: 8 additions & 0 deletions lib/Conversion/TorchToLinalg/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"

Expand Down Expand Up @@ -281,3 +282,10 @@ Value torch_to_linalg::createElementwiseLinalgGeneric(
iteratorTypes, bodyBuild)
.getResult(0);
}

Value torch_to_linalg::dynamicCast(OpBuilder &b, Location loc, Value tensor) {
auto tensorType = tensor.getType().cast<RankedTensorType>();
auto rank = tensorType.getRank();
SmallVector<int64_t> unknownSizes(rank, kUnknownSize);
return b.create<tensor::CastOp>(loc, tensorType.clone(unknownSizes), tensor);
}
4 changes: 4 additions & 0 deletions lib/Conversion/TorchToLinalg/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ Value createElementwiseLinalgGeneric(
OpBuilder &b, Location loc, ValueRange tensorOperands,
Type resultElementType,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild);

// Cast a tensor to a rank-equivalent tensor of unknown size, i.e. <1x2xf32> ->
// <?x?xf32>
Value dynamicCast(OpBuilder &b, Location loc, Value tensor);
} // namespace torch_to_linalg
} // namespace torch
} // namespace mlir

0 comments on commit 03fd7d0

Please sign in to comment.