From 017d08147c88fc07c1e7d031b8134be69aaa645b Mon Sep 17 00:00:00 2001 From: Johnson Sun Date: Sat, 27 Jul 2024 01:06:08 +0800 Subject: [PATCH] Add required objectives --- core/nurse_scheduling/context.py | 18 +++++++++++ core/nurse_scheduling/objective_types.py | 26 ++++++++++++++++ core/nurse_scheduling/scheduler.py | 38 +++++++++++------------- core/tests/testcases/example_1.yaml | 4 +++ 4 files changed, 66 insertions(+), 20 deletions(-) create mode 100644 core/nurse_scheduling/context.py create mode 100644 core/nurse_scheduling/objective_types.py diff --git a/core/nurse_scheduling/context.py b/core/nurse_scheduling/context.py new file mode 100644 index 0000000..62a19c5 --- /dev/null +++ b/core/nurse_scheduling/context.py @@ -0,0 +1,18 @@ +class Context: + def __init__(self) -> None: + self.startdate = None + self.enddate = None + self.requirements = None + self.people = None + self.objectives = None + self.dates = None + self.n_days = None + self.n_requirements = None + self.n_people = None + self.model = None + self.shifts = None + self.map_dr_p = None + self.map_dp_r = None + self.map_d_rp = None + self.map_r_dp = None + self.map_p_dr = None diff --git a/core/nurse_scheduling/objective_types.py b/core/nurse_scheduling/objective_types.py new file mode 100644 index 0000000..a42e5ce --- /dev/null +++ b/core/nurse_scheduling/objective_types.py @@ -0,0 +1,26 @@ +from . import utils + +def all_requirements_fulfilled(ctx, args): + # Hard constraint + # For all shifts, the requirements (# of people) must be fulfilled. + # Note that a shift is represented as (d, r) + # i.e., sum_{p}(shifts[(d, r, p)]) == required_n_people, for all (d, r) + for (d, r), ps in ctx.map_dr_p.items(): + actual_n_people = sum(ctx.shifts[(d, r, p)] for p in ps) + required_n_people = utils.required_n_people(ctx.requirements[r]) + ctx.model.Add(actual_n_people == required_n_people) + +def all_people_work_at_most_one_shift_per_day(ctx, args): + # Hard constraint + # For all people, for all days, only work at most one shift. + # Note that a shift in day `d` can be represented as `r` instead of (d, r). + # i.e., sum_{r}(shifts[(d, r, p)]) <= 1, for all (d, p) + for (d, p), rs in ctx.map_dp_r.items(): + actual_n_shifts = sum(ctx.shifts[(d, r, p)] for r in rs) + maximum_n_shifts = 1 + ctx.model.Add(actual_n_shifts <= maximum_n_shifts) + +OBJECTIVE_TYPES_TO_FUNC = { + "all requirements fulfilled": all_requirements_fulfilled, + "all people work at most one shift per day": all_people_work_at_most_one_shift_per_day, +} diff --git a/core/nurse_scheduling/scheduler.py b/core/nurse_scheduling/scheduler.py index 3598dd1..ee63518 100644 --- a/core/nurse_scheduling/scheduler.py +++ b/core/nurse_scheduling/scheduler.py @@ -4,7 +4,8 @@ from ortools.sat.python import cp_model -from . import export, utils +from . import export, objective_types +from .context import Context from .dataloader import load_data @@ -13,14 +14,17 @@ def schedule(filepath: str, validate=True, deterministic=False): scenario = load_data(filepath, validate) logging.info("Extracting scenario data...") + if scenario.apiVersion != "alpha": + raise NotImplementedError(f"Unsupported API version: {scenario.apiVersion}") startdate = scenario.startdate enddate = scenario.enddate requirements = scenario.requirements people = scenario.people + objectives = scenario.objectives del scenario n_days = (enddate - startdate).days + 1 - n_people = len(people) n_requirements = len(requirements) + n_people = len(people) dates = [startdate + timedelta(days=d) for d in range(n_days)] logging.info("Initializing solver model...") @@ -33,7 +37,9 @@ def schedule(filepath: str, validate=True, deterministic=False): # Ref: https://developers.google.com/optimization/scheduling/employee_scheduling for d in range(n_days): for r in range(n_requirements): + # TODO(Optimize): Skip if no people is required in that day for p in range(n_people): + # TODO(Optimize): Skip if the person does not qualify for the requirement shifts[(d, r, p)] = model.NewBoolVar(f"shift_d{d}_r{r}_p{p}") logging.info("Creating maps for faster lookup...") @@ -58,24 +64,16 @@ def schedule(filepath: str, validate=True, deterministic=False): for p in range(n_people) } - logging.info("Adding preferences and constraints...") - # Hard constraint - # For all shifts, the requirements (# of people) must be fulfilled. - # Note that a shift is represented as (d, r) - # i.e., sum_{p}(shifts[(d, r, p)]) == required_n_people, for all (d, r) - for (d, r), ps in map_dr_p.items(): - actual_n_people = sum(shifts[(d, r, p)] for p in ps) - required_n_people = utils.required_n_people(requirements[r]) - model.Add(actual_n_people == required_n_people) - - # Hard constraint - # For all people, for all days, only work at most one shift. - # Note that a shift in day `d` can be represented as `r` instead of (d, r). - # i.e., sum_{r}(shifts[(d, r, p)]) <= 1, for all (d, p) - for (d, p), rs in map_dp_r.items(): - actual_n_shifts = sum(shifts[(d, r, p)] for r in rs) - maximum_n_shifts = 1 - model.Add(actual_n_shifts <= maximum_n_shifts) + ctx = Context() + for k in vars(ctx): + setattr(ctx, k, locals()[k]) + + logging.info("Adding objectives (i.e., preferences and constraints)...") + # TODO: Check no duplicated objectives + # TODO: Check no overlapping objectives + # TODO: Check all required objectives are present + for objective in objectives: + objective_types.OBJECTIVE_TYPES_TO_FUNC[objective.type](ctx, objective.args) logging.info("Initializing solver...") solver = cp_model.CpSolver() diff --git a/core/tests/testcases/example_1.yaml b/core/tests/testcases/example_1.yaml index ddee775..41b9fd2 100644 --- a/core/tests/testcases/example_1.yaml +++ b/core/tests/testcases/example_1.yaml @@ -1,3 +1,4 @@ +apiVersion: alpha description: Simple Example 1 startdate: 2023-08-18 enddate: 2023-08-20 @@ -20,3 +21,6 @@ requirements: - id: N description: Night shift requirement required_people: 1 +objectives: + - type: all requirements fulfilled + - type: all people work at most one shift per day