From 5c599cb6dfb6fbf5b081b9df164e3bfe869fe824 Mon Sep 17 00:00:00 2001 From: Nikhil Kalra Date: Mon, 29 Jul 2024 14:32:50 -0700 Subject: [PATCH 1/3] add parsing impl --- mlir/include/mlir/Pass/PassOptions.h | 52 +++++++++++++++++++++++----- mlir/lib/Pass/PassRegistry.cpp | 2 +- 2 files changed, 45 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Pass/PassOptions.h b/mlir/include/mlir/Pass/PassOptions.h index 6bffa84f7b16b43..7a090f6b3c2df33 100644 --- a/mlir/include/mlir/Pass/PassOptions.h +++ b/mlir/include/mlir/Pass/PassOptions.h @@ -139,6 +139,26 @@ class PassOptions : protected llvm::cl::SubCommand { } }; + /// This is the parser that is used by pass options that wrap PassOptions + /// instances. Like GenericOptionParser, this is a thin wrapper around + /// llvm::cl::basic_parser. + template > * = 0> + struct PassOptionsParser + : public llvm::cl::basic_parser> { + // Parse the options object by delegating to + // `PassOptionsT::parseFromString`. + bool parse(llvm::cl::Option &, StringRef, StringRef arg, + PassOptionsT &value) { + return succeeded(value.parseFromString(arg)); + } + + // Print the options object by delegating to `PassOptionsT::print`. + static void print(llvm::raw_ostream &os, const PassOptionsT &value) { + value.print(os); + } + }; + /// Utility methods for printing option values. template static void printValue(raw_ostream &os, GenericOptionParser &parser, @@ -153,20 +173,36 @@ class PassOptions : protected llvm::cl::SubCommand { detail::pass_options::printOptionValue(os, value); } -public: /// The specific parser to use depending on llvm::cl parser used. This is only /// necessary because we need to provide additional methods for certain data /// type parsers. /// TODO: We should upstream the methods in GenericOptionParser to avoid the /// need to do this. + + template + struct OptionParserFrom { + using type = llvm::cl::parser; + }; + + template + struct OptionParserFrom< + DataType, + typename std::enable_if_t>>> { + using type = GenericOptionParser; + }; + + template + struct OptionParserFrom{})> { + using type = PassOptionsParser; + }; + +public: template - using OptionParser = - std::conditional_t>::value, - GenericOptionParser, - llvm::cl::parser>; + using OptionParser = typename OptionParserFrom::type; - /// This class represents a specific pass option, with a provided data type. + /// This class represents a specific pass option, with a provided + /// data type. template > class Option : public llvm::cl::opt, @@ -311,7 +347,7 @@ class PassOptions : protected llvm::cl::SubCommand { /// Print the options held by this struct in a form that can be parsed via /// 'parseFromString'. - void print(raw_ostream &os); + void print(raw_ostream &os) const; /// Print the help string for the options held by this struct. `descIndent` is /// the indent that the descriptions should be aligned. diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp index 483cbe80faba6a7..1914985911b4d62 100644 --- a/mlir/lib/Pass/PassRegistry.cpp +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -318,7 +318,7 @@ LogicalResult detail::PassOptions::parseFromString(StringRef options, /// Print the options held by this struct in a form that can be parsed via /// 'parseFromString'. -void detail::PassOptions::print(raw_ostream &os) { +void detail::PassOptions::print(raw_ostream &os) const { // If there are no options, there is nothing left to do. if (OptionsMap.empty()) return; From 1133ee9de2722c59b062ddd272e614bcc24562ca Mon Sep 17 00:00:00 2001 From: Nikhil Kalra Date: Mon, 29 Jul 2024 20:45:41 -0700 Subject: [PATCH 2/3] tests --- mlir/include/mlir/Pass/PassOptions.h | 44 ++++++-------- mlir/lib/Pass/PassRegistry.cpp | 61 +++++++++++--------- mlir/test/Pass/pipeline-options-parsing.mlir | 15 +++-- mlir/test/lib/Pass/TestPassManager.cpp | 54 ++++++++++++++++- 4 files changed, 110 insertions(+), 64 deletions(-) diff --git a/mlir/include/mlir/Pass/PassOptions.h b/mlir/include/mlir/Pass/PassOptions.h index 7a090f6b3c2df33..5b844f4a1cc127a 100644 --- a/mlir/include/mlir/Pass/PassOptions.h +++ b/mlir/include/mlir/Pass/PassOptions.h @@ -142,15 +142,14 @@ class PassOptions : protected llvm::cl::SubCommand { /// This is the parser that is used by pass options that wrap PassOptions /// instances. Like GenericOptionParser, this is a thin wrapper around /// llvm::cl::basic_parser. - template > * = 0> - struct PassOptionsParser - : public llvm::cl::basic_parser> { + template + struct PassOptionsParser : public llvm::cl::basic_parser { + using llvm::cl::basic_parser::basic_parser; // Parse the options object by delegating to // `PassOptionsT::parseFromString`. bool parse(llvm::cl::Option &, StringRef, StringRef arg, PassOptionsT &value) { - return succeeded(value.parseFromString(arg)); + return failed(value.parseFromString(arg)); } // Print the options object by delegating to `PassOptionsT::print`. @@ -173,33 +172,23 @@ class PassOptions : protected llvm::cl::SubCommand { detail::pass_options::printOptionValue(os, value); } +public: /// The specific parser to use depending on llvm::cl parser used. This is only /// necessary because we need to provide additional methods for certain data /// type parsers. /// TODO: We should upstream the methods in GenericOptionParser to avoid the /// need to do this. - - template - struct OptionParserFrom { - using type = llvm::cl::parser; - }; - - template - struct OptionParserFrom< - DataType, - typename std::enable_if_t>>> { - using type = GenericOptionParser; - }; - - template - struct OptionParserFrom{})> { - using type = PassOptionsParser; - }; - -public: template - using OptionParser = typename OptionParserFrom::type; + using OptionParser = std::conditional_t< + // If the data type is derived from PassOptions, use the + // PassOptionsParser. + std::is_base_of_v, PassOptionsParser, + // Otherwise, use GenericOptionParser where it is well formed, and fall + // back to llvm::cl::parser otherwise. + std::conditional_t>::value, + GenericOptionParser, + llvm::cl::parser>>; /// This class represents a specific pass option, with a provided /// data type. @@ -314,11 +303,12 @@ class PassOptions : protected llvm::cl::SubCommand { if ((**this).empty()) return; - os << this->ArgStr << '='; + os << this->ArgStr << "={"; auto printElementFn = [&](const DataType &value) { printValue(os, this->getParser(), value); }; llvm::interleave(*this, os, printElementFn, ","); + os << "}"; } /// Copy the value from the given option into this one. diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp index 1914985911b4d62..387c5340bcc8a04 100644 --- a/mlir/lib/Pass/PassRegistry.cpp +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -12,6 +12,7 @@ #include "mlir/Pass/PassManager.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/Format.h" #include "llvm/Support/ManagedStatic.h" #include "llvm/Support/MemoryBuffer.h" @@ -159,13 +160,36 @@ const PassPipelineInfo *mlir::PassPipelineInfo::lookup(StringRef pipelineArg) { // PassOptions //===----------------------------------------------------------------------===// +/// Extract an argument from 'options' and update it to point after the arg. +/// Returns the cleaned argument string. +static StringRef extractArgAndUpdateOptions(StringRef &options, + size_t argSize) { + StringRef str = options.take_front(argSize).trim(); + options = options.drop_front(argSize).ltrim(); + // Handle escape sequences + if (str.size() > 2) { + const auto escapePairs = {std::make_pair('\'', '\''), + std::make_pair('"', '"'), + std::make_pair('{', '}')}; + for (const auto &escape : escapePairs) { + if (str.front() == escape.first && str.back() == escape.second) { + // Drop the escape characters and trim. + str = str.drop_front().drop_back().trim(); + // Don't process additional escape sequences. + break; + } + } + } + return str; +} + LogicalResult detail::pass_options::parseCommaSeparatedList( llvm::cl::Option &opt, StringRef argName, StringRef optionStr, function_ref elementParseFn) { // Functor used for finding a character in a string, and skipping over // various "range" characters. llvm::unique_function findChar = - [&](StringRef str, size_t index, char c) -> size_t { + [&findChar](StringRef str, size_t index, char c) -> size_t { for (size_t i = index, e = str.size(); i < e; ++i) { if (str[i] == c) return i; @@ -187,13 +211,15 @@ LogicalResult detail::pass_options::parseCommaSeparatedList( size_t nextElePos = findChar(optionStr, 0, ','); while (nextElePos != StringRef::npos) { // Process the portion before the comma. - if (failed(elementParseFn(optionStr.substr(0, nextElePos)))) + if (failed( + elementParseFn(extractArgAndUpdateOptions(optionStr, nextElePos)))) return failure(); - optionStr = optionStr.substr(nextElePos + 1); + optionStr = optionStr.drop_front(); // drop the leading ',' nextElePos = findChar(optionStr, 0, ','); } - return elementParseFn(optionStr.substr(0, nextElePos)); + return elementParseFn( + extractArgAndUpdateOptions(optionStr, optionStr.size())); } /// Out of line virtual function to provide home for the class. @@ -213,27 +239,6 @@ void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) { /// `options` string pointing after the parsed option]. static std::tuple parseNextArg(StringRef options) { - // Functor used to extract an argument from 'options' and update it to point - // after the arg. - auto extractArgAndUpdateOptions = [&](size_t argSize) { - StringRef str = options.take_front(argSize).trim(); - options = options.drop_front(argSize).ltrim(); - // Handle escape sequences - if (str.size() > 2) { - const auto escapePairs = {std::make_pair('\'', '\''), - std::make_pair('"', '"'), - std::make_pair('{', '}')}; - for (const auto &escape : escapePairs) { - if (str.front() == escape.first && str.back() == escape.second) { - // Drop the escape characters and trim. - str = str.drop_front().drop_back().trim(); - // Don't process additional escape sequences. - break; - } - } - } - return str; - }; // Try to process the given punctuation, properly escaping any contained // characters. auto tryProcessPunct = [&](size_t ¤tPos, char punct) { @@ -250,13 +255,13 @@ parseNextArg(StringRef options) { for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) { // Check for the end of the full option. if (argEndIt == optionsE || options[argEndIt] == ' ') { - argName = extractArgAndUpdateOptions(argEndIt); + argName = extractArgAndUpdateOptions(options, argEndIt); return std::make_tuple(argName, StringRef(), options); } // Check for the end of the name and the start of the value. if (options[argEndIt] == '=') { - argName = extractArgAndUpdateOptions(argEndIt); + argName = extractArgAndUpdateOptions(options, argEndIt); options = options.drop_front(); break; } @@ -266,7 +271,7 @@ parseNextArg(StringRef options) { for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) { // Handle the end of the options string. if (argEndIt == optionsE || options[argEndIt] == ' ') { - StringRef value = extractArgAndUpdateOptions(argEndIt); + StringRef value = extractArgAndUpdateOptions(options, argEndIt); return std::make_tuple(argName, value, options); } diff --git a/mlir/test/Pass/pipeline-options-parsing.mlir b/mlir/test/Pass/pipeline-options-parsing.mlir index 50f08241ee5cfac..b6c2b688b7cfb3e 100644 --- a/mlir/test/Pass/pipeline-options-parsing.mlir +++ b/mlir/test/Pass/pipeline-options-parsing.mlir @@ -11,6 +11,8 @@ // RUN: mlir-opt %s -verify-each=false -pass-pipeline='builtin.module(builtin.module(func.func(test-options-pass{list=3}), func.func(test-options-pass{enum=one list=1,2,3,4 string="foo bar baz"})))' -dump-pass-pipeline 2>&1 | FileCheck --check-prefix=CHECK_5 %s // RUN: mlir-opt %s -verify-each=false -pass-pipeline='builtin.module(builtin.module(func.func(test-options-pass{list=3}), func.func(test-options-pass{enum=one list=1,2,3,4 string={foo bar baz}})))' -dump-pass-pipeline 2>&1 | FileCheck --check-prefix=CHECK_5 %s // RUN: mlir-opt %s -verify-each=false -pass-pipeline='builtin.module(builtin.module(func.func(test-options-pass{list=3}), func.func(test-options-pass{enum=one list=1,2,3,4 string=foo"bar"baz})))' -dump-pass-pipeline 2>&1 | FileCheck --check-prefix=CHECK_6 %s +// RUN: mlir-opt %s -verify-each=false '-test-options-super-pass-pipeline=super-list={{enum=zero list=1 string=foo},{enum=one list=2 string="bar"},{enum=two list=3 string={baz}}}' -dump-pass-pipeline 2>&1 | FileCheck --check-prefix=CHECK_7 %s +// RUN: mlir-opt %s -verify-each=false -pass-pipeline='builtin.module(func.func(test-options-super-pass{list={{enum=zero list={1} string=foo },{enum=one list={2} string=bar },{enum=two list={3} string=baz }}}))' -dump-pass-pipeline 2>&1 | FileCheck --check-prefix=CHECK_7 %s // CHECK_ERROR_1: missing closing '}' while processing pass options // CHECK_ERROR_2: no such option test-option @@ -18,9 +20,10 @@ // CHECK_ERROR_4: 'notaninteger' value invalid for integer argument // CHECK_ERROR_5: for the --enum option: Cannot find option named 'invalid'! -// CHECK_1: test-options-pass{enum=zero list=1,2,3,4,5 string=nested_pipeline{arg1=10 arg2=" {} " arg3=true} string-list=a,b,c,d} -// CHECK_2: test-options-pass{enum=one list=1 string= string-list=a,b} -// CHECK_3: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list=3 string= }),func.func(test-options-pass{enum=one list=1,2,3,4 string= }))) -// CHECK_4: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list=3 string= }),func.func(test-options-pass{enum=one list=1,2,3,4 string=foobar }))) -// CHECK_5: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list=3 string= }),func.func(test-options-pass{enum=one list=1,2,3,4 string={foo bar baz} }))) -// CHECK_6: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list=3 string= }),func.func(test-options-pass{enum=one list=1,2,3,4 string=foo"bar"baz }))) +// CHECK_1: test-options-pass{enum=zero list={1,2,3,4,5} string=nested_pipeline{arg1=10 arg2=" {} " arg3=true} string-list={a,b,c,d}} +// CHECK_2: test-options-pass{enum=one list={1} string= string-list={a,b}} +// CHECK_3: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list={3} string= }),func.func(test-options-pass{enum=one list={1,2,3,4} string= }))) +// CHECK_4: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list={3} string= }),func.func(test-options-pass{enum=one list={1,2,3,4} string=foobar }))) +// CHECK_5: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list={3} string= }),func.func(test-options-pass{enum=one list={1,2,3,4} string={foo bar baz} }))) +// CHECK_6: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list={3} string= }),func.func(test-options-pass{enum=one list={1,2,3,4} string=foo"bar"baz }))) +// CHECK_7{LITERAL}: builtin.module(func.func(test-options-super-pass{list={{enum=zero list={1} string=foo },{enum=one list={2} string=bar },{enum=two list={3} string=baz }}})) diff --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp index 2762e2549032459..ee32bec0c79bd42 100644 --- a/mlir/test/lib/Pass/TestPassManager.cpp +++ b/mlir/test/lib/Pass/TestPassManager.cpp @@ -54,7 +54,7 @@ struct TestOptionsPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOptionsPass) - enum Enum { One, Two }; + enum Enum { Zero, One, Two }; struct Options : public PassPipelineOptions { ListOption listOption{*this, "list", @@ -66,7 +66,15 @@ struct TestOptionsPass Option enumOption{ *this, "enum", llvm::cl::desc("Example enum option"), llvm::cl::values(clEnumValN(0, "zero", "Example zero value"), - clEnumValN(1, "one", "Example one value"))}; + clEnumValN(1, "one", "Example one value"), + clEnumValN(2, "two", "Example two value"))}; + + Options() = default; + Options(const Options &rhs) { *this = rhs; } + Options &operator=(const Options &rhs) { + copyOptionValuesFrom(rhs); + return *this; + } }; TestOptionsPass() = default; TestOptionsPass(const TestOptionsPass &) : PassWrapper() {} @@ -92,7 +100,37 @@ struct TestOptionsPass Option enumOption{ *this, "enum", llvm::cl::desc("Example enum option"), llvm::cl::values(clEnumValN(0, "zero", "Example zero value"), - clEnumValN(1, "one", "Example one value"))}; + clEnumValN(1, "one", "Example one value"), + clEnumValN(2, "two", "Example two value"))}; +}; + +struct TestOptionsSuperPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOptionsSuperPass) + + struct Options : public PassPipelineOptions { + ListOption listOption{ + *this, "super-list", + llvm::cl::desc("Example list of PassPipelineOptions option")}; + + Options() = default; + }; + + TestOptionsSuperPass() = default; + TestOptionsSuperPass(const TestOptionsSuperPass &) : PassWrapper() {} + TestOptionsSuperPass(const Options &options) { + listOption = options.listOption; + } + + void runOnOperation() final {} + StringRef getArgument() const final { return "test-options-super-pass"; } + StringRef getDescription() const final { + return "Test options of options parsing capabilities"; + } + + ListOption listOption{ + *this, "list", + llvm::cl::desc("Example list of PassPipelineOptions option")}; }; /// A test pass that always aborts to enable testing the crash recovery @@ -220,6 +258,7 @@ static void testNestedPipelineTextual(OpPassManager &pm) { namespace mlir { void registerPassManagerTestPass() { PassRegistration(); + PassRegistration(); PassRegistration(); @@ -248,5 +287,14 @@ void registerPassManagerTestPass() { [](OpPassManager &pm, const TestOptionsPass::Options &options) { pm.addPass(std::make_unique(options)); }); + + PassPipelineRegistration + registerOptionsSuperPassPipeline( + "test-options-super-pass-pipeline", + "Parses options of PassPipelineOptions using pass pipeline " + "registration", + [](OpPassManager &pm, const TestOptionsSuperPass::Options &options) { + pm.addPass(std::make_unique(options)); + }); } } // namespace mlir From a18fd3a97ebdff4b242b68c8b7ea720e7641156f Mon Sep 17 00:00:00 2001 From: Nikhil Kalra Date: Tue, 30 Jul 2024 18:30:00 -0700 Subject: [PATCH 3/3] address code review comments --- mlir/include/mlir/Pass/PassOptions.h | 9 ++++---- mlir/lib/Pass/PassRegistry.cpp | 31 +++++++++++++++------------- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/mlir/include/mlir/Pass/PassOptions.h b/mlir/include/mlir/Pass/PassOptions.h index 5b844f4a1cc127a..a5a3f1c1c196520 100644 --- a/mlir/include/mlir/Pass/PassOptions.h +++ b/mlir/include/mlir/Pass/PassOptions.h @@ -173,11 +173,8 @@ class PassOptions : protected llvm::cl::SubCommand { } public: - /// The specific parser to use depending on llvm::cl parser used. This is only - /// necessary because we need to provide additional methods for certain data - /// type parsers. - /// TODO: We should upstream the methods in GenericOptionParser to avoid the - /// need to do this. + /// The specific parser to use. This is necessary because we need to provide + /// additional methods for certain data type parsers. template using OptionParser = std::conditional_t< // If the data type is derived from PassOptions, use the @@ -185,6 +182,8 @@ class PassOptions : protected llvm::cl::SubCommand { std::is_base_of_v, PassOptionsParser, // Otherwise, use GenericOptionParser where it is well formed, and fall // back to llvm::cl::parser otherwise. + // TODO: We should upstream the methods in GenericOptionParser to avoid + // the need to do this. std::conditional_t>::value, GenericOptionParser, diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp index 387c5340bcc8a04..4ac425b410cc3f5 100644 --- a/mlir/lib/Pass/PassRegistry.cpp +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -166,20 +166,22 @@ static StringRef extractArgAndUpdateOptions(StringRef &options, size_t argSize) { StringRef str = options.take_front(argSize).trim(); options = options.drop_front(argSize).ltrim(); - // Handle escape sequences - if (str.size() > 2) { - const auto escapePairs = {std::make_pair('\'', '\''), - std::make_pair('"', '"'), - std::make_pair('{', '}')}; - for (const auto &escape : escapePairs) { - if (str.front() == escape.first && str.back() == escape.second) { - // Drop the escape characters and trim. - str = str.drop_front().drop_back().trim(); - // Don't process additional escape sequences. - break; - } + + // Early exit if there's no escape sequence. + if (str.size() <= 2) + return str; + + const auto escapePairs = {std::make_pair('\'', '\''), + std::make_pair('"', '"'), std::make_pair('{', '}')}; + for (const auto &escape : escapePairs) { + if (str.front() == escape.first && str.back() == escape.second) { + // Drop the escape characters and trim. + str = str.drop_front().drop_back().trim(); + // Don't process additional escape sequences. + break; } } + return str; } @@ -189,7 +191,7 @@ LogicalResult detail::pass_options::parseCommaSeparatedList( // Functor used for finding a character in a string, and skipping over // various "range" characters. llvm::unique_function findChar = - [&findChar](StringRef str, size_t index, char c) -> size_t { + [&](StringRef str, size_t index, char c) -> size_t { for (size_t i = index, e = str.size(); i < e; ++i) { if (str[i] == c) return i; @@ -215,7 +217,8 @@ LogicalResult detail::pass_options::parseCommaSeparatedList( elementParseFn(extractArgAndUpdateOptions(optionStr, nextElePos)))) return failure(); - optionStr = optionStr.drop_front(); // drop the leading ',' + // Drop the leading ',' + optionStr = optionStr.drop_front(); nextElePos = findChar(optionStr, 0, ','); } return elementParseFn(