Skip to content

Commit

Permalink
Add transposed case for at::convolution
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetters94 committed Jun 20, 2022
1 parent a34dad2 commit da81a2f
Show file tree
Hide file tree
Showing 11 changed files with 591 additions and 27 deletions.
90 changes: 90 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3083,6 +3083,96 @@ def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [
}];
}

def Torch_AtenConvTranspose1dOp : Torch_Op<"aten.conv_transpose1d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::conv_transpose1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding,
AnyTorchListOfTorchIntType:$output_padding,
Torch_IntType:$groups,
AnyTorchListOfTorchIntType:$dilation
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenConvTranspose1dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 8, 1);
}
void AtenConvTranspose1dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 8, 1);
}
}];
}

def Torch_AtenConvTranspose2dInputOp : Torch_Op<"aten.conv_transpose2d.input", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::conv_transpose2d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding,
AnyTorchListOfTorchIntType:$output_padding,
Torch_IntType:$groups,
AnyTorchListOfTorchIntType:$dilation
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenConvTranspose2dInputOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 8, 1);
}
void AtenConvTranspose2dInputOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 8, 1);
}
}];
}

def Torch_AtenConvTranspose3dInputOp : Torch_Op<"aten.conv_transpose3d.input", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::conv_transpose3d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding,
AnyTorchListOfTorchIntType:$output_padding,
Torch_IntType:$groups,
AnyTorchListOfTorchIntType:$dilation
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenConvTranspose3dInputOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 8, 1);
}
void AtenConvTranspose3dInputOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 8, 1);
}
}];
}

def Torch_AtenConvolutionOp : Torch_Op<"aten.convolution", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
152 changes: 129 additions & 23 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"

#include "../PassDetail.h"
Expand All @@ -20,8 +21,12 @@
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/ArrayRef.h"
#include <algorithm>
#include <bits/stdint-intn.h>

using namespace mlir;
using namespace mlir::torch;
Expand Down Expand Up @@ -534,12 +539,18 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
Value input = adaptor.input(); /* in form of N*C*H*W */
Value weight = adaptor.weight(); /* in form of F*C*H*W */

bool transposed = true;
if (!matchPattern(op.transposed(), m_TorchConstantBool(&transposed)))
return rewriter.notifyMatchFailure(
op, "unimplemented: only constant transposed supported");

Type elementType =
input.getType().cast<RankedTensorType>().getElementType();
if (!elementType.isa<mlir::FloatType>())
return op.emitError("unimplemented: non-floating point type");
size_t inRank = input.getType().cast<RankedTensorType>().getRank();
if (inRank != 4)
size_t numSpacialDims = inRank - 2;
if (numSpacialDims != 2)
return rewriter.notifyMatchFailure(
op, "unimplemented: only 2D convolution currently supported");

Expand All @@ -563,6 +574,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
"only support constant int dilations");

Value N = getDimOp(rewriter, loc, input, 0);
Value inChannels = getDimOp(rewriter, loc, input, 1);
SmallVector<Value> inDims;
for (size_t i = 2; i < inRank; i++)
inDims.push_back(getDimOp(rewriter, loc, input, i));
Expand All @@ -571,37 +583,131 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
for (size_t i = 2; i < inRank; i++)
weightDims.push_back(getDimOp(rewriter, loc, weight, i));

// Guard unused values (transposed, groups)
SmallVector<Value> dilationIntValues =
getAsConstantIntValues(rewriter, loc, dilationInts);
SmallVector<Value> paddingIntValues =
getAsConstantIntValues(rewriter, loc, paddingInts);
SmallVector<Value> strideIntValues =
getAsConstantIntValues(rewriter, loc, strideInts);

// Guard unused values (groups)
int64_t group_size;
if (!matchPattern(op.groups(), m_TorchConstantInt(&group_size)) ||
group_size != 1)
return rewriter.notifyMatchFailure(
op, "unimplemented: only group size of 1 supported");
bool transposed = true;
if (!matchPattern(op.transposed(), m_TorchConstantBool(&transposed)) ||
transposed)
return rewriter.notifyMatchFailure(
op, "unimplemented: only non-transposed convolution supported");

// Pad the input tensor according to padding.
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(),
paddingInts.end());
Value paddedInput = torch_to_linalg::getZeroPaddedTensor(
op, rewriter, input, paddingIncludingNC);
SmallVector<Value> outDims{N, F};
Value paddedInput;
if (transposed) {
Value c0 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
Value c1 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
Value c2 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(2));

// Transpose and flip weight
SmallVector<Value> weightInitDims = getTensorSizes(rewriter, loc, weight);
std::iter_swap(weightInitDims.begin(), weightInitDims.begin() + 1);
outDims[1] = weightInitDims[0];
Value weightInitTensor =
createZeroInitTensor(rewriter, loc, weightInitDims, elementType);
SmallVector<StringRef> iteratorTypes(inRank,
getParallelIteratorTypeName());
SmallVector<AffineMap> indexingMaps(
2, AffineMap::getMultiDimIdentityMap(inRank, context));
weight = rewriter
.create<linalg::GenericOp>(
loc, weightInitTensor.getType(), weight,
weightInitTensor, indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
SmallVector<Value> indices;
for (size_t i = 0; i < inRank; i++)
indices.push_back(b.create<linalg::IndexOp>(loc, i));
std::iter_swap(indices.begin(), indices.begin() + 1);
// Flip only the spatial dimensions (from 2 to inRank)
for (size_t flipDim = 2; flipDim < inRank; flipDim++) {
indices[flipDim] = b.create<arith::SubIOp>(
loc,
b.create<arith::SubIOp>(
loc, weightInitDims[flipDim], c1),
indices[flipDim]);
}
Value res =
b.create<tensor::ExtractOp>(loc, weight, indices)
.getResult();
b.create<linalg::YieldOp>(loc, res);
})
.getResult(0);

SmallVector<Value> paddingIntValues =
getAsConstantIntValues(rewriter, loc, paddingInts);
SmallVector<Value> dilationIntValues =
getAsConstantIntValues(rewriter, loc, dilationInts);
SmallVector<Value> strideIntValues =
getAsConstantIntValues(rewriter, loc, strideInts);
// Calculate padded input size, allocate tensor
SmallVector<Value> outerSizes{N, inChannels};
SmallVector<Value> innerSizes{N, inChannels};
SmallVector<Value> offsets{c0, c0};
for (size_t i = 0; i < numSpacialDims; i++) {
Value innerSize = rewriter.create<arith::SubIOp>(loc, inDims[i], c1);
innerSize = rewriter.create<arith::MulIOp>(
loc, innerSize, castIntToIndex(rewriter, loc, strideIntValues[i]));
innerSize = rewriter.create<arith::AddIOp>(loc, innerSize, c1);

Value offset = rewriter.create<arith::SubIOp>(loc, weightDims[i], c1);
offset = rewriter.create<arith::MulIOp>(
loc, offset, castIntToIndex(rewriter, loc, dilationIntValues[i]));
offset = rewriter.create<arith::SubIOp>(
loc, offset, castIntToIndex(rewriter, loc, paddingIntValues[i]));

Value outerSize = rewriter.create<arith::MulIOp>(loc, offset, c2);
outerSize = rewriter.create<arith::AddIOp>(loc, outerSize, innerSize);

innerSizes.push_back(innerSize);
outerSizes.push_back(outerSize);
offsets.push_back(offset);
}

SmallVector<Value> outDims{N, F};
for (size_t i = 0; i < inRank - 2; i++)
outDims.push_back(torch_to_linalg::getOutputDimForConvOps(
rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i],
castIndexToInt(weightDims[i]), strideIntValues[i]));
// Allocate padded input tensor
Value initTensor =
createZeroInitTensor(rewriter, loc, outerSizes, elementType);

// Insert input into allocated tensor
SmallVector<Value> strideIndexValues{c1, c1};
for (auto stride : strideIntValues)
strideIndexValues.push_back(castIntToIndex(rewriter, loc, stride));
SmallVector<Value> insertSizes = getTensorSizes(rewriter, loc, input);

paddedInput = rewriter.create<tensor::InsertSliceOp>(
loc, torch_to_linalg::dynamicCast(rewriter, loc, input), initTensor,
offsets, insertSizes, strideIndexValues);

// Calculate output dims
for (size_t i = 0; i < numSpacialDims; i++)
outDims.push_back(torch_to_linalg::getOutputDimForConvTransposeOps(
rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i],
castIndexToInt(weightDims[i]), strideIntValues[i]));

// Set stride to 1
strideInts.clear();
strideInts.append(numSpacialDims, 1);
strideIntValues = getAsConstantIntValues(rewriter, loc, strideInts);

paddingIntValues.clear();
for (auto pad = offsets.begin() + 2; pad < offsets.end(); pad++)
paddingIntValues.push_back(castIndexToInt(*pad));
} else {
// Pad input
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(),
paddingInts.end());
paddedInput = torch_to_linalg::getZeroPaddedTensor(op, rewriter, input,
paddingIncludingNC);

// Calculate output dims
for (size_t i = 0; i < numSpacialDims; i++)
outDims.push_back(torch_to_linalg::getOutputDimForConvOps(
rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i],
castIndexToInt(weightDims[i]), strideIntValues[i]));
}

Value initTensor =
rewriter.create<linalg::InitTensorOp>(loc, outDims, elementType);
Expand Down
33 changes: 33 additions & 0 deletions lib/Conversion/TorchToLinalg/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"

Expand Down Expand Up @@ -97,6 +98,31 @@ Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc,
return castIntToIndex(b, loc, out);
}

Value torch_to_linalg::getOutputDimForConvTransposeOps(
OpBuilder &b, Location loc, Value in, Value paddingInt, Value dilationInt,
Value kernelSizeInt, Value strideInt) {
Value c1 = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(1));
Value c2 = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(2));

// (in - 1) * stride
Value inStrided =
b.create<arith::SubIOp>(loc, castIndexToInt64(b, loc, in), c1);
inStrided = b.create<arith::MulIOp>(loc, inStrided, strideInt);

// 2 * padding
Value doublePadding = b.create<arith::MulIOp>(loc, paddingInt, c2);

// (kernelSize - 1) * dilation
Value kernelDilated = b.create<arith::SubIOp>(loc, kernelSizeInt, c1);
kernelDilated = b.create<arith::MulIOp>(loc, kernelDilated, dilationInt);

Value out = b.create<arith::SubIOp>(loc, inStrided, doublePadding);
out = b.create<arith::AddIOp>(loc, out, kernelDilated);
out = b.create<arith::AddIOp>(loc, out, c1);

return castIntToIndex(b, loc, out);
}

Value torch_to_linalg::createReductionLinalgGeneric(
OpBuilder &b, Location loc, const ReductionOpInfo &opInfo, Value initElem,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
Expand Down Expand Up @@ -256,3 +282,10 @@ Value torch_to_linalg::createElementwiseLinalgGeneric(
iteratorTypes, bodyBuild)
.getResult(0);
}

Value torch_to_linalg::dynamicCast(OpBuilder &b, Location loc, Value tensor) {
auto tensorType = tensor.getType().cast<RankedTensorType>();
auto rank = tensorType.getRank();
SmallVector<int64_t> unknownSizes(rank, kUnknownSize);
return b.create<tensor::CastOp>(loc, tensorType.clone(unknownSizes), tensor);
}
12 changes: 12 additions & 0 deletions lib/Conversion/TorchToLinalg/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ Value getOutputDimForConvOps(OpBuilder &b, Location loc, Value in,
Value kernelSizeInt, Value strideInt,
bool ceilMode = false);

// As above but for transposed convolution ops
// Along each dim:
// dim_out =
// (dim_in - 1) * stride - 2 * padding + dilation * (kernelSize - 1) + 1
Value getOutputDimForConvTransposeOps(OpBuilder &b, Location loc, Value in,
Value paddingInt, Value dilationInt,
Value kernelSizeInt, Value strideInt);

// Create a reduction of `opInfo.tensorOperand`, reducing along the dimensions
// in `opInfo.dimSet`. If `opInfo.keepDim` is true, the output tensor is the
// same rank as the `opInfo.tensorOperand` and reduced dimensions are set to
Expand All @@ -54,6 +62,10 @@ Value createElementwiseLinalgGeneric(
OpBuilder &b, Location loc, ValueRange tensorOperands,
Type resultElementType,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild);

// Cast a tensor to a rank-equivalent tensor of unknown size, i.e. <1x2xf32> ->
// <?x?xf32>
Value dynamicCast(OpBuilder &b, Location loc, Value tensor);
} // namespace torch_to_linalg
} // namespace torch
} // namespace mlir
Loading

0 comments on commit da81a2f

Please sign in to comment.