Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Add support for parsing nested PassPipelineOptions #101118

Merged
merged 4 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 35 additions & 10 deletions mlir/include/mlir/Pass/PassOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,25 @@ class PassOptions : protected llvm::cl::SubCommand {
}
};

/// This is the parser that is used by pass options that wrap PassOptions
/// instances. Like GenericOptionParser, this is a thin wrapper around
/// llvm::cl::basic_parser.
template <typename PassOptionsT>
struct PassOptionsParser : public llvm::cl::basic_parser<PassOptionsT> {
using llvm::cl::basic_parser<PassOptionsT>::basic_parser;
// Parse the options object by delegating to
// `PassOptionsT::parseFromString`.
bool parse(llvm::cl::Option &, StringRef, StringRef arg,
PassOptionsT &value) {
return failed(value.parseFromString(arg));
}

// Print the options object by delegating to `PassOptionsT::print`.
static void print(llvm::raw_ostream &os, const PassOptionsT &value) {
value.print(os);
}
};

/// Utility methods for printing option values.
template <typename DataT>
static void printValue(raw_ostream &os, GenericOptionParser<DataT> &parser,
Expand All @@ -154,19 +173,24 @@ class PassOptions : protected llvm::cl::SubCommand {
}

public:
/// The specific parser to use depending on llvm::cl parser used. This is only
/// necessary because we need to provide additional methods for certain data
/// type parsers.
/// TODO: We should upstream the methods in GenericOptionParser to avoid the
/// need to do this.
/// The specific parser to use. This is necessary because we need to provide
/// additional methods for certain data type parsers.
template <typename DataType>
using OptionParser =
using OptionParser = std::conditional_t<
// If the data type is derived from PassOptions, use the
// PassOptionsParser.
std::is_base_of_v<PassOptions, DataType>, PassOptionsParser<DataType>,
// Otherwise, use GenericOptionParser where it is well formed, and fall
// back to llvm::cl::parser otherwise.
// TODO: We should upstream the methods in GenericOptionParser to avoid
// the need to do this.
std::conditional_t<std::is_base_of<llvm::cl::generic_parser_base,
llvm::cl::parser<DataType>>::value,
GenericOptionParser<DataType>,
llvm::cl::parser<DataType>>;
llvm::cl::parser<DataType>>>;

/// This class represents a specific pass option, with a provided data type.
/// This class represents a specific pass option, with a provided
/// data type.
template <typename DataType, typename OptionParser = OptionParser<DataType>>
class Option
: public llvm::cl::opt<DataType, /*ExternalStorage=*/false, OptionParser>,
Expand Down Expand Up @@ -278,11 +302,12 @@ class PassOptions : protected llvm::cl::SubCommand {
if ((**this).empty())
return;

os << this->ArgStr << '=';
os << this->ArgStr << "={";
auto printElementFn = [&](const DataType &value) {
printValue(os, this->getParser(), value);
};
llvm::interleave(*this, os, printElementFn, ",");
os << "}";
}

/// Copy the value from the given option into this one.
Expand Down Expand Up @@ -311,7 +336,7 @@ class PassOptions : protected llvm::cl::SubCommand {

/// Print the options held by this struct in a form that can be parsed via
/// 'parseFromString'.
void print(raw_ostream &os);
void print(raw_ostream &os) const;

/// Print the help string for the options held by this struct. `descIndent` is
/// the indent that the descriptions should be aligned.
Expand Down
64 changes: 36 additions & 28 deletions mlir/lib/Pass/PassRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/Pass/PassManager.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/MemoryBuffer.h"
Expand Down Expand Up @@ -159,6 +160,31 @@ const PassPipelineInfo *mlir::PassPipelineInfo::lookup(StringRef pipelineArg) {
// PassOptions
//===----------------------------------------------------------------------===//

/// Extract an argument from 'options' and update it to point after the arg.
/// Returns the cleaned argument string.
static StringRef extractArgAndUpdateOptions(StringRef &options,
size_t argSize) {
StringRef str = options.take_front(argSize).trim();
options = options.drop_front(argSize).ltrim();

// Early exit if there's no escape sequence.
if (str.size() <= 2)
return str;

const auto escapePairs = {std::make_pair('\'', '\''),
std::make_pair('"', '"'), std::make_pair('{', '}')};
for (const auto &escape : escapePairs) {
if (str.front() == escape.first && str.back() == escape.second) {
// Drop the escape characters and trim.
str = str.drop_front().drop_back().trim();
// Don't process additional escape sequences.
break;
}
}

return str;
}

LogicalResult detail::pass_options::parseCommaSeparatedList(
llvm::cl::Option &opt, StringRef argName, StringRef optionStr,
function_ref<LogicalResult(StringRef)> elementParseFn) {
Expand Down Expand Up @@ -187,13 +213,16 @@ LogicalResult detail::pass_options::parseCommaSeparatedList(
size_t nextElePos = findChar(optionStr, 0, ',');
while (nextElePos != StringRef::npos) {
// Process the portion before the comma.
if (failed(elementParseFn(optionStr.substr(0, nextElePos))))
if (failed(
elementParseFn(extractArgAndUpdateOptions(optionStr, nextElePos))))
return failure();

optionStr = optionStr.substr(nextElePos + 1);
// Drop the leading ','
optionStr = optionStr.drop_front();
nextElePos = findChar(optionStr, 0, ',');
}
return elementParseFn(optionStr.substr(0, nextElePos));
return elementParseFn(
extractArgAndUpdateOptions(optionStr, optionStr.size()));
}

/// Out of line virtual function to provide home for the class.
Expand All @@ -213,27 +242,6 @@ void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) {
/// `options` string pointing after the parsed option].
static std::tuple<StringRef, StringRef, StringRef>
parseNextArg(StringRef options) {
// Functor used to extract an argument from 'options' and update it to point
// after the arg.
auto extractArgAndUpdateOptions = [&](size_t argSize) {
StringRef str = options.take_front(argSize).trim();
options = options.drop_front(argSize).ltrim();
// Handle escape sequences
if (str.size() > 2) {
const auto escapePairs = {std::make_pair('\'', '\''),
std::make_pair('"', '"'),
std::make_pair('{', '}')};
for (const auto &escape : escapePairs) {
if (str.front() == escape.first && str.back() == escape.second) {
// Drop the escape characters and trim.
str = str.drop_front().drop_back().trim();
// Don't process additional escape sequences.
break;
}
}
}
return str;
};
// Try to process the given punctuation, properly escaping any contained
// characters.
auto tryProcessPunct = [&](size_t &currentPos, char punct) {
Expand All @@ -250,13 +258,13 @@ parseNextArg(StringRef options) {
for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
// Check for the end of the full option.
if (argEndIt == optionsE || options[argEndIt] == ' ') {
argName = extractArgAndUpdateOptions(argEndIt);
argName = extractArgAndUpdateOptions(options, argEndIt);
return std::make_tuple(argName, StringRef(), options);
}

// Check for the end of the name and the start of the value.
if (options[argEndIt] == '=') {
argName = extractArgAndUpdateOptions(argEndIt);
argName = extractArgAndUpdateOptions(options, argEndIt);
options = options.drop_front();
break;
}
Expand All @@ -266,7 +274,7 @@ parseNextArg(StringRef options) {
for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
// Handle the end of the options string.
if (argEndIt == optionsE || options[argEndIt] == ' ') {
StringRef value = extractArgAndUpdateOptions(argEndIt);
StringRef value = extractArgAndUpdateOptions(options, argEndIt);
return std::make_tuple(argName, value, options);
}

Expand Down Expand Up @@ -318,7 +326,7 @@ LogicalResult detail::PassOptions::parseFromString(StringRef options,

/// Print the options held by this struct in a form that can be parsed via
/// 'parseFromString'.
void detail::PassOptions::print(raw_ostream &os) {
void detail::PassOptions::print(raw_ostream &os) const {
// If there are no options, there is nothing left to do.
if (OptionsMap.empty())
return;
Expand Down
15 changes: 9 additions & 6 deletions mlir/test/Pass/pipeline-options-parsing.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,19 @@
// RUN: mlir-opt %s -verify-each=false -pass-pipeline='builtin.module(builtin.module(func.func(test-options-pass{list=3}), func.func(test-options-pass{enum=one list=1,2,3,4 string="foo bar baz"})))' -dump-pass-pipeline 2>&1 | FileCheck --check-prefix=CHECK_5 %s
// RUN: mlir-opt %s -verify-each=false -pass-pipeline='builtin.module(builtin.module(func.func(test-options-pass{list=3}), func.func(test-options-pass{enum=one list=1,2,3,4 string={foo bar baz}})))' -dump-pass-pipeline 2>&1 | FileCheck --check-prefix=CHECK_5 %s
// RUN: mlir-opt %s -verify-each=false -pass-pipeline='builtin.module(builtin.module(func.func(test-options-pass{list=3}), func.func(test-options-pass{enum=one list=1,2,3,4 string=foo"bar"baz})))' -dump-pass-pipeline 2>&1 | FileCheck --check-prefix=CHECK_6 %s
// RUN: mlir-opt %s -verify-each=false '-test-options-super-pass-pipeline=super-list={{enum=zero list=1 string=foo},{enum=one list=2 string="bar"},{enum=two list=3 string={baz}}}' -dump-pass-pipeline 2>&1 | FileCheck --check-prefix=CHECK_7 %s
// RUN: mlir-opt %s -verify-each=false -pass-pipeline='builtin.module(func.func(test-options-super-pass{list={{enum=zero list={1} string=foo },{enum=one list={2} string=bar },{enum=two list={3} string=baz }}}))' -dump-pass-pipeline 2>&1 | FileCheck --check-prefix=CHECK_7 %s

// CHECK_ERROR_1: missing closing '}' while processing pass options
// CHECK_ERROR_2: no such option test-option
// CHECK_ERROR_3: no such option invalid-option
// CHECK_ERROR_4: 'notaninteger' value invalid for integer argument
// CHECK_ERROR_5: for the --enum option: Cannot find option named 'invalid'!

// CHECK_1: test-options-pass{enum=zero list=1,2,3,4,5 string=nested_pipeline{arg1=10 arg2=" {} " arg3=true} string-list=a,b,c,d}
// CHECK_2: test-options-pass{enum=one list=1 string= string-list=a,b}
// CHECK_3: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list=3 string= }),func.func(test-options-pass{enum=one list=1,2,3,4 string= })))
// CHECK_4: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list=3 string= }),func.func(test-options-pass{enum=one list=1,2,3,4 string=foobar })))
// CHECK_5: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list=3 string= }),func.func(test-options-pass{enum=one list=1,2,3,4 string={foo bar baz} })))
// CHECK_6: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list=3 string= }),func.func(test-options-pass{enum=one list=1,2,3,4 string=foo"bar"baz })))
// CHECK_1: test-options-pass{enum=zero list={1,2,3,4,5} string=nested_pipeline{arg1=10 arg2=" {} " arg3=true} string-list={a,b,c,d}}
// CHECK_2: test-options-pass{enum=one list={1} string= string-list={a,b}}
// CHECK_3: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list={3} string= }),func.func(test-options-pass{enum=one list={1,2,3,4} string= })))
// CHECK_4: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list={3} string= }),func.func(test-options-pass{enum=one list={1,2,3,4} string=foobar })))
// CHECK_5: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list={3} string= }),func.func(test-options-pass{enum=one list={1,2,3,4} string={foo bar baz} })))
// CHECK_6: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list={3} string= }),func.func(test-options-pass{enum=one list={1,2,3,4} string=foo"bar"baz })))
// CHECK_7{LITERAL}: builtin.module(func.func(test-options-super-pass{list={{enum=zero list={1} string=foo },{enum=one list={2} string=bar },{enum=two list={3} string=baz }}}))
54 changes: 51 additions & 3 deletions mlir/test/lib/Pass/TestPassManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ struct TestOptionsPass
: public PassWrapper<TestOptionsPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOptionsPass)

enum Enum { One, Two };
enum Enum { Zero, One, Two };

struct Options : public PassPipelineOptions<Options> {
ListOption<int> listOption{*this, "list",
Expand All @@ -66,7 +66,15 @@ struct TestOptionsPass
Option<Enum> enumOption{
*this, "enum", llvm::cl::desc("Example enum option"),
llvm::cl::values(clEnumValN(0, "zero", "Example zero value"),
clEnumValN(1, "one", "Example one value"))};
clEnumValN(1, "one", "Example one value"),
clEnumValN(2, "two", "Example two value"))};

Options() = default;
Options(const Options &rhs) { *this = rhs; }
Options &operator=(const Options &rhs) {
copyOptionValuesFrom(rhs);
return *this;
}
Comment on lines +73 to +77
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add these to the PassPipelineOptions base class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ListOption has an implicitly deleted copy constructor so we're forced to add this boilerplate to the derived classes.

};
TestOptionsPass() = default;
TestOptionsPass(const TestOptionsPass &) : PassWrapper() {}
Expand All @@ -92,7 +100,37 @@ struct TestOptionsPass
Option<Enum> enumOption{
*this, "enum", llvm::cl::desc("Example enum option"),
llvm::cl::values(clEnumValN(0, "zero", "Example zero value"),
clEnumValN(1, "one", "Example one value"))};
clEnumValN(1, "one", "Example one value"),
clEnumValN(2, "two", "Example two value"))};
};

struct TestOptionsSuperPass
: public PassWrapper<TestOptionsSuperPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOptionsSuperPass)

struct Options : public PassPipelineOptions<Options> {
ListOption<TestOptionsPass::Options> listOption{
*this, "super-list",
llvm::cl::desc("Example list of PassPipelineOptions option")};

Options() = default;
};

TestOptionsSuperPass() = default;
TestOptionsSuperPass(const TestOptionsSuperPass &) : PassWrapper() {}
TestOptionsSuperPass(const Options &options) {
listOption = options.listOption;
}

void runOnOperation() final {}
StringRef getArgument() const final { return "test-options-super-pass"; }
StringRef getDescription() const final {
return "Test options of options parsing capabilities";
}

ListOption<TestOptionsPass::Options> listOption{
*this, "list",
llvm::cl::desc("Example list of PassPipelineOptions option")};
};

/// A test pass that always aborts to enable testing the crash recovery
Expand Down Expand Up @@ -220,6 +258,7 @@ static void testNestedPipelineTextual(OpPassManager &pm) {
namespace mlir {
void registerPassManagerTestPass() {
PassRegistration<TestOptionsPass>();
PassRegistration<TestOptionsSuperPass>();

PassRegistration<TestModulePass>();

Expand Down Expand Up @@ -248,5 +287,14 @@ void registerPassManagerTestPass() {
[](OpPassManager &pm, const TestOptionsPass::Options &options) {
pm.addPass(std::make_unique<TestOptionsPass>(options));
});

PassPipelineRegistration<TestOptionsSuperPass::Options>
registerOptionsSuperPassPipeline(
"test-options-super-pass-pipeline",
"Parses options of PassPipelineOptions using pass pipeline "
"registration",
[](OpPassManager &pm, const TestOptionsSuperPass::Options &options) {
pm.addPass(std::make_unique<TestOptionsSuperPass>(options));
});
}
} // namespace mlir
Loading