Skip to content

Commit

Permalink
[XLA:GPU] Don't fall back to the default layout in all cases, not jus…
Browse files Browse the repository at this point in the history
…t entry computation layout.

Not resetting of the shapes of the entry computation's parameters has the same reasoning as entry_computation_layout.

Other ops are reset by the layout normalization pass anyway.

PiperOrigin-RevId: 678656058
  • Loading branch information
mooskagh authored and Google-ML-Automation committed Sep 25, 2024
1 parent f461486 commit 182111e
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 30 deletions.
12 changes: 6 additions & 6 deletions xla/service/hlo_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ class HloParserImpl : public HloParser {
bool ParseJsonDict(std::string* result);
bool ParseDimensionSizes(std::vector<int64_t>* dimension_sizes,
std::vector<bool>* dynamic_dimensions);
bool ParseShape(Shape* result, bool set_to_default_layout = true);
bool ParseShape(Shape* result);
bool ParseLayout(Layout* layout);
bool ParseLayoutIntAttribute(int64_t* attr_value,
absl::string_view attr_description);
Expand Down Expand Up @@ -915,7 +915,7 @@ bool HloParserImpl::ParseComputationLayout(
}
while (lexer_.GetKind() != TokKind::kRparen) {
Shape param;
if (!ParseShape(&param, options_.fill_missing_module_parameter_layouts())) {
if (!ParseShape(&param)) {
return false;
}
computation_layout->add_parameter_layout(ShapeLayout(param));
Expand All @@ -935,7 +935,7 @@ bool HloParserImpl::ParseComputationLayout(
return false;
}
Shape result;
if (!ParseShape(&result, options_.fill_missing_module_parameter_layouts())) {
if (!ParseShape(&result)) {
return false;
}
*computation_layout->mutable_result_layout() = ShapeLayout(result);
Expand Down Expand Up @@ -6097,7 +6097,7 @@ bool HloParserImpl::ParseLayout(Layout* layout) {
// tuple_elements
// ::= /*empty*/
// ::= shape (',' shape)*
bool HloParserImpl::ParseShape(Shape* result, bool set_to_default_layout) {
bool HloParserImpl::ParseShape(Shape* result) {
if (EatIfPresent(TokKind::kLparen)) { // Tuple
std::vector<Shape> shapes;
if (lexer_.GetKind() == TokKind::kRparen) {
Expand All @@ -6106,7 +6106,7 @@ bool HloParserImpl::ParseShape(Shape* result, bool set_to_default_layout) {
// shape (',' shape)*
do {
shapes.emplace_back();
if (!ParseShape(&shapes.back(), set_to_default_layout)) {
if (!ParseShape(&shapes.back())) {
return false;
}
} while (EatIfPresent(TokKind::kComma));
Expand All @@ -6132,7 +6132,7 @@ bool HloParserImpl::ParseShape(Shape* result, bool set_to_default_layout) {
result->add_dimensions(dimension_sizes[i]);
result->set_dynamic_dimension(i, dynamic_dimensions[i]);
}
if (set_to_default_layout || ShapeUtil::IsScalar(*result)) {
if (options_.fill_missing_layouts() || ShapeUtil::IsScalar(*result)) {
LayoutUtil::SetToDefaultLayout(result);
}
// We need to lookahead to see if a following open brace is the start of a
Expand Down
14 changes: 6 additions & 8 deletions xla/service/hlo_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,17 @@ namespace xla {

class HloParserOptions {
public:
// If the entry computation parameter layout is not set, set the layout to be
// the default (e.g. {3,2,1,0}).
HloParserOptions& set_fill_missing_module_parameter_layouts(bool value) {
fill_missing_module_parameter_layouts_ = value;
// When a shape layout is not set (e.g. in the entry computation layout or
// instruction layout), set the layout to be the default (e.g. {3,2,1,0}).
HloParserOptions& set_fill_missing_layouts(bool value) {
fill_missing_layouts_ = value;
return *this;
}

bool fill_missing_module_parameter_layouts() const {
return fill_missing_module_parameter_layouts_;
}
bool fill_missing_layouts() const { return fill_missing_layouts_; }

private:
bool fill_missing_module_parameter_layouts_ = true;
bool fill_missing_layouts_ = true;
};

// Given a string in the HloModule::ToString() format, parses the string and
Expand Down
76 changes: 70 additions & 6 deletions xla/service/hlo_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3435,8 +3435,7 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] {

absl::StatusOr<std::unique_ptr<HloModule>> module =
ParseAndReturnUnverifiedModule(
original, {},
HloParserOptions().set_fill_missing_module_parameter_layouts(false));
original, {}, HloParserOptions().set_fill_missing_layouts(false));
TF_ASSERT_OK(module.status());
// Do not set the default layout.
EXPECT_FALSE(module.value()->entry_computation_layout().AnyLayoutSet());
Expand All @@ -3460,8 +3459,7 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] {

absl::StatusOr<std::unique_ptr<HloModule>> module =
ParseAndReturnUnverifiedModule(
original, {},
HloParserOptions().set_fill_missing_module_parameter_layouts(true));
original, {}, HloParserOptions().set_fill_missing_layouts(true));
TF_ASSERT_OK(module.status());
EXPECT_THAT(module.value()
->entry_computation_layout()
Expand Down Expand Up @@ -3489,8 +3487,7 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] {

absl::StatusOr<std::unique_ptr<HloModule>> module =
ParseAndReturnUnverifiedModule(
original, {},
HloParserOptions().set_fill_missing_module_parameter_layouts(true));
original, {}, HloParserOptions().set_fill_missing_layouts(true));
TF_ASSERT_OK(module.status());
EXPECT_THAT(module.value()
->entry_computation_layout()
Expand All @@ -3500,6 +3497,73 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] {
ElementsAre(2, 1, 0));
}

TEST_F(HloParserTest, DoNotFallBackToDefaultLayoutIfDisabled) {
const std::string original = R"(
HloModule t
ENTRY main {
p0 = f16[16,32,48,64]{3,2,1,0} parameter(0)
p1 = f16[80,64,48,32]{3,2,1,0} parameter(1)
ROOT dot = f16[64,32,16,80] dot(p0, p1), lhs_contracting_dims={2}, rhs_contracting_dims={2}, lhs_batch_dims={3,1}, rhs_batch_dims={1,3}
})";

absl::StatusOr<std::unique_ptr<HloModule>> module =
ParseAndReturnUnverifiedModule(
original, {}, HloParserOptions().set_fill_missing_layouts(false));
TF_ASSERT_OK(module.status());
EXPECT_FALSE(module.value()
->entry_computation()
->root_instruction()
->shape()
.has_layout());
}

TEST_F(HloParserTest, FallBackToDefaultLayoutIfEnabled) {
const std::string original = R"(
HloModule t
ENTRY main {
p0 = f16[16,32,48,64]{3,2,1,0} parameter(0)
p1 = f16[80,64,48,32]{3,2,1,0} parameter(1)
ROOT dot = f16[64,32,16,80] dot(p0, p1), lhs_contracting_dims={2}, rhs_contracting_dims={2}, lhs_batch_dims={3,1}, rhs_batch_dims={1,3}
})";

absl::StatusOr<std::unique_ptr<HloModule>> module =
ParseAndReturnUnverifiedModule(
original, {}, HloParserOptions().set_fill_missing_layouts(true));
TF_ASSERT_OK(module.status());
EXPECT_THAT(module.value()
->entry_computation()
->root_instruction()
->shape()
.layout()
.minor_to_major(),
ElementsAre(3, 2, 1, 0));
}

TEST_F(HloParserTest, FallBackToDefaultLayoutIfAlreadySet) {
const std::string original = R"(
HloModule t
ENTRY main {
p0 = f16[16,32,48,64]{3,2,1,0} parameter(0)
p1 = f16[80,64,48,32]{3,2,1,0} parameter(1)
ROOT dot = f16[64,32,16,80]{1,2,0,3} dot(p0, p1), lhs_contracting_dims={2}, rhs_contracting_dims={2}, lhs_batch_dims={3,1}, rhs_batch_dims={1,3}
})";

absl::StatusOr<std::unique_ptr<HloModule>> module =
ParseAndReturnUnverifiedModule(
original, {}, HloParserOptions().set_fill_missing_layouts(true));
TF_ASSERT_OK(module.status());
EXPECT_THAT(module.value()
->entry_computation()
->root_instruction()
->shape()
.layout()
.minor_to_major(),
ElementsAre(1, 2, 0, 3));
}

TEST_F(HloParserTest, NoEntry) {
const std::string original = R"(HloModule no_entry:
c1 {
Expand Down
12 changes: 4 additions & 8 deletions xla/tools/hlo_module_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ absl::StatusOr<std::unique_ptr<HloModule>> LoadModuleFromData(
const std::string& data, std::string_view format,
const hlo_module_loader_details::Config& ovr_config,
const std::function<void(HloModuleConfig*)>& config_modifier_hook,
BufferAssignmentProto* buffer_assignment_proto,
bool fill_missing_module_parameter_layouts) {
BufferAssignmentProto* buffer_assignment_proto, bool fill_missing_layouts) {
DebugOptions debug_options = GetDebugOptionsFromFlags();
std::unique_ptr<HloModule> module;
if (format == "hlo" || format == "txt") {
Expand All @@ -85,8 +84,7 @@ absl::StatusOr<std::unique_ptr<HloModule>> LoadModuleFromData(
config_modifier_hook(&config);
}
HloParserOptions options;
options.set_fill_missing_module_parameter_layouts(
fill_missing_module_parameter_layouts);
options.set_fill_missing_layouts(fill_missing_layouts);
TF_ASSIGN_OR_RETURN(
module, ParseAndReturnUnverifiedModule(hlo_string, config, options));
} else {
Expand Down Expand Up @@ -136,16 +134,14 @@ absl::StatusOr<std::unique_ptr<HloModule>> LoadModuleFromFile(
const std::string& path, std::string format,
const hlo_module_loader_details::Config& ovr_config,
const std::function<void(HloModuleConfig*)>& config_modifier_hook,
BufferAssignmentProto* buffer_assignment_proto,
bool fill_missing_module_parameter_layouts) {
BufferAssignmentProto* buffer_assignment_proto, bool fill_missing_layouts) {
std::string data;
if (format.empty()) {
format = std::string(tsl::io::Extension(path));
}
TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), path, &data));
return LoadModuleFromData(data, format, ovr_config, config_modifier_hook,
buffer_assignment_proto,
fill_missing_module_parameter_layouts);
buffer_assignment_proto, fill_missing_layouts);
}

absl::StatusOr<std::unique_ptr<RunHloModuleIterationLiterals>>
Expand Down
4 changes: 2 additions & 2 deletions xla/tools/hlo_module_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ absl::StatusOr<std::unique_ptr<HloModule>> LoadModuleFromData(
hlo_module_loader_details::Config(),
const std::function<void(HloModuleConfig*)>& config_modifier_hook = {},
BufferAssignmentProto* buffer_assignment_proto = nullptr,
bool fill_missing_module_parameter_layouts = true);
bool fill_missing_layouts = true);

// Loads an HLO module from file.
// The file can be one of the followings:
Expand All @@ -84,7 +84,7 @@ absl::StatusOr<std::unique_ptr<HloModule>> LoadModuleFromFile(
hlo_module_loader_details::Config(),
const std::function<void(HloModuleConfig*)>& config_modifier_hook = {},
BufferAssignmentProto* buffer_assignment_proto = nullptr,
bool fill_missing_module_parameter_layouts = true);
bool fill_missing_layouts = true);

// Loads an HLO snapshot from a string, only for its inputs
// The data format must be one of the following:
Expand Down

0 comments on commit 182111e

Please sign in to comment.