Skip to content

Commit

Permalink
Revert "[MLIR][TableGen] Use const pointers for various Init objects (
Browse files Browse the repository at this point in the history
#112316)"

This reverts commit 1ae9fe5.
  • Loading branch information
joker-eph authored Oct 16, 2024
1 parent 7c5d5c0 commit de3695a
Show file tree
Hide file tree
Showing 14 changed files with 68 additions and 80 deletions.
2 changes: 1 addition & 1 deletion mlir/include/mlir/TableGen/AttrOrTypeDef.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class AttrOrTypeParameter {
std::optional<StringRef> getDefaultValue() const;

/// Return the underlying def of this parameter.
const llvm::Init *getDef() const;
llvm::Init *getDef() const;

/// The parameter is pointer-comparable.
bool operator==(const AttrOrTypeParameter &other) const {
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/TableGen/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class Dialect {
/// dialect.
bool usePropertiesForAttributes() const;

const llvm::DagInit *getDiscardableAttributes() const;
llvm::DagInit *getDiscardableAttributes() const;

const llvm::Record *getDef() const { return def; }

Expand Down
15 changes: 7 additions & 8 deletions mlir/include/mlir/TableGen/Operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,14 @@ class Operator {

/// A utility iterator over a list of variable decorators.
struct VariableDecoratorIterator
: public llvm::mapped_iterator<const llvm::Init *const *,
VariableDecorator (*)(
const llvm::Init *)> {
: public llvm::mapped_iterator<llvm::Init *const *,
VariableDecorator (*)(llvm::Init *)> {
/// Initializes the iterator to the specified iterator.
VariableDecoratorIterator(const llvm::Init *const *it)
: llvm::mapped_iterator<const llvm::Init *const *,
VariableDecorator (*)(const llvm::Init *)>(
it, &unwrap) {}
static VariableDecorator unwrap(const llvm::Init *init);
VariableDecoratorIterator(llvm::Init *const *it)
: llvm::mapped_iterator<llvm::Init *const *,
VariableDecorator (*)(llvm::Init *)>(it,
&unwrap) {}
static VariableDecorator unwrap(llvm::Init *init);
};
using var_decorator_iterator = VariableDecoratorIterator;
using var_decorator_range = llvm::iterator_range<VariableDecoratorIterator>;
Expand Down
12 changes: 5 additions & 7 deletions mlir/lib/TableGen/AttrOrTypeDef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
auto *builderList =
dyn_cast_or_null<llvm::ListInit>(def->getValueInit("builders"));
if (builderList && !builderList->empty()) {
for (const llvm::Init *init : builderList->getValues()) {
for (llvm::Init *init : builderList->getValues()) {
AttrOrTypeBuilder builder(cast<llvm::DefInit>(init)->getDef(),
def->getLoc());

Expand All @@ -58,8 +58,8 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
if (auto *traitList = def->getValueAsListInit("traits")) {
SmallPtrSet<const llvm::Init *, 32> traitSet;
traits.reserve(traitSet.size());
llvm::unique_function<void(const llvm::ListInit *)> processTraitList =
[&](const llvm::ListInit *traitList) {
llvm::unique_function<void(llvm::ListInit *)> processTraitList =
[&](llvm::ListInit *traitList) {
for (auto *traitInit : *traitList) {
if (!traitSet.insert(traitInit).second)
continue;
Expand Down Expand Up @@ -335,9 +335,7 @@ std::optional<StringRef> AttrOrTypeParameter::getDefaultValue() const {
return result && !result->empty() ? result : std::nullopt;
}

const llvm::Init *AttrOrTypeParameter::getDef() const {
return def->getArg(index);
}
llvm::Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); }

std::optional<Constraint> AttrOrTypeParameter::getConstraint() const {
if (auto *param = dyn_cast<llvm::DefInit>(getDef()))
Expand All @@ -351,7 +349,7 @@ std::optional<Constraint> AttrOrTypeParameter::getConstraint() const {
//===----------------------------------------------------------------------===//

bool AttributeSelfTypeParameter::classof(const AttrOrTypeParameter *param) {
const llvm::Init *paramDef = param->getDef();
llvm::Init *paramDef = param->getDef();
if (auto *paramDefInit = dyn_cast<llvm::DefInit>(paramDef))
return paramDefInit->getDef()->isSubClassOf("AttributeSelfTypeParameter");
return false;
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/TableGen/Attribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ StringRef Attribute::getDerivedCodeBody() const {
Dialect Attribute::getDialect() const {
const llvm::RecordVal *record = def->getValue("dialect");
if (record && record->getValue()) {
if (const DefInit *init = dyn_cast<DefInit>(record->getValue()))
if (DefInit *init = dyn_cast<DefInit>(record->getValue()))
return Dialect(init->getDef());
}
return Dialect(nullptr);
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/TableGen/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ bool Dialect::usePropertiesForAttributes() const {
return def->getValueAsBit("usePropertiesForAttributes");
}

const llvm::DagInit *Dialect::getDiscardableAttributes() const {
llvm::DagInit *Dialect::getDiscardableAttributes() const {
return def->getValueAsDag("discardableAttrs");
}

Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/TableGen/Interfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ using namespace mlir::tblgen;
//===----------------------------------------------------------------------===//

InterfaceMethod::InterfaceMethod(const llvm::Record *def) : def(def) {
const llvm::DagInit *args = def->getValueAsDag("arguments");
llvm::DagInit *args = def->getValueAsDag("arguments");
for (unsigned i = 0, e = args->getNumArgs(); i != e; ++i) {
arguments.push_back(
{llvm::cast<llvm::StringInit>(args->getArg(i))->getValue(),
Expand Down Expand Up @@ -78,7 +78,7 @@ Interface::Interface(const llvm::Record *def) : def(def) {

// Initialize the interface methods.
auto *listInit = dyn_cast<llvm::ListInit>(def->getValueInit("methods"));
for (const llvm::Init *init : listInit->getValues())
for (llvm::Init *init : listInit->getValues())
methods.emplace_back(cast<llvm::DefInit>(init)->getDef());

// Initialize the interface base classes.
Expand All @@ -98,7 +98,7 @@ Interface::Interface(const llvm::Record *def) : def(def) {
baseInterfaces.push_back(std::make_unique<Interface>(baseInterface));
basesAdded.insert(baseInterface.getName());
};
for (const llvm::Init *init : basesInit->getValues())
for (llvm::Init *init : basesInit->getValues())
addBaseInterfaceFn(Interface(cast<llvm::DefInit>(init)->getDef()));
}

Expand Down
21 changes: 10 additions & 11 deletions mlir/lib/TableGen/Operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ std::string Operator::getQualCppClassName() const {
StringRef Operator::getCppNamespace() const { return cppNamespace; }

int Operator::getNumResults() const {
const DagInit *results = def.getValueAsDag("results");
DagInit *results = def.getValueAsDag("results");
return results->getNumArgs();
}

Expand Down Expand Up @@ -198,12 +198,12 @@ auto Operator::getResults() const -> const_value_range {
}

TypeConstraint Operator::getResultTypeConstraint(int index) const {
const DagInit *results = def.getValueAsDag("results");
DagInit *results = def.getValueAsDag("results");
return TypeConstraint(cast<DefInit>(results->getArg(index)));
}

StringRef Operator::getResultName(int index) const {
const DagInit *results = def.getValueAsDag("results");
DagInit *results = def.getValueAsDag("results");
return results->getArgNameStr(index);
}

Expand Down Expand Up @@ -241,7 +241,7 @@ Operator::arg_range Operator::getArgs() const {
}

StringRef Operator::getArgName(int index) const {
const DagInit *argumentValues = def.getValueAsDag("arguments");
DagInit *argumentValues = def.getValueAsDag("arguments");
return argumentValues->getArgNameStr(index);
}

Expand Down Expand Up @@ -557,7 +557,7 @@ void Operator::populateOpStructure() {
auto *opVarClass = recordKeeper.getClass("OpVariable");
numNativeAttributes = 0;

const DagInit *argumentValues = def.getValueAsDag("arguments");
DagInit *argumentValues = def.getValueAsDag("arguments");
unsigned numArgs = argumentValues->getNumArgs();

// Mapping from name of to argument or result index. Arguments are indexed
Expand Down Expand Up @@ -721,8 +721,8 @@ void Operator::populateOpStructure() {
" to precede it in traits list");
};

std::function<void(const llvm::ListInit *)> insert;
insert = [&](const llvm::ListInit *traitList) {
std::function<void(llvm::ListInit *)> insert;
insert = [&](llvm::ListInit *traitList) {
for (auto *traitInit : *traitList) {
auto *def = cast<DefInit>(traitInit)->getDef();
if (def->isSubClassOf("TraitList")) {
Expand Down Expand Up @@ -780,7 +780,7 @@ void Operator::populateOpStructure() {
auto *builderList =
dyn_cast_or_null<llvm::ListInit>(def.getValueInit("builders"));
if (builderList && !builderList->empty()) {
for (const llvm::Init *init : builderList->getValues())
for (llvm::Init *init : builderList->getValues())
builders.emplace_back(cast<llvm::DefInit>(init)->getDef(), def.getLoc());
} else if (skipDefaultBuilders()) {
PrintFatalError(
Expand Down Expand Up @@ -818,8 +818,7 @@ bool Operator::hasAssemblyFormat() const {
}

StringRef Operator::getAssemblyFormat() const {
return TypeSwitch<const llvm::Init *, StringRef>(
def.getValueInit("assemblyFormat"))
return TypeSwitch<llvm::Init *, StringRef>(def.getValueInit("assemblyFormat"))
.Case<llvm::StringInit>([&](auto *init) { return init->getValue(); });
}

Expand All @@ -833,7 +832,7 @@ void Operator::print(llvm::raw_ostream &os) const {
}
}

auto Operator::VariableDecoratorIterator::unwrap(const llvm::Init *init)
auto Operator::VariableDecoratorIterator::unwrap(llvm::Init *init)
-> VariableDecorator {
return VariableDecorator(cast<llvm::DefInit>(init)->getDef());
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/TableGen/Pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ int Pattern::getBenefit() const {
// The initial benefit value is a heuristic with number of ops in the source
// pattern.
int initBenefit = getSourcePattern().getNumOps();
const llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
PrintFatalError(&def,
"The 'addBenefit' takes and only takes one integer value");
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/TableGen/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ std::optional<StringRef> TypeConstraint::getBuilderCall() const {
const llvm::RecordVal *builderCall = baseType->getValue("builderCall");
if (!builderCall || !builderCall->getValue())
return std::nullopt;
return TypeSwitch<const llvm::Init *, std::optional<StringRef>>(
return TypeSwitch<llvm::Init *, std::optional<StringRef>>(
builderCall->getValue())
.Case<llvm::StringInit>([&](auto *init) {
StringRef value = init->getValue();
Expand Down
16 changes: 8 additions & 8 deletions mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ enum DeprecatedAction { None, Warn, Error };
static DeprecatedAction actionOnDeprecatedValue;

// Returns if there is a use of `deprecatedInit` in `field`.
static bool findUse(const Init *field, const Init *deprecatedInit,
llvm::DenseMap<const Init *, bool> &known) {
static bool findUse(Init *field, Init *deprecatedInit,
llvm::DenseMap<Init *, bool> &known) {
if (field == deprecatedInit)
return true;

Expand Down Expand Up @@ -64,13 +64,13 @@ static bool findUse(const Init *field, const Init *deprecatedInit,
if (findUse(dagInit->getOperator(), deprecatedInit, known))
return memoize(true);

return memoize(llvm::any_of(dagInit->getArgs(), [&](const Init *arg) {
return memoize(llvm::any_of(dagInit->getArgs(), [&](Init *arg) {
return findUse(arg, deprecatedInit, known);
}));
}

if (const ListInit *li = dyn_cast<ListInit>(field)) {
return memoize(llvm::any_of(li->getValues(), [&](const Init *jt) {
if (ListInit *li = dyn_cast<ListInit>(field)) {
return memoize(llvm::any_of(li->getValues(), [&](Init *jt) {
return findUse(jt, deprecatedInit, known);
}));
}
Expand All @@ -83,8 +83,8 @@ static bool findUse(const Init *field, const Init *deprecatedInit,
}

// Returns if there is a use of `deprecatedInit` in `record`.
static bool findUse(Record &record, const Init *deprecatedInit,
llvm::DenseMap<const Init *, bool> &known) {
static bool findUse(Record &record, Init *deprecatedInit,
llvm::DenseMap<Init *, bool> &known) {
return llvm::any_of(record.getValues(), [&](const RecordVal &val) {
return findUse(val.getValue(), deprecatedInit, known);
});
Expand All @@ -100,7 +100,7 @@ static void warnOfDeprecatedUses(const RecordKeeper &records) {
if (!r || !r->getValue())
continue;

llvm::DenseMap<const Init *, bool> hasUse;
llvm::DenseMap<Init *, bool> hasUse;
if (auto *si = dyn_cast<StringInit>(r->getValue())) {
for (auto &jt : records.getDefs()) {
// Skip anonymous defs.
Expand Down
38 changes: 17 additions & 21 deletions mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,8 @@ class Generator {
private:
/// Emits parse calls to construct given kind.
void emitParseHelper(StringRef kind, StringRef returnType, StringRef builder,
ArrayRef<const Init *> args,
ArrayRef<std::string> argNames, StringRef failure,
mlir::raw_indented_ostream &ios);
ArrayRef<Init *> args, ArrayRef<std::string> argNames,
StringRef failure, mlir::raw_indented_ostream &ios);

/// Emits print instructions.
void emitPrintHelper(const Record *memberRec, StringRef kind,
Expand Down Expand Up @@ -136,12 +135,10 @@ void Generator::emitParse(StringRef kind, const Record &x) {
R"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )";
mlir::raw_indented_ostream os(output);
std::string returnType = getCType(&x);
os << formatv(head,
kind == "attribute" ? "::mlir::Attribute" : "::mlir::Type",
x.getName());
const DagInit *members = x.getValueAsDag("members");
SmallVector<std::string> argNames = llvm::to_vector(
map_range(members->getArgNames(), [](const StringInit *init) {
os << formatv(head, kind == "attribute" ? "::mlir::Attribute" : "::mlir::Type", x.getName());
DagInit *members = x.getValueAsDag("members");
SmallVector<std::string> argNames =
llvm::to_vector(map_range(members->getArgNames(), [](StringInit *init) {
return init->getAsUnquotedString();
}));
StringRef builder = x.getValueAsString("cBuilder").trim();
Expand All @@ -151,7 +148,7 @@ void Generator::emitParse(StringRef kind, const Record &x) {
}

void printParseConditional(mlir::raw_indented_ostream &ios,
ArrayRef<const Init *> args,
ArrayRef<Init *> args,
ArrayRef<std::string> argNames) {
ios << "if ";
auto parenScope = ios.scope("(", ") {");
Expand All @@ -162,7 +159,7 @@ void printParseConditional(mlir::raw_indented_ostream &ios,
};

auto parsedArgs =
llvm::to_vector(make_filter_range(args, [](const Init *const attr) {
llvm::to_vector(make_filter_range(args, [](Init *const attr) {
const Record *def = cast<DefInit>(attr)->getDef();
if (def->isSubClassOf("Array"))
return true;
Expand All @@ -171,7 +168,7 @@ void printParseConditional(mlir::raw_indented_ostream &ios,

interleave(
zip(parsedArgs, argNames),
[&](std::tuple<const Init *&, const std::string &> it) {
[&](std::tuple<llvm::Init *&, const std::string &> it) {
const Record *attr = cast<DefInit>(std::get<0>(it))->getDef();
std::string parser;
if (auto optParser = attr->getValueAsOptionalString("cParser")) {
Expand Down Expand Up @@ -199,7 +196,7 @@ void printParseConditional(mlir::raw_indented_ostream &ios,
}

void Generator::emitParseHelper(StringRef kind, StringRef returnType,
StringRef builder, ArrayRef<const Init *> args,
StringRef builder, ArrayRef<Init *> args,
ArrayRef<std::string> argNames,
StringRef failure,
mlir::raw_indented_ostream &ios) {
Expand All @@ -213,7 +210,7 @@ void Generator::emitParseHelper(StringRef kind, StringRef returnType,
// Print decls.
std::string lastCType = "";
for (auto [arg, name] : zip(args, argNames)) {
const DefInit *first = dyn_cast<DefInit>(arg);
DefInit *first = dyn_cast<DefInit>(arg);
if (!first)
PrintFatalError("Unexpected type for " + name);
const Record *def = first->getDef();
Expand Down Expand Up @@ -254,14 +251,13 @@ void Generator::emitParseHelper(StringRef kind, StringRef returnType,
std::string returnType = getCType(def);
ios << "auto " << listHelperName(name) << " = [&]() -> FailureOr<"
<< returnType << "> ";
SmallVector<const Init *> args;
SmallVector<Init *> args;
SmallVector<std::string> argNames;
if (def->isSubClassOf("CompositeBytecode")) {
const DagInit *members = def->getValueAsDag("members");
args = llvm::to_vector(map_range(
members->getArgs(), [](Init *init) { return (const Init *)init; }));
DagInit *members = def->getValueAsDag("members");
args = llvm::to_vector(members->getArgs());
argNames = llvm::to_vector(
map_range(members->getArgNames(), [](const StringInit *init) {
map_range(members->getArgNames(), [](StringInit *init) {
return init->getAsUnquotedString();
}));
} else {
Expand Down Expand Up @@ -336,7 +332,7 @@ void Generator::emitPrint(StringRef kind, StringRef type,
auto *members = rec->getValueAsDag("members");
for (auto [arg, name] :
llvm::zip(members->getArgs(), members->getArgNames())) {
const DefInit *def = dyn_cast<DefInit>(arg);
DefInit *def = dyn_cast<DefInit>(arg);
assert(def);
const Record *memberRec = def->getDef();
emitPrintHelper(memberRec, kind, kind, name->getAsUnquotedString(), os);
Expand Down Expand Up @@ -389,7 +385,7 @@ void Generator::emitPrintHelper(const Record *memberRec, StringRef kind,
auto *members = memberRec->getValueAsDag("members");
for (auto [arg, argName] :
zip(members->getArgs(), members->getArgNames())) {
const DefInit *def = dyn_cast<DefInit>(arg);
DefInit *def = dyn_cast<DefInit>(arg);
assert(def);
emitPrintHelper(def->getDef(), kind, parent,
argName->getAsUnquotedString(), ios);
Expand Down
Loading

0 comments on commit de3695a

Please sign in to comment.