Skip to content

Commit

Permalink
[mlir][vector] Fix invalid LoadOp indices being created (#75519)
Browse files Browse the repository at this point in the history
Fixes #71326.

The cause of the issue was that a new `LoadOp` was created which looked
something like:
```mlir
%arg4 = 
func.func main(%arg1 : index, %arg2 : index) {
  %alloca_0 = memref.alloca() : memref<vector<1x32xi1>>
  %1 = vector.type_cast %alloca_0 : memref<vector<1x32xi1>> to memref<1xvector<32xi1>>
  %2 = memref.load %1[%arg1, %arg2] : memref<1xvector<32xi1>>
  return
}
```
which crashed inside the `LoadOp::verify`. Note here that `%alloca_0` is
0 dimensional, `%1` has one dimension, but `memref.load` tries to index
`%1` with two indices.

This is now fixed by using the fact that `unpackOneDim` always unpacks
one dim


https://github.com/llvm/llvm-project/blob/1bce61e6b01b38e04260be4f422bbae59c34c766/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp#L897-L903

and so the `loadOp` should just index only one dimension.

---------

Co-authored-by: Benjamin Maxwell <macdue@dueutil.tech>
  • Loading branch information
rikhuijzer and MacDue authored Dec 17, 2023
1 parent a3952b4 commit 3a1ae2f
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 12 deletions.
27 changes: 17 additions & 10 deletions mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ struct Strategy<TransferReadOp> {
/// Retrieve the indices of the current StoreOp that stores into the buffer.
static void getBufferIndices(TransferReadOp xferOp,
SmallVector<Value, 8> &indices) {
auto storeOp = getStoreOp(xferOp);
memref::StoreOp storeOp = getStoreOp(xferOp);
auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
indices.append(prevIndices.begin(), prevIndices.end());
}
Expand Down Expand Up @@ -591,8 +591,8 @@ struct PrepareTransferReadConversion
if (checkPrepareXferOp(xferOp, options).failed())
return failure();

auto buffers = allocBuffers(rewriter, xferOp);
auto *newXfer = rewriter.clone(*xferOp.getOperation());
BufferAllocs buffers = allocBuffers(rewriter, xferOp);
Operation *newXfer = rewriter.clone(*xferOp.getOperation());
newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
if (xferOp.getMask()) {
dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign(
Expand Down Expand Up @@ -885,8 +885,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
// If the xferOp has a mask: Find and cast mask buffer.
Value castedMaskBuffer;
if (xferOp.getMask()) {
auto maskBuffer = getMaskBuffer(xferOp);
auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType());
Value maskBuffer = getMaskBuffer(xferOp);
if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
// Do not unpack a dimension of the mask, if:
// * To-be-unpacked transfer op dimension is a broadcast.
Expand All @@ -897,7 +896,8 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
} else {
// It's safe to assume the mask buffer can be unpacked if the data
// buffer was unpacked.
auto castedMaskType = *unpackOneDim(maskBufferType);
auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType());
MemRefType castedMaskType = *unpackOneDim(maskBufferType);
castedMaskBuffer =
locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
}
Expand Down Expand Up @@ -938,11 +938,18 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
b.setInsertionPoint(newXfer); // Insert load before newXfer.

SmallVector<Value, 8> loadIndices;
Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
// In case of broadcast: Use same indices to load from memref
// as before.
if (!xferOp.isBroadcastDim(0))
if (auto memrefType =
castedMaskBuffer.getType().dyn_cast<MemRefType>()) {
// If castedMaskBuffer is a memref, then one dim was
// unpacked; see above.
loadIndices.push_back(iv);
} else {
Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
// In case of broadcast: Use same indices to load from
// memref as before.
if (!xferOp.isBroadcastDim(0))
loadIndices.push_back(iv);
}

auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
loadIndices);
Expand Down
6 changes: 4 additions & 2 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1615,8 +1615,10 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
//===----------------------------------------------------------------------===//

LogicalResult LoadOp::verify() {
if (getNumOperands() != 1 + getMemRefType().getRank())
return emitOpError("incorrect number of indices for load");
if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) {
return emitOpError("incorrect number of indices for load, expected ")
<< getMemRefType().getRank() << " but got " << getIndices().size();
}
return success();
}

Expand Down
17 changes: 17 additions & 0 deletions mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,23 @@ func.func @cannot_lower_transfer_read_with_leading_scalable(%arg0: memref<?x4xf3

// -----

// Check that the `unpackOneDim` case in the `TransferOpConversion` generates valid indices for the LoadOp.

#map1 = affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>
func.func @does_not_crash_on_unpack_one_dim(%subview: memref<1x1x1x1xi32>, %mask: vector<1x1xi1>) -> vector<1x1x1x1xi32> {
%c0 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
%3 = vector.transfer_read %subview[%c0, %c0, %c0, %c0], %c0_i32, %mask {permutation_map = #map1}
: memref<1x1x1x1xi32>, vector<1x1x1x1xi32>
return %3 : vector<1x1x1x1xi32>
}
// CHECK-LABEL: func.func @does_not_crash_on_unpack_one_dim
// CHECK: %[[ALLOCA_0:.*]] = memref.alloca() : memref<vector<1x1xi1>>
// CHECK: %[[MASK:.*]] = vector.type_cast %[[ALLOCA_0]] : memref<vector<1x1xi1>> to memref<1xvector<1xi1>>
// CHECK: memref.load %[[MASK]][%{{.*}}] : memref<1xvector<1xi1>>

// -----

// FULL-UNROLL-LABEL: @cannot_fully_unroll_transfer_write_of_nd_scalable_vector
func.func @cannot_fully_unroll_transfer_write_of_nd_scalable_vector(%vec: vector<[4]x[4]xf32>, %memref: memref<?x?xf32>) {
// FULL-UNROLL-NOT: vector.extract
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/Dialect/MemRef/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,15 @@ func.func @bad_alloc_wrong_symbol_count() {

// -----

func.func @load_invalid_memref_indexes() {
%0 = memref.alloca() : memref<10xi32>
%c0 = arith.constant 0 : index
// expected-error@+1 {{incorrect number of indices for load, expected 1 but got 2}}
%1 = memref.load %0[%c0, %c0] : memref<10xi32>
}

// -----

func.func @test_store_zero_results() {
^bb0:
%0 = memref.alloc() : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1>
Expand Down

0 comments on commit 3a1ae2f

Please sign in to comment.