diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 06141018c5b8a3..d69764435604cc 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -1747,7 +1747,7 @@ std::unique_ptr CreateReshapeStrategies( return strategy_group; } -AutoShardingSolverResult CallSolver( +AutoShardingSolverResult CreateAutoShardingSolverRequestAndCallSolver( const HloModule& hlo_module, const HloLiveRange& hlo_live_range, const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, const CostGraph& cost_graph, const AliasSet& alias_set, @@ -1969,7 +1969,7 @@ AutoShardingSolverResult CallSolver( PopulateTemporalValues(cost_graph, request); - return CallORToolsSolver(request); + return FormulateAndSolveMIPFromSolverRequest(request); } void CheckHloSharding( diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc index e0fdd6ad71bbf8..7a92ac5715039a 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc @@ -48,13 +48,13 @@ AutoShardingSolverResult Solve( const AutoShardingOption& option, absl::string_view request_prefix, const absl::flat_hash_map& sharding_propagation_solution) { - return CallSolver(hlo_module, hlo_live_range, strategy_map, strategy_groups, - cost_graph, alias_set, node_intervals, edge_intervals, - node_groups, edge_groups, /*s_hint*/ {}, - /*compute_iis*/ true, option.solver_timeout_in_seconds, - option, /*max_cost*/ std::nullopt, request_prefix, - sharding_propagation_solution, - /*deterministic mode*/ true); + return CreateAutoShardingSolverRequestAndCallSolver( + hlo_module, hlo_live_range, strategy_map, strategy_groups, cost_graph, + alias_set, node_intervals, edge_intervals, node_groups, edge_groups, + /*s_hint*/ {}, + /*compute_iis*/ true, option.solver_timeout_in_seconds, option, + /*max_cost*/ std::nullopt, request_prefix, sharding_propagation_solution, + /*deterministic mode*/ true); } void PopulateTemporalValues(const CostGraph& cost_graph, diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index 8fca1bc7b81ab7..cf18e7a5c56c6e 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -399,7 +399,7 @@ void AddMemoryTerms( // can be a few (usually < 10) edges in the problem with negative costs. This // is guaranteed to never produce a negative overall cost for the graph, // however. -AutoShardingSolverResult CallORToolsSolver( +AutoShardingSolverResult FormulateAndSolveMIPFromSolverRequest( const AutoShardingSolverRequest& unscaled_request) { const absl::Time start_time = absl::Now(); const AutoShardingSolverRequest& request = ScaleRequest(unscaled_request); diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h b/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h index cb051f7718fd44..88884f7286d0b6 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h @@ -47,7 +47,7 @@ struct AutoShardingSolverResult { bool skip_auto_sharding; }; -AutoShardingSolverResult CallORToolsSolver( +AutoShardingSolverResult FormulateAndSolveMIPFromSolverRequest( const AutoShardingSolverRequest& request); enum AutoShardingViolationCode { diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc index 81c02acd354bd5..3e0c82d3b75510 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc @@ -250,10 +250,11 @@ AutoShardingSolverRequest AutoShardingSolverRequestWithEquivalences() { return request; } -TEST(CallORToolsSolverTest, SolvesOptimally) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesOptimally) { const AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; @@ -262,12 +263,13 @@ TEST(CallORToolsSolverTest, SolvesOptimally) { EXPECT_EQ(result, expected_result); } -TEST(CallORToolsSolverTest, SolvesOverbudget) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesOverbudget) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.set_memory_budget(100000); request.mutable_overbudget_coeff()->set_coeff(10.0); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 9007650.0; @@ -276,11 +278,12 @@ TEST(CallORToolsSolverTest, SolvesOverbudget) { EXPECT_EQ(result, expected_result); } -TEST(CallORToolsSolverTest, SolvesMaxDepartures) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesMaxDepartures) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.mutable_max_departures()->set_coeff(3.0); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; @@ -289,11 +292,12 @@ TEST(CallORToolsSolverTest, SolvesMaxDepartures) { EXPECT_EQ(result, expected_result); } -TEST(CallORToolsSolverTest, MinimizesDepartures) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, MinimizesDepartures) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.set_minimize_departures(true); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {0, 1, 0, 0, 1}; const double objective_value = 3.0; @@ -302,13 +306,14 @@ TEST(CallORToolsSolverTest, MinimizesDepartures) { EXPECT_EQ(result, expected_result); } -TEST(CallORToolsSolverTest, AvoidsInfiniteNodeCosts) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, AvoidsInfiniteNodeCosts) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.mutable_computation_costs(0)->set_costs(0, kInfinityCost); request.mutable_computation_costs(0)->set_costs(1, kInfinityCost); request.mutable_computation_costs(0)->set_costs(2, kInfinityCost); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {3, 0, 0, 0, 0}; const double objective_value = 10683.0; @@ -317,11 +322,12 @@ TEST(CallORToolsSolverTest, AvoidsInfiniteNodeCosts) { EXPECT_EQ(result, expected_result); } -TEST(CallORToolsSolverTest, AvoidsInfiniteEdgeCosts) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, AvoidsInfiniteEdgeCosts) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.mutable_resharding_costs(0)->set_costs(0, kInfinityCost); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; @@ -330,7 +336,7 @@ TEST(CallORToolsSolverTest, AvoidsInfiniteEdgeCosts) { EXPECT_EQ(result, expected_result); } -TEST(CallORToolsSolverTest, HandlesFollowedEdges) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesFollowedEdges) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); AutoShardingSolverRequest_Pair edge; edge.set_first(1); @@ -346,7 +352,8 @@ TEST(CallORToolsSolverTest, HandlesFollowedEdges) { 70000, 71000, 72000, 73000}}; AddCosts(request.mutable_duration_costs(), t); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 12650.0; @@ -355,7 +362,7 @@ TEST(CallORToolsSolverTest, HandlesFollowedEdges) { EXPECT_EQ(result, expected_result); } -TEST(CallORToolsSolverTest, HandlesCollapsedEdge) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesCollapsedEdge) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); AutoShardingSolverRequest_Pair edge; edge.set_first(2); @@ -373,7 +380,8 @@ TEST(CallORToolsSolverTest, HandlesCollapsedEdge) { 80000, 81000, 82000, 83000}}; AddCosts(request.mutable_duration_costs(), t); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 13972.0; @@ -382,12 +390,13 @@ TEST(CallORToolsSolverTest, HandlesCollapsedEdge) { EXPECT_EQ(result, expected_result); } -TEST(CallORToolsSolverTest, UsesHint) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, UsesHint) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); const auto s_hint = {1, 0, 0, 0, 0}; // Not optimal, but close. request.mutable_s_hint()->Add(s_hint.begin(), s_hint.end()); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; @@ -396,20 +405,22 @@ TEST(CallORToolsSolverTest, UsesHint) { EXPECT_EQ(result, expected_result); } -TEST(CallORToolsSolverTest, HonorsMaxCost) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, HonorsMaxCost) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.mutable_max_cost()->set_coeff(7600.0); // Best possible is 7650.0 - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); EXPECT_TRUE(absl::IsInternal(result.status.status())); } -TEST(CallORToolsSolverTest, HandlesExtremelyHighMaxCost) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesExtremelyHighMaxCost) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.mutable_max_cost()->set_coeff(1e19); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; @@ -418,7 +429,7 @@ TEST(CallORToolsSolverTest, HandlesExtremelyHighMaxCost) { EXPECT_EQ(result, expected_result); } -TEST(CallORToolsSolverTest, HandlesMemoryEdgeCosts) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesMemoryEdgeCosts) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); const EdgeMatrix live_edges = {{}, {0}, {0, 1}, {1}, {}}; const CostMatrix memory_edge_costs = {{1000000, 1100, 1200, 1300, @@ -432,7 +443,8 @@ TEST(CallORToolsSolverTest, HandlesMemoryEdgeCosts) { AddCosts(request.mutable_memory_edge_costs(), memory_edge_costs); request.set_enable_memory_edge_costs(true); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; @@ -441,7 +453,7 @@ TEST(CallORToolsSolverTest, HandlesMemoryEdgeCosts) { EXPECT_EQ(result, expected_result); } -TEST(CallORToolsSolverTest, HandlesIntervals) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesIntervals) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); const std::vector> node_intervals = {{0, 4}, {0, 4}, {2, 3}, {3, 4}, {100, -1}}; @@ -460,7 +472,8 @@ TEST(CallORToolsSolverTest, HandlesIntervals) { AddCosts(request.mutable_memory_edge_costs(), memory_edge_costs); request.set_enable_memory_edge_costs(true); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; @@ -469,7 +482,8 @@ TEST(CallORToolsSolverTest, HandlesIntervals) { EXPECT_EQ(result, expected_result); } -TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroups) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, + HandlesReducedIntervalsAndGroups) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); const std::vector> node_intervals = {{5, -1}, {5, -1}, {2, 3}, {3, 4}, {100, -1}, {0, 4}}; @@ -492,7 +506,8 @@ TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroups) { AddCosts(request.mutable_memory_edge_costs(), memory_edge_costs); request.set_enable_memory_edge_costs(true); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; @@ -501,7 +516,8 @@ TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroups) { EXPECT_EQ(result, expected_result); } -TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroupsNoMemoryEdgeCosts) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, + HandlesReducedIntervalsAndGroupsNoMemoryEdgeCosts) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); const std::vector> node_intervals = {{5, -1}, {5, -1}, {2, 3}, {3, 4}, {100, -1}, {0, 4}}; @@ -511,7 +527,8 @@ TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroupsNoMemoryEdgeCosts) { AddGroups(request.mutable_node_groups(), node_groups); request.set_enable_memory_edge_costs(false); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; @@ -520,7 +537,8 @@ TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroupsNoMemoryEdgeCosts) { EXPECT_EQ(result, expected_result); } -TEST(CallORToolsSolverTest, HandlesGroupsWithTinyMemoryCosts) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, + HandlesGroupsWithTinyMemoryCosts) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); const std::vector> node_intervals = {{5, -1}, {5, -1}, {2, 3}, {3, 4}, {100, -1}, {0, 4}}; @@ -551,7 +569,8 @@ TEST(CallORToolsSolverTest, HandlesGroupsWithTinyMemoryCosts) { request.set_enable_memory_edge_costs(true); request.set_memory_budget(4321); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; @@ -560,11 +579,12 @@ TEST(CallORToolsSolverTest, HandlesGroupsWithTinyMemoryCosts) { EXPECT_EQ(result, expected_result); } -TEST(CallORToolsSolverTest, SolvesWithEquivalences) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesWithEquivalences) { const AutoShardingSolverRequest request = AutoShardingSolverRequestWithEquivalences(); - const AutoShardingSolverResult result = CallORToolsSolver(request); + const AutoShardingSolverResult result = + FormulateAndSolveMIPFromSolverRequest(request); const std::vector s_val = {0, 0, 5, 5, 1}; const double objective_value = 7650.0; diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h b/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h index 069fde4e14c580..f9058802eea52d 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h @@ -41,7 +41,7 @@ namespace spmd { // A wrapper around the solver that converts the given objects into a // combinatorial optimization problem & solves it. -AutoShardingSolverResult CallSolver( +AutoShardingSolverResult CreateAutoShardingSolverRequestAndCallSolver( const HloModule& hlo_module, const HloLiveRange& hlo_live_range, const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, const CostGraph& cost_graph, const AliasSet& alias_set,