Skip to content

Commit

Permalink
[mlir][vector] Disable vector.matrix_multiply for scalable vectors (#…
Browse files Browse the repository at this point in the history
…102573)

Disables `vector.matrix_multiply` for scalable vectors. As per the docs:

>  This is the counterpart of llvm.matrix.multiply in MLIR

I'm not aware of any use of matrix-multiply intrinsics in the context of
scalable vectors, hence disabling.
  • Loading branch information
banach-space authored Aug 9, 2024
1 parent 574e958 commit 4c19de9
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 5 deletions.
10 changes: 6 additions & 4 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2688,13 +2688,13 @@ def Vector_MatmulOp : Vector_Op<"matrix_multiply", [Pure,
TCresVTEtIsSameAsOpBase<0, 1>>]>,
Arguments<(
// TODO: tighten vector element types that make sense.
ins VectorOfRankAndType<[1],
ins FixedVectorOfRankAndType<[1],
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$lhs,
VectorOfRankAndType<[1],
FixedVectorOfRankAndType<[1],
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$rhs,
I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)>,
Results<(
outs VectorOfRankAndType<[1],
outs FixedVectorOfRankAndType<[1],
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$res)>
{
let summary = "Vector matrix multiplication op that operates on flattened 1-D"
Expand All @@ -2712,7 +2712,9 @@ def Vector_MatmulOp : Vector_Op<"matrix_multiply", [Pure,
<rhs_columns> and multiplies them. The result matrix is returned embedded in
the result vector.

Also see:
Note, the corresponding LLVM intrinsic, `@llvm.matrix.multiply.*`, does not
support scalable vectors. Hence, this Op is only available for fixed-width
vectors. Also see:

http://llvm.org/docs/LangRef.html#llvm-matrix-multiply-intrinsic

Expand Down
10 changes: 9 additions & 1 deletion mlir/include/mlir/IR/CommonTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,14 @@ class VectorOfRankAndType<list<int> allowedRanks,
VectorOf<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
"::mlir::VectorType">;

// Fixed-width vector where the rank is from the given `allowedRanks` list and
// the type is from the given `allowedTypes` list
class FixedVectorOfRankAndType<list<int> allowedRanks,
list<Type> allowedTypes> : AllOfType<
[FixedVectorOf<allowedTypes>, VectorOfRank<allowedRanks>],
FixedVectorOf<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
"::mlir::VectorType">;

// Whether the number of elements of a vector is from the given
// `allowedLengths` list
class IsVectorOfLengthPred<list<int> allowedLengths> :
Expand Down Expand Up @@ -592,7 +600,7 @@ class VectorOfLengthAndType<list<int> allowedLengths,
// Any fixed-length vector where the number of elements is from the given
// `allowedLengths` list and the type is from the given `allowedTypes` list
class FixedVectorOfLengthAndType<list<int> allowedLengths,
list<Type> allowedTypes> : AllOfType<
list<Type> allowedTypes> : AllOfType<
[FixedVectorOf<allowedTypes>, FixedVectorOfLength<allowedLengths>],
FixedVectorOf<allowedTypes>.summary #
FixedVectorOfLength<allowedLengths>.summary,
Expand Down
13 changes: 13 additions & 0 deletions mlir/test/Dialect/Vector/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1862,3 +1862,16 @@ func.func @invalid_step_2d() {
vector.step : vector<2x4xf32>
return
}

// -----

func.func @matrix_multiply_scalable(%a: vector<[4]xf64>, %b: vector<4xf64>) {
// expected-error @+1 {{'vector.matrix_multiply' op operand #0 must be fixed-length vector of signless integer or signed integer or index or floating-point values of ranks 1, but got 'vector<[4]xf64>'}}
%c = vector.matrix_multiply %a, %b {
lhs_rows = 2: i32,
lhs_columns = 2: i32 ,
rhs_columns = 2: i32 }
: (vector<[4]xf64>, vector<4xf64>) -> vector<4xf64>

return
}

0 comments on commit 4c19de9

Please sign in to comment.