diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index bc97a5ae7d2f70..a2a317109e29d8 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2683,13 +2683,13 @@ def Vector_MatmulOp : Vector_Op<"matrix_multiply", [Pure, TCresVTEtIsSameAsOpBase<0, 1>>]>, Arguments<( // TODO: tighten vector element types that make sense. - ins VectorOfRankAndType<[1], + ins FixedVectorOfRankAndType<[1], [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$lhs, - VectorOfRankAndType<[1], + FixedVectorOfRankAndType<[1], [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$rhs, I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)>, Results<( - outs VectorOfRankAndType<[1], + outs FixedVectorOfRankAndType<[1], [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$res)> { let summary = "Vector matrix multiplication op that operates on flattened 1-D" @@ -2707,7 +2707,9 @@ def Vector_MatmulOp : Vector_Op<"matrix_multiply", [Pure, and multiplies them. The result matrix is returned embedded in the result vector. - Also see: + Note, the corresponding LLVM intrinsic, `@llvm.matrix.multiply.*`, does not + support scalable vectors. Hence, this Op is only available for fixed-width + vectors. Also see: http://llvm.org/docs/LangRef.html#llvm-matrix-multiply-intrinsic diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index 5b6ec167fa2420..2493f212a356a4 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -494,6 +494,14 @@ class VectorOfRankAndType allowedRanks, VectorOf.summary # VectorOfRank.summary, "::mlir::VectorType">; +// Fixed-width vector where the rank is from the given `allowedRanks` list and +// the type is from the given `allowedTypes` list +class FixedVectorOfRankAndType allowedRanks, + list allowedTypes> : AllOfType< + [FixedVectorOf, VectorOfRank], + FixedVectorOf.summary # VectorOfRank.summary, + "::mlir::VectorType">; + // Whether the number of elements of a vector is from the given // `allowedLengths` list class IsVectorOfLengthPred allowedLengths> : @@ -592,7 +600,7 @@ class VectorOfLengthAndType allowedLengths, // Any fixed-length vector where the number of elements is from the given // `allowedLengths` list and the type is from the given `allowedTypes` list class FixedVectorOfLengthAndType allowedLengths, - list allowedTypes> : AllOfType< + list allowedTypes> : AllOfType< [FixedVectorOf, FixedVectorOfLength], FixedVectorOf.summary # FixedVectorOfLength.summary, diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index ba1efe8b3c2d38..c95b8bd5ed6147 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1862,3 +1862,16 @@ func.func @invalid_step_2d() { vector.step : vector<2x4xf32> return } + +// ----- + +func.func @matrix_multiply_scalable(%a: vector<[4]xf64>, %b: vector<4xf64>) { + // expected-error @+1 {{'vector.matrix_multiply' op operand #0 must be fixed-length vector of signless integer or signed integer or index or floating-point values of ranks 1, but got 'vector<[4]xf64>'}} + %c = vector.matrix_multiply %a, %b { + lhs_rows = 2: i32, + lhs_columns = 2: i32 , + rhs_columns = 2: i32 } + : (vector<[4]xf64>, vector<4xf64>) -> vector<4xf64> + + return +}