diff --git a/xla/service/gpu/model/BUILD b/xla/service/gpu/model/BUILD index 0e5790f34acd7..cb2f333b4156a 100644 --- a/xla/service/gpu/model/BUILD +++ b/xla/service/gpu/model/BUILD @@ -506,6 +506,7 @@ xla_cc_test( deps = [ ":affine_map_printer", ":indexing_analysis", + ":indexing_map_serialization", ":indexing_test_utils", "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", @@ -513,6 +514,7 @@ xla_cc_test( "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/hash:hash_testing", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", "@llvm-project//mlir:IR", diff --git a/xla/service/gpu/model/indexing_map_serialization.cc b/xla/service/gpu/model/indexing_map_serialization.cc index af48bcb1f318d..31e7bfaa53ec2 100644 --- a/xla/service/gpu/model/indexing_map_serialization.cc +++ b/xla/service/gpu/model/indexing_map_serialization.cc @@ -209,20 +209,25 @@ bool Parser::ParseInterval(Interval* interval) { bool Parser::ParseAffineExprString(std::string* affine_expr_str) { unsigned num_unmatched_parens = 0; while (true) { - if (!IsPartOfAffineExpr(current_token_)) { - if (ConsumeToken(Token::Kind::kLParen)) { - ++num_unmatched_parens; - } else if (current_token_.kind == Token::Kind::kRParen && - num_unmatched_parens > 0) { - --num_unmatched_parens; - Advance(); - } else { - break; - } + if (IsPartOfAffineExpr(current_token_)) { + affine_expr_str->append(current_token_.spelling); + affine_expr_str->push_back(' '); + Advance(); + continue; } - affine_expr_str->append(current_token_.spelling); - affine_expr_str->push_back(' '); - Advance(); + if (ConsumeToken(Token::Kind::kLParen)) { + affine_expr_str->push_back('('); + ++num_unmatched_parens; + continue; + } + if (current_token_.kind == Token::Kind::kRParen && + num_unmatched_parens > 0) { + affine_expr_str->push_back(')'); + --num_unmatched_parens; + Advance(); + continue; + } + break; } return current_token_.kind != Token::Kind::kError; } @@ -302,11 +307,20 @@ Token Parser::GetNextTokenImpl() { } if (*it_ == '-') { ++it_; - if (it_ != input_.end() && *it_ == '>') { - ++it_; - return Token{"->", Token::Kind::kArrow}; - } else { - return Token{"-", Token::Kind::kMinus}; + if (it_ != input_.end()) { + if (*it_ == '>') { + ++it_; + return Token{"->", Token::Kind::kArrow}; + } else if (std::isdigit(*it_)) { + auto start = it_ - 1; + while (it_ != input_.end() && std::isdigit(*it_)) { + ++it_; + } + StringRef spelling = input_.substr(start - input_.data(), it_ - start); + return Token{spelling, Token::Kind::kIntLiteral}; + } else { + return Token{"-", Token::Kind::kMinus}; + } } } StringRef spelling = input_.substr(start - input_.data(), 1); @@ -407,6 +421,7 @@ std::optional ParseIndexingMap(llvm::StringRef input, if (!parser.ConsumeToken(Token::Kind::kComma) || !parser.ConsumeToken(Token::Kind::kKeywordDomain) || !parser.ConsumeToken(Token::Kind::kColon)) { + llvm::errs() << "Failed to parse domain keyword\n"; return std::nullopt; } // Parse dimension variables. @@ -418,9 +433,11 @@ std::optional ParseIndexingMap(llvm::StringRef input, !parser.ConsumeToken(Token::Kind::kKeywordIn) || !parser.ParseInterval(&interval) || !parser.ConsumeToken(Token::Kind::kComma)) { + llvm::errs() << "Failed to parse DimVar\n"; return std::nullopt; } if (var_name != dim_name) { + llvm::errs() << "Dimension name mismatch\n"; return std::nullopt; } dim_vars.push_back(DimVar{interval}); @@ -434,9 +451,11 @@ std::optional ParseIndexingMap(llvm::StringRef input, !parser.ConsumeToken(Token::Kind::kKeywordIn) || !parser.ParseInterval(&interval) || !parser.ConsumeToken(Token::Kind::kComma)) { + llvm::errs() << "Failed to parse RangeVar\n"; return std::nullopt; } if (var_name != symbol_var) { + llvm::errs() << "Symbol name mismatch\n"; return std::nullopt; } range_vars.push_back(RangeVar{interval}); @@ -450,6 +469,7 @@ std::optional ParseIndexingMap(llvm::StringRef input, !parser.ConsumeToken(Token::Kind::kKeywordIn) || !parser.ParseInterval(&interval) || !parser.ConsumeToken(Token::Kind::kComma)) { + llvm::errs() << "Failed to parse constraint\n"; return std::nullopt; } affine_expr_strs.push_back(affine_expr_str); @@ -459,6 +479,7 @@ std::optional ParseIndexingMap(llvm::StringRef input, bool is_simplified; if (!parser.ConsumeToken(Token::Kind::kColon) || !parser.ParseBool(&is_simplified)) { + llvm::errs() << "Failed to parse is_simplified\n"; return std::nullopt; } // Check that the input is consumed. diff --git a/xla/service/gpu/model/indexing_map_serialization_test.cc b/xla/service/gpu/model/indexing_map_serialization_test.cc index 7efd04b544280..c7d39e8c8690f 100644 --- a/xla/service/gpu/model/indexing_map_serialization_test.cc +++ b/xla/service/gpu/model/indexing_map_serialization_test.cc @@ -45,7 +45,7 @@ TEST_F(IndexingMapSerializationTest, DimsOnly) { (d0, d1) -> (d0 mod 2 + d1), domain: d0 in [0, 3], - d1 in [0, 4], + d1 in [-4, 4], is_simplified: true )"); } @@ -88,6 +88,19 @@ TEST_F(IndexingMapSerializationTest, DimsAndSymbolsAndConstraints) { )"); } +TEST_F(IndexingMapSerializationTest, AffineExprsWithParens) { + ParseAndCheck(R"( + (d0, d1)[s0, s1] -> ((d0 + d0 mod 3) floordiv 3 + + s0 + (s0 * 2) mod 3 + (d0 + s0) mod 3), + domain: + d0 in [0, 9], + d1 in [0, 19], + s0 in [0, 29], + s1 in [0, 39], + is_simplified: false + )"); +} + // This test will be updated when the printing uses types of variables. TEST_F(IndexingMapSerializationTest, CustomNames) { auto indexing_map_str = R"( diff --git a/xla/service/gpu/model/indexing_map_test.cc b/xla/service/gpu/model/indexing_map_test.cc index c7bd056b072a5..c3b8669f1b46c 100644 --- a/xla/service/gpu/model/indexing_map_test.cc +++ b/xla/service/gpu/model/indexing_map_test.cc @@ -27,12 +27,14 @@ limitations under the License. #include #include "absl/hash/hash_testing.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/MLIRContext.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" @@ -49,6 +51,12 @@ using ::testing::ElementsAre; class IndexingMapTest : public HloTestBase { public: + IndexingMap Parse(absl::string_view indexing_map_str) { + auto indexing_map = ParseIndexingMap(indexing_map_str, &mlir_context_); + EXPECT_TRUE(indexing_map.has_value()); + return *indexing_map; + } + mlir::MLIRContext mlir_context_; AffineMapPrinter printer_; }; @@ -112,10 +120,15 @@ TEST_F(IndexingMapTest, RTVar) { } TEST_F(IndexingMapTest, Evaluation) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", &mlir_context_), - {4, 4}, {2, 2}); - + IndexingMap indexing_map = Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1, s0), + domain: + d0 in [0, 3], + d1 in [0, 3], + s0 in [0, 1], + s1 in [0, 1], + is_simplified: false + )"); auto results = indexing_map.Evaluate( mlir::getAffineConstantExprs({1, 2}, &mlir_context_), mlir::getAffineConstantExprs({3, 4}, &mlir_context_)); @@ -136,12 +149,23 @@ TEST_F(IndexingMapTest, Evaluation) { } TEST_F(IndexingMapTest, Composition_Permutation) { - IndexingMap producer = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", &mlir_context_), - {4, 4}, {2, 2}); - - IndexingMap consumer = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), {4}, {4}); + IndexingMap producer = Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1, s0), + domain: + d0 in [0, 3], + d1 in [0, 3], + s0 in [0, 1], + s1 in [0, 1], + is_simplified: false + )"); + + IndexingMap consumer = Parse(R"( + (d0)[s0] -> (d0, s0), + domain: + d0 in [0, 3], + s0 in [0, 3], + is_simplified: false + )"); auto composed = ComposeIndexingMaps(consumer, producer); EXPECT_THAT(composed, MatchIndexingMap(R"( @@ -156,12 +180,23 @@ TEST_F(IndexingMapTest, Composition_Permutation) { } TEST_F(IndexingMapTest, Composition_RestrictedInterval) { - IndexingMap producer = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", &mlir_context_), - {5, 6}, {7, 2}); - - IndexingMap consumer = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), {10}, {8}); + IndexingMap producer = Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1, s0), + domain: + d0 in [0, 4], + d1 in [0, 5], + s0 in [0, 6], + s1 in [0, 1], + is_simplified: false + )"); + + IndexingMap consumer = Parse(R"( + (d0)[s0] -> (d0, s0), + domain: + d0 in [0, 9], + s0 in [0, 7], + is_simplified: false + )"); auto composed = ComposeIndexingMaps(consumer, producer); EXPECT_THAT(composed, MatchIndexingMap(R"( @@ -176,20 +211,27 @@ TEST_F(IndexingMapTest, Composition_RestrictedInterval) { } TEST_F(IndexingMapTest, Composition_ProducerAndConsumerHaveConstraints) { - IndexingMap producer = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", &mlir_context_), - {50, 60}, {70, 20}); - producer.AddConstraint(ParseAffineExpr("d0 mod 8", &mlir_context_), - Interval{0, 0}); - producer.AddConstraint(ParseAffineExpr("s0 mod 3", &mlir_context_), - Interval{1, 1}); - - IndexingMap consumer = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), {10}, {8}); - consumer.AddConstraint(ParseAffineExpr("d0 + s0", &mlir_context_), - Interval{0, 20}); - consumer.AddConstraint(ParseAffineExpr("s0 mod 4", &mlir_context_), - Interval{0, 0}); + IndexingMap producer = Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1, s0), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 19], + d0 mod 8 in [0, 0], + s0 mod 3 in [1, 1], + is_simplified: false + )"); + + IndexingMap consumer = Parse(R"( + (d0)[s0] -> (d0, s0), + domain: + d0 in [0, 9], + s0 in [0, 7], + d0 + s0 in [0, 20], + s0 mod 4 in [0, 0], + is_simplified: false + )"); auto composed = ComposeIndexingMaps(consumer, producer); EXPECT_THAT(composed, MatchIndexingMap(R"( @@ -311,14 +353,18 @@ TEST_F(IndexingMapTest, Composition_OnlyRTVars) { } TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintUsesDim) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, s0, s1)", &mlir_context_), - {50, 60}, {70, 20}); // This constraint cannot be removed, because it contains a dimension. - indexing_map.AddConstraint(ParseAffineExpr("s0 + d0", &mlir_context_), - Interval{1, 100}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 3", &mlir_context_), - Interval{0, 0}); + auto indexing_map = Parse(R"( + (d0, d1)[s0, s1] -> (d1, s0, s1), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 19], + d0 + s0 in [1, 100], + s0 mod 3 in [0, 0], + is_simplified: false + )"); indexing_map.RemoveUnusedVars(); EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (d1, s0, s1), @@ -334,12 +380,17 @@ TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintUsesDim) { } TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintUsesUnusedDim) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (s0, d1, s1)", &mlir_context_), - {50, 60}, {70, 20}); // This constraint can be removed, because it contains only the unused dim. - indexing_map.AddConstraint(ParseAffineExpr("d0 mod 3", &mlir_context_), - Interval{0, 0}); + auto indexing_map = Parse(R"( + (d0, d1)[s0, s1] -> (s0, d1, s1), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 19], + d0 mod 3 in [0, 0], + is_simplified: false + )"); indexing_map.RemoveUnusedVars(); EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0)[s0, s1] -> (s0, d0, s1), @@ -352,12 +403,17 @@ TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintUsesUnusedDim) { } TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesOnlyUnusedSym) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d0, d1, s1)", &mlir_context_), - {50, 60}, {70, 20}); // This constraint can be removed, because it contains only the unused symbol. - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 3", &mlir_context_), - Interval{0, 0}); + auto indexing_map = Parse(R"( + (d0, d1)[s0, s1] -> (d0, d1, s1), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 19], + s0 mod 3 in [0, 0], + is_simplified: false + )"); indexing_map.RemoveUnusedSymbols(); EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0, d1)[s0] -> (d0, d1, s0), @@ -370,17 +426,23 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesOnlyUnusedSym) { } TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintsWithManyDims) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap( - "(d0, d1, d2, d3, d4)[s0, s1, s2] -> (s0 * 4 + d1 + d3 - 42)", - &mlir_context_), - {1, 2, 3, 4, 5}, {32, 64, 96}); - indexing_map.AddConstraint( - ParseAffineExpr("s0 * 4 + d1 + d3", &mlir_context_), Interval{24, 459}); - indexing_map.AddConstraint(ParseAffineExpr("s0 + s2", &mlir_context_), - Interval{0, 512}); - auto unused_vars = indexing_map.RemoveUnusedVars(); + auto indexing_map = Parse(R"( + (d0, d1, d2, d3, d4)[s0, s1, s2] -> (s0 * 4 + d1 + d3 - 42), + domain: + d0 in [0, 0], + d1 in [0, 1], + d2 in [0, 2], + d3 in [0, 3], + d4 in [0, 4], + s0 in [0, 31], + s1 in [0, 63], + s2 in [0, 95], + s0 * 4 + d1 + d3 in [24, 459], + s0 + s2 in [0, 512], + is_simplified: false + )"); // dimensions d0, d2, d4 and symbol s1 will be removed. + auto unused_vars = indexing_map.RemoveUnusedVars(); EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (d0 + s0 * 4 + d1 - 42), domain: @@ -398,14 +460,18 @@ TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintsWithManyDims) { } TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesSymbol) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1)", &mlir_context_), - {50, 60}, {70, 20}); + auto indexing_map = Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 19], + s0 + s1 in [1, 100], + s0 mod 3 in [0, 0], + is_simplified: false + )"); // This constraint cannot be removed, because it contains a "used symbol". - indexing_map.AddConstraint(ParseAffineExpr("s0 + s1", &mlir_context_), - Interval{1, 100}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 3", &mlir_context_), - Interval{0, 0}); indexing_map.RemoveUnusedSymbols(); EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (d1, d0, s1), @@ -421,12 +487,17 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesSymbol) { } TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesOnlyUnusedSymbols) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1)", &mlir_context_), - {50, 60}, {70, 20}); + auto indexing_map = Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 19], + s0 mod 3 in [0, 0], + is_simplified: false + )"); // This constraint can be removed, because it contains only the unused symbol. - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 3", &mlir_context_), - Interval{0, 0}); indexing_map.RemoveUnusedSymbols(); EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0, d1)[s0] -> (d1, d0, s0), @@ -439,10 +510,13 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesOnlyUnusedSymbols) { } TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintIsAConstantWithinRange) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0) -> (d0)", &mlir_context_), {50}, {}); - indexing_map.AddConstraint(ParseAffineExpr("0", &mlir_context_), - Interval{-10, 5}); + auto indexing_map = Parse(R"( + (d0) -> (d0), + domain: + d0 in [0, 49], + 0 in [-10, 5], + is_simplified: false + )"); EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0) -> (d0), domain: @@ -452,25 +526,34 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintIsAConstantWithinRange) { } TEST_F(IndexingMapTest, KnownEmpty_CreatingIndexingMapWithInfeasibleRange) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0) -> (d0)", &mlir_context_), {-1}, {}); + auto indexing_map = Parse(R"( + (d0) -> (d0), + domain: + d0 in [0, -2], + is_simplified: false + )"); EXPECT_THAT(indexing_map, MatchIndexingMap("KNOWN EMPTY")); } TEST_F(IndexingMapTest, KnownEmpty_AddingConstraintOutOfRange) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0) -> (d0)", &mlir_context_), {50}, {}); + auto indexing_map = Parse(R"( + (d0) -> (d0), + domain: + d0 in [0, 49], + 0 in [10, 15], + is_simplified: false + )"); // Addition of this constraint makes the domain empty. - indexing_map.AddConstraint(ParseAffineExpr("0", &mlir_context_), - Interval{10, 15}); EXPECT_THAT(indexing_map, MatchIndexingMap("KNOWN EMPTY")); } TEST_F(IndexingMapTest, KnownEmpty_Composition) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0) -> (d0)", &mlir_context_), {50}, {}); - IndexingMap known_empty = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0) -> (0)", &mlir_context_), {0}, {}); + auto indexing_map = Parse(R"( + (d0) -> (d0), domain: d0 in [0, 49], is_simplified: false + )"); + auto known_empty = Parse(R"( + (d0) -> (d0), domain: d0 in [0, -1], is_simplified: false + )"); EXPECT_THAT(known_empty, MatchIndexingMap("KNOWN EMPTY")); EXPECT_THAT(indexing_map * known_empty, MatchIndexingMap("KNOWN EMPTY")); EXPECT_THAT(known_empty * indexing_map, MatchIndexingMap("KNOWN EMPTY")); @@ -480,22 +563,33 @@ TEST_F(IndexingMapTest, KnownEmpty_Composition) { TEST_F(IndexingMapTest, KnownEmpty_AddingConstraintOutOfRangeAfterSimplification) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1)", &mlir_context_), - {50, 60}, {70, 20}); - indexing_map.AddConstraint(ParseAffineExpr("s1 floordiv 20", &mlir_context_), - Interval{2, 2}); + auto indexing_map = Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 19], + s1 floordiv 20 in [2, 2], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map, MatchIndexingMap("KNOWN EMPTY")); } TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintsWithManySymbols) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0, s1, s2, s3, s4] -> (d0 * 4 + s1 + s3 - 42)", - &mlir_context_), - {32}, {1, 2, 3, 4, 5}); - indexing_map.AddConstraint( - ParseAffineExpr("d0 * 4 + s1 + s3", &mlir_context_), Interval{24, 459}); + auto indexing_map = Parse(R"( + (d0)[s0, s1, s2, s3, s4] -> (d0 * 4 + s1 + s3 - 42), + domain: + d0 in [0, 31], + s0 in [0, 0], + s1 in [0, 1], + s2 in [0, 2], + s3 in [0, 3], + s4 in [0, 4], + d0 * 4 + s1 + s3 in [24, 459], + is_simplified: false + )"); indexing_map.RemoveUnusedSymbols(); // Symbols s0, s2, s4 will be removed and s1 and s3 will become s0 and s1. EXPECT_THAT(indexing_map, MatchIndexingMap(R"( @@ -562,11 +656,13 @@ TEST_F(IndexingMapTest, ConvertSymbolsToDimensions) { } TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0) -> (d0)", &mlir_context_), {100}, {}); - - indexing_map.AddConstraint(ParseAffineExpr("(d0 mod 8) + 5", &mlir_context_), - Interval{50, 54}); + auto indexing_map = Parse(R"( + (d0) -> (d0), + domain: + d0 in [0, 99], + d0 mod 8 + 5 in [50, 54], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0) -> (d0), @@ -579,13 +675,15 @@ TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum) { TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum_IndependentOfSymbol) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0, s1] -> (d0 * 6 + s0 * 3 + s1)", &mlir_context_), - {2000}, {2, 3}); - - indexing_map.AddConstraint( - ParseAffineExpr("d0 * 6 + s0 * 3 + s1", &mlir_context_), - Interval{0, 599}); + auto indexing_map = Parse(R"( + (d0)[s0, s1] -> (d0 * 6 + s0 * 3 + s1), + domain: + d0 in [0, 1999], + s0 in [0, 1], + s1 in [0, 2], + d0 * 6 + s0 * 3 + s1 in [0, 599], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0, s1] -> (d0 * 6 + s0 * 3 + s1), @@ -599,23 +697,27 @@ TEST_F(IndexingMapTest, TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum_NotIndependentOfSymbol) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0, s1] -> (d0 * 6 + s0 * 3 + s1)", &mlir_context_), - {2000}, {2, 3}); - - indexing_map.AddConstraint( - ParseAffineExpr("d0 * 6 + s0 * 3 + s1", &mlir_context_), - Interval{0, 598}); + auto indexing_map = Parse(R"( + (d0)[s0, s1] -> (d0 * 6 + s0 * 3 + s1), + domain: + d0 in [0, 1999], + s0 in [0, 1], + s1 in [0, 2], + d0 * 6 + s0 * 3 + s1 in [0, 598], + is_simplified: false + )"); EXPECT_FALSE(indexing_map.Simplify()); } TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum_GcdGreaterOne) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0] -> (d0 * 6 + s0 * 3)", &mlir_context_), {2000}, - {2}); - - indexing_map.AddConstraint(ParseAffineExpr("d0 * 6 + s0 * 3", &mlir_context_), - Interval{0, 599}); + auto indexing_map = Parse(R"( + (d0)[s0] -> (d0 * 6 + s0 * 3), + domain: + d0 in [0, 1999], + s0 in [0, 1], + d0 * 6 + s0 * 3 in [0, 599], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0] -> (d0 * 6 + s0 * 3), @@ -628,11 +730,13 @@ TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum_GcdGreaterOne) { TEST_F(IndexingMapTest, ConstraintIntervalSimplification_FloorDivPositiveDivisorPositiveBounds) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0) -> (d0)", &mlir_context_), {100}, {}); - - indexing_map.AddConstraint(ParseAffineExpr("d0 floordiv 8", &mlir_context_), - Interval{5, 11}); + auto indexing_map = Parse(R"( + (d0) -> (d0), + domain: + d0 in [0, 99], + d0 floordiv 8 in [5, 11], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0) -> (d0), @@ -644,12 +748,14 @@ TEST_F(IndexingMapTest, TEST_F(IndexingMapTest, ConstraintIntervalSimplification_FloorDivPositiveDivisorNegativeBounds) { - IndexingMap indexing_map = - IndexingMap(ParseAffineMap("(d0)[s0] -> (d0)", &mlir_context_), - {DimVar{{0, 99}}}, {RangeVar{{-99, 99}}}, /*rt_vars=*/{}); - - indexing_map.AddConstraint(ParseAffineExpr("s0 floordiv 3", &mlir_context_), - Interval{-11, -5}); + auto indexing_map = Parse(R"( + (d0)[s0] -> (d0), + domain: + d0 in [0, 99], + s0 in [-99, 99], + s0 floordiv 3 in [-11, -5], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0] -> (d0), @@ -662,12 +768,14 @@ TEST_F(IndexingMapTest, TEST_F(IndexingMapTest, ConstraintIntervalSimplification_FloorDivNegativeDivisorNegativeBounds) { - IndexingMap indexing_map = - IndexingMap(ParseAffineMap("(d0)[s0] -> (d0)", &mlir_context_), - {DimVar{{0, 99}}}, {RangeVar{{-99, 99}}}, /*rt_vars=*/{}); - - indexing_map.AddConstraint(ParseAffineExpr("s0 floordiv -3", &mlir_context_), - Interval{-11, -5}); + auto indexing_map = Parse(R"( + (d0)[s0] -> (d0), + domain: + d0 in [0, 99], + s0 in [-99, 99], + s0 floordiv -3 in [-11, -5], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0] -> (d0), @@ -680,11 +788,13 @@ TEST_F(IndexingMapTest, TEST_F(IndexingMapTest, ConstraintIntervalSimplification_MulPositiveMultiplierPositiveBounds) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0) -> (d0)", &mlir_context_), {100}, {}); - - indexing_map.AddConstraint(ParseAffineExpr("d0 * 8", &mlir_context_), - Interval{14, 33}); + auto indexing_map = Parse(R"( + (d0) -> (d0), + domain: + d0 in [0, 99], + d0 * 8 in [14, 33], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0) -> (d0), @@ -696,12 +806,14 @@ TEST_F(IndexingMapTest, TEST_F(IndexingMapTest, ConstraintIntervalSimplification_MulPositiveMultiplierNegativeBounds) { - IndexingMap indexing_map = - IndexingMap(ParseAffineMap("(d0)[s0] -> (d0)", &mlir_context_), - {DimVar{{0, 99}}}, {RangeVar{{-99, 99}}}, /*rt_vars=*/{}); - - indexing_map.AddConstraint(ParseAffineExpr("s0 * 3", &mlir_context_), - Interval{-11, -5}); + auto indexing_map = Parse(R"( + (d0)[s0] -> (d0), + domain: + d0 in [0, 99], + s0 in [-99, 99], + s0 * 3 in [-11, -5], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0] -> (d0), @@ -714,12 +826,14 @@ TEST_F(IndexingMapTest, TEST_F(IndexingMapTest, ConstraintIntervalSimplification_MulNegativeMultiplierNegativeBounds) { - IndexingMap indexing_map = - IndexingMap(ParseAffineMap("(d0)[s0] -> (d0)", &mlir_context_), - {DimVar{{0, 99}}}, {RangeVar{{-99, 99}}}, /*rt_vars=*/{}); - - indexing_map.AddConstraint(ParseAffineExpr("s0 * -3", &mlir_context_), - Interval{-11, -5}); + auto indexing_map = Parse(R"( + (d0)[s0] -> (d0), + domain: + d0 in [0, 99], + s0 in [-99, 99], + s0 * -3 in [-11, -5], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0] -> (d0), @@ -731,20 +845,19 @@ TEST_F(IndexingMapTest, } TEST_F(IndexingMapTest, ConstraintMerge_Mod) { - IndexingMap indexing_map( - ParseAffineMap("(d0)[s0, s1] -> (d0, s1, s0)", &mlir_context_), - {DimVar{{0, 4}}}, {RangeVar{{-21, -1}}, RangeVar{{0, 10}}}, - /*rt_vars=*/{}); - indexing_map.AddConstraint(ParseAffineExpr("d0 mod 3", &mlir_context_), - Interval{0, 0}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 2", &mlir_context_), - Interval{0, 0}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 3", &mlir_context_), - Interval{0, 0}); - indexing_map.AddConstraint(ParseAffineExpr("s1 mod 5", &mlir_context_), - Interval{1, 1}); + auto indexing_map = Parse(R"( + (d0)[s0, s1] -> (d0, s1, s0), + domain: + d0 in [0, 3], + s0 in [-21, -2], + s1 in [0, 10], + d0 mod 3 in [0, 0], + s0 mod 2 in [0, 0], + s0 mod 3 in [0, 0], + s1 mod 5 in [1, 1], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0, s1] -> (d0, s1, s0), domain: @@ -759,9 +872,12 @@ TEST_F(IndexingMapTest, ConstraintMerge_Mod) { } TEST_F(IndexingMapTest, AffineMapSimplification_ConstantDims) { - IndexingMap indexing_map = - IndexingMap(ParseAffineMap("(d0) -> (d0)", &mlir_context_), - {DimVar{{5, 5}}}, /*range_vars=*/{}, /*rt_vars=*/{}); + auto indexing_map = Parse(R"( + (d0) -> (d0), + domain: + d0 in [5, 5], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0) -> (5), @@ -774,11 +890,16 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ConstantDims) { TEST_F(IndexingMapTest, AffineMapSimplification_SumOrderRegression) { // This is a regression test for a bug where we didn't canonicalize the order // of summands correctly, leading to `Simplify` not being idempotent. - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (((((d0 + (d0 mod 3)) floordiv 3) + " - "(s0 + ((s0 + s0) mod 3))) + (((d0 + s0) mod 3) + 0)))", - &mlir_context_), - {10, 20}, {30, 40}); + auto indexing_map = Parse(R"( + (d0, d1)[s0, s1] -> (((((d0 + (d0 mod 3)) floordiv 3) + + (s0 + ((s0 + s0) mod 3))) + (((d0 + s0) mod 3) + 0))), + domain: + d0 in [0, 9], + d1 in [0, 19], + s0 in [0, 29], + s1 in [0, 39], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_FALSE(indexing_map.Simplify()); } @@ -786,20 +907,25 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SumOrderRegression) { TEST_F(IndexingMapTest, AffineMapSimplification_SumOrderRegression2) { // This is a regression test for a bug where we didn't simplify the affine // expression fully after a single iteration. - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0] -> ((((s0 + d0) + d0) floordiv 2))", - &mlir_context_), - {10, 20}, {30, 40}); + auto indexing_map = Parse(R"( + (d0)[s0] -> ((((s0 + d0) + d0) floordiv 2)), + domain: + d0 in [0, 9], + s0 in [0, 19], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_FALSE(indexing_map.Simplify()); } TEST_F(IndexingMapTest, AffineMapSimplification_FloorDivRegression) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap( - "(d0, d1) -> (((d0 floordiv 3) * 3 + d1 floordiv 2) floordiv 6)", - &mlir_context_), - {12, 6}, {}); + auto indexing_map = Parse(R"( + (d0, d1) -> (((d0 floordiv 3) * 3 + d1 floordiv 2) floordiv 6), + domain: + d0 in [0, 11], + d1 in [0, 5], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1) -> (d0 floordiv 6), @@ -811,9 +937,12 @@ TEST_F(IndexingMapTest, AffineMapSimplification_FloorDivRegression) { } TEST_F(IndexingMapTest, AffineMapSimplification_ModIsSub) { - IndexingMap indexing_map( - ParseAffineMap("(d0) -> (d0 mod 42)", &mlir_context_), {{53, 71}}, {}, - {}); + auto indexing_map = Parse(R"( + (d0) -> (d0 mod 42), + domain: + d0 in [53, 71], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0) -> (d0 - 42), @@ -824,8 +953,12 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ModIsSub) { } TEST_F(IndexingMapTest, AffineMapSimplification_ModIsAdd) { - IndexingMap indexing_map(ParseAffineMap("(d0) -> (d0 mod 5)", &mlir_context_), - {{-5, -1}}, {}, {}); + auto indexing_map = Parse(R"( + (d0) -> (d0 mod 5), + domain: + d0 in [-5, -1], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0) -> (d0 + 5), @@ -836,19 +969,22 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ModIsAdd) { } TEST_F(IndexingMapTest, AffineMapSimplification_ModIsNotAdd) { - IndexingMap indexing_map1( - ParseAffineMap("(d0) -> (d0 mod 5)", &mlir_context_), {{-4, 0}}, {}, {}); + auto indexing_map1 = + Parse("(d0) -> (d0 mod 5), domain: d0 in [-4, 0], is_simplified: false"); EXPECT_FALSE(indexing_map1.Simplify()); - IndexingMap indexing_map2( - ParseAffineMap("(d0) -> (d0 mod 5)", &mlir_context_), {{-6, -1}}, {}, {}); + auto indexing_map2 = + Parse("(d0) -> (d0 mod 5), domain: d0 in [-6, -1], is_simplified: false"); EXPECT_FALSE(indexing_map2.Simplify()); } TEST_F(IndexingMapTest, AffineMapSimplification_SubIsMod) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0] -> (d0 - (s0 floordiv 3) * 3 + s0)", - &mlir_context_), - {2}, {4}); + auto indexing_map = Parse(R"( + (d0)[s0] -> (d0 - (s0 floordiv 3) * 3 + s0), + domain: + d0 in [0, 1], + s0 in [0, 3], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0)[s0] -> (d0 + s0 mod 3), @@ -860,10 +996,13 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SubIsMod) { } TEST_F(IndexingMapTest, AffineMapSimplification_SubIsModMultiplied) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0] -> (d0 - (s0 floordiv 3) * 12 + s0 * 7)", - &mlir_context_), - {2}, {4}); + auto indexing_map = Parse(R"( + (d0)[s0] -> (d0 - (s0 floordiv 3) * 12 + s0 * 7), + domain: + d0 in [0, 1], + s0 in [0, 3], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0)[s0] -> (d0 + (s0 mod 3) * 4 + s0 * 3), @@ -875,10 +1014,13 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SubIsModMultiplied) { } TEST_F(IndexingMapTest, AffineMapSimplification_SubIsModSum) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0] -> (1 + d0 - ((s0 + 1) floordiv 3) * 3 + s0)", - &mlir_context_), - {2}, {4}); + auto indexing_map = Parse(R"( + (d0)[s0] -> (1 + d0 - ((s0 + 1) floordiv 3) * 3 + s0), + domain: + d0 in [0, 1], + s0 in [0, 3], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0)[s0] -> (d0 + (s0 + 1) mod 3), @@ -891,9 +1033,13 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SubIsModSum) { TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsIfSmallerThanDivisor) { - auto serialized_map = "(d0, d1) -> (d0 + d1 floordiv 16, d1 mod 16)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {8, 16}, {}); + auto indexing_map = Parse(R"( + (d0, d1) -> (d0 + d1 floordiv 16, d1 mod 16), + domain: + d0 in [0, 7], + d1 in [0, 15], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1) -> (d0, d1), @@ -905,15 +1051,17 @@ TEST_F(IndexingMapTest, } TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithMultipliers) { - auto serialized_map = - "(d0, d1, d2) -> ((d0 * 100 + d1 * 10 + d2) floordiv 100, " - "((d0 * 100 + d1 * 10 + d2) mod 100) floordiv 10, " - "d2 mod 10)"; - - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {9, 9, 9}, {}); + auto indexing_map = Parse(R"( + (d0, d1, d2) -> ((d0 * 100 + d1 * 10 + d2) floordiv 100, + ((d0 * 100 + d1 * 10 + d2) mod 100) floordiv 10, + d2 mod 10), + domain: + d0 in [0, 8], + d1 in [0, 8], + d2 in [0, 8], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1, d2) -> (d0, d1, d2), domain: @@ -926,12 +1074,15 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithMultipliers) { TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithDivisibleMultipliers) { - auto serialized_map = - "(d0, d1, d2) -> ((d0 * 16 + d1 * 4 + d2) floordiv 8, " - " (d0 * 16 + d1 * 4 + d2) mod 8)"; - - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {10, 10, 10}, {}); + auto indexing_map = Parse(R"( + (d0, d1, d2) -> ((d0 * 16 + d1 * 4 + d2) floordiv 8, + (d0 * 16 + d1 * 4 + d2) mod 8), + domain: + d0 in [0, 9], + d1 in [0, 9], + d2 in [0, 9], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1, d2) -> (d0 * 2 + (d1 * 4 + d2) floordiv 8, @@ -945,11 +1096,14 @@ TEST_F(IndexingMapTest, } TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithReverse) { - auto serialized_map = - "(d0, d1) -> (-((d0 * -11 - d1 + 109) floordiv 11) + 9, " - "d0 * 11 + d1 + ((d0 * -11 - d1 + 109) floordiv 11) * 11 - 99)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {8, 9}, {}); + auto indexing_map = Parse(R"( + (d0, d1) -> (-((d0 * -11 - d1 + 109) floordiv 11) + 9, + d0 * 11 + d1 + ((d0 * -11 - d1 + 109) floordiv 11) * 11 - 99), + domain: + d0 in [0, 7], + d1 in [0, 8], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1) -> (d0, d1), @@ -961,10 +1115,12 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithReverse) { } TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape) { - auto serialized_map = - "()[s0] -> ((s0 * 128) mod 715 + ((s0 * 128) floordiv 715) * 715)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {128}); + auto indexing_map = Parse(R"( + ()[s0] -> ((s0 * 128) mod 715 + ((s0 * 128) floordiv 715) * 715), + domain: + s0 in [0, 127], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0] -> (s0 * 128), @@ -975,10 +1131,14 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape) { } TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape2) { - auto serialized_map = - "(d0, d1) -> ((d0 mod 8) * 128 + d1 + (d0 floordiv 8) * 1024)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {1024, 128}, {}); + auto indexing_map = Parse(R"( + (d0, d1) -> ((d0 mod 8) * 128 + d1 + (d0 floordiv 8) * 1024), + domain: + d0 in [0, 1023], + d1 in [0, 127], + is_simplified: false + )"); + ; EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1) -> (d0 * 128 + d1), @@ -990,11 +1150,14 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape2) { } TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape3) { - auto serialized_map = - "(d0, d1) -> (((d1 * 2 + d0 floordiv 64) mod 3) * 256 + (d0 mod 64) * 4 " - "+ ((d1 * 128 + d0) floordiv 192) * 768)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {128, 3072}, {}); + auto indexing_map = Parse(R"( + (d0, d1) -> (((d1 * 2 + d0 floordiv 64) mod 3) * 256 + (d0 mod 64) * 4 + + ((d1 * 128 + d0) floordiv 192) * 768), + domain: + d0 in [0, 127], + d1 in [0, 3071], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1) -> (d0 * 4 + d1 * 512), @@ -1007,9 +1170,12 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape3) { TEST_F(IndexingMapTest, AffineMapSimplification_ModWithNegativeMultiplerDoesNotGetSimplified) { - auto serialized_map = "(d0) -> ((-d0) mod 2)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {128}, {}); + auto indexing_map = Parse(R"( + (d0) -> ((-d0) mod 2), + domain: + d0 in [0, 127], + is_simplified: false + )"); EXPECT_FALSE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0) -> ((-d0) mod 2), @@ -1024,12 +1190,15 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyBitcastAndBack) { // `((d0 * 2 + d1 floordiv 64) floordiv 3) floordiv 1024`. // This test verifies that we can still simplify the map after the // simplification of the floordiv. - auto serialized_map = - "(d0, d1) -> ((d0 floordiv 1536) * 786432 + (((d0 * 2 + d1 floordiv " - "64) floordiv 3) mod 1024) * 768 + ((d0 * 2 + d1 floordiv 64) mod 3) * " - "256 + (d1 mod 64) * 4)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {3072, 128}, {}); + auto indexing_map = Parse(R"( + (d0, d1) -> ((d0 floordiv 1536) * 786432 + + (((d0 * 2 + d1 floordiv 64) floordiv 3) mod 1024) * 768 + + ((d0 * 2 + d1 floordiv 64) mod 3) * 256 + (d1 mod 64) * 4), + domain: + d0 in [0, 3071], + d1 in [0, 127], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1) -> (d0 * 512 + d1 * 4), @@ -1042,10 +1211,12 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyBitcastAndBack) { TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape_Regression) { // We have s0 * 128 in the mod, but s0 * 64 in the floordiv *. - auto serialized_map = - "()[s0] -> ((s0 * 128) mod 715 + ((s0 * 64) floordiv 715) * 715)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {128}); + auto indexing_map = Parse(R"( + ()[s0] -> ((s0 * 128) mod 715 + ((s0 * 64) floordiv 715) * 715), + domain: + s0 in [0, 127], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0] -> (((s0 * 64) floordiv 715) * 715 + (s0 * 128) mod 715), @@ -1056,11 +1227,12 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape_Regression) { } TEST_F(IndexingMapTest, AffineMapSimplification_DivsInSequence) { - auto serialized_map = - "()[s0] -> (s0 - ((s0 floordiv 2) floordiv 7) * 14 + (s0 floordiv 14) * " - "14)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {1234}); + auto indexing_map = Parse(R"( + ()[s0] -> (s0 - ((s0 floordiv 2) floordiv 7) * 14 + (s0 floordiv 14) * 14), + domain: + s0 in [0, 1233], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0] -> (s0), @@ -1071,9 +1243,13 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsInSequence) { } TEST_F(IndexingMapTest, AffineMapSimplification_DivDiv) { - auto serialized_map = "()[s0, s1] -> ((s0 * 2 + s1 floordiv 64) floordiv 3)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {1234, 128}); + auto indexing_map = Parse(R"( + ()[s0, s1] -> ((s0 * 2 + s1 floordiv 64) floordiv 3), + domain: + s0 in [0, 1233], + s1 in [0, 127], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0, s1] -> ((s0 * 128 + s1) floordiv 192), @@ -1085,9 +1261,12 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivDiv) { } TEST_F(IndexingMapTest, AffineMapSimplification_DivSumConstant) { - auto serialized_map = "()[s0] -> ((s0 * 6 + 9) floordiv 18)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {1234}); + auto indexing_map = Parse(R"( + ()[s0] -> ((s0 * 6 + 9) floordiv 18), + domain: + s0 in [0, 1233], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0] -> ((s0 * 2 + 3) floordiv 6), @@ -1098,10 +1277,13 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivSumConstant) { } TEST_F(IndexingMapTest, AffineMapSimplification_DivSumDiv) { - auto serialized_map = - "()[s0, s1] -> ((s0 floordiv 3 + s1 floordiv 3) floordiv 6)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {1234, 128}); + auto indexing_map = Parse(R"( + ()[s0, s1] -> ((s0 floordiv 3 + s1 floordiv 3) floordiv 6), + domain: + s0 in [0, 1233], + s1 in [0, 127], + is_simplified: false + )"); // The rewrite tested in AffineMapSimplification_DivDiv must not trigger here. EXPECT_FALSE(indexing_map.Simplify()); } @@ -1110,18 +1292,25 @@ TEST_F(IndexingMapTest, AffineMapSimplification_NegativeDiv) { // (s0 floordiv 2) floordiv -7 is not s0 floordiv -14: // 15 // 2 // -7 = -1 // 15 // -14 = -2 - auto serialized_map = "()[s0] -> ((s0 floordiv 2) floordiv -7)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {1234}); + auto indexing_map = Parse(R"( + ()[s0] -> ((s0 floordiv 2) floordiv -7), + domain: + s0 in [0, 1233], + is_simplified: false + )"); EXPECT_FALSE(indexing_map.Simplify()); } TEST_F(IndexingMapTest, AffineMapSimplification_ExtractFromMod) { - auto serialized_map = - "()[s0, s1, s2, s3] -> ((s0 * 458752 + s1 + s2 * 4 + s3 * 512) mod " - "20000)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {872, 4, 128, 896}); + auto indexing_map = Parse(R"( + ()[s0, s1, s2, s3] -> ((s0 * 458752 + s1 + s2 * 4 + s3 * 512) mod 20000), + domain: + s0 in [0, 871], + s1 in [0, 3], + s2 in [0, 127], + s3 in [0, 895], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0, s1, s2, s3] -> ( @@ -1138,11 +1327,14 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ExtractFromMod) { TEST_F(IndexingMapTest, AffineMapSimplification_ExtractFromDiv_NegativeMultiplier) { - auto serialized_map = - "()[s0, s1] -> ((s0 * 16 - (s1 floordiv 4) floordiv 2 + (s1 floordiv 8) " - "* 2) floordiv 4)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {2, 128}); + auto indexing_map = Parse(R"( + ()[s0, s1] -> ((s0 * 16 - (s1 floordiv 4) floordiv 2 + (s1 floordiv 8) * 2) + floordiv 4), + domain: + s0 in [0, 1], + s1 in [0, 127], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0, s1] -> ( @@ -1156,12 +1348,16 @@ TEST_F(IndexingMapTest, } TEST_F(IndexingMapTest, RescaleSymbols_Simple) { - auto serialized_map = "(d0)[s0, s1, s2] -> (s2, d0, s1, s0 floordiv 6)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {4}, {7, 2, 6}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 6", &mlir_context_), - Interval{0, 0}); - + auto indexing_map = Parse(R"( + (d0)[s0, s1, s2] -> (s2, d0, s1, s0 floordiv 6), + domain: + d0 in [0, 3], + s0 in [0, 6], + s1 in [0, 1], + s2 in [0, 5], + s0 mod 6 in [0, 0], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.RescaleSymbols()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0), @@ -1175,12 +1371,16 @@ TEST_F(IndexingMapTest, RescaleSymbols_Simple) { } TEST_F(IndexingMapTest, RescaleSymbols_WithShift) { - auto serialized_map = "(d0)[s0, s1, s2] -> (s2, d0, s1, s0)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {4}, {42, 2, 6}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 6", &mlir_context_), - Interval{3, 3}); - + auto indexing_map = Parse(R"( + (d0)[s0, s1, s2] -> (s2, d0, s1, s0), + domain: + d0 in [0, 3], + s0 in [0, 41], + s1 in [0, 1], + s2 in [0, 5], + s0 mod 6 in [3, 3], + is_simplified: false + )"); // [BEFORE] Allowed values for s0: 3, 9, 15, ..., 39 = (6 * 6 + 3) // [AFTER] Allowed values for s0: 0, 1, 2, ..., 6 EXPECT_TRUE(indexing_map.RescaleSymbols()); @@ -1196,14 +1396,17 @@ TEST_F(IndexingMapTest, RescaleSymbols_WithShift) { } TEST_F(IndexingMapTest, RescaleSymbols_TwoModConstraints) { - auto serialized_map = "(d0)[s0, s1, s2] -> (s2, d0, s1, s0 floordiv 6)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {4}, {7, 2, 6}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 2", &mlir_context_), - Interval{0, 0}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 3", &mlir_context_), - Interval{0, 0}); - + auto indexing_map = Parse(R"( + (d0)[s0, s1, s2] -> (s2, d0, s1, s0 floordiv 6), + domain: + d0 in [0, 3], + s0 in [0, 7], + s1 in [0, 1], + s2 in [0, 5], + s0 mod 2 in [0, 0], + s0 mod 3 in [0, 0], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.RescaleSymbols()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0), @@ -1217,14 +1420,17 @@ TEST_F(IndexingMapTest, RescaleSymbols_TwoModConstraints) { } TEST_F(IndexingMapTest, RescaleSymbols_RescaledSymbolInOtherNonModConstraint) { - auto serialized_map = "(d0)[s0, s1, s2] -> (s2, d0, s1, s0)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {4}, {10, 2, 6}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 6", &mlir_context_), - Interval{3, 3}); - indexing_map.AddConstraint(ParseAffineExpr("s0 * s2", &mlir_context_), - Interval{0, 28}); - + auto indexing_map = Parse(R"( + (d0)[s0, s1, s2] -> (s2, d0, s1, s0), + domain: + d0 in [0, 3], + s0 in [0, 9], + s1 in [0, 1], + s2 in [0, 5], + s0 * s2 in [0, 28], + s0 mod 6 in [3, 3], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.RescaleSymbols()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0 * 6 + 3), @@ -1240,14 +1446,17 @@ TEST_F(IndexingMapTest, RescaleSymbols_RescaledSymbolInOtherNonModConstraint) { TEST_F(IndexingMapTest, RescaleSymbols_TwoModConstraintsForTheSameSymbolWhichCannotBeMerged) { - auto serialized_map = "(d0)[s0, s1, s2] -> (s2, d0, s1, s0)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {4}, {100, 2, 6}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 6", &mlir_context_), - Interval{3, 3}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 7", &mlir_context_), - Interval{5, 5}); - + auto indexing_map = Parse(R"( + (d0)[s0, s1, s2] -> (s2, d0, s1, s0), + domain: + d0 in [0, 3], + s0 in [0, 99], + s1 in [0, 1], + s2 in [0, 5], + s0 mod 6 in [3, 3], + s0 mod 7 in [5, 5], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.RescaleSymbols()); const mlir::AffineExpr result3 = indexing_map.GetAffineMap().getResult(3); @@ -1274,14 +1483,17 @@ TEST_F(IndexingMapTest, } TEST_F(IndexingMapTest, RescaleSymbolsKeepsHashmapConsistent) { - auto serialized_map = "(d0)[s0, s1, s2] -> (s2, d0, s0, s0 floordiv 6)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {4}, {7, 2, 6}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 6", &mlir_context_), - Interval{0, 0}); - indexing_map.AddConstraint(ParseAffineExpr("s0 * s1", &mlir_context_), - Interval{0, 100}); - + auto indexing_map = Parse(R"( + (d0)[s0, s1, s2] -> (s2, d0, s0, s0 floordiv 6), + domain: + d0 in [0, 3], + s0 in [0, 6], + s1 in [0, 1], + s2 in [0, 5], + s0 mod 6 in [0, 0], + s0 * s1 in [0, 100], + is_simplified: false + )"); EXPECT_TRUE(indexing_map.RescaleSymbols()); for (auto& [expr, interval] : indexing_map.GetConstraints()) { @@ -1291,13 +1503,15 @@ TEST_F(IndexingMapTest, RescaleSymbolsKeepsHashmapConsistent) { } TEST_F(IndexingMapTest, RangeEvaluatorTest) { - auto serialized_map = "(d0, d1, d2, d3)[] -> (0)"; - IndexingMap indexing_map(ParseAffineMap(serialized_map, &mlir_context_), - {{Interval{0, 9}}, - {Interval{-10, -1}}, - {Interval{-1, 2}}, - {Interval{0, 0}}}, - {}, {}); + auto indexing_map = Parse(R"( + (d0, d1, d2, d3)[] -> (0), + domain: + d0 in [0, 9], + d1 in [-10, -1], + d2 in [-1, 2], + d3 in [0, 0], + is_simplified: false + )"); RangeEvaluator range_evaluator(indexing_map, &mlir_context_); mlir::AffineExpr d0, d1, d2, d3; bindDims(&mlir_context_, d0, d1, d2, d3); @@ -1936,44 +2150,64 @@ ENTRY e { TEST_F(IndexingMapTest, IndexingMapSupportsAbslHashAndEqAndNe) { auto zero_dim_map = AffineMap::get(&mlir_context_); ExpectSupportsAbslHashAndEqAndNe( - {IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", - &mlir_context_), - {50, 60}, {70, 80}), - IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1 * 2, d0, s1, s0)", - &mlir_context_), - {50, 60}, {70, 80}), - IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", - &mlir_context_), - {51, 60}, {70, 80}), - IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", - &mlir_context_), - {50, 60}, {71, 80}), - [&] { - auto m = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", - &mlir_context_), - {50, 60}, {70, 80}); - m.AddConstraint(ParseAffineExpr("d0 mod 8", &mlir_context_), - Interval{0, 0}); - m.AddConstraint(ParseAffineExpr("d0 mod 16", &mlir_context_), - Interval{0, 0}); - return m; - }(), - [&] { - auto m = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", - &mlir_context_), - {50, 60}, {70, 80}); - m.AddConstraint(ParseAffineExpr("d0 mod 8", &mlir_context_), - Interval{0, 0}); - m.AddConstraint(ParseAffineExpr("d0 mod 32", &mlir_context_), - Interval{0, 0}); - return m; - }(), + {Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1, s0), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 79], + is_simplified: false + )"), + Parse(R"( + (d0, d1)[s0, s1] -> (d1 * 2, d0, s1, s0), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 79], + is_simplified: false + )"), + Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1, s0), + domain: + d0 in [0, 50], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 79], + is_simplified: false + )"), + Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1, s0), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 79], + is_simplified: false + )"), + Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1, s0), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 79], + d0 mod 8 in [0, 0], + d0 mod 16 in [0, 0], + is_simplified: false + )"), + Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1, s0), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 79], + d0 mod 8 in [0, 0], + d0 mod 32 in [0, 0], + is_simplified: false + )"), IndexingMap( ParseAffineMap("(d0)[s0, s1, s2, s3, s4] -> (d0 * 4 + s1 + s3 - 42)", &mlir_context_),