diff --git a/xla/service/hlo_parser.cc b/xla/service/hlo_parser.cc index 9befbac1b48fc..abf31f1767d7f 100644 --- a/xla/service/hlo_parser.cc +++ b/xla/service/hlo_parser.cc @@ -249,10 +249,8 @@ class HloParserImpl : public HloParser { using BoolList = absl::InlinedVector; explicit HloParserImpl(absl::string_view str, - bool set_to_default_entry_computation_layout = true) - : lexer_(str), - set_to_default_entry_computation_layout_( - set_to_default_entry_computation_layout) {} + const HloParserOptions& options = HloParserOptions()) + : lexer_(str), options_(options) {} // Runs the parser and constructs the resulting HLO in the given (empty) // HloModule. Returns the error status in case an error occurred. @@ -673,7 +671,7 @@ class HloParserImpl : public HloParser { // Used to generate names for anonymous instructions. NameUniquer name_uniquer_{/*separator=*/"."}; - const bool set_to_default_entry_computation_layout_; + const HloParserOptions options_; }; bool SplitToInt64s(absl::string_view s, char delim, std::vector* out) { @@ -917,7 +915,7 @@ bool HloParserImpl::ParseComputationLayout( } while (lexer_.GetKind() != TokKind::kRparen) { Shape param; - if (!ParseShape(¶m, set_to_default_entry_computation_layout_)) { + if (!ParseShape(¶m, options_.fill_missing_module_parameter_layouts())) { return false; } computation_layout->add_parameter_layout(ShapeLayout(param)); @@ -937,7 +935,7 @@ bool HloParserImpl::ParseComputationLayout( return false; } Shape result; - if (!ParseShape(&result, set_to_default_entry_computation_layout_)) { + if (!ParseShape(&result, options_.fill_missing_module_parameter_layouts())) { return false; } *computation_layout->mutable_result_layout() = ShapeLayout(result); @@ -6990,19 +6988,13 @@ bool HloParserImpl::ParseSingleInstruction(HloModule* module) { absl::StatusOr> ParseAndReturnUnverifiedModule( absl::string_view str, const HloModuleConfig& config, - bool set_to_default_entry_computation_layout) { + const HloParserOptions& options) { auto module = std::make_unique(/*name=*/"_", config); - HloParserImpl parser(str, set_to_default_entry_computation_layout); + HloParserImpl parser(str, options); TF_RETURN_IF_ERROR(parser.Run(module.get())); return std::move(module); } -absl::StatusOr> ParseAndReturnUnverifiedModule( - absl::string_view str, bool set_to_default_entry_computation_layout) { - return ParseAndReturnUnverifiedModule( - str, HloModuleConfig(), set_to_default_entry_computation_layout); -} - absl::StatusOr ParseSharding(absl::string_view str) { HloParserImpl parser(str); return parser.ParseShardingOnly(); diff --git a/xla/service/hlo_parser.h b/xla/service/hlo_parser.h index 2628c15eb00db..c6b5f545c54cd 100644 --- a/xla/service/hlo_parser.h +++ b/xla/service/hlo_parser.h @@ -24,25 +24,34 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo_lexer.h" #include "xla/xla_data.pb.h" namespace xla { -// Given a string in the HloModule::ToString() format, parses the string and -// creates a HloModule with the given config. -// Note: Tests derived from HloTestBase should use -// ParseAndReturnVerifiedModule() instead! -absl::StatusOr> ParseAndReturnUnverifiedModule( - absl::string_view str, const HloModuleConfig& config, - bool set_to_default_entry_computation_layout = true); +class HloParserOptions { + public: + // If the entry computation parameter layout is not set, set the layout to be + // the default (e.g. {3,2,1,0}). + HloParserOptions& set_fill_missing_module_parameter_layouts(bool value) { + fill_missing_module_parameter_layouts_ = value; + return *this; + } + + bool fill_missing_module_parameter_layouts() const { + return fill_missing_module_parameter_layouts_; + } + + private: + bool fill_missing_module_parameter_layouts_ = true; +}; // Given a string in the HloModule::ToString() format, parses the string and -// creates a HloModule with default config. +// creates a HloModule with the given config. // Note: Tests derived from HloTestBase should use // ParseAndReturnVerifiedModule() instead! absl::StatusOr> ParseAndReturnUnverifiedModule( - absl::string_view str, bool set_to_default_entry_computation_layout = true); + absl::string_view str, const HloModuleConfig& config = HloModuleConfig(), + const HloParserOptions& options = HloParserOptions()); // Parses sharding from str. str is supposed to contain the body of the // sharding, i.e. just the rhs of the "sharding={...}" attribute string, e.g., diff --git a/xla/service/hlo_parser_test.cc b/xla/service/hlo_parser_test.cc index e7c52987492fe..0035b317fb418 100644 --- a/xla/service/hlo_parser_test.cc +++ b/xla/service/hlo_parser_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include #include #include "absl/log/log.h" #include "absl/status/status.h" @@ -3434,12 +3435,71 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] { absl::StatusOr> module = ParseAndReturnUnverifiedModule( - original, /*set_to_default_entry_computation_layout=*/false); + original, {}, + HloParserOptions().set_fill_missing_module_parameter_layouts(false)); TF_ASSERT_OK(module.status()); // Do not set the default layout. EXPECT_FALSE(module.value()->entry_computation_layout().AnyLayoutSet()); } +TEST_F(HloParserTest, DoNotSetEntryComputationLayoutIfSet) { + const std::string original = R"( +HloModule layout_defined, entry_computation_layout={(f32[8,16,256]{1,2,0}) -> f32[8,16]} + +add_F32.v3 { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] { + input = f32[8,16,256]{0,1,2} parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 +})"; + + absl::StatusOr> module = + ParseAndReturnUnverifiedModule( + original, {}, + HloParserOptions().set_fill_missing_module_parameter_layouts(true)); + TF_ASSERT_OK(module.status()); + EXPECT_THAT(module.value() + ->entry_computation_layout() + .parameter_layout(0) + .layout() + .minor_to_major(), + ElementsAre(1, 2, 0)); +} + +TEST_F(HloParserTest, SetEntryComputationLayoutIfNotSet) { + const std::string original = R"( +HloModule layout_defined, entry_computation_layout={(f32[8,16,256]) -> f32[8,16]} + +add_F32.v3 { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] { + input = f32[8,16,256]{0,1,2} parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 +})"; + + absl::StatusOr> module = + ParseAndReturnUnverifiedModule( + original, {}, + HloParserOptions().set_fill_missing_module_parameter_layouts(true)); + TF_ASSERT_OK(module.status()); + EXPECT_THAT(module.value() + ->entry_computation_layout() + .parameter_layout(0) + .layout() + .minor_to_major(), + ElementsAre(2, 1, 0)); +} + TEST_F(HloParserTest, NoEntry) { const std::string original = R"(HloModule no_entry: c1 { diff --git a/xla/tools/BUILD b/xla/tools/BUILD index 765e9e9c1e588..a0a30eea4e652 100644 --- a/xla/tools/BUILD +++ b/xla/tools/BUILD @@ -477,10 +477,12 @@ cc_library( "//xla/service:hlo_parser", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_googlesource_code_re2//:re2", "@tsl//tsl/platform:env", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:path", "@tsl//tsl/platform:protobuf", + "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/tools/hlo_module_loader.cc b/xla/tools/hlo_module_loader.cc index f6a685435825c..f765acaeeef8c 100644 --- a/xla/tools/hlo_module_loader.cc +++ b/xla/tools/hlo_module_loader.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "re2/re2.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -35,6 +36,7 @@ limitations under the License. #include "tsl/platform/logging.h" #include "tsl/platform/path.h" #include "tsl/platform/protobuf.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -71,7 +73,7 @@ absl::StatusOr> LoadModuleFromData( const hlo_module_loader_details::Config& ovr_config, const std::function& config_modifier_hook, BufferAssignmentProto* buffer_assignment_proto, - bool set_to_default_entry_computation_layout) { + bool fill_missing_module_parameter_layouts) { DebugOptions debug_options = GetDebugOptionsFromFlags(); std::unique_ptr module; if (format == "hlo" || format == "txt") { @@ -82,9 +84,11 @@ absl::StatusOr> LoadModuleFromData( if (config_modifier_hook) { config_modifier_hook(&config); } - TF_ASSIGN_OR_RETURN(module, ParseAndReturnUnverifiedModule( - hlo_string, config, - set_to_default_entry_computation_layout)); + HloParserOptions options; + options.set_fill_missing_module_parameter_layouts( + fill_missing_module_parameter_layouts); + TF_ASSIGN_OR_RETURN( + module, ParseAndReturnUnverifiedModule(hlo_string, config, options)); } else { HloSnapshot proto; if (format == "pb") { @@ -133,7 +137,7 @@ absl::StatusOr> LoadModuleFromFile( const hlo_module_loader_details::Config& ovr_config, const std::function& config_modifier_hook, BufferAssignmentProto* buffer_assignment_proto, - bool set_to_default_entry_computation_layout) { + bool fill_missing_module_parameter_layouts) { std::string data; if (format.empty()) { format = std::string(tsl::io::Extension(path)); @@ -141,7 +145,7 @@ absl::StatusOr> LoadModuleFromFile( TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), path, &data)); return LoadModuleFromData(data, format, ovr_config, config_modifier_hook, buffer_assignment_proto, - set_to_default_entry_computation_layout); + fill_missing_module_parameter_layouts); } absl::StatusOr> diff --git a/xla/tools/hlo_module_loader.h b/xla/tools/hlo_module_loader.h index a841bceed512d..9000422fd8d4f 100644 --- a/xla/tools/hlo_module_loader.h +++ b/xla/tools/hlo_module_loader.h @@ -61,7 +61,7 @@ absl::StatusOr> LoadModuleFromData( hlo_module_loader_details::Config(), const std::function& config_modifier_hook = {}, BufferAssignmentProto* buffer_assignment_proto = nullptr, - bool set_to_default_entry_computation_layout = true); + bool fill_missing_module_parameter_layouts = true); // Loads an HLO module from file. // The file can be one of the followings: @@ -84,7 +84,7 @@ absl::StatusOr> LoadModuleFromFile( hlo_module_loader_details::Config(), const std::function& config_modifier_hook = {}, BufferAssignmentProto* buffer_assignment_proto = nullptr, - bool set_to_default_entry_computation_layout = true); + bool fill_missing_module_parameter_layouts = true); // Loads an HLO snapshot from a string, only for its inputs // The data format must be one of the following: