Skip to content

Commit

Permalink
[mlir][ODS] Consistent cppType / cppClassName usage
Browse files Browse the repository at this point in the history
Make sure that the usage of `cppType` and `cppClassName` of type and attribute definitions/constraints is consistent in TableGen.

- `cppClassName`: The C++ class name of the type or attribute.
- `cppType`: The fully qualified C++ class name: C++ namespace and C++ class name.
  • Loading branch information
matthias-springer committed Aug 9, 2024
1 parent 16dadec commit 9968cd5
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 69 deletions.
18 changes: 9 additions & 9 deletions mlir/include/mlir/IR/AttrTypeBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ class AttrDef<Dialect dialect, string name, list<Trait> traits = [],
AttrOrTypeDef<"Attr", name, traits, baseCppClass> {
// The name of the C++ Attribute class.
string cppClassName = name # "Attr";
let storageType = dialect.cppNamespace # "::" # name # "Attr";
let storageType = dialect.cppNamespace # "::" # cppClassName;

// The underlying C++ value type
let returnType = dialect.cppNamespace # "::" # cppClassName;
Expand All @@ -275,12 +275,10 @@ class AttrDef<Dialect dialect, string name, list<Trait> traits = [],
//
// For example, `$_self.getValue().getSExtValue()` for `IntegerAttr val` will
// expand to `getAttrOfType<IntegerAttr>("val").getValue().getSExtValue()`.
let convertFromStorage = "::llvm::cast<" # dialect.cppNamespace #
"::" # cppClassName # ">($_self)";
let convertFromStorage = "::llvm::cast<" # cppType # ">($_self)";

// The predicate for when this def is used as a constraint.
let predicate = CPred<"::llvm::isa<" # dialect.cppNamespace #
"::" # cppClassName # ">($_self)">;
let predicate = CPred<"::llvm::isa<" # cppType # ">($_self)">;
}

// Define a new type, named `name`, belonging to `dialect` that inherits from
Expand All @@ -289,6 +287,9 @@ class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
string baseCppClass = "::mlir::Type">
: DialectType<dialect, CPred<"">, /*descr*/"", name # "Type">,
AttrOrTypeDef<"Type", name, traits, baseCppClass> {
// The name of the C++ Type class.
string cppClassName = name # "Type";

// Make it possible to use such type as parameters for other types.
string cppType = dialect.cppNamespace # "::" # cppClassName;

Expand All @@ -297,12 +298,11 @@ class TypeDef<Dialect dialect, string name, list<Trait> traits = [],

// A constant builder provided when the type has no parameters.
let builderCall = !if(!empty(parameters),
"$_builder.getType<" # dialect.cppNamespace #
"::" # cppClassName # ">()",
"$_builder.getType<" # cppType # ">()",
"");

// The predicate for when this def is used as a constraint.
let predicate = CPred<"::llvm::isa<" # dialect.cppNamespace #
"::" # cppClassName # ">($_self)">;
let predicate = CPred<"::llvm::isa<" # cppType # ">($_self)">;
}

//===----------------------------------------------------------------------===//
Expand Down
6 changes: 3 additions & 3 deletions mlir/include/mlir/IR/CommonAttrConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,14 @@ def AnyAttr : Attr<CPred<"true">, "any attribute"> {

// Any attribute from the given list
class AnyAttrOf<list<Attr> allowedAttrs, string summary = "",
string cppClassName = "::mlir::Attribute",
string cppType = "::mlir::Attribute",
string fromStorage = "$_self"> : Attr<
// Satisfy any of the allowed attribute's condition
Or<!foreach(allowedattr, allowedAttrs, allowedattr.predicate)>,
!if(!eq(summary, ""),
!interleave(!foreach(t, allowedAttrs, t.summary), " or "),
summary)> {
let returnType = cppClassName;
let returnType = cppType;
let convertFromStorage = fromStorage;
}

Expand Down Expand Up @@ -369,7 +369,7 @@ def TypeAttr : TypeAttrBase<"::mlir::Type", "any type attribute"> {
}

class TypeAttrOf<Type ty>
: TypeAttrBase<ty.cppClassName, "type attribute of " # ty.summary,
: TypeAttrBase<ty.cppType, "type attribute of " # ty.summary,
ty.predicate> {
let constBuilderCall = "::mlir::TypeAttr::get($0)";
}
Expand Down
34 changes: 17 additions & 17 deletions mlir/include/mlir/IR/CommonTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -98,31 +98,31 @@ def HasValueSemanticsPred : CPred<"$_self.hasTrait<::mlir::ValueSemantics>()">;

// A type, carries type constraints.
class Type<Pred condition, string descr = "",
string cppClassName = "::mlir::Type"> :
TypeConstraint<condition, descr, cppClassName> {
string cppType = "::mlir::Type"> :
TypeConstraint<condition, descr, cppType> {
string description = "";
string builderCall = "";
}

// Allows providing an alternative name and summary to an existing type def.
class TypeAlias<Type t, string summary = t.summary> :
Type<t.predicate, summary, t.cppClassName> {
Type<t.predicate, summary, t.cppType> {
let description = t.description;
let builderCall = t.builderCall;
}

// A type of a specific dialect.
class DialectType<Dialect d, Pred condition, string descr = "",
string cppClassName = "::mlir::Type"> :
Type<condition, descr, cppClassName> {
string cppType = "::mlir::Type"> :
Type<condition, descr, cppType> {
Dialect dialect = d;
}

// A variadic type constraint. It expands to zero or more of the base type. This
// class is used for supporting variadic operands/results.
class Variadic<Type type> : TypeConstraint<type.predicate,
"variadic of " # type.summary,
type.cppClassName> {
type.cppType> {
Type baseType = type;
int minSize = 0;
}
Expand All @@ -140,7 +140,7 @@ class VariadicOfVariadic<Type type, string variadicSegmentAttrName>
// An optional type constraint. It expands to either zero or one of the base
// type. This class is used for supporting optional operands/results.
class Optional<Type type> : TypeConstraint<type.predicate, type.summary,
type.cppClassName> {
type.cppType> {
Type baseType = type;
}

Expand Down Expand Up @@ -172,33 +172,33 @@ def NoneType : Type<CPred<"::llvm::isa<::mlir::NoneType>($_self)">, "none type",

// Any type from the given list
class AnyTypeOf<list<Type> allowedTypeList, string summary = "",
string cppClassName = "::mlir::Type"> : Type<
string cppType = "::mlir::Type"> : Type<
// Satisfy any of the allowed types' conditions.
Or<!foreach(allowedtype, allowedTypeList, allowedtype.predicate)>,
!if(!eq(summary, ""),
!interleave(!foreach(t, allowedTypeList, t.summary), " or "),
summary),
cppClassName> {
cppType> {
list<Type> allowedTypes = allowedTypeList;
}

// A type that satisfies the constraints of all given types.
class AllOfType<list<Type> allowedTypeList, string summary = "",
string cppClassName = "::mlir::Type"> : Type<
string cppType = "::mlir::Type"> : Type<
// Satisfy all of the allowed types' conditions.
And<!foreach(allowedType, allowedTypeList, allowedType.predicate)>,
!if(!eq(summary, ""),
!interleave(!foreach(t, allowedTypeList, t.summary), " and "),
summary),
cppClassName> {
cppType> {
list<Type> allowedTypes = allowedTypeList;
}

// A type that satisfies additional predicates.
class ConfinedType<Type type, list<Pred> predicates, string summary = "",
string cppClassName = type.cppClassName> : Type<
string cppType = type.cppType> : Type<
And<!listconcat([type.predicate], !foreach(pred, predicates, pred))>,
summary, cppClassName>;
summary, cppType>;

// Integer types.

Expand Down Expand Up @@ -375,23 +375,23 @@ def FunctionType : Type<CPred<"::llvm::isa<::mlir::FunctionType>($_self)">,

// A container type is a type that has another type embedded within it.
class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
string descr, string cppClassName = "::mlir::Type"> :
string descr, string cppType = "::mlir::Type"> :
// First, check the container predicate. Then, substitute the extracted
// element into the element type checker.
Type<And<[containerPred,
SubstLeaves<"$_self", !cast<string>(elementTypeCall),
etype.predicate>]>,
descr # " of " # etype.summary # " values", cppClassName>;
descr # " of " # etype.summary # " values", cppType>;

class ShapedContainerType<list<Type> allowedTypes,
Pred containerPred, string descr,
string cppClassName = "::mlir::Type"> :
string cppType = "::mlir::Type"> :
Type<And<[containerPred,
Concat<"[](::mlir::Type elementType) { return ",
SubstLeaves<"$_self", "elementType",
AnyTypeOf<allowedTypes>.predicate>,
"; }(::llvm::cast<::mlir::ShapedType>($_self).getElementType())">]>,
descr # " of " # AnyTypeOf<allowedTypes>.summary # " values", cppClassName>;
descr # " of " # AnyTypeOf<allowedTypes>.summary # " values", cppType>;

// Whether a shaped type is ranked.
def HasRankPred : CPred<"::llvm::cast<::mlir::ShapedType>($_self).hasRank()">;
Expand Down
4 changes: 2 additions & 2 deletions mlir/include/mlir/IR/Constraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,10 @@ class Constraint<Pred pred, string desc = ""> {

// Subclass for constraints on a type.
class TypeConstraint<Pred predicate, string summary = "",
string cppClassNameParam = "::mlir::Type"> :
string cppTypeParam = "::mlir::Type"> :
Constraint<predicate, summary> {
// The name of the C++ Type class if known, or Type if not.
string cppClassName = cppClassNameParam;
string cppType = cppTypeParam;
}

// Subclass for constraints on an attribute.
Expand Down
4 changes: 2 additions & 2 deletions mlir/include/mlir/TableGen/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ class TypeConstraint : public Constraint {
// returns std::nullopt otherwise.
std::optional<StringRef> getBuilderCall() const;

// Return the C++ class name for this type (which may just be ::mlir::Type).
std::string getCPPClassName() const;
// Return the C++ type for this type (which may just be ::mlir::Type).
StringRef getCppType() const;
};

// Wrapper class with helper methods for accessing Types defined in TableGen.
Expand Down
17 changes: 3 additions & 14 deletions mlir/lib/TableGen/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,9 @@ std::optional<StringRef> TypeConstraint::getBuilderCall() const {
.Default([](auto *) { return std::nullopt; });
}

// Return the C++ class name for this type (which may just be ::mlir::Type).
std::string TypeConstraint::getCPPClassName() const {
StringRef className = def->getValueAsString("cppClassName");

// If the class name is already namespace resolved, use it.
if (className.contains("::"))
return className.str();

// Otherwise, check to see if there is a namespace from a dialect to prepend.
if (const llvm::RecordVal *value = def->getValue("dialect")) {
Dialect dialect(cast<const llvm::DefInit>(value->getValue())->getDef());
return (dialect.getCppNamespace() + "::" + className).str();
}
return className.str();
// Return the C++ type for this type (which may just be ::mlir::Type).
StringRef TypeConstraint::getCppType() const {
return def->getValueAsString("cppType");
}

Type::Type(const llvm::Record *record) : TypeConstraint(record) {}
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Tools/PDLL/Parser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -879,8 +879,7 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
-> const ods::TypeConstraint & {
return odsContext.insertTypeConstraint(
cst.constraint.getUniqueDefName(),
processDoc(cst.constraint.getSummary()),
cst.constraint.getCPPClassName());
processDoc(cst.constraint.getSummary()), cst.constraint.getCppType());
};
auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
Expand Down Expand Up @@ -944,7 +943,7 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
tblgen::TypeConstraint constraint(def);
decls.push_back(createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
constraint, convertLocToRange(def->getLoc().front()), typeTy,
constraint.getCPPClassName()));
constraint.getCppType()));
}
/// OpInterfaces.
ast::Type opTy = ast::OperationType::get(ctx);
Expand Down
19 changes: 3 additions & 16 deletions mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2085,21 +2085,8 @@ static void generateValueRangeStartAndEnd(
}

static std::string generateTypeForGetter(const NamedTypeConstraint &value) {
std::string str = "::mlir::Value";
/// If the CPPClassName is not a fully qualified type. Uses of types
/// across Dialect fail because they are not in the correct namespace. So we
/// dont generate TypedValue unless the type is fully qualified.
/// getCPPClassName doesn't return the fully qualified path for
/// `mlir::pdl::OperationType` see
/// https://github.com/llvm/llvm-project/issues/57279.
/// Adaptor will have values that are not from the type of their operation and
/// this is expected, so we dont generate TypedValue for Adaptor
if (value.constraint.getCPPClassName() != "::mlir::Type" &&
StringRef(value.constraint.getCPPClassName()).starts_with("::"))
str = llvm::formatv("::mlir::TypedValue<{0}>",
value.constraint.getCPPClassName())
.str();
return str;
return llvm::formatv("::mlir::TypedValue<{0}>", value.constraint.getCppType())
.str();
}

// Generates the named operand getter methods for the given Operator `op` and
Expand Down Expand Up @@ -3944,7 +3931,7 @@ void OpEmitter::genTraits() {
// For single result ops with a known specific type, generate a OneTypedResult
// trait.
if (numResults == 1 && numVariadicResults == 0) {
auto cppName = op.getResults().begin()->constraint.getCPPClassName();
auto cppName = op.getResults().begin()->constraint.getCppType();
opClass.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl");
}

Expand Down
6 changes: 3 additions & 3 deletions mlir/tools/mlir-tblgen/OpFormatGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1657,7 +1657,7 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
TypeSwitch<FormatElement *>(dir->getArg())
.Case<OperandVariable, ResultVariable>([&](auto operand) {
body << formatv(parserCode,
operand->getVar()->constraint.getCPPClassName(),
operand->getVar()->constraint.getCppType(),
listName);
})
.Default([&](auto operand) {
Expand Down Expand Up @@ -2603,7 +2603,7 @@ void OperationFormat::genElementPrinter(FormatElement *element,
}
if (var && !var->isVariadicOfVariadic() && !var->isVariadic() &&
!var->isOptional()) {
std::string cppClass = var->constraint.getCPPClassName();
StringRef cppType = var->constraint.getCppType();
if (dir->shouldBeQualified()) {
body << " _odsPrinter << " << op.getGetterName(var->name)
<< "().getType();\n";
Expand All @@ -2612,7 +2612,7 @@ void OperationFormat::genElementPrinter(FormatElement *element,
body << " {\n"
<< " auto type = " << op.getGetterName(var->name)
<< "().getType();\n"
<< " if (auto validType = ::llvm::dyn_cast<" << cppClass
<< " if (auto validType = ::llvm::dyn_cast<" << cppType
<< ">(type))\n"
<< " _odsPrinter.printStrippedAttrOrType(validType);\n"
<< " else\n"
Expand Down

0 comments on commit 9968cd5

Please sign in to comment.