Skip to content

Commit

Permalink
[XLA:GPU] Factor out the newly introduced set_to_default_entry_comput…
Browse files Browse the repository at this point in the history
…ation_layout parser option to an options class.

I'm going to introduce more options.
Particularly, currently the parser also resets the layouts of all individual instructions. While it's not a big deal for most of them, for entry computation `parameter` it does matter.

PiperOrigin-RevId: 678615879
  • Loading branch information
mooskagh authored and Google-ML-Automation committed Sep 25, 2024
1 parent ccdb368 commit d2726de
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 34 deletions.
22 changes: 7 additions & 15 deletions xla/service/hlo_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,8 @@ class HloParserImpl : public HloParser {
using BoolList = absl::InlinedVector<bool, 1>;

explicit HloParserImpl(absl::string_view str,
bool set_to_default_entry_computation_layout = true)
: lexer_(str),
set_to_default_entry_computation_layout_(
set_to_default_entry_computation_layout) {}
const HloParserOptions& options = HloParserOptions())
: lexer_(str), options_(options) {}

// Runs the parser and constructs the resulting HLO in the given (empty)
// HloModule. Returns the error status in case an error occurred.
Expand Down Expand Up @@ -673,7 +671,7 @@ class HloParserImpl : public HloParser {
// Used to generate names for anonymous instructions.
NameUniquer name_uniquer_{/*separator=*/"."};

const bool set_to_default_entry_computation_layout_;
const HloParserOptions options_;
};

bool SplitToInt64s(absl::string_view s, char delim, std::vector<int64_t>* out) {
Expand Down Expand Up @@ -917,7 +915,7 @@ bool HloParserImpl::ParseComputationLayout(
}
while (lexer_.GetKind() != TokKind::kRparen) {
Shape param;
if (!ParseShape(&param, set_to_default_entry_computation_layout_)) {
if (!ParseShape(&param, options_.fill_missing_module_parameter_layouts())) {
return false;
}
computation_layout->add_parameter_layout(ShapeLayout(param));
Expand All @@ -937,7 +935,7 @@ bool HloParserImpl::ParseComputationLayout(
return false;
}
Shape result;
if (!ParseShape(&result, set_to_default_entry_computation_layout_)) {
if (!ParseShape(&result, options_.fill_missing_module_parameter_layouts())) {
return false;
}
*computation_layout->mutable_result_layout() = ShapeLayout(result);
Expand Down Expand Up @@ -6990,19 +6988,13 @@ bool HloParserImpl::ParseSingleInstruction(HloModule* module) {

absl::StatusOr<std::unique_ptr<HloModule>> ParseAndReturnUnverifiedModule(
absl::string_view str, const HloModuleConfig& config,
bool set_to_default_entry_computation_layout) {
const HloParserOptions& options) {
auto module = std::make_unique<HloModule>(/*name=*/"_", config);
HloParserImpl parser(str, set_to_default_entry_computation_layout);
HloParserImpl parser(str, options);
TF_RETURN_IF_ERROR(parser.Run(module.get()));
return std::move(module);
}

absl::StatusOr<std::unique_ptr<HloModule>> ParseAndReturnUnverifiedModule(
absl::string_view str, bool set_to_default_entry_computation_layout) {
return ParseAndReturnUnverifiedModule(
str, HloModuleConfig(), set_to_default_entry_computation_layout);
}

absl::StatusOr<HloSharding> ParseSharding(absl::string_view str) {
HloParserImpl parser(str);
return parser.ParseShardingOnly();
Expand Down
29 changes: 19 additions & 10 deletions xla/service/hlo_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,34 @@ limitations under the License.
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/service/hlo_lexer.h"
#include "xla/xla_data.pb.h"

namespace xla {

// Given a string in the HloModule::ToString() format, parses the string and
// creates a HloModule with the given config.
// Note: Tests derived from HloTestBase should use
// ParseAndReturnVerifiedModule() instead!
absl::StatusOr<std::unique_ptr<HloModule>> ParseAndReturnUnverifiedModule(
absl::string_view str, const HloModuleConfig& config,
bool set_to_default_entry_computation_layout = true);
class HloParserOptions {
public:
// If the entry computation parameter layout is not set, set the layout to be
// the default (e.g. {3,2,1,0}).
HloParserOptions& set_fill_missing_module_parameter_layouts(bool value) {
fill_missing_module_parameter_layouts_ = value;
return *this;
}

bool fill_missing_module_parameter_layouts() const {
return fill_missing_module_parameter_layouts_;
}

private:
bool fill_missing_module_parameter_layouts_ = true;
};

// Given a string in the HloModule::ToString() format, parses the string and
// creates a HloModule with default config.
// creates a HloModule with the given config.
// Note: Tests derived from HloTestBase should use
// ParseAndReturnVerifiedModule() instead!
absl::StatusOr<std::unique_ptr<HloModule>> ParseAndReturnUnverifiedModule(
absl::string_view str, bool set_to_default_entry_computation_layout = true);
absl::string_view str, const HloModuleConfig& config = HloModuleConfig(),
const HloParserOptions& options = HloParserOptions());

// Parses sharding from str. str is supposed to contain the body of the
// sharding, i.e. just the rhs of the "sharding={...}" attribute string, e.g.,
Expand Down
62 changes: 61 additions & 1 deletion xla/service/hlo_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include <utility>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/log/log.h"
#include "absl/status/status.h"
Expand Down Expand Up @@ -3434,12 +3435,71 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] {

absl::StatusOr<std::unique_ptr<HloModule>> module =
ParseAndReturnUnverifiedModule(
original, /*set_to_default_entry_computation_layout=*/false);
original, {},
HloParserOptions().set_fill_missing_module_parameter_layouts(false));
TF_ASSERT_OK(module.status());
// Do not set the default layout.
EXPECT_FALSE(module.value()->entry_computation_layout().AnyLayoutSet());
}

TEST_F(HloParserTest, DoNotSetEntryComputationLayoutIfSet) {
const std::string original = R"(
HloModule layout_defined, entry_computation_layout={(f32[8,16,256]{1,2,0}) -> f32[8,16]}
add_F32.v3 {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] {
input = f32[8,16,256]{0,1,2} parameter(0)
constant = f32[] constant(0)
ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
})";

absl::StatusOr<std::unique_ptr<HloModule>> module =
ParseAndReturnUnverifiedModule(
original, {},
HloParserOptions().set_fill_missing_module_parameter_layouts(true));
TF_ASSERT_OK(module.status());
EXPECT_THAT(module.value()
->entry_computation_layout()
.parameter_layout(0)
.layout()
.minor_to_major(),
ElementsAre(1, 2, 0));
}

TEST_F(HloParserTest, SetEntryComputationLayoutIfNotSet) {
const std::string original = R"(
HloModule layout_defined, entry_computation_layout={(f32[8,16,256]) -> f32[8,16]}
add_F32.v3 {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] {
input = f32[8,16,256]{0,1,2} parameter(0)
constant = f32[] constant(0)
ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
})";

absl::StatusOr<std::unique_ptr<HloModule>> module =
ParseAndReturnUnverifiedModule(
original, {},
HloParserOptions().set_fill_missing_module_parameter_layouts(true));
TF_ASSERT_OK(module.status());
EXPECT_THAT(module.value()
->entry_computation_layout()
.parameter_layout(0)
.layout()
.minor_to_major(),
ElementsAre(2, 1, 0));
}

TEST_F(HloParserTest, NoEntry) {
const std::string original = R"(HloModule no_entry:
c1 {
Expand Down
2 changes: 2 additions & 0 deletions xla/tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -477,10 +477,12 @@ cc_library(
"//xla/service:hlo_parser",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_googlesource_code_re2//:re2",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:path",
"@tsl//tsl/platform:protobuf",
"@tsl//tsl/platform:statusor",
],
)

Expand Down
16 changes: 10 additions & 6 deletions xla/tools/hlo_module_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "re2/re2.h"
#include "xla/debug_options_flags.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
Expand All @@ -35,6 +36,7 @@ limitations under the License.
#include "tsl/platform/logging.h"
#include "tsl/platform/path.h"
#include "tsl/platform/protobuf.h"
#include "tsl/platform/statusor.h"

namespace xla {
namespace {
Expand Down Expand Up @@ -71,7 +73,7 @@ absl::StatusOr<std::unique_ptr<HloModule>> LoadModuleFromData(
const hlo_module_loader_details::Config& ovr_config,
const std::function<void(HloModuleConfig*)>& config_modifier_hook,
BufferAssignmentProto* buffer_assignment_proto,
bool set_to_default_entry_computation_layout) {
bool fill_missing_module_parameter_layouts) {
DebugOptions debug_options = GetDebugOptionsFromFlags();
std::unique_ptr<HloModule> module;
if (format == "hlo" || format == "txt") {
Expand All @@ -82,9 +84,11 @@ absl::StatusOr<std::unique_ptr<HloModule>> LoadModuleFromData(
if (config_modifier_hook) {
config_modifier_hook(&config);
}
TF_ASSIGN_OR_RETURN(module, ParseAndReturnUnverifiedModule(
hlo_string, config,
set_to_default_entry_computation_layout));
HloParserOptions options;
options.set_fill_missing_module_parameter_layouts(
fill_missing_module_parameter_layouts);
TF_ASSIGN_OR_RETURN(
module, ParseAndReturnUnverifiedModule(hlo_string, config, options));
} else {
HloSnapshot proto;
if (format == "pb") {
Expand Down Expand Up @@ -133,15 +137,15 @@ absl::StatusOr<std::unique_ptr<HloModule>> LoadModuleFromFile(
const hlo_module_loader_details::Config& ovr_config,
const std::function<void(HloModuleConfig*)>& config_modifier_hook,
BufferAssignmentProto* buffer_assignment_proto,
bool set_to_default_entry_computation_layout) {
bool fill_missing_module_parameter_layouts) {
std::string data;
if (format.empty()) {
format = std::string(tsl::io::Extension(path));
}
TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), path, &data));
return LoadModuleFromData(data, format, ovr_config, config_modifier_hook,
buffer_assignment_proto,
set_to_default_entry_computation_layout);
fill_missing_module_parameter_layouts);
}

absl::StatusOr<std::unique_ptr<RunHloModuleIterationLiterals>>
Expand Down
4 changes: 2 additions & 2 deletions xla/tools/hlo_module_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ absl::StatusOr<std::unique_ptr<HloModule>> LoadModuleFromData(
hlo_module_loader_details::Config(),
const std::function<void(HloModuleConfig*)>& config_modifier_hook = {},
BufferAssignmentProto* buffer_assignment_proto = nullptr,
bool set_to_default_entry_computation_layout = true);
bool fill_missing_module_parameter_layouts = true);

// Loads an HLO module from file.
// The file can be one of the followings:
Expand All @@ -84,7 +84,7 @@ absl::StatusOr<std::unique_ptr<HloModule>> LoadModuleFromFile(
hlo_module_loader_details::Config(),
const std::function<void(HloModuleConfig*)>& config_modifier_hook = {},
BufferAssignmentProto* buffer_assignment_proto = nullptr,
bool set_to_default_entry_computation_layout = true);
bool fill_missing_module_parameter_layouts = true);

// Loads an HLO snapshot from a string, only for its inputs
// The data format must be one of the following:
Expand Down

0 comments on commit d2726de

Please sign in to comment.