Skip to content

Commit 4780226

Browse files
committed
Add new constraints to NL formulation, fix bugs
1 parent 482f3ca commit 4780226

File tree

6 files changed

+56
-64
lines changed

6 files changed

+56
-64
lines changed

demo_callbacks.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from dash import ALL, MATCH, Input, Output, State, ctx
2323
from dash.exceptions import PreventUpdate
2424

25-
from demo_enums import SolverType
25+
from src.demo_enums import SolverType
2626
import src.employee_scheduling as employee_scheduling
2727
import src.utils as utils
2828
from demo_configs import (
@@ -322,8 +322,6 @@ def run_optimization(
322322
availability = utils.availability_to_dict(sched_df["props"]["data"])
323323
employees = list(availability.keys())
324324

325-
isolated_days_allowed = True if 0 in checklist else False
326-
327325
forecast = [
328326
val if isinstance(val, int)
329327
else forecast_placeholder[i]
@@ -335,10 +333,10 @@ def run_optimization(
335333
shifts=shifts,
336334
min_shifts=min(shifts_per_employee),
337335
max_shifts=max(shifts_per_employee),
338-
forecast,
339-
allow_isolated_days_off=isolated_days_allowed,
336+
shift_forecast=forecast,
337+
allow_isolated_days_off=0 in checklist,
340338
max_consecutive_shifts=consecutive_shifts,
341-
num_full_time,
339+
num_full_time=num_full_time,
342340
)
343341

344342
if solver_type is SolverType.NL:

demo_configs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
"value": 12,
5959
}
6060

61-
# number of full time employees slider (value means default)
61+
# number of full-time employees slider (value means default)
6262
NUM_FULL_TIME = {
6363
"min": 0,
6464
"max": 9,

demo_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
THUMBNAIL,
3030
UNAVAILABLE_ICON,
3131
)
32-
from demo_enums import SolverType
32+
from src.demo_enums import SolverType
3333
from src.utils import COL_IDS
3434

3535

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
dash[diskcache]==2.16.1
22
dash-bootstrap-components==1.6.0
33
dwave-ocean-sdk>=7.0.0
4-
dwave-optimization>=0.3.0
4+
dwave-optimization>=0.4.0
55
Faker==21.0.0
66
pandas>=2.0

src/employee_scheduling.py

Lines changed: 38 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from dwave.optimization.symbols import BinaryVariable
2121
from dwave.system import LeapHybridCQMSampler, LeapHybridNLSampler
2222

23-
from utils import DAYS, FULL_TIME_SHIFTS, SHIFTS, validate_nl_schedule
23+
from src.utils import DAYS, FULL_TIME_SHIFTS, SHIFTS, validate_nl_schedule
2424

2525

2626
MSGS = {
@@ -42,7 +42,7 @@
4242
}
4343

4444

45-
def build_cqm(#params: ModelParams
45+
def build_cqm( # params: ModelParams
4646
availability: dict[str, list[int]],
4747
shifts: list[str],
4848
min_shifts: int,
@@ -62,7 +62,7 @@ def build_cqm(#params: ModelParams
6262
shift_forecast: A list of the number of expected employees needed per shift.
6363
allow_isolated_days_off: Whether on-off-on should be allowed in the schedule.
6464
max_consecutive_shifts: The maximum consectutive shifts to schedule a part-time employee for.
65-
num_full_time: The number of full time employees.
65+
num_full_time: The number of full-time employees.
6666
6767
Returns:
6868
cqm: A Constrained Quadratic Model representing the problem.
@@ -118,7 +118,7 @@ def build_cqm(#params: ModelParams
118118
)
119119

120120
for employee in employees_ft:
121-
# Schedule employees for at most max_shifts
121+
# Schedule full-time employees for all their shifts
122122
cqm.add_constraint(
123123
quicksum(x[employee, shift] for shift in shifts) <= FULL_TIME_SHIFTS,
124124
label=f"overtime,{employee},",
@@ -129,7 +129,7 @@ def build_cqm(#params: ModelParams
129129
label=f"insufficient,{employee},",
130130
)
131131

132-
# Every shift needs shift_min and shift_max employees working
132+
# Every shift needs shift_forecast employees working
133133
for i, shift in enumerate(shifts):
134134
cqm.add_constraint(
135135
sum(x[employee, shift] for employee in employees) >= shift_forecast[i],
@@ -234,16 +234,15 @@ def run_cqm(cqm: ConstrainedQuadraticModel):
234234
return feasible_sampleset, None
235235

236236

237-
def build_nl(
237+
def build_nl( # params: ModelParams
238238
availability: dict[str, list[int]],
239239
shifts: list[str],
240240
min_shifts: int,
241241
max_shifts: int,
242-
shift_min: int,
243-
shift_max: int,
244-
requires_manager: bool,
242+
shift_forecast: list,
245243
allow_isolated_days_off: bool,
246244
max_consecutive_shifts: int,
245+
num_full_time: int,
247246
) -> tuple[Model, BinaryVariable]:
248247
"""Builds an employee scheduling nonlinear model.
249248
@@ -252,11 +251,10 @@ def build_nl(
252251
shifts (list[str]): Shift labels.
253252
min_shifts (int): Minimum shifts per employee.
254253
max_shifts (int): Maximum shifts per employee.
255-
shift_min (int): Minimum employees per shift.
256-
shift_max (int): Maximum employees per shift.
257-
requires_manager (bool): Whether to require exactly one manager on every shift.
254+
shift_forecast (list[int]): A list of the number of expected employees needed per shift.
258255
allow_isolated_days_off (bool): Whether to allow isolated days off.
259256
max_consecutive_shifts (int): Maximum consecutive shifts per employee.
257+
num_full_time (int): The number of full-time employees.
260258
261259
Returns:
262260
tuple[Model, BinaryVariable]: the NL model and assignments decision variable
@@ -281,8 +279,8 @@ def build_nl(
281279
# Initialize model constants
282280
min_shifts_constant = model.constant(min_shifts)
283281
max_shifts_constant = model.constant(max_shifts)
284-
shift_min_constant = model.constant(shift_min)
285-
shift_max_constant = model.constant(shift_max)
282+
full_time_shifts_constant = model.constant(FULL_TIME_SHIFTS)
283+
shift_forecast_constant = model.constant(shift_forecast)
286284
max_consecutive_shifts_c = model.constant(max_consecutive_shifts)
287285
one_c = model.constant(1)
288286

@@ -292,28 +290,32 @@ def build_nl(
292290

293291
# Objective: for infeasible solutions, focus on right number of shifts for employees
294292
target_shifts = model.constant((min_shifts + max_shifts) / 2)
295-
shift_difference_list = [
296-
(assignments[e, :].sum() - target_shifts) ** 2 for e in range(num_employees)
293+
shift_difference_list_pt = [
294+
(assignments[e, :].sum() - target_shifts) ** 2 for e in range(num_full_time, num_employees)
297295
]
298-
obj += add(*shift_difference_list)
296+
shift_difference_list_ft = [
297+
(assignments[e, :].sum() - full_time_shifts_constant) ** 2 for e in range(num_full_time)
298+
]
299+
obj += add(*shift_difference_list_pt, *shift_difference_list_ft)
299300

300301
model.minimize(-obj)
301302

302303
# CONSTRAINTS:
303304
# Only schedule employees when they're available
304305
model.add_constraint((availability_const >= assignments).all())
305306

306-
for e in range(len(employees)):
307-
# Schedule employees for at most max_shifts
308-
model.add_constraint(assignments[e, :].sum() <= max_shifts_constant)
307+
# Schedule part-time employees for at most max_shifts
308+
model.add_constraint((assignments[num_full_time:, :].sum(axis=1) <= max_shifts_constant).all())
309309

310-
# Schedule employees for at least min_shifts
311-
model.add_constraint(assignments[e, :].sum() >= min_shifts_constant)
310+
# Schedule part-time employees for at least min_shifts
311+
model.add_constraint((assignments[num_full_time:, :].sum(axis=1) >= min_shifts_constant).all())
312+
313+
if num_full_time:
314+
# Schedule full-time employees for all their shifts
315+
model.add_constraint((assignments[:num_full_time, :].sum(axis=1) == full_time_shifts_constant).all())
312316

313-
# Every shift needs shift_min and shift_max employees working
314-
for s in range(num_shifts):
315-
model.add_constraint(assignments[:, s].sum() <= shift_max_constant)
316-
model.add_constraint(assignments[:, s].sum() >= shift_min_constant)
317+
# shft_fcst = model.constant(shift_forecast)
318+
model.add_constraint((assignments.sum(axis=0) == shift_forecast_constant).all())
317319

318320
managers_c = model.constant(
319321
[employees.index(e) for e in employees if e[-3:] == "Mgr"]
@@ -326,7 +328,7 @@ def build_nl(
326328
negthree_c = model.constant(-3)
327329
zero_c = model.constant(0)
328330
# Adding many small constraints greatly improves feasibility
329-
for e in range(len(employees)):
331+
for e in range(num_full_time, num_employees): # for part-time employees
330332
for s1 in range(len(shifts) - 2):
331333
s2, s3 = s1 + 1, s1 + 2
332334
model.add_constraint(
@@ -337,12 +339,11 @@ def build_nl(
337339
<= zero_c
338340
)
339341

340-
if requires_manager:
341-
for shift in range(len(shifts)):
342-
model.add_constraint(assignments[managers_c][:, shift].sum() == one_c)
342+
# At least 1 manager per shift
343+
model.add_constraint((assignments[managers_c].sum(axis=0) >= one_c).all())
343344

344-
# Don't exceed max_consecutive_shifts
345-
for e in range(num_employees):
345+
# Don't exceed max_consecutive_shifts for part-time employees
346+
for e in range(num_full_time, num_employees):
346347
for s in range(num_shifts - max_consecutive_shifts + 1):
347348
s_window = s + max_consecutive_shifts + 1
348349
model.add_constraint(
@@ -368,12 +369,11 @@ def run_nl(
368369
shifts: list[str],
369370
min_shifts: int,
370371
max_shifts: int,
371-
shift_min: int,
372-
shift_max: int,
373-
requires_manager: bool,
372+
shift_forecast: list[int],
374373
allow_isolated_days_off: bool,
375374
max_consecutive_shifts: int,
376-
time_limit: int | None = None,
375+
num_full_time: int,
376+
time_limit: Optional[int] = None,
377377
msgs: dict[str, tuple[str, str]] = MSGS,
378378
) -> Optional[defaultdict[str, list[str]]]:
379379
"""Solves the NL scheduling model and detects any errors.
@@ -395,11 +395,10 @@ def run_nl(
395395
shifts,
396396
min_shifts,
397397
max_shifts,
398-
shift_min,
399-
shift_max,
400-
requires_manager,
398+
shift_forecast,
401399
allow_isolated_days_off,
402400
max_consecutive_shifts,
401+
num_full_time,
403402
)
404403

405404
# Return errors if any error message list is populated

src/utils.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -59,23 +59,21 @@ class ModelParams:
5959
shifts (list[str]): List of shift labels.
6060
min_shifts (int): Min shifts per employee.
6161
max_shifts (int): Max shifts per employee.
62-
shift_min (int): Min employees per shift.
63-
shift_max (int): Max employees per shift.
64-
requires_manager (bool): Whether a manager is required on every shift.
62+
shift_forecast (list[int]): The forecasted employees per shift requirements.
6563
allow_isolated_days_off (bool): Whether isolated shifts off are allowed
6664
(pattern of on-off-on).
6765
max_consecutive_shifts (int): Max consecutive shifts for each employee.
66+
num_full_time: The number of full-time employees.
6867
"""
6968

7069
availability: dict[str, list[int]]
7170
shifts: list[str]
7271
min_shifts: int
7372
max_shifts: int
74-
shift_min: int
75-
shift_max: int
76-
requires_manager: bool
73+
shift_forecast: list[int]
7774
allow_isolated_days_off: bool
7875
max_consecutive_shifts: int
76+
num_full_time: int
7977

8078

8179
def get_random_string(length):
@@ -380,11 +378,10 @@ def validate_nl_schedule(
380378
shifts: list[str],
381379
min_shifts: int,
382380
max_shifts: int,
383-
shift_min: int,
384-
shift_max: int,
385-
requires_manager: bool,
381+
shift_forecast: list[int],
386382
allow_isolated_days_off: bool,
387383
max_consecutive_shifts: int,
384+
num_full_time: int,
388385
) -> defaultdict[str, list[str]]:
389386
"""Detect any errors in a solved NL scheduling model.
390387
@@ -438,11 +435,10 @@ def validate_nl_schedule(
438435

439436
_validate_availability(result, availability, employees, shift_labels, errors, msgs)
440437
_validate_shifts_per_employee(result, employees, min_shifts, max_shifts, errors, msgs)
441-
_validate_employees_per_shift(result, shift_min, shift_max, shift_labels, errors, msgs)
438+
_validate_employees_per_shift(result, shift_forecast, shift_labels, errors, msgs)
442439
_validate_max_consecutive_shifts(result, max_consecutive_shifts, employees, shift_labels, errors, msgs)
443440
_validate_trainee_shifts(result, employees, shift_labels, errors, msgs)
444-
if requires_manager:
445-
_validate_requires_manager(result, employees, shift_labels, errors, msgs)
441+
_validate_requires_manager(result, employees, shift_labels, errors, msgs)
446442
if not allow_isolated_days_off:
447443
_validate_isolated_days_off(result, employees, shift_labels, errors, msgs)
448444

@@ -496,8 +492,7 @@ def _validate_shifts_per_employee(
496492

497493
def _validate_employees_per_shift(
498494
results: np.ndarray,
499-
shift_min: int,
500-
shift_max: int,
495+
shift_forecast: list[int],
501496
shift_labels: list[int],
502497
errors: defaultdict[str, list[str]],
503498
msgs: dict[str, tuple[str, str]],
@@ -510,9 +505,9 @@ def _validate_employees_per_shift(
510505

511506
for s, day in enumerate(shift_labels):
512507
num_employees = results[:, s].sum()
513-
if num_employees < shift_min:
508+
if num_employees < shift_forecast[s]:
514509
errors[understaffed_key].append(understaffed_template.format(day=day))
515-
elif num_employees > shift_max:
510+
elif num_employees > shift_forecast[s]:
516511
errors[overstaffed_key].append(overstaffed_template.format(day=day))
517512
return errors
518513

0 commit comments

Comments
 (0)