Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][ODS] Consistent cppType / cppClassName usage #102657

Merged
merged 1 commit into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions mlir/include/mlir/IR/AttrTypeBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ class AttrDef<Dialect dialect, string name, list<Trait> traits = [],
AttrOrTypeDef<"Attr", name, traits, baseCppClass> {
// The name of the C++ Attribute class.
string cppClassName = name # "Attr";
let storageType = dialect.cppNamespace # "::" # name # "Attr";
let storageType = dialect.cppNamespace # "::" # cppClassName;

// The underlying C++ value type
let returnType = dialect.cppNamespace # "::" # cppClassName;
Expand All @@ -275,12 +275,10 @@ class AttrDef<Dialect dialect, string name, list<Trait> traits = [],
//
// For example, `$_self.getValue().getSExtValue()` for `IntegerAttr val` will
// expand to `getAttrOfType<IntegerAttr>("val").getValue().getSExtValue()`.
let convertFromStorage = "::llvm::cast<" # dialect.cppNamespace #
"::" # cppClassName # ">($_self)";
let convertFromStorage = "::llvm::cast<" # cppType # ">($_self)";

// The predicate for when this def is used as a constraint.
let predicate = CPred<"::llvm::isa<" # dialect.cppNamespace #
"::" # cppClassName # ">($_self)">;
let predicate = CPred<"::llvm::isa<" # cppType # ">($_self)">;
}

// Define a new type, named `name`, belonging to `dialect` that inherits from
Expand All @@ -289,6 +287,9 @@ class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
string baseCppClass = "::mlir::Type">
: DialectType<dialect, CPred<"">, /*descr*/"", name # "Type">,
AttrOrTypeDef<"Type", name, traits, baseCppClass> {
// The name of the C++ Type class.
string cppClassName = name # "Type";

// Make it possible to use such type as parameters for other types.
string cppType = dialect.cppNamespace # "::" # cppClassName;

Expand All @@ -297,12 +298,11 @@ class TypeDef<Dialect dialect, string name, list<Trait> traits = [],

// A constant builder provided when the type has no parameters.
let builderCall = !if(!empty(parameters),
"$_builder.getType<" # dialect.cppNamespace #
"::" # cppClassName # ">()",
"$_builder.getType<" # cppType # ">()",
"");

// The predicate for when this def is used as a constraint.
let predicate = CPred<"::llvm::isa<" # dialect.cppNamespace #
"::" # cppClassName # ">($_self)">;
let predicate = CPred<"::llvm::isa<" # cppType # ">($_self)">;
}

//===----------------------------------------------------------------------===//
Expand Down
6 changes: 3 additions & 3 deletions mlir/include/mlir/IR/CommonAttrConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,14 @@ def AnyAttr : Attr<CPred<"true">, "any attribute"> {

// Any attribute from the given list
class AnyAttrOf<list<Attr> allowedAttrs, string summary = "",
string cppClassName = "::mlir::Attribute",
string cppType = "::mlir::Attribute",
string fromStorage = "$_self"> : Attr<
// Satisfy any of the allowed attribute's condition
Or<!foreach(allowedattr, allowedAttrs, allowedattr.predicate)>,
!if(!eq(summary, ""),
!interleave(!foreach(t, allowedAttrs, t.summary), " or "),
summary)> {
let returnType = cppClassName;
let returnType = cppType;
let convertFromStorage = fromStorage;
}

Expand Down Expand Up @@ -369,7 +369,7 @@ def TypeAttr : TypeAttrBase<"::mlir::Type", "any type attribute"> {
}

class TypeAttrOf<Type ty>
: TypeAttrBase<ty.cppClassName, "type attribute of " # ty.summary,
: TypeAttrBase<ty.cppType, "type attribute of " # ty.summary,
ty.predicate> {
let constBuilderCall = "::mlir::TypeAttr::get($0)";
}
Expand Down
34 changes: 17 additions & 17 deletions mlir/include/mlir/IR/CommonTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -98,31 +98,31 @@ def HasValueSemanticsPred : CPred<"$_self.hasTrait<::mlir::ValueSemantics>()">;

// A type, carries type constraints.
class Type<Pred condition, string descr = "",
string cppClassName = "::mlir::Type"> :
TypeConstraint<condition, descr, cppClassName> {
string cppType = "::mlir::Type"> :
TypeConstraint<condition, descr, cppType> {
string description = "";
string builderCall = "";
}

// Allows providing an alternative name and summary to an existing type def.
class TypeAlias<Type t, string summary = t.summary> :
Type<t.predicate, summary, t.cppClassName> {
Type<t.predicate, summary, t.cppType> {
let description = t.description;
let builderCall = t.builderCall;
}

// A type of a specific dialect.
class DialectType<Dialect d, Pred condition, string descr = "",
string cppClassName = "::mlir::Type"> :
Type<condition, descr, cppClassName> {
string cppType = "::mlir::Type"> :
Type<condition, descr, cppType> {
Dialect dialect = d;
}

// A variadic type constraint. It expands to zero or more of the base type. This
// class is used for supporting variadic operands/results.
class Variadic<Type type> : TypeConstraint<type.predicate,
"variadic of " # type.summary,
type.cppClassName> {
type.cppType> {
Type baseType = type;
int minSize = 0;
}
Expand All @@ -140,7 +140,7 @@ class VariadicOfVariadic<Type type, string variadicSegmentAttrName>
// An optional type constraint. It expands to either zero or one of the base
// type. This class is used for supporting optional operands/results.
class Optional<Type type> : TypeConstraint<type.predicate, type.summary,
type.cppClassName> {
type.cppType> {
Type baseType = type;
}

Expand Down Expand Up @@ -172,33 +172,33 @@ def NoneType : Type<CPred<"::llvm::isa<::mlir::NoneType>($_self)">, "none type",

// Any type from the given list
class AnyTypeOf<list<Type> allowedTypeList, string summary = "",
string cppClassName = "::mlir::Type"> : Type<
string cppType = "::mlir::Type"> : Type<
// Satisfy any of the allowed types' conditions.
Or<!foreach(allowedtype, allowedTypeList, allowedtype.predicate)>,
!if(!eq(summary, ""),
!interleave(!foreach(t, allowedTypeList, t.summary), " or "),
summary),
cppClassName> {
cppType> {
list<Type> allowedTypes = allowedTypeList;
}

// A type that satisfies the constraints of all given types.
class AllOfType<list<Type> allowedTypeList, string summary = "",
string cppClassName = "::mlir::Type"> : Type<
string cppType = "::mlir::Type"> : Type<
// Satisfy all of the allowed types' conditions.
And<!foreach(allowedType, allowedTypeList, allowedType.predicate)>,
!if(!eq(summary, ""),
!interleave(!foreach(t, allowedTypeList, t.summary), " and "),
summary),
cppClassName> {
cppType> {
list<Type> allowedTypes = allowedTypeList;
}

// A type that satisfies additional predicates.
class ConfinedType<Type type, list<Pred> predicates, string summary = "",
string cppClassName = type.cppClassName> : Type<
string cppType = type.cppType> : Type<
And<!listconcat([type.predicate], !foreach(pred, predicates, pred))>,
summary, cppClassName>;
summary, cppType>;

// Integer types.

Expand Down Expand Up @@ -375,23 +375,23 @@ def FunctionType : Type<CPred<"::llvm::isa<::mlir::FunctionType>($_self)">,

// A container type is a type that has another type embedded within it.
class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
string descr, string cppClassName = "::mlir::Type"> :
string descr, string cppType = "::mlir::Type"> :
// First, check the container predicate. Then, substitute the extracted
// element into the element type checker.
Type<And<[containerPred,
SubstLeaves<"$_self", !cast<string>(elementTypeCall),
etype.predicate>]>,
descr # " of " # etype.summary # " values", cppClassName>;
descr # " of " # etype.summary # " values", cppType>;

class ShapedContainerType<list<Type> allowedTypes,
Pred containerPred, string descr,
string cppClassName = "::mlir::Type"> :
string cppType = "::mlir::Type"> :
Type<And<[containerPred,
Concat<"[](::mlir::Type elementType) { return ",
SubstLeaves<"$_self", "elementType",
AnyTypeOf<allowedTypes>.predicate>,
"; }(::llvm::cast<::mlir::ShapedType>($_self).getElementType())">]>,
descr # " of " # AnyTypeOf<allowedTypes>.summary # " values", cppClassName>;
descr # " of " # AnyTypeOf<allowedTypes>.summary # " values", cppType>;

// Whether a shaped type is ranked.
def HasRankPred : CPred<"::llvm::cast<::mlir::ShapedType>($_self).hasRank()">;
Expand Down
4 changes: 2 additions & 2 deletions mlir/include/mlir/IR/Constraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,10 @@ class Constraint<Pred pred, string desc = ""> {

// Subclass for constraints on a type.
class TypeConstraint<Pred predicate, string summary = "",
string cppClassNameParam = "::mlir::Type"> :
string cppTypeParam = "::mlir::Type"> :
Constraint<predicate, summary> {
// The name of the C++ Type class if known, or Type if not.
string cppClassName = cppClassNameParam;
string cppType = cppTypeParam;
}

// Subclass for constraints on an attribute.
Expand Down
4 changes: 2 additions & 2 deletions mlir/include/mlir/TableGen/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ class TypeConstraint : public Constraint {
// returns std::nullopt otherwise.
std::optional<StringRef> getBuilderCall() const;

// Return the C++ class name for this type (which may just be ::mlir::Type).
std::string getCPPClassName() const;
// Return the C++ type for this type (which may just be ::mlir::Type).
StringRef getCppType() const;
};

// Wrapper class with helper methods for accessing Types defined in TableGen.
Expand Down
17 changes: 3 additions & 14 deletions mlir/lib/TableGen/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,9 @@ std::optional<StringRef> TypeConstraint::getBuilderCall() const {
.Default([](auto *) { return std::nullopt; });
}

// Return the C++ class name for this type (which may just be ::mlir::Type).
std::string TypeConstraint::getCPPClassName() const {
StringRef className = def->getValueAsString("cppClassName");

// If the class name is already namespace resolved, use it.
if (className.contains("::"))
return className.str();

// Otherwise, check to see if there is a namespace from a dialect to prepend.
if (const llvm::RecordVal *value = def->getValue("dialect")) {
Dialect dialect(cast<const llvm::DefInit>(value->getValue())->getDef());
return (dialect.getCppNamespace() + "::" + className).str();
}
return className.str();
// Return the C++ type for this type (which may just be ::mlir::Type).
StringRef TypeConstraint::getCppType() const {
return def->getValueAsString("cppType");
}

Type::Type(const llvm::Record *record) : TypeConstraint(record) {}
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Tools/PDLL/Parser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -879,8 +879,7 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
-> const ods::TypeConstraint & {
return odsContext.insertTypeConstraint(
cst.constraint.getUniqueDefName(),
processDoc(cst.constraint.getSummary()),
cst.constraint.getCPPClassName());
processDoc(cst.constraint.getSummary()), cst.constraint.getCppType());
};
auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
Expand Down Expand Up @@ -944,7 +943,7 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
tblgen::TypeConstraint constraint(def);
decls.push_back(createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
constraint, convertLocToRange(def->getLoc().front()), typeTy,
constraint.getCPPClassName()));
constraint.getCppType()));
}
/// OpInterfaces.
ast::Type opTy = ast::OperationType::get(ctx);
Expand Down
19 changes: 3 additions & 16 deletions mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2085,21 +2085,8 @@ static void generateValueRangeStartAndEnd(
}

static std::string generateTypeForGetter(const NamedTypeConstraint &value) {
std::string str = "::mlir::Value";
/// If the CPPClassName is not a fully qualified type. Uses of types
/// across Dialect fail because they are not in the correct namespace. So we
/// dont generate TypedValue unless the type is fully qualified.
/// getCPPClassName doesn't return the fully qualified path for
/// `mlir::pdl::OperationType` see
/// https://github.com/llvm/llvm-project/issues/57279.
/// Adaptor will have values that are not from the type of their operation and
/// this is expected, so we dont generate TypedValue for Adaptor
if (value.constraint.getCPPClassName() != "::mlir::Type" &&
StringRef(value.constraint.getCPPClassName()).starts_with("::"))
str = llvm::formatv("::mlir::TypedValue<{0}>",
value.constraint.getCPPClassName())
.str();
return str;
return llvm::formatv("::mlir::TypedValue<{0}>", value.constraint.getCppType())
.str();
}

// Generates the named operand getter methods for the given Operator `op` and
Expand Down Expand Up @@ -3944,7 +3931,7 @@ void OpEmitter::genTraits() {
// For single result ops with a known specific type, generate a OneTypedResult
// trait.
if (numResults == 1 && numVariadicResults == 0) {
auto cppName = op.getResults().begin()->constraint.getCPPClassName();
auto cppName = op.getResults().begin()->constraint.getCppType();
opClass.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl");
}

Expand Down
6 changes: 3 additions & 3 deletions mlir/tools/mlir-tblgen/OpFormatGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1657,7 +1657,7 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
TypeSwitch<FormatElement *>(dir->getArg())
.Case<OperandVariable, ResultVariable>([&](auto operand) {
body << formatv(parserCode,
operand->getVar()->constraint.getCPPClassName(),
operand->getVar()->constraint.getCppType(),
listName);
})
.Default([&](auto operand) {
Expand Down Expand Up @@ -2603,7 +2603,7 @@ void OperationFormat::genElementPrinter(FormatElement *element,
}
if (var && !var->isVariadicOfVariadic() && !var->isVariadic() &&
!var->isOptional()) {
std::string cppClass = var->constraint.getCPPClassName();
StringRef cppType = var->constraint.getCppType();
if (dir->shouldBeQualified()) {
body << " _odsPrinter << " << op.getGetterName(var->name)
<< "().getType();\n";
Expand All @@ -2612,7 +2612,7 @@ void OperationFormat::genElementPrinter(FormatElement *element,
body << " {\n"
<< " auto type = " << op.getGetterName(var->name)
<< "().getType();\n"
<< " if (auto validType = ::llvm::dyn_cast<" << cppClass
<< " if (auto validType = ::llvm::dyn_cast<" << cppType
<< ">(type))\n"
<< " _odsPrinter.printStrippedAttrOrType(validType);\n"
<< " else\n"
Expand Down
Loading