From e05b9d7a9d5380e8d40e140a1133f306ee911036 Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Thu, 5 Sep 2024 18:35:55 -0700 Subject: [PATCH] [TableGen] Add const variants of accessors for backend (#106658) Split RecordKeeper `getAllDerivedDefinitions` family of functions into two variants: (a) non-const ones that return vectors of `Record *` and (b) const ones, that return vector/ArrayRef of `const Record *`. This will help gradual migration of TableGen backends to use `const RecordKeeper` and by implication change code to work with const pointers and better const correctness. Existing backends are not yet compatible with the const family of functions, so change them to use a non-constant `RecordKeeper` reference, till they are migrated. --- clang/utils/TableGen/ClangAttrEmitter.cpp | 4 +- clang/utils/TableGen/ClangSyntaxEmitter.cpp | 2 +- llvm/include/llvm/TableGen/DirectiveEmitter.h | 5 +- llvm/include/llvm/TableGen/Record.h | 34 ++++++++-- llvm/lib/TableGen/Record.cpp | 65 +++++++++++++++---- .../TableGen/Basic/CodeGenIntrinsics.cpp | 2 +- .../TableGen/Common/SubtargetFeatureInfo.cpp | 2 +- .../TableGen/Common/SubtargetFeatureInfo.h | 2 +- llvm/utils/TableGen/ExegesisEmitter.cpp | 2 +- llvm/utils/TableGen/GlobalISelEmitter.cpp | 2 +- llvm/utils/TableGen/SubtargetEmitter.cpp | 2 +- llvm/utils/TableGen/TableGen.cpp | 2 +- mlir/include/mlir/TableGen/GenInfo.h | 6 +- mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp | 28 ++++---- mlir/tools/mlir-tblgen/OmpOpGen.cpp | 2 +- mlir/tools/mlir-tblgen/OpDocGen.cpp | 20 +++--- mlir/tools/mlir-tblgen/OpInterfacesGen.cpp | 15 ++--- mlir/tools/mlir-tblgen/RewriterGen.cpp | 8 +-- .../tools/tblgen-to-irdl/OpDefinitionsGen.cpp | 7 +- 19 files changed, 134 insertions(+), 76 deletions(-) diff --git a/clang/utils/TableGen/ClangAttrEmitter.cpp b/clang/utils/TableGen/ClangAttrEmitter.cpp index adbe6af62d5cbe3..d24215d10f17c70 100644 --- a/clang/utils/TableGen/ClangAttrEmitter.cpp +++ b/clang/utils/TableGen/ClangAttrEmitter.cpp @@ -189,7 +189,7 @@ static StringRef NormalizeGNUAttrSpelling(StringRef AttrSpelling) { typedef std::vector> ParsedAttrMap; -static ParsedAttrMap getParsedAttrList(const RecordKeeper &Records, +static ParsedAttrMap getParsedAttrList(RecordKeeper &Records, ParsedAttrMap *Dupes = nullptr, bool SemaOnly = true) { std::vector Attrs = Records.getAllDerivedDefinitions("Attr"); @@ -4344,7 +4344,7 @@ static void GenerateAppertainsTo(const Record &Attr, raw_ostream &OS) { // written into OS and the checks for merging declaration attributes are // written into MergeOS. static void GenerateMutualExclusionsChecks(const Record &Attr, - const RecordKeeper &Records, + RecordKeeper &Records, raw_ostream &OS, raw_ostream &MergeDeclOS, raw_ostream &MergeStmtOS) { diff --git a/clang/utils/TableGen/ClangSyntaxEmitter.cpp b/clang/utils/TableGen/ClangSyntaxEmitter.cpp index 9720d587318432e..2a69e4c353b6b44 100644 --- a/clang/utils/TableGen/ClangSyntaxEmitter.cpp +++ b/clang/utils/TableGen/ClangSyntaxEmitter.cpp @@ -41,7 +41,7 @@ using llvm::formatv; // stable and useful way, where abstract Node subclasses correspond to ranges. class Hierarchy { public: - Hierarchy(const llvm::RecordKeeper &Records) { + Hierarchy(llvm::RecordKeeper &Records) { for (llvm::Record *T : Records.getAllDerivedDefinitions("NodeType")) add(T); for (llvm::Record *Derived : Records.getAllDerivedDefinitions("NodeType")) diff --git a/llvm/include/llvm/TableGen/DirectiveEmitter.h b/llvm/include/llvm/TableGen/DirectiveEmitter.h index 1121459be6ce7d1..ca21c8fc1014503 100644 --- a/llvm/include/llvm/TableGen/DirectiveEmitter.h +++ b/llvm/include/llvm/TableGen/DirectiveEmitter.h @@ -15,8 +15,7 @@ namespace llvm { // DirectiveBase.td and provides helper methods for accessing it. class DirectiveLanguage { public: - explicit DirectiveLanguage(const llvm::RecordKeeper &Records) - : Records(Records) { + explicit DirectiveLanguage(llvm::RecordKeeper &Records) : Records(Records) { const auto &DirectiveLanguages = getDirectiveLanguages(); Def = DirectiveLanguages[0]; } @@ -71,7 +70,7 @@ class DirectiveLanguage { private: const llvm::Record *Def; - const llvm::RecordKeeper &Records; + llvm::RecordKeeper &Records; std::vector getDirectiveLanguages() const { return Records.getAllDerivedDefinitions("DirectiveLanguage"); diff --git a/llvm/include/llvm/TableGen/Record.h b/llvm/include/llvm/TableGen/Record.h index ff596df94e4f5a1..5d36fcf57e23e35 100644 --- a/llvm/include/llvm/TableGen/Record.h +++ b/llvm/include/llvm/TableGen/Record.h @@ -2057,19 +2057,28 @@ class RecordKeeper { //===--------------------------------------------------------------------===// // High-level helper methods, useful for tablegen backends. + // Non-const methods return std::vector by value or reference. + // Const methods return std::vector by value or + // ArrayRef. + /// Get all the concrete records that inherit from the one specified /// class. The class must be defined. - std::vector getAllDerivedDefinitions(StringRef ClassName) const; + ArrayRef getAllDerivedDefinitions(StringRef ClassName) const; + const std::vector &getAllDerivedDefinitions(StringRef ClassName); /// Get all the concrete records that inherit from all the specified /// classes. The classes must be defined. - std::vector getAllDerivedDefinitions( - ArrayRef ClassNames) const; + std::vector + getAllDerivedDefinitions(ArrayRef ClassNames) const; + std::vector + getAllDerivedDefinitions(ArrayRef ClassNames); /// Get all the concrete records that inherit from specified class, if the /// class is defined. Returns an empty vector if the class is not defined. - std::vector + ArrayRef getAllDerivedDefinitionsIfDefined(StringRef ClassName) const; + const std::vector & + getAllDerivedDefinitionsIfDefined(StringRef ClassName); void dump() const; @@ -2081,9 +2090,24 @@ class RecordKeeper { RecordKeeper &operator=(RecordKeeper &&) = delete; RecordKeeper &operator=(const RecordKeeper &) = delete; + // Helper template functions for backend accessors. + template + const VecTy & + getAllDerivedDefinitionsImpl(StringRef ClassName, + std::map &Cache) const; + + template + VecTy getAllDerivedDefinitionsImpl(ArrayRef ClassNames) const; + + template + const VecTy &getAllDerivedDefinitionsIfDefinedImpl( + StringRef ClassName, std::map &Cache) const; + std::string InputFilename; RecordMap Classes, Defs; - mutable StringMap> ClassRecordsMap; + mutable std::map> + ClassRecordsMapConst; + mutable std::map> ClassRecordsMap; GlobalMap ExtraGlobals; // These members are for the phase timing feature. We need a timer group, diff --git a/llvm/lib/TableGen/Record.cpp b/llvm/lib/TableGen/Record.cpp index cead8f865a60795..17afa2f7eb1b999 100644 --- a/llvm/lib/TableGen/Record.cpp +++ b/llvm/lib/TableGen/Record.cpp @@ -3248,25 +3248,28 @@ void RecordKeeper::stopBackendTimer() { } } -std::vector -RecordKeeper::getAllDerivedDefinitions(StringRef ClassName) const { +template +const VecTy &RecordKeeper::getAllDerivedDefinitionsImpl( + StringRef ClassName, std::map &Cache) const { // We cache the record vectors for single classes. Many backends request // the same vectors multiple times. - auto Pair = ClassRecordsMap.try_emplace(ClassName); + auto Pair = Cache.try_emplace(ClassName.str()); if (Pair.second) - Pair.first->second = getAllDerivedDefinitions(ArrayRef(ClassName)); + Pair.first->second = + getAllDerivedDefinitionsImpl(ArrayRef(ClassName)); return Pair.first->second; } -std::vector RecordKeeper::getAllDerivedDefinitions( +template +VecTy RecordKeeper::getAllDerivedDefinitionsImpl( ArrayRef ClassNames) const { - SmallVector ClassRecs; - std::vector Defs; + SmallVector ClassRecs; + VecTy Defs; assert(ClassNames.size() > 0 && "At least one class must be passed."); for (const auto &ClassName : ClassNames) { - Record *Class = getClass(ClassName); + const Record *Class = getClass(ClassName); if (!Class) PrintFatalError("The class '" + ClassName + "' is not defined\n"); ClassRecs.push_back(Class); @@ -3274,20 +3277,54 @@ std::vector RecordKeeper::getAllDerivedDefinitions( for (const auto &OneDef : getDefs()) { if (all_of(ClassRecs, [&OneDef](const Record *Class) { - return OneDef.second->isSubClassOf(Class); - })) + return OneDef.second->isSubClassOf(Class); + })) Defs.push_back(OneDef.second.get()); } - llvm::sort(Defs, LessRecord()); - return Defs; } +template +const VecTy &RecordKeeper::getAllDerivedDefinitionsIfDefinedImpl( + StringRef ClassName, std::map &Cache) const { + return getClass(ClassName) + ? getAllDerivedDefinitionsImpl(ClassName, Cache) + : Cache[""]; +} + +ArrayRef +RecordKeeper::getAllDerivedDefinitions(StringRef ClassName) const { + return getAllDerivedDefinitionsImpl>( + ClassName, ClassRecordsMapConst); +} + +const std::vector & +RecordKeeper::getAllDerivedDefinitions(StringRef ClassName) { + return getAllDerivedDefinitionsImpl>(ClassName, + ClassRecordsMap); +} + +std::vector +RecordKeeper::getAllDerivedDefinitions(ArrayRef ClassNames) const { + return getAllDerivedDefinitionsImpl>(ClassNames); +} + std::vector +RecordKeeper::getAllDerivedDefinitions(ArrayRef ClassNames) { + return getAllDerivedDefinitionsImpl>(ClassNames); +} + +ArrayRef RecordKeeper::getAllDerivedDefinitionsIfDefined(StringRef ClassName) const { - return getClass(ClassName) ? getAllDerivedDefinitions(ClassName) - : std::vector(); + return getAllDerivedDefinitionsIfDefinedImpl>( + ClassName, ClassRecordsMapConst); +} + +const std::vector & +RecordKeeper::getAllDerivedDefinitionsIfDefined(StringRef ClassName) { + return getAllDerivedDefinitionsIfDefinedImpl>( + ClassName, ClassRecordsMap); } void RecordKeeper::dumpAllocationStats(raw_ostream &OS) const { diff --git a/llvm/utils/TableGen/Basic/CodeGenIntrinsics.cpp b/llvm/utils/TableGen/Basic/CodeGenIntrinsics.cpp index 23c64912c780f3c..05104e938b84867 100644 --- a/llvm/utils/TableGen/Basic/CodeGenIntrinsics.cpp +++ b/llvm/utils/TableGen/Basic/CodeGenIntrinsics.cpp @@ -43,7 +43,7 @@ CodeGenIntrinsicContext::CodeGenIntrinsicContext(const RecordKeeper &RC) { CodeGenIntrinsicTable::CodeGenIntrinsicTable(const RecordKeeper &RC) { CodeGenIntrinsicContext Ctx(RC); - std::vector Defs = RC.getAllDerivedDefinitions("Intrinsic"); + ArrayRef Defs = RC.getAllDerivedDefinitions("Intrinsic"); Intrinsics.reserve(Defs.size()); for (const Record *Def : Defs) diff --git a/llvm/utils/TableGen/Common/SubtargetFeatureInfo.cpp b/llvm/utils/TableGen/Common/SubtargetFeatureInfo.cpp index 4f57234d6fe2754..a4d6d8d21b3562e 100644 --- a/llvm/utils/TableGen/Common/SubtargetFeatureInfo.cpp +++ b/llvm/utils/TableGen/Common/SubtargetFeatureInfo.cpp @@ -21,7 +21,7 @@ LLVM_DUMP_METHOD void SubtargetFeatureInfo::dump() const { #endif std::vector> -SubtargetFeatureInfo::getAll(const RecordKeeper &Records) { +SubtargetFeatureInfo::getAll(RecordKeeper &Records) { std::vector> SubtargetFeatures; std::vector AllPredicates = Records.getAllDerivedDefinitions("Predicate"); diff --git a/llvm/utils/TableGen/Common/SubtargetFeatureInfo.h b/llvm/utils/TableGen/Common/SubtargetFeatureInfo.h index 2635e4b733e1a35..fee2c0263c4960a 100644 --- a/llvm/utils/TableGen/Common/SubtargetFeatureInfo.h +++ b/llvm/utils/TableGen/Common/SubtargetFeatureInfo.h @@ -49,7 +49,7 @@ struct SubtargetFeatureInfo { void dump() const; static std::vector> - getAll(const RecordKeeper &Records); + getAll(RecordKeeper &Records); /// Emit the subtarget feature flag definitions. /// diff --git a/llvm/utils/TableGen/ExegesisEmitter.cpp b/llvm/utils/TableGen/ExegesisEmitter.cpp index d48c7f3a480f24d..0de7cb423374816 100644 --- a/llvm/utils/TableGen/ExegesisEmitter.cpp +++ b/llvm/utils/TableGen/ExegesisEmitter.cpp @@ -59,7 +59,7 @@ class ExegesisEmitter { }; static std::map -collectPfmCounters(const RecordKeeper &Records) { +collectPfmCounters(RecordKeeper &Records) { std::map PfmCounterNameTable; const auto AddPfmCounterName = [&PfmCounterNameTable]( const Record *PfmCounterDef) { diff --git a/llvm/utils/TableGen/GlobalISelEmitter.cpp b/llvm/utils/TableGen/GlobalISelEmitter.cpp index a491a049e7c8124..2606768c0c582cd 100644 --- a/llvm/utils/TableGen/GlobalISelEmitter.cpp +++ b/llvm/utils/TableGen/GlobalISelEmitter.cpp @@ -335,7 +335,7 @@ class GlobalISelEmitter final : public GlobalISelMatchTableExecutorEmitter { private: std::string ClassName; - const RecordKeeper &RK; + RecordKeeper &RK; const CodeGenDAGPatterns CGP; const CodeGenTarget &Target; CodeGenRegBank &CGRegs; diff --git a/llvm/utils/TableGen/SubtargetEmitter.cpp b/llvm/utils/TableGen/SubtargetEmitter.cpp index 66ca38ee5ae2f87..7ae61cb7c446b17 100644 --- a/llvm/utils/TableGen/SubtargetEmitter.cpp +++ b/llvm/utils/TableGen/SubtargetEmitter.cpp @@ -1545,7 +1545,7 @@ void SubtargetEmitter::EmitSchedModel(raw_ostream &OS) { EmitProcessorModels(OS); } -static void emitPredicateProlog(const RecordKeeper &Records, raw_ostream &OS) { +static void emitPredicateProlog(RecordKeeper &Records, raw_ostream &OS) { std::string Buffer; raw_string_ostream Stream(Buffer); diff --git a/llvm/utils/TableGen/TableGen.cpp b/llvm/utils/TableGen/TableGen.cpp index 7ee6fa5c832114b..882410bac081b44 100644 --- a/llvm/utils/TableGen/TableGen.cpp +++ b/llvm/utils/TableGen/TableGen.cpp @@ -52,7 +52,7 @@ static void PrintEnums(RecordKeeper &Records, raw_ostream &OS) { static void PrintSets(const RecordKeeper &Records, raw_ostream &OS) { SetTheory Sets; Sets.addFieldExpander("Set", "Elements"); - for (Record *Rec : Records.getAllDerivedDefinitions("Set")) { + for (const Record *Rec : Records.getAllDerivedDefinitions("Set")) { OS << Rec->getName() << " = ["; const std::vector *Elts = Sets.expand(Rec); assert(Elts && "Couldn't expand Set instance"); diff --git a/mlir/include/mlir/TableGen/GenInfo.h b/mlir/include/mlir/TableGen/GenInfo.h index ef2e12f07df16d3..d59d64223827bdd 100644 --- a/mlir/include/mlir/TableGen/GenInfo.h +++ b/mlir/include/mlir/TableGen/GenInfo.h @@ -21,8 +21,8 @@ class RecordKeeper; namespace mlir { /// Generator function to invoke. -using GenFunction = std::function; +using GenFunction = + std::function; /// Structure to group information about a generator (argument to invoke via /// mlir-tblgen, description, and generator function). @@ -34,7 +34,7 @@ class GenInfo { : arg(arg), description(description), generator(std::move(generator)) {} /// Invokes the generator and returns whether the generator failed. - bool invoke(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) const { + bool invoke(llvm::RecordKeeper &recordKeeper, raw_ostream &os) const { assert(generator && "Cannot call generator with null generator"); return generator(recordKeeper, os); } diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp index eccd8029d950ff4..feca04bff643d5f 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -690,10 +690,10 @@ class DefGenerator { bool emitDefs(StringRef selectedDialect); protected: - DefGenerator(std::vector &&defs, raw_ostream &os, + DefGenerator(const std::vector &defs, raw_ostream &os, StringRef defType, StringRef valueType, bool isAttrGenerator) - : defRecords(std::move(defs)), os(os), defType(defType), - valueType(valueType), isAttrGenerator(isAttrGenerator) { + : defRecords(defs), os(os), defType(defType), valueType(valueType), + isAttrGenerator(isAttrGenerator) { // Sort by occurrence in file. llvm::sort(defRecords, [](llvm::Record *lhs, llvm::Record *rhs) { return lhs->getID() < rhs->getID(); @@ -721,13 +721,13 @@ class DefGenerator { /// A specialized generator for AttrDefs. struct AttrDefGenerator : public DefGenerator { - AttrDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os) + AttrDefGenerator(llvm::RecordKeeper &records, raw_ostream &os) : DefGenerator(records.getAllDerivedDefinitionsIfDefined("AttrDef"), os, "Attr", "Attribute", /*isAttrGenerator=*/true) {} }; /// A specialized generator for TypeDefs. struct TypeDefGenerator : public DefGenerator { - TypeDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os) + TypeDefGenerator(llvm::RecordKeeper &records, raw_ostream &os) : DefGenerator(records.getAllDerivedDefinitionsIfDefined("TypeDef"), os, "Type", "Type", /*isAttrGenerator=*/false) {} }; @@ -1029,7 +1029,7 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) { /// Find all type constraints for which a C++ function should be generated. static std::vector -getAllTypeConstraints(const llvm::RecordKeeper &records) { +getAllTypeConstraints(llvm::RecordKeeper &records) { std::vector result; for (llvm::Record *def : records.getAllDerivedDefinitionsIfDefined("TypeConstraint")) { @@ -1046,7 +1046,7 @@ getAllTypeConstraints(const llvm::RecordKeeper &records) { return result; } -static void emitTypeConstraintDecls(const llvm::RecordKeeper &records, +static void emitTypeConstraintDecls(llvm::RecordKeeper &records, raw_ostream &os) { static const char *const typeConstraintDecl = R"( bool {0}(::mlir::Type type); @@ -1056,7 +1056,7 @@ bool {0}(::mlir::Type type); os << strfmt(typeConstraintDecl, *constr.getCppFunctionName()); } -static void emitTypeConstraintDefs(const llvm::RecordKeeper &records, +static void emitTypeConstraintDefs(llvm::RecordKeeper &records, raw_ostream &os) { static const char *const typeConstraintDef = R"( bool {0}(::mlir::Type type) { @@ -1087,13 +1087,13 @@ static llvm::cl::opt static mlir::GenRegistration genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions", - [](const llvm::RecordKeeper &records, raw_ostream &os) { + [](llvm::RecordKeeper &records, raw_ostream &os) { AttrDefGenerator generator(records, os); return generator.emitDefs(attrDialect); }); static mlir::GenRegistration genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations", - [](const llvm::RecordKeeper &records, raw_ostream &os) { + [](llvm::RecordKeeper &records, raw_ostream &os) { AttrDefGenerator generator(records, os); return generator.emitDecls(attrDialect); }); @@ -1109,13 +1109,13 @@ static llvm::cl::opt static mlir::GenRegistration genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions", - [](const llvm::RecordKeeper &records, raw_ostream &os) { + [](llvm::RecordKeeper &records, raw_ostream &os) { TypeDefGenerator generator(records, os); return generator.emitDefs(typeDialect); }); static mlir::GenRegistration genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations", - [](const llvm::RecordKeeper &records, raw_ostream &os) { + [](llvm::RecordKeeper &records, raw_ostream &os) { TypeDefGenerator generator(records, os); return generator.emitDecls(typeDialect); }); @@ -1123,14 +1123,14 @@ static mlir::GenRegistration static mlir::GenRegistration genTypeConstrDefs("gen-type-constraint-defs", "Generate type constraint definitions", - [](const llvm::RecordKeeper &records, raw_ostream &os) { + [](llvm::RecordKeeper &records, raw_ostream &os) { emitTypeConstraintDefs(records, os); return false; }); static mlir::GenRegistration genTypeConstrDecls("gen-type-constraint-decls", "Generate type constraint declarations", - [](const llvm::RecordKeeper &records, raw_ostream &os) { + [](llvm::RecordKeeper &records, raw_ostream &os) { emitTypeConstraintDecls(records, os); return false; }); diff --git a/mlir/tools/mlir-tblgen/OmpOpGen.cpp b/mlir/tools/mlir-tblgen/OmpOpGen.cpp index ffa2e17cc8f9168..b7f6ca975a9a34d 100644 --- a/mlir/tools/mlir-tblgen/OmpOpGen.cpp +++ b/mlir/tools/mlir-tblgen/OmpOpGen.cpp @@ -149,7 +149,7 @@ static void verifyClause(Record *op, Record *clause) { /// Verify that all properties of `OpenMP_Clause`s of records deriving from /// `OpenMP_Op`s have been inherited by the latter. -static bool verifyDecls(const RecordKeeper &recordKeeper, raw_ostream &) { +static bool verifyDecls(RecordKeeper &recordKeeper, raw_ostream &) { for (Record *op : recordKeeper.getAllDerivedDefinitions("OpenMP_Op")) { for (Record *clause : op->getValueAsListOfDefs("clauseList")) verifyClause(op, clause); diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp index 71df80cd110f151..066e5b24f5a3c17 100644 --- a/mlir/tools/mlir-tblgen/OpDocGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp @@ -282,7 +282,7 @@ static void emitSourceLink(StringRef inputFilename, raw_ostream &os) { << inputFromMlirInclude << ")\n\n"; } -static void emitOpDoc(const RecordKeeper &recordKeeper, raw_ostream &os) { +static void emitOpDoc(RecordKeeper &recordKeeper, raw_ostream &os) { auto opDefs = getRequestedOpDefinitions(recordKeeper); os << "\n"; @@ -371,8 +371,8 @@ static void emitAttrOrTypeDefDoc(const AttrOrTypeDef &def, raw_ostream &os) { os << "\n"; } -static void emitAttrOrTypeDefDoc(const RecordKeeper &recordKeeper, - raw_ostream &os, StringRef recordTypeName) { +static void emitAttrOrTypeDefDoc(RecordKeeper &recordKeeper, raw_ostream &os, + StringRef recordTypeName) { std::vector defs = recordKeeper.getAllDerivedDefinitions(recordTypeName); @@ -405,7 +405,7 @@ static void emitEnumDoc(const EnumAttr &def, raw_ostream &os) { os << "\n"; } -static void emitEnumDoc(const RecordKeeper &recordKeeper, raw_ostream &os) { +static void emitEnumDoc(RecordKeeper &recordKeeper, raw_ostream &os) { std::vector defs = recordKeeper.getAllDerivedDefinitions("EnumAttr"); @@ -518,7 +518,7 @@ static void emitDialectDoc(const Dialect &dialect, StringRef inputFilename, os); } -static bool emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) { +static bool emitDialectDoc(RecordKeeper &recordKeeper, raw_ostream &os) { std::vector dialectDefs = recordKeeper.getAllDerivedDefinitionsIfDefined("Dialect"); SmallVector dialects(dialectDefs.begin(), dialectDefs.end()); @@ -617,34 +617,34 @@ static bool emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) { static mlir::GenRegistration genAttrRegister("gen-attrdef-doc", "Generate dialect attribute documentation", - [](const RecordKeeper &records, raw_ostream &os) { + [](RecordKeeper &records, raw_ostream &os) { emitAttrOrTypeDefDoc(records, os, "AttrDef"); return false; }); static mlir::GenRegistration genOpRegister("gen-op-doc", "Generate dialect documentation", - [](const RecordKeeper &records, raw_ostream &os) { + [](RecordKeeper &records, raw_ostream &os) { emitOpDoc(records, os); return false; }); static mlir::GenRegistration genTypeRegister("gen-typedef-doc", "Generate dialect type documentation", - [](const RecordKeeper &records, raw_ostream &os) { + [](RecordKeeper &records, raw_ostream &os) { emitAttrOrTypeDefDoc(records, os, "TypeDef"); return false; }); static mlir::GenRegistration genEnumRegister("gen-enum-doc", "Generate dialect enum documentation", - [](const RecordKeeper &records, raw_ostream &os) { + [](RecordKeeper &records, raw_ostream &os) { emitEnumDoc(records, os); return false; }); static mlir::GenRegistration genRegister("gen-dialect-doc", "Generate dialect documentation", - [](const RecordKeeper &records, raw_ostream &os) { + [](RecordKeeper &records, raw_ostream &os) { return emitDialectDoc(records, os); }); diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp index 4b06b92fbc8a8e0..00f21a1cefbdd8e 100644 --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -62,8 +62,7 @@ static void emitMethodNameAndArgs(const InterfaceMethod &method, /// Get an array of all OpInterface definitions but exclude those subclassing /// "DeclareOpInterfaceMethods". static std::vector -getAllInterfaceDefinitions(const llvm::RecordKeeper &recordKeeper, - StringRef name) { +getAllInterfaceDefinitions(llvm::RecordKeeper &recordKeeper, StringRef name) { std::vector defs = recordKeeper.getAllDerivedDefinitions((name + "Interface").str()); @@ -118,7 +117,7 @@ class InterfaceGenerator { /// A specialized generator for attribute interfaces. struct AttrInterfaceGenerator : public InterfaceGenerator { - AttrInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) + AttrInterfaceGenerator(llvm::RecordKeeper &records, raw_ostream &os) : InterfaceGenerator(getAllInterfaceDefinitions(records, "Attr"), os) { valueType = "::mlir::Attribute"; interfaceBaseType = "AttributeInterface"; @@ -133,7 +132,7 @@ struct AttrInterfaceGenerator : public InterfaceGenerator { }; /// A specialized generator for operation interfaces. struct OpInterfaceGenerator : public InterfaceGenerator { - OpInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) + OpInterfaceGenerator(llvm::RecordKeeper &records, raw_ostream &os) : InterfaceGenerator(getAllInterfaceDefinitions(records, "Op"), os) { valueType = "::mlir::Operation *"; interfaceBaseType = "OpInterface"; @@ -149,7 +148,7 @@ struct OpInterfaceGenerator : public InterfaceGenerator { }; /// A specialized generator for type interfaces. struct TypeInterfaceGenerator : public InterfaceGenerator { - TypeInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) + TypeInterfaceGenerator(llvm::RecordKeeper &records, raw_ostream &os) : InterfaceGenerator(getAllInterfaceDefinitions(records, "Type"), os) { valueType = "::mlir::Type"; interfaceBaseType = "TypeInterface"; @@ -684,15 +683,15 @@ struct InterfaceGenRegistration { genDefDesc(("Generate " + genDesc + " interface definitions").str()), genDocDesc(("Generate " + genDesc + " interface documentation").str()), genDecls(genDeclArg, genDeclDesc, - [](const llvm::RecordKeeper &records, raw_ostream &os) { + [](llvm::RecordKeeper &records, raw_ostream &os) { return GeneratorT(records, os).emitInterfaceDecls(); }), genDefs(genDefArg, genDefDesc, - [](const llvm::RecordKeeper &records, raw_ostream &os) { + [](llvm::RecordKeeper &records, raw_ostream &os) { return GeneratorT(records, os).emitInterfaceDefs(); }), genDocs(genDocArg, genDocDesc, - [](const llvm::RecordKeeper &records, raw_ostream &os) { + [](llvm::RecordKeeper &records, raw_ostream &os) { return GeneratorT(records, os).emitInterfaceDocs(); }) {} diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 2c79ba2cd6353ee..401f02246ed2356 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -268,7 +268,7 @@ class PatternEmitter { // inlining them. class StaticMatcherHelper { public: - StaticMatcherHelper(raw_ostream &os, const RecordKeeper &recordKeeper, + StaticMatcherHelper(raw_ostream &os, RecordKeeper &recordKeeper, RecordOperatorMap &mapper); // Determine if we should inline the match logic or delegate to a static @@ -1886,7 +1886,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs( } StaticMatcherHelper::StaticMatcherHelper(raw_ostream &os, - const RecordKeeper &recordKeeper, + RecordKeeper &recordKeeper, RecordOperatorMap &mapper) : opMap(mapper), staticVerifierEmitter(os, recordKeeper) {} @@ -1951,7 +1951,7 @@ StringRef StaticMatcherHelper::getVerifierName(DagLeaf leaf) { return staticVerifierEmitter.getTypeConstraintFn(leaf.getAsConstraint()); } -static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { +static void emitRewriters(RecordKeeper &recordKeeper, raw_ostream &os) { emitSourceFileHeader("Rewriters", os, recordKeeper); const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); @@ -2001,7 +2001,7 @@ static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { static mlir::GenRegistration genRewriters("gen-rewriters", "Generate pattern rewriters", - [](const RecordKeeper &records, raw_ostream &os) { + [](RecordKeeper &records, raw_ostream &os) { emitRewriters(records, os); return false; }); diff --git a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp index a55f3539f31db00..0957a5d55db9596 100644 --- a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp +++ b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp @@ -146,14 +146,13 @@ static irdl::DialectOp createIRDLDialect(OpBuilder &builder) { } static std::vector -getOpDefinitions(const RecordKeeper &recordKeeper) { +getOpDefinitions(RecordKeeper &recordKeeper) { if (!recordKeeper.getClass("Op")) return {}; return recordKeeper.getAllDerivedDefinitions("Op"); } -static bool emitDialectIRDLDefs(const RecordKeeper &recordKeeper, - raw_ostream &os) { +static bool emitDialectIRDLDefs(RecordKeeper &recordKeeper, raw_ostream &os) { // Initialize. MLIRContext ctx; ctx.getOrLoadDialect(); @@ -185,6 +184,6 @@ static bool emitDialectIRDLDefs(const RecordKeeper &recordKeeper, static mlir::GenRegistration genOpDefs("gen-dialect-irdl-defs", "Generate IRDL dialect definitions", - [](const RecordKeeper &records, raw_ostream &os) { + [](RecordKeeper &records, raw_ostream &os) { return emitDialectIRDLDefs(records, os); });