Skip to content

Commit

Permalink
[mlir][ODS] Verify type constraints in Types and Attributes
Browse files Browse the repository at this point in the history
When a type/attribute is defined in TableGen, a type constraint can be used for parameters, but the type constraint verification was missing.

Example:
```
def TestTypeVerification : Test_Type<"TestTypeVerification"> {
  let parameters = (ins AnyTypeOf<[I16, I32]>:$param);
  // ...
}
```

No verification code was generated to ensure that `$param` is I16 or I32.

When type constraints a present, a new method will generated for types and attributes: `verifyInvariantsImpl`. (The naming is similar to op verifiers.) The user-provided verifier is called `verify` (no change). There is now a new entry point to type/attribute verification: `verifyInvariants`. This function calls both `verifyInvariantsImpl` and `verify`. If neither of those two verifications are present, the `verifyInvariants` function is not generated.

When a type/attribute is not defined in TableGen, but a verifier is needed, users can implement the `verifyInvariants` function. (This function was previously called `verify`.)
  • Loading branch information
matthias-springer committed Aug 7, 2024
1 parent 96d824d commit 8df1c42
Show file tree
Hide file tree
Showing 22 changed files with 234 additions and 101 deletions.
7 changes: 4 additions & 3 deletions mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,10 @@ class MMAMatrixType

/// Verify that shape and elementType are actually allowed for the
/// MMAMatrixType.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
StringRef operand);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
StringRef operand);

/// Get number of dims.
unsigned getNumDims() const;
Expand Down
12 changes: 7 additions & 5 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,13 @@ class LLVMStructType
ArrayRef<Type> getBody() const;

/// Verifies that the type about to be constructed is well-formed.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
StringRef, bool);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<Type> types, bool);
using Base::verify;
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, StringRef,
bool);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<Type> types, bool);
using Base::verifyInvariants;

/// Hooks for DataLayoutTypeInterface. Should not be called directly. Obtain a
/// DataLayout instance and query it instead.
Expand Down
43 changes: 22 additions & 21 deletions mlir/include/mlir/Dialect/Quant/QuantTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ class QuantizedType : public Type {
/// The maximum number of bits supported for storage types.
static constexpr unsigned MaxStorageBits = 32;

static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax);

/// Support method to enable LLVM-style type casting.
static bool classof(Type type);
Expand Down Expand Up @@ -214,10 +214,10 @@ class AnyQuantizedType
int64_t storageTypeMax);

/// Verifies construction invariants and issues errors/warnings.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax);
};

/// Represents a family of uniform, quantized types.
Expand Down Expand Up @@ -276,11 +276,11 @@ class UniformQuantizedType
int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax);

/// Verifies construction invariants and issues errors/warnings.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, double scale,
int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, double scale,
int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax);

/// Gets the scale term. The scale designates the difference between the real
/// values corresponding to consecutive quantized values differing by 1.
Expand Down Expand Up @@ -338,12 +338,12 @@ class UniformQuantizedPerAxisType
int64_t storageTypeMin, int64_t storageTypeMax);

/// Verifies construction invariants and issues errors/warnings.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType,
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin,
int64_t storageTypeMax);

/// Gets the quantization scales. The scales designate the difference between
/// the real values corresponding to consecutive quantized values differing
Expand Down Expand Up @@ -403,8 +403,9 @@ class CalibratedQuantizedType
double min, double max);

/// Verifies construction invariants and issues errors/warnings.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type expressedType, double min, double max);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
Type expressedType, double min, double max);
double getMin() const;
double getMax() const;
};
Expand Down
14 changes: 8 additions & 6 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,10 @@ class InterfaceVarABIAttr
/// Returns `spirv::StorageClass`.
std::optional<StorageClass> getStorageClass();

static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
IntegerAttr descriptorSet, IntegerAttr binding,
IntegerAttr storageClass);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
IntegerAttr descriptorSet, IntegerAttr binding,
IntegerAttr storageClass);

static constexpr StringLiteral name = "spirv.interface_var_abi";
};
Expand Down Expand Up @@ -128,9 +129,10 @@ class VerCapExtAttr
/// Returns the capabilities as an integer array attribute.
ArrayAttr getCapabilitiesAttr();

static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
IntegerAttr version, ArrayAttr capabilities,
ArrayAttr extensions);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
IntegerAttr version, ArrayAttr capabilities,
ArrayAttr extensions);

static constexpr StringLiteral name = "spirv.ver_cap_ext";
};
Expand Down
10 changes: 6 additions & 4 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,9 @@ class SampledImageType
static SampledImageType
getChecked(function_ref<InFlightDiagnostic()> emitError, Type imageType);

static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type imageType);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
Type imageType);

Type getImageType() const;

Expand Down Expand Up @@ -462,8 +463,9 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
static MatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
Type columnType, uint32_t columnCount);

static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type columnType, uint32_t columnCount);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
Type columnType, uint32_t columnCount);

/// Returns true if the matrix elements are vectors of float elements.
static bool isValidColumnType(Type columnType);
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/IR/CommonTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ class AnyTypeOf<list<Type> allowedTypeList, string summary = "",
summary),
cppClassName> {
list<Type> allowedTypes = allowedTypeList;
string cppType = cppClassName;
}

// A type that satisfies the constraints of all given types.
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/IR/Constraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ class TypeConstraint<Pred predicate, string summary = "",
Constraint<predicate, summary> {
// The name of the C++ Type class if known, or Type if not.
string cppClassName = cppClassNameParam;
// TODO: This field is sometimes called `cppClassName` and sometimes
// `cppType`. Use a single name consistently.
string cppType = cppClassNameParam;
}

// Subclass for constraints on an attribute.
Expand Down
10 changes: 6 additions & 4 deletions mlir/include/mlir/IR/StorageUniquerSupport.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
template <typename... Args>
static ConcreteT get(MLIRContext *ctx, Args &&...args) {
// Ensure that the invariants are correct for construction.
assert(
succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...)));
assert(succeeded(
ConcreteT::verifyInvariants(getDefaultDiagnosticEmitFn(ctx), args...)));
return UniquerT::template get<ConcreteT>(ctx, std::forward<Args>(args)...);
}

Expand All @@ -198,7 +198,7 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
static ConcreteT getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
MLIRContext *ctx, Args... args) {
// If the construction invariants fail then we return a null attribute.
if (failed(ConcreteT::verify(emitErrorFn, args...)))
if (failed(ConcreteT::verifyInvariants(emitErrorFn, args...)))
return ConcreteT();
return UniquerT::template get<ConcreteT>(ctx, args...);
}
Expand Down Expand Up @@ -226,7 +226,9 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {

/// Default implementation that just returns success.
template <typename... Args>
static LogicalResult verify(Args... args) {
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitErrorFn,
Args... args) {
return success();
}

Expand Down
19 changes: 8 additions & 11 deletions mlir/include/mlir/IR/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class AsmState;
/// Derived type classes are expected to implement several required
/// implementation hooks:
/// * Optional:
/// - static LogicalResult verify(
/// - static LogicalResult verifyInvariants(
/// function_ref<InFlightDiagnostic()> emitError,
/// Args... args)
/// * This method is invoked when calling the 'TypeBase::get/getChecked'
Expand Down Expand Up @@ -97,20 +97,17 @@ class Type {
bool operator!() const { return impl == nullptr; }

template <typename... Tys>
[[deprecated("Use mlir::isa<U>() instead")]]
bool isa() const;
[[deprecated("Use mlir::isa<U>() instead")]] bool isa() const;
template <typename... Tys>
[[deprecated("Use mlir::isa_and_nonnull<U>() instead")]]
bool isa_and_nonnull() const;
[[deprecated("Use mlir::isa_and_nonnull<U>() instead")]] bool
isa_and_nonnull() const;
template <typename U>
[[deprecated("Use mlir::dyn_cast<U>() instead")]]
U dyn_cast() const;
[[deprecated("Use mlir::dyn_cast<U>() instead")]] U dyn_cast() const;
template <typename U>
[[deprecated("Use mlir::dyn_cast_or_null<U>() instead")]]
U dyn_cast_or_null() const;
[[deprecated("Use mlir::dyn_cast_or_null<U>() instead")]] U
dyn_cast_or_null() const;
template <typename U>
[[deprecated("Use mlir::cast<U>() instead")]]
U cast() const;
[[deprecated("Use mlir::cast<U>() instead")]] U cast() const;

/// Return a unique identifier for the concrete type. This is used to support
/// dynamic type casting.
Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/TableGen/AttrOrTypeDef.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
#define MLIR_TABLEGEN_ATTRORTYPEDEF_H

#include "mlir/Support/LLVM.h"
#include "mlir/TableGen/Argument.h"
#include "mlir/TableGen/Builder.h"
#include "mlir/TableGen/Predicate.h"
#include "mlir/TableGen/Trait.h"

namespace llvm {
Expand Down Expand Up @@ -85,6 +87,8 @@ class AttrOrTypeParameter {
/// Get an optional C++ parameter parser.
std::optional<StringRef> getParser() const;

std::optional<TypeConstraint> getTypeConstraint() const;

/// Get an optional C++ parameter printer.
std::optional<StringRef> getPrinter() const;

Expand Down Expand Up @@ -198,6 +202,8 @@ class AttrOrTypeDef {
/// method.
bool genVerifyDecl() const;

bool genVerifyInvariantsImpl() const;

/// Returns the def's extra class declaration code.
std::optional<StringRef> getExtraDecls() const;

Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/TableGen/Class.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class MethodParameter {

/// Get the C++ type.
StringRef getType() const { return type; }
/// Get the C++ parameter name.
StringRef getName() const { return name; }
/// Returns true if the parameter has a default value.
bool hasDefaultValue() const { return !defaultValue.empty(); }

Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ bool MMAMatrixType::isValidElementType(Type elementType) {
}

LogicalResult
MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
StringRef operand) {
MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
StringRef operand) {
if (operand != "AOp" && operand != "BOp" && operand != "COp")
return emitError() << "operand expected to be one of AOp, BOp or COp";

Expand Down
12 changes: 6 additions & 6 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,8 +418,7 @@ LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries,

bool LLVMStructType::isValidElementType(Type type) {
return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>(
type);
LLVMFunctionType, LLVMTokenType>(type);
}

LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
Expand Down Expand Up @@ -492,14 +491,15 @@ ArrayRef<Type> LLVMStructType::getBody() const {
: getImpl()->getTypeList();
}

LogicalResult LLVMStructType::verify(function_ref<InFlightDiagnostic()>,
StringRef, bool) {
LogicalResult
LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()>, StringRef,
bool) {
return success();
}

LogicalResult
LLVMStructType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<Type> types, bool) {
LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<Type> types, bool) {
for (Type t : types)
if (!isValidElementType(t))
return emitError() << "invalid LLVM structure element type: " << t;
Expand Down
Loading

0 comments on commit 8df1c42

Please sign in to comment.