Skip to content

Commit

Permalink
[XLA:GPU][IndexAnalysis] Use the parser in indexing_map_test.
Browse files Browse the repository at this point in the history
Add support for parsing of negative int literals and affine expressions with ().

PiperOrigin-RevId: 678601020
  • Loading branch information
pifon2a authored and Google-ML-Automation committed Sep 25, 2024
1 parent 8edbaf4 commit b5216ad
Show file tree
Hide file tree
Showing 4 changed files with 658 additions and 388 deletions.
2 changes: 2 additions & 0 deletions xla/service/gpu/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -506,13 +506,15 @@ xla_cc_test(
deps = [
":affine_map_printer",
":indexing_analysis",
":indexing_map_serialization",
":indexing_test_utils",
"//xla/hlo/ir:hlo",
"//xla/tests:hlo_test_base",
"//xla/tests:verified_hlo_module",
"//xla/tests:xla_internal_test_main",
"@com_google_absl//absl/hash:hash_testing",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest",
"@llvm-project//mlir:IR",
Expand Down
57 changes: 39 additions & 18 deletions xla/service/gpu/model/indexing_map_serialization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -209,20 +209,25 @@ bool Parser::ParseInterval(Interval* interval) {
bool Parser::ParseAffineExprString(std::string* affine_expr_str) {
unsigned num_unmatched_parens = 0;
while (true) {
if (!IsPartOfAffineExpr(current_token_)) {
if (ConsumeToken(Token::Kind::kLParen)) {
++num_unmatched_parens;
} else if (current_token_.kind == Token::Kind::kRParen &&
num_unmatched_parens > 0) {
--num_unmatched_parens;
Advance();
} else {
break;
}
if (IsPartOfAffineExpr(current_token_)) {
affine_expr_str->append(current_token_.spelling);
affine_expr_str->push_back(' ');
Advance();
continue;
}
affine_expr_str->append(current_token_.spelling);
affine_expr_str->push_back(' ');
Advance();
if (ConsumeToken(Token::Kind::kLParen)) {
affine_expr_str->push_back('(');
++num_unmatched_parens;
continue;
}
if (current_token_.kind == Token::Kind::kRParen &&
num_unmatched_parens > 0) {
affine_expr_str->push_back(')');
--num_unmatched_parens;
Advance();
continue;
}
break;
}
return current_token_.kind != Token::Kind::kError;
}
Expand Down Expand Up @@ -302,11 +307,20 @@ Token Parser::GetNextTokenImpl() {
}
if (*it_ == '-') {
++it_;
if (it_ != input_.end() && *it_ == '>') {
++it_;
return Token{"->", Token::Kind::kArrow};
} else {
return Token{"-", Token::Kind::kMinus};
if (it_ != input_.end()) {
if (*it_ == '>') {
++it_;
return Token{"->", Token::Kind::kArrow};
} else if (std::isdigit(*it_)) {
auto start = it_ - 1;
while (it_ != input_.end() && std::isdigit(*it_)) {
++it_;
}
StringRef spelling = input_.substr(start - input_.data(), it_ - start);
return Token{spelling, Token::Kind::kIntLiteral};
} else {
return Token{"-", Token::Kind::kMinus};
}
}
}
StringRef spelling = input_.substr(start - input_.data(), 1);
Expand Down Expand Up @@ -407,6 +421,7 @@ std::optional<IndexingMap> ParseIndexingMap(llvm::StringRef input,
if (!parser.ConsumeToken(Token::Kind::kComma) ||
!parser.ConsumeToken(Token::Kind::kKeywordDomain) ||
!parser.ConsumeToken(Token::Kind::kColon)) {
llvm::errs() << "Failed to parse domain keyword\n";
return std::nullopt;
}
// Parse dimension variables.
Expand All @@ -418,9 +433,11 @@ std::optional<IndexingMap> ParseIndexingMap(llvm::StringRef input,
!parser.ConsumeToken(Token::Kind::kKeywordIn) ||
!parser.ParseInterval(&interval) ||
!parser.ConsumeToken(Token::Kind::kComma)) {
llvm::errs() << "Failed to parse DimVar\n";
return std::nullopt;
}
if (var_name != dim_name) {
llvm::errs() << "Dimension name mismatch\n";
return std::nullopt;
}
dim_vars.push_back(DimVar{interval});
Expand All @@ -434,9 +451,11 @@ std::optional<IndexingMap> ParseIndexingMap(llvm::StringRef input,
!parser.ConsumeToken(Token::Kind::kKeywordIn) ||
!parser.ParseInterval(&interval) ||
!parser.ConsumeToken(Token::Kind::kComma)) {
llvm::errs() << "Failed to parse RangeVar\n";
return std::nullopt;
}
if (var_name != symbol_var) {
llvm::errs() << "Symbol name mismatch\n";
return std::nullopt;
}
range_vars.push_back(RangeVar{interval});
Expand All @@ -450,6 +469,7 @@ std::optional<IndexingMap> ParseIndexingMap(llvm::StringRef input,
!parser.ConsumeToken(Token::Kind::kKeywordIn) ||
!parser.ParseInterval(&interval) ||
!parser.ConsumeToken(Token::Kind::kComma)) {
llvm::errs() << "Failed to parse constraint\n";
return std::nullopt;
}
affine_expr_strs.push_back(affine_expr_str);
Expand All @@ -459,6 +479,7 @@ std::optional<IndexingMap> ParseIndexingMap(llvm::StringRef input,
bool is_simplified;
if (!parser.ConsumeToken(Token::Kind::kColon) ||
!parser.ParseBool(&is_simplified)) {
llvm::errs() << "Failed to parse is_simplified\n";
return std::nullopt;
}
// Check that the input is consumed.
Expand Down
15 changes: 14 additions & 1 deletion xla/service/gpu/model/indexing_map_serialization_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ TEST_F(IndexingMapSerializationTest, DimsOnly) {
(d0, d1) -> (d0 mod 2 + d1),
domain:
d0 in [0, 3],
d1 in [0, 4],
d1 in [-4, 4],
is_simplified: true
)");
}
Expand Down Expand Up @@ -88,6 +88,19 @@ TEST_F(IndexingMapSerializationTest, DimsAndSymbolsAndConstraints) {
)");
}

TEST_F(IndexingMapSerializationTest, AffineExprsWithParens) {
ParseAndCheck(R"(
(d0, d1)[s0, s1] -> ((d0 + d0 mod 3) floordiv 3
+ s0 + (s0 * 2) mod 3 + (d0 + s0) mod 3),
domain:
d0 in [0, 9],
d1 in [0, 19],
s0 in [0, 29],
s1 in [0, 39],
is_simplified: false
)");
}

// This test will be updated when the printing uses types of variables.
TEST_F(IndexingMapSerializationTest, CustomNames) {
auto indexing_map_str = R"(
Expand Down
Loading

0 comments on commit b5216ad

Please sign in to comment.