Skip to content

Commit

Permalink
Rename CallSolver --> CreateAutoShardingSolverRequestAndCallSolver an…
Browse files Browse the repository at this point in the history
…d CallORToolsSolver --> FormulateAndSolveMIPFromAutoShardingSolverRequest to better capture the function implementation.

PiperOrigin-RevId: 679350487
  • Loading branch information
Google-ML-Automation committed Sep 27, 2024
1 parent 7c5cc91 commit 2e2ab19
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 46 deletions.
4 changes: 2 additions & 2 deletions xla/hlo/experimental/auto_sharding/auto_sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1747,7 +1747,7 @@ std::unique_ptr<StrategyGroup> CreateReshapeStrategies(
return strategy_group;
}

AutoShardingSolverResult CallSolver(
AutoShardingSolverResult CreateAutoShardingSolverRequestAndCallSolver(
const HloModule& hlo_module, const HloLiveRange& hlo_live_range,
const StrategyMap& strategy_map, const StrategyGroups& strategy_groups,
const CostGraph& cost_graph, const AliasSet& alias_set,
Expand Down Expand Up @@ -1969,7 +1969,7 @@ AutoShardingSolverResult CallSolver(

PopulateTemporalValues(cost_graph, request);

return CallORToolsSolver(request);
return FormulateAndSolveMIPFromSolverRequest(request);
}

void CheckHloSharding(
Expand Down
14 changes: 7 additions & 7 deletions xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ AutoShardingSolverResult Solve(
const AutoShardingOption& option, absl::string_view request_prefix,
const absl::flat_hash_map<std::string, HloSharding>&
sharding_propagation_solution) {
return CallSolver(hlo_module, hlo_live_range, strategy_map, strategy_groups,
cost_graph, alias_set, node_intervals, edge_intervals,
node_groups, edge_groups, /*s_hint*/ {},
/*compute_iis*/ true, option.solver_timeout_in_seconds,
option, /*max_cost*/ std::nullopt, request_prefix,
sharding_propagation_solution,
/*deterministic mode*/ true);
return CreateAutoShardingSolverRequestAndCallSolver(
hlo_module, hlo_live_range, strategy_map, strategy_groups, cost_graph,
alias_set, node_intervals, edge_intervals, node_groups, edge_groups,
/*s_hint*/ {},
/*compute_iis*/ true, option.solver_timeout_in_seconds, option,
/*max_cost*/ std::nullopt, request_prefix, sharding_propagation_solution,
/*deterministic mode*/ true);
}

void PopulateTemporalValues(const CostGraph& cost_graph,
Expand Down
2 changes: 1 addition & 1 deletion xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ void AddMemoryTerms(
// can be a few (usually < 10) edges in the problem with negative costs. This
// is guaranteed to never produce a negative overall cost for the graph,
// however.
AutoShardingSolverResult CallORToolsSolver(
AutoShardingSolverResult FormulateAndSolveMIPFromSolverRequest(
const AutoShardingSolverRequest& unscaled_request) {
const absl::Time start_time = absl::Now();
const AutoShardingSolverRequest& request = ScaleRequest(unscaled_request);
Expand Down
2 changes: 1 addition & 1 deletion xla/hlo/experimental/auto_sharding/auto_sharding_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ struct AutoShardingSolverResult {
bool skip_auto_sharding;
};

AutoShardingSolverResult CallORToolsSolver(
AutoShardingSolverResult FormulateAndSolveMIPFromSolverRequest(
const AutoShardingSolverRequest& request);

enum AutoShardingViolationCode {
Expand Down
88 changes: 54 additions & 34 deletions xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,11 @@ AutoShardingSolverRequest AutoShardingSolverRequestWithEquivalences() {
return request;
}

TEST(CallORToolsSolverTest, SolvesOptimally) {
TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesOptimally) {
const AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {0, 0, 0, 0, 0};
const double objective_value = 7650.0;
Expand All @@ -262,12 +263,13 @@ TEST(CallORToolsSolverTest, SolvesOptimally) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, SolvesOverbudget) {
TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesOverbudget) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
request.set_memory_budget(100000);
request.mutable_overbudget_coeff()->set_coeff(10.0);

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {0, 0, 0, 0, 0};
const double objective_value = 9007650.0;
Expand All @@ -276,11 +278,12 @@ TEST(CallORToolsSolverTest, SolvesOverbudget) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, SolvesMaxDepartures) {
TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesMaxDepartures) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
request.mutable_max_departures()->set_coeff(3.0);

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {0, 0, 1, 1, 0};
const double objective_value = 7872.0;
Expand All @@ -289,11 +292,12 @@ TEST(CallORToolsSolverTest, SolvesMaxDepartures) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, MinimizesDepartures) {
TEST(FormulateAndSolveMIPFromSolverRequestTest, MinimizesDepartures) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
request.set_minimize_departures(true);

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {0, 1, 0, 0, 1};
const double objective_value = 3.0;
Expand All @@ -302,13 +306,14 @@ TEST(CallORToolsSolverTest, MinimizesDepartures) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, AvoidsInfiniteNodeCosts) {
TEST(FormulateAndSolveMIPFromSolverRequestTest, AvoidsInfiniteNodeCosts) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
request.mutable_computation_costs(0)->set_costs(0, kInfinityCost);
request.mutable_computation_costs(0)->set_costs(1, kInfinityCost);
request.mutable_computation_costs(0)->set_costs(2, kInfinityCost);

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {3, 0, 0, 0, 0};
const double objective_value = 10683.0;
Expand All @@ -317,11 +322,12 @@ TEST(CallORToolsSolverTest, AvoidsInfiniteNodeCosts) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, AvoidsInfiniteEdgeCosts) {
TEST(FormulateAndSolveMIPFromSolverRequestTest, AvoidsInfiniteEdgeCosts) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
request.mutable_resharding_costs(0)->set_costs(0, kInfinityCost);

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {0, 0, 1, 1, 0};
const double objective_value = 7872.0;
Expand All @@ -330,7 +336,7 @@ TEST(CallORToolsSolverTest, AvoidsInfiniteEdgeCosts) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, HandlesFollowedEdges) {
TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesFollowedEdges) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
AutoShardingSolverRequest_Pair edge;
edge.set_first(1);
Expand All @@ -346,7 +352,8 @@ TEST(CallORToolsSolverTest, HandlesFollowedEdges) {
70000, 71000, 72000, 73000}};
AddCosts(request.mutable_duration_costs(), t);

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {0, 0, 0, 0, 0};
const double objective_value = 12650.0;
Expand All @@ -355,7 +362,7 @@ TEST(CallORToolsSolverTest, HandlesFollowedEdges) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, HandlesCollapsedEdge) {
TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesCollapsedEdge) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
AutoShardingSolverRequest_Pair edge;
edge.set_first(2);
Expand All @@ -373,7 +380,8 @@ TEST(CallORToolsSolverTest, HandlesCollapsedEdge) {
80000, 81000, 82000, 83000}};
AddCosts(request.mutable_duration_costs(), t);

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {0, 0, 1, 1, 0};
const double objective_value = 13972.0;
Expand All @@ -382,12 +390,13 @@ TEST(CallORToolsSolverTest, HandlesCollapsedEdge) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, UsesHint) {
TEST(FormulateAndSolveMIPFromSolverRequestTest, UsesHint) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
const auto s_hint = {1, 0, 0, 0, 0}; // Not optimal, but close.
request.mutable_s_hint()->Add(s_hint.begin(), s_hint.end());

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {0, 0, 0, 0, 0};
const double objective_value = 7650.0;
Expand All @@ -396,20 +405,22 @@ TEST(CallORToolsSolverTest, UsesHint) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, HonorsMaxCost) {
TEST(FormulateAndSolveMIPFromSolverRequestTest, HonorsMaxCost) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
request.mutable_max_cost()->set_coeff(7600.0); // Best possible is 7650.0

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

EXPECT_TRUE(absl::IsInternal(result.status.status()));
}

TEST(CallORToolsSolverTest, HandlesExtremelyHighMaxCost) {
TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesExtremelyHighMaxCost) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
request.mutable_max_cost()->set_coeff(1e19);

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {0, 0, 0, 0, 0};
const double objective_value = 7650.0;
Expand All @@ -418,7 +429,7 @@ TEST(CallORToolsSolverTest, HandlesExtremelyHighMaxCost) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, HandlesMemoryEdgeCosts) {
TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesMemoryEdgeCosts) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
const EdgeMatrix live_edges = {{}, {0}, {0, 1}, {1}, {}};
const CostMatrix memory_edge_costs = {{1000000, 1100, 1200, 1300,
Expand All @@ -432,7 +443,8 @@ TEST(CallORToolsSolverTest, HandlesMemoryEdgeCosts) {
AddCosts(request.mutable_memory_edge_costs(), memory_edge_costs);
request.set_enable_memory_edge_costs(true);

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {0, 0, 1, 1, 0};
const double objective_value = 7872.0;
Expand All @@ -441,7 +453,7 @@ TEST(CallORToolsSolverTest, HandlesMemoryEdgeCosts) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, HandlesIntervals) {
TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesIntervals) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
const std::vector<std::pair<int64_t, int64_t>> node_intervals =
{{0, 4}, {0, 4}, {2, 3}, {3, 4}, {100, -1}};
Expand All @@ -460,7 +472,8 @@ TEST(CallORToolsSolverTest, HandlesIntervals) {
AddCosts(request.mutable_memory_edge_costs(), memory_edge_costs);
request.set_enable_memory_edge_costs(true);

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {0, 0, 1, 1, 0};
const double objective_value = 7872.0;
Expand All @@ -469,7 +482,8 @@ TEST(CallORToolsSolverTest, HandlesIntervals) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroups) {
TEST(FormulateAndSolveMIPFromSolverRequestTest,
HandlesReducedIntervalsAndGroups) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
const std::vector<std::pair<int64_t, int64_t>> node_intervals =
{{5, -1}, {5, -1}, {2, 3}, {3, 4}, {100, -1}, {0, 4}};
Expand All @@ -492,7 +506,8 @@ TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroups) {
AddCosts(request.mutable_memory_edge_costs(), memory_edge_costs);
request.set_enable_memory_edge_costs(true);

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {0, 0, 1, 1, 0};
const double objective_value = 7872.0;
Expand All @@ -501,7 +516,8 @@ TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroups) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroupsNoMemoryEdgeCosts) {
TEST(FormulateAndSolveMIPFromSolverRequestTest,
HandlesReducedIntervalsAndGroupsNoMemoryEdgeCosts) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
const std::vector<std::pair<int64_t, int64_t>> node_intervals =
{{5, -1}, {5, -1}, {2, 3}, {3, 4}, {100, -1}, {0, 4}};
Expand All @@ -511,7 +527,8 @@ TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroupsNoMemoryEdgeCosts) {
AddGroups(request.mutable_node_groups(), node_groups);
request.set_enable_memory_edge_costs(false);

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {0, 0, 0, 0, 0};
const double objective_value = 7650.0;
Expand All @@ -520,7 +537,8 @@ TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroupsNoMemoryEdgeCosts) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, HandlesGroupsWithTinyMemoryCosts) {
TEST(FormulateAndSolveMIPFromSolverRequestTest,
HandlesGroupsWithTinyMemoryCosts) {
AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest();
const std::vector<std::pair<int64_t, int64_t>> node_intervals =
{{5, -1}, {5, -1}, {2, 3}, {3, 4}, {100, -1}, {0, 4}};
Expand Down Expand Up @@ -551,7 +569,8 @@ TEST(CallORToolsSolverTest, HandlesGroupsWithTinyMemoryCosts) {
request.set_enable_memory_edge_costs(true);
request.set_memory_budget(4321);

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {0, 0, 0, 0, 0};
const double objective_value = 7650.0;
Expand All @@ -560,11 +579,12 @@ TEST(CallORToolsSolverTest, HandlesGroupsWithTinyMemoryCosts) {
EXPECT_EQ(result, expected_result);
}

TEST(CallORToolsSolverTest, SolvesWithEquivalences) {
TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesWithEquivalences) {
const AutoShardingSolverRequest request =
AutoShardingSolverRequestWithEquivalences();

const AutoShardingSolverResult result = CallORToolsSolver(request);
const AutoShardingSolverResult result =
FormulateAndSolveMIPFromSolverRequest(request);

const std::vector<NodeStrategyIdx> s_val = {0, 0, 5, 5, 1};
const double objective_value = 7650.0;
Expand Down
2 changes: 1 addition & 1 deletion xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ namespace spmd {

// A wrapper around the solver that converts the given objects into a
// combinatorial optimization problem & solves it.
AutoShardingSolverResult CallSolver(
AutoShardingSolverResult CreateAutoShardingSolverRequestAndCallSolver(
const HloModule& hlo_module, const HloLiveRange& hlo_live_range,
const StrategyMap& strategy_map, const StrategyGroups& strategy_groups,
const CostGraph& cost_graph, const AliasSet& alias_set,
Expand Down

0 comments on commit 2e2ab19

Please sign in to comment.