Skip to content

Commit 49815c3

Browse files
committed
Add new constraints to NL formulation, fix bugs
1 parent 482f3ca commit 49815c3

File tree

5 files changed

+54
-57
lines changed

5 files changed

+54
-57
lines changed

demo_callbacks.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from dash import ALL, MATCH, Input, Output, State, ctx
2323
from dash.exceptions import PreventUpdate
2424

25-
from demo_enums import SolverType
25+
from src.demo_enums import SolverType
2626
import src.employee_scheduling as employee_scheduling
2727
import src.utils as utils
2828
from demo_configs import (
@@ -322,8 +322,6 @@ def run_optimization(
322322
availability = utils.availability_to_dict(sched_df["props"]["data"])
323323
employees = list(availability.keys())
324324

325-
isolated_days_allowed = True if 0 in checklist else False
326-
327325
forecast = [
328326
val if isinstance(val, int)
329327
else forecast_placeholder[i]
@@ -335,10 +333,10 @@ def run_optimization(
335333
shifts=shifts,
336334
min_shifts=min(shifts_per_employee),
337335
max_shifts=max(shifts_per_employee),
338-
forecast,
339-
allow_isolated_days_off=isolated_days_allowed,
336+
shift_forecast=forecast,
337+
allow_isolated_days_off=0 in checklist,
340338
max_consecutive_shifts=consecutive_shifts,
341-
num_full_time,
339+
num_full_time=num_full_time,
342340
)
343341

344342
if solver_type is SolverType.NL:

demo_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
THUMBNAIL,
3030
UNAVAILABLE_ICON,
3131
)
32-
from demo_enums import SolverType
32+
from src.demo_enums import SolverType
3333
from src.utils import COL_IDS
3434

3535

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
dash[diskcache]==2.16.1
22
dash-bootstrap-components==1.6.0
33
dwave-ocean-sdk>=7.0.0
4-
dwave-optimization>=0.3.0
4+
dwave-optimization>=0.4.0
55
Faker==21.0.0
66
pandas>=2.0

src/employee_scheduling.py

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from dwave.optimization.symbols import BinaryVariable
2121
from dwave.system import LeapHybridCQMSampler, LeapHybridNLSampler
2222

23-
from utils import DAYS, FULL_TIME_SHIFTS, SHIFTS, validate_nl_schedule
23+
from src.utils import DAYS, FULL_TIME_SHIFTS, SHIFTS, validate_nl_schedule
2424

2525

2626
MSGS = {
@@ -42,7 +42,7 @@
4242
}
4343

4444

45-
def build_cqm(#params: ModelParams
45+
def build_cqm( # params: ModelParams
4646
availability: dict[str, list[int]],
4747
shifts: list[str],
4848
min_shifts: int,
@@ -118,7 +118,7 @@ def build_cqm(#params: ModelParams
118118
)
119119

120120
for employee in employees_ft:
121-
# Schedule employees for at most max_shifts
121+
# Schedule full time employees for all their shifts
122122
cqm.add_constraint(
123123
quicksum(x[employee, shift] for shift in shifts) <= FULL_TIME_SHIFTS,
124124
label=f"overtime,{employee},",
@@ -129,7 +129,7 @@ def build_cqm(#params: ModelParams
129129
label=f"insufficient,{employee},",
130130
)
131131

132-
# Every shift needs shift_min and shift_max employees working
132+
# Every shift needs shift_forecast employees working
133133
for i, shift in enumerate(shifts):
134134
cqm.add_constraint(
135135
sum(x[employee, shift] for employee in employees) >= shift_forecast[i],
@@ -234,16 +234,15 @@ def run_cqm(cqm: ConstrainedQuadraticModel):
234234
return feasible_sampleset, None
235235

236236

237-
def build_nl(
237+
def build_nl( # params: ModelParams
238238
availability: dict[str, list[int]],
239239
shifts: list[str],
240240
min_shifts: int,
241241
max_shifts: int,
242-
shift_min: int,
243-
shift_max: int,
244-
requires_manager: bool,
242+
shift_forecast: list,
245243
allow_isolated_days_off: bool,
246244
max_consecutive_shifts: int,
245+
num_full_time: int,
247246
) -> tuple[Model, BinaryVariable]:
248247
"""Builds an employee scheduling nonlinear model.
249248
@@ -252,11 +251,10 @@ def build_nl(
252251
shifts (list[str]): Shift labels.
253252
min_shifts (int): Minimum shifts per employee.
254253
max_shifts (int): Maximum shifts per employee.
255-
shift_min (int): Minimum employees per shift.
256-
shift_max (int): Maximum employees per shift.
257-
requires_manager (bool): Whether to require exactly one manager on every shift.
254+
shift_forecast (list[int]): A list of the number of expected employees needed per shift.
258255
allow_isolated_days_off (bool): Whether to allow isolated days off.
259256
max_consecutive_shifts (int): Maximum consecutive shifts per employee.
257+
num_full_time (int): The number of full time employees.
260258
261259
Returns:
262260
tuple[Model, BinaryVariable]: the NL model and assignments decision variable
@@ -268,6 +266,7 @@ def build_nl(
268266
# Create a binary symbol representing the assignment of employees to shifts
269267
# i.e. assignments[employee][shift] = 1 if assigned, else 0
270268
num_employees = len(employees)
269+
num_part_time = num_employees - num_full_time
271270
num_shifts = len(shifts)
272271
assignments = model.binary((num_employees, num_shifts))
273272

@@ -281,8 +280,8 @@ def build_nl(
281280
# Initialize model constants
282281
min_shifts_constant = model.constant(min_shifts)
283282
max_shifts_constant = model.constant(max_shifts)
284-
shift_min_constant = model.constant(shift_min)
285-
shift_max_constant = model.constant(shift_max)
283+
full_time_shifts_constant = model.constant(FULL_TIME_SHIFTS)
284+
shift_forecast_constant = model.constant(shift_forecast)
286285
max_consecutive_shifts_c = model.constant(max_consecutive_shifts)
287286
one_c = model.constant(1)
288287

@@ -292,28 +291,36 @@ def build_nl(
292291

293292
# Objective: for infeasible solutions, focus on right number of shifts for employees
294293
target_shifts = model.constant((min_shifts + max_shifts) / 2)
295-
shift_difference_list = [
296-
(assignments[e, :].sum() - target_shifts) ** 2 for e in range(num_employees)
294+
shift_difference_list_pt = [
295+
(assignments[e, :].sum() - target_shifts) ** 2 for e in range(num_part_time)
296+
]
297+
shift_difference_list_ft = [
298+
(assignments[e, :].sum() - full_time_shifts_constant) ** 2 for e in range(num_full_time)
297299
]
298-
obj += add(*shift_difference_list)
300+
obj += add(*shift_difference_list_pt, *shift_difference_list_ft)
299301

300302
model.minimize(-obj)
301303

302304
# CONSTRAINTS:
303305
# Only schedule employees when they're available
304306
model.add_constraint((availability_const >= assignments).all())
305307

306-
for e in range(len(employees)):
308+
for e in range(num_part_time):
307309
# Schedule employees for at most max_shifts
308310
model.add_constraint(assignments[e, :].sum() <= max_shifts_constant)
309311

310312
# Schedule employees for at least min_shifts
311313
model.add_constraint(assignments[e, :].sum() >= min_shifts_constant)
312314

313-
# Every shift needs shift_min and shift_max employees working
314-
for s in range(num_shifts):
315-
model.add_constraint(assignments[:, s].sum() <= shift_max_constant)
316-
model.add_constraint(assignments[:, s].sum() >= shift_min_constant)
315+
for e in range(num_full_time):
316+
# Schedule full time employees for all their shifts
317+
model.add_constraint(assignments[e, :].sum() == full_time_shifts_constant)
318+
319+
# Every shift needs shift_forecast employees working
320+
# model.add_constraint(([assignments[:, s].sum() for s in range(num_shifts)] == shift_forecast_constant).all())
321+
322+
# shft_fcst = model.constant(shift_forecast)
323+
model.add_constraint((assignments.sum(axis=0) == shift_forecast_constant).all())
317324

318325
managers_c = model.constant(
319326
[employees.index(e) for e in employees if e[-3:] == "Mgr"]
@@ -326,7 +333,7 @@ def build_nl(
326333
negthree_c = model.constant(-3)
327334
zero_c = model.constant(0)
328335
# Adding many small constraints greatly improves feasibility
329-
for e in range(len(employees)):
336+
for e in range(num_part_time):
330337
for s1 in range(len(shifts) - 2):
331338
s2, s3 = s1 + 1, s1 + 2
332339
model.add_constraint(
@@ -337,12 +344,11 @@ def build_nl(
337344
<= zero_c
338345
)
339346

340-
if requires_manager:
341-
for shift in range(len(shifts)):
342-
model.add_constraint(assignments[managers_c][:, shift].sum() == one_c)
347+
for shift in range(len(shifts)):
348+
model.add_constraint(assignments[managers_c][:, shift].sum() >= one_c)
343349

344350
# Don't exceed max_consecutive_shifts
345-
for e in range(num_employees):
351+
for e in range(num_part_time):
346352
for s in range(num_shifts - max_consecutive_shifts + 1):
347353
s_window = s + max_consecutive_shifts + 1
348354
model.add_constraint(
@@ -368,12 +374,11 @@ def run_nl(
368374
shifts: list[str],
369375
min_shifts: int,
370376
max_shifts: int,
371-
shift_min: int,
372-
shift_max: int,
373-
requires_manager: bool,
377+
shift_forecast: list[int],
374378
allow_isolated_days_off: bool,
375379
max_consecutive_shifts: int,
376-
time_limit: int | None = None,
380+
num_full_time: int,
381+
time_limit: Optional[int] = None,
377382
msgs: dict[str, tuple[str, str]] = MSGS,
378383
) -> Optional[defaultdict[str, list[str]]]:
379384
"""Solves the NL scheduling model and detects any errors.
@@ -395,11 +400,10 @@ def run_nl(
395400
shifts,
396401
min_shifts,
397402
max_shifts,
398-
shift_min,
399-
shift_max,
400-
requires_manager,
403+
shift_forecast,
401404
allow_isolated_days_off,
402405
max_consecutive_shifts,
406+
num_full_time,
403407
)
404408

405409
# Return errors if any error message list is populated

src/utils.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -59,23 +59,21 @@ class ModelParams:
5959
shifts (list[str]): List of shift labels.
6060
min_shifts (int): Min shifts per employee.
6161
max_shifts (int): Max shifts per employee.
62-
shift_min (int): Min employees per shift.
63-
shift_max (int): Max employees per shift.
64-
requires_manager (bool): Whether a manager is required on every shift.
62+
shift_forecast (list[int]): The forecasted employees per shift requirements.
6563
allow_isolated_days_off (bool): Whether isolated shifts off are allowed
6664
(pattern of on-off-on).
6765
max_consecutive_shifts (int): Max consecutive shifts for each employee.
66+
num_full_time: The number of full-time employees.
6867
"""
6968

7069
availability: dict[str, list[int]]
7170
shifts: list[str]
7271
min_shifts: int
7372
max_shifts: int
74-
shift_min: int
75-
shift_max: int
76-
requires_manager: bool
73+
shift_forecast: list[int]
7774
allow_isolated_days_off: bool
7875
max_consecutive_shifts: int
76+
num_full_time: int
7977

8078

8179
def get_random_string(length):
@@ -380,11 +378,10 @@ def validate_nl_schedule(
380378
shifts: list[str],
381379
min_shifts: int,
382380
max_shifts: int,
383-
shift_min: int,
384-
shift_max: int,
385-
requires_manager: bool,
381+
shift_forecast: list[int],
386382
allow_isolated_days_off: bool,
387383
max_consecutive_shifts: int,
384+
num_full_time: int,
388385
) -> defaultdict[str, list[str]]:
389386
"""Detect any errors in a solved NL scheduling model.
390387
@@ -438,11 +435,10 @@ def validate_nl_schedule(
438435

439436
_validate_availability(result, availability, employees, shift_labels, errors, msgs)
440437
_validate_shifts_per_employee(result, employees, min_shifts, max_shifts, errors, msgs)
441-
_validate_employees_per_shift(result, shift_min, shift_max, shift_labels, errors, msgs)
438+
_validate_employees_per_shift(result, shift_forecast, shift_labels, errors, msgs)
442439
_validate_max_consecutive_shifts(result, max_consecutive_shifts, employees, shift_labels, errors, msgs)
443440
_validate_trainee_shifts(result, employees, shift_labels, errors, msgs)
444-
if requires_manager:
445-
_validate_requires_manager(result, employees, shift_labels, errors, msgs)
441+
_validate_requires_manager(result, employees, shift_labels, errors, msgs)
446442
if not allow_isolated_days_off:
447443
_validate_isolated_days_off(result, employees, shift_labels, errors, msgs)
448444

@@ -496,8 +492,7 @@ def _validate_shifts_per_employee(
496492

497493
def _validate_employees_per_shift(
498494
results: np.ndarray,
499-
shift_min: int,
500-
shift_max: int,
495+
shift_forecast: list[int],
501496
shift_labels: list[int],
502497
errors: defaultdict[str, list[str]],
503498
msgs: dict[str, tuple[str, str]],
@@ -510,9 +505,9 @@ def _validate_employees_per_shift(
510505

511506
for s, day in enumerate(shift_labels):
512507
num_employees = results[:, s].sum()
513-
if num_employees < shift_min:
508+
if num_employees < shift_forecast[s]:
514509
errors[understaffed_key].append(understaffed_template.format(day=day))
515-
elif num_employees > shift_max:
510+
elif num_employees > shift_forecast[s]:
516511
errors[overstaffed_key].append(overstaffed_template.format(day=day))
517512
return errors
518513

0 commit comments

Comments
 (0)