Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][ODS] Verify type constraints in Types and Attributes #102326

Merged
merged 1 commit into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,10 @@ class MMAMatrixType

/// Verify that shape and elementType are actually allowed for the
/// MMAMatrixType.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
StringRef operand);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
StringRef operand);

/// Get number of dims.
unsigned getNumDims() const;
Expand Down
12 changes: 7 additions & 5 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,13 @@ class LLVMStructType
ArrayRef<Type> getBody() const;

/// Verifies that the type about to be constructed is well-formed.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
StringRef, bool);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<Type> types, bool);
using Base::verify;
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, StringRef,
bool);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<Type> types, bool);
using Base::verifyInvariants;

/// Hooks for DataLayoutTypeInterface. Should not be called directly. Obtain a
/// DataLayout instance and query it instead.
Expand Down
43 changes: 22 additions & 21 deletions mlir/include/mlir/Dialect/Quant/QuantTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ class QuantizedType : public Type {
/// The maximum number of bits supported for storage types.
static constexpr unsigned MaxStorageBits = 32;

static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax);

/// Support method to enable LLVM-style type casting.
static bool classof(Type type);
Expand Down Expand Up @@ -214,10 +214,10 @@ class AnyQuantizedType
int64_t storageTypeMax);

/// Verifies construction invariants and issues errors/warnings.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax);
};

/// Represents a family of uniform, quantized types.
Expand Down Expand Up @@ -276,11 +276,11 @@ class UniformQuantizedType
int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax);

/// Verifies construction invariants and issues errors/warnings.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, double scale,
int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, double scale,
int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax);

/// Gets the scale term. The scale designates the difference between the real
/// values corresponding to consecutive quantized values differing by 1.
Expand Down Expand Up @@ -338,12 +338,12 @@ class UniformQuantizedPerAxisType
int64_t storageTypeMin, int64_t storageTypeMax);

/// Verifies construction invariants and issues errors/warnings.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType,
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin,
int64_t storageTypeMax);

/// Gets the quantization scales. The scales designate the difference between
/// the real values corresponding to consecutive quantized values differing
Expand Down Expand Up @@ -403,8 +403,9 @@ class CalibratedQuantizedType
double min, double max);

/// Verifies construction invariants and issues errors/warnings.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type expressedType, double min, double max);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
Type expressedType, double min, double max);
double getMin() const;
double getMax() const;
};
Expand Down
14 changes: 8 additions & 6 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,10 @@ class InterfaceVarABIAttr
/// Returns `spirv::StorageClass`.
std::optional<StorageClass> getStorageClass();

static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
IntegerAttr descriptorSet, IntegerAttr binding,
IntegerAttr storageClass);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
IntegerAttr descriptorSet, IntegerAttr binding,
IntegerAttr storageClass);

static constexpr StringLiteral name = "spirv.interface_var_abi";
};
Expand Down Expand Up @@ -128,9 +129,10 @@ class VerCapExtAttr
/// Returns the capabilities as an integer array attribute.
ArrayAttr getCapabilitiesAttr();

static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
IntegerAttr version, ArrayAttr capabilities,
ArrayAttr extensions);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
IntegerAttr version, ArrayAttr capabilities,
ArrayAttr extensions);

static constexpr StringLiteral name = "spirv.ver_cap_ext";
};
Expand Down
10 changes: 6 additions & 4 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,9 @@ class SampledImageType
static SampledImageType
getChecked(function_ref<InFlightDiagnostic()> emitError, Type imageType);

static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type imageType);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
Type imageType);

Type getImageType() const;

Expand Down Expand Up @@ -462,8 +463,9 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
static MatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
Type columnType, uint32_t columnCount);

static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type columnType, uint32_t columnCount);
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
Type columnType, uint32_t columnCount);

/// Returns true if the matrix elements are vectors of float elements.
static bool isValidColumnType(Type columnType);
Expand Down
10 changes: 6 additions & 4 deletions mlir/include/mlir/IR/StorageUniquerSupport.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
template <typename... Args>
static ConcreteT get(MLIRContext *ctx, Args &&...args) {
// Ensure that the invariants are correct for construction.
assert(
succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...)));
assert(succeeded(
ConcreteT::verifyInvariants(getDefaultDiagnosticEmitFn(ctx), args...)));
return UniquerT::template get<ConcreteT>(ctx, std::forward<Args>(args)...);
}

Expand All @@ -198,7 +198,7 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
static ConcreteT getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
MLIRContext *ctx, Args... args) {
// If the construction invariants fail then we return a null attribute.
if (failed(ConcreteT::verify(emitErrorFn, args...)))
if (failed(ConcreteT::verifyInvariants(emitErrorFn, args...)))
return ConcreteT();
return UniquerT::template get<ConcreteT>(ctx, args...);
}
Expand Down Expand Up @@ -226,7 +226,9 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {

/// Default implementation that just returns success.
template <typename... Args>
static LogicalResult verify(Args... args) {
static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitErrorFn,
Args... args) {
return success();
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/IR/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class AsmState;
/// Derived type classes are expected to implement several required
/// implementation hooks:
/// * Optional:
/// - static LogicalResult verify(
/// - static LogicalResult verifyInvariants(
/// function_ref<InFlightDiagnostic()> emitError,
/// Args... args)
/// * This method is invoked when calling the 'TypeBase::get/getChecked'
Expand Down
8 changes: 8 additions & 0 deletions mlir/include/mlir/TableGen/AttrOrTypeDef.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "mlir/Support/LLVM.h"
#include "mlir/TableGen/Builder.h"
#include "mlir/TableGen/Constraint.h"
#include "mlir/TableGen/Trait.h"

namespace llvm {
Expand Down Expand Up @@ -85,6 +86,9 @@ class AttrOrTypeParameter {
/// Get an optional C++ parameter parser.
std::optional<StringRef> getParser() const;

/// If this is a type constraint, return it.
std::optional<Constraint> getConstraint() const;

/// Get an optional C++ parameter printer.
std::optional<StringRef> getPrinter() const;

Expand Down Expand Up @@ -198,6 +202,10 @@ class AttrOrTypeDef {
/// method.
bool genVerifyDecl() const;

/// Return true if we need to generate any type constraint verification and
/// the getChecked method.
bool genVerifyInvariantsImpl() const;

/// Returns the def's extra class declaration code.
std::optional<StringRef> getExtraDecls() const;

Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/TableGen/Class.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class MethodParameter {

/// Get the C++ type.
StringRef getType() const { return type; }
/// Get the C++ parameter name.
StringRef getName() const { return name; }
/// Returns true if the parameter has a default value.
bool hasDefaultValue() const { return !defaultValue.empty(); }

Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ bool MMAMatrixType::isValidElementType(Type elementType) {
}

LogicalResult
MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
StringRef operand) {
MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
StringRef operand) {
if (operand != "AOp" && operand != "BOp" && operand != "COp")
return emitError() << "operand expected to be one of AOp, BOp or COp";

Expand Down
12 changes: 6 additions & 6 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,8 +418,7 @@ LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries,

bool LLVMStructType::isValidElementType(Type type) {
return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>(
type);
LLVMFunctionType, LLVMTokenType>(type);
}

LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
Expand Down Expand Up @@ -492,14 +491,15 @@ ArrayRef<Type> LLVMStructType::getBody() const {
: getImpl()->getTypeList();
}

LogicalResult LLVMStructType::verify(function_ref<InFlightDiagnostic()>,
StringRef, bool) {
LogicalResult
LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()>, StringRef,
bool) {
return success();
}

LogicalResult
LLVMStructType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<Type> types, bool) {
LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<Type> types, bool) {
for (Type t : types)
if (!isValidElementType(t))
return emitError() << "invalid LLVM structure element type: " << t;
Expand Down
39 changes: 22 additions & 17 deletions mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ bool QuantizedType::classof(Type type) {
}

LogicalResult
QuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType, Type expressedType,
int64_t storageTypeMin, int64_t storageTypeMax) {
QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax) {
// Verify that the storage type is integral.
// This restriction may be lifted at some point in favor of using bf16
// or f16 as exact representations on hardware where that is advantageous.
Expand Down Expand Up @@ -233,11 +234,13 @@ AnyQuantizedType::getChecked(function_ref<InFlightDiagnostic()> emitError,
}

LogicalResult
AnyQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType, Type expressedType,
int64_t storageTypeMin, int64_t storageTypeMax) {
if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
storageTypeMin, storageTypeMax))) {
AnyQuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax) {
if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
expressedType, storageTypeMin,
storageTypeMax))) {
return failure();
}

Expand Down Expand Up @@ -268,12 +271,13 @@ UniformQuantizedType UniformQuantizedType::getChecked(
storageTypeMin, storageTypeMax);
}

LogicalResult UniformQuantizedType::verify(
LogicalResult UniformQuantizedType::verifyInvariants(
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, double scale, int64_t zeroPoint,
int64_t storageTypeMin, int64_t storageTypeMax) {
if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
storageTypeMin, storageTypeMax))) {
if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
expressedType, storageTypeMin,
storageTypeMax))) {
return failure();
}

Expand Down Expand Up @@ -321,13 +325,14 @@ UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
quantizedDimension, storageTypeMin, storageTypeMax);
}

LogicalResult UniformQuantizedPerAxisType::verify(
LogicalResult UniformQuantizedPerAxisType::verifyInvariants(
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax) {
if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
storageTypeMin, storageTypeMax))) {
if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
expressedType, storageTypeMin,
storageTypeMax))) {
return failure();
}

Expand Down Expand Up @@ -380,9 +385,9 @@ CalibratedQuantizedType CalibratedQuantizedType::getChecked(
min, max);
}

LogicalResult
CalibratedQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
Type expressedType, double min, double max) {
LogicalResult CalibratedQuantizedType::verifyInvariants(
function_ref<InFlightDiagnostic()> emitError, Type expressedType,
double min, double max) {
// Verify that the expressed type is floating point.
// If this restriction is ever eliminated, the parser/printer must be
// extended.
Expand Down
9 changes: 4 additions & 5 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ spirv::InterfaceVarABIAttr::getStorageClass() {
return std::nullopt;
}

LogicalResult spirv::InterfaceVarABIAttr::verify(
LogicalResult spirv::InterfaceVarABIAttr::verifyInvariants(
function_ref<InFlightDiagnostic()> emitError, IntegerAttr descriptorSet,
IntegerAttr binding, IntegerAttr storageClass) {
if (!descriptorSet.getType().isSignlessInteger(32))
Expand Down Expand Up @@ -257,10 +257,9 @@ ArrayAttr spirv::VerCapExtAttr::getCapabilitiesAttr() {
return llvm::cast<ArrayAttr>(getImpl()->capabilities);
}

LogicalResult
spirv::VerCapExtAttr::verify(function_ref<InFlightDiagnostic()> emitError,
IntegerAttr version, ArrayAttr capabilities,
ArrayAttr extensions) {
LogicalResult spirv::VerCapExtAttr::verifyInvariants(
function_ref<InFlightDiagnostic()> emitError, IntegerAttr version,
ArrayAttr capabilities, ArrayAttr extensions) {
if (!version.getType().isSignlessInteger(32))
return emitError() << "expected 32-bit integer for version";

Expand Down
Loading
Loading