Skip to content

Commit

Permalink
[mlir][ODS] Verify type constraints in Types and Attributes
Browse files Browse the repository at this point in the history
When a type/attribute is defined in TableGen, a type constraint can be used for parameters, but the type constraint verification was missing.

Example:
```
def TestTypeVerification : Test_Type<"TestTypeVerification"> {
  let parameters = (ins AnyTypeOf<[I16, I32]>:$param);
  // ...
}
```

No verification code was generated to ensure that `$param` is I16 or I32.

When type constraints a present, a new method will generated for types and attributes: `verifyInvariantsImpl`. (The naming is similar to op verifiers.) The user-provided verifier is called `verify` (no change). There is now a new entry point to type/attribute verification: `verifyInvariants`. This function calls both `verifyInvariantsImpl` and `verify`. If neither of those two verifications are present, the `verifyInvariants` function is not generated.

When a type/attribute is not defined in TableGen, but a verifier is needed, users can implement the `verifyInvariants` function. (This function was previously called `verify`.)
  • Loading branch information
matthias-springer committed Aug 9, 2024
1 parent 9968cd5 commit 27f3ffa
Show file tree
Hide file tree
Showing 22 changed files with 304 additions and 93 deletions.
7 changes: 4 additions & 3 deletions mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,10 @@ class MMAMatrixType

/// Verify that shape and elementType are actually allowed for the
/// MMAMatrixType.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
StringRef operand);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
StringRef operand);

/// Get number of dims.
unsigned getNumDims() const;
Expand Down
12 changes: 7 additions & 5 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,13 @@ class LLVMStructType
ArrayRef<Type> getBody() const;

/// Verifies that the type about to be constructed is well-formed.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
StringRef, bool);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<Type> types, bool);
using Base::verify;
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, StringRef,
bool);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<Type> types, bool);
using Base::verifyInvariants;

/// Hooks for DataLayoutTypeInterface. Should not be called directly. Obtain a
/// DataLayout instance and query it instead.
Expand Down
43 changes: 22 additions & 21 deletions mlir/include/mlir/Dialect/Quant/QuantTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ class QuantizedType : public Type {
/// The maximum number of bits supported for storage types.
static constexpr unsigned MaxStorageBits = 32;

static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax);

/// Support method to enable LLVM-style type casting.
static bool classof(Type type);
Expand Down Expand Up @@ -214,10 +214,10 @@ class AnyQuantizedType
int64_t storageTypeMax);

/// Verifies construction invariants and issues errors/warnings.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax);
};

/// Represents a family of uniform, quantized types.
Expand Down Expand Up @@ -276,11 +276,11 @@ class UniformQuantizedType
int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax);

/// Verifies construction invariants and issues errors/warnings.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, double scale,
int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, double scale,
int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax);

/// Gets the scale term. The scale designates the difference between the real
/// values corresponding to consecutive quantized values differing by 1.
Expand Down Expand Up @@ -338,12 +338,12 @@ class UniformQuantizedPerAxisType
int64_t storageTypeMin, int64_t storageTypeMax);

/// Verifies construction invariants and issues errors/warnings.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType,
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin,
int64_t storageTypeMax);

/// Gets the quantization scales. The scales designate the difference between
/// the real values corresponding to consecutive quantized values differing
Expand Down Expand Up @@ -403,8 +403,9 @@ class CalibratedQuantizedType
double min, double max);

/// Verifies construction invariants and issues errors/warnings.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type expressedType, double min, double max);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
Type expressedType, double min, double max);
double getMin() const;
double getMax() const;
};
Expand Down
14 changes: 8 additions & 6 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,10 @@ class InterfaceVarABIAttr
/// Returns `spirv::StorageClass`.
std::optional<StorageClass> getStorageClass();

static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
IntegerAttr descriptorSet, IntegerAttr binding,
IntegerAttr storageClass);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
IntegerAttr descriptorSet, IntegerAttr binding,
IntegerAttr storageClass);

static constexpr StringLiteral name = "spirv.interface_var_abi";
};
Expand Down Expand Up @@ -128,9 +129,10 @@ class VerCapExtAttr
/// Returns the capabilities as an integer array attribute.
ArrayAttr getCapabilitiesAttr();

static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
IntegerAttr version, ArrayAttr capabilities,
ArrayAttr extensions);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
IntegerAttr version, ArrayAttr capabilities,
ArrayAttr extensions);

static constexpr StringLiteral name = "spirv.ver_cap_ext";
};
Expand Down
10 changes: 6 additions & 4 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,9 @@ class SampledImageType
static SampledImageType
getChecked(function_ref<InFlightDiagnostic()> emitError, Type imageType);

static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type imageType);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
Type imageType);

Type getImageType() const;

Expand Down Expand Up @@ -462,8 +463,9 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
static MatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
Type columnType, uint32_t columnCount);

static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type columnType, uint32_t columnCount);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
Type columnType, uint32_t columnCount);

/// Returns true if the matrix elements are vectors of float elements.
static bool isValidColumnType(Type columnType);
Expand Down
10 changes: 6 additions & 4 deletions mlir/include/mlir/IR/StorageUniquerSupport.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
template <typename... Args>
static ConcreteT get(MLIRContext *ctx, Args &&...args) {
// Ensure that the invariants are correct for construction.
assert(
succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...)));
assert(succeeded(
ConcreteT::verifyInvariants(getDefaultDiagnosticEmitFn(ctx), args...)));
return UniquerT::template get<ConcreteT>(ctx, std::forward<Args>(args)...);
}

Expand All @@ -198,7 +198,7 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
static ConcreteT getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
MLIRContext *ctx, Args... args) {
// If the construction invariants fail then we return a null attribute.
if (failed(ConcreteT::verify(emitErrorFn, args...)))
if (failed(ConcreteT::verifyInvariants(emitErrorFn, args...)))
return ConcreteT();
return UniquerT::template get<ConcreteT>(ctx, args...);
}
Expand Down Expand Up @@ -226,7 +226,9 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {

/// Default implementation that just returns success.
template <typename... Args>
static LogicalResult verify(Args... args) {
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitErrorFn,
Args... args) {
return success();
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/IR/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class AsmState;
/// Derived type classes are expected to implement several required
/// implementation hooks:
/// * Optional:
/// - static LogicalResult verify(
/// - static LogicalResult verifyInvariants(
/// function_ref<InFlightDiagnostic()> emitError,
/// Args... args)
/// * This method is invoked when calling the 'TypeBase::get/getChecked'
Expand Down
8 changes: 8 additions & 0 deletions mlir/include/mlir/TableGen/AttrOrTypeDef.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "mlir/Support/LLVM.h"
#include "mlir/TableGen/Builder.h"
#include "mlir/TableGen/Constraint.h"
#include "mlir/TableGen/Trait.h"

namespace llvm {
Expand Down Expand Up @@ -85,6 +86,9 @@ class AttrOrTypeParameter {
/// Get an optional C++ parameter parser.
std::optional<StringRef> getParser() const;

/// If this is a type constraint, return it.
std::optional<Constraint> getConstraint() const;

/// Get an optional C++ parameter printer.
std::optional<StringRef> getPrinter() const;

Expand Down Expand Up @@ -198,6 +202,10 @@ class AttrOrTypeDef {
/// method.
bool genVerifyDecl() const;

/// Return true if we need to generate any type constraint verification and
/// the getChecked method.
bool genVerifyInvariantsImpl() const;

/// Returns the def's extra class declaration code.
std::optional<StringRef> getExtraDecls() const;

Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/TableGen/Class.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class MethodParameter {

/// Get the C++ type.
StringRef getType() const { return type; }
/// Get the C++ parameter name.
StringRef getName() const { return name; }
/// Returns true if the parameter has a default value.
bool hasDefaultValue() const { return !defaultValue.empty(); }

Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ bool MMAMatrixType::isValidElementType(Type elementType) {
}

LogicalResult
MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
StringRef operand) {
MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
StringRef operand) {
if (operand != "AOp" && operand != "BOp" && operand != "COp")
return emitError() << "operand expected to be one of AOp, BOp or COp";

Expand Down
12 changes: 6 additions & 6 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,8 +418,7 @@ LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries,

bool LLVMStructType::isValidElementType(Type type) {
return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>(
type);
LLVMFunctionType, LLVMTokenType>(type);
}

LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
Expand Down Expand Up @@ -492,14 +491,15 @@ ArrayRef<Type> LLVMStructType::getBody() const {
: getImpl()->getTypeList();
}

LogicalResult LLVMStructType::verify(function_ref<InFlightDiagnostic()>,
StringRef, bool) {
LogicalResult
LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()>, StringRef,
bool) {
return success();
}

LogicalResult
LLVMStructType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<Type> types, bool) {
LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<Type> types, bool) {
for (Type t : types)
if (!isValidElementType(t))
return emitError() << "invalid LLVM structure element type: " << t;
Expand Down
39 changes: 22 additions & 17 deletions mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ bool QuantizedType::classof(Type type) {
}

LogicalResult
QuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType, Type expressedType,
int64_t storageTypeMin, int64_t storageTypeMax) {
QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax) {
// Verify that the storage type is integral.
// This restriction may be lifted at some point in favor of using bf16
// or f16 as exact representations on hardware where that is advantageous.
Expand Down Expand Up @@ -233,11 +234,13 @@ AnyQuantizedType::getChecked(function_ref<InFlightDiagnostic()> emitError,
}

LogicalResult
AnyQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType, Type expressedType,
int64_t storageTypeMin, int64_t storageTypeMax) {
if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
storageTypeMin, storageTypeMax))) {
AnyQuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax) {
if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
expressedType, storageTypeMin,
storageTypeMax))) {
return failure();
}

Expand Down Expand Up @@ -268,12 +271,13 @@ UniformQuantizedType UniformQuantizedType::getChecked(
storageTypeMin, storageTypeMax);
}

LogicalResult UniformQuantizedType::verify(
LogicalResult UniformQuantizedType::verifyInvariants(
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, double scale, int64_t zeroPoint,
int64_t storageTypeMin, int64_t storageTypeMax) {
if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
storageTypeMin, storageTypeMax))) {
if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
expressedType, storageTypeMin,
storageTypeMax))) {
return failure();
}

Expand Down Expand Up @@ -321,13 +325,14 @@ UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
quantizedDimension, storageTypeMin, storageTypeMax);
}

LogicalResult UniformQuantizedPerAxisType::verify(
LogicalResult UniformQuantizedPerAxisType::verifyInvariants(
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax) {
if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
storageTypeMin, storageTypeMax))) {
if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
expressedType, storageTypeMin,
storageTypeMax))) {
return failure();
}

Expand Down Expand Up @@ -380,9 +385,9 @@ CalibratedQuantizedType CalibratedQuantizedType::getChecked(
min, max);
}

LogicalResult
CalibratedQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
Type expressedType, double min, double max) {
LogicalResult CalibratedQuantizedType::verifyInvariants(
function_ref<InFlightDiagnostic()> emitError, Type expressedType,
double min, double max) {
// Verify that the expressed type is floating point.
// If this restriction is ever eliminated, the parser/printer must be
// extended.
Expand Down
9 changes: 4 additions & 5 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ spirv::InterfaceVarABIAttr::getStorageClass() {
return std::nullopt;
}

LogicalResult spirv::InterfaceVarABIAttr::verify(
LogicalResult spirv::InterfaceVarABIAttr::verifyInvariants(
function_ref<InFlightDiagnostic()> emitError, IntegerAttr descriptorSet,
IntegerAttr binding, IntegerAttr storageClass) {
if (!descriptorSet.getType().isSignlessInteger(32))
Expand Down Expand Up @@ -257,10 +257,9 @@ ArrayAttr spirv::VerCapExtAttr::getCapabilitiesAttr() {
return llvm::cast<ArrayAttr>(getImpl()->capabilities);
}

LogicalResult
spirv::VerCapExtAttr::verify(function_ref<InFlightDiagnostic()> emitError,
IntegerAttr version, ArrayAttr capabilities,
ArrayAttr extensions) {
LogicalResult spirv::VerCapExtAttr::verifyInvariants(
function_ref<InFlightDiagnostic()> emitError, IntegerAttr version,
ArrayAttr capabilities, ArrayAttr extensions) {
if (!version.getType().isSignlessInteger(32))
return emitError() << "expected 32-bit integer for version";

Expand Down
Loading

0 comments on commit 27f3ffa

Please sign in to comment.