Skip to content

Commit

Permalink
Add support for index.Tensor on dimensions other than the first
Browse files Browse the repository at this point in the history
This patch still only supports a single indexing tensor.
  • Loading branch information
qedawkins committed Jul 16, 2022
1 parent baa4383 commit c354787
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 13 deletions.
43 changes: 30 additions & 13 deletions lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,19 +262,28 @@ class ConvertAtenIndexTensorOp : public OpConversionPattern<AtenIndexTensorOp> {
return rewriter.notifyMatchFailure(
op, "unimplemented: the indices list is not from a list construct");
}
if (indicesTuple.size() != 1) {
return rewriter.notifyMatchFailure(
op, "unimplemented: only one index tensor is supported");
}

SmallVector<Value> indicesVal =
getTypeConvertedValues(rewriter, loc, getTypeConverter(), indicesTuple);
Value indexTensor = indicesVal[0];
if (failed(checkNotNone(rewriter, op, indexTensor))) {

int indexTensorDim = -1;
for (auto i : llvm::seq(0, (int)indicesVal.size())) {
Value index = indicesVal[i];
if (!index || failed(checkNotNone(rewriter, op, index)))
continue;
if (indexTensorDim >= 0) {
return rewriter.notifyMatchFailure(
op, "unimplemented: only one index tensor allowed");
}
indexTensorDim = i;
}

if (indexTensorDim == -1) {
return rewriter.notifyMatchFailure(
op, "unimplemented: index tensor must not be None");
}

Value indexTensor = indicesVal[indexTensorDim];
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
RankedTensorType indexTensorType =
indexTensor.getType().cast<RankedTensorType>();
Expand All @@ -289,10 +298,13 @@ class ConvertAtenIndexTensorOp : public OpConversionPattern<AtenIndexTensorOp> {
// index tensor and that it is indexing the first dimension of the
// input tensor. The calculation for arbitrary inputs is much more complex.
SmallVector<Value> resultShape;
for (auto i : llvm::seq(0, indexTensorDim)) {
resultShape.push_back(getDimOp(rewriter, loc, input, i));
}
for (auto i : llvm::seq(0, indexTensorRank)) {
resultShape.push_back(getDimOp(rewriter, loc, indexTensor, i));
}
for (auto i : llvm::seq(1, inputRank)) {
for (auto i : llvm::seq(indexTensorDim + 1, inputRank)) {
resultShape.push_back(getDimOp(rewriter, loc, input, i));
}
int resultRank = resultShape.size();
Expand All @@ -302,7 +314,7 @@ class ConvertAtenIndexTensorOp : public OpConversionPattern<AtenIndexTensorOp> {
SmallVector<AffineExpr> indicesExpr, resultExpr;
SmallVector<StringRef> iteratorTypes;

for (auto i : llvm::seq(0, indexTensorRank))
for (auto i : llvm::seq(indexTensorDim, indexTensorDim + indexTensorRank))
indicesExpr.push_back(rewriter.getAffineDimExpr(i));
for (auto i : llvm::seq(0, resultRank)) {
resultExpr.push_back(rewriter.getAffineDimExpr(i));
Expand All @@ -316,11 +328,16 @@ class ConvertAtenIndexTensorOp : public OpConversionPattern<AtenIndexTensorOp> {
loc, initTensor.getType(), indexTensor, initTensor,
indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
SmallVector<Value> extractionIndices{
castIntToIndex(b, loc, args[0])};
for (auto i : llvm::seq(1, inputRank)) {
extractionIndices.push_back(b.create<linalg::IndexOp>(
loc, i + indexTensorRank - 1));
Value index = castIntToIndex(b, loc, args[0]);
SmallVector<Value> extractionIndices;
int extra_dims = 0;
for (auto i : llvm::seq(0, inputRank)) {
if (i == indexTensorDim) {
extractionIndices.push_back(index);
extra_dims += indexTensorRank - 1;
} else {
extractionIndices.push_back(b.create<linalg::IndexOp>(loc, i + extra_dims));
}
}
Value extractedElement = b.create<tensor::ExtractOp>(
loc, input, extractionIndices);
Expand Down
22 changes: 22 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1632,6 +1632,28 @@ def IndexTensorModule3dInput_basic(module, tu: TestUtils):
# ==============================================================================


class IndexTensorSelectDimModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1], torch.int64, True),
])
def forward(self, a, ind):
return torch.ops.aten.index(a, (None, ind, None))


@register_test_case(module_factory=lambda: IndexTensorSelectDimModule())
def IndexTensorSelectDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 6), torch.randint(3, (2, 3)))

# ==============================================================================


class SquareModule(torch.nn.Module):

def __init__(self):
Expand Down

0 comments on commit c354787

Please sign in to comment.