Skip to content

Commit

Permalink
add general support for statically sized index tensors for aten.index…
Browse files Browse the repository at this point in the history
….Tensor

 - Includes a canonicalization for aten.add.t needed to properly lower
the shape function calculation
  • Loading branch information
qedawkins committed Jul 22, 2022
1 parent a02dbb2 commit dcd2c04
Show file tree
Hide file tree
Showing 10 changed files with 381 additions and 51 deletions.
1 change: 1 addition & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -6504,6 +6504,7 @@ def Torch_AtenAddTOp : Torch_Op<"aten.add.t", [
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasCanonicalizer = 1;
}

def Torch_AtenEqIntListOp : Torch_Op<"aten.eq.int_list", [
Expand Down
196 changes: 161 additions & 35 deletions lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
Expand Down Expand Up @@ -244,13 +245,29 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern<AtenIndexSelectOp> {
};
} // namespace

// IndexTensor for multiple input tensors broadcasts their shapes to a common
// shape and then replaces the indexed dims with the indices given by the
// indexing tensors:
// x[i_1, i_2, ..., i_M] = result
// result[...] = x[i_1[...], i_2[...], ..., i_M[...]]
//
// where the result shape is computed as follows:
// 1. broadcast i_1, i_2, ..., i_M to a common shape
// 2. if i_1, i_2, ..., i_M is not contiguous, transpose the broadcasted
// shape to the beginning of the result shape, while removing the
// unchanged dims (marked by None)
// 3. Otherwise replace the indexed dims with the broadcasted shape
//
// e.g. x: [2, 3]
// x[[4], [6, 1]] -> x[6, 4]
namespace {
class ConvertAtenIndexTensorOp : public OpConversionPattern<AtenIndexTensorOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenIndexTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();

Expand All @@ -266,78 +283,187 @@ class ConvertAtenIndexTensorOp : public OpConversionPattern<AtenIndexTensorOp> {
SmallVector<Value> indicesVal =
getTypeConvertedValues(rewriter, loc, getTypeConverter(), indicesTuple);

int indexTensorDim = -1;
// Identify the indices with non-None index tensors and determine if they
// are contiguous within the input list.
SmallVector<int> indexTensorDims;
bool contiguous = true;
for (auto i : llvm::seq(0, (int)indicesVal.size())) {
Value index = indicesVal[i];
if (!index || failed(checkNotNone(rewriter, op, index)))
continue;
if (indexTensorDim >= 0) {
return rewriter.notifyMatchFailure(
op, "unimplemented: only one index tensor allowed");
}
indexTensorDim = i;
if (indexTensorDims.size())
if (indexTensorDims.back() != i - 1)
contiguous = false;
indexTensorDims.push_back(i);
}

if (indexTensorDim == -1) {
if (!indexTensorDims.size()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: index tensor must not be None");
}

Value indexTensor = indicesVal[indexTensorDim];
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
RankedTensorType indexTensorType =
indexTensor.getType().cast<RankedTensorType>();
RankedTensorType resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
Type elementType = resultType.getElementType();
int inputRank = inputType.getRank();
int indexTensorRank = indexTensorType.getRank();
int resultRank = resultType.getRank();
int firstIndexDim = indexTensorDims[0];
int replacedIndexCount = indexTensorDims.size();

SmallVector<int> indexRanks;
SmallVector<Value> indexTensors;
for (auto i : indexTensorDims) {
indexTensors.push_back(indicesVal[i]);
RankedTensorType indexTensorType =
indexTensors.back().getType().cast<RankedTensorType>();
indexRanks.push_back(indexTensorType.getRank());
}

int maxRankIndex =
std::distance(indexRanks.begin(),
std::max_element(indexRanks.begin(), indexRanks.end()));
int maxRank = indexRanks[maxRankIndex];
Value maxRankIndexTensor = indexTensors[maxRankIndex];

SmallVector<Value> broadcastedIndexShape;

// Currently we only support statically sized index tensors
// when there is more than one index tensor.
// TODO: Add support for dynamic size index tensors. This will probably
// require broadcasting the index tensors to a common shape.
if (indexTensorDims.size() > 1) {
for (auto i : llvm::seq(0, (int)indexTensors.size())) {
Value indexTensor = indexTensors[i];
RankedTensorType indexTensorType =
indexTensor.getType().cast<RankedTensorType>();
auto indexTensorShape = indexTensorType.getShape();
if (llvm::any_of(indexTensorShape, ShapedType::isDynamic)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: index tensors must have static shape if "
"there is more than one index tensor");
}
}

// The following ops from the mlir::shape dialect rely on the lack of
// dynamic shapes to properly fold
SmallVector<Value> indexTensorExtents;
for (auto indexTensor : indexTensors) {
indexTensorExtents.push_back(
rewriter.createOrFold<shape::ShapeOfOp>(loc, indexTensor));
}
Value broadcastedShape = rewriter.createOrFold<shape::BroadcastOp>(
loc, shape::getExtentTensorType(op->getContext(), maxRank),
indexTensorExtents);
for (auto i : llvm::seq(0, maxRank)) {
broadcastedIndexShape.push_back(
rewriter.createOrFold<shape::GetExtentOp>(loc, broadcastedShape,
i));
}
} else {
for (auto i : llvm::seq(0, maxRank)) {
broadcastedIndexShape.push_back(
getDimOp(rewriter, loc, maxRankIndexTensor, i));
}
}

// This result shape calculation assumes that there is only one
// index tensor of the input tensor. The calculation for arbitrary inputs is
// much more complex.
// index tensor, or all of the index tensors are statically shaped.
int broadcastRank = broadcastedIndexShape.size();

SmallVector<Value> resultShape;
for (auto i : llvm::seq(0, indexTensorDim)) {
resultShape.push_back(getDimOp(rewriter, loc, input, i));
}
for (auto i : llvm::seq(0, indexTensorRank)) {
resultShape.push_back(getDimOp(rewriter, loc, indexTensor, i));
}
for (auto i : llvm::seq(indexTensorDim + 1, inputRank)) {
resultShape.push_back(getDimOp(rewriter, loc, input, i));
if (contiguous) {
for (auto i : llvm::seq(0, firstIndexDim)) {
resultShape.push_back(getDimOp(rewriter, loc, input, i));
}
resultShape.append(broadcastedIndexShape);
for (auto i : llvm::seq((int)resultShape.size(), resultRank)) {
resultShape.push_back(getDimOp(rewriter, loc, input,
i - broadcastRank + replacedIndexCount));
}
} else {
resultShape.append(broadcastedIndexShape);
int j = 0;
for (auto i : llvm::seq(0, inputRank)) {
if (i == indexTensorDims[j]) {
j++;
continue;
}
resultShape.push_back(getDimOp(rewriter, loc, input, i));
}
}
int resultRank = resultShape.size();

// Initialize the indexing maps for the generic op. Because we are assuming
// static shapes for the indexing tensors when there are more than 1, we can
// safely map all size 1 dims to 0 in the corresponding affine maps.
// TODO: For dynamic shapes, we have to either broadcast the index tensors
// to a common shape or introduce some form of control flow.
Value initTensor =
rewriter.create<linalg::InitTensorOp>(loc, resultShape, elementType);
SmallVector<AffineExpr> indicesExpr, resultExpr;
SmallVector<AffineMap> indexingMaps;
SmallVector<StringRef> iteratorTypes;

for (auto i : llvm::seq(indexTensorDim, indexTensorDim + indexTensorRank))
indicesExpr.push_back(rewriter.getAffineDimExpr(i));
for (auto indexTensor : indexTensors) {
RankedTensorType indexTensorType =
indexTensor.getType().cast<RankedTensorType>();
auto indexTensorShape = indexTensorType.getShape();
int rank = indexTensorShape.size();
SmallVector<AffineExpr> indicesExpr;
for (auto dim : llvm::seq(0, rank)) {
if (indexTensorShape[dim] == 1) {
indicesExpr.push_back(rewriter.getAffineConstantExpr(0));
continue;
}
indicesExpr.push_back(rewriter.getAffineDimExpr(
contiguous ? firstIndexDim + maxRank - rank + dim
: maxRank - rank + dim));
}
indexingMaps.push_back(
AffineMap::get(resultRank, 0, indicesExpr, op->getContext()));
}

SmallVector<AffineExpr> resultExpr;
for (auto i : llvm::seq(0, resultRank)) {
resultExpr.push_back(rewriter.getAffineDimExpr(i));
iteratorTypes.push_back(getParallelIteratorTypeName());
}
auto indexingMaps = AffineMap::inferFromExprList({indicesExpr, resultExpr});

indexingMaps.push_back(
AffineMap::get(resultRank, 0, resultExpr, op->getContext()));

Value finalRes =
rewriter
.create<linalg::GenericOp>(
loc, initTensor.getType(), indexTensor, initTensor,
loc, initTensor.getType(), indexTensors, initTensor,
indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value index = castIntToIndex(b, loc, args[0]);
SmallVector<Value> extractionIndices;
int extra_dims = 0;
for (auto i : llvm::seq(0, inputRank)) {
if (i == indexTensorDim) {
extractionIndices.push_back(index);
extra_dims += indexTensorRank - 1;
} else {
if (contiguous) {
for (auto i : llvm::seq(0, firstIndexDim)) {
extractionIndices.push_back(
b.create<linalg::IndexOp>(loc, i + extra_dims));
b.create<linalg::IndexOp>(loc, i));
}
for (auto i : llvm::seq(0, (int)indexTensorDims.size())) {
extractionIndices.push_back(
castIntToIndex(b, loc, args[i]));
}
for (auto i :
llvm::seq((int)extractionIndices.size(), inputRank)) {
extractionIndices.push_back(b.create<linalg::IndexOp>(
loc, i + broadcastRank - replacedIndexCount));
}
} else {
int indexCount = 0, unchanged = 0;
for (auto i : llvm::seq(0, inputRank)) {
if (i == indexTensorDims[indexCount]) {
extractionIndices.push_back(
castIntToIndex(b, loc, args[indexCount++]));
continue;
}
extractionIndices.push_back(b.create<linalg::IndexOp>(
loc, broadcastRank + unchanged));
unchanged++;
}
}
Value extractedElement = b.create<tensor::ExtractOp>(
Expand Down
5 changes: 4 additions & 1 deletion lib/Conversion/TorchToLinalg/TorchToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
Expand Down Expand Up @@ -45,6 +46,7 @@ class ConvertTorchToLinalg
registry.insert<tensor::TensorDialect>();
registry.insert<arith::ArithmeticDialect>();
registry.insert<cf::ControlFlowDialect>();
registry.insert<shape::ShapeDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
}

Expand All @@ -53,7 +55,8 @@ class ConvertTorchToLinalg
ConversionTarget target(*context);
target.addLegalDialect<linalg::LinalgDialect, func::FuncDialect,
cf::ControlFlowDialect, math::MathDialect,
tensor::TensorDialect, arith::ArithmeticDialect>();
tensor::TensorDialect, arith::ArithmeticDialect,
shape::ShapeDialect>();
target.addLegalOp<TorchConversion::GetNextSeedOp>();

TypeConverter typeConverter;
Expand Down
29 changes: 29 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1479,6 +1479,35 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
});
}

//===----------------------------------------------------------------------===//
// AtenAddTOp
//===----------------------------------------------------------------------===//

void AtenAddTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](AtenAddTOp op, PatternRewriter &rewriter) {
auto lhsListConstruct = op.a().getDefiningOp<Torch::PrimListConstructOp>();
if (!lhsListConstruct || isListPotentiallyMutated(lhsListConstruct))
return failure();

auto rhsListConstruct = op.b().getDefiningOp<Torch::PrimListConstructOp>();
if (!rhsListConstruct || isListPotentiallyMutated(rhsListConstruct))
return failure();

SmallVector<Value> concatenatedList;
for (auto a : lhsListConstruct.getOperands()) {
concatenatedList.push_back(a);
}
for (auto b : rhsListConstruct.getOperands()) {
concatenatedList.push_back(b);
}

rewriter.replaceOpWithNewOp<Torch::PrimListConstructOp>(op, op.getType(),
concatenatedList);
return success();
});
}

//===----------------------------------------------------------------------===//
// AtenEqIntListOp
//===----------------------------------------------------------------------===//
Expand Down
22 changes: 11 additions & 11 deletions lib/Dialect/Torch/Transforms/ShapeLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6578,30 +6578,30 @@ module {
%10 = torch.aten.len.t %arg1 : !torch.list<optional<list<int>>> -> !torch.int
%11 = torch.prim.ListConstruct %int9223372036854775807, %10 : (!torch.int, !torch.int) -> !torch.list<int>
%12 = torch.prim.min.self_int %11 : !torch.list<int> -> !torch.int
%13:3 = torch.prim.Loop %12, %true, init(%true, %int-1, %int-1) {
^bb0(%arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.int, %arg5: !torch.int):
%13:2 = torch.prim.Loop %12, %true, init(%true, %int-1) {
^bb0(%arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.int):
%16 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<optional<list<int>>>, !torch.int -> !torch.optional<list<int>>
%17 = torch.aten.__isnot__ %16, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%18:3 = torch.prim.If %17 -> (!torch.bool, !torch.int, !torch.int) {
%18:2 = torch.prim.If %17 -> (!torch.bool, !torch.int) {
%19 = torch.aten.eq.int %arg4, %int-1 : !torch.int, !torch.int -> !torch.bool
%20:3 = torch.prim.If %19 -> (!torch.bool, !torch.int, !torch.int) {
torch.prim.If.yield %arg3, %arg2, %arg2 : !torch.bool, !torch.int, !torch.int
%20:2 = torch.prim.If %19 -> (!torch.bool, !torch.int) {
torch.prim.If.yield %arg3, %arg2 : !torch.bool, !torch.int
} else {
%21 = torch.aten.sub.int %arg2, %arg5 : !torch.int, !torch.int -> !torch.int
%21 = torch.aten.sub.int %arg2, %arg4 : !torch.int, !torch.int -> !torch.int
%22 = torch.aten.ne.int %21, %int1 : !torch.int, !torch.int -> !torch.bool
%23 = torch.prim.If %22 -> (!torch.bool) {
torch.prim.If.yield %false : !torch.bool
} else {
torch.prim.If.yield %arg3 : !torch.bool
}
torch.prim.If.yield %23, %arg4, %arg5 : !torch.bool, !torch.int, !torch.int
torch.prim.If.yield %23, %arg4 : !torch.bool, !torch.int
}
torch.prim.If.yield %20#0, %20#1, %20#2 : !torch.bool, !torch.int, !torch.int
torch.prim.If.yield %20#0, %20#1 : !torch.bool, !torch.int
} else {
torch.prim.If.yield %arg3, %arg4, %arg5 : !torch.bool, !torch.int, !torch.int
torch.prim.If.yield %arg3, %arg4 : !torch.bool, !torch.int
}
torch.prim.Loop.condition %true, iter(%18#0, %18#1, %18#2 : !torch.bool, !torch.int, !torch.int)
} : (!torch.int, !torch.bool, !torch.bool, !torch.int, !torch.int) -> (!torch.bool, !torch.int, !torch.int)
torch.prim.Loop.condition %true, iter(%18#0, %18#1 : !torch.bool, !torch.int)
} : (!torch.int, !torch.bool, !torch.bool, !torch.int) -> (!torch.bool, !torch.int)
%14 = torch.aten.__not__ %13#0 : !torch.bool -> !torch.bool
%15 = torch.prim.If %14 -> (!torch.list<int>) {
%16 = torch.aten.add.t %6, %4 : !torch.list<int>, !torch.list<int> -> !torch.list<int>
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ class SimplifyShapeCalculationsPass
Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context);
AtenSizeOp::getCanonicalizationPatterns(patterns, context);
AtenLenTOp::getCanonicalizationPatterns(patterns, context);
AtenAddTOp::getCanonicalizationPatterns(patterns, context);

// TODO: Debug visitation order to make this more efficient.
// A single linear scan should suffice.
Expand Down
Loading

0 comments on commit dcd2c04

Please sign in to comment.