Skip to content

Commit 95aeacd

Browse files
committed
update tests to use asdict with ModelParams
1 parent 1a024bc commit 95aeacd

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

tests/test_inputs.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import unittest
15+
from dataclasses import asdict
1516

1617
from dash import dash_table
1718
from numpy import asarray
@@ -47,14 +48,14 @@ def test_initial_sched(self):
4748

4849
# Check that CQM created has the right number of variables
4950
def test_cqm(self):
50-
cqm = employee_scheduling.build_cqm(self.test_params)
51+
cqm = employee_scheduling.build_cqm(**asdict(self.test_params))
5152

5253
self.assertEqual(len(cqm.variables),
5354
self.num_employees * len(self.shifts))
5455

5556
# Check that NL assignments variable is the correct shape
5657
def test_nl(self):
57-
_, assignments = employee_scheduling.build_nl(self.test_params)
58+
_, assignments = employee_scheduling.build_nl(**asdict(self.test_params))
5859

5960
self.assertEqual(assignments.shape(),
6061
(self.num_employees, len(self.shifts)))
@@ -84,7 +85,7 @@ def test_samples_cqm(self):
8485
max_consecutive_shifts=6
8586
)
8687

87-
cqm = employee_scheduling.build_cqm(test_params)
88+
cqm = employee_scheduling.build_cqm(**asdict(test_params))
8889

8990
feasible_sample = {
9091
"A-Mgr_1": 0.0,
@@ -180,7 +181,7 @@ def test_states_nl(self):
180181
max_consecutive_shifts=6
181182
)
182183

183-
model, assignments = employee_scheduling.build_nl(test_params)
184+
model, assignments = employee_scheduling.build_nl(**asdict(test_params))
184185

185186
if not model.is_locked():
186187
model.lock()
@@ -274,6 +275,7 @@ def test_build_from_sample(self):
274275

275276
def test_build_from_state(self):
276277
employees = ["A-Mgr", "B-Mgr", "C", "D", "E", "E-Tr"]
278+
shifts = [str(i+1) for i in range(14)]
277279

278280
# Make every employee available for every shift
279281
availability = {
@@ -286,15 +288,17 @@ def test_build_from_state(self):
286288
}
287289

288290
state = asarray([
289-
[0, 0, 1, 1, 1],
290-
[1, 1, 0, 0, 0],
291-
[1, 1, 1, 1, 1],
292-
[1, 1, 1, 1, 1],
293-
[1, 1, 1, 1, 1]
291+
# Give managers alternating shifts
292+
[i % 2 for i in range(14)],
293+
[(i+1) % 2 for i in range(14)],
294+
[1 for _ in range(14)],
295+
[1 for _ in range(14)],
296+
[1 for _ in range(14)],
297+
[1 for _ in range(14)],
294298
])
295299

296300
disp_datatable = utils.display_schedule(
297-
utils.build_schedule_from_state(state, employees),
301+
utils.build_schedule_from_state(state, employees, shifts),
298302
availability
299303
)
300304

0 commit comments

Comments
 (0)