diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h index 96e1935bd0a8414..57acd72610415fe 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h @@ -148,9 +148,10 @@ class MMAMatrixType /// Verify that shape and elementType are actually allowed for the /// MMAMatrixType. - static LogicalResult verify(function_ref emitError, - ArrayRef shape, Type elementType, - StringRef operand); + static LogicalResult + verifyInvariants(function_ref emitError, + ArrayRef shape, Type elementType, + StringRef operand); /// Get number of dims. unsigned getNumDims() const; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h index 1befdfa74f67c53..2ea589a7c4c3bdc 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -180,11 +180,13 @@ class LLVMStructType ArrayRef getBody() const; /// Verifies that the type about to be constructed is well-formed. - static LogicalResult verify(function_ref emitError, - StringRef, bool); - static LogicalResult verify(function_ref emitError, - ArrayRef types, bool); - using Base::verify; + static LogicalResult + verifyInvariants(function_ref emitError, StringRef, + bool); + static LogicalResult + verifyInvariants(function_ref emitError, + ArrayRef types, bool); + using Base::verifyInvariants; /// Hooks for DataLayoutTypeInterface. Should not be called directly. Obtain a /// DataLayout instance and query it instead. diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/QuantTypes.h index de5aed0a91a2096..57a2aa298336571 100644 --- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h +++ b/mlir/include/mlir/Dialect/Quant/QuantTypes.h @@ -54,10 +54,10 @@ class QuantizedType : public Type { /// The maximum number of bits supported for storage types. static constexpr unsigned MaxStorageBits = 32; - static LogicalResult verify(function_ref emitError, - unsigned flags, Type storageType, - Type expressedType, int64_t storageTypeMin, - int64_t storageTypeMax); + static LogicalResult + verifyInvariants(function_ref emitError, unsigned flags, + Type storageType, Type expressedType, int64_t storageTypeMin, + int64_t storageTypeMax); /// Support method to enable LLVM-style type casting. static bool classof(Type type); @@ -214,10 +214,10 @@ class AnyQuantizedType int64_t storageTypeMax); /// Verifies construction invariants and issues errors/warnings. - static LogicalResult verify(function_ref emitError, - unsigned flags, Type storageType, - Type expressedType, int64_t storageTypeMin, - int64_t storageTypeMax); + static LogicalResult + verifyInvariants(function_ref emitError, unsigned flags, + Type storageType, Type expressedType, int64_t storageTypeMin, + int64_t storageTypeMax); }; /// Represents a family of uniform, quantized types. @@ -276,11 +276,11 @@ class UniformQuantizedType int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax); /// Verifies construction invariants and issues errors/warnings. - static LogicalResult verify(function_ref emitError, - unsigned flags, Type storageType, - Type expressedType, double scale, - int64_t zeroPoint, int64_t storageTypeMin, - int64_t storageTypeMax); + static LogicalResult + verifyInvariants(function_ref emitError, unsigned flags, + Type storageType, Type expressedType, double scale, + int64_t zeroPoint, int64_t storageTypeMin, + int64_t storageTypeMax); /// Gets the scale term. The scale designates the difference between the real /// values corresponding to consecutive quantized values differing by 1. @@ -338,12 +338,12 @@ class UniformQuantizedPerAxisType int64_t storageTypeMin, int64_t storageTypeMax); /// Verifies construction invariants and issues errors/warnings. - static LogicalResult verify(function_ref emitError, - unsigned flags, Type storageType, - Type expressedType, ArrayRef scales, - ArrayRef zeroPoints, - int32_t quantizedDimension, - int64_t storageTypeMin, int64_t storageTypeMax); + static LogicalResult + verifyInvariants(function_ref emitError, unsigned flags, + Type storageType, Type expressedType, + ArrayRef scales, ArrayRef zeroPoints, + int32_t quantizedDimension, int64_t storageTypeMin, + int64_t storageTypeMax); /// Gets the quantization scales. The scales designate the difference between /// the real values corresponding to consecutive quantized values differing @@ -403,8 +403,9 @@ class CalibratedQuantizedType double min, double max); /// Verifies construction invariants and issues errors/warnings. - static LogicalResult verify(function_ref emitError, - Type expressedType, double min, double max); + static LogicalResult + verifyInvariants(function_ref emitError, + Type expressedType, double min, double max); double getMin() const; double getMax() const; }; diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h index 5ebfa9ca5ec25cc..2bdd7a5bf3dd83c 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h @@ -76,9 +76,10 @@ class InterfaceVarABIAttr /// Returns `spirv::StorageClass`. std::optional getStorageClass(); - static LogicalResult verify(function_ref emitError, - IntegerAttr descriptorSet, IntegerAttr binding, - IntegerAttr storageClass); + static LogicalResult + verifyInvariants(function_ref emitError, + IntegerAttr descriptorSet, IntegerAttr binding, + IntegerAttr storageClass); static constexpr StringLiteral name = "spirv.interface_var_abi"; }; @@ -128,9 +129,10 @@ class VerCapExtAttr /// Returns the capabilities as an integer array attribute. ArrayAttr getCapabilitiesAttr(); - static LogicalResult verify(function_ref emitError, - IntegerAttr version, ArrayAttr capabilities, - ArrayAttr extensions); + static LogicalResult + verifyInvariants(function_ref emitError, + IntegerAttr version, ArrayAttr capabilities, + ArrayAttr extensions); static constexpr StringLiteral name = "spirv.ver_cap_ext"; }; diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h index 55f0c787b444037..e2d04553d91b8bf 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -258,8 +258,9 @@ class SampledImageType static SampledImageType getChecked(function_ref emitError, Type imageType); - static LogicalResult verify(function_ref emitError, - Type imageType); + static LogicalResult + verifyInvariants(function_ref emitError, + Type imageType); Type getImageType() const; @@ -462,8 +463,9 @@ class MatrixType : public Type::TypeBase emitError, Type columnType, uint32_t columnCount); - static LogicalResult verify(function_ref emitError, - Type columnType, uint32_t columnCount); + static LogicalResult + verifyInvariants(function_ref emitError, + Type columnType, uint32_t columnCount); /// Returns true if the matrix elements are vectors of float elements. static bool isValidColumnType(Type columnType); diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index 5b6ec167fa2420f..4d3e1428e6c40b9 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -180,6 +180,7 @@ class AnyTypeOf allowedTypeList, string summary = "", summary), cppClassName> { list allowedTypes = allowedTypeList; + string cppType = cppClassName; } // A type that satisfies the constraints of all given types. diff --git a/mlir/include/mlir/IR/Constraints.td b/mlir/include/mlir/IR/Constraints.td index a026d58ccffb8ec..242c850f38f309a 100644 --- a/mlir/include/mlir/IR/Constraints.td +++ b/mlir/include/mlir/IR/Constraints.td @@ -153,6 +153,9 @@ class TypeConstraint { // The name of the C++ Type class if known, or Type if not. string cppClassName = cppClassNameParam; + // TODO: This field is sometimes called `cppClassName` and sometimes + // `cppType`. Use a single name consistently. + string cppType = cppClassNameParam; } // Subclass for constraints on an attribute. diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h index fb64f15162df5b0..d6ccbbd85799479 100644 --- a/mlir/include/mlir/IR/StorageUniquerSupport.h +++ b/mlir/include/mlir/IR/StorageUniquerSupport.h @@ -176,8 +176,8 @@ class StorageUserBase : public BaseT, public Traits... { template static ConcreteT get(MLIRContext *ctx, Args &&...args) { // Ensure that the invariants are correct for construction. - assert( - succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...))); + assert(succeeded( + ConcreteT::verifyInvariants(getDefaultDiagnosticEmitFn(ctx), args...))); return UniquerT::template get(ctx, std::forward(args)...); } @@ -198,7 +198,7 @@ class StorageUserBase : public BaseT, public Traits... { static ConcreteT getChecked(function_ref emitErrorFn, MLIRContext *ctx, Args... args) { // If the construction invariants fail then we return a null attribute. - if (failed(ConcreteT::verify(emitErrorFn, args...))) + if (failed(ConcreteT::verifyInvariants(emitErrorFn, args...))) return ConcreteT(); return UniquerT::template get(ctx, args...); } @@ -226,7 +226,9 @@ class StorageUserBase : public BaseT, public Traits... { /// Default implementation that just returns success. template - static LogicalResult verify(Args... args) { + static LogicalResult + verifyInvariants(function_ref emitErrorFn, + Args... args) { return success(); } diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h index 60dc8fee0f4a96d..91b457deeba2f68 100644 --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -34,7 +34,7 @@ class AsmState; /// Derived type classes are expected to implement several required /// implementation hooks: /// * Optional: -/// - static LogicalResult verify( +/// - static LogicalResult verifyInvariants( /// function_ref emitError, /// Args... args) /// * This method is invoked when calling the 'TypeBase::get/getChecked' @@ -97,20 +97,17 @@ class Type { bool operator!() const { return impl == nullptr; } template - [[deprecated("Use mlir::isa() instead")]] - bool isa() const; + [[deprecated("Use mlir::isa() instead")]] bool isa() const; template - [[deprecated("Use mlir::isa_and_nonnull() instead")]] - bool isa_and_nonnull() const; + [[deprecated("Use mlir::isa_and_nonnull() instead")]] bool + isa_and_nonnull() const; template - [[deprecated("Use mlir::dyn_cast() instead")]] - U dyn_cast() const; + [[deprecated("Use mlir::dyn_cast() instead")]] U dyn_cast() const; template - [[deprecated("Use mlir::dyn_cast_or_null() instead")]] - U dyn_cast_or_null() const; + [[deprecated("Use mlir::dyn_cast_or_null() instead")]] U + dyn_cast_or_null() const; template - [[deprecated("Use mlir::cast() instead")]] - U cast() const; + [[deprecated("Use mlir::cast() instead")]] U cast() const; /// Return a unique identifier for the concrete type. This is used to support /// dynamic type casting. diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h index 19c3a9183ec2cfb..c7819ef7c0ffa3e 100644 --- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h +++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h @@ -15,7 +15,9 @@ #define MLIR_TABLEGEN_ATTRORTYPEDEF_H #include "mlir/Support/LLVM.h" +#include "mlir/TableGen/Argument.h" #include "mlir/TableGen/Builder.h" +#include "mlir/TableGen/Predicate.h" #include "mlir/TableGen/Trait.h" namespace llvm { @@ -85,6 +87,8 @@ class AttrOrTypeParameter { /// Get an optional C++ parameter parser. std::optional getParser() const; + std::optional getTypeConstraint() const; + /// Get an optional C++ parameter printer. std::optional getPrinter() const; @@ -198,6 +202,8 @@ class AttrOrTypeDef { /// method. bool genVerifyDecl() const; + bool genVerifyInvariantsImpl() const; + /// Returns the def's extra class declaration code. std::optional getExtraDecls() const; diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h index 855952d19492db8..f750a34a3b2ba40 100644 --- a/mlir/include/mlir/TableGen/Class.h +++ b/mlir/include/mlir/TableGen/Class.h @@ -67,6 +67,8 @@ class MethodParameter { /// Get the C++ type. StringRef getType() const { return type; } + /// Get the C++ parameter name. + StringRef getName() const { return name; } /// Returns true if the parameter has a default value. bool hasDefaultValue() const { return !defaultValue.empty(); } diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 7bc2668310ddb0d..a1f87a637a61415 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -148,9 +148,9 @@ bool MMAMatrixType::isValidElementType(Type elementType) { } LogicalResult -MMAMatrixType::verify(function_ref emitError, - ArrayRef shape, Type elementType, - StringRef operand) { +MMAMatrixType::verifyInvariants(function_ref emitError, + ArrayRef shape, Type elementType, + StringRef operand) { if (operand != "AOp" && operand != "BOp" && operand != "COp") return emitError() << "operand expected to be one of AOp, BOp or COp"; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index dc7aef8ef7f850f..7f10a15ff31ff94 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -418,8 +418,7 @@ LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries, bool LLVMStructType::isValidElementType(Type type) { return !llvm::isa( - type); + LLVMFunctionType, LLVMTokenType>(type); } LLVMStructType LLVMStructType::getIdentified(MLIRContext *context, @@ -492,14 +491,15 @@ ArrayRef LLVMStructType::getBody() const { : getImpl()->getTypeList(); } -LogicalResult LLVMStructType::verify(function_ref, - StringRef, bool) { +LogicalResult +LLVMStructType::verifyInvariants(function_ref, StringRef, + bool) { return success(); } LogicalResult -LLVMStructType::verify(function_ref emitError, - ArrayRef types, bool) { +LLVMStructType::verifyInvariants(function_ref emitError, + ArrayRef types, bool) { for (Type t : types) if (!isValidElementType(t)) return emitError() << "invalid LLVM structure element type: " << t; diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp index 81e3b914755be2e..c2ba9c04e8771db 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp @@ -29,9 +29,10 @@ bool QuantizedType::classof(Type type) { } LogicalResult -QuantizedType::verify(function_ref emitError, - unsigned flags, Type storageType, Type expressedType, - int64_t storageTypeMin, int64_t storageTypeMax) { +QuantizedType::verifyInvariants(function_ref emitError, + unsigned flags, Type storageType, + Type expressedType, int64_t storageTypeMin, + int64_t storageTypeMax) { // Verify that the storage type is integral. // This restriction may be lifted at some point in favor of using bf16 // or f16 as exact representations on hardware where that is advantageous. @@ -233,11 +234,13 @@ AnyQuantizedType::getChecked(function_ref emitError, } LogicalResult -AnyQuantizedType::verify(function_ref emitError, - unsigned flags, Type storageType, Type expressedType, - int64_t storageTypeMin, int64_t storageTypeMax) { - if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType, - storageTypeMin, storageTypeMax))) { +AnyQuantizedType::verifyInvariants(function_ref emitError, + unsigned flags, Type storageType, + Type expressedType, int64_t storageTypeMin, + int64_t storageTypeMax) { + if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType, + expressedType, storageTypeMin, + storageTypeMax))) { return failure(); } @@ -268,12 +271,13 @@ UniformQuantizedType UniformQuantizedType::getChecked( storageTypeMin, storageTypeMax); } -LogicalResult UniformQuantizedType::verify( +LogicalResult UniformQuantizedType::verifyInvariants( function_ref emitError, unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax) { - if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType, - storageTypeMin, storageTypeMax))) { + if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType, + expressedType, storageTypeMin, + storageTypeMax))) { return failure(); } @@ -321,13 +325,14 @@ UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked( quantizedDimension, storageTypeMin, storageTypeMax); } -LogicalResult UniformQuantizedPerAxisType::verify( +LogicalResult UniformQuantizedPerAxisType::verifyInvariants( function_ref emitError, unsigned flags, Type storageType, Type expressedType, ArrayRef scales, ArrayRef zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax) { - if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType, - storageTypeMin, storageTypeMax))) { + if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType, + expressedType, storageTypeMin, + storageTypeMax))) { return failure(); } @@ -380,9 +385,9 @@ CalibratedQuantizedType CalibratedQuantizedType::getChecked( min, max); } -LogicalResult -CalibratedQuantizedType::verify(function_ref emitError, - Type expressedType, double min, double max) { +LogicalResult CalibratedQuantizedType::verifyInvariants( + function_ref emitError, Type expressedType, + double min, double max) { // Verify that the expressed type is floating point. // If this restriction is ever eliminated, the parser/printer must be // extended. diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp index 8a0ee7a3d813675..b71be23fdf47d0e 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp @@ -162,7 +162,7 @@ spirv::InterfaceVarABIAttr::getStorageClass() { return std::nullopt; } -LogicalResult spirv::InterfaceVarABIAttr::verify( +LogicalResult spirv::InterfaceVarABIAttr::verifyInvariants( function_ref emitError, IntegerAttr descriptorSet, IntegerAttr binding, IntegerAttr storageClass) { if (!descriptorSet.getType().isSignlessInteger(32)) @@ -257,10 +257,9 @@ ArrayAttr spirv::VerCapExtAttr::getCapabilitiesAttr() { return llvm::cast(getImpl()->capabilities); } -LogicalResult -spirv::VerCapExtAttr::verify(function_ref emitError, - IntegerAttr version, ArrayAttr capabilities, - ArrayAttr extensions) { +LogicalResult spirv::VerCapExtAttr::verifyInvariants( + function_ref emitError, IntegerAttr version, + ArrayAttr capabilities, ArrayAttr extensions) { if (!version.getType().isSignlessInteger(32)) return emitError() << "expected 32-bit integer for version"; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index 3808620bdffa6d6..c5590905b750453 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -817,8 +817,8 @@ SampledImageType::getChecked(function_ref emitError, Type SampledImageType::getImageType() const { return getImpl()->imageType; } LogicalResult -SampledImageType::verify(function_ref emitError, - Type imageType) { +SampledImageType::verifyInvariants(function_ref emitError, + Type imageType) { if (!llvm::isa(imageType)) return emitError() << "expected image type"; @@ -1181,8 +1181,9 @@ MatrixType MatrixType::getChecked(function_ref emitError, columnCount); } -LogicalResult MatrixType::verify(function_ref emitError, - Type columnType, uint32_t columnCount) { +LogicalResult +MatrixType::verifyInvariants(function_ref emitError, + Type columnType, uint32_t columnCount) { if (columnCount < 2 || columnCount > 4) return emitError() << "matrix can have 2, 3, or 4 columns only"; diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp index c9dbb3bc76b1fa0..ed727a834e34d01 100644 --- a/mlir/lib/TableGen/AttrOrTypeDef.cpp +++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp @@ -184,6 +184,12 @@ bool AttrOrTypeDef::genVerifyDecl() const { return def->getValueAsBit("genVerifyDecl"); } +bool AttrOrTypeDef::genVerifyInvariantsImpl() const { + return any_of(parameters, [](const AttrOrTypeParameter &p) { + return p.getTypeConstraint() != std::nullopt; + }); +} + std::optional AttrOrTypeDef::getExtraDecls() const { auto value = def->getValueAsString("extraClassDeclaration"); return value.empty() ? std::optional() : value; @@ -331,6 +337,13 @@ std::optional AttrOrTypeParameter::getDefaultValue() const { llvm::Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); } +std::optional AttrOrTypeParameter::getTypeConstraint() const { + if (auto *param = dyn_cast(getDef())) + if (param->getDef()->isSubClassOf("TypeConstraint")) + return TypeConstraint(param); + return std::nullopt; +} + //===----------------------------------------------------------------------===// // AttributeSelfTypeParameter //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/test-verifiers-type.mlir b/mlir/test/IR/test-verifiers-type.mlir new file mode 100644 index 000000000000000..96d0005eb7a19d8 --- /dev/null +++ b/mlir/test/IR/test-verifiers-type.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-opt %s -verify-diagnostics -split-input-file | FileCheck %s + +// CHECK: "test.type_producer"() : () -> !test.type_verification +"test.type_producer"() : () -> !test.type_verification + +// ----- + +// expected-error @below{{failed to verify 'param': 16-bit signless integer or 32-bit signless integer}} +"test.type_producer"() : () -> !test.type_verification diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td index 8f109f8ce5e6ddf..b3b94bd0ffea31a 100644 --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -270,7 +270,6 @@ def TestOverrideBuilderAttr : Test_Attr<"TestOverrideBuilder"> { let assemblyFormat = "`<` $a `>`"; let skipDefaultBuilders = 1; - let genVerifyDecl = 1; let builders = [AttrBuilder<(ins "int":$a), [{ return ::mlir::IntegerAttr::get(::mlir::IndexType::get($_ctxt), a); }], "::mlir::Attribute">]; diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td index d96152a0826f960..830475bed4e4440 100644 --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -392,4 +392,10 @@ def TestRecursiveAlias }]; } +def TestTypeVerification : Test_Type<"TestTypeVerification"> { + let parameters = (ins AnyTypeOf<[I16, I32]>:$param); + let mnemonic = "type_verification"; + let assemblyFormat = "`<` $param `>`"; +} + #endif // TEST_TYPEDEFS diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp index 8cc8314418104c0..03fdefde766cc49 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -93,8 +93,14 @@ class DefGen { void emitDialectName(); /// Emit attribute or type builders. void emitBuilders(); - /// Emit a verifier for the def. - void emitVerifier(); + /// Emit a verifier declaration for custom verification (impl. provided by + /// the users). + void emitVerifierDecl(); + /// Emit a verifier that checks type constraints. + void emitInvariantsVerifierImpl(); + /// Emit an entry poiunt for verification that calls the invariants and + /// custom verifier. + void emitInvariantsVerifier(bool hasImpl, bool hasCustomVerifier); /// Emit parsers and printers. void emitParserPrinter(); /// Emit parameter accessors, if required. @@ -169,9 +175,10 @@ DefGen::DefGen(const AttrOrTypeDef &def) valueType(isa(def) ? "Attribute" : "Type"), defType(isa(def) ? "Attr" : "Type") { // Check that all parameters have names. - for (const AttrOrTypeParameter ¶m : def.getParameters()) + for (const AttrOrTypeParameter ¶m : def.getParameters()) { if (param.isAnonymous()) llvm::PrintFatalError("all parameters must have a name"); + } // If a storage class is needed, create one. if (def.getNumParameters() > 0) @@ -188,9 +195,17 @@ DefGen::DefGen(const AttrOrTypeDef &def) emitName(); // Emit the dialect name. emitDialectName(); - // Emit the verifier. - if (storageCls && def.genVerifyDecl()) - emitVerifier(); + // Emit verification of type constraints. + bool genVerifyInvariantsImpl = def.genVerifyInvariantsImpl(); + if (storageCls && genVerifyInvariantsImpl) + emitInvariantsVerifierImpl(); + // Emit the custom verifier (written by the user). + bool genVerifyDecl = def.genVerifyDecl(); + if (storageCls && genVerifyDecl) + emitVerifierDecl(); + // Emit the "verifyInvariants" function if there is any verification at all. + if (storageCls) + emitInvariantsVerifier(genVerifyInvariantsImpl, genVerifyDecl); // Emit the mnemonic, if there is one, and any associated parser and printer. if (def.getMnemonic()) emitParserPrinter(); @@ -295,24 +310,91 @@ void DefGen::emitDialectName() { void DefGen::emitBuilders() { if (!def.skipDefaultBuilders()) { emitDefaultBuilder(); - if (def.genVerifyDecl()) + if (def.genVerifyDecl() || def.genVerifyInvariantsImpl()) emitCheckedBuilder(); } for (auto &builder : def.getBuilders()) { emitCustomBuilder(builder); - if (def.genVerifyDecl()) + if (def.genVerifyDecl() || def.genVerifyInvariantsImpl()) emitCheckedCustomBuilder(builder); } } -void DefGen::emitVerifier() { - defCls.declare("Base::getChecked"); +void DefGen::emitVerifierDecl() { defCls.declareStaticMethod( "::llvm::LogicalResult", "verify", getBuilderParams({{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}})); } +static const char *const patternParameterVerificationCode = R"( +if (!({0})) { + emitError() << "failed to verify '{1}': {2}"; + return ::mlir::failure(); +} +)"; + +void DefGen::emitInvariantsVerifierImpl() { + SmallVector builderParams = getBuilderParams( + {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}}); + Method *verifier = + defCls.addMethod("::llvm::LogicalResult", "verifyInvariantsImpl", + Method::Static, builderParams); + verifier->body().indent(); + + // Generate verification for each parameter that is a type constraint. + for (auto it : llvm::enumerate(def.getParameters())) { + const AttrOrTypeParameter ¶m = it.value(); + std::optional constraint = param.getTypeConstraint(); + // No verification needed for parameters that are not type constraints. + if (!constraint.has_value()) + continue; + FmtContext ctx; + // Note: Skip over the first method parameter (`emitError`). + ctx.withSelf(builderParams[it.index() + 1].getName()); + std::string condition = tgfmt(constraint->getConditionTemplate(), &ctx); + verifier->body() << formatv(patternParameterVerificationCode, condition, + param.getName(), constraint->getSummary()) + << "\n"; + } + verifier->body() << "return ::mlir::success();"; +} + +void DefGen::emitInvariantsVerifier(bool hasImpl, bool hasCustomVerifier) { + if (!hasImpl && !hasCustomVerifier) + return; + defCls.declare("Base::getChecked"); + SmallVector builderParams = getBuilderParams( + {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}}); + Method *verifier = + defCls.addMethod("::llvm::LogicalResult", "verifyInvariants", + Method::Static, builderParams); + verifier->body().indent(); + if (hasImpl) { + // Call the verifier that checks the type constraints. + verifier->body() << "if (::mlir::failed(verifyInvariantsImpl("; + for (int i = 0, e = builderParams.size(); i < e; ++i) { + if (i > 0) + verifier->body() << ", "; + verifier->body() << builderParams[i].getName(); + } + verifier->body() << ")))\n"; + verifier->body() << " return ::mlir::failure();\n"; + } + if (hasCustomVerifier) { + // Call the custom verifier that is provided by the user. + verifier->body() << "if (::mlir::failed(verify("; + for (int i = 0, e = builderParams.size(); i < e; ++i) { + if (i > 0) + verifier->body() << ", "; + verifier->body() << builderParams[i].getName(); + } + verifier->body() << ")))\n"; + verifier->body() << " return ::mlir::failure();\n"; + } + verifier->body() << "return ::mlir::success();"; +} + void DefGen::emitParserPrinter() { auto *mnemonic = defCls.addStaticMethod( "::llvm::StringLiteral", "getMnemonic"); diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp index dacc20b6ba20866..a4ae271edb6bd24 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -323,7 +323,7 @@ void DefFormat::genParser(MethodBody &os) { // Generate call to the attribute or type builder. Use the checked getter // if one was generated. - if (def.genVerifyDecl()) { + if (def.genVerifyDecl() || def.genVerifyInvariantsImpl()) { os << tgfmt("return $_parser.getChecked<$0>($_loc, $_parser.getContext()", &ctx, def.getCppClassName()); } else {