Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][TORCH] Add support for multiple indexing tensors for aten.index.Tensor #1097

Merged
merged 1 commit into from
Jul 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -6619,6 +6619,7 @@ def Torch_AtenAddTOp : Torch_Op<"aten.add.t", [
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasCanonicalizer = 1;
}

def Torch_AtenEqIntListOp : Torch_Op<"aten.eq.int_list", [
Expand Down
175 changes: 139 additions & 36 deletions lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,29 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern<AtenIndexSelectOp> {
};
} // namespace

// IndexTensor for multiple input tensors broadcasts their shapes to a common
// shape and then replaces the indexed dims with the indices given by the
// indexing tensors:
// x[i_1, i_2, ..., i_M] = result
// result[...] = x[i_1[...], i_2[...], ..., i_M[...]]
//
// where the result shape is computed as follows:
// 1. broadcast i_1, i_2, ..., i_M to a common shape
// 2. if i_1, i_2, ..., i_M is not contiguous, transpose the broadcasted
// shape to the beginning of the result shape, while removing the
// unchanged dims (marked by None)
// 3. Otherwise replace the indexed dims with the broadcasted shape
//
// e.g. x: [2, 3]
// x[[4], [6, 1]] -> x[6, 4]
namespace {
class ConvertAtenIndexTensorOp : public OpConversionPattern<AtenIndexTensorOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenIndexTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();

Expand All @@ -266,78 +282,165 @@ class ConvertAtenIndexTensorOp : public OpConversionPattern<AtenIndexTensorOp> {
SmallVector<Value> indicesVal =
getTypeConvertedValues(rewriter, loc, getTypeConverter(), indicesTuple);

int indexTensorDim = -1;
// Identify the indices with non-None index tensors and determine if they
// are contiguous within the input list.
SmallVector<int> indexTensorDims;
SmallVector<Value> indexTensors;
bool contiguous = true;
for (auto i : llvm::seq(0, (int)indicesVal.size())) {
Value index = indicesVal[i];
if (!index || failed(checkNotNone(rewriter, op, index)))
continue;
if (indexTensorDim >= 0) {
return rewriter.notifyMatchFailure(
op, "unimplemented: only one index tensor allowed");
}
indexTensorDim = i;
if (!indexTensorDims.empty() && indexTensorDims.back() != i - 1)
contiguous = false;
indexTensorDims.push_back(i);
indexTensors.push_back(index);
}

if (indexTensorDim == -1) {
if (indexTensors.empty()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: index tensor must not be None");
op, "aten.index.Tensor: index tensor must not be None");
}

Value indexTensor = indicesVal[indexTensorDim];
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
RankedTensorType indexTensorType =
indexTensor.getType().cast<RankedTensorType>();
RankedTensorType resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
Type elementType = resultType.getElementType();
int inputRank = inputType.getRank();
int indexTensorRank = indexTensorType.getRank();
int resultRank = resultType.getRank();
int firstIndexDim = indexTensorDims[0];
int replacedIndexCount = indexTensorDims.size();
int64_t startIndex = contiguous ? firstIndexDim : 0;

// Currently we only support statically sized index tensors
// when there is more than one index tensor.
// TODO: Add support for dynamic size index tensors. This will probably
// require broadcasting the index tensors to a common shape.
SmallVector<Value> broadcastedIndexShape;
if (indexTensors.size() > 1) {
int maxRank = -1;
for (auto indexTensor : indexTensors) {
RankedTensorType indexTensorType =
indexTensor.getType().cast<RankedTensorType>();
maxRank = std::max(maxRank, (int)indexTensorType.getRank());
}

// Because we are assuming static shapes, we can get the shape of the
// broadcasted index tensors from the shape refinement pass
auto refinedResultShape = resultType.getShape();
for (auto i : llvm::seq(startIndex, startIndex + maxRank)) {
auto resultDimSize = refinedResultShape[i];
if (ShapedType::isDynamic(resultDimSize)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: index tensors must have static shape if "
"there is more than one index tensor");
}
broadcastedIndexShape.push_back(
getConstant(rewriter, loc, resultDimSize, rewriter.getIndexType()));
}
} else {
// For a single indexing tensor we can simply use its (dynamic) sizes
broadcastedIndexShape =
getTensorSizes(rewriter, loc, indexTensors.front());
}

// This result shape calculation assumes that there is only one
// index tensor of the input tensor. The calculation for arbitrary inputs is
// much more complex.
// index tensor, or all of the index tensors are statically shaped.
int broadcastRank = broadcastedIndexShape.size();

SmallVector<Value> resultShape;
for (auto i : llvm::seq(0, indexTensorDim)) {
resultShape.push_back(getDimOp(rewriter, loc, input, i));
}
for (auto i : llvm::seq(0, indexTensorRank)) {
resultShape.push_back(getDimOp(rewriter, loc, indexTensor, i));
}
for (auto i : llvm::seq(indexTensorDim + 1, inputRank)) {
resultShape.push_back(getDimOp(rewriter, loc, input, i));
if (contiguous) {
for (auto i : llvm::seq(0, firstIndexDim)) {
resultShape.push_back(getDimOp(rewriter, loc, input, i));
}
resultShape.append(broadcastedIndexShape);
for (auto i : llvm::seq((int)resultShape.size(), resultRank)) {
resultShape.push_back(getDimOp(rewriter, loc, input,
i - broadcastRank + replacedIndexCount));
}
} else {
resultShape.append(broadcastedIndexShape);
int j = 0;
for (auto i : llvm::seq(0, inputRank)) {
if (j < replacedIndexCount && i == indexTensorDims[j]) {
j++;
continue;
}
resultShape.push_back(getDimOp(rewriter, loc, input, i));
}
}
int resultRank = resultShape.size();

// Initialize the indexing maps for the generic op. Because we are assuming
// static shapes for the indexing tensors when there are more than 1, we can
// safely map all size 1 dims to 0 in the corresponding affine maps.
// TODO: For dynamic shapes, we have to either broadcast the index tensors
// to a common shape or introduce some form of control flow.
Value initTensor =
rewriter.create<linalg::InitTensorOp>(loc, resultShape, elementType);
SmallVector<AffineExpr> indicesExpr, resultExpr;
SmallVector<AffineMap> indexingMaps;
SmallVector<StringRef> iteratorTypes;

for (auto i : llvm::seq(indexTensorDim, indexTensorDim + indexTensorRank))
indicesExpr.push_back(rewriter.getAffineDimExpr(i));
for (auto indexTensor : indexTensors) {
RankedTensorType indexTensorType =
indexTensor.getType().cast<RankedTensorType>();
auto indexTensorShape = indexTensorType.getShape();
int rank = indexTensorShape.size();
SmallVector<AffineExpr> indicesExpr;
for (auto dim : llvm::seq(0, rank)) {
if (indexTensorShape[dim] == 1) {
indicesExpr.push_back(rewriter.getAffineConstantExpr(0));
continue;
}
indicesExpr.push_back(
rewriter.getAffineDimExpr(startIndex + broadcastRank - rank + dim));
}
indexingMaps.push_back(
AffineMap::get(resultRank, 0, indicesExpr, op->getContext()));
}

SmallVector<AffineExpr> resultExpr;
for (auto i : llvm::seq(0, resultRank)) {
resultExpr.push_back(rewriter.getAffineDimExpr(i));
iteratorTypes.push_back(getParallelIteratorTypeName());
}
auto indexingMaps = AffineMap::inferFromExprList({indicesExpr, resultExpr});

indexingMaps.push_back(
AffineMap::get(resultRank, 0, resultExpr, op->getContext()));

Value finalRes =
rewriter
.create<linalg::GenericOp>(
loc, initTensor.getType(), indexTensor, initTensor,
loc, initTensor.getType(), indexTensors, initTensor,
indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value index = castIntToIndex(b, loc, args[0]);
SmallVector<Value> extractionIndices;
int extra_dims = 0;
for (auto i : llvm::seq(0, inputRank)) {
if (i == indexTensorDim) {
extractionIndices.push_back(index);
extra_dims += indexTensorRank - 1;
} else {
if (contiguous) {
for (auto i : llvm::seq(0, firstIndexDim)) {
extractionIndices.push_back(
b.create<linalg::IndexOp>(loc, i));
}
for (auto i : llvm::seq(0, (int)indexTensorDims.size())) {
extractionIndices.push_back(
b.create<linalg::IndexOp>(loc, i + extra_dims));
castIntToIndex(b, loc, args[i]));
}
for (auto i :
llvm::seq((int)extractionIndices.size(), inputRank)) {
extractionIndices.push_back(b.create<linalg::IndexOp>(
loc, i + broadcastRank - replacedIndexCount));
}
} else {
int indexCount = 0, unchanged = 0;
for (auto i : llvm::seq(0, inputRank)) {
if (indexCount < replacedIndexCount &&
i == indexTensorDims[indexCount]) {
extractionIndices.push_back(
castIntToIndex(b, loc, args[indexCount++]));
continue;
}
extractionIndices.push_back(b.create<linalg::IndexOp>(
loc, broadcastRank + unchanged));
unchanged++;
}
}
Value extractedElement = b.create<tensor::ExtractOp>(
Expand Down
29 changes: 29 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1479,6 +1479,35 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
});
}

//===----------------------------------------------------------------------===//
// AtenAddTOp
//===----------------------------------------------------------------------===//

void AtenAddTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](AtenAddTOp op, PatternRewriter &rewriter) {
auto lhsListConstruct = op.a().getDefiningOp<Torch::PrimListConstructOp>();
if (!lhsListConstruct || isListPotentiallyMutated(lhsListConstruct))
return failure();

auto rhsListConstruct = op.b().getDefiningOp<Torch::PrimListConstructOp>();
if (!rhsListConstruct || isListPotentiallyMutated(rhsListConstruct))
return failure();

SmallVector<Value> concatenatedList;
for (auto a : lhsListConstruct.getOperands()) {
concatenatedList.push_back(a);
}
for (auto b : rhsListConstruct.getOperands()) {
concatenatedList.push_back(b);
}

rewriter.replaceOpWithNewOp<Torch::PrimListConstructOp>(op, op.getType(),
concatenatedList);
return success();
});
}

//===----------------------------------------------------------------------===//
// AtenEqIntListOp
//===----------------------------------------------------------------------===//
Expand Down
22 changes: 11 additions & 11 deletions lib/Dialect/Torch/Transforms/ShapeLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6590,30 +6590,30 @@ module {
%10 = torch.aten.len.t %arg1 : !torch.list<optional<list<int>>> -> !torch.int
%11 = torch.prim.ListConstruct %int9223372036854775807, %10 : (!torch.int, !torch.int) -> !torch.list<int>
%12 = torch.prim.min.self_int %11 : !torch.list<int> -> !torch.int
%13:3 = torch.prim.Loop %12, %true, init(%true, %int-1, %int-1) {
^bb0(%arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.int, %arg5: !torch.int):
%13:2 = torch.prim.Loop %12, %true, init(%true, %int-1) {
^bb0(%arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.int):
%16 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<optional<list<int>>>, !torch.int -> !torch.optional<list<int>>
%17 = torch.aten.__isnot__ %16, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%18:3 = torch.prim.If %17 -> (!torch.bool, !torch.int, !torch.int) {
%18:2 = torch.prim.If %17 -> (!torch.bool, !torch.int) {
%19 = torch.aten.eq.int %arg4, %int-1 : !torch.int, !torch.int -> !torch.bool
%20:3 = torch.prim.If %19 -> (!torch.bool, !torch.int, !torch.int) {
torch.prim.If.yield %arg3, %arg2, %arg2 : !torch.bool, !torch.int, !torch.int
%20:2 = torch.prim.If %19 -> (!torch.bool, !torch.int) {
torch.prim.If.yield %arg3, %arg2 : !torch.bool, !torch.int
} else {
%21 = torch.aten.sub.int %arg2, %arg5 : !torch.int, !torch.int -> !torch.int
%21 = torch.aten.sub.int %arg2, %arg4 : !torch.int, !torch.int -> !torch.int
%22 = torch.aten.ne.int %21, %int1 : !torch.int, !torch.int -> !torch.bool
%23 = torch.prim.If %22 -> (!torch.bool) {
torch.prim.If.yield %false : !torch.bool
} else {
torch.prim.If.yield %arg3 : !torch.bool
}
torch.prim.If.yield %23, %arg4, %arg5 : !torch.bool, !torch.int, !torch.int
torch.prim.If.yield %23, %arg4 : !torch.bool, !torch.int
}
torch.prim.If.yield %20#0, %20#1, %20#2 : !torch.bool, !torch.int, !torch.int
torch.prim.If.yield %20#0, %20#1 : !torch.bool, !torch.int
} else {
torch.prim.If.yield %arg3, %arg4, %arg5 : !torch.bool, !torch.int, !torch.int
torch.prim.If.yield %arg3, %arg4 : !torch.bool, !torch.int
}
torch.prim.Loop.condition %true, iter(%18#0, %18#1, %18#2 : !torch.bool, !torch.int, !torch.int)
} : (!torch.int, !torch.bool, !torch.bool, !torch.int, !torch.int) -> (!torch.bool, !torch.int, !torch.int)
torch.prim.Loop.condition %true, iter(%18#0, %18#1 : !torch.bool, !torch.int)
} : (!torch.int, !torch.bool, !torch.bool, !torch.int) -> (!torch.bool, !torch.int)
%14 = torch.aten.__not__ %13#0 : !torch.bool -> !torch.bool
%15 = torch.prim.If %14 -> (!torch.list<int>) {
%16 = torch.aten.add.t %6, %4 : !torch.list<int>, !torch.list<int> -> !torch.list<int>
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ class SimplifyShapeCalculationsPass
Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context);
AtenSizeOp::getCanonicalizationPatterns(patterns, context);
AtenLenTOp::getCanonicalizationPatterns(patterns, context);
AtenAddTOp::getCanonicalizationPatterns(patterns, context);

// TODO: Debug visitation order to make this more efficient.
// A single linear scan should suffice.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,7 @@ def aten〇pad(self: List[int], pad: List[int], mode: str = "constant", value: O
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), None]), # Explicit None value.
Invocation(TensorOfShape(2, 3, 4, 5), [None, LongTensorOfShape(4), LongTensorOfShape(4)]), # Indexing tensors on consecutive dimensions.
Invocation(TensorOfShape(2, 3, 4, 5), [None, LongTensorOfShape(4), None, LongTensorOfShape(4)]), # Indexing tensors on non-consecutive dimensions.
Invocation(TensorOfShape(2, 3, 4, 5), [LongTensorOfShape(4, 2), None, LongTensorOfShape(2)]), # Indexing tensors on non-consecutive dimensions.
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4, 5, 6), LongTensorOfShape(1, 5, 1)]), # Broadcasting of index tensors.
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4)]), # Fewer index tensors than dimensions.
ErrorInvocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(4), LongTensorOfShape(4)]), # More index tensors than dimensions.
Expand All @@ -1037,15 +1038,13 @@ def aten〇index〇Tensor(self: List[int], indices: List[Optional[List[int]]]) -
if len(unused_dim_sizes) == 0:
return broadcasted_shape

prev_index_tensor_location = -1
first_index_tensor_location = -1
index_tensors_are_together = True
for e, index_tensor_shape in enumerate(indices):
if index_tensor_shape is not None:
if first_index_tensor_location == -1:
first_index_tensor_location = e
prev_index_tensor_location = e
elif e - prev_index_tensor_location != 1:
elif e - first_index_tensor_location != 1:
index_tensors_are_together = False

if not index_tensors_are_together:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def emit_with_mutating_variants(key, **kwargs):
# List ops.
emit("aten::cat : (Tensor[], int) -> (Tensor)")
emit("aten::append.t : (t[], t) -> (t[])")
emit("aten::add.t : (t[], t[]) -> (t[])")
emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True)
emit("aten::eq.int_list : (int[], int[]) -> (bool)", has_folder=True)
emit("aten::list.t : (t[]) -> (t[])")
emit("aten::slice.t : (t[], int?, int?, int) -> (t[])")
Expand Down
Loading