Skip to content

Commit

Permalink
[mlir][vector] Clarify the semantics of BroadcastOp (#101928)
Browse files Browse the repository at this point in the history
Clarifies the semantics of `vector.broadcast` in the context of scalable
vectors. In particular, broadcasting a unit scalable dim, `[1]`, is not
valid unless there's a match between the output and the input dims.
See the examples below for an illustration:

```mlir
// VALID
 %0 = vector.broadcast %arg0 : vector<[1]xf32> to vector<4x[1]xf32>
// INVALID
 %0 = vector.broadcast %arg0 : vector<[1]xf32> to vector<[4]xf32>
// VALID FIXED-WIDTH EQUIVALENT
 %0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32>
```

Documentation, the Op verifier and tests are updated accordingly.
  • Loading branch information
banach-space authored Aug 8, 2024
1 parent 3423470 commit 1919db9
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 14 deletions.
7 changes: 6 additions & 1 deletion mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,14 @@ enum class BroadcastableToResult {
DimensionMismatch = 2,
SourceTypeNotAVector = 3
};

struct VectorDim {
int64_t dim;
bool isScalable;
};
BroadcastableToResult
isBroadcastableTo(Type srcType, VectorType dstVectorType,
std::pair<int, int> *mismatchingDims = nullptr);
std::pair<VectorDim, VectorDim> *mismatchingDims = nullptr);

/// Collect a set of vector-to-vector canonicalization patterns.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,8 @@ def Vector_BroadcastOp :
s_1 x .. x s_j x .. x s_k
<duplication> <potential stretch>
```
* in addition, any scalable unit dimension, `[1]`, must match exactly.

The source operand is duplicated over all the missing leading dimensions
and stretched over the trailing dimensions where the source has a non-equal
dimension of 1. These rules imply that any scalar broadcast (k=0) to any
Expand Down
50 changes: 37 additions & 13 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2371,9 +2371,9 @@ Value BroadcastOp::createOrFoldBroadcastOp(
return res;
}

BroadcastableToResult
mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
std::pair<int, int> *mismatchingDims) {
BroadcastableToResult mlir::vector::isBroadcastableTo(
Type srcType, VectorType dstVectorType,
std::pair<VectorDim, VectorDim> *mismatchingDims) {
// Broadcast scalar to vector of the same element type.
if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
Expand All @@ -2390,13 +2390,31 @@ mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
// Source has an exact match or singleton value for all trailing dimensions
// (all leading dimensions are simply duplicated).
int64_t lead = dstRank - srcRank;
for (int64_t r = 0; r < srcRank; ++r) {
int64_t srcDim = srcVectorType.getDimSize(r);
int64_t dstDim = dstVectorType.getDimSize(lead + r);
if (srcDim != 1 && srcDim != dstDim) {
if (mismatchingDims) {
mismatchingDims->first = srcDim;
mismatchingDims->second = dstDim;
for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
// Have mismatching dims (in the sense of vector.broadcast semantics) been
// encountered?
bool foundMismatchingDims = false;

// Check fixed-width dims.
int64_t srcDim = srcVectorType.getDimSize(dimIdx);
int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
if (srcDim != 1 && srcDim != dstDim)
foundMismatchingDims = true;

// Check scalable flags.
bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
(srcDimScalableFlag != dstDimScalableFlag))
foundMismatchingDims = true;

if (foundMismatchingDims) {
if (mismatchingDims != nullptr) {
mismatchingDims->first.dim = srcDim;
mismatchingDims->first.isScalable = srcDimScalableFlag;

mismatchingDims->second.dim = dstDim;
mismatchingDims->second.isScalable = dstDimScalableFlag;
}
return BroadcastableToResult::DimensionMismatch;
}
Expand All @@ -2406,16 +2424,22 @@ mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
}

LogicalResult BroadcastOp::verify() {
std::pair<int, int> mismatchingDims;
std::pair<VectorDim, VectorDim> mismatchingDims;
BroadcastableToResult res = isBroadcastableTo(
getSourceType(), getResultVectorType(), &mismatchingDims);
if (res == BroadcastableToResult::Success)
return success();
if (res == BroadcastableToResult::SourceRankHigher)
return emitOpError("source rank higher than destination rank");
if (res == BroadcastableToResult::DimensionMismatch)
if (res == BroadcastableToResult::DimensionMismatch) {
return emitOpError("dimension mismatch (")
<< mismatchingDims.first << " vs. " << mismatchingDims.second << ")";
<< (mismatchingDims.first.isScalable ? "[" : "")
<< mismatchingDims.first.dim
<< (mismatchingDims.first.isScalable ? "]" : "") << " vs. "
<< (mismatchingDims.second.isScalable ? "[" : "")
<< mismatchingDims.second.dim
<< (mismatchingDims.second.isScalable ? "]" : "") << ")";
}
if (res == BroadcastableToResult::SourceTypeNotAVector)
return emitOpError("source type is not a vector");
llvm_unreachable("unexpected vector.broadcast op error");
Expand Down
21 changes: 21 additions & 0 deletions mlir/test/Dialect/Vector/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,27 @@ func.func @broadcast_dim2_mismatch(%arg0: vector<4x8xf32>) {

// -----

func.func @broadcast_scalable_unit_dim(%arg0: vector<[1]xf32>) {
// expected-error@+1 {{'vector.broadcast' op dimension mismatch ([1] vs. [4])}}
%0 = vector.broadcast %arg0 : vector<[1]xf32> to vector<[4]xf32>
}

// -----

func.func @broadcast_fixed_to_scalable(%arg0: vector<2xf32>) {
// expected-error@+1 {{'vector.broadcast' op dimension mismatch (2 vs. [2])}}
%0 = vector.broadcast %arg0 : vector<2xf32> to vector<[2]xf32>
}

// -----

func.func @broadcast_scalable_to_fixed(%arg0: vector<[1]xf32>) {
// expected-error@+1 {{'vector.broadcast' op dimension mismatch ([1] vs. 1)}}
%0 = vector.broadcast %arg0 : vector<[1]xf32> to vector<4x1xf32>
}

// -----

func.func @broadcast_unknown(%arg0: memref<4x8xf32>) {
// expected-error@+1 {{'vector.broadcast' op source type is not a vector}}
%1 = vector.broadcast %arg0 : memref<4x8xf32> to vector<1x8xf32>
Expand Down

0 comments on commit 1919db9

Please sign in to comment.