Skip to content

Commit

Permalink
Swiss scheduling improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
evroon committed Nov 21, 2024
1 parent 51f1edf commit 0eb6907
Show file tree
Hide file tree
Showing 12 changed files with 97 additions and 64 deletions.
57 changes: 38 additions & 19 deletions backend/bracket/logic/scheduling/ladder_teams.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import random
from collections import defaultdict

Expand Down Expand Up @@ -78,14 +79,27 @@ def get_possible_upcoming_matches_for_swiss(
if input_.id not in times_played_per_input:
times_played_per_input[input_.id] = 0

min_times_played = (
min(times_played_per_input.values()) if len(times_played_per_input) > 0 else 0
)

inputs1_random = random.choices(inputs_to_schedule, k=filter_.iterations)
inputs2_random = random.choices(inputs_to_schedule, k=filter_.iterations)

for i1, i2 in zip(inputs1_random, inputs2_random):
# If there are more possible matches to schedule (N * (N - 1)) than iteration count, then
# pick random combinations.
# Otherwise, when there's not too many inputs, just take all possible combinations.
# Note: `itertools.product` creates N * N results, so we look at N * N instead of N * (N - 1).
# For example: iteration count: 2_000, number of inputs: 20. Then N * N = 380,
# 380 is less than 2_000, so we just loop over all possible combinations.
N = len(inputs_to_schedule)
Item = tuple[StageItemInput, StageItemInput]
inputs_iter: itertools.product[Item] | zip[Item]
if N * N <= filter_.iterations:
inputs1 = inputs_to_schedule.copy()
inputs2 = inputs_to_schedule.copy()
random.shuffle(inputs1)
random.shuffle(inputs2)
inputs_iter = itertools.product(inputs1, inputs2)
else:
inputs1 = random.choices(inputs_to_schedule, k=filter_.iterations)
inputs2 = random.choices(inputs_to_schedule, k=filter_.iterations)
inputs_iter = zip(inputs1, inputs2)

for i1, i2 in inputs_iter:
if assert_some(i1.id) > assert_some(i2.id):
input2, input1 = i1, i2
elif assert_some(i1.id) < assert_some(i2.id):
Expand All @@ -97,22 +111,27 @@ def get_possible_upcoming_matches_for_swiss(
if get_match_hash(input1.id, input2.id) in previous_match_input_hashes:
continue

times_played_min = min(
times_played_per_input[input1.id],
times_played_per_input[input2.id],
)
suggested_match = check_input_combination_adheres_to_filter(
input1, input2, filter_, is_recommended=times_played_min <= min_times_played
input1,
input2,
filter_,
times_played_per_input[input1.id] + times_played_per_input[input2.id],
)
if (
suggested_match
and match_hash not in scheduled_hashes
and (not filter_.only_recommended or suggested_match.is_recommended)
):
if suggested_match and match_hash not in scheduled_hashes:
suggestions.append(suggested_match)
scheduled_hashes.append(match_hash)
scheduled_hashes.append(get_match_hash(input2.id, input1.id))

if len(suggestions) < 1:
return []

lowest_times_played_sum = min(sug.times_played_sum for sug in suggestions)
for sug in suggestions:
sug.is_recommended = sug.times_played_sum == lowest_times_played_sum

if filter_.only_recommended:
suggestions = [sug for sug in suggestions if sug.is_recommended]

sorted_by_elo = sorted(suggestions, key=lambda x: x.elo_diff)
sorted_by_times_played = sorted(sorted_by_elo, key=lambda x: x.is_recommended, reverse=True)
sorted_by_times_played = sorted(sorted_by_elo, key=lambda x: x.times_played_sum)
return sorted_by_times_played[: filter_.limit]
13 changes: 9 additions & 4 deletions backend/bracket/logic/scheduling/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
def get_suggested_match(
stage_item_input1: StageItemInput,
stage_item_input2: StageItemInput,
is_recommended: bool,
times_played_sum: int,
) -> SuggestedMatch:
elo_diff = abs(stage_item_input1.elo - stage_item_input2.elo)
swiss_diff = abs(stage_item_input1.points - stage_item_input2.points)
Expand All @@ -15,7 +15,8 @@ def get_suggested_match(
stage_item_input2=stage_item_input2,
elo_diff=elo_diff,
swiss_diff=swiss_diff,
is_recommended=is_recommended,
is_recommended=False,
times_played_sum=times_played_sum,
player_behind_schedule_count=0,
)

Expand All @@ -24,9 +25,13 @@ def check_input_combination_adheres_to_filter(
stage_item_input1: StageItemInput,
stage_item_input2: StageItemInput,
filter_: MatchFilter,
is_recommended: bool,
times_played_sum: int,
) -> SuggestedMatch | None:
suggested_match = get_suggested_match(stage_item_input1, stage_item_input2, is_recommended)
suggested_match = get_suggested_match(
stage_item_input1,
stage_item_input2,
times_played_sum,
)

if suggested_match.elo_diff <= filter_.elo_diff_threshold:
return suggested_match
Expand Down
7 changes: 2 additions & 5 deletions backend/bracket/logic/scheduling/upcoming_matches.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from bracket.models.db.match import MatchFilter, SuggestedMatch
from bracket.models.db.stage_item import StageType
from bracket.models.db.util import RoundWithMatches, StageItemWithRounds
from bracket.sql.rounds import get_rounds_for_stage_item
from bracket.sql.stages import get_full_tournament_details
from bracket.utils.id_types import StageItemId, TournamentId

Expand All @@ -28,10 +27,9 @@ async def get_draft_round_in_stage_item(
return draft_round, stage_item


async def get_upcoming_matches_for_swiss(
def get_upcoming_matches_for_swiss(
match_filter: MatchFilter,
stage_item: StageItemWithRounds,
tournament_id: TournamentId,
draft_round: RoundWithMatches | None = None,
) -> list[SuggestedMatch]:
if stage_item.type is not StageType.SWISS:
Expand All @@ -40,7 +38,6 @@ async def get_upcoming_matches_for_swiss(
if draft_round is not None and not draft_round.is_draft:
raise HTTPException(400, "There is no draft round, so no matches can be scheduled.")

rounds = await get_rounds_for_stage_item(tournament_id, stage_item.id)
return get_possible_upcoming_matches_for_swiss(
match_filter, rounds, stage_item.inputs, draft_round
match_filter, stage_item.rounds, stage_item.inputs, draft_round
)
1 change: 1 addition & 0 deletions backend/bracket/models/db/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ class SuggestedMatch(BaseModel):
elo_diff: Decimal
swiss_diff: Decimal
is_recommended: bool
times_played_sum: int
player_behind_schedule_count: int

@property
Expand Down
6 changes: 2 additions & 4 deletions backend/bracket/routes/matches.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async def get_matches_to_schedule(
tournament_id: TournamentId,
stage_item_id: StageItemId,
elo_diff_threshold: int = 200,
iterations: int = 200,
iterations: int = 2_000,
only_recommended: bool = False,
limit: int = 50,
_: UserPublic = Depends(user_authenticated_for_tournament),
Expand All @@ -68,9 +68,7 @@ async def get_matches_to_schedule(
return UpcomingMatchesResponse(data=[])

return UpcomingMatchesResponse(
data=await get_upcoming_matches_for_swiss(
match_filter, stage_item, tournament_id, draft_round
)
data=get_upcoming_matches_for_swiss(match_filter, stage_item, draft_round)
)


Expand Down
22 changes: 10 additions & 12 deletions backend/bracket/routes/rounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,13 @@
round_dependency,
round_with_matches_dependency,
)
from bracket.schema import rounds
from bracket.sql.rounds import get_next_round_name, set_round_active_or_draft, sql_create_round
from bracket.sql.matches import sql_delete_match
from bracket.sql.rounds import (
get_next_round_name,
set_round_active_or_draft,
sql_create_round,
sql_delete_round,
)
from bracket.sql.stage_items import get_stage_item
from bracket.sql.stages import get_full_tournament_details
from bracket.sql.validation import check_foreign_keys_belong_to_tournament
Expand All @@ -38,17 +43,10 @@ async def delete_round(
_: UserPublic = Depends(user_authenticated_for_tournament),
round_with_matches: RoundWithMatches = Depends(round_with_matches_dependency),
) -> SuccessResponse:
if len(round_with_matches.matches) > 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Round contains matches, delete those first",
)
for match in round_with_matches.matches:
await sql_delete_match(match.id)

await database.execute(
query=rounds.delete().where(
rounds.c.id == round_id and rounds.c.tournament_id == tournament_id
),
)
await sql_delete_round(round_id)

stage_item = await get_stage_item(tournament_id, round_with_matches.stage_item_id)
await recalculate_ranking_for_stage_item(tournament_id, stage_item)
Expand Down
20 changes: 12 additions & 8 deletions backend/bracket/routes/stage_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from starlette import status

from bracket.database import database
from bracket.logic.planning.conflicts import handle_conflicts
from bracket.logic.planning.matches import update_start_times_of_matches
from bracket.logic.planning.rounds import (
MatchTimingAdjustmentInfeasible,
get_draft_round,
Expand Down Expand Up @@ -42,6 +44,7 @@
)
from bracket.sql.shared import sql_delete_stage_item_with_foreign_keys
from bracket.sql.stage_items import (
get_stage_item,
sql_create_stage_item_with_empty_inputs,
)
from bracket.sql.stages import get_full_tournament_details
Expand Down Expand Up @@ -69,6 +72,7 @@ async def delete_stage_item(
{ForeignKey.matches_stage_item_input1_id_fkey, ForeignKey.matches_stage_item_input2_id_fkey}
):
await sql_delete_stage_item_with_foreign_keys(stage_item_id)
await update_start_times_of_matches(tournament_id)
return SuccessResponse()


Expand Down Expand Up @@ -130,8 +134,8 @@ async def start_next_round(
active_next_body: StageItemActivateNextBody,
stage_item: StageItemWithRounds = Depends(stage_item_dependency),
user: UserPublic = Depends(user_authenticated_for_tournament),
elo_diff_threshold: int = 100,
iterations: int = 200,
elo_diff_threshold: int = 200,
iterations: int = 2_000,
only_recommended: bool = False,
) -> SuccessResponse:
draft_round = get_draft_round(stage_item)
Expand All @@ -147,9 +151,7 @@ async def start_next_round(
limit=1,
iterations=iterations,
)
all_matches_to_schedule = await get_upcoming_matches_for_swiss(
match_filter, stage_item, tournament_id
)
all_matches_to_schedule = get_upcoming_matches_for_swiss(match_filter, stage_item)
if len(all_matches_to_schedule) < 1:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
Expand Down Expand Up @@ -179,8 +181,10 @@ async def start_next_round(

limit = len(courts) - len(draft_round.matches)
for ___ in range(limit):
all_matches_to_schedule = await get_upcoming_matches_for_swiss(
match_filter, stage_item, tournament_id
stage_item = await get_stage_item(tournament_id, stage_item_id)
draft_round = next(round_ for round_ in stage_item.rounds if round_.is_draft)
all_matches_to_schedule = get_upcoming_matches_for_swiss(
match_filter, stage_item, draft_round
)
if len(all_matches_to_schedule) < 1:
break
Expand Down Expand Up @@ -216,5 +220,5 @@ async def start_next_round(
) from exc

await set_round_active_or_draft(draft_round.id, tournament_id, is_draft=False)

await handle_conflicts(await get_full_tournament_details(tournament_id))
return SuccessResponse()
8 changes: 8 additions & 0 deletions backend/bracket/sql/rounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ async def sql_delete_rounds_for_stage_item_id(stage_item_id: StageItemId) -> Non
await database.execute(query=query, values={"stage_item_id": stage_item_id})


async def sql_delete_round(round_id: RoundId) -> None:
query = """
DELETE FROM rounds
WHERE rounds.id = :round_id
"""
await database.execute(query=query, values={"round_id": round_id})


async def set_round_active_or_draft(
round_id: RoundId, tournament_id: TournamentId, *, is_draft: bool
) -> None:
Expand Down
1 change: 1 addition & 0 deletions backend/tests/integration_tests/api/matches_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ async def test_upcoming_matches_endpoint(
},
"elo_diff": "0",
"swiss_diff": "0",
"times_played_sum": 0,
"is_recommended": True,
"player_behind_schedule_count": 0,
}
Expand Down
20 changes: 11 additions & 9 deletions backend/tests/unit_tests/swiss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,19 +89,21 @@ def test_constraints() -> None:
# is recommended.
assert result == [
SuggestedMatch(
stage_item_input1=input3,
stage_item_input2=input2,
elo_diff=Decimal("25"),
swiss_diff=Decimal("25"),
stage_item_input1=input4,
stage_item_input2=input3,
elo_diff=Decimal("50.0"),
swiss_diff=Decimal("50.0"),
is_recommended=True,
times_played_sum=0,
player_behind_schedule_count=0,
),
SuggestedMatch(
stage_item_input1=input4,
stage_item_input2=input3,
elo_diff=Decimal("50"),
swiss_diff=Decimal("50"),
is_recommended=True,
stage_item_input1=input3,
stage_item_input2=input2,
elo_diff=Decimal("25.0"),
swiss_diff=Decimal("25.0"),
is_recommended=False,
times_played_sum=1,
player_behind_schedule_count=0,
),
]
2 changes: 1 addition & 1 deletion frontend/src/pages/tournaments/[id]/dashboard/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ export function Schedule({
.filter((m1: any) => m1.match.start_time != null)
.sort(
(m1: any, m2: any) =>
formatTime(m1.match.start_time).localeCompare(formatTime(m2.match.start_time)) ||
-formatTime(m1.match.start_time).localeCompare(formatTime(m2.match.start_time)) ||
m1.match.court?.name.localeCompare(m2.match.court?.name)
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ export default function TournamentPage() {
const swrStagesResponse: SWRResponse = getStages(id);
const swrCourtsResponse = getCourts(tournamentData.id);
const [onlyRecommended, setOnlyRecommended] = useRouterQueryState('only-recommended', 'true');
const [eloThreshold, setEloThreshold] = useRouterQueryState('max-elo-diff', 100);
const [iterations, setIterations] = useRouterQueryState('iterations', 1000);
const [eloThreshold, setEloThreshold] = useRouterQueryState('max-elo-diff', 200);
const [iterations, setIterations] = useRouterQueryState('iterations', 2_000);
const [limit, setLimit] = useRouterQueryState('limit', 50);
const [matchVisibility, setMatchVisibility] = useRouterQueryState('match-visibility', 'all');
const [teamNamesDisplay, setTeamNamesDisplay] = useRouterQueryState('which-names', 'team-names');
Expand Down

0 comments on commit 0eb6907

Please sign in to comment.