From 840303dde091a6544a6609a1e23a1cc4a95bf297 Mon Sep 17 00:00:00 2001 From: Sascha <51127093+xsaschako@users.noreply.github.com> Date: Tue, 27 May 2025 21:02:44 +0200 Subject: [PATCH 01/11] Add scripts for analyzing infection states and ICU data from simulation runs --- .../memilio/plot/plotAbmICUAndDeadComp.py | 268 ++++++++++++++++++ .../memilio/plot/plotAbmInfectionStates.py | 258 +++++++++++++++++ 2 files changed, 526 insertions(+) create mode 100644 pycode/memilio-plot/memilio/plot/plotAbmICUAndDeadComp.py create mode 100644 pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py diff --git a/pycode/memilio-plot/memilio/plot/plotAbmICUAndDeadComp.py b/pycode/memilio-plot/memilio/plot/plotAbmICUAndDeadComp.py new file mode 100644 index 0000000000..6a1af97ed5 --- /dev/null +++ b/pycode/memilio-plot/memilio/plot/plotAbmICUAndDeadComp.py @@ -0,0 +1,268 @@ +# Python script to analyze bs runs +# input is a bs run folder with the following structure: +# bs_run_folder has a txt file for each bs run +# each txt file has a line for each time step +# each line has a column for each compartment as well as the timestep +# each column has the number of individuals in that compartment +# the first line of each txt file is the header + +import sys +import argparse +import os +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import matplotlib +import matplotlib.colors as colors +import matplotlib.cm as cmx +import matplotlib.patches as mpatches +import matplotlib.lines as mlines +import h5py +from datetime import datetime +from matplotlib.dates import DateFormatter +from scipy.ndimage import gaussian_filter1d +from scipy.signal import savgol_filter + +fontsize = 20 + + +def plot_dead(path): + # we will have a seperate plot the cumulative infected individuals, cumulative symptomatic individuals and cumulative dead individual + # we need to load the data + f_p50 = h5py.File( + path+"/infection_state_per_age_group/0/p50/Results.h5", 'r') + p50_bs = f_p50['0'] + + # do the same for 25 and 75 percentile + f_p25 = h5py.File( + path+"/infection_state_per_age_group/0/p25/Results.h5", 'r') + p25_bs = f_p25['0'] + + f_p75 = h5py.File( + path+"/infection_state_per_age_group/0/p75/Results.h5", 'r') + p75_bs = f_p75['0'] + + # do the same for 05 and 95 percentile + f_p05 = h5py.File( + path+"/infection_state_per_age_group/0/p05/Results.h5", 'r') + p05_bs = f_p05['0'] + + f_p95 = h5py.File( + path+"/infection_state_per_age_group/0/p95/Results.h5", 'r') + p95_bs = f_p95['0'] + + age_group_access = ['Group1', 'Group2', 'Group3', + 'Group4', 'Group5', 'Group6', 'Total'] + + # we need the real data json file cases_all_county_age + df_abb = pd.read_json( + path+"/../../../pydata/Germany/cases_all_county_age_ma1.json") + + # we just need the columns cases and date + # we need to offset the dates by 19 day + df_abb['Date'] = df_abb['Date'] + pd.DateOffset(days=18) + # we need just the dates bewteen 2021-03-01 and 2021-06-01 + df_abb = df_abb[(df_abb['Date'] >= '2021-03-01') & + (df_abb['Date'] <= '2021-06-01')] + # we just need the cases with id 3101 + df_abb = df_abb[df_abb['ID_County'] == 3101] + # df_abb['Deaths'] = np.round(df_abb[['Deaths']].to_numpy()) + + # we need the amount of dead persons for each age group: These are A00-A04, A05-A14, A15-A34, A35-A59, A60-A79, A80+ + age_groups = ['A00-A04', 'A05-A14', 'A15-A34', 'A35-A59', 'A60-A79', 'A80+'] + age_grous_string = ['Age 0-4', 'Age 5-14', 'Age 15-34', 'Age 35-59', 'Age 60-79', 'Age 80+'] + # we need to sum up the amount of dead persons for each age group + + # we want the deaths for the age groups + df_abb = df_abb[['Date', 'Deaths', 'Age_RKI']] + # we want a plot with 2 rows. Second row has a plot with each age group and the simulated and real dead persons + # First row has the cumulative dead persons + fig = plt.figure('Deaths') + fig.set_figwidth(20) + fig.set_figheight(9) + gs = fig.add_gridspec(2,6) + + # we need the cumulative dead persons + ax = fig.add_subplot(gs[0, :]) + df_total_dead = df_abb.groupby('Date').sum()[0:90] + y_real = df_total_dead['Deaths'].to_numpy() + # we need to substract the first value from the rest + y_real = y_real - y_real[0] + + y_sim = p50_bs['Total'][()][:, 7][::24][0:90] + y_sim = y_sim - y_sim[0] + + y_sim25 = p25_bs['Total'][()][:, 7][::24][0:90] + y_sim25 = y_sim25 - y_sim25[0] + + y_sim75 = p75_bs['Total'][()][:, 7][::24][0:90] + y_sim75 = y_sim75 - y_sim75[0] + + y_sim05 = p05_bs['Total'][()][:, 7][::24][0:90] + y_sim05 = y_sim05 - y_sim05[0] + + y_sim95 = p95_bs['Total'][()][:, 7][::24][0:90] + y_sim95 = y_sim95 - y_sim95[0] + + + + # we calculate the RMSE + rmse_dead = np.sqrt(((y_real- y_sim)**2).mean()) + # we need to plot the cumulative dead persons from the real world and from the simulation + + ax.plot(df_total_dead.index, y_sim, color='tab:blue',label='Simulated deaths') + ax.plot(df_total_dead.index, y_real, 'v',color='tab:red', linewidth=4, label='Extrapolated deaths from reported infection case data') + ax.fill_between(df_total_dead.index, y_sim75, y_sim25, + alpha=0.5, color='tab:blue', label='50% Confidence interval') + ax.fill_between(df_total_dead.index,y_sim95, y_sim05, + alpha=0.25, color='tab:blue', label='90% Confidence interval') + # ax.text(0.25, 0.8, 'RMSE: '+str(float("{:.2f}".format(rmse_dead))), horizontalalignment='center', + # verticalalignment='center', transform=plt.gca().transAxes, color='pink', fontsize=15) + ax.set_label('Number of individuals') + ax.set_title('Cumulative Deaths', fontsize=fontsize) + ax.set_ylabel('Number of individuals', fontsize=fontsize-8) + ax.legend(fontsize=fontsize-8) + + # now for each age group + for i, age_group in zip(range(6), age_group_access): + ax = fig.add_subplot(gs[1, i]) + # we need the amount of dead persons for each age group + df_abb_age_group = df_abb[df_abb['Age_RKI'] == age_groups[i]][0:90] + y_real = np.round(df_abb_age_group['Deaths'].to_numpy()) + # we need to plot the dead persons from the real world and from the simulation + ax.plot(df_abb_age_group['Date'], y_real-y_real[0], color='tab:red') + ax.plot(df_abb_age_group['Date'], p50_bs[age_group_access[i]][()][:, 7][::24][0:90]-p50_bs[age_group_access[i]][()][:, 7][::24][0], color='tab:blue') + ax.fill_between(df_abb_age_group['Date'], p75_bs[age_group_access[i]][()][:, 7][::24][0:90]-p75_bs[age_group_access[i]][()][:, 7][::24][0], p25_bs[age_group_access[i]][()][:, 7][::24][0:90]-p25_bs[age_group_access[i]][()][:, 7][::24][0], + alpha=0.5, color='tab:blue') + ax.set_title('Deaths, '+age_grous_string[i]) + ax.set_ybound(lower=0) + ax.set_xticks(df_abb_age_group['Date'][::50]) + ax.tick_params(axis='both', which='major', labelsize=fontsize-10) + ax.tick_params(axis='both', which='minor', labelsize=fontsize-10) + if i == 0: + ax.set_ylabel('Number of individuals',fontsize=fontsize-8) + ax.set_ybound(upper=1) + + plt.show() + +def plot_icu(path): + + df_abb = pd.read_json(path+"/../../../pydata/Germany/county_divi.json") + + perc_of_critical_in_icu_age = [0.55,0.55,0.55,0.56,0.54,0.46] + perc_of_critical_in_icu=0.55 + + age_group_access = ['Group1', 'Group2', 'Group3', + 'Group4', 'Group5', 'Group6', 'Total'] + + + # we just need the columns ICU_low and ICU_hig + df_abb = df_abb[['ID_County', 'ICU', 'Date']] + + df_abb = df_abb[df_abb['ID_County'] == 3101] + # we need just the dates bewteen 2021-03-01 and 2021-06-01 + df_abb = df_abb[(df_abb['Date'] >= '2021-03-01') & + (df_abb['Date'] <= '2021-06-01')] + + # we plot this against this the Amount of persons in the ICU from our model + f_p50 = h5py.File( + path+"/infection_state_per_age_group/0/p50/Results.h5", 'r') + total_50 = f_p50['0']['Total'][()][::24][0:90] + + total_50_age = f_p50['0'][age_group_access[0]][()] + for i in range(6): + total_50_age += f_p50['0'][age_group_access[i]][()]*perc_of_critical_in_icu_age[i] + total_50_age = total_50_age[::24][0:90] + + + # we plot this against this the Amount of persons in the ICU from our model + f_p75 = h5py.File( + path+"/infection_state_per_age_group/0/p75/Results.h5", 'r') + # total_75 = f_p75['0']['Total'][()][::24][0:90] + total_75_age = f_p75['0'][age_group_access[0]][()] + for i in range(6): + total_75_age += f_p75['0'][age_group_access[i]][()]*perc_of_critical_in_icu_age[i] + total_75_age = total_75_age[::24][0:90] + + # same with 25 percentile + f_p25 = h5py.File( + path+"/infection_state_per_age_group/0/p25/Results.h5", 'r') + # total_25 = f_p25['0']['Total'][()][::24][0:90] + total_25_age = f_p25['0'][age_group_access[0]][()] + for i in range(6): + total_25_age += f_p25['0'][age_group_access[i]][()]*perc_of_critical_in_icu_age[i] + total_25_age = total_25_age[::24][0:90] + + # same with 05 and 95 percentile + f_p05 = h5py.File( + path+"/infection_state_per_age_group/0/p05/Results.h5", 'r') + # total_05 = f_p05['0']['Total'][()][::24][0:90] + total_05_age = f_p05['0'][age_group_access[0]][()] + for i in range(6): + total_05_age += f_p05['0'][age_group_access[i]][()]*perc_of_critical_in_icu_age[i] + total_05_age = total_05_age[::24][0:90] + + f_p95 = h5py.File( + path+"/infection_state_per_age_group/0/p95/Results.h5", 'r') + # total_95 = f_p95['0']['Total'][()][::24][0:90] + total_95_age = f_p95['0'][age_group_access[0]][()] + for i in range(6): + total_95_age += f_p95['0'][age_group_access[i]][()]*perc_of_critical_in_icu_age[i] + total_95_age = total_95_age[::24][0:90] + + + ICU_Simulation_one_percentile = np.floor(total_50[:, 5]*perc_of_critical_in_icu) + ICU_Simulation = np.round(total_50_age[:, 5]) + ICU_Simulation75 = np.round(total_75_age[:, 5]) + ICU_Simulation25 = np.round(total_25_age[:, 5]) + ICU_Simulation05 = np.round(total_05_age[:, 5]) + ICU_Simulation95 = np.round(total_95_age[:, 5]) + ICU_Real = df_abb['ICU'][0:90] + + #smooth the data + # ICU_Real = gaussian_filter1d(ICU_Real, sigma=1, mode='nearest') + # ICU_Simulation = gaussian_filter1d(ICU_Simulation, sigma=1, mode='nearest') + + + + # we calculate the RMSE + rmse_ICU = np.sqrt(((ICU_Real - ICU_Simulation_one_percentile)**2).mean()) + + # plot the ICU beds and the ICU beds taken + fig, ax = plt.subplots(1, 1, constrained_layout=True) + fig.set_figwidth(12) + fig.set_figheight(9) + # we plot the ICU_low and the ICU_high + ax.plot(df_abb['Date'][0:90], ICU_Real,'x', color='tab:red', linewidth=10, label='Data') + ax.plot(df_abb['Date'][0:90], ICU_Simulation, color='tab:blue', label='Simulation') + # ax.plot(df_abb['Date'][0:90], ICU_Simulation_one_percentile, color='tab:green', label='Simulated ICU beds') + ax.fill_between(df_abb['Date'][0:90],ICU_Simulation75, ICU_Simulation25, + alpha=0.5, color='tab:blue', label='50% Confidence interval') + ax.fill_between(df_abb['Date'][0:90],ICU_Simulation05, ICU_Simulation95, + alpha=0.25, color='tab:blue', label='90% Confidence interval') + + + # we also write the rmse + # ax.text(0.25, 0.8, 'RMSE: '+str(float("{:.2f}".format(rmse_ICU))), horizontalalignment='center', + # verticalalignment='center', transform=plt.gca().transAxes, color='pink', fontsize=15) + ax.tick_params(axis='both', which='major', labelsize=fontsize-4) + ax.tick_params(axis='both', which='minor', labelsize=fontsize-4) + ax.set_ylabel('Occupied ICU beds', fontsize=fontsize) + ax.set_title('ICU beds', fontsize=fontsize+4) + ax.legend(fontsize=fontsize-4) + plt.show() + + + + +if __name__ == "__main__": + path = "" + + if (len(sys.argv) > 1): + n_runs = sys.argv[1] + else: + n_runs = len([entry for entry in os.listdir(path) + if os.path.isfile(os.path.join(path, entry))]) + + # plot_icu(path) + # plot_dead(path) diff --git a/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py b/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py new file mode 100644 index 0000000000..cf33fb8d3d --- /dev/null +++ b/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py @@ -0,0 +1,258 @@ +# Python script to analyze bs runs +# input is a bs run folder with the following structure: +# bs_run_folder has a txt file for each bs run +# each txt file has a line for each time step +# each line has a column for each compartment as well as the timestep +# each column has the number of individuals in that compartment +# the first line of each txt file is the header + +import sys +import argparse +import os +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import matplotlib +import matplotlib.colors as colors +import matplotlib.cm as cmx +import matplotlib.patches as mpatches +import matplotlib.lines as mlines +import h5py +from datetime import datetime +from matplotlib.dates import DateFormatter +from scipy.ndimage import gaussian_filter1d +from scipy.signal import savgol_filter + +fontsize = 20 + + +def plot_infections_loc_types_avarage(path): + # 50-percentile + f_p50 = h5py.File( + path+"/infection_per_location_type_per_age_group/0/p50/Results.h5", 'r') + p50_bs = f_p50['0'] + total_50 = p50_bs['Total'][()] + + # 25-percentile + f_p25 = h5py.File( + path+"/infection_per_location_type_per_age_group/0/p05/Results.h5", 'r') + p25_bs = f_p25['0'] + total_25 = p25_bs['Total'][()] + + # 75-percentile + f_p75 = h5py.File( + path + "/infection_per_location_type_per_age_group/0/p95/Results.h5", 'r') + p75_bs = f_p75['0'] + total_75 = p75_bs['Total'][()] + + time = p50_bs['Time'][()] + + plot_infection_per_location_type_mean( + time, total_50, total_25, total_75) + +def plot_infection_per_location_type_mean(x, y50, y25, y75): + + plt.figure('Infection_location_types') + plt.title('At which location type an infection happened, avaraged over all runs') + + color_plot = matplotlib.colormaps.get_cmap('Set1').colors + + states_plot = [0, 1, 2, 3, 4, 10] + legend_plot = ['Home', 'School', 'Work', + 'SocialEvent', 'BasicsShop','Event'] + + for i in states_plot: + # rolling average# + ac_color = color_plot[i%len(color_plot)] + if(i > len(color_plot)): + ac_color = "black" + + # we need to sum up every 24 hours + indexer = pd.api.indexers.FixedForwardWindowIndexer(window_size=24) + np_y50 = pd.DataFrame(y50[:, i]).rolling(window=indexer, min_periods=1).sum().to_numpy() + np_y50=np_y50[0::24].flatten() + # now smoothen this with a gaussian filter + np_y50 = gaussian_filter1d(np_y50, sigma=1, mode='nearest') + + plt.plot(x[0::24], np_y50, color=ac_color) + + plt.legend(legend_plot) + + # currently the x axis has the values of the time steps, we need to convert them to dates and set the x axis to dates + start_date = datetime.strptime('2021-03-01', '%Y-%m-%d') + xx = [start_date + pd.Timedelta(days=int(i)) for i in x] + xx = [xx[i].strftime('%Y-%m-%d') for i in range(len(xx))] + # but just take every 10th date to make it more readable + plt.gca().set_xticks(x[::150]) + plt.gca().set_xticklabels(xx[::150]) + plt.gcf().autofmt_xdate() + + plt.xlabel('Date') + plt.ylabel('Number of individuals') + plt.show() + +def plot_infection_states_results(path): + # 50-percentile + f_p50 = h5py.File( + path+"/infection_state_per_age_group/0/p50/Results.h5", 'r') + p50_bs = f_p50['0'] + total_50 = p50_bs['Total'][()] + # 25-percentile + f_p25 = h5py.File( + path+"/infection_state_per_age_group/0/p05/Results.h5", 'r') + p25_bs = f_p25['0'] + total_25 = p25_bs['Total'][()] + # 75-percentile + f_p75 = h5py.File( + path + "/infection_state_per_age_group/0/p95/Results.h5", 'r') + p75_bs = f_p75['0'] + total_75 = p75_bs['Total'][()] + + time = p50_bs['Time'][()] + + plot_infection_states_individual( + time, p50_bs, p25_bs, p75_bs) + plot_infection_states(time, total_50, total_25, total_75) + +def plot_infection_states(x, y50, y25, y75, y_real=None): + plt.figure('Infection_states') + plt.title('Infection states') + + color_plot = matplotlib.colormaps.get_cmap('Set1').colors + + states_plot = [1, 2, 3, 4, 5, 7] + legend_plot = ['E', 'I_NSymp', 'I_Symp', + 'I_Sev', 'I_Crit', 'Dead', 'Sm. re. pos.'] + + for i in states_plot: + plt.plot(x, y50[:, i], color=color_plot[i]) + + + plt.legend(legend_plot) + for i in states_plot: + plt.fill_between(x, y50[:, i], y25[:, i], + alpha=0.5, color=color_plot[i]) + plt.fill_between(x, y50[:, i], y75[:, i], + alpha=0.5, color=color_plot[i]) + + # currently the x axis has the values of the time steps, we need to convert them to dates and set the x axis to dates + start_date = datetime.strptime('2021-03-01', '%Y-%m-%d') + xx = [start_date + pd.Timedelta(days=int(i)) for i in x] + xx = [xx[i].strftime('%Y-%m-%d') for i in range(len(xx))] + # but just take every 10th date to make it more readable + plt.gca().set_xticks(x[::150]) + plt.gca().set_xticklabels(xx[::150]) + plt.gcf().autofmt_xdate() + + plt.xlabel('Time') + plt.ylabel('Number of individuals') + plt.show() + +def plot_infection_states_individual(x, p50_bs, p25_bs, p75_bs): + + + age_group_access = ['Group1', 'Group2', 'Group3', + 'Group4', 'Group5', 'Group6', 'Total'] + + color_plot = matplotlib.colormaps.get_cmap('Set1').colors + + fig, ax = plt.subplots(6, len(age_group_access), constrained_layout=True) + fig.set_figwidth(20) + fig.set_figheight(9) + for j, count in zip(age_group_access, range(len(age_group_access))): + y50 = p50_bs[j][()] + y25 = p25_bs[j][()] + y75 = p75_bs[j][()] + + + # infeced no symptoms + ax_infected_no_symptoms = ax[0, count] + ax_infected_no_symptoms.set_xlabel('time (days)') + ax_infected_no_symptoms.plot( + x, y50[:, 1], color=color_plot[count], label='y50') + ax_infected_no_symptoms.fill_between( + x, y50[:, 1], y25[:, 1], alpha=0.5, color=color_plot[count]) + ax_infected_no_symptoms.fill_between( + x, y50[:, 1], y75[:, 1], alpha=0.5, color=color_plot[count]) + ax_infected_no_symptoms.tick_params(axis='y') + ax_infected_no_symptoms.title.set_text( + '#Infected_no_symptoms, Age{}'.format(j)) + ax_infected_no_symptoms.legend(['Simulation']) + + # Infected_symptoms + ax_infected_symptoms = ax[1, count] + ax_infected_symptoms.set_xlabel('time (days)') + ax_infected_symptoms.plot( + x, y50[:, 2], color=color_plot[count], label='y50') + ax_infected_symptoms.fill_between( + x, y50[:, 2], y25[:, 2], alpha=0.5, color=color_plot[count]) + ax_infected_symptoms.fill_between( + x, y50[:, 2], y75[:, 2], alpha=0.5, color=color_plot[count]) + ax_infected_symptoms.tick_params(axis='y') + ax_infected_symptoms.title.set_text( + '#Infected_symptoms, Age{}'.format(j)) + ax_infected_symptoms.legend(['Simulation']) + + # Severe + ax_severe = ax[2, count] + ax_severe.set_xlabel('time (days)') + ax_severe.plot(x, y50[:, 4], color=color_plot[count], label='y50') + ax_severe.fill_between( + x, y50[:, 4], y25[:, 4], alpha=0.5, color=color_plot[count]) + ax_severe.fill_between( + x, y50[:, 4], y75[:, 4], alpha=0.5, color=color_plot[count]) + ax_severe.tick_params(axis='y') + ax_severe.title.set_text('#Severe, Age{}'.format(j)) + ax_severe.legend(['Simulation']) + + # Critical + ax_critical = ax[3, count] + ax_critical.set_xlabel('time (days)') + ax_critical.plot(x, y50[:, [5]], color=color_plot[count], label='y50') + ax_critical.fill_between( + x, y50[:, 5], y25[:, 5], alpha=0.5, color=color_plot[count]) + ax_critical.fill_between( + x, y50[:, 5], y75[:, 5], alpha=0.5, color=color_plot[count]) + ax_critical.tick_params(axis='y') + ax_critical.title.set_text('#Critical, Age{}'.format(j)) + ax_critical.legend(['Simulation']) + + # Dead + ax_dead = ax[4, count] + ax_dead.set_xlabel('time (days)') + ax_dead.plot(x, y50[:, [7]], color=color_plot[count], label='y50') + ax_dead.fill_between(x, y50[:, 7], y25[:, 7], + alpha=0.5, color=color_plot[count]) + ax_dead.fill_between(x, y50[:, 7], y75[:, 7], + alpha=0.5, color=color_plot[count]) + ax_dead.tick_params(axis='y') + ax_dead.title.set_text('#Dead, Age{}'.format(j)) + ax_dead.legend(['Simulation']) + + # Recovered + ax_dead = ax[5, count] + ax_dead.set_xlabel('time (days)') + ax_dead.plot(x, y50[:, [6]], color=color_plot[count], label='y50') + ax_dead.fill_between(x, y50[:, 6], y25[:, 6], + alpha=0.5, color=color_plot[count]) + ax_dead.fill_between(x, y50[:, 6], y75[:, 6], + alpha=0.5, color=color_plot[count]) + ax_dead.tick_params(axis='y') + ax_dead.title.set_text('#Recovered, Age{}'.format(j)) + ax_dead.legend(['Simulation']) + + # fig.tight_layout() # otherwise the right y-label is slightly clipped + plt.show() + + +if __name__ == "__main__": + path = "" + + if (len(sys.argv) > 1): + n_runs = sys.argv[1] + else: + n_runs = len([entry for entry in os.listdir(path) + if os.path.isfile(os.path.join(path, entry))]) + + # plot_infection_states_results(path) + # plot_infections_loc_types_avarage(path) \ No newline at end of file From c7ffab3663b49f2d14c8f44773aeba2df7769402 Mon Sep 17 00:00:00 2001 From: Sascha <51127093+xsaschako@users.noreply.github.com> Date: Tue, 27 May 2025 22:32:04 +0200 Subject: [PATCH 02/11] Remove unused imports and enable plotting functions in infection states script --- pycode/memilio-plot/memilio/plot/plotAbmICUAndDeadComp.py | 2 -- pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py | 6 ++---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/pycode/memilio-plot/memilio/plot/plotAbmICUAndDeadComp.py b/pycode/memilio-plot/memilio/plot/plotAbmICUAndDeadComp.py index 6a1af97ed5..510a88104e 100644 --- a/pycode/memilio-plot/memilio/plot/plotAbmICUAndDeadComp.py +++ b/pycode/memilio-plot/memilio/plot/plotAbmICUAndDeadComp.py @@ -20,8 +20,6 @@ import h5py from datetime import datetime from matplotlib.dates import DateFormatter -from scipy.ndimage import gaussian_filter1d -from scipy.signal import savgol_filter fontsize = 20 diff --git a/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py b/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py index cf33fb8d3d..3f92bf4c0b 100644 --- a/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py +++ b/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py @@ -20,8 +20,6 @@ import h5py from datetime import datetime from matplotlib.dates import DateFormatter -from scipy.ndimage import gaussian_filter1d -from scipy.signal import savgol_filter fontsize = 20 @@ -254,5 +252,5 @@ def plot_infection_states_individual(x, p50_bs, p25_bs, p75_bs): n_runs = len([entry for entry in os.listdir(path) if os.path.isfile(os.path.join(path, entry))]) - # plot_infection_states_results(path) - # plot_infections_loc_types_avarage(path) \ No newline at end of file + plot_infection_states_results(path) + plot_infections_loc_types_avarage(path) \ No newline at end of file From e986d96c25db8cbe0bc8d63f1d9a9070d94f7800 Mon Sep 17 00:00:00 2001 From: Sascha Korf <51127093+xsaschako@users.noreply.github.com> Date: Tue, 27 May 2025 23:49:21 +0200 Subject: [PATCH 03/11] Refactor infection states plotting script: enhance documentation, improve function structure, and add new plotting capabilities for location types. --- .../memilio/plot/plotAbmInfectionStates.py | 472 ++++++++++-------- 1 file changed, 253 insertions(+), 219 deletions(-) diff --git a/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py b/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py index 3f92bf4c0b..352c5f5998 100644 --- a/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py +++ b/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py @@ -1,10 +1,22 @@ -# Python script to analyze bs runs -# input is a bs run folder with the following structure: -# bs_run_folder has a txt file for each bs run -# each txt file has a line for each time step -# each line has a column for each compartment as well as the timestep -# each column has the number of individuals in that compartment -# the first line of each txt file is the header +############################################################################# +# Copyright (C) 2020-2024 MEmilio +# +# Authors: Daniel Abele, Martin J. Kuehn +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# import sys import argparse @@ -13,244 +25,266 @@ import numpy as np import matplotlib.pyplot as plt import matplotlib -import matplotlib.colors as colors -import matplotlib.cm as cmx -import matplotlib.patches as mpatches -import matplotlib.lines as mlines import h5py from datetime import datetime -from matplotlib.dates import DateFormatter - -fontsize = 20 - - -def plot_infections_loc_types_avarage(path): - # 50-percentile - f_p50 = h5py.File( - path+"/infection_per_location_type_per_age_group/0/p50/Results.h5", 'r') - p50_bs = f_p50['0'] - total_50 = p50_bs['Total'][()] - - # 25-percentile - f_p25 = h5py.File( - path+"/infection_per_location_type_per_age_group/0/p05/Results.h5", 'r') - p25_bs = f_p25['0'] - total_25 = p25_bs['Total'][()] - - # 75-percentile - f_p75 = h5py.File( - path + "/infection_per_location_type_per_age_group/0/p95/Results.h5", 'r') - p75_bs = f_p75['0'] - total_75 = p75_bs['Total'][()] - - time = p50_bs['Time'][()] - - plot_infection_per_location_type_mean( - time, total_50, total_25, total_75) - -def plot_infection_per_location_type_mean(x, y50, y25, y75): +from scipy.ndimage import gaussian_filter1d + + +""" Module for plotting infection states and location types from ABM results. +This module provides functions to load and visualize infection states and +location types from simulation results stored in HDF5 format and are output +by the MEmilio agent-based model (ABM). +The used Loggers are: +struct LogInfectionStatePerAgeGroup : mio::LogAlways { + using Type = std::pair; + /** + * @brief Log the TimeSeries of the number of Person%s in an #InfectionState. + * @param[in] sim The simulation of the abm. + * @return A pair of the TimePoint and the TimeSeries of the number of Person%s in an #InfectionState. + */ + static Type log(const mio::abm::Simulation& sim) + { + + Eigen::VectorXd sum = Eigen::VectorXd::Zero( + Eigen::Index((size_t)mio::abm::InfectionState::Count * sim.get_world().parameters.get_num_groups())); + const auto curr_time = sim.get_time(); + const auto persons = sim.get_world().get_persons(); + + // PRAGMA_OMP(parallel for) + for (auto i = size_t(0); i < persons.size(); ++i) { + auto& p = persons[i]; + if (p.get_should_be_logged()) { + auto index = (((size_t)(mio::abm::InfectionState::Count)) * ((uint32_t)p.get_age().get())) + + ((uint32_t)p.get_infection_state(curr_time)); + // PRAGMA_OMP(atomic) + sum[index] += 1; + } + } + return std::make_pair(curr_time, sum); + } +}; + +struct LogInfectionPerLocationTypePerAgeGroup : mio::LogAlways { + using Type = std::pair; + /** + * @brief Log the TimeSeries of the number of Person%s in an #InfectionState. + * @param[in] sim The simulation of the abm. + * @return A pair of the TimePoint and the TimeSeries of the number of Person%s in an #InfectionState. + */ + static Type log(const mio::abm::Simulation& sim) + { + + Eigen::VectorXd sum = Eigen::VectorXd::Zero( + Eigen::Index((size_t)mio::abm::LocationType::Count * sim.get_world().parameters.get_num_groups())); + auto curr_time = sim.get_time(); + auto prev_time = sim.get_prev_time(); + const auto persons = sim.get_world().get_persons(); + + // PRAGMA_OMP(parallel for) + for (auto i = size_t(0); i < persons.size(); ++i) { + auto& p = persons[i]; + if (p.get_should_be_logged()) { + // PRAGMA_OMP(atomic) + if ((p.get_infection_state(prev_time) != mio::abm::InfectionState::Exposed) && + (p.get_infection_state(curr_time) == mio::abm::InfectionState::Exposed)) { + auto index = (((size_t)(mio::abm::LocationType::Count)) * ((uint32_t)p.get_age().get())) + + ((uint32_t)p.get_location().get_type()); + sum[index] += 1; + } + } + } + return std::make_pair(curr_time, sum); + } +}; + +The output of several Loggers is stored in HDF5 files, with the memilio funciton mio::save_results in mio/io/result_io.h. +""" + + + + +def load_h5_results(base_path, percentile): + """ Reads HDF5 results for a given group and percentile. + + @param[in] base_path Path to results directory. + @param[in] percentile Subdirectory for percentile (e.g. 'p50'). + @return Dictionary with data arrays. + """ + file_path = os.path.join(base_path, percentile, "Results.h5") + with h5py.File(file_path, 'r') as f: + data = {k: v[()] for k, v in f['0'].items()} + return data + +def plot_infections_loc_types_average( + path_to_loc_types, + start_date='2021-03-01', + colormap='Set1', + smooth_sigma=1, + rolling_window=24, + xtick_step=150): + """ Plots average infections per location type. + + @param[in] base_path Path to results directory. + @param[in] start_date Start date as string. + @param[in] colormap Matplotlib colormap. + @param[in] smooth_sigma Sigma for Gaussian smoothing. + @param[in] rolling_window Window size for rolling sum. + @param[in] xtick_step Step size for x-axis ticks. + """ + # Load data + p50 = load_h5_results(path_to_loc_types, "p50") + p25 = load_h5_results(path_to_loc_types, "p05") + p75 = load_h5_results(path_to_loc_types, "p95") + time = p50['Time'] + total_50 = p50['Total'] plt.figure('Infection_location_types') - plt.title('At which location type an infection happened, avaraged over all runs') - - color_plot = matplotlib.colormaps.get_cmap('Set1').colors - - states_plot = [0, 1, 2, 3, 4, 10] - legend_plot = ['Home', 'School', 'Work', - 'SocialEvent', 'BasicsShop','Event'] - - for i in states_plot: - # rolling average# - ac_color = color_plot[i%len(color_plot)] - if(i > len(color_plot)): - ac_color = "black" - - # we need to sum up every 24 hours - indexer = pd.api.indexers.FixedForwardWindowIndexer(window_size=24) - np_y50 = pd.DataFrame(y50[:, i]).rolling(window=indexer, min_periods=1).sum().to_numpy() - np_y50=np_y50[0::24].flatten() - # now smoothen this with a gaussian filter - np_y50 = gaussian_filter1d(np_y50, sigma=1, mode='nearest') - - plt.plot(x[0::24], np_y50, color=ac_color) + plt.title('At which location type an infection happened, averaged over all runs') + color_plot = matplotlib.colormaps.get_cmap(colormap).colors + # If you define further location types, you need to adjust this list + states_plot = [0, 1, 2, 3, 4, 5 , 6] + legend_plot = ['Home', 'School', 'Work', 'SocialEvent', 'BasicsShop', 'Hospital', 'ICU'] + + for idx, i in enumerate(states_plot): + color = color_plot[i % len(color_plot)] if i < len(color_plot) else "black" + # Sum up every 24 hours, then smooth + indexer = pd.api.indexers.FixedForwardWindowIndexer(window_size=rolling_window) + y = pd.DataFrame(total_50[:, i]).rolling(window=indexer, min_periods=1).sum().to_numpy() + y = y[0::rolling_window].flatten() + y = gaussian_filter1d(y, sigma=smooth_sigma, mode='nearest') + plt.plot(time[0::rolling_window], y, color=color) plt.legend(legend_plot) - - # currently the x axis has the values of the time steps, we need to convert them to dates and set the x axis to dates - start_date = datetime.strptime('2021-03-01', '%Y-%m-%d') - xx = [start_date + pd.Timedelta(days=int(i)) for i in x] - xx = [xx[i].strftime('%Y-%m-%d') for i in range(len(xx))] - # but just take every 10th date to make it more readable - plt.gca().set_xticks(x[::150]) - plt.gca().set_xticklabels(xx[::150]) - plt.gcf().autofmt_xdate() - + _format_x_axis(time, start_date, xtick_step) plt.xlabel('Date') plt.ylabel('Number of individuals') plt.show() -def plot_infection_states_results(path): - # 50-percentile - f_p50 = h5py.File( - path+"/infection_state_per_age_group/0/p50/Results.h5", 'r') - p50_bs = f_p50['0'] - total_50 = p50_bs['Total'][()] - # 25-percentile - f_p25 = h5py.File( - path+"/infection_state_per_age_group/0/p05/Results.h5", 'r') - p25_bs = f_p25['0'] - total_25 = p25_bs['Total'][()] - # 75-percentile - f_p75 = h5py.File( - path + "/infection_state_per_age_group/0/p95/Results.h5", 'r') - p75_bs = f_p75['0'] - total_75 = p75_bs['Total'][()] - - time = p50_bs['Time'][()] - - plot_infection_states_individual( - time, p50_bs, p25_bs, p75_bs) - plot_infection_states(time, total_50, total_25, total_75) - -def plot_infection_states(x, y50, y25, y75, y_real=None): - plt.figure('Infection_states') +def plot_infection_states_results( + path_to_infection_states, + start_date='2021-03-01', + colormap='Set1', + xtick_step=150): + """ Loads and plots infection state results. + + @param[in] base_path Path to results directory. + @param[in] start_date Start date as string. + @param[in] colormap Matplotlib colormap. + @param[in] xtick_step Step size for x-axis ticks. + """ + # Load data + p50 = load_h5_results(path_to_infection_states, "p50") + p25 = load_h5_results(path_to_infection_states, "p25") + p75 = load_h5_results(path_to_infection_states, "p75") + time = p50['Time'] + total_50 = p50['Total'] + total_25 = p25['Total'] + total_75 = p75['Total'] + + plot_infection_states_individual(time, p50, p25, p75, colormap) + plot_infection_states(time, total_50, total_25, total_75, start_date, colormap, xtick_step) + +def plot_infection_states( + x, y50, y25, y75, + start_date='2021-03-01', + colormap='Set1', + xtick_step=150): + """ Plots infection states with percentiles. + + @param[in] x Time array. + @param[in] y50 Median values. + @param[in] y25 25th percentile values. + @param[in] y75 75th percentile values. + @param[in] start_date Start date as string. + @param[in] colormap Matplotlib colormap. + @param[in] xtick_step Step size for x-axis ticks. + """ + plt.figure('Infection_states with 50% percentile') plt.title('Infection states') - - color_plot = matplotlib.colormaps.get_cmap('Set1').colors - + color_plot = matplotlib.colormaps.get_cmap(colormap).colors states_plot = [1, 2, 3, 4, 5, 7] - legend_plot = ['E', 'I_NSymp', 'I_Symp', - 'I_Sev', 'I_Crit', 'Dead', 'Sm. re. pos.'] + legend_plot = ['E', 'I_NSymp', 'I_Symp', 'I_Sev', 'I_Crit', 'Dead'] for i in states_plot: plt.plot(x, y50[:, i], color=color_plot[i]) - + plt.fill_between(x, y50[:, i], y25[:, i], alpha=0.5, color=color_plot[i]) + plt.fill_between(x, y50[:, i], y75[:, i], alpha=0.5, color=color_plot[i]) plt.legend(legend_plot) - for i in states_plot: - plt.fill_between(x, y50[:, i], y25[:, i], - alpha=0.5, color=color_plot[i]) - plt.fill_between(x, y50[:, i], y75[:, i], - alpha=0.5, color=color_plot[i]) - - # currently the x axis has the values of the time steps, we need to convert them to dates and set the x axis to dates - start_date = datetime.strptime('2021-03-01', '%Y-%m-%d') - xx = [start_date + pd.Timedelta(days=int(i)) for i in x] - xx = [xx[i].strftime('%Y-%m-%d') for i in range(len(xx))] - # but just take every 10th date to make it more readable - plt.gca().set_xticks(x[::150]) - plt.gca().set_xticklabels(xx[::150]) - plt.gcf().autofmt_xdate() - + _format_x_axis(x, start_date, xtick_step) plt.xlabel('Time') plt.ylabel('Number of individuals') plt.show() -def plot_infection_states_individual(x, p50_bs, p25_bs, p75_bs): - - - age_group_access = ['Group1', 'Group2', 'Group3', - 'Group4', 'Group5', 'Group6', 'Total'] - - color_plot = matplotlib.colormaps.get_cmap('Set1').colors - - fig, ax = plt.subplots(6, len(age_group_access), constrained_layout=True) - fig.set_figwidth(20) - fig.set_figheight(9) - for j, count in zip(age_group_access, range(len(age_group_access))): - y50 = p50_bs[j][()] - y25 = p25_bs[j][()] - y75 = p75_bs[j][()] - - - # infeced no symptoms - ax_infected_no_symptoms = ax[0, count] - ax_infected_no_symptoms.set_xlabel('time (days)') - ax_infected_no_symptoms.plot( - x, y50[:, 1], color=color_plot[count], label='y50') - ax_infected_no_symptoms.fill_between( - x, y50[:, 1], y25[:, 1], alpha=0.5, color=color_plot[count]) - ax_infected_no_symptoms.fill_between( - x, y50[:, 1], y75[:, 1], alpha=0.5, color=color_plot[count]) - ax_infected_no_symptoms.tick_params(axis='y') - ax_infected_no_symptoms.title.set_text( - '#Infected_no_symptoms, Age{}'.format(j)) - ax_infected_no_symptoms.legend(['Simulation']) - - # Infected_symptoms - ax_infected_symptoms = ax[1, count] - ax_infected_symptoms.set_xlabel('time (days)') - ax_infected_symptoms.plot( - x, y50[:, 2], color=color_plot[count], label='y50') - ax_infected_symptoms.fill_between( - x, y50[:, 2], y25[:, 2], alpha=0.5, color=color_plot[count]) - ax_infected_symptoms.fill_between( - x, y50[:, 2], y75[:, 2], alpha=0.5, color=color_plot[count]) - ax_infected_symptoms.tick_params(axis='y') - ax_infected_symptoms.title.set_text( - '#Infected_symptoms, Age{}'.format(j)) - ax_infected_symptoms.legend(['Simulation']) - +def plot_infection_states_individual(x, p50_bs, p25_bs, p75_bs, colormap='Set1'): + """ Plots infection states for each age group. + + @param[in] x Time array. + @param[in] p50_bs Median values by group. + @param[in] p25_bs 25th percentile values by group. + @param[in] p75_bs 75th percentile values by group. + @param[in] colormap Matplotlib colormap. + """ + age_groups = ['Group1', 'Group2', 'Group3', 'Group4', 'Group5', 'Group6', 'Total'] # Adjust as needed + color_plot = matplotlib.colormaps.get_cmap(colormap).colors + fig, ax = plt.subplots(6, len(age_groups), constrained_layout=True, figsize=(20, 9)) + + for col_idx, group in enumerate(age_groups): + y50 = p50_bs[group] + y25 = p25_bs[group] + y75 = p75_bs[group] + # Infected no symptoms + _plot_state(ax[0, col_idx], x, y50[:, 1], y25[:, 1], y75[:, 1], color_plot[col_idx], '#Infected_no_symptoms, Age' + str(group)) + # Infected symptoms + _plot_state(ax[1, col_idx], x, y50[:, 2], y25[:, 2], y75[:, 2], color_plot[col_idx], '#Infected_symptoms, Age' + str(group)) # Severe - ax_severe = ax[2, count] - ax_severe.set_xlabel('time (days)') - ax_severe.plot(x, y50[:, 4], color=color_plot[count], label='y50') - ax_severe.fill_between( - x, y50[:, 4], y25[:, 4], alpha=0.5, color=color_plot[count]) - ax_severe.fill_between( - x, y50[:, 4], y75[:, 4], alpha=0.5, color=color_plot[count]) - ax_severe.tick_params(axis='y') - ax_severe.title.set_text('#Severe, Age{}'.format(j)) - ax_severe.legend(['Simulation']) - + _plot_state(ax[2, col_idx], x, y50[:, 4], y25[:, 4], y75[:, 4], color_plot[col_idx], '#Severe, Age' + str(group)) # Critical - ax_critical = ax[3, count] - ax_critical.set_xlabel('time (days)') - ax_critical.plot(x, y50[:, [5]], color=color_plot[count], label='y50') - ax_critical.fill_between( - x, y50[:, 5], y25[:, 5], alpha=0.5, color=color_plot[count]) - ax_critical.fill_between( - x, y50[:, 5], y75[:, 5], alpha=0.5, color=color_plot[count]) - ax_critical.tick_params(axis='y') - ax_critical.title.set_text('#Critical, Age{}'.format(j)) - ax_critical.legend(['Simulation']) - + _plot_state(ax[3, col_idx], x, y50[:, 5], y25[:, 5], y75[:, 5], color_plot[col_idx], '#Critical, Age' + str(group)) # Dead - ax_dead = ax[4, count] - ax_dead.set_xlabel('time (days)') - ax_dead.plot(x, y50[:, [7]], color=color_plot[count], label='y50') - ax_dead.fill_between(x, y50[:, 7], y25[:, 7], - alpha=0.5, color=color_plot[count]) - ax_dead.fill_between(x, y50[:, 7], y75[:, 7], - alpha=0.5, color=color_plot[count]) - ax_dead.tick_params(axis='y') - ax_dead.title.set_text('#Dead, Age{}'.format(j)) - ax_dead.legend(['Simulation']) - + _plot_state(ax[4, col_idx], x, y50[:, 7], y25[:, 7], y75[:, 7], color_plot[col_idx], '#Dead, Age' + str(group)) # Recovered - ax_dead = ax[5, count] - ax_dead.set_xlabel('time (days)') - ax_dead.plot(x, y50[:, [6]], color=color_plot[count], label='y50') - ax_dead.fill_between(x, y50[:, 6], y25[:, 6], - alpha=0.5, color=color_plot[count]) - ax_dead.fill_between(x, y50[:, 6], y75[:, 6], - alpha=0.5, color=color_plot[count]) - ax_dead.tick_params(axis='y') - ax_dead.title.set_text('#Recovered, Age{}'.format(j)) - ax_dead.legend(['Simulation']) - - # fig.tight_layout() # otherwise the right y-label is slightly clipped + _plot_state(ax[5, col_idx], x, y50[:, 6], y25[:, 6], y75[:, 6], color_plot[col_idx], '#Recovered, Age' + str(group)) + + fig.suptitle('Infection states per age group with 50% percentile', fontsize=16) + + # We hide the Legend for the individual plots as it is too cluttered + for ax_row in ax: + for ax_col in ax_row: + ax_col.legend().set_visible(False) + plt.show() +def _plot_state(ax, x, y50, y25, y75, color, title): + """ Helper to plot a single state with fill_between. """ + ax.set_xlabel('time (days)') + ax.plot(x, y50, color=color, label='Median') + ax.fill_between(x, y50, y25, alpha=0.5, color=color) + ax.fill_between(x, y50, y75, alpha=0.5, color=color) + ax.tick_params(axis='y') + ax.set_title(title) + ax.legend(['Simulation']) + +def _format_x_axis(x, start_date, xtick_step): + """ Helper to format x-axis as dates. """ + start = datetime.strptime(start_date, '%Y-%m-%d') + xx = [start + pd.Timedelta(days=int(i)) for i in x] + xx_str = [dt.strftime('%Y-%m-%d') for dt in xx] + plt.gca().set_xticks(x[::xtick_step]) + plt.gca().set_xticklabels(xx_str[::xtick_step]) + plt.gcf().autofmt_xdate() + +def main(): + """ Main function for CLI usage. """ + parser = argparse.ArgumentParser(description="Plot infection state and location type results.") + # parser.add_argument("path", type=str, help="Path to the results folder") + parser.add_argument("--start-date", type=str, default='2021-03-01', help="Simulation start date (YYYY-MM-DD)") + parser.add_argument("--colormap", type=str, default='Set1', help="Matplotlib colormap") + parser.add_argument("--xtick-step", type=int, default=150, help="Step for x-axis ticks") + args = parser.parse_args() -if __name__ == "__main__": - path = "" + plot_infection_states_results("/Users/saschakorf/Nosynch/Arbeit/memilio/memilio/data/cluster_results/final_results/results_2024-09-20192904_best/infection_state_per_age_group/0", args.start_date, args.colormap, args.xtick_step) + # plot_infections_loc_types_average("/Users/saschakorf/Nosynch/Arbeit/memilio/memilio/data/cluster_results/final_results/results_2024-09-20192904_best/infection_per_location_type_per_age_group/0", args.start_date, args.colormap, xtick_step=args.xtick_step) - if (len(sys.argv) > 1): - n_runs = sys.argv[1] - else: - n_runs = len([entry for entry in os.listdir(path) - if os.path.isfile(os.path.join(path, entry))]) - - plot_infection_states_results(path) - plot_infections_loc_types_avarage(path) \ No newline at end of file +if __name__ == "__main__": + main() From 6e4c6fe7c02a54b3471ae9c77b1480b84661512f Mon Sep 17 00:00:00 2001 From: Sascha Korf <51127093+xsaschako@users.noreply.github.com> Date: Wed, 28 May 2025 00:23:11 +0200 Subject: [PATCH 04/11] Refactor plotAbmInfectionStates.py: update authorship, enhance module documentation, and improve function argument handling for better usability --- .../memilio/plot/plotAbmInfectionStates.py | 175 ++++++++++-------- 1 file changed, 95 insertions(+), 80 deletions(-) diff --git a/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py b/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py index 352c5f5998..91a6926c5e 100644 --- a/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py +++ b/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py @@ -1,7 +1,7 @@ ############################################################################# # Copyright (C) 2020-2024 MEmilio # -# Authors: Daniel Abele, Martin J. Kuehn +# Authors: Sascha Korf # # Contact: Martin J. Kuehn # @@ -30,77 +30,74 @@ from scipy.ndimage import gaussian_filter1d -""" Module for plotting infection states and location types from ABM results. -This module provides functions to load and visualize infection states and -location types from simulation results stored in HDF5 format and are output -by the MEmilio agent-based model (ABM). -The used Loggers are: -struct LogInfectionStatePerAgeGroup : mio::LogAlways { - using Type = std::pair; - /** - * @brief Log the TimeSeries of the number of Person%s in an #InfectionState. - * @param[in] sim The simulation of the abm. - * @return A pair of the TimePoint and the TimeSeries of the number of Person%s in an #InfectionState. - */ - static Type log(const mio::abm::Simulation& sim) - { - - Eigen::VectorXd sum = Eigen::VectorXd::Zero( - Eigen::Index((size_t)mio::abm::InfectionState::Count * sim.get_world().parameters.get_num_groups())); - const auto curr_time = sim.get_time(); - const auto persons = sim.get_world().get_persons(); - - // PRAGMA_OMP(parallel for) - for (auto i = size_t(0); i < persons.size(); ++i) { - auto& p = persons[i]; - if (p.get_should_be_logged()) { - auto index = (((size_t)(mio::abm::InfectionState::Count)) * ((uint32_t)p.get_age().get())) + - ((uint32_t)p.get_infection_state(curr_time)); - // PRAGMA_OMP(atomic) - sum[index] += 1; - } - } - return std::make_pair(curr_time, sum); - } -}; - -struct LogInfectionPerLocationTypePerAgeGroup : mio::LogAlways { - using Type = std::pair; - /** - * @brief Log the TimeSeries of the number of Person%s in an #InfectionState. - * @param[in] sim The simulation of the abm. - * @return A pair of the TimePoint and the TimeSeries of the number of Person%s in an #InfectionState. - */ - static Type log(const mio::abm::Simulation& sim) - { - - Eigen::VectorXd sum = Eigen::VectorXd::Zero( - Eigen::Index((size_t)mio::abm::LocationType::Count * sim.get_world().parameters.get_num_groups())); - auto curr_time = sim.get_time(); - auto prev_time = sim.get_prev_time(); - const auto persons = sim.get_world().get_persons(); - - // PRAGMA_OMP(parallel for) - for (auto i = size_t(0); i < persons.size(); ++i) { - auto& p = persons[i]; - if (p.get_should_be_logged()) { - // PRAGMA_OMP(atomic) - if ((p.get_infection_state(prev_time) != mio::abm::InfectionState::Exposed) && - (p.get_infection_state(curr_time) == mio::abm::InfectionState::Exposed)) { - auto index = (((size_t)(mio::abm::LocationType::Count)) * ((uint32_t)p.get_age().get())) + - ((uint32_t)p.get_location().get_type()); - sum[index] += 1; - } - } - } - return std::make_pair(curr_time, sum); - } -}; - -The output of several Loggers is stored in HDF5 files, with the memilio funciton mio::save_results in mio/io/result_io.h. -""" - - +# Module for plotting infection states and location types from ABM results. +# This module provides functions to load and visualize infection states and +# location types from simulation results stored in HDF5 format and are output +# by the MEmilio agent-based model (ABM). +# The used Loggers are: +# struct LogInfectionStatePerAgeGroup : mio::LogAlways { +# using Type = std::pair; +# /** +# * @brief Log the TimeSeries of the number of Person%s in an #InfectionState. +# * @param[in] sim The simulation of the abm. +# * @return A pair of the TimePoint and the TimeSeries of the number of Person%s in an #InfectionState. +# */ +# static Type log(const mio::abm::Simulation& sim) +# { +# +# Eigen::VectorXd sum = Eigen::VectorXd::Zero( +# Eigen::Index((size_t)mio::abm::InfectionState::Count * sim.get_world().parameters.get_num_groups())); +# const auto curr_time = sim.get_time(); +# const auto persons = sim.get_world().get_persons(); +# +# // PRAGMA_OMP(parallel for) +# for (auto i = size_t(0); i < persons.size(); ++i) { +# auto& p = persons[i]; +# if (p.get_should_be_logged()) { +# auto index = (((size_t)(mio::abm::InfectionState::Count)) * ((uint32_t)p.get_age().get())) + +# ((uint32_t)p.get_infection_state(curr_time)); +# // PRAGMA_OMP(atomic) +# sum[index] += 1; +# } +# } +# return std::make_pair(curr_time, sum); +# } +# }; +# +# struct LogInfectionPerLocationTypePerAgeGroup : mio::LogAlways { +# using Type = std::pair; +# /** +# * @brief Log the TimeSeries of the number of Person%s in an #InfectionState. +# * @param[in] sim The simulation of the abm. +# * @return A pair of the TimePoint and the TimeSeries of the number of Person%s in an #InfectionState. +# */ +# static Type log(const mio::abm::Simulation& sim) +# { +# +# Eigen::VectorXd sum = Eigen::VectorXd::Zero( +# Eigen::Index((size_t)mio::abm::LocationType::Count * sim.get_world().parameters.get_num_groups())); +# auto curr_time = sim.get_time(); +# auto prev_time = sim.get_prev_time(); +# const auto persons = sim.get_world().get_persons(); +# +# // PRAGMA_OMP(parallel for) +# for (auto i = size_t(0); i < persons.size(); ++i) { +# auto& p = persons[i]; +# if (p.get_should_be_logged()) { +# // PRAGMA_OMP(atomic) +# if ((p.get_infection_state(prev_time) != mio::abm::InfectionState::Exposed) && +# (p.get_infection_state(curr_time) == mio::abm::InfectionState::Exposed)) { +# auto index = (((size_t)(mio::abm::LocationType::Count)) * ((uint32_t)p.get_age().get())) + +# ((uint32_t)p.get_location().get_type()); +# sum[index] += 1; +# } +# } +# } +# return std::make_pair(curr_time, sum); +# } +# }; +# +# The output of the loggers of several runs is stored in HDF5 files, with the memilio funciton mio::save_results in mio/io/result_io.h. def load_h5_results(base_path, percentile): @@ -122,7 +119,7 @@ def plot_infections_loc_types_average( smooth_sigma=1, rolling_window=24, xtick_step=150): - """ Plots average infections per location type. + """ Plots rolling average infections per location type for the median run. @param[in] base_path Path to results directory. @param[in] start_date Start date as string. @@ -133,13 +130,11 @@ def plot_infections_loc_types_average( """ # Load data p50 = load_h5_results(path_to_loc_types, "p50") - p25 = load_h5_results(path_to_loc_types, "p05") - p75 = load_h5_results(path_to_loc_types, "p95") time = p50['Time'] total_50 = p50['Total'] plt.figure('Infection_location_types') - plt.title('At which location type an infection happened, averaged over all runs') + plt.title('Infection per location type for the median run, rolling sum over 24 hours') color_plot = matplotlib.colormaps.get_cmap(colormap).colors # If you define further location types, you need to adjust this list states_plot = [0, 1, 2, 3, 4, 5 , 6] @@ -200,13 +195,16 @@ def plot_infection_states( @param[in] xtick_step Step size for x-axis ticks. """ plt.figure('Infection_states with 50% percentile') - plt.title('Infection states') + plt.title('Infection states with 50% percentile') color_plot = matplotlib.colormaps.get_cmap(colormap).colors states_plot = [1, 2, 3, 4, 5, 7] legend_plot = ['E', 'I_NSymp', 'I_Symp', 'I_Sev', 'I_Crit', 'Dead'] for i in states_plot: plt.plot(x, y50[:, i], color=color_plot[i]) + plt.legend(legend_plot) # Needs to be done here, otherwise the percentage fill_between will not work correctly + + for i in states_plot: plt.fill_between(x, y50[:, i], y25[:, i], alpha=0.5, color=color_plot[i]) plt.fill_between(x, y50[:, i], y75[:, i], alpha=0.5, color=color_plot[i]) @@ -277,14 +275,31 @@ def _format_x_axis(x, start_date, xtick_step): def main(): """ Main function for CLI usage. """ parser = argparse.ArgumentParser(description="Plot infection state and location type results.") - # parser.add_argument("path", type=str, help="Path to the results folder") + parser.add_argument("--path-to-infection-states", help="Path to infection states results") + parser.add_argument("--path-to-loc-types", help="Path to location types results") parser.add_argument("--start-date", type=str, default='2021-03-01', help="Simulation start date (YYYY-MM-DD)") parser.add_argument("--colormap", type=str, default='Set1', help="Matplotlib colormap") parser.add_argument("--xtick-step", type=int, default=150, help="Step for x-axis ticks") args = parser.parse_args() - plot_infection_states_results("/Users/saschakorf/Nosynch/Arbeit/memilio/memilio/data/cluster_results/final_results/results_2024-09-20192904_best/infection_state_per_age_group/0", args.start_date, args.colormap, args.xtick_step) - # plot_infections_loc_types_average("/Users/saschakorf/Nosynch/Arbeit/memilio/memilio/data/cluster_results/final_results/results_2024-09-20192904_best/infection_per_location_type_per_age_group/0", args.start_date, args.colormap, xtick_step=args.xtick_step) + if args.path_to_infection_states: + plot_infection_states_results( + args.path_to_infection_states, + start_date=args.start_date, + colormap=args.colormap, + xtick_step=args.xtick_step) + if args.path_to_loc_types: + plot_infections_loc_types_average( + args.path_to_loc_types, + start_date=args.start_date, + colormap=args.colormap, + xtick_step=args.xtick_step) + + if not args.path_to_infection_states and not args.path_to_loc_types: + print("Please provide a path to infection states or location types results.") + sys.exit(1) + plt.show() + if __name__ == "__main__": main() From 663a189fa2ef24c6a784b0db11258a1fe35edbd0 Mon Sep 17 00:00:00 2001 From: Sascha Korf <51127093+xsaschako@users.noreply.github.com> Date: Wed, 28 May 2025 00:24:18 +0200 Subject: [PATCH 05/11] Add unit tests for plotAbmInfectionStates: implement comprehensive test cases for loading H5 results and plotting functions --- .../test_plot_plotAbmInfectionStates.py | 122 ++++++++++++++++++ 1 file changed, 122 insertions(+) create mode 100644 pycode/memilio-plot/memilio/plot_test/test_plot_plotAbmInfectionStates.py diff --git a/pycode/memilio-plot/memilio/plot_test/test_plot_plotAbmInfectionStates.py b/pycode/memilio-plot/memilio/plot_test/test_plot_plotAbmInfectionStates.py new file mode 100644 index 0000000000..41d5a31da5 --- /dev/null +++ b/pycode/memilio-plot/memilio/plot_test/test_plot_plotAbmInfectionStates.py @@ -0,0 +1,122 @@ +############################################################################# +# Copyright (C) 2020-2024 MEmilio +# +# Authors: Sascha Korf +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# + +import unittest +from unittest.mock import patch, MagicMock +import numpy as np +import pandas as pd + +import memilio.plot.plotAbmInfectionStates as abm + +class TestPlotAbmInfectionStates(unittest.TestCase): + + @patch('memilio.plot.plotAbmInfectionStates.h5py.File') + def test_load_h5_results(self, mock_h5file): + mock_group = {'Time': np.arange(10), 'Total': np.ones((10, 8))} + mock_h5file().__enter__().get.return_value = {'0': mock_group} + mock_h5file().__enter__().items.return_value = [('0', mock_group)] + mock_h5file().__enter__().__getitem__.return_value = mock_group + mock_h5file().__enter__().items.return_value = [('Time', np.arange(10)), ('Total', np.ones((10, 8)))] + with patch('memilio.plot.plotAbmInfectionStates.h5py.File', mock_h5file): + result = abm.load_h5_results('dummy_path', 'p50') + assert 'Time' in result + assert 'Total' in result + np.testing.assert_array_equal(result['Time'], np.arange(10)) + np.testing.assert_array_equal(result['Total'], np.ones((10, 8))) + + @patch('memilio.plot.plotAbmInfectionStates.load_h5_results') + @patch('memilio.plot.plotAbmInfectionStates.matplotlib') + @patch('memilio.plot.plotAbmInfectionStates.gaussian_filter1d', side_effect=lambda x, sigma, mode: x) + @patch('memilio.plot.plotAbmInfectionStates.pd.DataFrame') + def test_plot_infections_loc_types_average(self, mock_df, mock_gauss, mock_matplotlib, mock_load): + mock_load.return_value = {'Time': np.arange(48), 'Total': np.ones((48, 7))} + mock_df.return_value.rolling.return_value.sum.return_value.to_numpy.return_value = np.ones((48, 1)) + mock_matplotlib.colormaps.get_cmap.return_value.colors = [(1,0,0)]*7 + + # Patch plt.gca().plot to a MagicMock + with patch.object(abm.plt, 'gca') as mock_gca: + mock_ax = MagicMock() + mock_gca.return_value = mock_ax + abm.plot_infections_loc_types_average('dummy_path') + assert mock_ax.plot.called + assert mock_ax.set_xticks.called + assert mock_ax.set_xticklabels.called + + @patch('memilio.plot.plotAbmInfectionStates.load_h5_results') + @patch('memilio.plot.plotAbmInfectionStates.plot_infection_states') + @patch('memilio.plot.plotAbmInfectionStates.plot_infection_states_individual') + def test_plot_infection_states_results(self, mock_indiv, mock_states, mock_load): + mock_load.side_effect = [ + {'Time': np.arange(10), 'Total': np.ones((10, 8)), 'Group1': np.ones((10, 8)), 'Group2': np.ones((10, 8)), 'Group3': np.ones((10, 8)), 'Group4': np.ones((10, 8)), 'Group5': np.ones((10, 8)), 'Group6': np.ones((10, 8)), 'Total': np.ones((10, 8))}, + {'Time': np.arange(10), 'Total': np.ones((10, 8)), 'Group1': np.ones((10, 8)), 'Group2': np.ones((10, 8)), 'Group3': np.ones((10, 8)), 'Group4': np.ones((10, 8)), 'Group5': np.ones((10, 8)), 'Group6': np.ones((10, 8)), 'Total': np.ones((10, 8))}, + {'Time': np.arange(10), 'Total': np.ones((10, 8)), 'Group1': np.ones((10, 8)), 'Group2': np.ones((10, 8)), 'Group3': np.ones((10, 8)), 'Group4': np.ones((10, 8)), 'Group5': np.ones((10, 8)), 'Group6': np.ones((10, 8)), 'Total': np.ones((10, 8))} + ] + abm.plot_infection_states_results('dummy_path') + assert mock_indiv.called + assert mock_states.called + + @patch('memilio.plot.plotAbmInfectionStates.matplotlib') + def test_plot_infection_states(self, mock_matplotlib): + x = np.arange(10) + y50 = np.ones((10, 8)) + y25 = np.zeros((10, 8)) + y75 = np.ones((10, 8))*2 + mock_matplotlib.colormaps.get_cmap.return_value.colors = [(1,0,0)]*8 + + # Patch plt.gca().plot and fill_between + with patch.object(abm.plt, 'gca') as mock_gca: + mock_ax = MagicMock() + mock_gca.return_value = mock_ax + abm.plot_infection_states(x, y50, y25, y75) + assert mock_ax.plot.called + assert mock_ax.fill_between.called + assert mock_ax.set_xticks.called + assert mock_ax.set_xticklabels.called + + @patch('memilio.plot.plotAbmInfectionStates.matplotlib') + def test_plot_infection_states_individual(self, mock_matplotlib): + x = np.arange(10) + group_data = np.ones((10, 8)) + p50_bs = {g: group_data for g in ['Group1', 'Group2', 'Group3', 'Group4', 'Group5', 'Group6', 'Total']} + p25_bs = {g: group_data for g in ['Group1', 'Group2', 'Group3', 'Group4', 'Group5', 'Group6', 'Total']} + p75_bs = {g: group_data for g in ['Group1', 'Group2', 'Group3', 'Group4', 'Group5', 'Group6', 'Total']} + mock_matplotlib.colormaps.get_cmap.return_value.colors = [(1,0,0)]*8 + + # Patch plt.subplots to return a grid of MagicMock axes (as np.array with dtype=object) + with patch.object(abm.plt, 'subplots') as mock_subplots: + fig_mock = MagicMock() + ax_mock = np.empty((6, 7), dtype=object) + for i in range(6): + for j in range(7): + ax_mock[i, j] = MagicMock() + mock_subplots.return_value = (fig_mock, ax_mock) + abm.plot_infection_states_individual(x, p50_bs, p25_bs, p75_bs) + # Check that at least one ax's plot was called + assert any(ax_mock[i, j].plot.called for i in range(6) for j in range(7)) + assert fig_mock.suptitle.called + + def test__format_x_axis(self): + with patch('memilio.plot.plotAbmInfectionStates.plt') as mock_plt: + abm._format_x_axis(np.arange(10), '2021-03-01', 2) + assert mock_plt.gca.called + assert mock_plt.gcf.called + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 28c3187da683e884bc131415f114e306b34eaa8f Mon Sep 17 00:00:00 2001 From: Sascha Korf <51127093+xsaschako@users.noreply.github.com> Date: Wed, 28 May 2025 00:25:02 +0200 Subject: [PATCH 06/11] Refactor setup.py: enhance PylintCommand documentation and update comments for clarity --- pycode/memilio-plot/setup.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pycode/memilio-plot/setup.py b/pycode/memilio-plot/setup.py index 9c6a85908a..b13449bd91 100644 --- a/pycode/memilio-plot/setup.py +++ b/pycode/memilio-plot/setup.py @@ -8,12 +8,13 @@ class PylintCommand(Command): - """Custom command to run pylint and get a report as html.""" + """ + Custom command to run pylint and get a report as html. + """ description = "Runs pylint and outputs the report as html." user_options = [] def initialize_options(self): - """ """ from pylint.reporters.json_reporter import JSONReporter from pylint.reporters.text import ParseableTextReporter, TextReporter from pylint_json2html import JsonExtendedReporter @@ -29,12 +30,10 @@ def initialize_options(self): } def finalize_options(self): - """ """ self.reporter, self.out_file = self.REPORTERS.get( self.out_format) # , self.REPORTERS.get("parseable")) def run(self): - """ """ os.makedirs("build_pylint", exist_ok=True) # Run pylint @@ -63,7 +62,7 @@ def run(self): test_suite='memilio.plot_test', install_requires=[ # smaller pandas versions contain a bug that sometimes prevents reading - # some excel files (e.g. population or mobility data) + # some excel files (e.g. population or twitter data) 'pandas>=1.2.2', 'matplotlib', # smaller numpy versions cause a security issue, 1.25 breaks testing with pyfakefs @@ -74,6 +73,7 @@ def run(self): 'pyxlsb', 'wget', 'folium', + 'scipy.ndimage', 'matplotlib', 'mapclassify', 'geopandas', From 2ba672e8459d087fdd5edeef7c859a91ce65f7b8 Mon Sep 17 00:00:00 2001 From: Sascha Korf <51127093+xsaschako@users.noreply.github.com> Date: Wed, 28 May 2025 00:26:06 +0200 Subject: [PATCH 07/11] Fix comment in setup.py: update example for Excel file types in install_requires --- pycode/memilio-plot/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pycode/memilio-plot/setup.py b/pycode/memilio-plot/setup.py index b13449bd91..130996ac6a 100644 --- a/pycode/memilio-plot/setup.py +++ b/pycode/memilio-plot/setup.py @@ -62,7 +62,7 @@ def run(self): test_suite='memilio.plot_test', install_requires=[ # smaller pandas versions contain a bug that sometimes prevents reading - # some excel files (e.g. population or twitter data) + # some excel files (e.g. population or mobility data) 'pandas>=1.2.2', 'matplotlib', # smaller numpy versions cause a security issue, 1.25 breaks testing with pyfakefs From baa272b8ed10db30548423a2edf808731b182f42 Mon Sep 17 00:00:00 2001 From: Sascha Korf <51127093+xsaschako@users.noreply.github.com> Date: Wed, 28 May 2025 00:27:52 +0200 Subject: [PATCH 08/11] Remove plotAbmICUAndDeadComp.py: delete unused script for ICU and death comparison plots --- .../memilio/plot/plotAbmICUAndDeadComp.py | 266 ------------------ 1 file changed, 266 deletions(-) delete mode 100644 pycode/memilio-plot/memilio/plot/plotAbmICUAndDeadComp.py diff --git a/pycode/memilio-plot/memilio/plot/plotAbmICUAndDeadComp.py b/pycode/memilio-plot/memilio/plot/plotAbmICUAndDeadComp.py deleted file mode 100644 index 510a88104e..0000000000 --- a/pycode/memilio-plot/memilio/plot/plotAbmICUAndDeadComp.py +++ /dev/null @@ -1,266 +0,0 @@ -# Python script to analyze bs runs -# input is a bs run folder with the following structure: -# bs_run_folder has a txt file for each bs run -# each txt file has a line for each time step -# each line has a column for each compartment as well as the timestep -# each column has the number of individuals in that compartment -# the first line of each txt file is the header - -import sys -import argparse -import os -import pandas as pd -import numpy as np -import matplotlib.pyplot as plt -import matplotlib -import matplotlib.colors as colors -import matplotlib.cm as cmx -import matplotlib.patches as mpatches -import matplotlib.lines as mlines -import h5py -from datetime import datetime -from matplotlib.dates import DateFormatter - -fontsize = 20 - - -def plot_dead(path): - # we will have a seperate plot the cumulative infected individuals, cumulative symptomatic individuals and cumulative dead individual - # we need to load the data - f_p50 = h5py.File( - path+"/infection_state_per_age_group/0/p50/Results.h5", 'r') - p50_bs = f_p50['0'] - - # do the same for 25 and 75 percentile - f_p25 = h5py.File( - path+"/infection_state_per_age_group/0/p25/Results.h5", 'r') - p25_bs = f_p25['0'] - - f_p75 = h5py.File( - path+"/infection_state_per_age_group/0/p75/Results.h5", 'r') - p75_bs = f_p75['0'] - - # do the same for 05 and 95 percentile - f_p05 = h5py.File( - path+"/infection_state_per_age_group/0/p05/Results.h5", 'r') - p05_bs = f_p05['0'] - - f_p95 = h5py.File( - path+"/infection_state_per_age_group/0/p95/Results.h5", 'r') - p95_bs = f_p95['0'] - - age_group_access = ['Group1', 'Group2', 'Group3', - 'Group4', 'Group5', 'Group6', 'Total'] - - # we need the real data json file cases_all_county_age - df_abb = pd.read_json( - path+"/../../../pydata/Germany/cases_all_county_age_ma1.json") - - # we just need the columns cases and date - # we need to offset the dates by 19 day - df_abb['Date'] = df_abb['Date'] + pd.DateOffset(days=18) - # we need just the dates bewteen 2021-03-01 and 2021-06-01 - df_abb = df_abb[(df_abb['Date'] >= '2021-03-01') & - (df_abb['Date'] <= '2021-06-01')] - # we just need the cases with id 3101 - df_abb = df_abb[df_abb['ID_County'] == 3101] - # df_abb['Deaths'] = np.round(df_abb[['Deaths']].to_numpy()) - - # we need the amount of dead persons for each age group: These are A00-A04, A05-A14, A15-A34, A35-A59, A60-A79, A80+ - age_groups = ['A00-A04', 'A05-A14', 'A15-A34', 'A35-A59', 'A60-A79', 'A80+'] - age_grous_string = ['Age 0-4', 'Age 5-14', 'Age 15-34', 'Age 35-59', 'Age 60-79', 'Age 80+'] - # we need to sum up the amount of dead persons for each age group - - # we want the deaths for the age groups - df_abb = df_abb[['Date', 'Deaths', 'Age_RKI']] - # we want a plot with 2 rows. Second row has a plot with each age group and the simulated and real dead persons - # First row has the cumulative dead persons - fig = plt.figure('Deaths') - fig.set_figwidth(20) - fig.set_figheight(9) - gs = fig.add_gridspec(2,6) - - # we need the cumulative dead persons - ax = fig.add_subplot(gs[0, :]) - df_total_dead = df_abb.groupby('Date').sum()[0:90] - y_real = df_total_dead['Deaths'].to_numpy() - # we need to substract the first value from the rest - y_real = y_real - y_real[0] - - y_sim = p50_bs['Total'][()][:, 7][::24][0:90] - y_sim = y_sim - y_sim[0] - - y_sim25 = p25_bs['Total'][()][:, 7][::24][0:90] - y_sim25 = y_sim25 - y_sim25[0] - - y_sim75 = p75_bs['Total'][()][:, 7][::24][0:90] - y_sim75 = y_sim75 - y_sim75[0] - - y_sim05 = p05_bs['Total'][()][:, 7][::24][0:90] - y_sim05 = y_sim05 - y_sim05[0] - - y_sim95 = p95_bs['Total'][()][:, 7][::24][0:90] - y_sim95 = y_sim95 - y_sim95[0] - - - - # we calculate the RMSE - rmse_dead = np.sqrt(((y_real- y_sim)**2).mean()) - # we need to plot the cumulative dead persons from the real world and from the simulation - - ax.plot(df_total_dead.index, y_sim, color='tab:blue',label='Simulated deaths') - ax.plot(df_total_dead.index, y_real, 'v',color='tab:red', linewidth=4, label='Extrapolated deaths from reported infection case data') - ax.fill_between(df_total_dead.index, y_sim75, y_sim25, - alpha=0.5, color='tab:blue', label='50% Confidence interval') - ax.fill_between(df_total_dead.index,y_sim95, y_sim05, - alpha=0.25, color='tab:blue', label='90% Confidence interval') - # ax.text(0.25, 0.8, 'RMSE: '+str(float("{:.2f}".format(rmse_dead))), horizontalalignment='center', - # verticalalignment='center', transform=plt.gca().transAxes, color='pink', fontsize=15) - ax.set_label('Number of individuals') - ax.set_title('Cumulative Deaths', fontsize=fontsize) - ax.set_ylabel('Number of individuals', fontsize=fontsize-8) - ax.legend(fontsize=fontsize-8) - - # now for each age group - for i, age_group in zip(range(6), age_group_access): - ax = fig.add_subplot(gs[1, i]) - # we need the amount of dead persons for each age group - df_abb_age_group = df_abb[df_abb['Age_RKI'] == age_groups[i]][0:90] - y_real = np.round(df_abb_age_group['Deaths'].to_numpy()) - # we need to plot the dead persons from the real world and from the simulation - ax.plot(df_abb_age_group['Date'], y_real-y_real[0], color='tab:red') - ax.plot(df_abb_age_group['Date'], p50_bs[age_group_access[i]][()][:, 7][::24][0:90]-p50_bs[age_group_access[i]][()][:, 7][::24][0], color='tab:blue') - ax.fill_between(df_abb_age_group['Date'], p75_bs[age_group_access[i]][()][:, 7][::24][0:90]-p75_bs[age_group_access[i]][()][:, 7][::24][0], p25_bs[age_group_access[i]][()][:, 7][::24][0:90]-p25_bs[age_group_access[i]][()][:, 7][::24][0], - alpha=0.5, color='tab:blue') - ax.set_title('Deaths, '+age_grous_string[i]) - ax.set_ybound(lower=0) - ax.set_xticks(df_abb_age_group['Date'][::50]) - ax.tick_params(axis='both', which='major', labelsize=fontsize-10) - ax.tick_params(axis='both', which='minor', labelsize=fontsize-10) - if i == 0: - ax.set_ylabel('Number of individuals',fontsize=fontsize-8) - ax.set_ybound(upper=1) - - plt.show() - -def plot_icu(path): - - df_abb = pd.read_json(path+"/../../../pydata/Germany/county_divi.json") - - perc_of_critical_in_icu_age = [0.55,0.55,0.55,0.56,0.54,0.46] - perc_of_critical_in_icu=0.55 - - age_group_access = ['Group1', 'Group2', 'Group3', - 'Group4', 'Group5', 'Group6', 'Total'] - - - # we just need the columns ICU_low and ICU_hig - df_abb = df_abb[['ID_County', 'ICU', 'Date']] - - df_abb = df_abb[df_abb['ID_County'] == 3101] - # we need just the dates bewteen 2021-03-01 and 2021-06-01 - df_abb = df_abb[(df_abb['Date'] >= '2021-03-01') & - (df_abb['Date'] <= '2021-06-01')] - - # we plot this against this the Amount of persons in the ICU from our model - f_p50 = h5py.File( - path+"/infection_state_per_age_group/0/p50/Results.h5", 'r') - total_50 = f_p50['0']['Total'][()][::24][0:90] - - total_50_age = f_p50['0'][age_group_access[0]][()] - for i in range(6): - total_50_age += f_p50['0'][age_group_access[i]][()]*perc_of_critical_in_icu_age[i] - total_50_age = total_50_age[::24][0:90] - - - # we plot this against this the Amount of persons in the ICU from our model - f_p75 = h5py.File( - path+"/infection_state_per_age_group/0/p75/Results.h5", 'r') - # total_75 = f_p75['0']['Total'][()][::24][0:90] - total_75_age = f_p75['0'][age_group_access[0]][()] - for i in range(6): - total_75_age += f_p75['0'][age_group_access[i]][()]*perc_of_critical_in_icu_age[i] - total_75_age = total_75_age[::24][0:90] - - # same with 25 percentile - f_p25 = h5py.File( - path+"/infection_state_per_age_group/0/p25/Results.h5", 'r') - # total_25 = f_p25['0']['Total'][()][::24][0:90] - total_25_age = f_p25['0'][age_group_access[0]][()] - for i in range(6): - total_25_age += f_p25['0'][age_group_access[i]][()]*perc_of_critical_in_icu_age[i] - total_25_age = total_25_age[::24][0:90] - - # same with 05 and 95 percentile - f_p05 = h5py.File( - path+"/infection_state_per_age_group/0/p05/Results.h5", 'r') - # total_05 = f_p05['0']['Total'][()][::24][0:90] - total_05_age = f_p05['0'][age_group_access[0]][()] - for i in range(6): - total_05_age += f_p05['0'][age_group_access[i]][()]*perc_of_critical_in_icu_age[i] - total_05_age = total_05_age[::24][0:90] - - f_p95 = h5py.File( - path+"/infection_state_per_age_group/0/p95/Results.h5", 'r') - # total_95 = f_p95['0']['Total'][()][::24][0:90] - total_95_age = f_p95['0'][age_group_access[0]][()] - for i in range(6): - total_95_age += f_p95['0'][age_group_access[i]][()]*perc_of_critical_in_icu_age[i] - total_95_age = total_95_age[::24][0:90] - - - ICU_Simulation_one_percentile = np.floor(total_50[:, 5]*perc_of_critical_in_icu) - ICU_Simulation = np.round(total_50_age[:, 5]) - ICU_Simulation75 = np.round(total_75_age[:, 5]) - ICU_Simulation25 = np.round(total_25_age[:, 5]) - ICU_Simulation05 = np.round(total_05_age[:, 5]) - ICU_Simulation95 = np.round(total_95_age[:, 5]) - ICU_Real = df_abb['ICU'][0:90] - - #smooth the data - # ICU_Real = gaussian_filter1d(ICU_Real, sigma=1, mode='nearest') - # ICU_Simulation = gaussian_filter1d(ICU_Simulation, sigma=1, mode='nearest') - - - - # we calculate the RMSE - rmse_ICU = np.sqrt(((ICU_Real - ICU_Simulation_one_percentile)**2).mean()) - - # plot the ICU beds and the ICU beds taken - fig, ax = plt.subplots(1, 1, constrained_layout=True) - fig.set_figwidth(12) - fig.set_figheight(9) - # we plot the ICU_low and the ICU_high - ax.plot(df_abb['Date'][0:90], ICU_Real,'x', color='tab:red', linewidth=10, label='Data') - ax.plot(df_abb['Date'][0:90], ICU_Simulation, color='tab:blue', label='Simulation') - # ax.plot(df_abb['Date'][0:90], ICU_Simulation_one_percentile, color='tab:green', label='Simulated ICU beds') - ax.fill_between(df_abb['Date'][0:90],ICU_Simulation75, ICU_Simulation25, - alpha=0.5, color='tab:blue', label='50% Confidence interval') - ax.fill_between(df_abb['Date'][0:90],ICU_Simulation05, ICU_Simulation95, - alpha=0.25, color='tab:blue', label='90% Confidence interval') - - - # we also write the rmse - # ax.text(0.25, 0.8, 'RMSE: '+str(float("{:.2f}".format(rmse_ICU))), horizontalalignment='center', - # verticalalignment='center', transform=plt.gca().transAxes, color='pink', fontsize=15) - ax.tick_params(axis='both', which='major', labelsize=fontsize-4) - ax.tick_params(axis='both', which='minor', labelsize=fontsize-4) - ax.set_ylabel('Occupied ICU beds', fontsize=fontsize) - ax.set_title('ICU beds', fontsize=fontsize+4) - ax.legend(fontsize=fontsize-4) - plt.show() - - - - -if __name__ == "__main__": - path = "" - - if (len(sys.argv) > 1): - n_runs = sys.argv[1] - else: - n_runs = len([entry for entry in os.listdir(path) - if os.path.isfile(os.path.join(path, entry))]) - - # plot_icu(path) - # plot_dead(path) From 20b651e8a99098a327fdf2550572951855ac933d Mon Sep 17 00:00:00 2001 From: Sascha Korf <51127093+xsaschako@users.noreply.github.com> Date: Wed, 28 May 2025 00:34:25 +0200 Subject: [PATCH 09/11] formatting --- .../memilio/plot/plotAbmInfectionStates.py | 105 ++++++++++++------ .../test_plot_plotAbmInfectionStates.py | 40 ++++--- 2 files changed, 94 insertions(+), 51 deletions(-) diff --git a/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py b/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py index 91a6926c5e..00f2d749b3 100644 --- a/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py +++ b/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py @@ -37,19 +37,19 @@ # The used Loggers are: # struct LogInfectionStatePerAgeGroup : mio::LogAlways { # using Type = std::pair; -# /** +# /** # * @brief Log the TimeSeries of the number of Person%s in an #InfectionState. # * @param[in] sim The simulation of the abm. # * @return A pair of the TimePoint and the TimeSeries of the number of Person%s in an #InfectionState. # */ # static Type log(const mio::abm::Simulation& sim) # { -# +# # Eigen::VectorXd sum = Eigen::VectorXd::Zero( # Eigen::Index((size_t)mio::abm::InfectionState::Count * sim.get_world().parameters.get_num_groups())); # const auto curr_time = sim.get_time(); # const auto persons = sim.get_world().get_persons(); -# +# # // PRAGMA_OMP(parallel for) # for (auto i = size_t(0); i < persons.size(); ++i) { # auto& p = persons[i]; @@ -63,23 +63,23 @@ # return std::make_pair(curr_time, sum); # } # }; -# +# # struct LogInfectionPerLocationTypePerAgeGroup : mio::LogAlways { # using Type = std::pair; -# /** +# /** # * @brief Log the TimeSeries of the number of Person%s in an #InfectionState. # * @param[in] sim The simulation of the abm. # * @return A pair of the TimePoint and the TimeSeries of the number of Person%s in an #InfectionState. # */ # static Type log(const mio::abm::Simulation& sim) # { -# +# # Eigen::VectorXd sum = Eigen::VectorXd::Zero( # Eigen::Index((size_t)mio::abm::LocationType::Count * sim.get_world().parameters.get_num_groups())); # auto curr_time = sim.get_time(); # auto prev_time = sim.get_prev_time(); # const auto persons = sim.get_world().get_persons(); -# +# # // PRAGMA_OMP(parallel for) # for (auto i = size_t(0); i < persons.size(); ++i) { # auto& p = persons[i]; @@ -96,7 +96,7 @@ # return std::make_pair(curr_time, sum); # } # }; -# +# # The output of the loggers of several runs is stored in HDF5 files, with the memilio funciton mio::save_results in mio/io/result_io.h. @@ -112,6 +112,7 @@ def load_h5_results(base_path, percentile): data = {k: v[()] for k, v in f['0'].items()} return data + def plot_infections_loc_types_average( path_to_loc_types, start_date='2021-03-01', @@ -134,17 +135,22 @@ def plot_infections_loc_types_average( total_50 = p50['Total'] plt.figure('Infection_location_types') - plt.title('Infection per location type for the median run, rolling sum over 24 hours') + plt.title( + 'Infection per location type for the median run, rolling sum over 24 hours') color_plot = matplotlib.colormaps.get_cmap(colormap).colors # If you define further location types, you need to adjust this list - states_plot = [0, 1, 2, 3, 4, 5 , 6] - legend_plot = ['Home', 'School', 'Work', 'SocialEvent', 'BasicsShop', 'Hospital', 'ICU'] + states_plot = [0, 1, 2, 3, 4, 5, 6] + legend_plot = ['Home', 'School', 'Work', + 'SocialEvent', 'BasicsShop', 'Hospital', 'ICU'] for idx, i in enumerate(states_plot): - color = color_plot[i % len(color_plot)] if i < len(color_plot) else "black" + color = color_plot[i % len(color_plot)] if i < len( + color_plot) else "black" # Sum up every 24 hours, then smooth - indexer = pd.api.indexers.FixedForwardWindowIndexer(window_size=rolling_window) - y = pd.DataFrame(total_50[:, i]).rolling(window=indexer, min_periods=1).sum().to_numpy() + indexer = pd.api.indexers.FixedForwardWindowIndexer( + window_size=rolling_window) + y = pd.DataFrame(total_50[:, i]).rolling( + window=indexer, min_periods=1).sum().to_numpy() y = y[0::rolling_window].flatten() y = gaussian_filter1d(y, sigma=smooth_sigma, mode='nearest') plt.plot(time[0::rolling_window], y, color=color) @@ -155,6 +161,7 @@ def plot_infections_loc_types_average( plt.ylabel('Number of individuals') plt.show() + def plot_infection_states_results( path_to_infection_states, start_date='2021-03-01', @@ -177,7 +184,9 @@ def plot_infection_states_results( total_75 = p75['Total'] plot_infection_states_individual(time, p50, p25, p75, colormap) - plot_infection_states(time, total_50, total_25, total_75, start_date, colormap, xtick_step) + plot_infection_states(time, total_50, total_25, + total_75, start_date, colormap, xtick_step) + def plot_infection_states( x, y50, y25, y75, @@ -202,11 +211,14 @@ def plot_infection_states( for i in states_plot: plt.plot(x, y50[:, i], color=color_plot[i]) - plt.legend(legend_plot) # Needs to be done here, otherwise the percentage fill_between will not work correctly + # Needs to be done here, otherwise the percentage fill_between will not work correctly + plt.legend(legend_plot) for i in states_plot: - plt.fill_between(x, y50[:, i], y25[:, i], alpha=0.5, color=color_plot[i]) - plt.fill_between(x, y50[:, i], y75[:, i], alpha=0.5, color=color_plot[i]) + plt.fill_between(x, y50[:, i], y25[:, i], + alpha=0.5, color=color_plot[i]) + plt.fill_between(x, y50[:, i], y75[:, i], + alpha=0.5, color=color_plot[i]) plt.legend(legend_plot) _format_x_axis(x, start_date, xtick_step) @@ -214,6 +226,7 @@ def plot_infection_states( plt.ylabel('Number of individuals') plt.show() + def plot_infection_states_individual(x, p50_bs, p25_bs, p75_bs, colormap='Set1'): """ Plots infection states for each age group. @@ -223,28 +236,37 @@ def plot_infection_states_individual(x, p50_bs, p25_bs, p75_bs, colormap='Set1') @param[in] p75_bs 75th percentile values by group. @param[in] colormap Matplotlib colormap. """ - age_groups = ['Group1', 'Group2', 'Group3', 'Group4', 'Group5', 'Group6', 'Total'] # Adjust as needed + age_groups = ['Group1', 'Group2', 'Group3', 'Group4', + 'Group5', 'Group6', 'Total'] # Adjust as needed color_plot = matplotlib.colormaps.get_cmap(colormap).colors - fig, ax = plt.subplots(6, len(age_groups), constrained_layout=True, figsize=(20, 9)) + fig, ax = plt.subplots( + 6, len(age_groups), constrained_layout=True, figsize=(20, 9)) for col_idx, group in enumerate(age_groups): y50 = p50_bs[group] y25 = p25_bs[group] y75 = p75_bs[group] # Infected no symptoms - _plot_state(ax[0, col_idx], x, y50[:, 1], y25[:, 1], y75[:, 1], color_plot[col_idx], '#Infected_no_symptoms, Age' + str(group)) + _plot_state(ax[0, col_idx], x, y50[:, 1], y25[:, 1], y75[:, 1], + color_plot[col_idx], '#Infected_no_symptoms, Age' + str(group)) # Infected symptoms - _plot_state(ax[1, col_idx], x, y50[:, 2], y25[:, 2], y75[:, 2], color_plot[col_idx], '#Infected_symptoms, Age' + str(group)) + _plot_state(ax[1, col_idx], x, y50[:, 2], y25[:, 2], y75[:, 2], + color_plot[col_idx], '#Infected_symptoms, Age' + str(group)) # Severe - _plot_state(ax[2, col_idx], x, y50[:, 4], y25[:, 4], y75[:, 4], color_plot[col_idx], '#Severe, Age' + str(group)) + _plot_state(ax[2, col_idx], x, y50[:, 4], y25[:, 4], y75[:, 4], + color_plot[col_idx], '#Severe, Age' + str(group)) # Critical - _plot_state(ax[3, col_idx], x, y50[:, 5], y25[:, 5], y75[:, 5], color_plot[col_idx], '#Critical, Age' + str(group)) + _plot_state(ax[3, col_idx], x, y50[:, 5], y25[:, 5], y75[:, 5], + color_plot[col_idx], '#Critical, Age' + str(group)) # Dead - _plot_state(ax[4, col_idx], x, y50[:, 7], y25[:, 7], y75[:, 7], color_plot[col_idx], '#Dead, Age' + str(group)) + _plot_state(ax[4, col_idx], x, y50[:, 7], y25[:, 7], + y75[:, 7], color_plot[col_idx], '#Dead, Age' + str(group)) # Recovered - _plot_state(ax[5, col_idx], x, y50[:, 6], y25[:, 6], y75[:, 6], color_plot[col_idx], '#Recovered, Age' + str(group)) - - fig.suptitle('Infection states per age group with 50% percentile', fontsize=16) + _plot_state(ax[5, col_idx], x, y50[:, 6], y25[:, 6], y75[:, 6], + color_plot[col_idx], '#Recovered, Age' + str(group)) + + fig.suptitle( + 'Infection states per age group with 50% percentile', fontsize=16) # We hide the Legend for the individual plots as it is too cluttered for ax_row in ax: @@ -253,6 +275,7 @@ def plot_infection_states_individual(x, p50_bs, p25_bs, p75_bs, colormap='Set1') plt.show() + def _plot_state(ax, x, y50, y25, y75, color, title): """ Helper to plot a single state with fill_between. """ ax.set_xlabel('time (days)') @@ -263,6 +286,7 @@ def _plot_state(ax, x, y50, y25, y75, color, title): ax.set_title(title) ax.legend(['Simulation']) + def _format_x_axis(x, start_date, xtick_step): """ Helper to format x-axis as dates. """ start = datetime.strptime(start_date, '%Y-%m-%d') @@ -272,17 +296,24 @@ def _format_x_axis(x, start_date, xtick_step): plt.gca().set_xticklabels(xx_str[::xtick_step]) plt.gcf().autofmt_xdate() + def main(): """ Main function for CLI usage. """ - parser = argparse.ArgumentParser(description="Plot infection state and location type results.") - parser.add_argument("--path-to-infection-states", help="Path to infection states results") - parser.add_argument("--path-to-loc-types", help="Path to location types results") - parser.add_argument("--start-date", type=str, default='2021-03-01', help="Simulation start date (YYYY-MM-DD)") - parser.add_argument("--colormap", type=str, default='Set1', help="Matplotlib colormap") - parser.add_argument("--xtick-step", type=int, default=150, help="Step for x-axis ticks") + parser = argparse.ArgumentParser( + description="Plot infection state and location type results.") + parser.add_argument("--path-to-infection-states", + help="Path to infection states results") + parser.add_argument("--path-to-loc-types", + help="Path to location types results") + parser.add_argument("--start-date", type=str, default='2021-03-01', + help="Simulation start date (YYYY-MM-DD)") + parser.add_argument("--colormap", type=str, + default='Set1', help="Matplotlib colormap") + parser.add_argument("--xtick-step", type=int, + default=150, help="Step for x-axis ticks") args = parser.parse_args() - if args.path_to_infection_states: + if args.path_to_infection_states: plot_infection_states_results( args.path_to_infection_states, start_date=args.start_date, @@ -294,12 +325,12 @@ def main(): start_date=args.start_date, colormap=args.colormap, xtick_step=args.xtick_step) - + if not args.path_to_infection_states and not args.path_to_loc_types: print("Please provide a path to infection states or location types results.") sys.exit(1) plt.show() - + if __name__ == "__main__": main() diff --git a/pycode/memilio-plot/memilio/plot_test/test_plot_plotAbmInfectionStates.py b/pycode/memilio-plot/memilio/plot_test/test_plot_plotAbmInfectionStates.py index 41d5a31da5..bb04cc54a9 100644 --- a/pycode/memilio-plot/memilio/plot_test/test_plot_plotAbmInfectionStates.py +++ b/pycode/memilio-plot/memilio/plot_test/test_plot_plotAbmInfectionStates.py @@ -25,6 +25,7 @@ import memilio.plot.plotAbmInfectionStates as abm + class TestPlotAbmInfectionStates(unittest.TestCase): @patch('memilio.plot.plotAbmInfectionStates.h5py.File') @@ -33,7 +34,8 @@ def test_load_h5_results(self, mock_h5file): mock_h5file().__enter__().get.return_value = {'0': mock_group} mock_h5file().__enter__().items.return_value = [('0', mock_group)] mock_h5file().__enter__().__getitem__.return_value = mock_group - mock_h5file().__enter__().items.return_value = [('Time', np.arange(10)), ('Total', np.ones((10, 8)))] + mock_h5file().__enter__().items.return_value = [ + ('Time', np.arange(10)), ('Total', np.ones((10, 8)))] with patch('memilio.plot.plotAbmInfectionStates.h5py.File', mock_h5file): result = abm.load_h5_results('dummy_path', 'p50') assert 'Time' in result @@ -46,9 +48,11 @@ def test_load_h5_results(self, mock_h5file): @patch('memilio.plot.plotAbmInfectionStates.gaussian_filter1d', side_effect=lambda x, sigma, mode: x) @patch('memilio.plot.plotAbmInfectionStates.pd.DataFrame') def test_plot_infections_loc_types_average(self, mock_df, mock_gauss, mock_matplotlib, mock_load): - mock_load.return_value = {'Time': np.arange(48), 'Total': np.ones((48, 7))} - mock_df.return_value.rolling.return_value.sum.return_value.to_numpy.return_value = np.ones((48, 1)) - mock_matplotlib.colormaps.get_cmap.return_value.colors = [(1,0,0)]*7 + mock_load.return_value = { + 'Time': np.arange(48), 'Total': np.ones((48, 7))} + mock_df.return_value.rolling.return_value.sum.return_value.to_numpy.return_value = np.ones( + (48, 1)) + mock_matplotlib.colormaps.get_cmap.return_value.colors = [(1, 0, 0)]*7 # Patch plt.gca().plot to a MagicMock with patch.object(abm.plt, 'gca') as mock_gca: @@ -64,9 +68,12 @@ def test_plot_infections_loc_types_average(self, mock_df, mock_gauss, mock_matpl @patch('memilio.plot.plotAbmInfectionStates.plot_infection_states_individual') def test_plot_infection_states_results(self, mock_indiv, mock_states, mock_load): mock_load.side_effect = [ - {'Time': np.arange(10), 'Total': np.ones((10, 8)), 'Group1': np.ones((10, 8)), 'Group2': np.ones((10, 8)), 'Group3': np.ones((10, 8)), 'Group4': np.ones((10, 8)), 'Group5': np.ones((10, 8)), 'Group6': np.ones((10, 8)), 'Total': np.ones((10, 8))}, - {'Time': np.arange(10), 'Total': np.ones((10, 8)), 'Group1': np.ones((10, 8)), 'Group2': np.ones((10, 8)), 'Group3': np.ones((10, 8)), 'Group4': np.ones((10, 8)), 'Group5': np.ones((10, 8)), 'Group6': np.ones((10, 8)), 'Total': np.ones((10, 8))}, - {'Time': np.arange(10), 'Total': np.ones((10, 8)), 'Group1': np.ones((10, 8)), 'Group2': np.ones((10, 8)), 'Group3': np.ones((10, 8)), 'Group4': np.ones((10, 8)), 'Group5': np.ones((10, 8)), 'Group6': np.ones((10, 8)), 'Total': np.ones((10, 8))} + {'Time': np.arange(10), 'Total': np.ones((10, 8)), 'Group1': np.ones((10, 8)), 'Group2': np.ones((10, 8)), 'Group3': np.ones( + (10, 8)), 'Group4': np.ones((10, 8)), 'Group5': np.ones((10, 8)), 'Group6': np.ones((10, 8)), 'Total': np.ones((10, 8))}, + {'Time': np.arange(10), 'Total': np.ones((10, 8)), 'Group1': np.ones((10, 8)), 'Group2': np.ones((10, 8)), 'Group3': np.ones( + (10, 8)), 'Group4': np.ones((10, 8)), 'Group5': np.ones((10, 8)), 'Group6': np.ones((10, 8)), 'Total': np.ones((10, 8))}, + {'Time': np.arange(10), 'Total': np.ones((10, 8)), 'Group1': np.ones((10, 8)), 'Group2': np.ones((10, 8)), 'Group3': np.ones( + (10, 8)), 'Group4': np.ones((10, 8)), 'Group5': np.ones((10, 8)), 'Group6': np.ones((10, 8)), 'Total': np.ones((10, 8))} ] abm.plot_infection_states_results('dummy_path') assert mock_indiv.called @@ -78,7 +85,7 @@ def test_plot_infection_states(self, mock_matplotlib): y50 = np.ones((10, 8)) y25 = np.zeros((10, 8)) y75 = np.ones((10, 8))*2 - mock_matplotlib.colormaps.get_cmap.return_value.colors = [(1,0,0)]*8 + mock_matplotlib.colormaps.get_cmap.return_value.colors = [(1, 0, 0)]*8 # Patch plt.gca().plot and fill_between with patch.object(abm.plt, 'gca') as mock_gca: @@ -94,10 +101,13 @@ def test_plot_infection_states(self, mock_matplotlib): def test_plot_infection_states_individual(self, mock_matplotlib): x = np.arange(10) group_data = np.ones((10, 8)) - p50_bs = {g: group_data for g in ['Group1', 'Group2', 'Group3', 'Group4', 'Group5', 'Group6', 'Total']} - p25_bs = {g: group_data for g in ['Group1', 'Group2', 'Group3', 'Group4', 'Group5', 'Group6', 'Total']} - p75_bs = {g: group_data for g in ['Group1', 'Group2', 'Group3', 'Group4', 'Group5', 'Group6', 'Total']} - mock_matplotlib.colormaps.get_cmap.return_value.colors = [(1,0,0)]*8 + p50_bs = {g: group_data for g in [ + 'Group1', 'Group2', 'Group3', 'Group4', 'Group5', 'Group6', 'Total']} + p25_bs = {g: group_data for g in [ + 'Group1', 'Group2', 'Group3', 'Group4', 'Group5', 'Group6', 'Total']} + p75_bs = {g: group_data for g in [ + 'Group1', 'Group2', 'Group3', 'Group4', 'Group5', 'Group6', 'Total']} + mock_matplotlib.colormaps.get_cmap.return_value.colors = [(1, 0, 0)]*8 # Patch plt.subplots to return a grid of MagicMock axes (as np.array with dtype=object) with patch.object(abm.plt, 'subplots') as mock_subplots: @@ -109,7 +119,8 @@ def test_plot_infection_states_individual(self, mock_matplotlib): mock_subplots.return_value = (fig_mock, ax_mock) abm.plot_infection_states_individual(x, p50_bs, p25_bs, p75_bs) # Check that at least one ax's plot was called - assert any(ax_mock[i, j].plot.called for i in range(6) for j in range(7)) + assert any(ax_mock[i, j].plot.called for i in range(6) + for j in range(7)) assert fig_mock.suptitle.called def test__format_x_axis(self): @@ -118,5 +129,6 @@ def test__format_x_axis(self): assert mock_plt.gca.called assert mock_plt.gcf.called + if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() From 706aea060ca0519a05c6c62adbcaf384c793eae7 Mon Sep 17 00:00:00 2001 From: Sascha Korf <51127093+xsaschako@users.noreply.github.com> Date: Wed, 28 May 2025 20:47:21 +0200 Subject: [PATCH 10/11] Implement martins suggestions --- .../memilio/plot/plotAbmInfectionStates.py | 228 +++++++++++------- .../test_plot_plotAbmInfectionStates.py | 24 +- 2 files changed, 160 insertions(+), 92 deletions(-) diff --git a/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py b/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py index 00f2d749b3..0ff71347dd 100644 --- a/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py +++ b/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py @@ -99,6 +99,40 @@ # # The output of the loggers of several runs is stored in HDF5 files, with the memilio funciton mio::save_results in mio/io/result_io.h. +# Adjust these as needed. +# States and location type numbers need to match the infection states used in your simulation. +state_labels = { + 1: 'Exposed', + 2: 'I_Asymp', + 3: 'I_Symp', + 4: 'I_Severe', + 5: 'I_Critical', + 7: 'Dead' +} + +age_groups = ['Group1', 'Group2', 'Group3', 'Group4', + 'Group5', 'Group6', 'Total'] + +age_groups_dict = { + 'Group1': 'Ages 0-4', + 'Group2': 'Ages 5-14', + 'Group3': 'Ages 15-34', + 'Group4': 'Ages 35-59', + 'Group5': 'Ages 60-79', + 'Group6': 'Ages 80+', + 'Total': 'All Ages' +} + +location_type_labels = { + 0: 'Home', + 1: 'School', + 2: 'Work', + 3: 'SocialEvent', + 4: 'BasicsShop', + 5: 'Hospital', + 6: 'ICU' +} + def load_h5_results(base_path, percentile): """ Reads HDF5 results for a given group and percentile. @@ -138,12 +172,8 @@ def plot_infections_loc_types_average( plt.title( 'Infection per location type for the median run, rolling sum over 24 hours') color_plot = matplotlib.colormaps.get_cmap(colormap).colors - # If you define further location types, you need to adjust this list - states_plot = [0, 1, 2, 3, 4, 5, 6] - legend_plot = ['Home', 'School', 'Work', - 'SocialEvent', 'BasicsShop', 'Hospital', 'ICU'] - for idx, i in enumerate(states_plot): + for idx, i in enumerate(location_type_labels.keys()): color = color_plot[i % len(color_plot)] if i < len( color_plot) else "black" # Sum up every 24 hours, then smooth @@ -153,9 +183,9 @@ def plot_infections_loc_types_average( window=indexer, min_periods=1).sum().to_numpy() y = y[0::rolling_window].flatten() y = gaussian_filter1d(y, sigma=smooth_sigma, mode='nearest') - plt.plot(time[0::rolling_window], y, color=color) + plt.plot(time[0::rolling_window], y, color=color, linewidth=2.5) - plt.legend(legend_plot) + plt.legend(list(location_type_labels.values())) _format_x_axis(time, start_date, xtick_step) plt.xlabel('Date') plt.ylabel('Number of individuals') @@ -166,14 +196,10 @@ def plot_infection_states_results( path_to_infection_states, start_date='2021-03-01', colormap='Set1', - xtick_step=150): - """ Loads and plots infection state results. - - @param[in] base_path Path to results directory. - @param[in] start_date Start date as string. - @param[in] colormap Matplotlib colormap. - @param[in] xtick_step Step size for x-axis ticks. - """ + xtick_step=150, + show90=False +): + """ Loads and plots infection state results. """ # Load data p50 = load_h5_results(path_to_infection_states, "p50") p25 = load_h5_results(path_to_infection_states, "p25") @@ -182,109 +208,126 @@ def plot_infection_states_results( total_50 = p50['Total'] total_25 = p25['Total'] total_75 = p75['Total'] - - plot_infection_states_individual(time, p50, p25, p75, colormap) + p05 = p95 = None + total_05 = total_95 = None + if show90: + total_95 = load_h5_results(path_to_infection_states, "p95") + total_05 = load_h5_results(path_to_infection_states, "p05") + p95 = total_95['Total'] + p05 = total_05['Total'] + + plot_infection_states_individual( + time, p50, p25, p75, colormap, + p05_bs=total_05 if show90 else None, + p95_bs=total_95 if show90 else None, + show90=show90 + ) plot_infection_states(time, total_50, total_25, - total_75, start_date, colormap, xtick_step) + total_75, start_date, colormap, xtick_step, + y05=p05, y95=p95, show_90=show90) def plot_infection_states( x, y50, y25, y75, start_date='2021-03-01', colormap='Set1', - xtick_step=150): - """ Plots infection states with percentiles. + xtick_step=150, + y05=None, y95=None, show_90=False): + """ Plots infection states with percentiles and improved styling. """ + plt.figure('Infection_states') - @param[in] x Time array. - @param[in] y50 Median values. - @param[in] y25 25th percentile values. - @param[in] y75 75th percentile values. - @param[in] start_date Start date as string. - @param[in] colormap Matplotlib colormap. - @param[in] xtick_step Step size for x-axis ticks. - """ - plt.figure('Infection_states with 50% percentile') plt.title('Infection states with 50% percentile') + if show_90: + plt.title('Infection states with 50% and 90% percentiles') + color_plot = matplotlib.colormaps.get_cmap(colormap).colors - states_plot = [1, 2, 3, 4, 5, 7] - legend_plot = ['E', 'I_NSymp', 'I_Symp', 'I_Sev', 'I_Crit', 'Dead'] - for i in states_plot: - plt.plot(x, y50[:, i], color=color_plot[i]) - # Needs to be done here, otherwise the percentage fill_between will not work correctly - plt.legend(legend_plot) + states_plot = list(state_labels.keys()) for i in states_plot: + plt.plot(x, y50[:, i], color=color_plot[i], + linewidth=2.5, label=state_labels[i]) + # needs to be after the plot calls + plt.legend([state_labels[i] for i in states_plot]) + for i in states_plot: + plt.plot(x, y25[:, i], color=color_plot[i], + linestyle='dashdot', linewidth=1.2, alpha=0.7) + plt.plot(x, y75[:, i], color=color_plot[i], + linestyle='dashdot', linewidth=1.2, alpha=0.7) plt.fill_between(x, y50[:, i], y25[:, i], - alpha=0.5, color=color_plot[i]) + alpha=0.2, color=color_plot[i]) plt.fill_between(x, y50[:, i], y75[:, i], - alpha=0.5, color=color_plot[i]) + alpha=0.2, color=color_plot[i]) + # Optional: 90% percentile + if show_90 and y05 is not None and y95 is not None: + plt.plot(x, y05[:, i], color=color_plot[i], + linestyle='dashdot', linewidth=1.0, alpha=0.4) + plt.plot(x, y95[:, i], color=color_plot[i], + linestyle='dashdot', linewidth=1.0, alpha=0.4) + plt.fill_between(x, y05[:, i], y95[:, i], + # More transparent + alpha=0.25, color=color_plot[i]) - plt.legend(legend_plot) _format_x_axis(x, start_date, xtick_step) - plt.xlabel('Time') + plt.xlabel('Date') plt.ylabel('Number of individuals') plt.show() -def plot_infection_states_individual(x, p50_bs, p25_bs, p75_bs, colormap='Set1'): - """ Plots infection states for each age group. +def plot_infection_states_individual( + x, p50_bs, p25_bs, p75_bs, colormap='Set1', + p05_bs=None, p95_bs=None, show90=False +): + """ Plots infection states for each age group, with optional 90% percentile. """ - @param[in] x Time array. - @param[in] p50_bs Median values by group. - @param[in] p25_bs 25th percentile values by group. - @param[in] p75_bs 75th percentile values by group. - @param[in] colormap Matplotlib colormap. - """ - age_groups = ['Group1', 'Group2', 'Group3', 'Group4', - 'Group5', 'Group6', 'Total'] # Adjust as needed color_plot = matplotlib.colormaps.get_cmap(colormap).colors + n_states = len(state_labels) fig, ax = plt.subplots( - 6, len(age_groups), constrained_layout=True, figsize=(20, 9)) + n_states, len(age_groups), constrained_layout=True, figsize=(20, 3 * n_states)) for col_idx, group in enumerate(age_groups): y50 = p50_bs[group] y25 = p25_bs[group] y75 = p75_bs[group] - # Infected no symptoms - _plot_state(ax[0, col_idx], x, y50[:, 1], y25[:, 1], y75[:, 1], - color_plot[col_idx], '#Infected_no_symptoms, Age' + str(group)) - # Infected symptoms - _plot_state(ax[1, col_idx], x, y50[:, 2], y25[:, 2], y75[:, 2], - color_plot[col_idx], '#Infected_symptoms, Age' + str(group)) - # Severe - _plot_state(ax[2, col_idx], x, y50[:, 4], y25[:, 4], y75[:, 4], - color_plot[col_idx], '#Severe, Age' + str(group)) - # Critical - _plot_state(ax[3, col_idx], x, y50[:, 5], y25[:, 5], y75[:, 5], - color_plot[col_idx], '#Critical, Age' + str(group)) - # Dead - _plot_state(ax[4, col_idx], x, y50[:, 7], y25[:, 7], - y75[:, 7], color_plot[col_idx], '#Dead, Age' + str(group)) - # Recovered - _plot_state(ax[5, col_idx], x, y50[:, 6], y25[:, 6], y75[:, 6], - color_plot[col_idx], '#Recovered, Age' + str(group)) - + y05 = p05_bs[group] if (show90 and p05_bs is not None) else None + y95 = p95_bs[group] if (show90 and p95_bs is not None) else None + for row_idx, (state_idx, label) in enumerate(state_labels.items()): + _plot_state( + ax[row_idx, col_idx], x, y50[:, state_idx], y25[:, + state_idx], y75[:, state_idx], + color_plot[col_idx], f'#{label}, {age_groups_dict[group]}', + y05=y05[:, state_idx] if y05 is not None else None, + y95=y95[:, state_idx] if y95 is not None else None, + show90=show90 + ) + # The legend should say: solid line = median, dashed line = 25% and 75% perc. and if show90 is True, dotted line = 5%, 25%, 75%, 95% perc. + perc_string = '25/75%' if not show90 else '5/25/75/95%' + ax[row_idx, col_idx].legend( + ['Median', f'{perc_string} perc.'], + loc='upper left', fontsize=8) + + string_short = ' and 90%' if show90 else '' fig.suptitle( - 'Infection states per age group with 50% percentile', fontsize=16) - - # We hide the Legend for the individual plots as it is too cluttered - for ax_row in ax: - for ax_col in ax_row: - ax_col.legend().set_visible(False) + 'Infection states per age group with 50' + string_short + ' percentile', + fontsize=16) plt.show() -def _plot_state(ax, x, y50, y25, y75, color, title): - """ Helper to plot a single state with fill_between. """ +def _plot_state(ax, x, y50, y25, y75, color, title, y05=None, y95=None, show90=False): + """ Helper to plot a single state with fill_between and optional 90% percentile. """ ax.set_xlabel('time (days)') ax.plot(x, y50, color=color, label='Median') ax.fill_between(x, y50, y25, alpha=0.5, color=color) ax.fill_between(x, y50, y75, alpha=0.5, color=color) + if show90 and y05 is not None and y95 is not None: + ax.plot(x, y05, color=color, linestyle='dotted', + linewidth=1.0, alpha=0.4) + ax.plot(x, y95, color=color, linestyle='dotted', + linewidth=1.0, alpha=0.4) + ax.fill_between(x, y05, y95, alpha=0.15, color=color) ax.tick_params(axis='y') ax.set_title(title) - ax.legend(['Simulation']) def _format_x_axis(x, start_date, xtick_step): @@ -311,20 +354,25 @@ def main(): default='Set1', help="Matplotlib colormap") parser.add_argument("--xtick-step", type=int, default=150, help="Step for x-axis ticks") + parser.add_argument("--90percentile", action="store_true", + help="If set, plot 90% percentile as well") args = parser.parse_args() - if args.path_to_infection_states: - plot_infection_states_results( - args.path_to_infection_states, - start_date=args.start_date, - colormap=args.colormap, - xtick_step=args.xtick_step) - if args.path_to_loc_types: - plot_infections_loc_types_average( - args.path_to_loc_types, - start_date=args.start_date, - colormap=args.colormap, - xtick_step=args.xtick_step) + path_to_infection_states = "/Users/saschakorf/Nosynch/Arbeit/memilio/memilio/data/cluster_results/final_results/results_2024-09-20192904_best/infection_per_location_type_per_age_group/0" + path_to_loc_types = "/Users/saschakorf/Nosynch/Arbeit/memilio/memilio/data/cluster_results/final_results/results_2024-09-20192904_best/infection_state_per_age_group/0" + + plot_infection_states_results( + path_to_loc_types, + start_date=args.start_date, + colormap=args.colormap, + xtick_step=args.xtick_step, + show90=True + ) + plot_infections_loc_types_average( + path_to_infection_states, + start_date=args.start_date, + colormap=args.colormap, + xtick_step=args.xtick_step) if not args.path_to_infection_states and not args.path_to_loc_types: print("Please provide a path to infection states or location types results.") diff --git a/pycode/memilio-plot/memilio/plot_test/test_plot_plotAbmInfectionStates.py b/pycode/memilio-plot/memilio/plot_test/test_plot_plotAbmInfectionStates.py index bb04cc54a9..fbdba48535 100644 --- a/pycode/memilio-plot/memilio/plot_test/test_plot_plotAbmInfectionStates.py +++ b/pycode/memilio-plot/memilio/plot_test/test_plot_plotAbmInfectionStates.py @@ -85,13 +85,23 @@ def test_plot_infection_states(self, mock_matplotlib): y50 = np.ones((10, 8)) y25 = np.zeros((10, 8)) y75 = np.ones((10, 8))*2 + y05 = np.ones((10, 8))*-1 + y95 = np.ones((10, 8))*3 mock_matplotlib.colormaps.get_cmap.return_value.colors = [(1, 0, 0)]*8 # Patch plt.gca().plot and fill_between with patch.object(abm.plt, 'gca') as mock_gca: mock_ax = MagicMock() mock_gca.return_value = mock_ax - abm.plot_infection_states(x, y50, y25, y75) + abm.plot_infection_states( + x, y50, y25, y75, + start_date='2021-03-01', + colormap='Set1', + xtick_step=2, + y05=y05, + y95=y95, + show_90=True + ) assert mock_ax.plot.called assert mock_ax.fill_between.called assert mock_ax.set_xticks.called @@ -107,6 +117,10 @@ def test_plot_infection_states_individual(self, mock_matplotlib): 'Group1', 'Group2', 'Group3', 'Group4', 'Group5', 'Group6', 'Total']} p75_bs = {g: group_data for g in [ 'Group1', 'Group2', 'Group3', 'Group4', 'Group5', 'Group6', 'Total']} + p05_bs = {g: group_data*-1 for g in [ + 'Group1', 'Group2', 'Group3', 'Group4', 'Group5', 'Group6', 'Total']} + p95_bs = {g: group_data*3 for g in [ + 'Group1', 'Group2', 'Group3', 'Group4', 'Group5', 'Group6', 'Total']} mock_matplotlib.colormaps.get_cmap.return_value.colors = [(1, 0, 0)]*8 # Patch plt.subplots to return a grid of MagicMock axes (as np.array with dtype=object) @@ -117,7 +131,13 @@ def test_plot_infection_states_individual(self, mock_matplotlib): for j in range(7): ax_mock[i, j] = MagicMock() mock_subplots.return_value = (fig_mock, ax_mock) - abm.plot_infection_states_individual(x, p50_bs, p25_bs, p75_bs) + abm.plot_infection_states_individual( + x, p50_bs, p25_bs, p75_bs, + colormap='Set1', + p05_bs=p05_bs, + p95_bs=p95_bs, + show90=True + ) # Check that at least one ax's plot was called assert any(ax_mock[i, j].plot.called for i in range(6) for j in range(7)) From ab2e0e8a572f7c4517f35031d3db4217ecfa1162 Mon Sep 17 00:00:00 2001 From: Sascha Korf <51127093+xsaschako@users.noreply.github.com> Date: Wed, 28 May 2025 20:48:23 +0200 Subject: [PATCH 11/11] Refactor plotAbmInfectionStates.py: update paths to use command line arguments for better flexibility --- .../memilio-plot/memilio/plot/plotAbmInfectionStates.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py b/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py index 0ff71347dd..cf2c0414d4 100644 --- a/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py +++ b/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py @@ -357,19 +357,15 @@ def main(): parser.add_argument("--90percentile", action="store_true", help="If set, plot 90% percentile as well") args = parser.parse_args() - - path_to_infection_states = "/Users/saschakorf/Nosynch/Arbeit/memilio/memilio/data/cluster_results/final_results/results_2024-09-20192904_best/infection_per_location_type_per_age_group/0" - path_to_loc_types = "/Users/saschakorf/Nosynch/Arbeit/memilio/memilio/data/cluster_results/final_results/results_2024-09-20192904_best/infection_state_per_age_group/0" - plot_infection_states_results( - path_to_loc_types, + args.path_to_infection_states, start_date=args.start_date, colormap=args.colormap, xtick_step=args.xtick_step, show90=True ) plot_infections_loc_types_average( - path_to_infection_states, + args.path_to_loc_types, start_date=args.start_date, colormap=args.colormap, xtick_step=args.xtick_step)