Skip to content

Commit d1b66f3

Browse files
mhhegerHenrZu
andauthored
1242 update simple surrogate model (#1252)
- Update scripts regarding the surrogate model with the secir model with one age group - Add grid search and hyperparameter tuning for these models - Optimization/rework of existing code Co-authored-by: Henrik Zunker <69154294+HenrZu@users.noreply.github.com>
1 parent add315f commit d1b66f3

File tree

10 files changed

+1838
-641
lines changed

10 files changed

+1838
-641
lines changed

pycode/memilio-surrogatemodel/memilio/surrogatemodel/ode_secir_groups/data_generation.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def generate_data(
241241
data['inputs'].append(data_run[:input_width])
242242
data['labels'].append(data_run[input_width:])
243243
data['contact_matrix'].append(np.array(damped_contact_matrix))
244-
data['damping_day'].append(damping_day)
244+
data['damping_day'].append([damping_day])
245245
bar.next()
246246
bar.finish()
247247

@@ -271,13 +271,13 @@ def getBaselineMatrix():
271271
""" loads the baselinematrix"""
272272

273273
baseline_contact_matrix0 = os.path.join(
274-
"./data/contacts/baseline_home.txt")
274+
"./data/Germany/contacts/baseline_home.txt")
275275
baseline_contact_matrix1 = os.path.join(
276-
"./data/contacts/baseline_school_pf_eig.txt")
276+
"./data/Germany/contacts/baseline_school_pf_eig.txt")
277277
baseline_contact_matrix2 = os.path.join(
278-
"./data/contacts/baseline_work.txt")
278+
"./data/Germany/contacts/baseline_work.txt")
279279
baseline_contact_matrix3 = os.path.join(
280-
"./data/contacts/baseline_other.txt")
280+
"./data/Germany/contacts/baseline_other.txt")
281281

282282
baseline = np.loadtxt(baseline_contact_matrix0) \
283283
+ np.loadtxt(baseline_contact_matrix1) + \
@@ -329,7 +329,7 @@ def get_population(path):
329329
os.path.dirname(os.path.realpath(path)))), 'data')
330330

331331
path_population = os.path.abspath(
332-
r"data//pydata//Germany//county_population.json")
332+
r"data//Germany//pydata//county_current_population.json")
333333

334334
input_width = 5
335335
label_width = 30

0 commit comments

Comments
 (0)