Skip to content

Commit

Permalink
Add transposed case for at::convolution
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetters94 committed Aug 23, 2022
1 parent 8cad02f commit 063feb2
Show file tree
Hide file tree
Showing 11 changed files with 554 additions and 27 deletions.
90 changes: 90 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3347,6 +3347,96 @@ def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [
}];
}

def Torch_AtenConvTranspose1dOp : Torch_Op<"aten.conv_transpose1d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::conv_transpose1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding,
AnyTorchListOfTorchIntType:$output_padding,
Torch_IntType:$groups,
AnyTorchListOfTorchIntType:$dilation
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenConvTranspose1dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 8, 1);
}
void AtenConvTranspose1dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 8, 1);
}
}];
}

def Torch_AtenConvTranspose2dInputOp : Torch_Op<"aten.conv_transpose2d.input", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::conv_transpose2d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding,
AnyTorchListOfTorchIntType:$output_padding,
Torch_IntType:$groups,
AnyTorchListOfTorchIntType:$dilation
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenConvTranspose2dInputOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 8, 1);
}
void AtenConvTranspose2dInputOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 8, 1);
}
}];
}

def Torch_AtenConvTranspose3dInputOp : Torch_Op<"aten.conv_transpose3d.input", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::conv_transpose3d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding,
AnyTorchListOfTorchIntType:$output_padding,
Torch_IntType:$groups,
AnyTorchListOfTorchIntType:$dilation
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenConvTranspose3dInputOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 8, 1);
}
void AtenConvTranspose3dInputOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 8, 1);
}
}];
}

def Torch_AtenConvolutionOp : Torch_Op<"aten.convolution", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
138 changes: 115 additions & 23 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include <algorithm>

using namespace mlir;
using namespace mlir::torch;
Expand Down Expand Up @@ -635,12 +636,18 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
Value input = adaptor.input(); /* in form of N*C*H*W */
Value weight = adaptor.weight(); /* in form of F*C*H*W */

bool transposed = true;
if (!matchPattern(op.transposed(), m_TorchConstantBool(&transposed)))
return rewriter.notifyMatchFailure(
op, "unimplemented: only constant transposed supported");

Type elementType =
input.getType().cast<RankedTensorType>().getElementType();
if (!elementType.isa<mlir::FloatType>())
return op.emitError("unimplemented: non-floating point type");
size_t inRank = input.getType().cast<RankedTensorType>().getRank();
if (inRank != 4)
size_t numSpacialDims = inRank - 2;
if (numSpacialDims != 2)
return rewriter.notifyMatchFailure(
op, "unimplemented: only 2D convolution currently supported");

Expand Down Expand Up @@ -674,13 +681,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
for (size_t i = 2; i < inRank; i++)
weightDims.push_back(getDimOp(rewriter, loc, weight, i));

// Guard unused values (transposed)
bool transposed = true;
if (!matchPattern(op.transposed(), m_TorchConstantBool(&transposed)) ||
transposed)
return rewriter.notifyMatchFailure(
op, "unimplemented: only non-transposed convolution supported");

// Checks for valid group size
int64_t groupSize;
if (!matchPattern(op.groups(), m_TorchConstantInt(&groupSize)))
Expand All @@ -701,19 +701,119 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
"invalid: groups must divide input channel size evenly.");
validate(weightBatch,
"invalid: groups must divide weight batch size evenly.");

SmallVector<Value> paddingIntValues =
getAsConstantIntValues(rewriter, loc, paddingInts);
SmallVector<Value> dilationIntValues =
getAsConstantIntValues(rewriter, loc, dilationInts);
SmallVector<Value> paddingIntValues =
getAsConstantIntValues(rewriter, loc, paddingInts);
SmallVector<Value> strideIntValues =
getAsConstantIntValues(rewriter, loc, strideInts);

// Pad the input tensor according to padding.
SmallVector<Value> outDims{inBatch, weightBatch};
for (size_t i = 0; i < inRank - 2; i++)
outDims.push_back(torch_to_linalg::getOutputDimForConvOps(
rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i],
castIndexToInt(weightDims[i]), strideIntValues[i]));
Value paddedInput;
if (transposed) {
Value c0 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
Value c1 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
Value c2 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(2));

// Transpose and flip weight
SmallVector<Value> weightInitDims = getTensorSizes(rewriter, loc, weight);
std::iter_swap(weightInitDims.begin(), weightInitDims.begin() + 1);
outDims[1] = weightInitDims[0];
Value weightInitTensor =
createZeroInitTensor(rewriter, loc, weightInitDims, elementType);
SmallVector<StringRef> iteratorTypes(inRank,
getParallelIteratorTypeName());
SmallVector<AffineMap> indexingMaps(
2, AffineMap::getMultiDimIdentityMap(inRank, context));
weight = rewriter
.create<linalg::GenericOp>(
loc, weightInitTensor.getType(), weight,
weightInitTensor, indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
SmallVector<Value> indices;
for (size_t i = 0; i < inRank; i++)
indices.push_back(b.create<linalg::IndexOp>(loc, i));
std::iter_swap(indices.begin(), indices.begin() + 1);
// Flip only the spatial dimensions (from 2 to inRank)
for (size_t flipDim = 2; flipDim < inRank; flipDim++) {
indices[flipDim] = b.create<arith::SubIOp>(
loc,
b.create<arith::SubIOp>(
loc, weightInitDims[flipDim], c1),
indices[flipDim]);
}
Value res =
b.create<tensor::ExtractOp>(loc, weight, indices)
.getResult();
b.create<linalg::YieldOp>(loc, res);
})
.getResult(0);

// Calculate padded input size, allocate tensor
SmallVector<Value> outerSizes{inBatch, inChannels};
SmallVector<Value> innerSizes{inBatch, inChannels};
SmallVector<Value> offsets{c0, c0};
for (size_t i = 0; i < numSpacialDims; i++) {
Value innerSize = rewriter.create<arith::SubIOp>(loc, inDims[i], c1);
innerSize = rewriter.create<arith::MulIOp>(
loc, innerSize, castIntToIndex(rewriter, loc, strideIntValues[i]));
innerSize = rewriter.create<arith::AddIOp>(loc, innerSize, c1);

Value offset = rewriter.create<arith::SubIOp>(loc, weightDims[i], c1);
offset = rewriter.create<arith::MulIOp>(
loc, offset, castIntToIndex(rewriter, loc, dilationIntValues[i]));
offset = rewriter.create<arith::SubIOp>(
loc, offset, castIntToIndex(rewriter, loc, paddingIntValues[i]));

Value outerSize = rewriter.create<arith::MulIOp>(loc, offset, c2);
outerSize = rewriter.create<arith::AddIOp>(loc, outerSize, innerSize);

outerSizes.push_back(outerSize);
offsets.push_back(offset);
}

// Allocate padded input tensor
Value initTensor =
createZeroInitTensor(rewriter, loc, outerSizes, elementType);

// Insert input into allocated tensor
SmallVector<Value> strideIndexValues{c1, c1};
for (auto stride : strideIntValues)
strideIndexValues.push_back(castIntToIndex(rewriter, loc, stride));
SmallVector<Value> insertSizes = getTensorSizes(rewriter, loc, input);

paddedInput = rewriter.create<tensor::InsertSliceOp>(
loc, torch_to_linalg::removeSizeInformation(rewriter, loc, input),
initTensor, offsets, insertSizes, strideIndexValues);

// Calculate output dims
for (size_t i = 0; i < numSpacialDims; i++)
outDims.push_back(torch_to_linalg::getOutputDimForConvTransposeOps(
rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i],
castIndexToInt(weightDims[i]), strideIntValues[i]));

// Set stride to 1
strideInts.clear();
strideInts.append(numSpacialDims, 1);

} else {
// Pad input
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(),
paddingInts.end());
paddedInput = torch_to_linalg::getZeroPaddedTensor(op, rewriter, input,
paddingIncludingNC);

// Calculate output dims
for (size_t i = 0; i < numSpacialDims; i++)
outDims.push_back(torch_to_linalg::getOutputDimForConvOps(
rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i],
castIndexToInt(weightDims[i]), strideIntValues[i]));
}

Value initTensor =
rewriter.create<linalg::InitTensorOp>(loc, outDims, elementType);
Expand Down Expand Up @@ -769,14 +869,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
SmallVector<Value> weightSliceSizes{weightStride, weightChannels};
weightSliceSizes.append(weightDims);

// Pad the input tensor according to padding.
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
paddingIncludingNC.append(paddingInts);

// Pad inputSlice
Value paddedInput = torch_to_linalg::getZeroPaddedTensor(
op, rewriter, input, paddingIncludingNC);

Value conv;
if (groupSize == 1) {
// TODO: add 1D and 3D case
Expand Down
33 changes: 33 additions & 0 deletions lib/Conversion/TorchToLinalg/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,31 @@ Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc,
return castIntToIndex(b, loc, out);
}

Value torch_to_linalg::getOutputDimForConvTransposeOps(
OpBuilder &b, Location loc, Value in, Value paddingInt, Value dilationInt,
Value kernelSizeInt, Value strideInt) {
Value c1 = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(1));
Value c2 = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(2));

// (in - 1) * stride
Value inStrided =
b.create<arith::SubIOp>(loc, castIndexToInt64(b, loc, in), c1);
inStrided = b.create<arith::MulIOp>(loc, inStrided, strideInt);

// 2 * padding
Value doublePadding = b.create<arith::MulIOp>(loc, paddingInt, c2);

// (kernelSize - 1) * dilation
Value kernelDilated = b.create<arith::SubIOp>(loc, kernelSizeInt, c1);
kernelDilated = b.create<arith::MulIOp>(loc, kernelDilated, dilationInt);

Value out = b.create<arith::SubIOp>(loc, inStrided, doublePadding);
out = b.create<arith::AddIOp>(loc, out, kernelDilated);
out = b.create<arith::AddIOp>(loc, out, c1);

return castIntToIndex(b, loc, out);
}

Value torch_to_linalg::createReductionLinalgGeneric(
OpBuilder &b, Location loc, const ReductionOpInfo &opInfo, Value initElem,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
Expand Down Expand Up @@ -338,3 +363,11 @@ LogicalResult torch_to_linalg::broadcastToGivenShape(

return success();
}

Value torch_to_linalg::removeSizeInformation(OpBuilder &b, Location loc,
Value tensor) {
auto tensorType = tensor.getType().cast<RankedTensorType>();
auto rank = tensorType.getRank();
SmallVector<int64_t> unknownSizes(rank, kUnknownSize);
return b.create<tensor::CastOp>(loc, tensorType.clone(unknownSizes), tensor);
}
11 changes: 11 additions & 0 deletions lib/Conversion/TorchToLinalg/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ Value getOutputDimForConvOps(OpBuilder &b, Location loc, Value in,
Value kernelSizeInt, Value strideInt,
bool ceilMode = false);

// As above but for transposed convolution ops
// Along each dim:
// dim_out =
// (dim_in - 1) * stride - 2 * padding + dilation * (kernelSize - 1) + 1
Value getOutputDimForConvTransposeOps(OpBuilder &b, Location loc, Value in,
Value paddingInt, Value dilationInt,
Value kernelSizeInt, Value strideInt);

// Create a reduction of `opInfo.tensorOperand`, reducing along the dimensions
// in `opInfo.dimSet`. If `opInfo.keepDim` is true, the output tensor is the
// same rank as the `opInfo.tensorOperand` and reduced dimensions are set to
Expand All @@ -61,6 +69,9 @@ LogicalResult broadcastToGivenShape(Operation *op, PatternRewriter &rewriter,
SmallVector<Value> broadcastToShape,
Value &result);

// Cast a tensor to a rank-equivalent tensor of unknown size, i.e. <1x2xf32> ->
// <?x?xf32>
Value removeSizeInformation(OpBuilder &b, Location loc, Value tensor);
} // namespace torch_to_linalg
} // namespace torch
} // namespace mlir
22 changes: 22 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,26 @@ class DecomposeAtenConv2dOp : public OpRewritePattern<AtenConv2dOp> {
};
} // namespace

// Decompose aten.conv_transpose2d to aten.convolution
namespace {
class DecomposeAtenConvTranspose2dOp
: public OpRewritePattern<AtenConvTranspose2dInputOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenConvTranspose2dInputOp op,
PatternRewriter &rewriter) const override {

Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), true);
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
op, op->getResultTypes(), op.input(), op.weight(), op.bias(),
op.stride(), op.padding(), op.dilation(), /*transposed=*/cstTrue,
op.output_padding(), op.groups());

return success();
}
};
} // namespace

// Decompose aten.addmm into aten.mm and aten.add.Tensor op.
namespace {
class DecomposeAtenAddmmOp : public OpRewritePattern<AtenAddmmOp> {
Expand Down Expand Up @@ -2613,6 +2633,8 @@ class DecomposeComplexOpsPass
context);
target.addIllegalOp<AtenConv2dOp>();
patterns.add<DecomposeAtenConv2dOp>(context);
target.addIllegalOp<AtenConvTranspose2dInputOp>();
patterns.add<DecomposeAtenConvTranspose2dOp>(context);
patterns.add<DecomposeAtenArangeOp>(context);
target.addIllegalOp<AtenArangeOp>();
patterns.add<DecomposeAtenArangeStartOp>(context);
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,8 @@ void TypeAnalysis::visitOperation(Operation *op,

// Promote the two dtypes assuming non-zero rank.
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp, AtenConvolutionOverrideableOp>(op)) {
Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp,
AtenConvolutionOverrideableOp, AtenConvTranspose2dInputOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
Expand Down
Loading

0 comments on commit 063feb2

Please sign in to comment.