diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp index 1a248c3a16647d6..6a39424bd463fd7 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -22,6 +22,8 @@ using namespace mlir; using namespace mlir::tblgen; +using llvm::Record; +using llvm::RecordKeeper; //===----------------------------------------------------------------------===// // Utility Functions @@ -30,14 +32,14 @@ using namespace mlir::tblgen; /// Find all the AttrOrTypeDef for the specified dialect. If no dialect /// specified and can only find one dialect's defs, use that. static void collectAllDefs(StringRef selectedDialect, - ArrayRef records, + ArrayRef records, SmallVectorImpl &resultDefs) { // Nothing to do if no defs were found. if (records.empty()) return; auto defs = llvm::map_range( - records, [&](const llvm::Record *rec) { return AttrOrTypeDef(rec); }); + records, [&](const Record *rec) { return AttrOrTypeDef(rec); }); if (selectedDialect.empty()) { // If a dialect was not specified, ensure that all found defs belong to the // same dialect. @@ -690,15 +692,14 @@ class DefGenerator { bool emitDefs(StringRef selectedDialect); protected: - DefGenerator(ArrayRef defs, raw_ostream &os, + DefGenerator(ArrayRef defs, raw_ostream &os, StringRef defType, StringRef valueType, bool isAttrGenerator) : defRecords(defs), os(os), defType(defType), valueType(valueType), isAttrGenerator(isAttrGenerator) { // Sort by occurrence in file. - llvm::sort(defRecords, - [](const llvm::Record *lhs, const llvm::Record *rhs) { - return lhs->getID() < rhs->getID(); - }); + llvm::sort(defRecords, [](const Record *lhs, const Record *rhs) { + return lhs->getID() < rhs->getID(); + }); } /// Emit the list of def type names. @@ -707,7 +708,7 @@ class DefGenerator { void emitParsePrintDispatch(ArrayRef defs); /// The set of def records to emit. - std::vector defRecords; + std::vector defRecords; /// The attribute or type class to emit. /// The stream to emit to. raw_ostream &os; @@ -722,13 +723,13 @@ class DefGenerator { /// A specialized generator for AttrDefs. struct AttrDefGenerator : public DefGenerator { - AttrDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os) + AttrDefGenerator(const RecordKeeper &records, raw_ostream &os) : DefGenerator(records.getAllDerivedDefinitionsIfDefined("AttrDef"), os, "Attr", "Attribute", /*isAttrGenerator=*/true) {} }; /// A specialized generator for TypeDefs. struct TypeDefGenerator : public DefGenerator { - TypeDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os) + TypeDefGenerator(const RecordKeeper &records, raw_ostream &os) : DefGenerator(records.getAllDerivedDefinitionsIfDefined("TypeDef"), os, "Type", "Type", /*isAttrGenerator=*/false) {} }; @@ -1030,9 +1031,9 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) { /// Find all type constraints for which a C++ function should be generated. static std::vector -getAllTypeConstraints(const llvm::RecordKeeper &records) { +getAllTypeConstraints(const RecordKeeper &records) { std::vector result; - for (const llvm::Record *def : + for (const Record *def : records.getAllDerivedDefinitionsIfDefined("TypeConstraint")) { // Ignore constraints defined outside of the top-level file. if (llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) != @@ -1047,7 +1048,7 @@ getAllTypeConstraints(const llvm::RecordKeeper &records) { return result; } -static void emitTypeConstraintDecls(const llvm::RecordKeeper &records, +static void emitTypeConstraintDecls(const RecordKeeper &records, raw_ostream &os) { static const char *const typeConstraintDecl = R"( bool {0}(::mlir::Type type); @@ -1057,7 +1058,7 @@ bool {0}(::mlir::Type type); os << strfmt(typeConstraintDecl, *constr.getCppFunctionName()); } -static void emitTypeConstraintDefs(const llvm::RecordKeeper &records, +static void emitTypeConstraintDefs(const RecordKeeper &records, raw_ostream &os) { static const char *const typeConstraintDef = R"( bool {0}(::mlir::Type type) { @@ -1088,13 +1089,13 @@ static llvm::cl::opt static mlir::GenRegistration genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions", - [](const llvm::RecordKeeper &records, raw_ostream &os) { + [](const RecordKeeper &records, raw_ostream &os) { AttrDefGenerator generator(records, os); return generator.emitDefs(attrDialect); }); static mlir::GenRegistration genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations", - [](const llvm::RecordKeeper &records, raw_ostream &os) { + [](const RecordKeeper &records, raw_ostream &os) { AttrDefGenerator generator(records, os); return generator.emitDecls(attrDialect); }); @@ -1110,13 +1111,13 @@ static llvm::cl::opt static mlir::GenRegistration genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions", - [](const llvm::RecordKeeper &records, raw_ostream &os) { + [](const RecordKeeper &records, raw_ostream &os) { TypeDefGenerator generator(records, os); return generator.emitDefs(typeDialect); }); static mlir::GenRegistration genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations", - [](const llvm::RecordKeeper &records, raw_ostream &os) { + [](const RecordKeeper &records, raw_ostream &os) { TypeDefGenerator generator(records, os); return generator.emitDecls(typeDialect); }); @@ -1124,14 +1125,14 @@ static mlir::GenRegistration static mlir::GenRegistration genTypeConstrDefs("gen-type-constraint-defs", "Generate type constraint definitions", - [](const llvm::RecordKeeper &records, raw_ostream &os) { + [](const RecordKeeper &records, raw_ostream &os) { emitTypeConstraintDefs(records, os); return false; }); static mlir::GenRegistration genTypeConstrDecls("gen-type-constraint-decls", "Generate type constraint declarations", - [](const llvm::RecordKeeper &records, raw_ostream &os) { + [](const RecordKeeper &records, raw_ostream &os) { emitTypeConstraintDecls(records, os); return false; }); diff --git a/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp b/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp index 1474bd8c149ff60..86ebaf2cf27dfeb 100644 --- a/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp +++ b/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp @@ -18,11 +18,10 @@ using namespace llvm; -static llvm::cl::OptionCategory dialectGenCat("Options for -gen-bytecode"); -static llvm::cl::opt - selectedBcDialect("bytecode-dialect", - llvm::cl::desc("The dialect to gen for"), - llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated); +static cl::OptionCategory dialectGenCat("Options for -gen-bytecode"); +static cl::opt + selectedBcDialect("bytecode-dialect", cl::desc("The dialect to gen for"), + cl::cat(dialectGenCat), cl::CommaSeparated); namespace { @@ -306,7 +305,7 @@ void Generator::emitPrint(StringRef kind, StringRef type, auto funScope = os.scope("{\n", "}\n\n"); // Check that predicates specified if multiple bytecode instances. - for (const llvm::Record *rec : make_second_range(vec)) { + for (const Record *rec : make_second_range(vec)) { StringRef pred = rec->getValueAsString("printerPredicate"); if (vec.size() > 1 && pred.empty()) { for (auto [index, rec] : vec) { diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp index 2412876958a0c98..76da9d7cea4e8d6 100644 --- a/mlir/tools/mlir-tblgen/DialectGen.cpp +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -30,6 +30,8 @@ using namespace mlir; using namespace mlir::tblgen; +using llvm::Record; +using llvm::RecordKeeper; static llvm::cl::OptionCategory dialectGenCat("Options for -gen-dialect-*"); llvm::cl::opt @@ -39,8 +41,8 @@ llvm::cl::opt /// Utility iterator used for filtering records for a specific dialect. namespace { using DialectFilterIterator = - llvm::filter_iterator::iterator, - std::function>; + llvm::filter_iterator::iterator, + std::function>; } // namespace static void populateDiscardableAttributes( @@ -62,8 +64,8 @@ static void populateDiscardableAttributes( /// the given dialect. template static iterator_range -filterForDialect(ArrayRef records, Dialect &dialect) { - auto filterFn = [&](const llvm::Record *record) { +filterForDialect(ArrayRef records, Dialect &dialect) { + auto filterFn = [&](const Record *record) { return T(record).getDialect() == dialect; }; return {DialectFilterIterator(records.begin(), records.end(), filterFn), @@ -295,7 +297,7 @@ static void emitDialectDecl(Dialect &dialect, raw_ostream &os) { << "::" << dialect.getCppClassName() << ")\n"; } -static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper, +static bool emitDialectDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { emitSourceFileHeader("Dialect Declarations", os, recordKeeper); @@ -340,8 +342,7 @@ static const char *const dialectDestructorStr = R"( )"; -static void emitDialectDef(Dialect &dialect, - const llvm::RecordKeeper &recordKeeper, +static void emitDialectDef(Dialect &dialect, const RecordKeeper &recordKeeper, raw_ostream &os) { std::string cppClassName = dialect.getCppClassName(); @@ -389,8 +390,7 @@ static void emitDialectDef(Dialect &dialect, os << llvm::formatv(dialectDestructorStr, cppClassName); } -static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper, - raw_ostream &os) { +static bool emitDialectDefs(const RecordKeeper &recordKeeper, raw_ostream &os) { emitSourceFileHeader("Dialect Definitions", os, recordKeeper); auto dialectDefs = recordKeeper.getAllDerivedDefinitions("Dialect"); @@ -411,12 +411,12 @@ static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper, static mlir::GenRegistration genDialectDecls("gen-dialect-decls", "Generate dialect declarations", - [](const llvm::RecordKeeper &records, raw_ostream &os) { + [](const RecordKeeper &records, raw_ostream &os) { return emitDialectDecls(records, os); }); static mlir::GenRegistration genDialectDefs("gen-dialect-defs", "Generate dialect definitions", - [](const llvm::RecordKeeper &records, raw_ostream &os) { + [](const RecordKeeper &records, raw_ostream &os) { return emitDialectDefs(records, os); }); diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp index 79249944e484f72..189487794f8f7ca 100644 --- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp @@ -21,6 +21,9 @@ using namespace mlir; using namespace mlir::tblgen; +using llvm::formatv; +using llvm::Record; +using llvm::RecordKeeper; /// File header and includes. constexpr const char *fileHeader = R"Py( @@ -42,44 +45,42 @@ static std::string makePythonEnumCaseName(StringRef name) { /// Emits the Python class for the given enum. static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) { - os << llvm::formatv("class {0}({1}):\n", enumAttr.getEnumClassName(), - enumAttr.isBitEnum() ? "IntFlag" : "IntEnum"); + os << formatv("class {0}({1}):\n", enumAttr.getEnumClassName(), + enumAttr.isBitEnum() ? "IntFlag" : "IntEnum"); if (!enumAttr.getSummary().empty()) - os << llvm::formatv(" \"\"\"{0}\"\"\"\n", enumAttr.getSummary()); + os << formatv(" \"\"\"{0}\"\"\"\n", enumAttr.getSummary()); os << "\n"; for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) { - os << llvm::formatv( - " {0} = {1}\n", makePythonEnumCaseName(enumCase.getSymbol()), - enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue()) - : "auto()"); + os << formatv(" {0} = {1}\n", + makePythonEnumCaseName(enumCase.getSymbol()), + enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue()) + : "auto()"); } os << "\n"; if (enumAttr.isBitEnum()) { - os << llvm::formatv(" def __iter__(self):\n" - " return iter([case for case in type(self) if " - "(self & case) is case])\n"); - os << llvm::formatv(" def __len__(self):\n" - " return bin(self).count(\"1\")\n"); + os << formatv(" def __iter__(self):\n" + " return iter([case for case in type(self) if " + "(self & case) is case])\n"); + os << formatv(" def __len__(self):\n" + " return bin(self).count(\"1\")\n"); os << "\n"; } - os << llvm::formatv(" def __str__(self):\n"); + os << formatv(" def __str__(self):\n"); if (enumAttr.isBitEnum()) - os << llvm::formatv(" if len(self) > 1:\n" - " return \"{0}\".join(map(str, self))\n", - enumAttr.getDef().getValueAsString("separator")); + os << formatv(" if len(self) > 1:\n" + " return \"{0}\".join(map(str, self))\n", + enumAttr.getDef().getValueAsString("separator")); for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) { - os << llvm::formatv(" if self is {0}.{1}:\n", - enumAttr.getEnumClassName(), - makePythonEnumCaseName(enumCase.getSymbol())); - os << llvm::formatv(" return \"{0}\"\n", enumCase.getStr()); + os << formatv(" if self is {0}.{1}:\n", enumAttr.getEnumClassName(), + makePythonEnumCaseName(enumCase.getSymbol())); + os << formatv(" return \"{0}\"\n", enumCase.getStr()); } - os << llvm::formatv( - " raise ValueError(\"Unknown {0} enum entry.\")\n\n\n", - enumAttr.getEnumClassName()); + os << formatv(" raise ValueError(\"Unknown {0} enum entry.\")\n\n\n", + enumAttr.getEnumClassName()); os << "\n"; } @@ -105,15 +106,13 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) { return true; } - os << llvm::formatv("@register_attribute_builder(\"{0}\")\n", - enumAttr.getAttrDefName()); - os << llvm::formatv("def _{0}(x, context):\n", - enumAttr.getAttrDefName().lower()); - os << llvm::formatv( - " return " - "_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, " - "context=context), int(x))\n\n", - bitwidth); + os << formatv("@register_attribute_builder(\"{0}\")\n", + enumAttr.getAttrDefName()); + os << formatv("def _{0}(x, context):\n", enumAttr.getAttrDefName().lower()); + os << formatv(" return " + "_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, " + "context=context), int(x))\n\n", + bitwidth); return false; } @@ -123,26 +122,25 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) { static bool emitDialectEnumAttributeBuilder(StringRef attrDefName, StringRef formatString, raw_ostream &os) { - os << llvm::formatv("@register_attribute_builder(\"{0}\")\n", attrDefName); - os << llvm::formatv("def _{0}(x, context):\n", attrDefName.lower()); - os << llvm::formatv(" return " - "_ods_ir.Attribute.parse(f'{0}', context=context)\n\n", - formatString); + os << formatv("@register_attribute_builder(\"{0}\")\n", attrDefName); + os << formatv("def _{0}(x, context):\n", attrDefName.lower()); + os << formatv(" return " + "_ods_ir.Attribute.parse(f'{0}', context=context)\n\n", + formatString); return false; } /// Emits Python bindings for all enums in the record keeper. Returns /// `false` on success, `true` on failure. -static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper, - raw_ostream &os) { +static bool emitPythonEnums(const RecordKeeper &recordKeeper, raw_ostream &os) { os << fileHeader; - for (const llvm::Record *it : + for (const Record *it : recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) { EnumAttr enumAttr(*it); emitEnumClass(enumAttr, os); emitAttributeBuilder(enumAttr, os); } - for (const llvm::Record *it : + for (const Record *it : recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) { AttrOrTypeDef attr(&*it); if (!attr.getMnemonic()) { @@ -156,11 +154,11 @@ static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper, if (assemblyFormat == "`<` $value `>`") { emitDialectEnumAttributeBuilder( attr.getName(), - llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os); + formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os); } else if (assemblyFormat == "$value") { emitDialectEnumAttributeBuilder( attr.getName(), - llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os); + formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os); } else { llvm::errs() << "unsupported assembly format for python enum bindings generation"; diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp index 863463bd920bffc..5f2008818e3eb72 100644 --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -26,10 +26,9 @@ using llvm::formatv; using llvm::isDigit; using llvm::PrintFatalError; -using llvm::raw_ostream; using llvm::Record; using llvm::RecordKeeper; -using llvm::StringRef; +using namespace mlir; using mlir::tblgen::Attribute; using mlir::tblgen::EnumAttr; using mlir::tblgen::EnumAttrCase; @@ -139,7 +138,7 @@ inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{ // is not a power of two (i.e. not a single bit case) and not a known case. } else if (enumAttr.isBitEnum()) { // Process the known multi-bit cases that use valid keywords. - llvm::SmallVector validMultiBitCases; + SmallVector validMultiBitCases; for (auto [index, caseVal] : llvm::enumerate(cases)) { uint64_t value = caseVal.getValue(); if (value && !llvm::has_single_bit(value) && !nonKeywordCases.test(index)) @@ -476,7 +475,7 @@ static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) { EnumAttr enumAttr(enumDef); StringRef enumName = enumAttr.getEnumClassName(); StringRef attrClassName = enumAttr.getSpecializedAttrClassName(); - const llvm::Record *baseAttrDef = enumAttr.getBaseAttrClass(); + const Record *baseAttrDef = enumAttr.getBaseAttrClass(); Attribute baseAttr(baseAttrDef); // Emit classof method @@ -565,7 +564,7 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName(); auto enumerants = enumAttr.getAllCases(); - llvm::SmallVector namespaces; + SmallVector namespaces; llvm::SplitString(cppNamespace, namespaces, "::"); for (auto ns : namespaces) @@ -656,7 +655,7 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) { EnumAttr enumAttr(enumDef); StringRef cppNamespace = enumAttr.getCppNamespace(); - llvm::SmallVector namespaces; + SmallVector namespaces; llvm::SplitString(cppNamespace, namespaces, "::"); for (auto ns : namespaces) diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp index 7540e584b8fac5d..d145f3e5a23ddb6 100644 --- a/mlir/tools/mlir-tblgen/FormatGen.cpp +++ b/mlir/tools/mlir-tblgen/FormatGen.cpp @@ -13,6 +13,7 @@ using namespace mlir; using namespace mlir::tblgen; +using llvm::SourceMgr; //===----------------------------------------------------------------------===// // FormatToken @@ -26,14 +27,14 @@ SMLoc FormatToken::getLoc() const { // FormatLexer //===----------------------------------------------------------------------===// -FormatLexer::FormatLexer(llvm::SourceMgr &mgr, SMLoc loc) +FormatLexer::FormatLexer(SourceMgr &mgr, SMLoc loc) : mgr(mgr), loc(loc), curBuffer(mgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer()), curPtr(curBuffer.begin()) {} FormatToken FormatLexer::emitError(SMLoc loc, const Twine &msg) { - mgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg); - llvm::SrcMgr.PrintMessage(this->loc, llvm::SourceMgr::DK_Note, + mgr.PrintMessage(loc, SourceMgr::DK_Error, msg); + llvm::SrcMgr.PrintMessage(this->loc, SourceMgr::DK_Note, "in custom assembly format for this operation"); return formToken(FormatToken::error, loc.getPointer()); } @@ -44,10 +45,10 @@ FormatToken FormatLexer::emitError(const char *loc, const Twine &msg) { FormatToken FormatLexer::emitErrorAndNote(SMLoc loc, const Twine &msg, const Twine ¬e) { - mgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg); - llvm::SrcMgr.PrintMessage(this->loc, llvm::SourceMgr::DK_Note, + mgr.PrintMessage(loc, SourceMgr::DK_Error, msg); + llvm::SrcMgr.PrintMessage(this->loc, SourceMgr::DK_Note, "in custom assembly format for this operation"); - mgr.PrintMessage(loc, llvm::SourceMgr::DK_Note, note); + mgr.PrintMessage(loc, SourceMgr::DK_Note, note); return formToken(FormatToken::error, loc.getPointer()); } diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp index 5560298831865f7..4e7ddab75fc1dd2 100644 --- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp @@ -411,8 +411,7 @@ class LLVMCEnumAttr : public tblgen::EnumAttr { // Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing // switch-based logic to convert from the MLIR LLVM dialect enum attribute case // (Enum) to the corresponding LLVM API enumerant -static void emitOneEnumToConversion(const llvm::Record *record, - raw_ostream &os) { +static void emitOneEnumToConversion(const Record *record, raw_ostream &os) { LLVMEnumAttr enumAttr(record); StringRef llvmClass = enumAttr.getLLVMClassName(); StringRef cppClassName = enumAttr.getEnumClassName(); @@ -441,8 +440,7 @@ static void emitOneEnumToConversion(const llvm::Record *record, // Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing // switch-based logic to convert from the MLIR LLVM dialect enum attribute case // (Enum) to the corresponding LLVM API C-style enumerant -static void emitOneCEnumToConversion(const llvm::Record *record, - raw_ostream &os) { +static void emitOneCEnumToConversion(const Record *record, raw_ostream &os) { LLVMCEnumAttr enumAttr(record); StringRef llvmClass = enumAttr.getLLVMClassName(); StringRef cppClassName = enumAttr.getEnumClassName(); @@ -472,8 +470,7 @@ static void emitOneCEnumToConversion(const llvm::Record *record, // Emits conversion function "Enum convertEnumFromLLVM(LLVMClass)" and // containing switch-based logic to convert from the LLVM API enumerant to MLIR // LLVM dialect enum attribute (Enum). -static void emitOneEnumFromConversion(const llvm::Record *record, - raw_ostream &os) { +static void emitOneEnumFromConversion(const Record *record, raw_ostream &os) { LLVMEnumAttr enumAttr(record); StringRef llvmClass = enumAttr.getLLVMClassName(); StringRef cppClassName = enumAttr.getEnumClassName(); @@ -508,8 +505,7 @@ static void emitOneEnumFromConversion(const llvm::Record *record, // Emits conversion function "Enum convertEnumFromLLVM(LLVMEnum)" and // containing switch-based logic to convert from the LLVM API C-style enumerant // to MLIR LLVM dialect enum attribute (Enum). -static void emitOneCEnumFromConversion(const llvm::Record *record, - raw_ostream &os) { +static void emitOneCEnumFromConversion(const Record *record, raw_ostream &os) { LLVMCEnumAttr enumAttr(record); StringRef llvmClass = enumAttr.getLLVMClassName(); StringRef cppClassName = enumAttr.getEnumClassName(); diff --git a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp index 1e3cd8b86d5679b..411a98a48bfb28b 100644 --- a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp @@ -24,6 +24,11 @@ #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" +using llvm::Record; +using llvm::RecordKeeper; +using llvm::Regex; +using namespace mlir; + static llvm::cl::OptionCategory intrinsicGenCat("Intrinsics Generator Options"); static llvm::cl::opt @@ -54,14 +59,14 @@ static llvm::cl::opt aliasAnalysisRegexp( using IndicesTy = llvm::SmallBitVector; /// Return a CodeGen value type entry from a type record. -static llvm::MVT::SimpleValueType getValueType(const llvm::Record *rec) { +static llvm::MVT::SimpleValueType getValueType(const Record *rec) { return (llvm::MVT::SimpleValueType)rec->getValueAsDef("VT")->getValueAsInt( "Value"); } /// Return the indices of the definitions in a list of definitions that /// represent overloadable types -static IndicesTy getOverloadableTypeIdxs(const llvm::Record &record, +static IndicesTy getOverloadableTypeIdxs(const Record &record, const char *listName) { auto results = record.getValueAsListOfDefs(listName); IndicesTy overloadedOps(results.size()); @@ -87,13 +92,13 @@ namespace { /// the fields of the record. class LLVMIntrinsic { public: - LLVMIntrinsic(const llvm::Record &record) : record(record) {} + LLVMIntrinsic(const Record &record) : record(record) {} /// Get the name of the operation to be used in MLIR. Uses the appropriate /// field if not empty, constructs a name by replacing underscores with dots /// in the record name otherwise. std::string getOperationName() const { - llvm::StringRef name = record.getValueAsString(fieldName); + StringRef name = record.getValueAsString(fieldName); if (!name.empty()) return name.str(); @@ -101,8 +106,8 @@ class LLVMIntrinsic { assert(name.starts_with("int_") && "LLVM intrinsic names are expected to start with 'int_'"); name = name.drop_front(4); - llvm::SmallVector chunks; - llvm::StringRef targetPrefix = record.getValueAsString("TargetPrefix"); + SmallVector chunks; + StringRef targetPrefix = record.getValueAsString("TargetPrefix"); name.split(chunks, '_'); auto *chunksBegin = chunks.begin(); // Remove the target prefix from target specific intrinsics. @@ -119,8 +124,8 @@ class LLVMIntrinsic { } /// Get the name of the record without the "intrinsic" prefix. - llvm::StringRef getProperRecordName() const { - llvm::StringRef name = record.getName(); + StringRef getProperRecordName() const { + StringRef name = record.getName(); assert(name.starts_with("int_") && "LLVM intrinsic names are expected to start with 'int_'"); return name.drop_front(4); @@ -129,10 +134,9 @@ class LLVMIntrinsic { /// Get the number of operands. unsigned getNumOperands() const { auto operands = record.getValueAsListOfDefs(fieldOperands); - assert(llvm::all_of(operands, - [](const llvm::Record *r) { - return r->isSubClassOf("LLVMType"); - }) && + assert(llvm::all_of( + operands, + [](const Record *r) { return r->isSubClassOf("LLVMType"); }) && "expected operands to be of LLVM type"); return operands.size(); } @@ -142,7 +146,7 @@ class LLVMIntrinsic { /// structure type. unsigned getNumResults() const { auto results = record.getValueAsListOfDefs(fieldResults); - for (const llvm::Record *r : results) { + for (const Record *r : results) { (void)r; assert(r->isSubClassOf("LLVMType") && "expected operands to be of LLVM type"); @@ -155,7 +159,7 @@ class LLVMIntrinsic { bool hasSideEffects() const { return llvm::none_of( record.getValueAsListOfDefs(fieldTraits), - [](const llvm::Record *r) { return r->getName() == "IntrNoMem"; }); + [](const Record *r) { return r->getName() == "IntrNoMem"; }); } /// Return true if the intrinsic is commutative, i.e. has the respective @@ -163,7 +167,7 @@ class LLVMIntrinsic { bool isCommutative() const { return llvm::any_of( record.getValueAsListOfDefs(fieldTraits), - [](const llvm::Record *r) { return r->getName() == "Commutative"; }); + [](const Record *r) { return r->getName() == "Commutative"; }); } IndicesTy getOverloadableOperandsIdxs() const { @@ -181,7 +185,7 @@ class LLVMIntrinsic { const char *fieldResults = "RetTypes"; const char *fieldTraits = "IntrProperties"; - const llvm::Record &record; + const Record &record; }; } // namespace @@ -195,27 +199,26 @@ void printBracketedRange(const Range &range, llvm::raw_ostream &os) { /// Emits ODS (TableGen-based) code for `record` representing an LLVM intrinsic. /// Returns true on error, false on success. -static bool emitIntrinsic(const llvm::Record &record, llvm::raw_ostream &os) { +static bool emitIntrinsic(const Record &record, llvm::raw_ostream &os) { LLVMIntrinsic intr(record); - llvm::Regex accessGroupMatcher(accessGroupRegexp); + Regex accessGroupMatcher(accessGroupRegexp); bool requiresAccessGroup = !accessGroupRegexp.empty() && accessGroupMatcher.match(record.getName()); - llvm::Regex aliasAnalysisMatcher(aliasAnalysisRegexp); + Regex aliasAnalysisMatcher(aliasAnalysisRegexp); bool requiresAliasAnalysis = !aliasAnalysisRegexp.empty() && aliasAnalysisMatcher.match(record.getName()); // Prepare strings for traits, if any. - llvm::SmallVector traits; + SmallVector traits; if (intr.isCommutative()) traits.push_back("Commutative"); if (!intr.hasSideEffects()) traits.push_back("NoMemoryEffect"); // Prepare strings for operands. - llvm::SmallVector operands(intr.getNumOperands(), - "LLVM_Type"); + SmallVector operands(intr.getNumOperands(), "LLVM_Type"); if (requiresAccessGroup) operands.push_back( "OptionalAttr:$access_groups"); @@ -247,14 +250,13 @@ static bool emitIntrinsic(const llvm::Record &record, llvm::raw_ostream &os) { /// Traverses the list of TableGen definitions derived from the "Intrinsic" /// class and generates MLIR ODS definitions for those intrinsics that have /// the name matching the filter. -static bool emitIntrinsics(const llvm::RecordKeeper &records, - llvm::raw_ostream &os) { +static bool emitIntrinsics(const RecordKeeper &records, llvm::raw_ostream &os) { llvm::emitSourceFileHeader("Operations for LLVM intrinsics", os, records); os << "include \"mlir/Dialect/LLVMIR/LLVMOpBase.td\"\n"; os << "include \"mlir/Interfaces/SideEffectInterfaces.td\"\n\n"; auto defs = records.getAllDerivedDefinitions("Intrinsic"); - for (const llvm::Record *r : defs) { + for (const Record *r : defs) { if (!nameFilter.empty() && !r->getName().contains(nameFilter)) continue; if (emitIntrinsic(*r, os)) diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp index ed9d90a25625fc4..5171e3fad9e84b8 100644 --- a/mlir/tools/mlir-tblgen/OpDocGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp @@ -34,30 +34,30 @@ #include #include +using namespace llvm; +using namespace mlir; +using namespace mlir::tblgen; +using mlir::tblgen::Operator; + //===----------------------------------------------------------------------===// // Commandline Options //===----------------------------------------------------------------------===// -static llvm::cl::OptionCategory +static cl::OptionCategory docCat("Options for -gen-(attrdef|typedef|enum|op|dialect)-doc"); -llvm::cl::opt +cl::opt stripPrefix("strip-prefix", - llvm::cl::desc("Strip prefix of the fully qualified names"), - llvm::cl::init("::mlir::"), llvm::cl::cat(docCat)); -llvm::cl::opt allowHugoSpecificFeatures( + cl::desc("Strip prefix of the fully qualified names"), + cl::init("::mlir::"), cl::cat(docCat)); +cl::opt allowHugoSpecificFeatures( "allow-hugo-specific-features", - llvm::cl::desc("Allows using features specific to Hugo"), - llvm::cl::init(false), llvm::cl::cat(docCat)); - -using namespace llvm; -using namespace mlir; -using namespace mlir::tblgen; -using mlir::tblgen::Operator; + cl::desc("Allows using features specific to Hugo"), cl::init(false), + cl::cat(docCat)); void mlir::tblgen::emitSummary(StringRef summary, raw_ostream &os) { if (!summary.empty()) { - llvm::StringRef trimmed = summary.trim(); + StringRef trimmed = summary.trim(); char first = std::toupper(trimmed.front()); - llvm::StringRef rest = trimmed.drop_front(); + StringRef rest = trimmed.drop_front(); os << "\n_" << first << rest << "_\n\n"; } } @@ -152,10 +152,10 @@ static void emitOpTraitsDoc(const Operator &op, raw_ostream &os) { effectName.consume_front("::"); effectName.consume_front("mlir::"); std::string effectStr; - llvm::raw_string_ostream os(effectStr); + raw_string_ostream os(effectStr); os << effectName << "{"; auto list = trait.getDef().getValueAsListOfDefs("effects"); - llvm::interleaveComma(list, os, [&](const Record *rec) { + interleaveComma(list, os, [&](const Record *rec) { StringRef effect = rec->getValueAsString("effect"); effect.consume_front("::"); effect.consume_front("mlir::"); @@ -163,7 +163,7 @@ static void emitOpTraitsDoc(const Operator &op, raw_ostream &os) { }); os << "}"; effects.insert(backticks(effectStr)); - name.append(llvm::formatv(" ({0})", traitName).str()); + name.append(formatv(" ({0})", traitName).str()); } interfaces.insert(backticks(name)); continue; @@ -172,15 +172,15 @@ static void emitOpTraitsDoc(const Operator &op, raw_ostream &os) { traits.insert(backticks(name)); } if (!traits.empty()) { - llvm::interleaveComma(traits, os << "\nTraits: "); + interleaveComma(traits, os << "\nTraits: "); os << "\n"; } if (!interfaces.empty()) { - llvm::interleaveComma(interfaces, os << "\nInterfaces: "); + interleaveComma(interfaces, os << "\nInterfaces: "); os << "\n"; } if (!effects.empty()) { - llvm::interleaveComma(effects, os << "\nEffects: "); + interleaveComma(effects, os << "\nEffects: "); os << "\n"; } } @@ -196,7 +196,7 @@ static void emitOpDoc(const Operator &op, raw_ostream &os) { std::string classNameStr = op.getQualCppClassName(); StringRef className = classNameStr; (void)className.consume_front(stripPrefix); - os << llvm::formatv("### `{0}` ({1})\n", op.getOperationName(), className); + os << formatv("### `{0}` ({1})\n", op.getOperationName(), className); // Emit the summary, syntax, and description if present. if (op.hasSummary()) @@ -287,7 +287,7 @@ static void emitOpDoc(const RecordKeeper &recordKeeper, raw_ostream &os) { os << "\n"; emitSourceLink(recordKeeper.getInputFilename(), os); - for (const llvm::Record *opDef : opDefs) + for (const Record *opDef : opDefs) emitOpDoc(Operator(opDef), os); } @@ -339,7 +339,7 @@ static void emitAttrOrTypeDefAssemblyFormat(const AttrOrTypeDef &def, } static void emitAttrOrTypeDefDoc(const AttrOrTypeDef &def, raw_ostream &os) { - os << llvm::formatv("### {0}\n", def.getCppClassName()); + os << formatv("### {0}\n", def.getCppClassName()); // Emit the summary if present. if (def.hasSummary()) @@ -376,7 +376,7 @@ static void emitAttrOrTypeDefDoc(const RecordKeeper &recordKeeper, auto defs = recordKeeper.getAllDerivedDefinitions(recordTypeName); os << "\n"; - for (const llvm::Record *def : defs) + for (const Record *def : defs) emitAttrOrTypeDefDoc(AttrOrTypeDef(def), os); } @@ -385,7 +385,7 @@ static void emitAttrOrTypeDefDoc(const RecordKeeper &recordKeeper, //===----------------------------------------------------------------------===// static void emitEnumDoc(const EnumAttr &def, raw_ostream &os) { - os << llvm::formatv("### {0}\n", def.getEnumClassName()); + os << formatv("### {0}\n", def.getEnumClassName()); // Emit the summary if present. if (!def.getSummary().empty()) @@ -406,8 +406,7 @@ static void emitEnumDoc(const EnumAttr &def, raw_ostream &os) { static void emitEnumDoc(const RecordKeeper &recordKeeper, raw_ostream &os) { os << "\n"; - for (const llvm::Record *def : - recordKeeper.getAllDerivedDefinitions("EnumAttr")) + for (const Record *def : recordKeeper.getAllDerivedDefinitions("EnumAttr")) emitEnumDoc(EnumAttr(def), os); } @@ -431,7 +430,7 @@ struct OpDocGroup { static void maybeNest(bool nest, llvm::function_ref fn, raw_ostream &os) { std::string str; - llvm::raw_string_ostream ss(str); + raw_string_ostream ss(str); fn(ss); for (StringRef x : llvm::split(str, "\n")) { if (nest && x.starts_with("#")) @@ -507,7 +506,7 @@ static void emitDialectDoc(const Dialect &dialect, StringRef inputFilename, emitIfNotEmpty(dialect.getDescription(), os); // Generate a TOC marker except if description already contains one. - llvm::Regex r("^[[:space:]]*\\[TOC\\]$", llvm::Regex::RegexFlags::Newline); + Regex r("^[[:space:]]*\\[TOC\\]$", Regex::RegexFlags::Newline); if (!r.match(dialect.getDescription())) os << "[TOC]\n\n"; @@ -537,17 +536,15 @@ static bool emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) { std::vector dialectTypeDefs; std::vector dialectEnums; - llvm::SmallDenseSet seen; - auto addIfNotSeen = [&](const llvm::Record *record, const auto &def, - auto &vec) { + SmallDenseSet seen; + auto addIfNotSeen = [&](const Record *record, const auto &def, auto &vec) { if (seen.insert(record).second) { vec.push_back(def); return true; } return false; }; - auto addIfInDialect = [&](const llvm::Record *record, const auto &def, - auto &vec) { + auto addIfInDialect = [&](const Record *record, const auto &def, auto &vec) { return def.getDialect() == *dialect && addIfNotSeen(record, def, vec); }; diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp index 7016fe41ca75d03..c99c71572bec232 100644 --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -28,6 +28,9 @@ using namespace mlir; using namespace mlir::tblgen; +using llvm::formatv; +using llvm::Record; +using llvm::StringMap; //===----------------------------------------------------------------------===// // VariableElement @@ -404,7 +407,7 @@ struct OperationFormat { StringRef opCppClassName; /// A map of buildable types to indices. - llvm::MapVector> buildableTypes; + llvm::MapVector> buildableTypes; /// The index of the buildable type, if valid, for every operand and result. std::vector operandTypes, resultTypes; @@ -891,8 +894,7 @@ static void genElementParserStorage(FormatElement *element, const Operator &op, } else if (auto *attr = dyn_cast(element)) { const NamedAttribute *var = attr->getVar(); - body << llvm::formatv(" {0} {1}Attr;\n", var->attr.getStorageType(), - var->name); + body << formatv(" {0} {1}Attr;\n", var->attr.getStorageType(), var->name); } else if (auto *operand = dyn_cast(element)) { StringRef name = operand->getVar()->name; @@ -910,31 +912,31 @@ static void genElementParserStorage(FormatElement *element, const Operator &op, << " ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> " << name << "Operands(&" << name << "RawOperand, 1);"; } - body << llvm::formatv(" ::llvm::SMLoc {0}OperandsLoc;\n" - " (void){0}OperandsLoc;\n", - name); + body << formatv(" ::llvm::SMLoc {0}OperandsLoc;\n" + " (void){0}OperandsLoc;\n", + name); } else if (auto *region = dyn_cast(element)) { StringRef name = region->getVar()->name; if (region->getVar()->isVariadic()) { - body << llvm::formatv( + body << formatv( " ::llvm::SmallVector, 2> " "{0}Regions;\n", name); } else { - body << llvm::formatv(" std::unique_ptr<::mlir::Region> {0}Region = " - "std::make_unique<::mlir::Region>();\n", - name); + body << formatv(" std::unique_ptr<::mlir::Region> {0}Region = " + "std::make_unique<::mlir::Region>();\n", + name); } } else if (auto *successor = dyn_cast(element)) { StringRef name = successor->getVar()->name; if (successor->getVar()->isVariadic()) { - body << llvm::formatv(" ::llvm::SmallVector<::mlir::Block *, 2> " - "{0}Successors;\n", - name); + body << formatv(" ::llvm::SmallVector<::mlir::Block *, 2> " + "{0}Successors;\n", + name); } else { - body << llvm::formatv(" ::mlir::Block *{0}Successor = nullptr;\n", name); + body << formatv(" ::mlir::Block *{0}Successor = nullptr;\n", name); } } else if (auto *dir = dyn_cast(element)) { @@ -944,8 +946,8 @@ static void genElementParserStorage(FormatElement *element, const Operator &op, body << " ::llvm::SmallVector<::mlir::Type, 1> " << name << "Types;\n"; else body - << llvm::formatv(" ::mlir::Type {0}RawType{{};\n", name) - << llvm::formatv( + << formatv(" ::mlir::Type {0}RawType{{};\n", name) + << formatv( " ::llvm::ArrayRef<::mlir::Type> {0}Types(&{0}RawType, 1);\n", name); } else if (auto *dir = dyn_cast(element)) { @@ -969,27 +971,27 @@ static void genCustomParameterParser(FormatElement *param, MethodBody &body) { StringRef name = operand->getVar()->name; ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar()); if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) - body << llvm::formatv("{0}OperandGroups", name); + body << formatv("{0}OperandGroups", name); else if (lengthKind == ArgumentLengthKind::Variadic) - body << llvm::formatv("{0}Operands", name); + body << formatv("{0}Operands", name); else if (lengthKind == ArgumentLengthKind::Optional) - body << llvm::formatv("{0}Operand", name); + body << formatv("{0}Operand", name); else body << formatv("{0}RawOperand", name); } else if (auto *region = dyn_cast(param)) { StringRef name = region->getVar()->name; if (region->getVar()->isVariadic()) - body << llvm::formatv("{0}Regions", name); + body << formatv("{0}Regions", name); else - body << llvm::formatv("*{0}Region", name); + body << formatv("*{0}Region", name); } else if (auto *successor = dyn_cast(param)) { StringRef name = successor->getVar()->name; if (successor->getVar()->isVariadic()) - body << llvm::formatv("{0}Successors", name); + body << formatv("{0}Successors", name); else - body << llvm::formatv("{0}Successor", name); + body << formatv("{0}Successor", name); } else if (auto *dir = dyn_cast(param)) { genCustomParameterParser(dir->getArg(), body); @@ -998,11 +1000,11 @@ static void genCustomParameterParser(FormatElement *param, MethodBody &body) { ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getArg(), lengthKind); if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) - body << llvm::formatv("{0}TypeGroups", listName); + body << formatv("{0}TypeGroups", listName); else if (lengthKind == ArgumentLengthKind::Variadic) - body << llvm::formatv("{0}Types", listName); + body << formatv("{0}Types", listName); else if (lengthKind == ArgumentLengthKind::Optional) - body << llvm::formatv("{0}Type", listName); + body << formatv("{0}Type", listName); else body << formatv("{0}RawType", listName); @@ -1013,8 +1015,8 @@ static void genCustomParameterParser(FormatElement *param, MethodBody &body) { body << tgfmt(string->getValue(), &ctx); } else if (auto *property = dyn_cast(param)) { - body << llvm::formatv("result.getOrAddProperties().{0}", - property->getVar()->name); + body << formatv("result.getOrAddProperties().{0}", + property->getVar()->name); } else { llvm_unreachable("unknown custom directive parameter"); } @@ -1037,24 +1039,24 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body, body << " " << var->name << "OperandsLoc = parser.getCurrentLocation();\n"; if (var->isOptional()) { - body << llvm::formatv( + body << formatv( " ::std::optional<::mlir::OpAsmParser::UnresolvedOperand> " "{0}Operand;\n", var->name); } else if (var->isVariadicOfVariadic()) { - body << llvm::formatv(" " - "::llvm::SmallVector<::llvm::SmallVector<::mlir::" - "OpAsmParser::UnresolvedOperand>> " - "{0}OperandGroups;\n", - var->name); + body << formatv(" " + "::llvm::SmallVector<::llvm::SmallVector<::mlir::" + "OpAsmParser::UnresolvedOperand>> " + "{0}OperandGroups;\n", + var->name); } } else if (auto *dir = dyn_cast(param)) { ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getArg(), lengthKind); if (lengthKind == ArgumentLengthKind::Optional) { - body << llvm::formatv(" ::mlir::Type {0}Type;\n", listName); + body << formatv(" ::mlir::Type {0}Type;\n", listName); } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) { - body << llvm::formatv( + body << formatv( " ::llvm::SmallVector> " "{0}TypeGroups;\n", listName); @@ -1064,7 +1066,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body, if (auto *operand = dyn_cast(input)) { if (!operand->getVar()->isOptional()) continue; - body << llvm::formatv( + body << formatv( " {0} {1}Operand = {1}Operands.empty() ? {0}() : " "{1}Operands[0];\n", "::std::optional<::mlir::OpAsmParser::UnresolvedOperand>", @@ -1074,9 +1076,9 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body, ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(type->getArg(), lengthKind); if (lengthKind == ArgumentLengthKind::Optional) { - body << llvm::formatv(" ::mlir::Type {0}Type = {0}Types.empty() ? " - "::mlir::Type() : {0}Types[0];\n", - listName); + body << formatv(" ::mlir::Type {0}Type = {0}Types.empty() ? " + "::mlir::Type() : {0}Types[0];\n", + listName); } } } @@ -1101,23 +1103,23 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body, if (auto *attr = dyn_cast(param)) { const NamedAttribute *var = attr->getVar(); if (var->attr.isOptional() || var->attr.hasDefaultValue()) - body << llvm::formatv(" if ({0}Attr)\n ", var->name); + body << formatv(" if ({0}Attr)\n ", var->name); if (useProperties) { body << formatv( " result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;\n", var->name, opCppClassName); } else { - body << llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n", - var->name); + body << formatv(" result.addAttribute(\"{0}\", {0}Attr);\n", + var->name); } } else if (auto *operand = dyn_cast(param)) { const NamedTypeConstraint *var = operand->getVar(); if (var->isOptional()) { - body << llvm::formatv(" if ({0}Operand.has_value())\n" - " {0}Operands.push_back(*{0}Operand);\n", - var->name); + body << formatv(" if ({0}Operand.has_value())\n" + " {0}Operands.push_back(*{0}Operand);\n", + var->name); } else if (var->isVariadicOfVariadic()) { - body << llvm::formatv( + body << formatv( " for (const auto &subRange : {0}OperandGroups) {{\n" " {0}Operands.append(subRange.begin(), subRange.end());\n" " {0}OperandGroupSizes.push_back(subRange.size());\n" @@ -1128,11 +1130,11 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body, ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getArg(), lengthKind); if (lengthKind == ArgumentLengthKind::Optional) { - body << llvm::formatv(" if ({0}Type)\n" - " {0}Types.push_back({0}Type);\n", - listName); + body << formatv(" if ({0}Type)\n" + " {0}Types.push_back({0}Type);\n", + listName); } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) { - body << llvm::formatv( + body << formatv( " for (const auto &subRange : {0}TypeGroups)\n" " {0}Types.append(subRange.begin(), subRange.end());\n", listName); @@ -1460,9 +1462,9 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body, body << " if (" << attrVar->getVar()->name << "Attr) {\n"; } else if (auto *propVar = dyn_cast(firstElement)) { genPropertyParser(propVar, body, opCppClassName, /*requireParse=*/false); - body << llvm::formatv("if ({0}PropParseResult.has_value() && " - "succeeded(*{0}PropParseResult)) ", - propVar->getVar()->name) + body << formatv("if ({0}PropParseResult.has_value() && " + "succeeded(*{0}PropParseResult)) ", + propVar->getVar()->name) << " {\n"; } else if (auto *literal = dyn_cast(firstElement)) { body << " if (::mlir::succeeded(parser.parseOptional"; @@ -1477,13 +1479,12 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body, genElementParser(regionVar, body, attrTypeCtx); body << " if (!" << region->name << "Regions.empty()) {\n"; } else { - body << llvm::formatv(optionalRegionParserCode, region->name); + body << formatv(optionalRegionParserCode, region->name); body << " if (!" << region->name << "Region->empty()) {\n "; if (hasImplicitTermTrait) - body << llvm::formatv(regionEnsureTerminatorParserCode, region->name); + body << formatv(regionEnsureTerminatorParserCode, region->name); else if (hasSingleBlockTrait) - body << llvm::formatv(regionEnsureSingleBlockParserCode, - region->name); + body << formatv(regionEnsureSingleBlockParserCode, region->name); } } else if (auto *custom = dyn_cast(firstElement)) { body << " if (auto optResult = [&]() -> ::mlir::OptionalParseResult {\n"; @@ -1575,26 +1576,26 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body, ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar()); StringRef name = operand->getVar()->name; if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) - body << llvm::formatv(variadicOfVariadicOperandParserCode, name); + body << formatv(variadicOfVariadicOperandParserCode, name); else if (lengthKind == ArgumentLengthKind::Variadic) - body << llvm::formatv(variadicOperandParserCode, name); + body << formatv(variadicOperandParserCode, name); else if (lengthKind == ArgumentLengthKind::Optional) - body << llvm::formatv(optionalOperandParserCode, name); + body << formatv(optionalOperandParserCode, name); else body << formatv(operandParserCode, name); } else if (auto *region = dyn_cast(element)) { bool isVariadic = region->getVar()->isVariadic(); - body << llvm::formatv(isVariadic ? regionListParserCode : regionParserCode, - region->getVar()->name); + body << formatv(isVariadic ? regionListParserCode : regionParserCode, + region->getVar()->name); if (hasImplicitTermTrait) - body << llvm::formatv(isVariadic ? regionListEnsureTerminatorParserCode - : regionEnsureTerminatorParserCode, - region->getVar()->name); + body << formatv(isVariadic ? regionListEnsureTerminatorParserCode + : regionEnsureTerminatorParserCode, + region->getVar()->name); else if (hasSingleBlockTrait) - body << llvm::formatv(isVariadic ? regionListEnsureSingleBlockParserCode - : regionEnsureSingleBlockParserCode, - region->getVar()->name); + body << formatv(isVariadic ? regionListEnsureSingleBlockParserCode + : regionEnsureSingleBlockParserCode, + region->getVar()->name); } else if (auto *successor = dyn_cast(element)) { bool isVariadic = successor->getVar()->isVariadic(); @@ -1631,24 +1632,24 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body, << " return ::mlir::failure();\n"; } else if (isa(element)) { - body << llvm::formatv(regionListParserCode, "full"); + body << formatv(regionListParserCode, "full"); if (hasImplicitTermTrait) - body << llvm::formatv(regionListEnsureTerminatorParserCode, "full"); + body << formatv(regionListEnsureTerminatorParserCode, "full"); else if (hasSingleBlockTrait) - body << llvm::formatv(regionListEnsureSingleBlockParserCode, "full"); + body << formatv(regionListEnsureSingleBlockParserCode, "full"); } else if (isa(element)) { - body << llvm::formatv(successorListParserCode, "full"); + body << formatv(successorListParserCode, "full"); } else if (auto *dir = dyn_cast(element)) { ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getArg(), lengthKind); if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) { - body << llvm::formatv(variadicOfVariadicTypeParserCode, listName); + body << formatv(variadicOfVariadicTypeParserCode, listName); } else if (lengthKind == ArgumentLengthKind::Variadic) { - body << llvm::formatv(variadicTypeParserCode, listName); + body << formatv(variadicTypeParserCode, listName); } else if (lengthKind == ArgumentLengthKind::Optional) { - body << llvm::formatv(optionalTypeParserCode, listName); + body << formatv(optionalTypeParserCode, listName); } else { const char *parserCode = dir->shouldBeQualified() ? qualifiedTypeParserCode : typeParserCode; @@ -1903,14 +1904,14 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op, if (!operand.isVariadicOfVariadic()) continue; if (op.getDialect().usePropertiesForAttributes()) { - body << llvm::formatv( + body << formatv( " result.getOrAddProperties<{0}::Properties>().{1} = " "parser.getBuilder().getDenseI32ArrayAttr({2}OperandGroupSizes);\n", op.getCppClassName(), operand.constraint.getVariadicOfVariadicSegmentSizeAttr(), operand.name); } else { - body << llvm::formatv( + body << formatv( " result.addAttribute(\"{0}\", " "parser.getBuilder().getDenseI32ArrayAttr({1}OperandGroupSizes));" "\n", @@ -2160,7 +2161,7 @@ static void genCustomDirectiveParameterPrinter(FormatElement *element, if (var->isVariadic()) body << name << "().getTypes()"; else if (var->isOptional()) - body << llvm::formatv("({0}() ? {0}().getType() : ::mlir::Type())", name); + body << formatv("({0}() ? {0}().getType() : ::mlir::Type())", name); else body << name << "().getType()"; @@ -2195,8 +2196,7 @@ static void genCustomDirectivePrinter(CustomDirective *customDir, static void genRegionPrinter(const Twine ®ionName, MethodBody &body, bool hasImplicitTermTrait) { if (hasImplicitTermTrait) - body << llvm::formatv(regionSingleBlockImplicitTerminatorPrinterCode, - regionName); + body << formatv(regionSingleBlockImplicitTerminatorPrinterCode, regionName); else body << " _odsPrinter.printRegion(" << regionName << ");\n"; } @@ -2220,12 +2220,12 @@ static MethodBody &genTypeOperandPrinter(FormatElement *arg, const Operator &op, auto *operand = dyn_cast(arg); auto *var = operand ? operand->getVar() : cast(arg)->getVar(); if (var->isVariadicOfVariadic()) - return body << llvm::formatv("{0}().join().getTypes()", - op.getGetterName(var->name)); + return body << formatv("{0}().join().getTypes()", + op.getGetterName(var->name)); if (var->isVariadic()) return body << op.getGetterName(var->name) << "().getTypes()"; if (var->isOptional()) - return body << llvm::formatv( + return body << formatv( "({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : " "::llvm::ArrayRef<::mlir::Type>())", op.getGetterName(var->name)); @@ -2242,10 +2242,10 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op, const EnumAttr &enumAttr = cast(baseAttr); std::vector cases = enumAttr.getAllCases(); - body << llvm::formatv(enumAttrBeginPrinterCode, - (var->attr.isOptional() ? "*" : "") + - op.getGetterName(var->name), - enumAttr.getSymbolToStringFnName()); + body << formatv(enumAttrBeginPrinterCode, + (var->attr.isOptional() ? "*" : "") + + op.getGetterName(var->name), + enumAttr.getSymbolToStringFnName()); // Get a string containing all of the cases that can't be represented with a // keyword. @@ -2276,9 +2276,8 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op, if (nonKeywordCases.test(it.index())) continue; StringRef symbol = it.value().getSymbol(); - body << llvm::formatv(" case {0}::{1}::{2}:\n", cppNamespace, enumName, - llvm::isDigit(symbol.front()) ? ("_" + symbol) - : symbol); + body << formatv(" case {0}::{1}::{2}:\n", cppNamespace, enumName, + llvm::isDigit(symbol.front()) ? ("_" + symbol) : symbol); } body << " _odsPrinter << caseValueStr;\n" " break;\n" @@ -2584,7 +2583,7 @@ void OperationFormat::genElementPrinter(FormatElement *element, } else if (auto *dir = dyn_cast(element)) { if (auto *operand = dyn_cast(dir->getArg())) { if (operand->getVar()->isVariadicOfVariadic()) { - body << llvm::formatv( + body << formatv( " ::llvm::interleaveComma({0}().getTypes(), _odsPrinter, " "[&](::mlir::TypeRange types) {{ _odsPrinter << \"(\" << " "types << \")\"; });\n", @@ -2710,7 +2709,7 @@ class OpFormatParser : public FormatParser { /// Verify the state of operation operands within the format. LogicalResult verifyOperands(SMLoc loc, - llvm::StringMap &variableTyResolver); + StringMap &variableTyResolver); /// Verify the state of operation regions within the format. LogicalResult verifyRegions(SMLoc loc); @@ -2718,7 +2717,7 @@ class OpFormatParser : public FormatParser { /// Verify the state of operation results within the format. LogicalResult verifyResults(SMLoc loc, - llvm::StringMap &variableTyResolver); + StringMap &variableTyResolver); /// Verify the state of operation successors within the format. LogicalResult verifySuccessors(SMLoc loc); @@ -2730,18 +2729,17 @@ class OpFormatParser : public FormatParser { /// resolution. void handleAllTypesMatchConstraint( ArrayRef values, - llvm::StringMap &variableTyResolver); + StringMap &variableTyResolver); /// Check for inferable type resolution given all operands, and or results, /// have the same type. If 'includeResults' is true, the results also have the /// same type as all of the operands. void handleSameTypesConstraint( - llvm::StringMap &variableTyResolver, + StringMap &variableTyResolver, bool includeResults); /// Check for inferable type resolution based on another operand, result, or /// attribute. void handleTypesMatchConstraint( - llvm::StringMap &variableTyResolver, - const llvm::Record &def); + StringMap &variableTyResolver, const Record &def); /// Returns an argument or attribute with the given name that has been seen /// within the format. @@ -2794,9 +2792,9 @@ LogicalResult OpFormatParser::verify(SMLoc loc, "custom assembly format"); // Check for any type traits that we can use for inferring types. - llvm::StringMap variableTyResolver; + StringMap variableTyResolver; for (const Trait &trait : op.getTraits()) { - const llvm::Record &def = trait.getDef(); + const Record &def = trait.getDef(); if (def.isSubClassOf("AllTypesMatch")) { handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"), variableTyResolver); @@ -2995,10 +2993,9 @@ OpFormatParser::verifyAttributeColonType(SMLoc loc, return false; // If we encounter `:`, the range is known to be invalid. (void)emitError( - loc, - llvm::formatv("format ambiguity caused by `:` literal found after " - "attribute `{0}` which does not have a buildable type", - cast(base)->getVar()->name)); + loc, formatv("format ambiguity caused by `:` literal found after " + "attribute `{0}` which does not have a buildable type", + cast(base)->getVar()->name)); return true; }; return verifyAdjacentElements(isBase, isInvalid, elements); @@ -3018,9 +3015,9 @@ OpFormatParser::verifyAttrDictRegion(SMLoc loc, return false; (void)emitErrorAndNote( loc, - llvm::formatv("format ambiguity caused by `attr-dict` directive " - "followed by region `{0}`", - region->getVar()->name), + formatv("format ambiguity caused by `attr-dict` directive " + "followed by region `{0}`", + region->getVar()->name), "try using `attr-dict-with-keyword` instead"); return true; }; @@ -3028,7 +3025,7 @@ OpFormatParser::verifyAttrDictRegion(SMLoc loc, } LogicalResult OpFormatParser::verifyOperands( - SMLoc loc, llvm::StringMap &variableTyResolver) { + SMLoc loc, StringMap &variableTyResolver) { // Check that all of the operands are within the format, and their types can // be inferred. auto &buildableTypes = fmt.buildableTypes; @@ -3093,7 +3090,7 @@ LogicalResult OpFormatParser::verifyRegions(SMLoc loc) { } LogicalResult OpFormatParser::verifyResults( - SMLoc loc, llvm::StringMap &variableTyResolver) { + SMLoc loc, StringMap &variableTyResolver) { // If we format all of the types together, there is nothing to check. if (fmt.allResultTypes) return success(); @@ -3197,7 +3194,7 @@ OpFormatParser::verifyOIListElements(SMLoc loc, void OpFormatParser::handleAllTypesMatchConstraint( ArrayRef values, - llvm::StringMap &variableTyResolver) { + StringMap &variableTyResolver) { for (unsigned i = 0, e = values.size(); i != e; ++i) { // Check to see if this value matches a resolved operand or result type. ConstArgument arg = findSeenArg(values[i]); @@ -3213,7 +3210,7 @@ void OpFormatParser::handleAllTypesMatchConstraint( } void OpFormatParser::handleSameTypesConstraint( - llvm::StringMap &variableTyResolver, + StringMap &variableTyResolver, bool includeResults) { const NamedTypeConstraint *resolver = nullptr; int resolvedIt = -1; @@ -3238,8 +3235,7 @@ void OpFormatParser::handleSameTypesConstraint( } void OpFormatParser::handleTypesMatchConstraint( - llvm::StringMap &variableTyResolver, - const llvm::Record &def) { + StringMap &variableTyResolver, const Record &def) { StringRef lhsName = def.getValueAsString("lhs"); StringRef rhsName = def.getValueAsString("rhs"); StringRef transformer = def.getValueAsString("transformer"); diff --git a/mlir/tools/mlir-tblgen/OpGenHelpers.cpp b/mlir/tools/mlir-tblgen/OpGenHelpers.cpp index 702ea6643245548..18e42285bdde91b 100644 --- a/mlir/tools/mlir-tblgen/OpGenHelpers.cpp +++ b/mlir/tools/mlir-tblgen/OpGenHelpers.cpp @@ -41,7 +41,7 @@ static std::string getOperationName(const Record &def) { auto opName = def.getValueAsString("opName"); if (prefix.empty()) return std::string(opName); - return std::string(llvm::formatv("{0}.{1}", prefix, opName)); + return std::string(formatv("{0}.{1}", prefix, opName)); } std::vector @@ -50,7 +50,7 @@ mlir::tblgen::getRequestedOpDefinitions(const RecordKeeper &recordKeeper) { if (!classDef) PrintFatalError("ERROR: Couldn't find the 'Op' class!\n"); - llvm::Regex includeRegex(opIncFilter), excludeRegex(opExcFilter); + Regex includeRegex(opIncFilter), excludeRegex(opExcFilter); std::vector defs; for (const auto &def : recordKeeper.getDefs()) { if (!def.second->isSubClassOf(classDef)) @@ -70,7 +70,7 @@ mlir::tblgen::getRequestedOpDefinitions(const RecordKeeper &recordKeeper) { } bool mlir::tblgen::isPythonReserved(StringRef str) { - static llvm::StringSet<> reserved({ + static StringSet<> reserved({ "False", "None", "True", "and", "as", "assert", "async", "await", "break", "class", "continue", "def", "del", "elif", "else", "except", "finally", "for", "from", "global", "if", @@ -86,8 +86,8 @@ bool mlir::tblgen::isPythonReserved(StringRef str) { } void mlir::tblgen::shardOpDefinitions( - ArrayRef defs, - SmallVectorImpl> &shardedDefs) { + ArrayRef defs, + SmallVectorImpl> &shardedDefs) { assert(opShardCount > 0 && "expected a positive shard count"); if (opShardCount == 1) { shardedDefs.push_back(defs); diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp index 7c32c2549d788f7..8fa286f14dc5ed8 100644 --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -23,6 +23,8 @@ #include "llvm/TableGen/TableGenBackend.h" using namespace mlir; +using llvm::Record; +using llvm::RecordKeeper; using mlir::tblgen::Interface; using mlir::tblgen::InterfaceMethod; using mlir::tblgen::OpInterface; @@ -61,14 +63,13 @@ static void emitMethodNameAndArgs(const InterfaceMethod &method, /// Get an array of all OpInterface definitions but exclude those subclassing /// "DeclareOpInterfaceMethods". -static std::vector -getAllInterfaceDefinitions(const llvm::RecordKeeper &recordKeeper, - StringRef name) { - std::vector defs = +static std::vector +getAllInterfaceDefinitions(const RecordKeeper &recordKeeper, StringRef name) { + std::vector defs = recordKeeper.getAllDerivedDefinitions((name + "Interface").str()); std::string declareName = ("Declare" + name + "InterfaceMethods").str(); - llvm::erase_if(defs, [&](const llvm::Record *def) { + llvm::erase_if(defs, [&](const Record *def) { // Ignore any "declare methods" interfaces. if (def->isSubClassOf(declareName)) return true; @@ -88,7 +89,7 @@ class InterfaceGenerator { bool emitInterfaceDocs(); protected: - InterfaceGenerator(std::vector &&defs, raw_ostream &os) + InterfaceGenerator(std::vector &&defs, raw_ostream &os) : defs(std::move(defs)), os(os) {} void emitConceptDecl(const Interface &interface); @@ -99,7 +100,7 @@ class InterfaceGenerator { void emitInterfaceDecl(const Interface &interface); /// The set of interface records to emit. - std::vector defs; + std::vector defs; // The stream to emit to. raw_ostream &os; /// The C++ value type of the interface, e.g. Operation*. @@ -118,7 +119,7 @@ class InterfaceGenerator { /// A specialized generator for attribute interfaces. struct AttrInterfaceGenerator : public InterfaceGenerator { - AttrInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) + AttrInterfaceGenerator(const RecordKeeper &records, raw_ostream &os) : InterfaceGenerator(getAllInterfaceDefinitions(records, "Attr"), os) { valueType = "::mlir::Attribute"; interfaceBaseType = "AttributeInterface"; @@ -133,7 +134,7 @@ struct AttrInterfaceGenerator : public InterfaceGenerator { }; /// A specialized generator for operation interfaces. struct OpInterfaceGenerator : public InterfaceGenerator { - OpInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) + OpInterfaceGenerator(const RecordKeeper &records, raw_ostream &os) : InterfaceGenerator(getAllInterfaceDefinitions(records, "Op"), os) { valueType = "::mlir::Operation *"; interfaceBaseType = "OpInterface"; @@ -149,7 +150,7 @@ struct OpInterfaceGenerator : public InterfaceGenerator { }; /// A specialized generator for type interfaces. struct TypeInterfaceGenerator : public InterfaceGenerator { - TypeInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) + TypeInterfaceGenerator(const RecordKeeper &records, raw_ostream &os) : InterfaceGenerator(getAllInterfaceDefinitions(records, "Type"), os) { valueType = "::mlir::Type"; interfaceBaseType = "TypeInterface"; @@ -607,13 +608,13 @@ bool InterfaceGenerator::emitInterfaceDecls() { llvm::emitSourceFileHeader("Interface Declarations", os); // Sort according to ID, so defs are emitted in the order in which they appear // in the Tablegen file. - std::vector sortedDefs(defs); - llvm::sort(sortedDefs, [](const llvm::Record *lhs, const llvm::Record *rhs) { + std::vector sortedDefs(defs); + llvm::sort(sortedDefs, [](const Record *lhs, const Record *rhs) { return lhs->getID() < rhs->getID(); }); - for (const llvm::Record *def : sortedDefs) + for (const Record *def : sortedDefs) emitInterfaceDecl(Interface(def)); - for (const llvm::Record *def : sortedDefs) + for (const Record *def : sortedDefs) emitModelMethodsDef(Interface(def)); return false; } @@ -622,8 +623,7 @@ bool InterfaceGenerator::emitInterfaceDecls() { // GEN: Interface documentation //===----------------------------------------------------------------------===// -static void emitInterfaceDoc(const llvm::Record &interfaceDef, - raw_ostream &os) { +static void emitInterfaceDoc(const Record &interfaceDef, raw_ostream &os) { Interface interface(&interfaceDef); // Emit the interface name followed by the description. @@ -684,15 +684,15 @@ struct InterfaceGenRegistration { genDefDesc(("Generate " + genDesc + " interface definitions").str()), genDocDesc(("Generate " + genDesc + " interface documentation").str()), genDecls(genDeclArg, genDeclDesc, - [](const llvm::RecordKeeper &records, raw_ostream &os) { + [](const RecordKeeper &records, raw_ostream &os) { return GeneratorT(records, os).emitInterfaceDecls(); }), genDefs(genDefArg, genDefDesc, - [](const llvm::RecordKeeper &records, raw_ostream &os) { + [](const RecordKeeper &records, raw_ostream &os) { return GeneratorT(records, os).emitInterfaceDefs(); }), genDocs(genDocArg, genDocDesc, - [](const llvm::RecordKeeper &records, raw_ostream &os) { + [](const RecordKeeper &records, raw_ostream &os) { return GeneratorT(records, os).emitInterfaceDocs(); }) {} diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp index 553ab6adc65b061..0c5c936f5addeea 100644 --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -23,6 +23,9 @@ using namespace mlir; using namespace mlir::tblgen; +using llvm::formatv; +using llvm::Record; +using llvm::RecordKeeper; /// File header and includes. /// {0} is the dialect namespace. @@ -315,9 +318,9 @@ static std::string sanitizeName(StringRef name) { } static std::string attrSizedTraitForKind(const char *kind) { - return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments", - llvm::StringRef(kind).take_front().upper(), - llvm::StringRef(kind).drop_front()); + return formatv("::mlir::OpTrait::AttrSized{0}{1}Segments", + StringRef(kind).take_front().upper(), + StringRef(kind).drop_front()); } /// Emits accessors to "elements" of an Op definition. Currently, the supported @@ -328,15 +331,14 @@ static void emitElementAccessors( unsigned numVariadicGroups, unsigned numElements, llvm::function_ref getElement) { - assert(llvm::is_contained( - llvm::SmallVector{"operand", "result"}, kind) && + assert(llvm::is_contained(SmallVector{"operand", "result"}, + kind) && "unsupported kind"); // Traits indicating how to process variadic elements. - std::string sameSizeTrait = - llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size", - llvm::StringRef(kind).take_front().upper(), - llvm::StringRef(kind).drop_front()); + std::string sameSizeTrait = formatv("::mlir::OpTrait::SameVariadic{0}{1}Size", + StringRef(kind).take_front().upper(), + StringRef(kind).drop_front()); std::string attrSizedTrait = attrSizedTraitForKind(kind); // If there is only one variable-length element group, its size can be @@ -351,15 +353,14 @@ static void emitElementAccessors( if (element.name.empty()) continue; if (element.isVariableLength()) { - os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate - : opOneVariadicTemplate, - sanitizeName(element.name), kind, numElements, i); + os << formatv(element.isOptional() ? opOneOptionalTemplate + : opOneVariadicTemplate, + sanitizeName(element.name), kind, numElements, i); } else if (seenVariableLength) { - os << llvm::formatv(opSingleAfterVariableTemplate, - sanitizeName(element.name), kind, numElements, i); + os << formatv(opSingleAfterVariableTemplate, sanitizeName(element.name), + kind, numElements, i); } else { - os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind, - i); + os << formatv(opSingleTemplate, sanitizeName(element.name), kind, i); } } return; @@ -382,14 +383,13 @@ static void emitElementAccessors( for (unsigned i = 0; i < numElements; ++i) { const NamedTypeConstraint &element = getElement(op, i); if (!element.name.empty()) { - os << llvm::formatv(opVariadicEqualPrefixTemplate, - sanitizeName(element.name), kind, numSimpleLength, - numVariadicGroups, numPrecedingSimple, - numPrecedingVariadic); - os << llvm::formatv(element.isVariableLength() - ? opVariadicEqualVariadicTemplate - : opVariadicEqualSimpleTemplate, - kind); + os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name), + kind, numSimpleLength, numVariadicGroups, + numPrecedingSimple, numPrecedingVariadic); + os << formatv(element.isVariableLength() + ? opVariadicEqualVariadicTemplate + : opVariadicEqualSimpleTemplate, + kind); } if (element.isVariableLength()) ++numPrecedingVariadic; @@ -412,9 +412,9 @@ static void emitElementAccessors( trailing = "[0]"; else if (element.isOptional()) trailing = std::string( - llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind)); - os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name), - kind, i, trailing); + formatv(opVariadicSegmentOptionalTrailingTemplate, kind)); + os << formatv(opVariadicSegmentTemplate, sanitizeName(element.name), kind, + i, trailing); } return; } @@ -459,27 +459,21 @@ static void emitAttributeAccessors(const Operator &op, raw_ostream &os) { // Unit attributes are handled specially. if (namedAttr.attr.getStorageType().trim() == "::mlir::UnitAttr") { - os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName, - namedAttr.name); - os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName, - namedAttr.name); - os << llvm::formatv(attributeDeleterTemplate, sanitizedName, - namedAttr.name); + os << formatv(unitAttributeGetterTemplate, sanitizedName, namedAttr.name); + os << formatv(unitAttributeSetterTemplate, sanitizedName, namedAttr.name); + os << formatv(attributeDeleterTemplate, sanitizedName, namedAttr.name); continue; } if (namedAttr.attr.isOptional()) { - os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName, - namedAttr.name); - os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName, - namedAttr.name); - os << llvm::formatv(attributeDeleterTemplate, sanitizedName, - namedAttr.name); + os << formatv(optionalAttributeGetterTemplate, sanitizedName, + namedAttr.name); + os << formatv(optionalAttributeSetterTemplate, sanitizedName, + namedAttr.name); + os << formatv(attributeDeleterTemplate, sanitizedName, namedAttr.name); } else { - os << llvm::formatv(attributeGetterTemplate, sanitizedName, - namedAttr.name); - os << llvm::formatv(attributeSetterTemplate, sanitizedName, - namedAttr.name); + os << formatv(attributeGetterTemplate, sanitizedName, namedAttr.name); + os << formatv(attributeSetterTemplate, sanitizedName, namedAttr.name); // Non-optional attributes cannot be deleted. } } @@ -595,7 +589,7 @@ static bool canInferType(const Operator &op) { /// accept them as arguments. static void populateBuilderArgsResults(const Operator &op, - llvm::SmallVectorImpl &builderArgs) { + SmallVectorImpl &builderArgs) { if (canInferType(op)) return; @@ -607,7 +601,7 @@ populateBuilderArgsResults(const Operator &op, // to properly match the built-in result accessor. name = "result"; } else { - name = llvm::formatv("_gen_res_{0}", i); + name = formatv("_gen_res_{0}", i); } } name = sanitizeName(name); @@ -620,14 +614,13 @@ populateBuilderArgsResults(const Operator &op, /// appear in the `arguments` field of the op definition. Additionally, /// `operandNames` is populated with names of operands in their order of /// appearance. -static void -populateBuilderArgs(const Operator &op, - llvm::SmallVectorImpl &builderArgs, - llvm::SmallVectorImpl &operandNames) { +static void populateBuilderArgs(const Operator &op, + SmallVectorImpl &builderArgs, + SmallVectorImpl &operandNames) { for (int i = 0, e = op.getNumArgs(); i < e; ++i) { std::string name = op.getArgName(i).str(); if (name.empty()) - name = llvm::formatv("_gen_arg_{0}", i); + name = formatv("_gen_arg_{0}", i); name = sanitizeName(name); builderArgs.push_back(name); if (!op.getArg(i).is()) @@ -637,15 +630,16 @@ populateBuilderArgs(const Operator &op, /// Populates `builderArgs` with the Python-compatible names of builder function /// successor arguments. Additionally, `successorArgNames` is also populated. -static void populateBuilderArgsSuccessors( - const Operator &op, llvm::SmallVectorImpl &builderArgs, - llvm::SmallVectorImpl &successorArgNames) { +static void +populateBuilderArgsSuccessors(const Operator &op, + SmallVectorImpl &builderArgs, + SmallVectorImpl &successorArgNames) { for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) { NamedSuccessor successor = op.getSuccessor(i); std::string name = std::string(successor.name); if (name.empty()) - name = llvm::formatv("_gen_successor_{0}", i); + name = formatv("_gen_successor_{0}", i); name = sanitizeName(name); builderArgs.push_back(name); successorArgNames.push_back(name); @@ -658,9 +652,8 @@ static void populateBuilderArgsSuccessors( /// operands and attributes in the same order as they appear in the `arguments` /// field. static void -populateBuilderLinesAttr(const Operator &op, - llvm::ArrayRef argNames, - llvm::SmallVectorImpl &builderLines) { +populateBuilderLinesAttr(const Operator &op, ArrayRef argNames, + SmallVectorImpl &builderLines) { builderLines.push_back("_ods_context = _ods_get_default_loc_context(loc)"); for (int i = 0, e = op.getNumArgs(); i < e; ++i) { Argument arg = op.getArg(i); @@ -670,12 +663,12 @@ populateBuilderLinesAttr(const Operator &op, // Unit attributes are handled specially. if (attribute->attr.getStorageType().trim() == "::mlir::UnitAttr") { - builderLines.push_back(llvm::formatv(initUnitAttributeTemplate, - attribute->name, argNames[i])); + builderLines.push_back( + formatv(initUnitAttributeTemplate, attribute->name, argNames[i])); continue; } - builderLines.push_back(llvm::formatv( + builderLines.push_back(formatv( attribute->attr.isOptional() || attribute->attr.hasDefaultValue() ? initOptionalAttributeWithBuilderTemplate : initAttributeWithBuilderTemplate, @@ -686,30 +679,30 @@ populateBuilderLinesAttr(const Operator &op, /// Populates `builderLines` with additional lines that are required in the /// builder to set up successors. successorArgNames is expected to correspond /// to the Python argument name for each successor on the op. -static void populateBuilderLinesSuccessors( - const Operator &op, llvm::ArrayRef successorArgNames, - llvm::SmallVectorImpl &builderLines) { +static void +populateBuilderLinesSuccessors(const Operator &op, + ArrayRef successorArgNames, + SmallVectorImpl &builderLines) { if (successorArgNames.empty()) { - builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "None")); + builderLines.push_back(formatv(initSuccessorsTemplate, "None")); return; } - builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "[]")); + builderLines.push_back(formatv(initSuccessorsTemplate, "[]")); for (int i = 0, e = successorArgNames.size(); i < e; ++i) { auto &argName = successorArgNames[i]; const NamedSuccessor &successor = op.getSuccessor(i); - builderLines.push_back( - llvm::formatv(addSuccessorTemplate, - successor.isVariadic() ? "extend" : "append", argName)); + builderLines.push_back(formatv(addSuccessorTemplate, + successor.isVariadic() ? "extend" : "append", + argName)); } } /// Populates `builderLines` with additional lines that are required in the /// builder to set up op operands. static void -populateBuilderLinesOperand(const Operator &op, - llvm::ArrayRef names, - llvm::SmallVectorImpl &builderLines) { +populateBuilderLinesOperand(const Operator &op, ArrayRef names, + SmallVectorImpl &builderLines) { bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr; // For each element, find or generate a name. @@ -718,7 +711,7 @@ populateBuilderLinesOperand(const Operator &op, std::string name = names[i]; // Choose the formatting string based on the element kind. - llvm::StringRef formatString; + StringRef formatString; if (!element.isVariableLength()) { formatString = singleOperandAppendTemplate; } else if (element.isOptional()) { @@ -738,7 +731,7 @@ populateBuilderLinesOperand(const Operator &op, } } - builderLines.push_back(llvm::formatv(formatString.data(), name)); + builderLines.push_back(formatv(formatString.data(), name)); } } @@ -758,7 +751,7 @@ constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})"; /// Appends the given multiline string as individual strings into /// `builderLines`. static void appendLineByLine(StringRef string, - llvm::SmallVectorImpl &builderLines) { + SmallVectorImpl &builderLines) { std::pair split = std::make_pair(string, string); do { @@ -770,14 +763,13 @@ static void appendLineByLine(StringRef string, /// Populates `builderLines` with additional lines that are required in the /// builder to set up op results. static void -populateBuilderLinesResult(const Operator &op, - llvm::ArrayRef names, - llvm::SmallVectorImpl &builderLines) { +populateBuilderLinesResult(const Operator &op, ArrayRef names, + SmallVectorImpl &builderLines) { bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr; if (hasSameArgumentAndResultTypes(op)) { - builderLines.push_back(llvm::formatv( - appendSameResultsTemplate, "operands[0].type", op.getNumResults())); + builderLines.push_back(formatv(appendSameResultsTemplate, + "operands[0].type", op.getNumResults())); return; } @@ -785,12 +777,11 @@ populateBuilderLinesResult(const Operator &op, const NamedAttribute &firstAttr = op.getAttribute(0); assert(!firstAttr.name.empty() && "unexpected empty name for the attribute " "from which the type is derived"); - appendLineByLine( - llvm::formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(), - builderLines); - builderLines.push_back(llvm::formatv(appendSameResultsTemplate, - "_ods_derived_result_type", - op.getNumResults())); + appendLineByLine(formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(), + builderLines); + builderLines.push_back(formatv(appendSameResultsTemplate, + "_ods_derived_result_type", + op.getNumResults())); return; } @@ -803,7 +794,7 @@ populateBuilderLinesResult(const Operator &op, std::string name = names[i]; // Choose the formatting string based on the element kind. - llvm::StringRef formatString; + StringRef formatString; if (!element.isVariableLength()) { formatString = singleResultAppendTemplate; } else if (element.isOptional()) { @@ -819,17 +810,16 @@ populateBuilderLinesResult(const Operator &op, } } - builderLines.push_back(llvm::formatv(formatString.data(), name)); + builderLines.push_back(formatv(formatString.data(), name)); } } /// If the operation has variadic regions, adds a builder argument to specify /// the number of those regions and builder lines to forward it to the generic /// constructor. -static void -populateBuilderRegions(const Operator &op, - llvm::SmallVectorImpl &builderArgs, - llvm::SmallVectorImpl &builderLines) { +static void populateBuilderRegions(const Operator &op, + SmallVectorImpl &builderArgs, + SmallVectorImpl &builderLines) { if (op.hasNoVariadicRegions()) return; @@ -844,19 +834,19 @@ populateBuilderRegions(const Operator &op, .str(); builderArgs.push_back(name); builderLines.push_back( - llvm::formatv("regions = {0} + {1}", op.getNumRegions() - 1, name)); + formatv("regions = {0} + {1}", op.getNumRegions() - 1, name)); } /// Emits a default builder constructing an operation from the list of its /// result types, followed by a list of its operands. Returns vector /// of fully built functionArgs for downstream users (to save having to /// rebuild anew). -static llvm::SmallVector emitDefaultOpBuilder(const Operator &op, - raw_ostream &os) { - llvm::SmallVector builderArgs; - llvm::SmallVector builderLines; - llvm::SmallVector operandArgNames; - llvm::SmallVector successorArgNames; +static SmallVector emitDefaultOpBuilder(const Operator &op, + raw_ostream &os) { + SmallVector builderArgs; + SmallVector builderLines; + SmallVector operandArgNames; + SmallVector successorArgNames; builderArgs.reserve(op.getNumOperands() + op.getNumResults() + op.getNumNativeAttributes() + op.getNumSuccessors()); populateBuilderArgsResults(op, builderArgs); @@ -866,10 +856,10 @@ static llvm::SmallVector emitDefaultOpBuilder(const Operator &op, populateBuilderArgsSuccessors(op, builderArgs, successorArgNames); populateBuilderLinesOperand(op, operandArgNames, builderLines); - populateBuilderLinesAttr( - op, llvm::ArrayRef(builderArgs).drop_front(numResultArgs), builderLines); + populateBuilderLinesAttr(op, ArrayRef(builderArgs).drop_front(numResultArgs), + builderLines); populateBuilderLinesResult( - op, llvm::ArrayRef(builderArgs).take_front(numResultArgs), builderLines); + op, ArrayRef(builderArgs).take_front(numResultArgs), builderLines); populateBuilderLinesSuccessors(op, successorArgNames, builderLines); populateBuilderRegions(op, builderArgs, builderLines); @@ -896,7 +886,7 @@ static llvm::SmallVector emitDefaultOpBuilder(const Operator &op, }; // StringRefs in functionArgs refer to strings allocated by builderArgs. - llvm::SmallVector functionArgs; + SmallVector functionArgs; // Add positional arguments. for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) { @@ -929,11 +919,10 @@ static llvm::SmallVector emitDefaultOpBuilder(const Operator &op, initArgs.push_back("loc=loc"); initArgs.push_back("ip=ip"); - os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "), - llvm::join(builderLines, "\n "), - llvm::join(initArgs, ", ")); + os << formatv(initTemplate, llvm::join(functionArgs, ", "), + llvm::join(builderLines, "\n "), llvm::join(initArgs, ", ")); return llvm::to_vector<8>( - llvm::map_range(functionArgs, [](llvm::StringRef s) { return s.str(); })); + llvm::map_range(functionArgs, [](StringRef s) { return s.str(); })); } static void emitSegmentSpec( @@ -955,15 +944,15 @@ static void emitSegmentSpec( } segmentSpec.append("]"); - os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec); + os << formatv(opClassSizedSegmentsTemplate, kind, segmentSpec); } static void emitRegionAttributes(const Operator &op, raw_ostream &os) { // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions). // Note that the base OpView class defines this as (0, True). unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions(); - os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount, - op.hasNoVariadicRegions() ? "True" : "False"); + os << formatv(opClassRegionSpecTemplate, minRegionCount, + op.hasNoVariadicRegions() ? "True" : "False"); } /// Emits named accessors to regions. @@ -975,20 +964,20 @@ static void emitRegionAccessors(const Operator &op, raw_ostream &os) { assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) && "expected only the last region to be variadic"); - os << llvm::formatv(regionAccessorTemplate, sanitizeName(region.name), - std::to_string(en.index()) + - (region.isVariadic() ? ":" : "")); + os << formatv(regionAccessorTemplate, sanitizeName(region.name), + std::to_string(en.index()) + + (region.isVariadic() ? ":" : "")); } } /// Emits builder that extracts results from op static void emitValueBuilder(const Operator &op, - llvm::SmallVector functionArgs, + SmallVector functionArgs, raw_ostream &os) { // Params with (possibly) default args. auto valueBuilderParams = llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) { - llvm::SmallVector argMaybeDefault = + SmallVector argMaybeDefault = llvm::to_vector<2>(llvm::split(argAndMaybeDefault, "=")); auto arg = llvm::convertToSnakeFromCamelCase(argMaybeDefault[0]); if (argMaybeDefault.size() == 2) @@ -1005,7 +994,7 @@ static void emitValueBuilder(const Operator &op, }); std::string nameWithoutDialect = op.getOperationName().substr(op.getOperationName().find('.') + 1); - os << llvm::formatv( + os << formatv( valueBuilderTemplate, sanitizeName(nameWithoutDialect), op.getCppClassName(), llvm::join(valueBuilderParams, ", "), llvm::join(opBuilderArgs, ", "), @@ -1016,8 +1005,7 @@ static void emitValueBuilder(const Operator &op, /// Emits bindings for a specific Op to the given output stream. static void emitOpBindings(const Operator &op, raw_ostream &os) { - os << llvm::formatv(opClassTemplate, op.getCppClassName(), - op.getOperationName()); + os << formatv(opClassTemplate, op.getCppClassName(), op.getOperationName()); // Sized segments. if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) { @@ -1028,7 +1016,7 @@ static void emitOpBindings(const Operator &op, raw_ostream &os) { } emitRegionAttributes(op, os); - llvm::SmallVector functionArgs = emitDefaultOpBuilder(op, os); + SmallVector functionArgs = emitDefaultOpBuilder(op, os); emitOperandAccessors(op, os); emitAttributeAccessors(op, os); emitResultAccessors(op, os); @@ -1039,17 +1027,17 @@ static void emitOpBindings(const Operator &op, raw_ostream &os) { /// Emits bindings for the dialect specified in the command line, including file /// headers and utilities. Returns `false` on success to comply with Tablegen /// registration requirements. -static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) { +static bool emitAllOps(const RecordKeeper &records, raw_ostream &os) { if (clDialectName.empty()) llvm::PrintFatalError("dialect name not provided"); os << fileHeader; if (!clDialectExtensionName.empty()) - os << llvm::formatv(dialectExtensionTemplate, clDialectName.getValue()); + os << formatv(dialectExtensionTemplate, clDialectName.getValue()); else - os << llvm::formatv(dialectClassTemplate, clDialectName.getValue()); + os << formatv(dialectClassTemplate, clDialectName.getValue()); - for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) { + for (const Record *rec : records.getAllDerivedDefinitions("Op")) { Operator op(rec); if (op.getDialectName() == clDialectName.getValue()) emitOpBindings(op, os); diff --git a/mlir/tools/mlir-tblgen/PassCAPIGen.cpp b/mlir/tools/mlir-tblgen/PassCAPIGen.cpp index 9f33a4129aaccaa..8c13c9b03133521 100644 --- a/mlir/tools/mlir-tblgen/PassCAPIGen.cpp +++ b/mlir/tools/mlir-tblgen/PassCAPIGen.cpp @@ -20,6 +20,8 @@ using namespace mlir; using namespace mlir::tblgen; +using llvm::formatv; +using llvm::RecordKeeper; static llvm::cl::OptionCategory passGenCat("Options for -gen-pass-capi-header and -gen-pass-capi-impl"); @@ -56,7 +58,7 @@ const char *const fileFooter = R"( )"; /// Emit TODO -static bool emitCAPIHeader(const llvm::RecordKeeper &records, raw_ostream &os) { +static bool emitCAPIHeader(const RecordKeeper &records, raw_ostream &os) { os << fileHeader; os << "// Registration for the entire group\n"; os << "MLIR_CAPI_EXPORTED void mlirRegister" << groupName @@ -64,7 +66,7 @@ static bool emitCAPIHeader(const llvm::RecordKeeper &records, raw_ostream &os) { for (const auto *def : records.getAllDerivedDefinitions("PassBase")) { Pass pass(def); StringRef defName = pass.getDef()->getName(); - os << llvm::formatv(passDecl, groupName, defName); + os << formatv(passDecl, groupName, defName); } os << fileFooter; return false; @@ -91,9 +93,9 @@ void mlirRegister{0}Passes(void) {{ } )"; -static bool emitCAPIImpl(const llvm::RecordKeeper &records, raw_ostream &os) { +static bool emitCAPIImpl(const RecordKeeper &records, raw_ostream &os) { os << "/* Autogenerated by mlir-tblgen; don't manually edit. */"; - os << llvm::formatv(passGroupRegistrationCode, groupName); + os << formatv(passGroupRegistrationCode, groupName); for (const auto *def : records.getAllDerivedDefinitions("PassBase")) { Pass pass(def); @@ -103,10 +105,9 @@ static bool emitCAPIImpl(const llvm::RecordKeeper &records, raw_ostream &os) { if (StringRef constructor = pass.getConstructor(); !constructor.empty()) constructorCall = constructor.str(); else - constructorCall = - llvm::formatv("create{0}()", pass.getDef()->getName()).str(); + constructorCall = formatv("create{0}()", pass.getDef()->getName()).str(); - os << llvm::formatv(passCreateDef, groupName, defName, constructorCall); + os << formatv(passCreateDef, groupName, defName, constructorCall); } return false; } diff --git a/mlir/tools/mlir-tblgen/PassDocGen.cpp b/mlir/tools/mlir-tblgen/PassDocGen.cpp index 8febba191562511..914112e926d8b8e 100644 --- a/mlir/tools/mlir-tblgen/PassDocGen.cpp +++ b/mlir/tools/mlir-tblgen/PassDocGen.cpp @@ -18,6 +18,7 @@ using namespace mlir; using namespace mlir::tblgen; +using llvm::RecordKeeper; /// Emit the documentation for the given pass. static void emitDoc(const Pass &pass, raw_ostream &os) { @@ -56,7 +57,7 @@ static void emitDoc(const Pass &pass, raw_ostream &os) { } } -static void emitDocs(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) { +static void emitDocs(const RecordKeeper &recordKeeper, raw_ostream &os) { os << "\n"; auto passDefs = recordKeeper.getAllDerivedDefinitions("PassBase"); @@ -74,7 +75,7 @@ static void emitDocs(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) { static mlir::GenRegistration genRegister("gen-pass-doc", "Generate pass documentation", - [](const llvm::RecordKeeper &records, raw_ostream &os) { + [](const RecordKeeper &records, raw_ostream &os) { emitDocs(records, os); return false; }); diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp index 655843f26201ac2..295f01b08754692 100644 --- a/mlir/tools/mlir-tblgen/PassGen.cpp +++ b/mlir/tools/mlir-tblgen/PassGen.cpp @@ -21,6 +21,8 @@ using namespace mlir; using namespace mlir::tblgen; +using llvm::formatv; +using llvm::RecordKeeper; static llvm::cl::OptionCategory passGenCat("Options for -gen-pass-decls"); static llvm::cl::opt @@ -28,7 +30,7 @@ static llvm::cl::opt llvm::cl::cat(passGenCat)); /// Extract the list of passes from the TableGen records. -static std::vector getPasses(const llvm::RecordKeeper &recordKeeper) { +static std::vector getPasses(const RecordKeeper &recordKeeper) { std::vector passes; for (const auto *def : recordKeeper.getAllDerivedDefinitions("PassBase")) @@ -91,7 +93,7 @@ static void emitPassOptionsStruct(const Pass &pass, raw_ostream &os) { if (options.empty()) return; - os << llvm::formatv("struct {0}Options {{\n", passName); + os << formatv("struct {0}Options {{\n", passName); for (const PassOption &opt : options) { std::string type = opt.getType().str(); @@ -99,7 +101,7 @@ static void emitPassOptionsStruct(const Pass &pass, raw_ostream &os) { if (opt.isListOption()) type = "::llvm::SmallVector<" + type + ">"; - os.indent(2) << llvm::formatv("{0} {1}", type, opt.getCppVariableName()); + os.indent(2) << formatv("{0} {1}", type, opt.getCppVariableName()); if (std::optional defaultVal = opt.getDefaultValue()) os << " = " << defaultVal; @@ -128,9 +130,9 @@ static void emitPassDecls(const Pass &pass, raw_ostream &os) { // Declaration of the constructor with options. if (ArrayRef options = pass.getOptions(); !options.empty()) - os << llvm::formatv("std::unique_ptr<::mlir::Pass> create{0}(" - "{0}Options options);\n", - passName); + os << formatv("std::unique_ptr<::mlir::Pass> create{0}(" + "{0}Options options);\n", + passName); } os << "#undef " << enableVarName << "\n"; @@ -147,14 +149,13 @@ static void emitRegistrations(llvm::ArrayRef passes, raw_ostream &os) { if (StringRef constructor = pass.getConstructor(); !constructor.empty()) constructorCall = constructor.str(); else - constructorCall = - llvm::formatv("create{0}()", pass.getDef()->getName()).str(); + constructorCall = formatv("create{0}()", pass.getDef()->getName()).str(); - os << llvm::formatv(passRegistrationCode, pass.getDef()->getName(), - constructorCall); + os << formatv(passRegistrationCode, pass.getDef()->getName(), + constructorCall); } - os << llvm::formatv(passGroupRegistrationCode, groupName); + os << formatv(passGroupRegistrationCode, groupName); for (const Pass &pass : passes) os << " register" << pass.getDef()->getName() << "();\n"; @@ -270,9 +271,9 @@ static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) { os.indent(2) << "::mlir::Pass::" << (opt.isListOption() ? "ListOption" : "Option"); - os << llvm::formatv(R"(<{0}> {1}{{*this, "{2}", ::llvm::cl::desc("{3}"))", - opt.getType(), opt.getCppVariableName(), - opt.getArgument(), opt.getDescription()); + os << formatv(R"(<{0}> {1}{{*this, "{2}", ::llvm::cl::desc("{3}"))", + opt.getType(), opt.getCppVariableName(), opt.getArgument(), + opt.getDescription()); if (std::optional defaultVal = opt.getDefaultValue()) os << ", ::llvm::cl::init(" << defaultVal << ")"; if (std::optional additionalFlags = opt.getAdditionalFlags()) @@ -284,9 +285,9 @@ static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) { /// Emit the declarations for each of the pass statistics. static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) { for (const PassStatistic &stat : pass.getStatistics()) { - os << llvm::formatv( - " ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n", - stat.getCppVariableName(), stat.getName(), stat.getDescription()); + os << formatv(" ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n", + stat.getCppVariableName(), stat.getName(), + stat.getDescription()); } } @@ -300,11 +301,10 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) { os << "#ifdef " << enableVarName << "\n"; if (emitDefaultConstructors) { - os << llvm::formatv(friendDefaultConstructorDeclTemplate, passName); + os << formatv(friendDefaultConstructorDeclTemplate, passName); if (emitDefaultConstructorWithOptions) - os << llvm::formatv(friendDefaultConstructorWithOptionsDeclTemplate, - passName); + os << formatv(friendDefaultConstructorWithOptionsDeclTemplate, passName); } std::string dependentDialectRegistrations; @@ -313,24 +313,23 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) { llvm::interleave( pass.getDependentDialects(), dialectsOs, [&](StringRef dependentDialect) { - dialectsOs << llvm::formatv(dialectRegistrationTemplate, - dependentDialect); + dialectsOs << formatv(dialectRegistrationTemplate, dependentDialect); }, "\n "); } os << "namespace impl {\n"; - os << llvm::formatv(baseClassBegin, passName, pass.getBaseClass(), - pass.getArgument(), pass.getSummary(), - dependentDialectRegistrations); + os << formatv(baseClassBegin, passName, pass.getBaseClass(), + pass.getArgument(), pass.getSummary(), + dependentDialectRegistrations); if (ArrayRef options = pass.getOptions(); !options.empty()) { - os.indent(2) << llvm::formatv( - "{0}Base({0}Options options) : {0}Base() {{\n", passName); + os.indent(2) << formatv("{0}Base({0}Options options) : {0}Base() {{\n", + passName); for (const PassOption &opt : pass.getOptions()) - os.indent(4) << llvm::formatv("{0} = std::move(options.{0});\n", - opt.getCppVariableName()); + os.indent(4) << formatv("{0} = std::move(options.{0});\n", + opt.getCppVariableName()); os.indent(2) << "}\n"; } @@ -344,21 +343,20 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) { os << "private:\n"; if (emitDefaultConstructors) { - os << llvm::formatv(friendDefaultConstructorDefTemplate, passName); + os << formatv(friendDefaultConstructorDefTemplate, passName); if (!pass.getOptions().empty()) - os << llvm::formatv(friendDefaultConstructorWithOptionsDefTemplate, - passName); + os << formatv(friendDefaultConstructorWithOptionsDefTemplate, passName); } os << "};\n"; os << "} // namespace impl\n"; if (emitDefaultConstructors) { - os << llvm::formatv(defaultConstructorDefTemplate, passName); + os << formatv(defaultConstructorDefTemplate, passName); if (emitDefaultConstructorWithOptions) - os << llvm::formatv(defaultConstructorWithOptionsDefTemplate, passName); + os << formatv(defaultConstructorWithOptionsDefTemplate, passName); } os << "#undef " << enableVarName << "\n"; @@ -367,7 +365,7 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) { static void emitPass(const Pass &pass, raw_ostream &os) { StringRef passName = pass.getDef()->getName(); - os << llvm::formatv(passHeader, passName); + os << formatv(passHeader, passName); emitPassDecls(pass, os); emitPassDefs(pass, os); @@ -436,21 +434,19 @@ static void emitOldPassDecl(const Pass &pass, raw_ostream &os) { llvm::interleave( pass.getDependentDialects(), dialectsOs, [&](StringRef dependentDialect) { - dialectsOs << llvm::formatv(dialectRegistrationTemplate, - dependentDialect); + dialectsOs << formatv(dialectRegistrationTemplate, dependentDialect); }, "\n "); } - os << llvm::formatv(oldPassDeclBegin, defName, pass.getBaseClass(), - pass.getArgument(), pass.getSummary(), - dependentDialectRegistrations); + os << formatv(oldPassDeclBegin, defName, pass.getBaseClass(), + pass.getArgument(), pass.getSummary(), + dependentDialectRegistrations); emitPassOptionDecls(pass, os); emitPassStatisticDecls(pass, os); os << "};\n"; } -static void emitPasses(const llvm::RecordKeeper &recordKeeper, - raw_ostream &os) { +static void emitPasses(const RecordKeeper &recordKeeper, raw_ostream &os) { std::vector passes = getPasses(recordKeeper); os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n"; @@ -479,7 +475,7 @@ static void emitPasses(const llvm::RecordKeeper &recordKeeper, static mlir::GenRegistration genPassDecls("gen-pass-decls", "Generate pass declarations", - [](const llvm::RecordKeeper &records, raw_ostream &os) { + [](const RecordKeeper &records, raw_ostream &os) { emitPasses(records, os); return false; }); diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index 220e039ac48f4f1..a92f05455be1b29 100644 --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -98,15 +98,15 @@ class Availability { StringRef getMergeInstance() const; // Returns the underlying LLVM TableGen Record. - const llvm::Record *getDef() const { return def; } + const Record *getDef() const { return def; } private: // The TableGen definition of this availability. - const llvm::Record *def; + const Record *def; }; } // namespace -Availability::Availability(const llvm::Record *def) : def(def) { +Availability::Availability(const Record *def) : def(def) { assert(def->isSubClassOf("Availability") && "must be subclass of TableGen 'Availability' class"); }