Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NFC][MLIR][TableGen] Eliminate llvm:: for commonly used types #110841

Merged
merged 1 commit into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

using namespace mlir;
using namespace mlir::tblgen;
using llvm::Record;
using llvm::RecordKeeper;

//===----------------------------------------------------------------------===//
// Utility Functions
Expand All @@ -30,14 +32,14 @@ using namespace mlir::tblgen;
/// Find all the AttrOrTypeDef for the specified dialect. If no dialect
/// specified and can only find one dialect's defs, use that.
static void collectAllDefs(StringRef selectedDialect,
ArrayRef<const llvm::Record *> records,
ArrayRef<const Record *> records,
SmallVectorImpl<AttrOrTypeDef> &resultDefs) {
// Nothing to do if no defs were found.
if (records.empty())
return;

auto defs = llvm::map_range(
records, [&](const llvm::Record *rec) { return AttrOrTypeDef(rec); });
records, [&](const Record *rec) { return AttrOrTypeDef(rec); });
if (selectedDialect.empty()) {
// If a dialect was not specified, ensure that all found defs belong to the
// same dialect.
Expand Down Expand Up @@ -690,15 +692,14 @@ class DefGenerator {
bool emitDefs(StringRef selectedDialect);

protected:
DefGenerator(ArrayRef<const llvm::Record *> defs, raw_ostream &os,
DefGenerator(ArrayRef<const Record *> defs, raw_ostream &os,
StringRef defType, StringRef valueType, bool isAttrGenerator)
: defRecords(defs), os(os), defType(defType), valueType(valueType),
isAttrGenerator(isAttrGenerator) {
// Sort by occurrence in file.
llvm::sort(defRecords,
[](const llvm::Record *lhs, const llvm::Record *rhs) {
return lhs->getID() < rhs->getID();
});
llvm::sort(defRecords, [](const Record *lhs, const Record *rhs) {
return lhs->getID() < rhs->getID();
});
}

/// Emit the list of def type names.
Expand All @@ -707,7 +708,7 @@ class DefGenerator {
void emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs);

/// The set of def records to emit.
std::vector<const llvm::Record *> defRecords;
std::vector<const Record *> defRecords;
jurahul marked this conversation as resolved.
Show resolved Hide resolved
/// The attribute or type class to emit.
/// The stream to emit to.
raw_ostream &os;
Expand All @@ -722,13 +723,13 @@ class DefGenerator {

/// A specialized generator for AttrDefs.
struct AttrDefGenerator : public DefGenerator {
AttrDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
AttrDefGenerator(const RecordKeeper &records, raw_ostream &os)
: DefGenerator(records.getAllDerivedDefinitionsIfDefined("AttrDef"), os,
"Attr", "Attribute", /*isAttrGenerator=*/true) {}
};
/// A specialized generator for TypeDefs.
struct TypeDefGenerator : public DefGenerator {
TypeDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
TypeDefGenerator(const RecordKeeper &records, raw_ostream &os)
: DefGenerator(records.getAllDerivedDefinitionsIfDefined("TypeDef"), os,
"Type", "Type", /*isAttrGenerator=*/false) {}
};
Expand Down Expand Up @@ -1030,9 +1031,9 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {

/// Find all type constraints for which a C++ function should be generated.
static std::vector<Constraint>
getAllTypeConstraints(const llvm::RecordKeeper &records) {
getAllTypeConstraints(const RecordKeeper &records) {
std::vector<Constraint> result;
for (const llvm::Record *def :
for (const Record *def :
records.getAllDerivedDefinitionsIfDefined("TypeConstraint")) {
// Ignore constraints defined outside of the top-level file.
if (llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
Expand All @@ -1047,7 +1048,7 @@ getAllTypeConstraints(const llvm::RecordKeeper &records) {
return result;
}

static void emitTypeConstraintDecls(const llvm::RecordKeeper &records,
static void emitTypeConstraintDecls(const RecordKeeper &records,
raw_ostream &os) {
static const char *const typeConstraintDecl = R"(
bool {0}(::mlir::Type type);
Expand All @@ -1057,7 +1058,7 @@ bool {0}(::mlir::Type type);
os << strfmt(typeConstraintDecl, *constr.getCppFunctionName());
}

static void emitTypeConstraintDefs(const llvm::RecordKeeper &records,
static void emitTypeConstraintDefs(const RecordKeeper &records,
raw_ostream &os) {
static const char *const typeConstraintDef = R"(
bool {0}(::mlir::Type type) {
Expand Down Expand Up @@ -1088,13 +1089,13 @@ static llvm::cl::opt<std::string>

static mlir::GenRegistration
genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](const RecordKeeper &records, raw_ostream &os) {
AttrDefGenerator generator(records, os);
return generator.emitDefs(attrDialect);
});
static mlir::GenRegistration
genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](const RecordKeeper &records, raw_ostream &os) {
AttrDefGenerator generator(records, os);
return generator.emitDecls(attrDialect);
});
Expand All @@ -1110,28 +1111,28 @@ static llvm::cl::opt<std::string>

static mlir::GenRegistration
genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](const RecordKeeper &records, raw_ostream &os) {
TypeDefGenerator generator(records, os);
return generator.emitDefs(typeDialect);
});
static mlir::GenRegistration
genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](const RecordKeeper &records, raw_ostream &os) {
TypeDefGenerator generator(records, os);
return generator.emitDecls(typeDialect);
});

static mlir::GenRegistration
genTypeConstrDefs("gen-type-constraint-defs",
"Generate type constraint definitions",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](const RecordKeeper &records, raw_ostream &os) {
emitTypeConstraintDefs(records, os);
return false;
});
static mlir::GenRegistration
genTypeConstrDecls("gen-type-constraint-decls",
"Generate type constraint declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](const RecordKeeper &records, raw_ostream &os) {
emitTypeConstraintDecls(records, os);
return false;
});
11 changes: 5 additions & 6 deletions mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@

using namespace llvm;

static llvm::cl::OptionCategory dialectGenCat("Options for -gen-bytecode");
static llvm::cl::opt<std::string>
selectedBcDialect("bytecode-dialect",
llvm::cl::desc("The dialect to gen for"),
llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated);
static cl::OptionCategory dialectGenCat("Options for -gen-bytecode");
static cl::opt<std::string>
selectedBcDialect("bytecode-dialect", cl::desc("The dialect to gen for"),
cl::cat(dialectGenCat), cl::CommaSeparated);

namespace {

Expand Down Expand Up @@ -306,7 +305,7 @@ void Generator::emitPrint(StringRef kind, StringRef type,
auto funScope = os.scope("{\n", "}\n\n");

// Check that predicates specified if multiple bytecode instances.
for (const llvm::Record *rec : make_second_range(vec)) {
for (const Record *rec : make_second_range(vec)) {
StringRef pred = rec->getValueAsString("printerPredicate");
if (vec.size() > 1 && pred.empty()) {
for (auto [index, rec] : vec) {
Expand Down
22 changes: 11 additions & 11 deletions mlir/tools/mlir-tblgen/DialectGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

using namespace mlir;
using namespace mlir::tblgen;
using llvm::Record;
using llvm::RecordKeeper;

static llvm::cl::OptionCategory dialectGenCat("Options for -gen-dialect-*");
llvm::cl::opt<std::string>
Expand All @@ -39,8 +41,8 @@ llvm::cl::opt<std::string>
/// Utility iterator used for filtering records for a specific dialect.
namespace {
using DialectFilterIterator =
llvm::filter_iterator<ArrayRef<llvm::Record *>::iterator,
std::function<bool(const llvm::Record *)>>;
llvm::filter_iterator<ArrayRef<Record *>::iterator,
std::function<bool(const Record *)>>;
} // namespace

static void populateDiscardableAttributes(
Expand All @@ -62,8 +64,8 @@ static void populateDiscardableAttributes(
/// the given dialect.
template <typename T>
static iterator_range<DialectFilterIterator>
filterForDialect(ArrayRef<llvm::Record *> records, Dialect &dialect) {
auto filterFn = [&](const llvm::Record *record) {
filterForDialect(ArrayRef<Record *> records, Dialect &dialect) {
auto filterFn = [&](const Record *record) {
return T(record).getDialect() == dialect;
};
return {DialectFilterIterator(records.begin(), records.end(), filterFn),
Expand Down Expand Up @@ -295,7 +297,7 @@ static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
<< "::" << dialect.getCppClassName() << ")\n";
}

static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
static bool emitDialectDecls(const RecordKeeper &recordKeeper,
raw_ostream &os) {
emitSourceFileHeader("Dialect Declarations", os, recordKeeper);

Expand Down Expand Up @@ -340,8 +342,7 @@ static const char *const dialectDestructorStr = R"(

)";

static void emitDialectDef(Dialect &dialect,
const llvm::RecordKeeper &recordKeeper,
static void emitDialectDef(Dialect &dialect, const RecordKeeper &recordKeeper,
raw_ostream &os) {
std::string cppClassName = dialect.getCppClassName();

Expand Down Expand Up @@ -389,8 +390,7 @@ static void emitDialectDef(Dialect &dialect,
os << llvm::formatv(dialectDestructorStr, cppClassName);
}

static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,
raw_ostream &os) {
static bool emitDialectDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
emitSourceFileHeader("Dialect Definitions", os, recordKeeper);

auto dialectDefs = recordKeeper.getAllDerivedDefinitions("Dialect");
Expand All @@ -411,12 +411,12 @@ static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,

static mlir::GenRegistration
genDialectDecls("gen-dialect-decls", "Generate dialect declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](const RecordKeeper &records, raw_ostream &os) {
return emitDialectDecls(records, os);
});

static mlir::GenRegistration
genDialectDefs("gen-dialect-defs", "Generate dialect definitions",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](const RecordKeeper &records, raw_ostream &os) {
return emitDialectDefs(records, os);
});
84 changes: 41 additions & 43 deletions mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@

using namespace mlir;
using namespace mlir::tblgen;
using llvm::formatv;
using llvm::Record;
using llvm::RecordKeeper;

/// File header and includes.
constexpr const char *fileHeader = R"Py(
Expand All @@ -42,44 +45,42 @@ static std::string makePythonEnumCaseName(StringRef name) {

/// Emits the Python class for the given enum.
static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
os << llvm::formatv("class {0}({1}):\n", enumAttr.getEnumClassName(),
enumAttr.isBitEnum() ? "IntFlag" : "IntEnum");
os << formatv("class {0}({1}):\n", enumAttr.getEnumClassName(),
enumAttr.isBitEnum() ? "IntFlag" : "IntEnum");
if (!enumAttr.getSummary().empty())
os << llvm::formatv(" \"\"\"{0}\"\"\"\n", enumAttr.getSummary());
os << formatv(" \"\"\"{0}\"\"\"\n", enumAttr.getSummary());
os << "\n";

for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
os << llvm::formatv(
" {0} = {1}\n", makePythonEnumCaseName(enumCase.getSymbol()),
enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue())
: "auto()");
os << formatv(" {0} = {1}\n",
makePythonEnumCaseName(enumCase.getSymbol()),
enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue())
: "auto()");
}

os << "\n";

if (enumAttr.isBitEnum()) {
os << llvm::formatv(" def __iter__(self):\n"
" return iter([case for case in type(self) if "
"(self & case) is case])\n");
os << llvm::formatv(" def __len__(self):\n"
" return bin(self).count(\"1\")\n");
os << formatv(" def __iter__(self):\n"
" return iter([case for case in type(self) if "
"(self & case) is case])\n");
os << formatv(" def __len__(self):\n"
" return bin(self).count(\"1\")\n");
os << "\n";
}

os << llvm::formatv(" def __str__(self):\n");
os << formatv(" def __str__(self):\n");
if (enumAttr.isBitEnum())
os << llvm::formatv(" if len(self) > 1:\n"
" return \"{0}\".join(map(str, self))\n",
enumAttr.getDef().getValueAsString("separator"));
os << formatv(" if len(self) > 1:\n"
" return \"{0}\".join(map(str, self))\n",
enumAttr.getDef().getValueAsString("separator"));
for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
os << llvm::formatv(" if self is {0}.{1}:\n",
enumAttr.getEnumClassName(),
makePythonEnumCaseName(enumCase.getSymbol()));
os << llvm::formatv(" return \"{0}\"\n", enumCase.getStr());
os << formatv(" if self is {0}.{1}:\n", enumAttr.getEnumClassName(),
makePythonEnumCaseName(enumCase.getSymbol()));
os << formatv(" return \"{0}\"\n", enumCase.getStr());
}
os << llvm::formatv(
" raise ValueError(\"Unknown {0} enum entry.\")\n\n\n",
enumAttr.getEnumClassName());
os << formatv(" raise ValueError(\"Unknown {0} enum entry.\")\n\n\n",
enumAttr.getEnumClassName());
os << "\n";
}

Expand All @@ -105,15 +106,13 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
return true;
}

os << llvm::formatv("@register_attribute_builder(\"{0}\")\n",
enumAttr.getAttrDefName());
os << llvm::formatv("def _{0}(x, context):\n",
enumAttr.getAttrDefName().lower());
os << llvm::formatv(
" return "
"_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, "
"context=context), int(x))\n\n",
bitwidth);
os << formatv("@register_attribute_builder(\"{0}\")\n",
enumAttr.getAttrDefName());
os << formatv("def _{0}(x, context):\n", enumAttr.getAttrDefName().lower());
os << formatv(" return "
"_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, "
"context=context), int(x))\n\n",
bitwidth);
return false;
}

Expand All @@ -123,26 +122,25 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
StringRef formatString,
raw_ostream &os) {
os << llvm::formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
os << llvm::formatv("def _{0}(x, context):\n", attrDefName.lower());
os << llvm::formatv(" return "
"_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
formatString);
os << formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
os << formatv("def _{0}(x, context):\n", attrDefName.lower());
os << formatv(" return "
"_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
formatString);
return false;
}

/// Emits Python bindings for all enums in the record keeper. Returns
/// `false` on success, `true` on failure.
static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
raw_ostream &os) {
static bool emitPythonEnums(const RecordKeeper &recordKeeper, raw_ostream &os) {
os << fileHeader;
for (const llvm::Record *it :
for (const Record *it :
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
EnumAttr enumAttr(*it);
emitEnumClass(enumAttr, os);
emitAttributeBuilder(enumAttr, os);
}
for (const llvm::Record *it :
for (const Record *it :
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
AttrOrTypeDef attr(&*it);
if (!attr.getMnemonic()) {
Expand All @@ -156,11 +154,11 @@ static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
if (assemblyFormat == "`<` $value `>`") {
emitDialectEnumAttributeBuilder(
attr.getName(),
llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
} else if (assemblyFormat == "$value") {
emitDialectEnumAttributeBuilder(
attr.getName(),
llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
} else {
llvm::errs()
<< "unsupported assembly format for python enum bindings generation";
Expand Down
Loading
Loading