diff --git a/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py b/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py new file mode 100644 index 0000000000..cf2c0414d4 --- /dev/null +++ b/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py @@ -0,0 +1,380 @@ +############################################################################# +# Copyright (C) 2020-2024 MEmilio +# +# Authors: Sascha Korf +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# + +import sys +import argparse +import os +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import matplotlib +import h5py +from datetime import datetime +from scipy.ndimage import gaussian_filter1d + + +# Module for plotting infection states and location types from ABM results. +# This module provides functions to load and visualize infection states and +# location types from simulation results stored in HDF5 format and are output +# by the MEmilio agent-based model (ABM). +# The used Loggers are: +# struct LogInfectionStatePerAgeGroup : mio::LogAlways { +# using Type = std::pair; +# /** +# * @brief Log the TimeSeries of the number of Person%s in an #InfectionState. +# * @param[in] sim The simulation of the abm. +# * @return A pair of the TimePoint and the TimeSeries of the number of Person%s in an #InfectionState. +# */ +# static Type log(const mio::abm::Simulation& sim) +# { +# +# Eigen::VectorXd sum = Eigen::VectorXd::Zero( +# Eigen::Index((size_t)mio::abm::InfectionState::Count * sim.get_world().parameters.get_num_groups())); +# const auto curr_time = sim.get_time(); +# const auto persons = sim.get_world().get_persons(); +# +# // PRAGMA_OMP(parallel for) +# for (auto i = size_t(0); i < persons.size(); ++i) { +# auto& p = persons[i]; +# if (p.get_should_be_logged()) { +# auto index = (((size_t)(mio::abm::InfectionState::Count)) * ((uint32_t)p.get_age().get())) + +# ((uint32_t)p.get_infection_state(curr_time)); +# // PRAGMA_OMP(atomic) +# sum[index] += 1; +# } +# } +# return std::make_pair(curr_time, sum); +# } +# }; +# +# struct LogInfectionPerLocationTypePerAgeGroup : mio::LogAlways { +# using Type = std::pair; +# /** +# * @brief Log the TimeSeries of the number of Person%s in an #InfectionState. +# * @param[in] sim The simulation of the abm. +# * @return A pair of the TimePoint and the TimeSeries of the number of Person%s in an #InfectionState. +# */ +# static Type log(const mio::abm::Simulation& sim) +# { +# +# Eigen::VectorXd sum = Eigen::VectorXd::Zero( +# Eigen::Index((size_t)mio::abm::LocationType::Count * sim.get_world().parameters.get_num_groups())); +# auto curr_time = sim.get_time(); +# auto prev_time = sim.get_prev_time(); +# const auto persons = sim.get_world().get_persons(); +# +# // PRAGMA_OMP(parallel for) +# for (auto i = size_t(0); i < persons.size(); ++i) { +# auto& p = persons[i]; +# if (p.get_should_be_logged()) { +# // PRAGMA_OMP(atomic) +# if ((p.get_infection_state(prev_time) != mio::abm::InfectionState::Exposed) && +# (p.get_infection_state(curr_time) == mio::abm::InfectionState::Exposed)) { +# auto index = (((size_t)(mio::abm::LocationType::Count)) * ((uint32_t)p.get_age().get())) + +# ((uint32_t)p.get_location().get_type()); +# sum[index] += 1; +# } +# } +# } +# return std::make_pair(curr_time, sum); +# } +# }; +# +# The output of the loggers of several runs is stored in HDF5 files, with the memilio funciton mio::save_results in mio/io/result_io.h. + +# Adjust these as needed. +# States and location type numbers need to match the infection states used in your simulation. +state_labels = { + 1: 'Exposed', + 2: 'I_Asymp', + 3: 'I_Symp', + 4: 'I_Severe', + 5: 'I_Critical', + 7: 'Dead' +} + +age_groups = ['Group1', 'Group2', 'Group3', 'Group4', + 'Group5', 'Group6', 'Total'] + +age_groups_dict = { + 'Group1': 'Ages 0-4', + 'Group2': 'Ages 5-14', + 'Group3': 'Ages 15-34', + 'Group4': 'Ages 35-59', + 'Group5': 'Ages 60-79', + 'Group6': 'Ages 80+', + 'Total': 'All Ages' +} + +location_type_labels = { + 0: 'Home', + 1: 'School', + 2: 'Work', + 3: 'SocialEvent', + 4: 'BasicsShop', + 5: 'Hospital', + 6: 'ICU' +} + + +def load_h5_results(base_path, percentile): + """ Reads HDF5 results for a given group and percentile. + + @param[in] base_path Path to results directory. + @param[in] percentile Subdirectory for percentile (e.g. 'p50'). + @return Dictionary with data arrays. + """ + file_path = os.path.join(base_path, percentile, "Results.h5") + with h5py.File(file_path, 'r') as f: + data = {k: v[()] for k, v in f['0'].items()} + return data + + +def plot_infections_loc_types_average( + path_to_loc_types, + start_date='2021-03-01', + colormap='Set1', + smooth_sigma=1, + rolling_window=24, + xtick_step=150): + """ Plots rolling average infections per location type for the median run. + + @param[in] base_path Path to results directory. + @param[in] start_date Start date as string. + @param[in] colormap Matplotlib colormap. + @param[in] smooth_sigma Sigma for Gaussian smoothing. + @param[in] rolling_window Window size for rolling sum. + @param[in] xtick_step Step size for x-axis ticks. + """ + # Load data + p50 = load_h5_results(path_to_loc_types, "p50") + time = p50['Time'] + total_50 = p50['Total'] + + plt.figure('Infection_location_types') + plt.title( + 'Infection per location type for the median run, rolling sum over 24 hours') + color_plot = matplotlib.colormaps.get_cmap(colormap).colors + + for idx, i in enumerate(location_type_labels.keys()): + color = color_plot[i % len(color_plot)] if i < len( + color_plot) else "black" + # Sum up every 24 hours, then smooth + indexer = pd.api.indexers.FixedForwardWindowIndexer( + window_size=rolling_window) + y = pd.DataFrame(total_50[:, i]).rolling( + window=indexer, min_periods=1).sum().to_numpy() + y = y[0::rolling_window].flatten() + y = gaussian_filter1d(y, sigma=smooth_sigma, mode='nearest') + plt.plot(time[0::rolling_window], y, color=color, linewidth=2.5) + + plt.legend(list(location_type_labels.values())) + _format_x_axis(time, start_date, xtick_step) + plt.xlabel('Date') + plt.ylabel('Number of individuals') + plt.show() + + +def plot_infection_states_results( + path_to_infection_states, + start_date='2021-03-01', + colormap='Set1', + xtick_step=150, + show90=False +): + """ Loads and plots infection state results. """ + # Load data + p50 = load_h5_results(path_to_infection_states, "p50") + p25 = load_h5_results(path_to_infection_states, "p25") + p75 = load_h5_results(path_to_infection_states, "p75") + time = p50['Time'] + total_50 = p50['Total'] + total_25 = p25['Total'] + total_75 = p75['Total'] + p05 = p95 = None + total_05 = total_95 = None + if show90: + total_95 = load_h5_results(path_to_infection_states, "p95") + total_05 = load_h5_results(path_to_infection_states, "p05") + p95 = total_95['Total'] + p05 = total_05['Total'] + + plot_infection_states_individual( + time, p50, p25, p75, colormap, + p05_bs=total_05 if show90 else None, + p95_bs=total_95 if show90 else None, + show90=show90 + ) + plot_infection_states(time, total_50, total_25, + total_75, start_date, colormap, xtick_step, + y05=p05, y95=p95, show_90=show90) + + +def plot_infection_states( + x, y50, y25, y75, + start_date='2021-03-01', + colormap='Set1', + xtick_step=150, + y05=None, y95=None, show_90=False): + """ Plots infection states with percentiles and improved styling. """ + plt.figure('Infection_states') + + plt.title('Infection states with 50% percentile') + if show_90: + plt.title('Infection states with 50% and 90% percentiles') + + color_plot = matplotlib.colormaps.get_cmap(colormap).colors + + states_plot = list(state_labels.keys()) + + for i in states_plot: + plt.plot(x, y50[:, i], color=color_plot[i], + linewidth=2.5, label=state_labels[i]) + # needs to be after the plot calls + plt.legend([state_labels[i] for i in states_plot]) + for i in states_plot: + plt.plot(x, y25[:, i], color=color_plot[i], + linestyle='dashdot', linewidth=1.2, alpha=0.7) + plt.plot(x, y75[:, i], color=color_plot[i], + linestyle='dashdot', linewidth=1.2, alpha=0.7) + plt.fill_between(x, y50[:, i], y25[:, i], + alpha=0.2, color=color_plot[i]) + plt.fill_between(x, y50[:, i], y75[:, i], + alpha=0.2, color=color_plot[i]) + # Optional: 90% percentile + if show_90 and y05 is not None and y95 is not None: + plt.plot(x, y05[:, i], color=color_plot[i], + linestyle='dashdot', linewidth=1.0, alpha=0.4) + plt.plot(x, y95[:, i], color=color_plot[i], + linestyle='dashdot', linewidth=1.0, alpha=0.4) + plt.fill_between(x, y05[:, i], y95[:, i], + # More transparent + alpha=0.25, color=color_plot[i]) + + _format_x_axis(x, start_date, xtick_step) + plt.xlabel('Date') + plt.ylabel('Number of individuals') + plt.show() + + +def plot_infection_states_individual( + x, p50_bs, p25_bs, p75_bs, colormap='Set1', + p05_bs=None, p95_bs=None, show90=False +): + """ Plots infection states for each age group, with optional 90% percentile. """ + + color_plot = matplotlib.colormaps.get_cmap(colormap).colors + n_states = len(state_labels) + fig, ax = plt.subplots( + n_states, len(age_groups), constrained_layout=True, figsize=(20, 3 * n_states)) + + for col_idx, group in enumerate(age_groups): + y50 = p50_bs[group] + y25 = p25_bs[group] + y75 = p75_bs[group] + y05 = p05_bs[group] if (show90 and p05_bs is not None) else None + y95 = p95_bs[group] if (show90 and p95_bs is not None) else None + for row_idx, (state_idx, label) in enumerate(state_labels.items()): + _plot_state( + ax[row_idx, col_idx], x, y50[:, state_idx], y25[:, + state_idx], y75[:, state_idx], + color_plot[col_idx], f'#{label}, {age_groups_dict[group]}', + y05=y05[:, state_idx] if y05 is not None else None, + y95=y95[:, state_idx] if y95 is not None else None, + show90=show90 + ) + # The legend should say: solid line = median, dashed line = 25% and 75% perc. and if show90 is True, dotted line = 5%, 25%, 75%, 95% perc. + perc_string = '25/75%' if not show90 else '5/25/75/95%' + ax[row_idx, col_idx].legend( + ['Median', f'{perc_string} perc.'], + loc='upper left', fontsize=8) + + string_short = ' and 90%' if show90 else '' + fig.suptitle( + 'Infection states per age group with 50' + string_short + ' percentile', + fontsize=16) + + plt.show() + + +def _plot_state(ax, x, y50, y25, y75, color, title, y05=None, y95=None, show90=False): + """ Helper to plot a single state with fill_between and optional 90% percentile. """ + ax.set_xlabel('time (days)') + ax.plot(x, y50, color=color, label='Median') + ax.fill_between(x, y50, y25, alpha=0.5, color=color) + ax.fill_between(x, y50, y75, alpha=0.5, color=color) + if show90 and y05 is not None and y95 is not None: + ax.plot(x, y05, color=color, linestyle='dotted', + linewidth=1.0, alpha=0.4) + ax.plot(x, y95, color=color, linestyle='dotted', + linewidth=1.0, alpha=0.4) + ax.fill_between(x, y05, y95, alpha=0.15, color=color) + ax.tick_params(axis='y') + ax.set_title(title) + + +def _format_x_axis(x, start_date, xtick_step): + """ Helper to format x-axis as dates. """ + start = datetime.strptime(start_date, '%Y-%m-%d') + xx = [start + pd.Timedelta(days=int(i)) for i in x] + xx_str = [dt.strftime('%Y-%m-%d') for dt in xx] + plt.gca().set_xticks(x[::xtick_step]) + plt.gca().set_xticklabels(xx_str[::xtick_step]) + plt.gcf().autofmt_xdate() + + +def main(): + """ Main function for CLI usage. """ + parser = argparse.ArgumentParser( + description="Plot infection state and location type results.") + parser.add_argument("--path-to-infection-states", + help="Path to infection states results") + parser.add_argument("--path-to-loc-types", + help="Path to location types results") + parser.add_argument("--start-date", type=str, default='2021-03-01', + help="Simulation start date (YYYY-MM-DD)") + parser.add_argument("--colormap", type=str, + default='Set1', help="Matplotlib colormap") + parser.add_argument("--xtick-step", type=int, + default=150, help="Step for x-axis ticks") + parser.add_argument("--90percentile", action="store_true", + help="If set, plot 90% percentile as well") + args = parser.parse_args() + plot_infection_states_results( + args.path_to_infection_states, + start_date=args.start_date, + colormap=args.colormap, + xtick_step=args.xtick_step, + show90=True + ) + plot_infections_loc_types_average( + args.path_to_loc_types, + start_date=args.start_date, + colormap=args.colormap, + xtick_step=args.xtick_step) + + if not args.path_to_infection_states and not args.path_to_loc_types: + print("Please provide a path to infection states or location types results.") + sys.exit(1) + plt.show() + + +if __name__ == "__main__": + main() diff --git a/pycode/memilio-plot/memilio/plot_test/test_plot_plotAbmInfectionStates.py b/pycode/memilio-plot/memilio/plot_test/test_plot_plotAbmInfectionStates.py new file mode 100644 index 0000000000..fbdba48535 --- /dev/null +++ b/pycode/memilio-plot/memilio/plot_test/test_plot_plotAbmInfectionStates.py @@ -0,0 +1,154 @@ +############################################################################# +# Copyright (C) 2020-2024 MEmilio +# +# Authors: Sascha Korf +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# + +import unittest +from unittest.mock import patch, MagicMock +import numpy as np +import pandas as pd + +import memilio.plot.plotAbmInfectionStates as abm + + +class TestPlotAbmInfectionStates(unittest.TestCase): + + @patch('memilio.plot.plotAbmInfectionStates.h5py.File') + def test_load_h5_results(self, mock_h5file): + mock_group = {'Time': np.arange(10), 'Total': np.ones((10, 8))} + mock_h5file().__enter__().get.return_value = {'0': mock_group} + mock_h5file().__enter__().items.return_value = [('0', mock_group)] + mock_h5file().__enter__().__getitem__.return_value = mock_group + mock_h5file().__enter__().items.return_value = [ + ('Time', np.arange(10)), ('Total', np.ones((10, 8)))] + with patch('memilio.plot.plotAbmInfectionStates.h5py.File', mock_h5file): + result = abm.load_h5_results('dummy_path', 'p50') + assert 'Time' in result + assert 'Total' in result + np.testing.assert_array_equal(result['Time'], np.arange(10)) + np.testing.assert_array_equal(result['Total'], np.ones((10, 8))) + + @patch('memilio.plot.plotAbmInfectionStates.load_h5_results') + @patch('memilio.plot.plotAbmInfectionStates.matplotlib') + @patch('memilio.plot.plotAbmInfectionStates.gaussian_filter1d', side_effect=lambda x, sigma, mode: x) + @patch('memilio.plot.plotAbmInfectionStates.pd.DataFrame') + def test_plot_infections_loc_types_average(self, mock_df, mock_gauss, mock_matplotlib, mock_load): + mock_load.return_value = { + 'Time': np.arange(48), 'Total': np.ones((48, 7))} + mock_df.return_value.rolling.return_value.sum.return_value.to_numpy.return_value = np.ones( + (48, 1)) + mock_matplotlib.colormaps.get_cmap.return_value.colors = [(1, 0, 0)]*7 + + # Patch plt.gca().plot to a MagicMock + with patch.object(abm.plt, 'gca') as mock_gca: + mock_ax = MagicMock() + mock_gca.return_value = mock_ax + abm.plot_infections_loc_types_average('dummy_path') + assert mock_ax.plot.called + assert mock_ax.set_xticks.called + assert mock_ax.set_xticklabels.called + + @patch('memilio.plot.plotAbmInfectionStates.load_h5_results') + @patch('memilio.plot.plotAbmInfectionStates.plot_infection_states') + @patch('memilio.plot.plotAbmInfectionStates.plot_infection_states_individual') + def test_plot_infection_states_results(self, mock_indiv, mock_states, mock_load): + mock_load.side_effect = [ + {'Time': np.arange(10), 'Total': np.ones((10, 8)), 'Group1': np.ones((10, 8)), 'Group2': np.ones((10, 8)), 'Group3': np.ones( + (10, 8)), 'Group4': np.ones((10, 8)), 'Group5': np.ones((10, 8)), 'Group6': np.ones((10, 8)), 'Total': np.ones((10, 8))}, + {'Time': np.arange(10), 'Total': np.ones((10, 8)), 'Group1': np.ones((10, 8)), 'Group2': np.ones((10, 8)), 'Group3': np.ones( + (10, 8)), 'Group4': np.ones((10, 8)), 'Group5': np.ones((10, 8)), 'Group6': np.ones((10, 8)), 'Total': np.ones((10, 8))}, + {'Time': np.arange(10), 'Total': np.ones((10, 8)), 'Group1': np.ones((10, 8)), 'Group2': np.ones((10, 8)), 'Group3': np.ones( + (10, 8)), 'Group4': np.ones((10, 8)), 'Group5': np.ones((10, 8)), 'Group6': np.ones((10, 8)), 'Total': np.ones((10, 8))} + ] + abm.plot_infection_states_results('dummy_path') + assert mock_indiv.called + assert mock_states.called + + @patch('memilio.plot.plotAbmInfectionStates.matplotlib') + def test_plot_infection_states(self, mock_matplotlib): + x = np.arange(10) + y50 = np.ones((10, 8)) + y25 = np.zeros((10, 8)) + y75 = np.ones((10, 8))*2 + y05 = np.ones((10, 8))*-1 + y95 = np.ones((10, 8))*3 + mock_matplotlib.colormaps.get_cmap.return_value.colors = [(1, 0, 0)]*8 + + # Patch plt.gca().plot and fill_between + with patch.object(abm.plt, 'gca') as mock_gca: + mock_ax = MagicMock() + mock_gca.return_value = mock_ax + abm.plot_infection_states( + x, y50, y25, y75, + start_date='2021-03-01', + colormap='Set1', + xtick_step=2, + y05=y05, + y95=y95, + show_90=True + ) + assert mock_ax.plot.called + assert mock_ax.fill_between.called + assert mock_ax.set_xticks.called + assert mock_ax.set_xticklabels.called + + @patch('memilio.plot.plotAbmInfectionStates.matplotlib') + def test_plot_infection_states_individual(self, mock_matplotlib): + x = np.arange(10) + group_data = np.ones((10, 8)) + p50_bs = {g: group_data for g in [ + 'Group1', 'Group2', 'Group3', 'Group4', 'Group5', 'Group6', 'Total']} + p25_bs = {g: group_data for g in [ + 'Group1', 'Group2', 'Group3', 'Group4', 'Group5', 'Group6', 'Total']} + p75_bs = {g: group_data for g in [ + 'Group1', 'Group2', 'Group3', 'Group4', 'Group5', 'Group6', 'Total']} + p05_bs = {g: group_data*-1 for g in [ + 'Group1', 'Group2', 'Group3', 'Group4', 'Group5', 'Group6', 'Total']} + p95_bs = {g: group_data*3 for g in [ + 'Group1', 'Group2', 'Group3', 'Group4', 'Group5', 'Group6', 'Total']} + mock_matplotlib.colormaps.get_cmap.return_value.colors = [(1, 0, 0)]*8 + + # Patch plt.subplots to return a grid of MagicMock axes (as np.array with dtype=object) + with patch.object(abm.plt, 'subplots') as mock_subplots: + fig_mock = MagicMock() + ax_mock = np.empty((6, 7), dtype=object) + for i in range(6): + for j in range(7): + ax_mock[i, j] = MagicMock() + mock_subplots.return_value = (fig_mock, ax_mock) + abm.plot_infection_states_individual( + x, p50_bs, p25_bs, p75_bs, + colormap='Set1', + p05_bs=p05_bs, + p95_bs=p95_bs, + show90=True + ) + # Check that at least one ax's plot was called + assert any(ax_mock[i, j].plot.called for i in range(6) + for j in range(7)) + assert fig_mock.suptitle.called + + def test__format_x_axis(self): + with patch('memilio.plot.plotAbmInfectionStates.plt') as mock_plt: + abm._format_x_axis(np.arange(10), '2021-03-01', 2) + assert mock_plt.gca.called + assert mock_plt.gcf.called + + +if __name__ == '__main__': + unittest.main() diff --git a/pycode/memilio-plot/setup.py b/pycode/memilio-plot/setup.py index 9c6a85908a..130996ac6a 100644 --- a/pycode/memilio-plot/setup.py +++ b/pycode/memilio-plot/setup.py @@ -8,12 +8,13 @@ class PylintCommand(Command): - """Custom command to run pylint and get a report as html.""" + """ + Custom command to run pylint and get a report as html. + """ description = "Runs pylint and outputs the report as html." user_options = [] def initialize_options(self): - """ """ from pylint.reporters.json_reporter import JSONReporter from pylint.reporters.text import ParseableTextReporter, TextReporter from pylint_json2html import JsonExtendedReporter @@ -29,12 +30,10 @@ def initialize_options(self): } def finalize_options(self): - """ """ self.reporter, self.out_file = self.REPORTERS.get( self.out_format) # , self.REPORTERS.get("parseable")) def run(self): - """ """ os.makedirs("build_pylint", exist_ok=True) # Run pylint @@ -74,6 +73,7 @@ def run(self): 'pyxlsb', 'wget', 'folium', + 'scipy.ndimage', 'matplotlib', 'mapclassify', 'geopandas',