Skip to content

Commit

Permalink
[mlir][ODS] Optionally generate public C++ functions for type constra…
Browse files Browse the repository at this point in the history
…ints (llvm#104577)

Add `gen-type-constraint-decls` and `gen-type-constraint-defs`, which
generate public C++ functions for type constraints. The name of the C++
function is specified in the `cppFunctionName` field.

Type constraints are typically used for op/type/attribute verification.
They are also sometimes called from builders and transformations. Until
now, this required duplicating the check in C++.

Note: This commit just adds the option for type constraints, but
attribute constraints could be supported in the same way.

Alternatives considered:
1. The C++ functions could also be generated as part of
`gen-typedef-decls/defs`, but that can be confusing because type
constraints may rely on type definitions from multiple `.td` files.
`#include`s could cause duplicate definitions of the same type
constraint.
2. The C++ functions could also be generated as static member functions
of dialects, but they don't really belong to a dialect. (Because they
may rely on type definitions from multiple dialects.)
  • Loading branch information
matthias-springer authored Aug 21, 2024
1 parent 90556ef commit a3d4187
Show file tree
Hide file tree
Showing 11 changed files with 181 additions and 11 deletions.
59 changes: 59 additions & 0 deletions mlir/docs/DefiningDialects/Constraints.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Constraints

[TOC]

## Attribute / Type Constraints

When defining the arguments of an operation in TableGen, users can specify
either plain attributes/types or use attribute/type constraints to levy
additional requirements on the attribute value or operand type.

```tablegen
def My_Type1 : MyDialect_Type<"Type1", "type1"> { ... }
def My_Type2 : MyDialect_Type<"Type2", "type2"> { ... }
// Plain type
let arguments = (ins MyType1:$val);
// Type constraint
let arguments = (ins AnyTypeOf<[MyType1, MyType2]>:$val);
```

`AnyTypeOf` is an example for a type constraints. Many useful type constraints
can be found in `mlir/IR/CommonTypeConstraints.td`. Additional verification
code is generated for type/attribute constraints. Type constraints can not only
be used when defining operation arguments, but also when defining type
parameters.

Optionally, C++ functions can be generated, so that type constraints can be
checked from C++. The name of the C++ function must be specified in the
`cppFunctionName` field. If no function name is specified, no C++ function is
emitted.

```tablegen
// Example: Element type constraint for VectorType
def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
let cppFunctionName = "isValidVectorTypeElementType";
}
```

The above example tranlates into the following C++ code:
```c++
bool isValidVectorTypeElementType(::mlir::Type type) {
return (((::llvm::isa<::mlir::IntegerType>(type))) || ((::llvm::isa<::mlir::IndexType>(type))) || ((::llvm::isa<::mlir::FloatType>(type))));
}
```
An extra TableGen rule is needed to emit C++ code for type constraints. This
will generate only the declarations/definitions of the type constaraints that
are defined in the specified `.td` file, but not those that are in included
`.td` files.
```cmake
mlir_tablegen(<Your Dialect>TypeConstraints.h.inc -gen-type-constraint-decls)
mlir_tablegen(<Your Dialect>TypeConstraints.cpp.inc -gen-type-constraint-defs)
```

The generated `<Your Dialect>TypeConstraints.h.inc` will need to be included
whereever you are referencing the type constraint in C++. Note that no C++
namespace will be emitted by the code generator. The `#include` statements of
the `.h.inc`/`.cpp.inc` files should be wrapped in C++ namespaces by the user.
1 change: 1 addition & 0 deletions mlir/include/mlir/IR/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
#include "mlir/IR/BuiltinTypes.h.inc"

namespace mlir {
#include "mlir/IR/BuiltinTypeConstraints.h.inc"

//===----------------------------------------------------------------------===//
// MemRefType
Expand Down
14 changes: 7 additions & 7 deletions mlir/include/mlir/IR/BuiltinTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1097,6 +1097,10 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
// VectorType
//===----------------------------------------------------------------------===//

def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
let cppFunctionName = "isValidVectorTypeElementType";
}

def Builtin_Vector : Builtin_Type<"Vector", "vector",
[ShapedTypeInterface, ValueSemantics], "Type"> {
let summary = "Multi-dimensional SIMD vector type";
Expand Down Expand Up @@ -1147,7 +1151,7 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
}];
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
AnyTypeOf<[AnyInteger, Index, AnyFloat]>:$elementType,
Builtin_VectorTypeElementType:$elementType,
ArrayRefParameter<"bool">:$scalableDims
);
let builders = [
Expand All @@ -1171,12 +1175,8 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
class Builder;

/// Returns true if the given type can be used as an element of a vector
/// type. In particular, vectors can consist of integer, index, or float
/// primitives.
static bool isValidElementType(Type t) {
// TODO: Auto-generate this function from $elementType.
return ::llvm::isa<IntegerType, IndexType, FloatType>(t);
}
/// type. See "Builtin_VectorTypeElementType" for allowed types.
static bool isValidElementType(Type t);

/// Returns true if the vector contains scalable dimensions.
bool isScalable() const {
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ set(LLVM_TARGET_DEFINITIONS BuiltinTypes.td)
mlir_tablegen(BuiltinTypes.h.inc -gen-typedef-decls)
mlir_tablegen(BuiltinTypes.cpp.inc -gen-typedef-defs)
add_public_tablegen_target(MLIRBuiltinTypesIncGen)
mlir_tablegen(BuiltinTypeConstraints.h.inc -gen-type-constraint-decls)
mlir_tablegen(BuiltinTypeConstraints.cpp.inc -gen-type-constraint-defs)
add_public_tablegen_target(MLIRBuiltinTypeConstraintsIncGen)

set(LLVM_TARGET_DEFINITIONS BuiltinTypeInterfaces.td)
mlir_tablegen(BuiltinTypeInterfaces.h.inc -gen-type-interface-decls)
Expand Down
6 changes: 5 additions & 1 deletion mlir/include/mlir/IR/Constraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,14 @@ class Constraint<Pred pred, string desc = ""> {

// Subclass for constraints on a type.
class TypeConstraint<Pred predicate, string summary = "",
string cppTypeParam = "::mlir::Type"> :
string cppTypeParam = "::mlir::Type",
string cppFunctionNameParam = ""> :
Constraint<predicate, summary> {
// The name of the C++ Type class if known, or Type if not.
string cppType = cppTypeParam;
// The name of the C++ function that is generated for this type constraint.
// If empty, no C++ function is generated.
string cppFunctionName = cppFunctionNameParam;
}

// Subclass for constraints on an attribute.
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/TableGen/Constraint.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ class Constraint {
/// context on the def).
std::string getUniqueDefName() const;

/// Returns the name of the C++ function that should be generated for this
/// constraint, or std::nullopt if no C++ function should be generated.
std::optional<StringRef> getCppFunctionName() const;

Kind getKind() const { return kind; }

/// Return the underlying def.
Expand Down
16 changes: 14 additions & 2 deletions mlir/lib/IR/BuiltinTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ using namespace mlir::detail;
#define GET_TYPEDEF_CLASSES
#include "mlir/IR/BuiltinTypes.cpp.inc"

namespace mlir {
#include "mlir/IR/BuiltinTypeConstraints.cpp.inc"
} // namespace mlir

//===----------------------------------------------------------------------===//
// BuiltinDialect
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -230,6 +234,10 @@ LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
// VectorType
//===----------------------------------------------------------------------===//

bool VectorType::isValidElementType(Type t) {
return isValidVectorTypeElementType(t);
}

LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
ArrayRef<bool> scalableDims) {
Expand Down Expand Up @@ -278,7 +286,9 @@ Type TensorType::getElementType() const {
[](auto type) { return type.getElementType(); });
}

bool TensorType::hasRank() const { return !llvm::isa<UnrankedTensorType>(*this); }
bool TensorType::hasRank() const {
return !llvm::isa<UnrankedTensorType>(*this);
}

ArrayRef<int64_t> TensorType::getShape() const {
return llvm::cast<RankedTensorType>(*this).getShape();
Expand Down Expand Up @@ -365,7 +375,9 @@ Type BaseMemRefType::getElementType() const {
[](auto type) { return type.getElementType(); });
}

bool BaseMemRefType::hasRank() const { return !llvm::isa<UnrankedMemRefType>(*this); }
bool BaseMemRefType::hasRank() const {
return !llvm::isa<UnrankedMemRefType>(*this);
}

ArrayRef<int64_t> BaseMemRefType::getShape() const {
return llvm::cast<MemRefType>(*this).getShape();
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ add_mlir_library(MLIRIR
MLIRBuiltinLocationAttributesIncGen
MLIRBuiltinOpsIncGen
MLIRBuiltinTypesIncGen
MLIRBuiltinTypeConstraintsIncGen
MLIRBuiltinTypeInterfacesIncGen
MLIRCallInterfacesIncGen
MLIRCastInterfacesIncGen
Expand Down
10 changes: 9 additions & 1 deletion mlir/lib/TableGen/Constraint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Constraint::Constraint(const llvm::Record *record)
kind = CK_Region;
} else if (def->isSubClassOf("SuccessorConstraint")) {
kind = CK_Successor;
} else if(!def->isSubClassOf("Constraint")) {
} else if (!def->isSubClassOf("Constraint")) {
llvm::errs() << "Expected a constraint but got: \n" << *def << "\n";
llvm::report_fatal_error("Abort");
}
Expand Down Expand Up @@ -109,6 +109,14 @@ std::optional<StringRef> Constraint::getBaseDefName() const {
}
}

std::optional<StringRef> Constraint::getCppFunctionName() const {
std::optional<StringRef> name =
def->getValueAsOptionalString("cppFunctionName");
if (!name || *name == "")
return std::nullopt;
return name;
}

AppliedConstraint::AppliedConstraint(Constraint &&constraint,
llvm::StringRef self,
std::vector<std::string> &&entities)
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/mlir-tblgen/type-constraints.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: mlir-tblgen -gen-type-constraint-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
// RUN: mlir-tblgen -gen-type-constraint-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF

include "mlir/IR/CommonTypeConstraints.td"

def DummyConstraint : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
let cppFunctionName = "isValidDummy";
}

// DECL: bool isValidDummy(::mlir::Type type);

// DEF: bool isValidDummy(::mlir::Type type) {
// DEF: return (((::llvm::isa<::mlir::IntegerType>(type))) || ((::llvm::isa<::mlir::IndexType>(type))) || ((::llvm::isa<::mlir::FloatType>(type))));
// DEF: }
64 changes: 64 additions & 0 deletions mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,55 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
return false;
}

//===----------------------------------------------------------------------===//
// Type Constraints
//===----------------------------------------------------------------------===//

/// Find all type constraints for which a C++ function should be generated.
static std::vector<Constraint>
getAllTypeConstraints(const llvm::RecordKeeper &records) {
std::vector<Constraint> result;
for (llvm::Record *def :
records.getAllDerivedDefinitionsIfDefined("TypeConstraint")) {
// Ignore constraints defined outside of the top-level file.
if (llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
llvm::SrcMgr.getMainFileID())
continue;
Constraint constr(def);
// Generate C++ function only if "cppFunctionName" is set.
if (!constr.getCppFunctionName())
continue;
result.push_back(constr);
}
return result;
}

static void emitTypeConstraintDecls(const llvm::RecordKeeper &records,
raw_ostream &os) {
static const char *const typeConstraintDecl = R"(
bool {0}(::mlir::Type type);
)";

for (Constraint constr : getAllTypeConstraints(records))
os << strfmt(typeConstraintDecl, *constr.getCppFunctionName());
}

static void emitTypeConstraintDefs(const llvm::RecordKeeper &records,
raw_ostream &os) {
static const char *const typeConstraintDef = R"(
bool {0}(::mlir::Type type) {
return ({1});
}
)";

for (Constraint constr : getAllTypeConstraints(records)) {
FmtContext ctx;
ctx.withSelf("type");
std::string condition = tgfmt(constr.getConditionTemplate(), &ctx);
os << strfmt(typeConstraintDef, *constr.getCppFunctionName(), condition);
}
}

//===----------------------------------------------------------------------===//
// GEN: Registration hooks
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1070,3 +1119,18 @@ static mlir::GenRegistration
TypeDefGenerator generator(records, os);
return generator.emitDecls(typeDialect);
});

static mlir::GenRegistration
genTypeConstrDefs("gen-type-constraint-defs",
"Generate type constraint definitions",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
emitTypeConstraintDefs(records, os);
return false;
});
static mlir::GenRegistration
genTypeConstrDecls("gen-type-constraint-decls",
"Generate type constraint declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
emitTypeConstraintDecls(records, os);
return false;
});

0 comments on commit a3d4187

Please sign in to comment.