Skip to content

Commit

Permalink
[TableGen] Add const variants of accessors for backend (llvm#106658)
Browse files Browse the repository at this point in the history
Split RecordKeeper `getAllDerivedDefinitions` family of functions into
two variants:
  (a) non-const ones that return vectors of `Record *` and 
  (b) const ones, that return vector/ArrayRef of `const Record *`.

This will help gradual migration of TableGen backends to use 
`const RecordKeeper` and by implication change code to work 
with const pointers and better const correctness.

Existing backends are not yet compatible with the const family of
functions, so change them to use a non-constant `RecordKeeper` 
reference, till they are migrated.
  • Loading branch information
jurahul authored and VitaNuo committed Sep 12, 2024
1 parent ca322cc commit e05b9d7
Show file tree
Hide file tree
Showing 19 changed files with 134 additions and 76 deletions.
4 changes: 2 additions & 2 deletions clang/utils/TableGen/ClangAttrEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ static StringRef NormalizeGNUAttrSpelling(StringRef AttrSpelling) {

typedef std::vector<std::pair<std::string, const Record *>> ParsedAttrMap;

static ParsedAttrMap getParsedAttrList(const RecordKeeper &Records,
static ParsedAttrMap getParsedAttrList(RecordKeeper &Records,
ParsedAttrMap *Dupes = nullptr,
bool SemaOnly = true) {
std::vector<Record *> Attrs = Records.getAllDerivedDefinitions("Attr");
Expand Down Expand Up @@ -4344,7 +4344,7 @@ static void GenerateAppertainsTo(const Record &Attr, raw_ostream &OS) {
// written into OS and the checks for merging declaration attributes are
// written into MergeOS.
static void GenerateMutualExclusionsChecks(const Record &Attr,
const RecordKeeper &Records,
RecordKeeper &Records,
raw_ostream &OS,
raw_ostream &MergeDeclOS,
raw_ostream &MergeStmtOS) {
Expand Down
2 changes: 1 addition & 1 deletion clang/utils/TableGen/ClangSyntaxEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ using llvm::formatv;
// stable and useful way, where abstract Node subclasses correspond to ranges.
class Hierarchy {
public:
Hierarchy(const llvm::RecordKeeper &Records) {
Hierarchy(llvm::RecordKeeper &Records) {
for (llvm::Record *T : Records.getAllDerivedDefinitions("NodeType"))
add(T);
for (llvm::Record *Derived : Records.getAllDerivedDefinitions("NodeType"))
Expand Down
5 changes: 2 additions & 3 deletions llvm/include/llvm/TableGen/DirectiveEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ namespace llvm {
// DirectiveBase.td and provides helper methods for accessing it.
class DirectiveLanguage {
public:
explicit DirectiveLanguage(const llvm::RecordKeeper &Records)
: Records(Records) {
explicit DirectiveLanguage(llvm::RecordKeeper &Records) : Records(Records) {
const auto &DirectiveLanguages = getDirectiveLanguages();
Def = DirectiveLanguages[0];
}
Expand Down Expand Up @@ -71,7 +70,7 @@ class DirectiveLanguage {

private:
const llvm::Record *Def;
const llvm::RecordKeeper &Records;
llvm::RecordKeeper &Records;

std::vector<Record *> getDirectiveLanguages() const {
return Records.getAllDerivedDefinitions("DirectiveLanguage");
Expand Down
34 changes: 29 additions & 5 deletions llvm/include/llvm/TableGen/Record.h
Original file line number Diff line number Diff line change
Expand Up @@ -2057,19 +2057,28 @@ class RecordKeeper {
//===--------------------------------------------------------------------===//
// High-level helper methods, useful for tablegen backends.

// Non-const methods return std::vector<Record *> by value or reference.
// Const methods return std::vector<const Record *> by value or
// ArrayRef<const Record *>.

/// Get all the concrete records that inherit from the one specified
/// class. The class must be defined.
std::vector<Record *> getAllDerivedDefinitions(StringRef ClassName) const;
ArrayRef<const Record *> getAllDerivedDefinitions(StringRef ClassName) const;
const std::vector<Record *> &getAllDerivedDefinitions(StringRef ClassName);

/// Get all the concrete records that inherit from all the specified
/// classes. The classes must be defined.
std::vector<Record *> getAllDerivedDefinitions(
ArrayRef<StringRef> ClassNames) const;
std::vector<const Record *>
getAllDerivedDefinitions(ArrayRef<StringRef> ClassNames) const;
std::vector<Record *>
getAllDerivedDefinitions(ArrayRef<StringRef> ClassNames);

/// Get all the concrete records that inherit from specified class, if the
/// class is defined. Returns an empty vector if the class is not defined.
std::vector<Record *>
ArrayRef<const Record *>
getAllDerivedDefinitionsIfDefined(StringRef ClassName) const;
const std::vector<Record *> &
getAllDerivedDefinitionsIfDefined(StringRef ClassName);

void dump() const;

Expand All @@ -2081,9 +2090,24 @@ class RecordKeeper {
RecordKeeper &operator=(RecordKeeper &&) = delete;
RecordKeeper &operator=(const RecordKeeper &) = delete;

// Helper template functions for backend accessors.
template <typename VecTy>
const VecTy &
getAllDerivedDefinitionsImpl(StringRef ClassName,
std::map<std::string, VecTy> &Cache) const;

template <typename VecTy>
VecTy getAllDerivedDefinitionsImpl(ArrayRef<StringRef> ClassNames) const;

template <typename VecTy>
const VecTy &getAllDerivedDefinitionsIfDefinedImpl(
StringRef ClassName, std::map<std::string, VecTy> &Cache) const;

std::string InputFilename;
RecordMap Classes, Defs;
mutable StringMap<std::vector<Record *>> ClassRecordsMap;
mutable std::map<std::string, std::vector<const Record *>>
ClassRecordsMapConst;
mutable std::map<std::string, std::vector<Record *>> ClassRecordsMap;
GlobalMap ExtraGlobals;

// These members are for the phase timing feature. We need a timer group,
Expand Down
65 changes: 51 additions & 14 deletions llvm/lib/TableGen/Record.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3248,46 +3248,83 @@ void RecordKeeper::stopBackendTimer() {
}
}

std::vector<Record *>
RecordKeeper::getAllDerivedDefinitions(StringRef ClassName) const {
template <typename VecTy>
const VecTy &RecordKeeper::getAllDerivedDefinitionsImpl(
StringRef ClassName, std::map<std::string, VecTy> &Cache) const {
// We cache the record vectors for single classes. Many backends request
// the same vectors multiple times.
auto Pair = ClassRecordsMap.try_emplace(ClassName);
auto Pair = Cache.try_emplace(ClassName.str());
if (Pair.second)
Pair.first->second = getAllDerivedDefinitions(ArrayRef(ClassName));
Pair.first->second =
getAllDerivedDefinitionsImpl<VecTy>(ArrayRef(ClassName));

return Pair.first->second;
}

std::vector<Record *> RecordKeeper::getAllDerivedDefinitions(
template <typename VecTy>
VecTy RecordKeeper::getAllDerivedDefinitionsImpl(
ArrayRef<StringRef> ClassNames) const {
SmallVector<Record *, 2> ClassRecs;
std::vector<Record *> Defs;
SmallVector<const Record *, 2> ClassRecs;
VecTy Defs;

assert(ClassNames.size() > 0 && "At least one class must be passed.");
for (const auto &ClassName : ClassNames) {
Record *Class = getClass(ClassName);
const Record *Class = getClass(ClassName);
if (!Class)
PrintFatalError("The class '" + ClassName + "' is not defined\n");
ClassRecs.push_back(Class);
}

for (const auto &OneDef : getDefs()) {
if (all_of(ClassRecs, [&OneDef](const Record *Class) {
return OneDef.second->isSubClassOf(Class);
}))
return OneDef.second->isSubClassOf(Class);
}))
Defs.push_back(OneDef.second.get());
}

llvm::sort(Defs, LessRecord());

return Defs;
}

template <typename VecTy>
const VecTy &RecordKeeper::getAllDerivedDefinitionsIfDefinedImpl(
StringRef ClassName, std::map<std::string, VecTy> &Cache) const {
return getClass(ClassName)
? getAllDerivedDefinitionsImpl<VecTy>(ClassName, Cache)
: Cache[""];
}

ArrayRef<const Record *>
RecordKeeper::getAllDerivedDefinitions(StringRef ClassName) const {
return getAllDerivedDefinitionsImpl<std::vector<const Record *>>(
ClassName, ClassRecordsMapConst);
}

const std::vector<Record *> &
RecordKeeper::getAllDerivedDefinitions(StringRef ClassName) {
return getAllDerivedDefinitionsImpl<std::vector<Record *>>(ClassName,
ClassRecordsMap);
}

std::vector<const Record *>
RecordKeeper::getAllDerivedDefinitions(ArrayRef<StringRef> ClassNames) const {
return getAllDerivedDefinitionsImpl<std::vector<const Record *>>(ClassNames);
}

std::vector<Record *>
RecordKeeper::getAllDerivedDefinitions(ArrayRef<StringRef> ClassNames) {
return getAllDerivedDefinitionsImpl<std::vector<Record *>>(ClassNames);
}

ArrayRef<const Record *>
RecordKeeper::getAllDerivedDefinitionsIfDefined(StringRef ClassName) const {
return getClass(ClassName) ? getAllDerivedDefinitions(ClassName)
: std::vector<Record *>();
return getAllDerivedDefinitionsIfDefinedImpl<std::vector<const Record *>>(
ClassName, ClassRecordsMapConst);
}

const std::vector<Record *> &
RecordKeeper::getAllDerivedDefinitionsIfDefined(StringRef ClassName) {
return getAllDerivedDefinitionsIfDefinedImpl<std::vector<Record *>>(
ClassName, ClassRecordsMap);
}

void RecordKeeper::dumpAllocationStats(raw_ostream &OS) const {
Expand Down
2 changes: 1 addition & 1 deletion llvm/utils/TableGen/Basic/CodeGenIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ CodeGenIntrinsicContext::CodeGenIntrinsicContext(const RecordKeeper &RC) {
CodeGenIntrinsicTable::CodeGenIntrinsicTable(const RecordKeeper &RC) {
CodeGenIntrinsicContext Ctx(RC);

std::vector<Record *> Defs = RC.getAllDerivedDefinitions("Intrinsic");
ArrayRef<const Record *> Defs = RC.getAllDerivedDefinitions("Intrinsic");
Intrinsics.reserve(Defs.size());

for (const Record *Def : Defs)
Expand Down
2 changes: 1 addition & 1 deletion llvm/utils/TableGen/Common/SubtargetFeatureInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ LLVM_DUMP_METHOD void SubtargetFeatureInfo::dump() const {
#endif

std::vector<std::pair<Record *, SubtargetFeatureInfo>>
SubtargetFeatureInfo::getAll(const RecordKeeper &Records) {
SubtargetFeatureInfo::getAll(RecordKeeper &Records) {
std::vector<std::pair<Record *, SubtargetFeatureInfo>> SubtargetFeatures;
std::vector<Record *> AllPredicates =
Records.getAllDerivedDefinitions("Predicate");
Expand Down
2 changes: 1 addition & 1 deletion llvm/utils/TableGen/Common/SubtargetFeatureInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ struct SubtargetFeatureInfo {

void dump() const;
static std::vector<std::pair<Record *, SubtargetFeatureInfo>>
getAll(const RecordKeeper &Records);
getAll(RecordKeeper &Records);

/// Emit the subtarget feature flag definitions.
///
Expand Down
2 changes: 1 addition & 1 deletion llvm/utils/TableGen/ExegesisEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class ExegesisEmitter {
};

static std::map<llvm::StringRef, unsigned>
collectPfmCounters(const RecordKeeper &Records) {
collectPfmCounters(RecordKeeper &Records) {
std::map<llvm::StringRef, unsigned> PfmCounterNameTable;
const auto AddPfmCounterName = [&PfmCounterNameTable](
const Record *PfmCounterDef) {
Expand Down
2 changes: 1 addition & 1 deletion llvm/utils/TableGen/GlobalISelEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ class GlobalISelEmitter final : public GlobalISelMatchTableExecutorEmitter {
private:
std::string ClassName;

const RecordKeeper &RK;
RecordKeeper &RK;
const CodeGenDAGPatterns CGP;
const CodeGenTarget &Target;
CodeGenRegBank &CGRegs;
Expand Down
2 changes: 1 addition & 1 deletion llvm/utils/TableGen/SubtargetEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1545,7 +1545,7 @@ void SubtargetEmitter::EmitSchedModel(raw_ostream &OS) {
EmitProcessorModels(OS);
}

static void emitPredicateProlog(const RecordKeeper &Records, raw_ostream &OS) {
static void emitPredicateProlog(RecordKeeper &Records, raw_ostream &OS) {
std::string Buffer;
raw_string_ostream Stream(Buffer);

Expand Down
2 changes: 1 addition & 1 deletion llvm/utils/TableGen/TableGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ static void PrintEnums(RecordKeeper &Records, raw_ostream &OS) {
static void PrintSets(const RecordKeeper &Records, raw_ostream &OS) {
SetTheory Sets;
Sets.addFieldExpander("Set", "Elements");
for (Record *Rec : Records.getAllDerivedDefinitions("Set")) {
for (const Record *Rec : Records.getAllDerivedDefinitions("Set")) {
OS << Rec->getName() << " = [";
const std::vector<Record *> *Elts = Sets.expand(Rec);
assert(Elts && "Couldn't expand Set instance");
Expand Down
6 changes: 3 additions & 3 deletions mlir/include/mlir/TableGen/GenInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ class RecordKeeper;
namespace mlir {

/// Generator function to invoke.
using GenFunction = std::function<bool(const llvm::RecordKeeper &recordKeeper,
raw_ostream &os)>;
using GenFunction =
std::function<bool(llvm::RecordKeeper &recordKeeper, raw_ostream &os)>;

/// Structure to group information about a generator (argument to invoke via
/// mlir-tblgen, description, and generator function).
Expand All @@ -34,7 +34,7 @@ class GenInfo {
: arg(arg), description(description), generator(std::move(generator)) {}

/// Invokes the generator and returns whether the generator failed.
bool invoke(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) const {
bool invoke(llvm::RecordKeeper &recordKeeper, raw_ostream &os) const {
assert(generator && "Cannot call generator with null generator");
return generator(recordKeeper, os);
}
Expand Down
28 changes: 14 additions & 14 deletions mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -690,10 +690,10 @@ class DefGenerator {
bool emitDefs(StringRef selectedDialect);

protected:
DefGenerator(std::vector<llvm::Record *> &&defs, raw_ostream &os,
DefGenerator(const std::vector<llvm::Record *> &defs, raw_ostream &os,
StringRef defType, StringRef valueType, bool isAttrGenerator)
: defRecords(std::move(defs)), os(os), defType(defType),
valueType(valueType), isAttrGenerator(isAttrGenerator) {
: defRecords(defs), os(os), defType(defType), valueType(valueType),
isAttrGenerator(isAttrGenerator) {
// Sort by occurrence in file.
llvm::sort(defRecords, [](llvm::Record *lhs, llvm::Record *rhs) {
return lhs->getID() < rhs->getID();
Expand Down Expand Up @@ -721,13 +721,13 @@ class DefGenerator {

/// A specialized generator for AttrDefs.
struct AttrDefGenerator : public DefGenerator {
AttrDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
AttrDefGenerator(llvm::RecordKeeper &records, raw_ostream &os)
: DefGenerator(records.getAllDerivedDefinitionsIfDefined("AttrDef"), os,
"Attr", "Attribute", /*isAttrGenerator=*/true) {}
};
/// A specialized generator for TypeDefs.
struct TypeDefGenerator : public DefGenerator {
TypeDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
TypeDefGenerator(llvm::RecordKeeper &records, raw_ostream &os)
: DefGenerator(records.getAllDerivedDefinitionsIfDefined("TypeDef"), os,
"Type", "Type", /*isAttrGenerator=*/false) {}
};
Expand Down Expand Up @@ -1029,7 +1029,7 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {

/// Find all type constraints for which a C++ function should be generated.
static std::vector<Constraint>
getAllTypeConstraints(const llvm::RecordKeeper &records) {
getAllTypeConstraints(llvm::RecordKeeper &records) {
std::vector<Constraint> result;
for (llvm::Record *def :
records.getAllDerivedDefinitionsIfDefined("TypeConstraint")) {
Expand All @@ -1046,7 +1046,7 @@ getAllTypeConstraints(const llvm::RecordKeeper &records) {
return result;
}

static void emitTypeConstraintDecls(const llvm::RecordKeeper &records,
static void emitTypeConstraintDecls(llvm::RecordKeeper &records,
raw_ostream &os) {
static const char *const typeConstraintDecl = R"(
bool {0}(::mlir::Type type);
Expand All @@ -1056,7 +1056,7 @@ bool {0}(::mlir::Type type);
os << strfmt(typeConstraintDecl, *constr.getCppFunctionName());
}

static void emitTypeConstraintDefs(const llvm::RecordKeeper &records,
static void emitTypeConstraintDefs(llvm::RecordKeeper &records,
raw_ostream &os) {
static const char *const typeConstraintDef = R"(
bool {0}(::mlir::Type type) {
Expand Down Expand Up @@ -1087,13 +1087,13 @@ static llvm::cl::opt<std::string>

static mlir::GenRegistration
genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](llvm::RecordKeeper &records, raw_ostream &os) {
AttrDefGenerator generator(records, os);
return generator.emitDefs(attrDialect);
});
static mlir::GenRegistration
genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](llvm::RecordKeeper &records, raw_ostream &os) {
AttrDefGenerator generator(records, os);
return generator.emitDecls(attrDialect);
});
Expand All @@ -1109,28 +1109,28 @@ static llvm::cl::opt<std::string>

static mlir::GenRegistration
genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](llvm::RecordKeeper &records, raw_ostream &os) {
TypeDefGenerator generator(records, os);
return generator.emitDefs(typeDialect);
});
static mlir::GenRegistration
genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](llvm::RecordKeeper &records, raw_ostream &os) {
TypeDefGenerator generator(records, os);
return generator.emitDecls(typeDialect);
});

static mlir::GenRegistration
genTypeConstrDefs("gen-type-constraint-defs",
"Generate type constraint definitions",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](llvm::RecordKeeper &records, raw_ostream &os) {
emitTypeConstraintDefs(records, os);
return false;
});
static mlir::GenRegistration
genTypeConstrDecls("gen-type-constraint-decls",
"Generate type constraint declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](llvm::RecordKeeper &records, raw_ostream &os) {
emitTypeConstraintDecls(records, os);
return false;
});
Loading

0 comments on commit e05b9d7

Please sign in to comment.