From 773bc26fca81550e5124ef3adbe5cfad78367365 Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Fri, 14 Feb 2025 14:32:07 +0000 Subject: [PATCH 1/9] Added serial implementation for ParaDiag for collocation methods --- docs/source/index.rst | 1 + docs/source/tutorial/doc_step_9_A.rst | 7 + docs/source/tutorial/doc_step_9_B.rst | 7 + docs/source/tutorial/doc_step_9_C.rst | 7 + docs/source/tutorial/step_9.rst | 1 + pySDC/core/controller.py | 67 +++ pySDC/core/level.py | 1 + pySDC/helpers/ParaDiagHelper.py | 131 +++++ .../controller_ParaDiag_nonMPI.py | 482 ++++++++++++++++++ pySDC/implementations/hooks/log_timings.py | 3 +- .../problem_classes/TestEquation_0D.py | 117 ++++- .../sweeper_classes/ParaDiagSweepers.py | 158 ++++++ .../test_controller_ParaDiag_nonMPI.py | 230 +++++++++ .../tests/test_helpers/test_ParaDiagHelper.py | 20 + .../test_problems/test_Dahlquist_IMEX.py | 32 ++ .../test_sweepers/test_ParaDiag_sweepers.py | 103 ++++ pySDC/tests/test_tutorials/test_step_9.py | 20 + .../step_9/A_paradiag_for_linear_problems.py | 277 ++++++++++ .../B_paradiag_for_nonlinear_problems.py | 205 ++++++++ pySDC/tutorial/step_9/C_paradiag_in_pySDC.py | 182 +++++++ 20 files changed, 2049 insertions(+), 2 deletions(-) create mode 100644 docs/source/tutorial/doc_step_9_A.rst create mode 100644 docs/source/tutorial/doc_step_9_B.rst create mode 100644 docs/source/tutorial/doc_step_9_C.rst create mode 100644 docs/source/tutorial/step_9.rst create mode 100644 pySDC/helpers/ParaDiagHelper.py create mode 100644 pySDC/implementations/controller_classes/controller_ParaDiag_nonMPI.py create mode 100644 pySDC/implementations/sweeper_classes/ParaDiagSweepers.py create mode 100644 pySDC/tests/test_controllers/test_controller_ParaDiag_nonMPI.py create mode 100644 pySDC/tests/test_helpers/test_ParaDiagHelper.py create mode 100644 pySDC/tests/test_problems/test_Dahlquist_IMEX.py create mode 100644 pySDC/tests/test_sweepers/test_ParaDiag_sweepers.py create mode 100644 pySDC/tests/test_tutorials/test_step_9.py create mode 100644 pySDC/tutorial/step_9/A_paradiag_for_linear_problems.py create mode 100644 pySDC/tutorial/step_9/B_paradiag_for_nonlinear_problems.py create mode 100644 pySDC/tutorial/step_9/C_paradiag_in_pySDC.py diff --git a/docs/source/index.rst b/docs/source/index.rst index 5fa8c304be..3803edd0c0 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -25,6 +25,7 @@ Tutorial tutorial/step_6.rst tutorial/step_7.rst tutorial/step_8.rst + tutorial/step_9.rst Playgrounds ----------- diff --git a/docs/source/tutorial/doc_step_9_A.rst b/docs/source/tutorial/doc_step_9_A.rst new file mode 100644 index 0000000000..64638dd562 --- /dev/null +++ b/docs/source/tutorial/doc_step_9_A.rst @@ -0,0 +1,7 @@ +Full code: `pySDC/tutorial/step_9/A_paradiag_for_linear_problems.py `_ + +.. literalinclude:: ../../../pySDC/tutorial/step_9/A_paradiag_for_linear_problems.py + +Results: + +.. literalinclude:: ../../../data/step_9_A_out.txt diff --git a/docs/source/tutorial/doc_step_9_B.rst b/docs/source/tutorial/doc_step_9_B.rst new file mode 100644 index 0000000000..004410e924 --- /dev/null +++ b/docs/source/tutorial/doc_step_9_B.rst @@ -0,0 +1,7 @@ +Full code: `pySDC/tutorial/step_9/B_paradiag_for_nonlinear_problems.py `_ + +.. literalinclude:: ../../../pySDC/tutorial/step_9/B_paradiag_for_nonlinear_problems.py + +Results: + +.. literalinclude:: ../../../data/step_9_B_out.txt diff --git a/docs/source/tutorial/doc_step_9_C.rst b/docs/source/tutorial/doc_step_9_C.rst new file mode 100644 index 0000000000..5067bf60d4 --- /dev/null +++ b/docs/source/tutorial/doc_step_9_C.rst @@ -0,0 +1,7 @@ +Full code: `pySDC/tutorial/step_9/C_paradiag_in_pySDC.py `_ + +.. literalinclude:: ../../../pySDC/tutorial/step_9/C_paradiag_in_pySDC.py + +Results: + +.. literalinclude:: ../../../data/step_9_C_out.txt diff --git a/docs/source/tutorial/step_9.rst b/docs/source/tutorial/step_9.rst new file mode 100644 index 0000000000..37b5c62811 --- /dev/null +++ b/docs/source/tutorial/step_9.rst @@ -0,0 +1 @@ +.. include:: /../../pySDC/tutorial/step_9/README.rst diff --git a/pySDC/core/controller.py b/pySDC/core/controller.py index e5b2748a36..fbc88057fe 100644 --- a/pySDC/core/controller.py +++ b/pySDC/core/controller.py @@ -6,6 +6,7 @@ from pySDC.core.base_transfer import BaseTransfer from pySDC.helpers.pysdc_helper import FrozenClass from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence +from pySDC.implementations.convergence_controller_classes.store_uold import StoreUOld from pySDC.implementations.hooks.default_hook import DefaultHooks from pySDC.implementations.hooks.log_timings import CPUTimings @@ -41,6 +42,7 @@ def __init__(self, controller_params, description, useMPI=None): controller_params (dict): parameter set for the controller and the steps """ self.useMPI = useMPI + self.description = description # check if we have a hook on this list. If not, use default class. self.__hooks = [] @@ -341,3 +343,68 @@ def return_stats(self): for hook in self.hooks: stats = {**stats, **hook.return_stats()} return stats + + +class ParaDiagController(Controller): + + def __init__(self, controller_params, description, n_steps, useMPI=None): + """ + Initialization routine for ParaDiag controllers + + Args: + num_procs: number of parallel time steps (still serial, though), can be 1 + controller_params: parameter set for the controller and the steps + description: all the parameters to set up the rest (levels, problems, transfer, ...) + n_steps (int): Number of parallel steps + alpha (float): alpha parameter for ParaDiag + """ + # TODO: where should I put alpha? When I want to adapt it, maybe it shouldn't be in the controller? + from pySDC.implementations.sweeper_classes.ParaDiagSweepers import QDiagonalization + + if QDiagonalization in description['sweeper_class'].__mro__: + description['sweeper_params']['ignore_ic'] = True + description['sweeper_params']['update_f_evals'] = False + else: + logging.getLogger('controller').warning( + f'Warning: Your sweeper class {description["sweeper_class"]} is not derived from {QDiagonalization}. You probably want to use another sweeper class.' + ) + + if controller_params.get('all_to_done', False): + raise NotImplementedError('ParaDiag only implemented with option `all_to_done=True`') + if 'alpha' not in controller_params.keys(): + from pySDC.core.errors import ParameterError + + raise ParameterError('Please supply alpha as a parameter to the ParaDiag controller!') + controller_params['average_jacobian'] = controller_params.get('average_jacobian', True) + + controller_params['all_to_done'] = True + super().__init__(controller_params=controller_params, description=description, useMPI=useMPI) + self.base_convergence_controllers += [StoreUOld] + + self.ParaDiag_block_u0 = None + self.n_steps = n_steps + + def FFT_in_time(self): + """ + Compute weighted forward FFT in time. The weighting is determined by the alpha parameter in ParaDiag + + Note: The implementation via matrix-vector multiplication may be inefficient and less stable compared to an FFT + with transposes! + """ + if not hasattr(self, '__FFT_matrix'): + from pySDC.helpers.ParaDiagHelper import get_weighted_FFT_matrix + + self.__FFT_matrix = get_weighted_FFT_matrix(self.n_steps, self.params.alpha) + + self.apply_matrix(self.__FFT_matrix) + + def iFFT_in_time(self): + """ + Compute weighted backward FFT in time. The weighting is determined by the alpha parameter in ParaDiag + """ + if not hasattr(self, '__iFFT_matrix'): + from pySDC.helpers.ParaDiagHelper import get_weighted_iFFT_matrix + + self.__iFFT_matrix = get_weighted_iFFT_matrix(self.n_steps, self.params.alpha) + + self.apply_matrix(self.__iFFT_matrix) diff --git a/pySDC/core/level.py b/pySDC/core/level.py index 66d3d3083e..76c415af54 100644 --- a/pySDC/core/level.py +++ b/pySDC/core/level.py @@ -82,6 +82,7 @@ def __init__(self, problem_class, problem_params, sweeper_class, sweeper_params, self.uend = None self.u = [None] * (self.sweep.coll.num_nodes + 1) self.uold = [None] * (self.sweep.coll.num_nodes + 1) + self.u_avg = [None] * self.sweep.coll.num_nodes self.f = [None] * (self.sweep.coll.num_nodes + 1) self.fold = [None] * (self.sweep.coll.num_nodes + 1) diff --git a/pySDC/helpers/ParaDiagHelper.py b/pySDC/helpers/ParaDiagHelper.py new file mode 100644 index 0000000000..7c4e118c36 --- /dev/null +++ b/pySDC/helpers/ParaDiagHelper.py @@ -0,0 +1,131 @@ +import numpy as np +import scipy.sparse as sp + + +def get_FFT_matrix(N): + """ + Get matrix for computing FFT of size N. Normalization is like "ortho" in numpy. + Compute inverse FFT by multiplying by the complex conjugate (numpy.conjugate) of this matrix + + Args: + N (int): Size of the data to be transformed + + Returns: + numpy.ndarray: Dense square matrix to compute forward transform + """ + idx_1d = np.arange(N, dtype=complex) + i1, i2 = np.meshgrid(idx_1d, idx_1d) + + return np.exp(-2 * np.pi * 1j * i1 * i2 / N) / np.sqrt(N) + + +def get_E_matrix(N, alpha=0): + """ + Get NxN matrix with -1 on the lower subdiagonal, -alpha in the top right and 0 elsewhere + + Args: + N (int): Size of the matrix + alpha (float): Negative of value in the top right + + Returns: + sparse E matrix + """ + E = sp.diags( + [ + -1.0, + ] + * (N - 1), + offsets=-1, + ).tolil() + E[0, -1] = -alpha + return E + + +def get_J_matrix(N, alpha): + """ + Get matrix for weights in the weighted inverse FFT + + Args: + N (int): Size of the matrix + alpha (float): alpha parameter in ParaDiag + + Returns: + sparse J matrix + """ + gamma = alpha ** (-np.arange(N) / N) + return sp.diags(gamma) + + +def get_J_inv_matrix(N, alpha): + """ + Get matrix for weights in the weighted FFT + + Args: + N (int): Size of the matrix + alpha (float): alpha parameter in ParaDiag + + Returns: + sparse J_inv matrix + """ + gamma = alpha ** (-np.arange(N) / N) + return sp.diags(1 / gamma) + + +def get_weighted_FFT_matrix(N, alpha): + """ + Get matrix for the weighted FFT + + Args: + N (int): Size of the matrix + alpha (float): alpha parameter in ParaDiag + + Returns: + Dense weighted FFT matrix + """ + return get_FFT_matrix(N) @ get_J_inv_matrix(N, alpha) + + +def get_weighted_iFFT_matrix(N, alpha): + """ + Get matrix for the weighted inverse FFT + + Args: + N (int): Size of the matrix + alpha (float): alpha parameter in ParaDiag + + Returns: + Dense weighted FFT matrix + """ + return get_J_matrix(N, alpha) @ np.conjugate(get_FFT_matrix(N)) + + +def get_H_matrix(N, sweeper_params): + """ + Get sparse matrix for computing the collocation update. Requires not to do a collocation update! + + Args: + N (int): Number of collocation nodes + sweeper_params (dict): Parameters for the sweeper + + Returns: + Sparse matrix for collocation update + """ + assert sweeper_params['quad_type'] == 'RADAU-RIGHT' + H = sp.eye(N).tolil() * 0 + H[:, -1] = 1 + return H + + +def get_G_inv_matrix(l, L, alpha, sweeper_params): + M = sweeper_params['num_nodes'] + I_M = sp.eye(M) + E_alpha = get_E_matrix(L, alpha) + H = get_H_matrix(M, sweeper_params) + + gamma = alpha ** (-np.arange(L) / L) + diags = np.fft.fft(1 / gamma * E_alpha[:, 0].toarray().flatten(), norm='backward') + G = (diags[l] * H + I_M).tocsc() + if M > 1: + return sp.linalg.inv(G).toarray() + else: + return 1 / G.toarray() diff --git a/pySDC/implementations/controller_classes/controller_ParaDiag_nonMPI.py b/pySDC/implementations/controller_classes/controller_ParaDiag_nonMPI.py new file mode 100644 index 0000000000..e5c5deca3a --- /dev/null +++ b/pySDC/implementations/controller_classes/controller_ParaDiag_nonMPI.py @@ -0,0 +1,482 @@ +import itertools +import copy as cp +import numpy as np +import dill + +from pySDC.core.controller import ParaDiagController +from pySDC.core import step as stepclass +from pySDC.core.errors import ControllerError, CommunicationError +from pySDC.implementations.convergence_controller_classes.basic_restarting import BasicRestarting +from pySDC.helpers.ParaDiagHelper import get_G_inv_matrix + + +class controller_ParaDiag_nonMPI(ParaDiagController): + """ + + ParaDiag controller, running serialized version. + + This controller uses the increment formulation. That is to say, we setup the residual of the all at once problem, + put it on the right hand side, invert the ParaDiag preconditioner on the left-hand side to compute the increment + and then add the increment onto the solution. For this reason, we need to replace the solution values in the steps + with the residual values before the solves and then put the solution plus increment back into the steps. This is a + bit counter to what you expect when you access the `u` variable in the levels, but it is mathematically advantageous. + """ + + def __init__(self, num_procs, controller_params, description): + """ + Initialization routine for ParaDiag controller + + Args: + num_procs: number of parallel time steps (still serial, though), can be 1 + controller_params: parameter set for the controller and the steps + description: all the parameters to set up the rest (levels, problems, transfer, ...) + """ + super().__init__(controller_params, description, useMPI=False, n_steps=num_procs) + + self.MS = [] + + for l in range(num_procs): + G_inv = get_G_inv_matrix(l, num_procs, self.params.alpha, description['sweeper_params']) + description['sweeper_params']['G_inv'] = G_inv + + self.MS.append(stepclass.Step(description)) + + self.base_convergence_controllers += [BasicRestarting.get_implementation(useMPI=False)] + for convergence_controller in self.base_convergence_controllers: + self.add_convergence_controller(convergence_controller, description) + + if self.params.dump_setup: + self.dump_setup(step=self.MS[0], controller_params=controller_params, description=description) + + if len(self.MS[0].levels) > 1: + raise NotImplementedError('This controller does not support multiple levels') + + for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: + C.reset_buffers_nonMPI(self) + C.setup_status_variables(self, MS=self.MS) + + def ParaDiag(self, local_MS_active): + """ + Main function for ParaDiag + + For the workflow of this controller, see https://arxiv.org/abs/2103.12571 + + This method changes self.MS directly by accessing active steps through local_MS_active. + + Args: + local_MS_active (list): all active steps + + Returns: + boot: Whether all steps are done + """ + + # if all stages are the same (or DONE), continue, otherwise abort + stages = [S.status.stage for S in local_MS_active if S.status.stage != 'DONE'] + if stages[1:] == stages[:-1]: + stage = stages[0] + else: + raise ControllerError('not all stages are equal') + + self.logger.debug(stage) + + MS_running = [S for S in local_MS_active if S.status.stage != 'DONE'] + + switcher = { + 'SPREAD': self.spread, + 'IT_CHECK': self.it_check, + 'IT_PARADIAG': self.it_ParaDiag, + } + + assert stage in switcher.keys(), f'Got unexpected stage {stage!r}' + switcher[stage](MS_running) + + return all(S.status.done for S in local_MS_active) + + def apply_matrix(self, mat): + """ + Apply a matrix on the step level. Needs to be square. Puts the result back into the controller. + + Args: + mat: square LxL matrix with L number of steps + """ + L = len(self.MS) + assert np.allclose(mat.shape, L) + assert len(mat.shape) == 2 + + level = self.MS[0].levels[0] + M = level.sweep.params.num_nodes + prob = level.prob + + # buffer for storing the result + res = [ + None, + ] * L + + # compute matrix-vector product + for i in range(mat.shape[0]): + res[i] = [prob.u_init for _ in range(M + 1)] + for j in range(mat.shape[1]): + for m in range(M + 1): + res[i][m] += mat[i, j] * self.MS[j].levels[0].u[m] + + # put the result in the "output" + for i in range(mat.shape[0]): + for m in range(M + 1): + self.MS[i].levels[0].u[m] = res[i][m] + + def swap_solution_for_all_at_once_residual(self, local_MS_running): + """ + Replace the solution values in the steps with the all-at-once residual. + + This requires to communicate the solutions at the end of the steps to be the initial conditions for the next + steps. Afterwards, the residual can be computed locally on the steps. + + Args: + local_MS_running (list): list of currently running steps + """ + prob = self.MS[0].levels[0].prob + + for S in local_MS_running: + # communicate initial conditions + S.levels[0].sweep.compute_end_point() + + for hook in self.hooks: + hook.pre_comm(step=S, level_number=0) + + if S.status.first: + S.levels[0].u[0] = prob.dtype_u(self.ParaDiag_block_u0) + else: + S.levels[0].u[0] = S.prev.levels[0].uend + + for hook in self.hooks: + hook.post_comm(step=S, level_number=0, add_to_stats=True) + + # compute residuals locally + residual = S.levels[0].sweep.get_residual() + S.levels[0].status.residual = max(abs(me) for me in residual) + + # put residual in the solution variables + for m in range(S.levels[0].sweep.coll.num_nodes): + S.levels[0].u[m + 1] = residual[m] + + def swap_increment_for_solution(self, local_MS_running): + """ + After inversion of the preconditioner, the values stored in the steps are the increment. This function adds the + solution after the previous iteration to arrive at the solution after the current iteration. + Note that we also need to put in the initial conditions back in the first step because they will be perturbed by + the circular preconditioner. + + Args: + local_MS_running (list): list of currently running steps + """ + for S in local_MS_running: + for m in range(S.levels[0].sweep.coll.num_nodes + 1): + S.levels[0].u[m] = S.levels[0].uold[m] + S.levels[0].u[m] + if S.status.first: + S.levels[0].u[0] = self.ParaDiag_block_u0 + + def prepare_Jacobians(self, local_MS_running): + # get solutions for constructing average Jacobians + if self.params.average_jacobian: + level = local_MS_running[0].levels[0] + M = level.sweep.coll.num_nodes + + u_avg = [level.prob.dtype_u(level.prob.init, val=0)] * M + + # communicate average solution + for S in local_MS_running: + for m in range(M): + u_avg[m] += S.levels[0].u[m + 1] / self.n_steps + + # store the averaged solution in the steps + for S in local_MS_running: + S.levels[0].u_avg = u_avg + + def it_ParaDiag(self, local_MS_running): + """ + Do a single ParaDiag iteration. Does the following steps + - (1) Compute the residual of the all-at-once / composite collocation problem + - (2) Compute an FFT in time to diagonalize the preconditioner + - (3) Solve the collocation problems locally on the steps for the increment + - (4) Compute iFFT in time to go back to the original base + - (5) Update the solution by adding increment + + Note that this is the only place where we compute the all-at-once residual because it requires communication and + swaps the solution values for the residuals. So after the residual tolerance is reached, one more ParaDiag + iteration will be done. + + Args: + local_MS_running (list): list of currently running steps + """ + + for S in local_MS_running: + for hook in self.hooks: + hook.pre_sweep(step=S, level_number=0) + + # replace the values stored in the steps with the residuals in order to compute the increment + self.swap_solution_for_all_at_once_residual(local_MS_running) + + # communicate average residual for setting up Jacobians for non-linear problems + self.prepare_Jacobians(local_MS_running) + + # weighted FFT in time + self.FFT_in_time() + + # perform local solves of "collocation problems" on the steps (can be done in parallel) + for S in local_MS_running: + assert len(S.levels) == 1, 'Multi-level SDC not implemented in ParaDiag' + S.levels[0].sweep.update_nodes() + + # inverse FFT in time + self.iFFT_in_time() + + # replace the values stored in the steps with the previous solution plus the increment + self.swap_increment_for_solution(local_MS_running) + + for S in local_MS_running: + for hook in self.hooks: + hook.post_sweep(step=S, level_number=0) + + # update stage + for S in local_MS_running: + S.status.stage = 'IT_CHECK' + + def it_check(self, local_MS_running): + """ + Key routine to check for convergence/termination + + Args: + local_MS_running (list): list of currently running steps + """ + + for S in local_MS_running: + if S.status.iter > 0: + for hook in self.hooks: + hook.post_iteration(step=S, level_number=0) + + # decide if the step is done, needs to be restarted and other things convergence related + for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: + C.post_iteration_processing(self, S, MS=local_MS_running) + C.convergence_control(self, S, MS=local_MS_running) + + for S in local_MS_running: + if not S.status.first: + for hook in self.hooks: + hook.pre_comm(step=S, level_number=0) + S.status.prev_done = S.prev.status.done # "communicate" + for hook in self.hooks: + hook.post_comm(step=S, level_number=0, add_to_stats=True) + S.status.done = S.status.done and S.status.prev_done + + if self.params.all_to_done: + for hook in self.hooks: + hook.pre_comm(step=S, level_number=0) + S.status.done = all(T.status.done for T in local_MS_running) + for hook in self.hooks: + hook.post_comm(step=S, level_number=0, add_to_stats=True) + + if not S.status.done: + # increment iteration count here (and only here) + S.status.iter += 1 + for hook in self.hooks: + hook.pre_iteration(step=S, level_number=0) + for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: + C.pre_iteration_processing(self, S, MS=local_MS_running) + + # Do another ParaDiag iteration + S.status.stage = 'IT_PARADIAG' + else: + S.levels[0].sweep.compute_end_point() + for hook in self.hooks: + hook.post_step(step=S, level_number=0) + S.status.stage = 'DONE' + + for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: + C.reset_buffers_nonMPI(self) + + def spread(self, local_MS_running): + """ + Spreading phase + + Args: + local_MS_running (list): list of currently running steps + """ + + for S in local_MS_running: + + # first stage: spread values + for hook in self.hooks: + hook.pre_step(step=S, level_number=0) + + # call predictor from sweeper + S.levels[0].sweep.predict() + + # compute the residual + S.levels[0].sweep.compute_residual() + + # update stage + S.status.stage = 'IT_CHECK' + + for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: + C.post_spread_processing(self, S, MS=local_MS_running) + + def run(self, u0, t0, Tend): + """ + Main driver for running the serial version of ParaDiag + + Args: + u0: initial values + t0: starting time + Tend: ending time + + Returns: + end values on the last step + stats object containing statistics for each step, each level and each iteration + """ + + # some initializations and reset of statistics + uend = None + num_procs = len(self.MS) + for hook in self.hooks: + hook.reset_stats() + + # initial ordering of the steps: 0,1,...,Np-1 + slots = list(range(num_procs)) + + # initialize time variables of each step + time = [t0 + sum(self.MS[j].dt for j in range(p)) for p in slots] + + # determine which steps are still active (time < Tend) + active = [time[p] < Tend - 10 * np.finfo(float).eps for p in slots] + if not all(active) and any(active): + self.logger.warning( + 'Warning: This controller will solve past your desired end time until the end of its block!' + ) + active = [ + True, + ] * len(active) + + if not any(active): + raise ControllerError('Nothing to do, check t0, dt and Tend.') + + # compress slots according to active steps, i.e. remove all steps which have times above Tend + active_slots = list(itertools.compress(slots, active)) + + # initialize block of steps with u0 + self.restart_block(active_slots, time, u0) + + for hook in self.hooks: + hook.post_setup(step=None, level_number=None) + + # call pre-run hook + for S in self.MS: + for hook in self.hooks: + hook.pre_run(step=S, level_number=0) + + # main loop: as long as at least one step is still active (time < Tend), do something + while any(active): + MS_active = [self.MS[p] for p in active_slots] + done = False + while not done: + done = self.ParaDiag(MS_active) + + restarts = [S.status.restart for S in MS_active] + restart_at = np.where(restarts)[0][0] if True in restarts else len(MS_active) + if True in restarts: # restart part of the block + # initial condition to next block is initial condition of step that needs restarting + uend = self.MS[restart_at].levels[0].u[0] + time[active_slots[0]] = time[restart_at] + self.logger.info(f'Starting next block with initial conditions from step {restart_at}') + + else: # move on to next block + # initial condition for next block is last solution of current block + uend = self.MS[active_slots[-1]].levels[0].uend + time[active_slots[0]] = time[active_slots[-1]] + self.MS[active_slots[-1]].dt + + for S in MS_active[:restart_at]: + for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: + C.post_step_processing(self, S, MS=MS_active) + + for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: + [C.prepare_next_block(self, S, len(active_slots), time, Tend, MS=MS_active) for S in self.MS] + + # setup the times of the steps for the next block + for i in range(1, len(active_slots)): + time[active_slots[i]] = time[active_slots[i] - 1] + self.MS[active_slots[i] - 1].dt + + # determine new set of active steps and compress slots accordingly + active = [time[p] < Tend - 10 * np.finfo(float).eps for p in slots] + if not all(active) and any(active): + self.logger.warning( + 'Warning: This controller will solve past your desired end time until the end of its block!' + ) + active = [ + True, + ] * len(active) + active_slots = list(itertools.compress(slots, active)) + + # restart active steps (reset all values and pass uend to u0) + self.restart_block(active_slots, time, uend) + + # call post-run hook + for S in self.MS: + for hook in self.hooks: + hook.post_run(step=S, level_number=0) + + for S in self.MS: + for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: + C.post_run_processing(self, S, MS=MS_active) + + return uend, self.return_stats() + + def restart_block(self, active_slots, time, u0): + """ + Helper routine to reset/restart block of (active) steps + + Args: + active_slots: list of active steps + time: list of new times + u0: initial value to distribute across the steps + + """ + self.ParaDiag_block_u0 = u0 # need this for computing residual + + for j in range(len(active_slots)): + # get slot number + p = active_slots[j] + + # store current slot number for diagnostics + self.MS[p].status.slot = p + # store link to previous step + self.MS[p].prev = self.MS[active_slots[j - 1]] + + self.MS[p].reset_step() + + # determine whether I am the first and/or last in line + self.MS[p].status.first = active_slots.index(p) == 0 + self.MS[p].status.last = active_slots.index(p) == len(active_slots) - 1 + + # initialize step with u0 + self.MS[p].init_step(u0) + + # setup G^{-1} for new number of active slots + # self.MS[j].levels[0].sweep.set_G_inv(get_G_inv_matrix(j, len(active_slots), self.params.alpha, self.description['sweeper_params'])) + + # reset some values + self.MS[p].status.done = False + self.MS[p].status.prev_done = False + self.MS[p].status.iter = 0 + self.MS[p].status.stage = 'SPREAD' + self.MS[p].status.force_done = False + self.MS[p].status.time_size = len(active_slots) + + for l in self.MS[p].levels: + l.tag = None + l.status.sweep = 1 + + for p in active_slots: + for lvl in self.MS[p].levels: + lvl.status.time = time[p] + + for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: + C.reset_status_variables(self, active_slots=active_slots) diff --git a/pySDC/implementations/hooks/log_timings.py b/pySDC/implementations/hooks/log_timings.py index b7a2305bbc..9d9cedf41d 100644 --- a/pySDC/implementations/hooks/log_timings.py +++ b/pySDC/implementations/hooks/log_timings.py @@ -287,7 +287,8 @@ def post_run(self, step, level_number): type=f'{self.prefix}timing_run', value=t_run, ) - self.logger.info(f'Finished run after {t_run:.2e}s') + if step.status.first: + self.logger.info(f'Finished run after {t_run:.2e}s') def post_setup(self, step, level_number): """ diff --git a/pySDC/implementations/problem_classes/TestEquation_0D.py b/pySDC/implementations/problem_classes/TestEquation_0D.py index 811dcf60c0..4276e96e97 100644 --- a/pySDC/implementations/problem_classes/TestEquation_0D.py +++ b/pySDC/implementations/problem_classes/TestEquation_0D.py @@ -2,7 +2,7 @@ import scipy.sparse as nsp from pySDC.core.problem import Problem, WorkCounter -from pySDC.implementations.datatype_classes.mesh import mesh +from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh class testequation0d(Problem): @@ -145,3 +145,118 @@ def u_exact(self, t, u_init=None, t_init=None): me = self.dtype_u(self.init) me[:] = u_init * self.xp.exp((t - t_init) * self.lambdas) return me + + +class test_equation_IMEX(Problem): + dtype_f = imex_mesh + dtype_u = mesh + xp = np + xsp = nsp + + def __init__(self, lambdas_implicit=None, lambdas_explicit=None, u0=0.0): + """Initialization routine""" + + if lambdas_implicit is None: + re = self.xp.linspace(-30, 19, 50) + im = self.xp.linspace(-50, 49, 50) + lambdas_implicit = self.xp.array( + [[complex(re[i], im[j]) for i in range(len(re))] for j in range(len(im))] + ).reshape((len(re) * len(im))) + if lambdas_explicit is None: + re = self.xp.linspace(-30, 19, 50) + im = self.xp.linspace(-50, 49, 50) + lambdas_implicit = self.xp.array( + [[complex(re[i], im[j]) for i in range(len(re))] for j in range(len(im))] + ).reshape((len(re) * len(im))) + lambdas_implicit = self.xp.asarray(lambdas_implicit) + lambdas_explicit = self.xp.asarray(lambdas_explicit) + + assert lambdas_implicit.ndim == 1, f'expect flat list here, got {lambdas_implicit}' + assert lambdas_explicit.shape == lambdas_implicit.shape + nvars = lambdas_implicit.size + assert nvars > 0, 'expect at least one lambda parameter here' + + # invoke super init, passing number of dofs, dtype_u and dtype_f + super().__init__(init=(nvars, None, self.xp.dtype('complex128'))) + + self.A = self.xsp.diags(lambdas_implicit) + self._makeAttributeAndRegister( + 'nvars', 'lambdas_implicit', 'lambdas_explicit', 'u0', localVars=locals(), readOnly=True + ) + self.work_counters['rhs'] = WorkCounter() + + def eval_f(self, u, t): + """ + Routine to evaluate the right-hand side of the problem. + + Parameters + ---------- + u : dtype_u + Current values of the numerical solution. + t : float + Current time of the numerical solution is computed. + + Returns + ------- + f : dtype_f + The right-hand side of the problem. + """ + + f = self.dtype_f(self.init) + f.impl[:] = u * self.lambdas_implicit + f.expl[:] = u * self.lambdas_explicit + self.work_counters['rhs']() + return f + + def solve_system(self, rhs, factor, u0, t): + r""" + Simple linear solver for :math:`(I-factor\cdot A)\vec{u}=\vec{rhs}`. + + Parameters + ---------- + rhs : dtype_f + Right-hand side for the linear system. + factor : float + Abbrev. for the local stepsize (or any other factor required). + u0 : dtype_u + Initial guess for the iterative solver. + t : float + Current time (e.g. for time-dependent BCs). + + Returns + ------- + me : dtype_u + The solution as mesh. + """ + me = self.dtype_u(self.init) + L = 1 - factor * self.lambdas_implicit + L[L == 0] = 1 # to avoid potential divisions by zeros + me[:] = rhs + me /= L + return me + + def u_exact(self, t, u_init=None, t_init=None): + """ + Routine to compute the exact solution at time t. + + Parameters + ---------- + t : float + Time of the exact solution. + u_init : pySDC.problem.testequation0d.dtype_u + Initial solution. + t_init : float + The initial time. + + Returns + ------- + me : dtype_u + The exact solution. + """ + + u_init = (self.u0 if u_init is None else u_init) * 1.0 + t_init = 0.0 if t_init is None else t_init * 1.0 + + me = self.dtype_u(self.init) + me[:] = u_init * self.xp.exp((t - t_init) * (self.lambdas_implicit + self.lambdas_explicit)) + return me diff --git a/pySDC/implementations/sweeper_classes/ParaDiagSweepers.py b/pySDC/implementations/sweeper_classes/ParaDiagSweepers.py new file mode 100644 index 0000000000..01fd4c37ac --- /dev/null +++ b/pySDC/implementations/sweeper_classes/ParaDiagSweepers.py @@ -0,0 +1,158 @@ +""" +These sweepers are made for use with ParaDiag. They can be used to some degree with SDC as well, but unless you know what you are doing, you probably want another sweeper. +""" + +from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit +from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order +import numpy as np +import scipy.sparse as sp + + +class QDiagonalization(generic_implicit): + """ + Sweeper solving the collocation problem directly via diagonalization of Q. Mainly made for ParaDiag. + Can be reconfigured for use with SDC. + + Note that the initial conditions for the collocation problem are generally stored in node zero in pySDC. However, + this sweeper is intended for ParaDiag, where a node-local residual is needed as a right hand side for this sweeper + rather than a step local one. Therefore, this sweeper has an option `ignore_ic`. If true, the value in node zero + will only be used in computing the step-local residual, but not in the solves. If false, the values on the nodes + will be ignored in the solves and the node-zero value will be used as initial conditions. When using this as a time- + parallel algorithm outside ParaDiag, you should set this parameter to false, which is not the default! + + Similarly, in ParaDiag, the solution is in Fourier space right after the solve. It therefore makes little sense to + evaluate the right hand side directly after. By default, this is not done! Set `update_f_evals=True` in the + parameters if you want to use this sweeper in SDC. + """ + + def __init__(self, params): + """ + Initialization routine for the custom sweeper + + Args: + params: parameters for the sweeper + """ + if 'G_inv' not in params.keys(): + params['G_inv'] = np.eye(params['num_nodes']) + params['update_f_evals'] = params.get('update_f_evals', False) + params['ignore_ic'] = params.get('ignore_ic', True) + + super().__init__(params) + + self.set_G_inv(self.params.G_inv) + + def set_G_inv(self, G_inv): + """ + In ParaDiag, QG^{-1} is diagonalized. This function stores the G_inv matrix and computes and stores the diagonalization. + """ + self.params.G_inv = G_inv + self.w, self.S, self.S_inv = self.computeDiagonalization(A=self.coll.Qmat[1:, 1:] @ self.params.G_inv) + + @staticmethod + def computeDiagonalization(A): + """ + Compute diagonalization of dense matrix A = S diag(w) S^-1 + + Args: + A (numpy.ndarray): dense matrix to diagonalize + + Returns: + numpy.array: Diagonal entries of the diagonalized matrix w + numpy.ndarray: Matrix of eigenvectors S + numpy.ndarray: Inverse of S + """ + w, S = np.linalg.eig(A) + S_inv = np.linalg.inv(S) + assert np.allclose(S @ np.diag(w) @ S_inv, A) + return w, S, S_inv + + def mat_vec(self, mat, vec): + """ + Compute matrix-vector multiplication. Vector can be list. + + Args: + mat: Matrix + vec: Vector + + Returns: + list: mat @ vec + """ + assert mat.shape[1] == len(vec) + result = [] + for m in range(mat.shape[0]): + result.append(self.level.prob.u_init) + for j in range(mat.shape[1]): + result[-1] += mat[m, j] * vec[j] + return result + + def update_nodes(self): + """ + Update the u- and f-values at the collocation nodes -> corresponds to a single sweep over all nodes + + Returns: + None + """ + + L = self.level + P = L.prob + M = self.coll.num_nodes + + if L.tau[0] is not None: + raise NotImplementedError('This sweeper does not work with multi-level SDC') + + # perform local solves on the collocation nodes, can be parallelized! + if self.params.ignore_ic: + x1 = self.mat_vec(self.S_inv, [self.level.u[m + 1] for m in range(M)]) + else: + x1 = self.mat_vec(self.S_inv, [self.level.u[0] for _ in range(M)]) + x2 = [] + for m in range(M): + # TODO: need to put averaged x1 in u0 here for nonlinear problems + u0 = L.u_avg[m] if L.u_avg[m] is not None else x1[m] + x2.append(P.solve_system(x1[m], self.w[m] * L.dt, u0=u0, t=L.time + L.dt * self.coll.nodes[m])) + z = self.mat_vec(self.S, x2) + y = self.mat_vec(self.params.G_inv, z) + + # update solution and evaluate right hand side + for m in range(M): + L.u[m + 1] = y[m] + if self.params.update_f_evals: + raise + L.f[m + 1] = P.eval_f(L.u[m + 1], L.time + L.dt * self.coll.nodes[m]) + + L.status.updated = True + return None + + def eval_f_at_all_nodes(self): + L = self.level + P = self.level.prob + for m in range(self.coll.num_nodes): + L.f[m + 1] = P.eval_f(L.u[m + 1], L.time + L.dt * self.coll.nodes[m]) + + def get_residual(self): + """ + This function computes and returns the "spatially extended" residual, not the norm of the residual! + + Returns: + pySDC.datatype: Spatially extended residual + + """ + self.eval_f_at_all_nodes() + + # start with integral dt*Q*F + residual = self.integrate() + + # subtract u and add u0 to arrive at r = dt*Q*F - u + u0 + for m in range(self.coll.num_nodes): + residual[m] -= self.level.u[m + 1] + residual[m] += self.level.u[0] + + return residual + + +class QDiagonalizationIMEX(QDiagonalization): + """ + Use as sweeper class for ParaDiag with IMEX splitting. Note that it will not work with SDC. + """ + + integrate = imex_1st_order.integrate diff --git a/pySDC/tests/test_controllers/test_controller_ParaDiag_nonMPI.py b/pySDC/tests/test_controllers/test_controller_ParaDiag_nonMPI.py new file mode 100644 index 0000000000..5ade344032 --- /dev/null +++ b/pySDC/tests/test_controllers/test_controller_ParaDiag_nonMPI.py @@ -0,0 +1,230 @@ +import pytest + + +def get_composite_collocation_problem(L, M, N, alpha=0, dt=1e-1, problem='Dahlquist', ParaDiag=True): + import numpy as np + from pySDC.implementations.hooks.log_errors import ( + LogGlobalErrorPostRun, + LogGlobalErrorPostStep, + ) + + if ParaDiag: + from pySDC.implementations.controller_classes.controller_ParaDiag_nonMPI import ( + controller_ParaDiag_nonMPI as controller_class, + ) + else: + from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI as controller_class + + average_jacobian = False + restol = 1e-8 + + if problem == 'Dahlquist': + from pySDC.implementations.problem_classes.TestEquation_0D import testequation0d as problem_class + + if ParaDiag: + from pySDC.implementations.sweeper_classes.ParaDiagSweepers import QDiagonalization as sweeper_class + else: + from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit as sweeper_class + + problem_params = {'lambdas': -1.0 * np.ones(shape=(N)), 'u0': 1} + elif problem == 'Dahlquist_IMEX': + from pySDC.implementations.problem_classes.TestEquation_0D import test_equation_IMEX as problem_class + + if ParaDiag: + from pySDC.implementations.sweeper_classes.ParaDiagSweepers import QDiagonalizationIMEX as sweeper_class + else: + from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order as sweeper_class + + problem_params = { + 'lambdas_implicit': -1.0 * np.ones(shape=(N)), + 'lambdas_explicit': -1.0e-1 * np.ones(shape=(N)), + 'u0': 1.0, + } + elif problem == 'heat': + from pySDC.implementations.problem_classes.HeatEquation_ND_FD import heatNd_forced as problem_class + from pySDC.implementations.sweeper_classes.ParaDiagSweepers import QDiagonalizationIMEX as sweeper_class + + problem_params = {'nvars': N} + elif problem == 'vdp': + from pySDC.implementations.problem_classes.Van_der_Pol_implicit import vanderpol as problem_class + + if ParaDiag: + from pySDC.implementations.sweeper_classes.ParaDiagSweepers import QDiagonalization as sweeper_class + + problem_params = {'newton_maxiter': 1, 'mu': 1e0, 'crash_at_maxiter': False} + average_jacobian = True + else: + from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit as sweeper_class + + problem_params = {'newton_maxiter': 99, 'mu': 1e0, 'crash_at_maxiter': True} + else: + raise NotImplementedError() + + level_params = {} + level_params['dt'] = dt + level_params['restol'] = restol + + sweeper_params = {} + sweeper_params['quad_type'] = 'RADAU-RIGHT' + sweeper_params['num_nodes'] = M + sweeper_params['initial_guess'] = 'spread' + + step_params = {} + step_params['maxiter'] = 99 + + controller_params = {} + controller_params['logger_level'] = 15 + controller_params['hook_class'] = [LogGlobalErrorPostRun, LogGlobalErrorPostStep] + controller_params['mssdc_jac'] = False + controller_params['alpha'] = alpha + controller_params['average_jacobian'] = average_jacobian + + description = {} + description['problem_class'] = problem_class + description['problem_params'] = problem_params + description['sweeper_class'] = sweeper_class + description['sweeper_params'] = sweeper_params + description['level_params'] = level_params + description['step_params'] = step_params + + controller_args = { + 'controller_params': controller_params, + 'description': description, + } + controller = controller_class(**controller_args, num_procs=L) + P = controller.MS[0].levels[0].prob + + for prob in [S.levels[0].prob for S in controller.MS]: + prob.init = tuple([*prob.init[:2]] + [np.dtype('complex128')]) + + return controller, P + + +@pytest.mark.base +@pytest.mark.parametrize('L', [1, 4]) +@pytest.mark.parametrize('M', [2, 3]) +@pytest.mark.parametrize('N', [2]) +@pytest.mark.parametrize('alpha', [1e-4, 1e-2]) +@pytest.mark.parametrize('problem', ['Dahlquist', 'Dahlquist_IMEX', 'vdp']) +def test_ParaDiag_convergence(L, M, N, alpha, problem): + from pySDC.helpers.stats_helper import get_sorted + + controller, prob = get_composite_collocation_problem(L, M, N, alpha, problem=problem) + level = controller.MS[0].levels[0] + + # setup initial conditions + u0 = prob.u_exact(0) + + uend, stats = controller.run(u0=u0, t0=0, Tend=L * level.dt * 2) + + # make some tests + error = get_sorted(stats, type='e_global_post_step') + k = get_sorted(stats, type='niter') + assert max(me[1] for me in k) < 90, 'ParaDiag did not converge' + if problem in ['Dahlquist', 'Dahlquist_IMEX']: + assert max(me[1] for me in error) < 1e-5, 'Error with ParaDiag too large' + + +@pytest.mark.base +@pytest.mark.parametrize('L', [1, 4]) +@pytest.mark.parametrize('M', [2, 3]) +@pytest.mark.parametrize('N', [64]) +@pytest.mark.parametrize('alpha', [1e-4, 1e-2]) +def test_IMEX_ParaDiag_convergence(L, M, N, alpha): + from pySDC.helpers.stats_helper import get_sorted + + controller, prob = get_composite_collocation_problem(L, M, N, alpha, problem='heat', dt=1e-3) + level = controller.MS[0].levels[0] + + # setup initial conditions + u0 = prob.u_exact(0) + + uend, stats = controller.run(u0=u0, t0=0, Tend=L * level.dt * 2) + + # make some tests + error = get_sorted(stats, type='e_global_post_step') + k = get_sorted(stats, type='niter') + assert max(me[1] for me in k) < 9, 'ParaDiag did not converge' + assert max(me[1] for me in error) < 1e-4, 'Error with ParaDiag too large' + + +@pytest.mark.base +@pytest.mark.parametrize('L', [1, 4]) +@pytest.mark.parametrize('M', [1, 2]) +@pytest.mark.parametrize('N', [2]) +@pytest.mark.parametrize('problem', ['Dahlquist', 'Dahlquist_IMEX', 'vdp']) +def test_ParaDiag_vs_PFASST(L, M, N, problem): + import numpy as np + + alpha = 1e-4 + + # setup the same composite collocation problem with different solvers + controllerParaDiag, prob = get_composite_collocation_problem(L, M, N, alpha, problem=problem, ParaDiag=True) + controllerPFASST, _ = get_composite_collocation_problem(L, M, N, alpha, problem=problem, ParaDiag=False) + level = controllerParaDiag.MS[0].levels[0] + + # setup initial conditions + u0 = prob.u_exact(0) + Tend = L * 2 * level.dt + + # run the two different solvers for the composite collocation problem + uendParaDiag, _ = controllerParaDiag.run(u0=u0, t0=0, Tend=Tend) + uendPFASST, _ = controllerPFASST.run(u0=u0, t0=0, Tend=Tend) + + assert np.allclose( + uendParaDiag, uendPFASST + ), f'Got different solutions between single-level PFASST and ParaDiag with {problem=}' + # make sure we didn't trick ourselves with a bug in the test... + assert ( + abs(uendParaDiag - uendPFASST) > 0 + ), 'The solutions with PFASST and ParaDiag are unexpectedly exactly the same!' + + +@pytest.mark.base +@pytest.mark.parametrize('L', [4]) +@pytest.mark.parametrize('M', [2, 3]) +@pytest.mark.parametrize('N', [1]) +@pytest.mark.parametrize('alpha', [1e-4, 1e-2]) +def test_ParaDiag_order(L, M, N, alpha): + import numpy as np + from pySDC.helpers.stats_helper import get_sorted + + errors = [] + if M == 3: + dts = [0.8 * 2 ** (-x) for x in range(7, 9)] + elif M == 2: + dts = [2 ** (-x) for x in range(5, 9)] + else: + raise NotImplementedError + Tend = max(dts) * L * 2 + + for dt in dts: + controller, prob = get_composite_collocation_problem(L, M, N, alpha, dt=dt) + level = controller.MS[0].levels[0] + + # setup initial conditions + u0 = prob.u_init + u0[:] = 1 + + uend, stats = controller.run(u0=u0, t0=0, Tend=Tend) + + # make some tests + errors.append(get_sorted(stats, type='e_global_post_run')[-1][1]) + + expected_order = level.sweep.coll.order + + errors = np.array(errors) + dts = np.array(dts) + order = np.log(abs(errors[1:] - errors[:-1])) / np.log(abs(dts[1:] - dts[:-1])) + num_order = np.mean(order) + + assert ( + expected_order + 1 > num_order > expected_order + ), f'Got unexpected numerical order {num_order} instead of {expected_order} in ParaDiag' + + +if __name__ == '__main__': + test_ParaDiag_vs_PFASST(4, 3, 2, 'Dahlquist') + # test_ParaDiag_convergence(4, 3, 1, 1e-4, 'vdp') + # test_IMEX_ParaDiag_convergence(4, 3, 64, 1e-4) + # test_ParaDiag_order(3, 3, 1, 1e-4) diff --git a/pySDC/tests/test_helpers/test_ParaDiagHelper.py b/pySDC/tests/test_helpers/test_ParaDiagHelper.py new file mode 100644 index 0000000000..51d35d4907 --- /dev/null +++ b/pySDC/tests/test_helpers/test_ParaDiagHelper.py @@ -0,0 +1,20 @@ +import pytest + + +@pytest.mark.base +@pytest.mark.parametrize('N', [4, 69]) +def test_get_FFT_matrix(N): + import numpy as np + from pySDC.helpers.ParaDiagHelper import get_FFT_matrix + + fft_mat = get_FFT_matrix(N) + + data = np.random.random(N) + + fft1 = fft_mat @ data + fft2 = np.fft.fft(data, norm='ortho') + assert np.allclose(fft1, fft2), 'Forward transform incorrect' + + ifft1 = np.conjugate(fft_mat) @ data + ifft2 = np.fft.ifft(data, norm='ortho') + assert np.allclose(ifft1, ifft2), 'Backward transform incorrect' diff --git a/pySDC/tests/test_problems/test_Dahlquist_IMEX.py b/pySDC/tests/test_problems/test_Dahlquist_IMEX.py new file mode 100644 index 0000000000..ebc2fb73c8 --- /dev/null +++ b/pySDC/tests/test_problems/test_Dahlquist_IMEX.py @@ -0,0 +1,32 @@ +def test_Dahlquist_IMEX(): + from pySDC.implementations.problem_classes.TestEquation_0D import test_equation_IMEX + import numpy as np + + N = 1 + dt = 1e-2 + + lambdas_implicit = np.ones(N) * -10 + lambdas_explicit = np.ones(N) * -1e-3 + + prob = test_equation_IMEX(lambdas_explicit=lambdas_explicit, lambdas_implicit=lambdas_implicit, u0=1) + + u0 = prob.u_exact(0) + + # do IMEX Euler step forward + f0 = prob.eval_f(u0, 0) + u1 = prob.solve_system(u0 + dt * f0.expl, dt, u0, 0) + + exact = prob.u_exact(dt) + error = abs(u1 - exact) + error0 = abs(u0 - exact) + assert error < error0 * 1e-1 + + # do explicit Euler step backwards + f = prob.eval_f(u1, dt) + u02 = u1 - dt * (f.impl + f0.expl) + + assert np.allclose(u0, u02) + + +if __name__ == '__main__': + test_Dahlquist_IMEX() diff --git a/pySDC/tests/test_sweepers/test_ParaDiag_sweepers.py b/pySDC/tests/test_sweepers/test_ParaDiag_sweepers.py new file mode 100644 index 0000000000..aeffedaeb6 --- /dev/null +++ b/pySDC/tests/test_sweepers/test_ParaDiag_sweepers.py @@ -0,0 +1,103 @@ +import pytest + + +def get_composite_collocation_problem(L, M, N, alpha=0): + import numpy as np + from pySDC.implementations.problem_classes.TestEquation_0D import testequation0d + from pySDC.implementations.controller_classes.controller_ParaDiag_nonMPI import controller_ParaDiag_nonMPI + + from pySDC.implementations.problem_classes.TestEquation_0D import testequation0d as problem_class + from pySDC.implementations.sweeper_classes.ParaDiagSweepers import QDiagonalization as sweeper_class + + problem_params = {'lambdas': -1.0 * np.ones(shape=(N)), 'u0': 1} + + level_params = {} + level_params['dt'] = 1e-1 + level_params['restol'] = -1 + + sweeper_params = {} + sweeper_params['quad_type'] = 'RADAU-RIGHT' + sweeper_params['num_nodes'] = M + sweeper_params['initial_guess'] = 'copy' + sweeper_params['update_f_evals'] = True + + step_params = {} + step_params['maxiter'] = 1 + + controller_params = {} + controller_params['logger_level'] = 30 + controller_params['hook_class'] = [] + controller_params['mssdc_jac'] = False + controller_params['alpha'] = alpha + + description = {} + description['problem_class'] = problem_class + description['problem_params'] = problem_params + description['sweeper_class'] = sweeper_class + description['sweeper_params'] = sweeper_params + description['level_params'] = level_params + description['step_params'] = step_params + + controller_args = { + 'controller_params': controller_params, + 'description': description, + } + controller = controller_ParaDiag_nonMPI(**controller_args, num_procs=L) + P = controller.MS[0].levels[0].prob + + for prob in [S.levels[0].prob for S in controller.MS]: + prob.init = tuple([*prob.init[:2]] + [np.dtype('complex128')]) + + return controller, P + + +@pytest.mark.base +@pytest.mark.parametrize('M', [1, 3]) +@pytest.mark.parametrize('N', [2, 4]) +@pytest.mark.parametrize('ignore_ic', [True, False]) +def test_direct_solve(M, N, ignore_ic): + """ + Test that the diagonalization has the same result as a direct solve of the collocation problem + """ + import numpy as np + import scipy.sparse as sp + + controller, prob = get_composite_collocation_problem(1, M, N) + + controller.MS[0].levels[0].status.unlocked = True + level = controller.MS[0].levels[0] + level.status.time = 0 + sweep = level.sweep + sweep.params.ignore_ic = ignore_ic + + # initial conditions + for m in range(M + 1): + level.u[m] = prob.u_exact(0) + level.f[m] = prob.eval_f(level.u[m], 0) + + if ignore_ic: + level.u[0][:] = None + + sweep.update_nodes() + sweep.eval_f_at_all_nodes() + + # solve directly + I_MN = sp.eye((M) * N) + Q = sweep.coll.Qmat[1:, 1:] + C_coll = I_MN - level.dt * sp.kron(Q, prob.A) + + u0 = np.zeros(shape=(M, N), dtype=complex) + for m in range(M): + u0[m, ...] = prob.u_exact(0) + u = sp.linalg.spsolve(C_coll, u0.flatten()).reshape(u0.shape) + + for m in range(M): + assert np.allclose(u[m], level.u[m + 1]) + + if not ignore_ic: + sweep.compute_residual() + assert np.isclose(level.status.residual, 0), 'residual is non-zero' + + +if __name__ == '__main__': + test_direct_solve(2, 1, False) diff --git a/pySDC/tests/test_tutorials/test_step_9.py b/pySDC/tests/test_tutorials/test_step_9.py new file mode 100644 index 0000000000..cfe0d680bb --- /dev/null +++ b/pySDC/tests/test_tutorials/test_step_9.py @@ -0,0 +1,20 @@ +import pytest + + +@pytest.mark.base +def test_step_9_A(): + import pySDC.tutorial.step_9.A_paradiag_for_linear_problems + + +@pytest.mark.base +def test_step_9_B(): + import pySDC.tutorial.step_9.B_paradiag_for_nonlinear_problems + + +@pytest.mark.base +@pytest.mark.parametrize('problem', ['advection', 'vdp']) +def test_step_9_C(problem): + + from pySDC.tutorial.step_9.C_paradiag_in_pySDC import compare_ParaDiag_and_PFASST + + compare_ParaDiag_and_PFASST(n_steps=16, problem=problem) diff --git a/pySDC/tutorial/step_9/A_paradiag_for_linear_problems.py b/pySDC/tutorial/step_9/A_paradiag_for_linear_problems.py new file mode 100644 index 0000000000..bacfe8e5fd --- /dev/null +++ b/pySDC/tutorial/step_9/A_paradiag_for_linear_problems.py @@ -0,0 +1,277 @@ +""" +This script introduces ParaDiag for linear problems. +It is recommended to view this code side by side with `Gaya's paper on ParaDiag with collocation methods +`_ as the code is close to the equations presented there but offers no explanations +about them. +""" + +import numpy as np +import scipy.sparse as sp +import sys + +from pySDC.implementations.problem_classes.TestEquation_0D import testequation0d as problem_class +from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit +from pySDC.implementations.sweeper_classes.ParaDiagSweepers import QDiagonalization + +# setup output +out_file = open('data/step_9_A_out.txt', 'w') + + +def my_print(*args, **kwargs): + for output in [sys.stdout, out_file]: + print(*args, **kwargs, file=output) + + +# setup parameters +L = 4 # Number of parallel time steps +M = 3 # Number of collocation nodes +N = 2 # Number of spatial degrees of freedom +alpha = 1e-4 # Circular perturbation parameter +restol = 1e-10 # Residual tolerance for the composite collocation problem +dt = 0.1 # step size + +my_print(f'Running ParaDiag test script with {L} time steps, {M} collocation nodes and {N} spatial degrees of freedom') + +# setup pySDC infrastructure for Dahlquist problem and quadrature +prob = problem_class(lambdas=-1.0 * np.ones(shape=(N)), u0=1.0) +sweeper_params = params = {'num_nodes': M, 'quad_type': 'RADAU-RIGHT'} +sweep = generic_implicit(sweeper_params) + +# Setup a global NumPy array and insert initial conditions in the first step +u = np.zeros((L, M, N), dtype=complex) +u[0, :, :] = prob.u_exact(t=0) + +# setup matrices for composite collocation problem. We note the sizes of the matrices in comments after generating them. + +# Start with identity matrices (I) of various sizes +I_L = sp.eye(L) # LxL +I_MN = sp.eye((M) * N) # MNxMN +I_N = sp.eye(N) # NxN +I_M = sp.eye(M) # MxM + +# E matrix propagates the solution of the steps to be the initial condition for the next step +E = sp.diags( + [ + -1.0, + ] + * (L - 1), + offsets=-1, +) # LxL + +""" +The H matrix computes the solution at the of an individual step from the solutions at the collocation nodes. +For the RADAU-RIGHT rule we use here, the right node coincides with the end of the interval, so this is simple. +We start with building the MxM matrix H_M on the node level and then extend to the spatial dimension with a Kronecker product. +""" +H_M = sp.eye(M).tolil() * 0 # MxM +H_M[:, -1] = 1 +H = sp.kron(H_M, I_N) # MNxMN + +""" +Set up collocation problem. +Note that the Kronecker product from Q and A is only possible when there is an A, i.e. when the problem is linear. +We will discuss non-linear problems in later steps in this tutorial +""" +Q = sweep.coll.Qmat[1:, 1:] # MxM +C_coll = I_MN - dt * sp.kron(Q, prob.A) # MNxMN + +# Set up the composite collocation / all-at-once problem +C = (sp.kron(I_L, C_coll) + sp.kron(E, H)).tocsc() # LMNxLMN + +""" +Now that we have the full composite collocation problem as one large matrix, we can just solve it directly to get a reference solution. +Of course, this is prohibitively expensive for any actual application and we would never want to do this in practice. +""" +sol_direct = sp.linalg.spsolve(C, u.flatten()).reshape(u.shape) + +""" +The normal time-stepping approach is to solve the composite collocation problem with forward substitution +""" +sol_stepping = u.copy() +for l in range(L): + """ + Solve the current step (sol_stepping[l] currently contains the initial conditions at step l) + Here, we only solve MNxMN systems rather than LMNxLMN systems. This is still really expensive in practice, which is why there is SDC, for example. + """ + sol_stepping[l, :] = sp.linalg.spsolve(C_coll, sol_stepping[l].flatten()).reshape(sol_stepping[l].shape) + + # place the solution to the current step as the initial conditions to the next step + if l < L - 1: + sol_stepping[l + 1, ...] = sol_stepping[l, -1, :] + +assert np.allclose(sol_stepping, sol_direct) + + +""" +So far, so serial and boring. We will now parallelize this using ParaDiag. +We will solve the composite collocation problem using preconditioned Picard iterations: + C_alpha delta = u_0 - Cu^k = < residual of the composite collocation problem > + u^{k+1} = u^k + delta +The trick behind ParaDiag is to choose the preconditioner C_alpha to be a time-periodic approximation to C that can be diagonalized and therefore inverted in parallel. +What we change in C_alpha compared to C is the E matrix that propagates the solutions between steps, which we amend to feed the solution to the last step back into the first step. +""" +E_alpha = sp.diags( + [ + -1.0, + ] + * (L - 1), + offsets=-1, +).tolil() # LxL +E_alpha[0, -1] = -alpha # make the problem time-periodic + +""" +In order to diagonalize C_alpha, on the step level, we need to diagonalize I_L and E_alpha simultaneously. +I_L and E_alpha are alpha-circular matrices which can be simultaneously diagonalized by a weighted Fourier transform. +We start by setting the weighting matrices for the Fourier transforms and then compute the diagonal entries of the diagonal version D_alpha of E_alpha. +We refrain from actually setting up the preconditioner because we will not use the expanded version here. +""" +gamma = alpha ** (-np.arange(L) / L) +J = sp.diags(gamma) # LxL +J_inv = sp.diags(1 / gamma) # LxL + +# compute diagonal entries via Fourier transform +D_alpha_diag_vals = np.fft.fft(1 / gamma * E_alpha[:, 0].toarray().flatten(), norm='backward') + + +""" +We need some convenience functions for computing matrix vector multiplication and the composite collocation problem residual here +""" + + +def mat_vec(mat, vec): + """ + Matrix vector product + + Args: + mat (np.ndarray or scipy.sparse) : Matrix + vec (np.ndarray) : vector + + Returns: + np.ndarray: mat @ vec + """ + res = np.zeros_like(vec).astype(complex) + for l in range(vec.shape[0]): + for k in range(vec.shape[0]): + res[l] += mat[l, k] * vec[k] + return res + + +def residual(_u, u0): + """ + Compute the residual of the composite collocation problem + + Args: + _u (np.ndarray): Current iterate + u0 (np.ndarray): Initial conditions + + Returns: + np.ndarray: LMN size array with the residual + """ + res = _u * 0j + for l in range(L): + # build step local residual + + # communicate initial conditions for each step + if l == 0: + res[l, ...] = u0[l, ...] + else: + res[l, ...] = _u[l - 1, -1, ...] + + # evaluate and subtract integral over right hand side functions + f_evals = np.array([prob.eval_f(_u[l, m], 0) for m in range(M)]) + Qf = mat_vec(Q, f_evals) + for m in range(M): + # res[l, m, ...] -= (_u[l] - dt * Qf)[-1] + res[l, m, ...] -= (_u[l] - dt * Qf)[m] + # res[l, m, ...] -= np.mean((_u[l] - dt * Qf), axis=0) + + return res + + +""" +We will start with ParaDiag where we parallelize across the L steps but solve the collocation problems directly in serial. +""" +sol_ParaDiag_L = u.copy() +u0 = u.copy() +niter_ParaDiag_L = 0 + +res = residual(sol_ParaDiag_L, u0) +while np.linalg.norm(res) > restol: + # compute weighted FFT in time to go to diagonal base of C_alpha + x = np.fft.fft( + mat_vec(J_inv.tolil(), res), + axis=0, + norm='ortho', + ) + + # solve the collocation problems in parallel on the steps + y = np.empty_like(x) + for l in range(L): + # construct local matrix of "collocation problem" + local_matrix = (D_alpha_diag_vals[l] * H + C_coll).tocsc() + + # solve local "collocation problem" directly + y[l, ...] = sp.linalg.spsolve(local_matrix, x[l, ...].flatten()).reshape(x[l, ...].shape) + + # compute inverse weighted FFT in time to go back from diagonal base of C_alpha + sol_ParaDiag_L += mat_vec(J.tolil(), np.fft.ifft(y, axis=0, norm='ortho')) + + # update residual + res = residual(sol_ParaDiag_L, u0) + niter_ParaDiag_L += 1 +my_print( + f'Needed {niter_ParaDiag_L} iterations in parallel across the steps ParaDiag. Stopped at residual {np.linalg.norm(res):.2e}' +) +assert np.allclose(sol_ParaDiag_L, sol_direct) + +""" +While we have distributed the work across L tasks, we are still solving perturbed collocation problems directly on a single task here. +This is very expensive, and we will now additionally diagonalize the quadrature matrix Q in order to distribute the work on LM tasks, where we solve NxN systems each. +We rearrange the contribution of E_alpha to arrive at a problem (I - dtQG^{-1}A)u = u0. +After diagonalizing QG^{-1}, we can simply utilize the Euler solves that are implemented in pySDC, but need to keep in mind that complex valued "step sizes" are required. + +We start by setting up the G and G^{-1} matrices. Then we will setup pySDC sweepers that solve QG^{-1} with diagonalization. +Here, we will not use the sweepers, but just the diagonalization computed there in order to make more clear what is going on. +""" +G = [(D_alpha_diag_vals[l] * H_M + I_M).tocsc() for l in range(L)] # MxM +G_inv = [sp.linalg.inv(_G).toarray() for _G in G] # MxM +sweepers = [QDiagonalization(params={**sweeper_params, 'G_inv': _G_inv}) for _G_inv in G_inv] + + +sol_ParaDiag = u.copy().astype(complex) +res = residual(sol_ParaDiag, u0) +niter = 0 +while np.max(np.abs(residual(sol_ParaDiag, u0))) > restol: + + # weighted FFT in time + x = np.fft.fft( + mat_vec(J_inv.tolil(), res), + axis=0, + norm='ortho', + ) + + # perform local solves of "collocation problems" on the steps in parallel + y = np.empty_like(x) + for l in range(L): + + # diagonalize QG^-1 matrix + w, S, S_inv = sweepers[l].w, sweepers[l].S, sweepers[l].S_inv + + # perform local solves on the collocation nodes in parallel + x1 = S_inv @ x[l] + x2 = np.empty_like(x1) + for m in range(M): + x2[m, :] = prob.solve_system(rhs=x1[m], factor=w[m] * dt, u0=x1[m], t=0) + z = S @ x2 + y[l, ...] = G_inv[l] @ z + + # inverse weighted FFT in time + sol_ParaDiag += mat_vec(J.tolil(), np.fft.ifft(y, axis=0, norm='ortho')) + + res = residual(sol_ParaDiag, u0) + niter += 1 +my_print( + f'Needed {niter} iterations in parallel and local paradiag with increment formulation, stopped at residual {np.linalg.norm(res):.2e}' +) +assert np.allclose(sol_ParaDiag, sol_direct) +assert np.allclose(niter, niter_ParaDiag_L) diff --git a/pySDC/tutorial/step_9/B_paradiag_for_nonlinear_problems.py b/pySDC/tutorial/step_9/B_paradiag_for_nonlinear_problems.py new file mode 100644 index 0000000000..0abcbd76ff --- /dev/null +++ b/pySDC/tutorial/step_9/B_paradiag_for_nonlinear_problems.py @@ -0,0 +1,205 @@ +""" +This script introduces ParaDiag for nonlinear problems with the van der Pol oscillator as an example. + +ParaDiag works by diagonalizing the "top layer" of Kronecker products that make up the circularized composite +collocation problem. +However, in nonlinear problems, the problem cannot be written as a matrix and therefore we cannot write the composite +collocation problem as a matrix. +There are two approaches for dealing with this. We can do IMEX splitting, where we treat only the linear part implicitly. +The ParaDiag preconditioner is then only made up of the linear implicit part and we can again write this as a matrix and +do the diagonalization just like for linear problems. The non-linear part then comes in via the residual on the right +hand side. +The second approach is to average Jacobians. The non-linear problems are solved with a Newton scheme, where the Jacobian +matrix is computed based on the current solution and then inverted in each Newton iteration. In order to write the +ParaDiag preconditioner as a matrix with Kronecker products and then only diagonalize the outermost part, we need to +have the same Jacobian on all steps. +The ParaDiag iteration then proceeds as follows: + - (1) Compute residual of composite collocation problem + - (2) Average the residual across the steps as preparation for computing the average Jacobian + Note that we still have values for each collocation node and space position. + - (3) Weighted FFT in time to diagonalize E_alpha + - (4) Solve for the increment by perform a single Newton iterations on the subproblems on the different steps and + nodes. The Jacobian is based on the averaged residual from (2) + - (5) Weighted iFFT in time + - (6) Increment solution +As IMEX ParaDiag is a trivial extension of ParaDiag for linear problems, we focus on the second approach here. +""" + +import numpy as np +import scipy.sparse as sp +import sys + +from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit as sweeper_class +from pySDC.implementations.problem_classes.Van_der_Pol_implicit import vanderpol + +# setup output +out_file = open('data/step_9_B_out.txt', 'w') + + +def my_print(*args, **kwargs): + for output in [sys.stdout, out_file]: + print(*args, **kwargs, file=output) + + +# setup parameters +L = 4 +M = 3 +alpha = 1e-4 +restol = 1e-8 +dt = 0.1 + +# setup infrastructure +prob = vanderpol(newton_maxiter=1, mu=1e0, crash_at_maxiter=False) +N = prob.init[0] + +# make problem work on complex data +prob.init = tuple([*prob.init[:2]] + [np.dtype('complex128')]) + +# setup global solution array +u = np.zeros((L, M, N), dtype=complex) + +# setup collocation problem +sweep = sweeper_class({'num_nodes': M, 'quad_type': 'RADAU-RIGHT'}) + +# initial conditions +u[0, :, :] = prob.u_exact(t=0) + +my_print( + f'Running ParaDiag test script for van der Pol with mu={prob.mu} and {L} time steps and {M} collocation nodes.' +) + + +""" +Setup matrices that make up the composite collocation problem. We do not set up the full composite collocation problem +here, however. See https://arxiv.org/abs/2103.12571 for the meaning of the matrices. +""" +I_M = sp.eye(M) + +H_M = sp.eye(M).tolil() * 0 +H_M[:, -1] = 1 + +Q = sweep.coll.Qmat[1:, 1:] + +E_alpha = sp.diags( + [ + -1.0, + ] + * (L - 1), + offsets=-1, +).tolil() +E_alpha[0, -1] = -alpha + +gamma = alpha ** (-np.arange(L) / L) +D_alpha_diag_vals = np.fft.fft(1 / gamma * E_alpha[:, 0].toarray().flatten(), norm='backward') + +J = sp.diags(gamma) +J_inv = sp.diags(1 / gamma) + +G = [(D_alpha_diag_vals[l] * H_M + I_M).tocsc() for l in range(L)] # MxM + +# prepare diagonalization of QG^{-1} +w = [] +S = [] +S_inv = [] + +for l in range(L): + # diagonalize QG^-1 matrix + if M > 1: + _w, _S = np.linalg.eig(Q @ sp.linalg.inv(G[l]).toarray()) + else: + _w, _S = np.linalg.eig(Q / (G[l].toarray())) + _S_inv = np.linalg.inv(_S) + w.append(_w) + S.append(_S) + S_inv.append(_S_inv) + +""" +Setup functions for computing matrix-vector productions on the steps and for computing the residual of the composite +collocation problem +""" + + +def mat_vec(mat, vec): + """ + Matrix vector product + + Args: + mat (np.ndarray or scipy.sparse) : Matrix + vec (np.ndarray) : vector + + Returns: + np.ndarray: mat @ vec + """ + res = np.zeros_like(vec) + for l in range(vec.shape[0]): + for k in range(vec.shape[0]): + res[l] += mat[l, k] * vec[k] + return res + + +def residual(_u, u0): + """ + Compute the residual of the composite collocation problem + + Args: + _u (np.ndarray): Current iterate + u0 (np.ndarray): Initial conditions + + Returns: + np.ndarray: LMN size array with the residual + """ + res = _u * 0j + for l in range(L): + # build step local residual + + # communicate initial conditions for each step + if l == 0: + res[l, ...] = u0[l, ...] + else: + res[l, ...] = _u[l - 1, -1, ...] + + # evaluate and subtract integral over right hand side functions + f_evals = np.array([prob.eval_f(_u[l, m], 0) for m in range(M)]) + Qf = mat_vec(Q, f_evals) + res[l, ...] -= _u[l] - dt * Qf + + return res + + +# do ParaDiag +sol_paradiag = u.copy() * 0j +u0 = u.copy() + +buf = prob.u_init +niter = 0 +res = residual(sol_paradiag, u0) +while np.max(np.abs(res)) > restol: + # compute all-at-once residual + res = residual(sol_paradiag, u0) + + # compute residual averaged across the L steps. This is the difference to ParaDiag for linear problems. + res_avg = np.mean(res, axis=0) + + # weighted FFT in time + x = np.fft.fft(mat_vec(J_inv.toarray(), res), axis=0) + + # perform local solves of "collocation problems" on the steps in parallel + y = np.empty_like(x) + for l in range(L): + + # perform local solves on the collocation nodes in parallel + x1 = S_inv[l] @ x[l] + x2 = np.empty_like(x1) + for m in range(M): + buf[:] = res_avg[m] # set up averaged Jacobian by using averaged residual as initial guess + x2[m, :] = prob.solve_system(x1[m], w[l][m] * dt, u0=buf, t=l * dt) + z = S[l] @ x2 + y[l, ...] = sp.linalg.spsolve(G[l], z) + + # inverse FFT in time and increment + sol_paradiag += mat_vec(J.toarray(), np.fft.ifft(y, axis=0)) + + res = residual(sol_paradiag, u0) + niter += 1 + assert niter < 99, 'ParaDiag did not converge for nonlinear problem!' +my_print(f'Needed {niter} ParaDiag iterations, stopped at residual {np.max(np.abs(res)):.2e}') diff --git a/pySDC/tutorial/step_9/C_paradiag_in_pySDC.py b/pySDC/tutorial/step_9/C_paradiag_in_pySDC.py new file mode 100644 index 0000000000..f01b060054 --- /dev/null +++ b/pySDC/tutorial/step_9/C_paradiag_in_pySDC.py @@ -0,0 +1,182 @@ +""" +This script shows how to setup ParaDiag in pySDC for two examples and compares performance to single-level PFASST in +Jacobi mode and serial time stepping. +In PFASST, we use a diagonal preconditioner, which allows for the same amount of parallelism as ParaDiag. +We show iteration counts per step here, but both schemes have further concurrency across the nodes. + +We have a linear advection example, discretized with finite differences, where ParaDiag converges in very few iterations. +PFASST, on the hand, needs a lot more iterations for this hyperbolic problem. +Note that we did not optimize either setup. With different choice of alpha in ParaDiag, or inexactness and coarsening +in PFASST, both schemes could be improved significantly. + +Second is the nonlinear van der Pol oscillator. We choose the mu parameter such that the problem is not overly stiff. +Here, ParaDiag needs many iterations compared to PFASST, but remember that we only perform one Newton iteration per +ParaDiag iteration. So per node, the number of Newton iterations is equal to the number of ParaDiag iterations. +In PFASST, on the other hand, we solve the systems to some accuracy and allow more iterations. Here, ParaDiag needs +fewer Newton iterations per step in total, leaving it with greater speedup. Again, inexactness could improve PFASST. + +This script is not meant to show that one parallelization scheme is better than the other. It does, however, demonstrate +that both schemes, without optimization, need fewer iterations per task than serial time stepping. Kindly refrain from +computing parallel efficiency for these examples, however. ;) +""" + +import numpy as np +import sys +from pySDC.helpers.stats_helper import get_sorted + +# prepare output +out_file = open('data/step_9_C_out.txt', 'w') + + +def my_print(*args, **kwargs): + for output in [sys.stdout, out_file]: + print(*args, **kwargs, file=output) + + +def get_description(problem='advection', mode='ParaDiag'): + level_params = {} + level_params['dt'] = 0.1 + level_params['restol'] = 1e-6 + + sweeper_params = {} + sweeper_params['quad_type'] = 'RADAU-RIGHT' + sweeper_params['num_nodes'] = 3 + sweeper_params['initial_guess'] = 'copy' + + if mode == 'ParaDiag': + from pySDC.implementations.sweeper_classes.ParaDiagSweepers import QDiagonalization as sweeper_class + + # we only want to use the averaged Jacobian and do only one Newton iteration per ParaDiag iteration! + newton_maxiter = 1 + else: + from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit as sweeper_class + + newton_maxiter = 99 + # need diagonal preconditioner for same concurrency as ParaDiag + sweeper_params['QI'] = 'MIN-SR-S' + + if problem == 'advection': + from pySDC.implementations.problem_classes.AdvectionEquation_ND_FD import advectionNd as problem_class + + problem_params = {'nvars': 64, 'order': 8, 'c': 1, 'solver_type': 'GMRES', 'lintol': 1e-8} + elif problem == 'vdp': + from pySDC.implementations.problem_classes.Van_der_Pol_implicit import vanderpol as problem_class + + # need to not raise an error when Newton has not converged because we do only one iteration + problem_params = {'newton_maxiter': newton_maxiter, 'crash_at_maxiter': False, 'mu': 1, 'newton_tol': 1e-9} + + step_params = {} + step_params['maxiter'] = 99 + + description = {} + description['problem_class'] = problem_class + description['problem_params'] = problem_params + description['sweeper_class'] = sweeper_class + description['sweeper_params'] = sweeper_params + description['level_params'] = level_params + description['step_params'] = step_params + + return description + + +def get_controller_params(problem='advection', mode='ParaDiag'): + from pySDC.implementations.hooks.log_errors import LogGlobalErrorPostRun + from pySDC.implementations.hooks.log_work import LogWork, LogSDCIterations + + controller_params = {} + controller_params['logger_level'] = 30 + controller_params['hook_class'] = [LogGlobalErrorPostRun, LogWork, LogSDCIterations] + + if mode == 'ParaDiag': + controller_params['alpha'] = 1e-4 + + # For nonlinear problems, we need to communicate the average solution, which allows to compute the average + # Jacobian locally. For linear problems, we do not want the extra communication. + if problem == 'advection': + controller_params['average_jacobians'] = False + elif problem == 'vdp': + controller_params['average_jacobians'] = True + else: + # We do Block-Jacobi multi-step SDC here. It's a bit silly but it's better for comparing "speedup" + controller_params['mssdc_jac'] = True + + return controller_params + + +def run_problem( + n_steps=4, + problem='advection', + mode='ParaDiag', +): + if mode == 'ParaDiag': + from pySDC.implementations.controller_classes.controller_ParaDiag_nonMPI import ( + controller_ParaDiag_nonMPI as controller_class, + ) + else: + from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI as controller_class + + if mode == 'serial': + num_procs = 1 + else: + num_procs = n_steps + + description = get_description(problem, mode) + controller_params = get_controller_params(problem, mode) + + controller = controller_class(num_procs=num_procs, description=description, controller_params=controller_params) + + for S in controller.MS: + S.levels[0].prob.init = tuple([*S.levels[0].prob.init[:2]] + [np.dtype('complex128')]) + + P = controller.MS[0].levels[0].prob + + t0 = 0.0 + uinit = P.u_exact(t0) + + uend, stats = controller.run(u0=uinit, t0=t0, Tend=n_steps * controller.MS[0].levels[0].dt) + return uend, stats + + +def compare_ParaDiag_and_PFASST(n_steps, problem): + my_print(f'Running {problem} with {n_steps} steps') + + uend_PD, stats_PD = run_problem(n_steps, problem, mode='ParaDiag') + uend_PF, stats_PF = run_problem(n_steps, problem, mode='PFASST') + uend_S, stats_S = run_problem(n_steps, problem, mode='serial') + + assert np.allclose(uend_PD, uend_PF) + assert np.allclose(uend_S, uend_PD) + assert ( + abs(uend_PD - uend_PF) > 0 + ) # two different iterative methods should not give identical results for non-zero tolerance + + k_PD = get_sorted(stats_PD, type='k') + k_PF = get_sorted(stats_PF, type='k') + + my_print( + f'Needed {max(me[1] for me in k_PD)} ParaDiag iterations and {max(me[1] for me in k_PF)} single-level PFASST iterations' + ) + if problem == 'advection': + k_GMRES_PD = get_sorted(stats_PD, type='work_GMRES') + k_GMRES_PF = get_sorted(stats_PF, type='work_GMRES') + k_GMRES_S = get_sorted(stats_S, type='work_GMRES') + my_print( + f'Maximum GMRES iterations on each step: {max(me[1] for me in k_GMRES_PD)} in ParaDiag, {max(me[1] for me in k_GMRES_PF)} in single-level PFASST and {sum(me[1] for me in k_GMRES_S)} total GMRES iterations in serial' + ) + elif problem == 'vdp': + k_Newton_PD = get_sorted(stats_PD, type='work_newton') + k_Newton_PF = get_sorted(stats_PF, type='work_newton') + k_Newton_S = get_sorted(stats_S, type='work_newton') + my_print( + f'Maximum Newton iterations on each step: {max(me[1] for me in k_Newton_PD)} in ParaDiag, {max(me[1] for me in k_Newton_PF)} in single-level PFASST and {sum(me[1] for me in k_Newton_S)} total Newton iterations in serial' + ) + my_print() + + +if __name__ == '__main__': + out_file = open('data/step_9_C_out.txt', 'w') + params = { + 'n_steps': 16, + } + compare_ParaDiag_and_PFASST(**params, problem='advection') + compare_ParaDiag_and_PFASST(**params, problem='vdp') From bdfa5c97405fc2f0167f85ad7c3e7271a4175389 Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Sat, 15 Feb 2025 14:28:37 +0000 Subject: [PATCH 2/9] Removed some old comments --- pySDC/core/controller.py | 1 - pySDC/implementations/sweeper_classes/ParaDiagSweepers.py | 1 - 2 files changed, 2 deletions(-) diff --git a/pySDC/core/controller.py b/pySDC/core/controller.py index fbc88057fe..b35338f217 100644 --- a/pySDC/core/controller.py +++ b/pySDC/core/controller.py @@ -358,7 +358,6 @@ def __init__(self, controller_params, description, n_steps, useMPI=None): n_steps (int): Number of parallel steps alpha (float): alpha parameter for ParaDiag """ - # TODO: where should I put alpha? When I want to adapt it, maybe it shouldn't be in the controller? from pySDC.implementations.sweeper_classes.ParaDiagSweepers import QDiagonalization if QDiagonalization in description['sweeper_class'].__mro__: diff --git a/pySDC/implementations/sweeper_classes/ParaDiagSweepers.py b/pySDC/implementations/sweeper_classes/ParaDiagSweepers.py index 01fd4c37ac..1e4df2ad6e 100644 --- a/pySDC/implementations/sweeper_classes/ParaDiagSweepers.py +++ b/pySDC/implementations/sweeper_classes/ParaDiagSweepers.py @@ -107,7 +107,6 @@ def update_nodes(self): x1 = self.mat_vec(self.S_inv, [self.level.u[0] for _ in range(M)]) x2 = [] for m in range(M): - # TODO: need to put averaged x1 in u0 here for nonlinear problems u0 = L.u_avg[m] if L.u_avg[m] is not None else x1[m] x2.append(P.solve_system(x1[m], self.w[m] * L.dt, u0=u0, t=L.time + L.dt * self.coll.nodes[m])) z = self.mat_vec(self.S, x2) From d313aef7cf9166ed11fad596797cd208d8a0e9b4 Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Mon, 17 Feb 2025 11:42:44 +0000 Subject: [PATCH 3/9] Added test that ParaDiag converges within the bounds for Dahlquist problem --- .../test_controller_ParaDiag_nonMPI.py | 42 ++++++++++++++++++- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/pySDC/tests/test_controllers/test_controller_ParaDiag_nonMPI.py b/pySDC/tests/test_controllers/test_controller_ParaDiag_nonMPI.py index 5ade344032..31861df4ee 100644 --- a/pySDC/tests/test_controllers/test_controller_ParaDiag_nonMPI.py +++ b/pySDC/tests/test_controllers/test_controller_ParaDiag_nonMPI.py @@ -6,6 +6,7 @@ def get_composite_collocation_problem(L, M, N, alpha=0, dt=1e-1, problem='Dahlqu from pySDC.implementations.hooks.log_errors import ( LogGlobalErrorPostRun, LogGlobalErrorPostStep, + LogGlobalErrorPostIter, ) if ParaDiag: @@ -74,7 +75,7 @@ def get_composite_collocation_problem(L, M, N, alpha=0, dt=1e-1, problem='Dahlqu controller_params = {} controller_params['logger_level'] = 15 - controller_params['hook_class'] = [LogGlobalErrorPostRun, LogGlobalErrorPostStep] + controller_params['hook_class'] = [LogGlobalErrorPostRun, LogGlobalErrorPostStep, LogGlobalErrorPostIter] controller_params['mssdc_jac'] = False controller_params['alpha'] = alpha controller_params['average_jacobian'] = average_jacobian @@ -223,8 +224,45 @@ def test_ParaDiag_order(L, M, N, alpha): ), f'Got unexpected numerical order {num_order} instead of {expected_order} in ParaDiag' +@pytest.mark.base +@pytest.mark.parametrize('L', [4, 12]) +@pytest.mark.parametrize('M', [2, 3]) +@pytest.mark.parametrize('N', [1]) +@pytest.mark.parametrize('alpha', [1e-4, 1e-2]) +def test_ParaDiag_convergence_rate(L, M, N, alpha): + r""" + Test that the error in ParaDiag contracts as fast as expected. + + The upper bound is \|u^{k+1} - u^*\| / \|u^k - u^*\| < \alpha / (1-\alpha). + Here, we compare to the exact solution to the continuous problem rather than the exact solution of the collocation + problem, which means the error stalls at time-discretization level. Therefore, we only check the contraction in the + first ParaDiag iteration. + """ + import numpy as np + from pySDC.helpers.stats_helper import get_sorted + + dt = 1e-2 + controller, prob = get_composite_collocation_problem(L, M, N, alpha, dt=dt, problem='Dahlquist') + + # setup initial conditions + u0 = prob.u_exact(0) + + uend, stats = controller.run(u0=u0, t0=0, Tend=L * dt) + + # test that the convergence rate in the first iteration is sufficiently small. + errors = get_sorted(stats, type='e_global_post_iteration', sortby='iter', time=(L - 1) * dt) + convergence_rates = [errors[i + 1][1] / errors[i][1] for i in range(len(errors) - 1)] + convergence_rate = convergence_rates[0] + convergence_bound = alpha / (1 - alpha) + + assert ( + convergence_rate < convergence_bound + ), f'Convergence rate {convergence_rate} exceeds upper bound of {convergence_bound}!' + + if __name__ == '__main__': - test_ParaDiag_vs_PFASST(4, 3, 2, 'Dahlquist') + test_ParaDiag_convergence_rate(4, 3, 1, 1e-4) + # test_ParaDiag_vs_PFASST(4, 3, 2, 'Dahlquist') # test_ParaDiag_convergence(4, 3, 1, 1e-4, 'vdp') # test_IMEX_ParaDiag_convergence(4, 3, 64, 1e-4) # test_ParaDiag_order(3, 3, 1, 1e-4) From 8254c831412ecb3643ef9b7d1736f6c86e217f69 Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Mon, 17 Feb 2025 12:14:44 +0000 Subject: [PATCH 4/9] Fix test on older Python versions --- .../tests/test_controllers/test_controller_ParaDiag_nonMPI.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pySDC/tests/test_controllers/test_controller_ParaDiag_nonMPI.py b/pySDC/tests/test_controllers/test_controller_ParaDiag_nonMPI.py index 31861df4ee..c0c5304899 100644 --- a/pySDC/tests/test_controllers/test_controller_ParaDiag_nonMPI.py +++ b/pySDC/tests/test_controllers/test_controller_ParaDiag_nonMPI.py @@ -250,7 +250,8 @@ def test_ParaDiag_convergence_rate(L, M, N, alpha): uend, stats = controller.run(u0=u0, t0=0, Tend=L * dt) # test that the convergence rate in the first iteration is sufficiently small. - errors = get_sorted(stats, type='e_global_post_iteration', sortby='iter', time=(L - 1) * dt) + t_last = max([me[0] for me in get_sorted(stats, type='e_global_post_iteration')]) + errors = get_sorted(stats, type='e_global_post_iteration', sortby='iter', time=t_last) convergence_rates = [errors[i + 1][1] / errors[i][1] for i in range(len(errors) - 1)] convergence_rate = convergence_rates[0] convergence_bound = alpha / (1 - alpha) From 9cddb175eef19389a69054b02ed09f3a1ee29f19 Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Tue, 18 Feb 2025 09:28:08 +0000 Subject: [PATCH 5/9] Added smaller values of alpha to the tests --- .../test_controller_ParaDiag_nonMPI.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/pySDC/tests/test_controllers/test_controller_ParaDiag_nonMPI.py b/pySDC/tests/test_controllers/test_controller_ParaDiag_nonMPI.py index c0c5304899..a2e8c7894c 100644 --- a/pySDC/tests/test_controllers/test_controller_ParaDiag_nonMPI.py +++ b/pySDC/tests/test_controllers/test_controller_ParaDiag_nonMPI.py @@ -105,7 +105,7 @@ def get_composite_collocation_problem(L, M, N, alpha=0, dt=1e-1, problem='Dahlqu @pytest.mark.parametrize('L', [1, 4]) @pytest.mark.parametrize('M', [2, 3]) @pytest.mark.parametrize('N', [2]) -@pytest.mark.parametrize('alpha', [1e-4, 1e-2]) +@pytest.mark.parametrize('alpha', [1e-8, 1e-4, 1e-2]) @pytest.mark.parametrize('problem', ['Dahlquist', 'Dahlquist_IMEX', 'vdp']) def test_ParaDiag_convergence(L, M, N, alpha, problem): from pySDC.helpers.stats_helper import get_sorted @@ -129,8 +129,8 @@ def test_ParaDiag_convergence(L, M, N, alpha, problem): @pytest.mark.base @pytest.mark.parametrize('L', [1, 4]) @pytest.mark.parametrize('M', [2, 3]) -@pytest.mark.parametrize('N', [64]) -@pytest.mark.parametrize('alpha', [1e-4, 1e-2]) +@pytest.mark.parametrize('N', [32]) +@pytest.mark.parametrize('alpha', [1e-8, 1e-4, 1e-2]) def test_IMEX_ParaDiag_convergence(L, M, N, alpha): from pySDC.helpers.stats_helper import get_sorted @@ -185,14 +185,14 @@ def test_ParaDiag_vs_PFASST(L, M, N, problem): @pytest.mark.parametrize('L', [4]) @pytest.mark.parametrize('M', [2, 3]) @pytest.mark.parametrize('N', [1]) -@pytest.mark.parametrize('alpha', [1e-4, 1e-2]) +@pytest.mark.parametrize('alpha', [1e-6, 1e-4]) def test_ParaDiag_order(L, M, N, alpha): import numpy as np from pySDC.helpers.stats_helper import get_sorted errors = [] if M == 3: - dts = [0.8 * 2 ** (-x) for x in range(7, 9)] + dts = [2 ** (-x) for x in range(9, 11)] elif M == 2: dts = [2 ** (-x) for x in range(5, 9)] else: @@ -217,18 +217,18 @@ def test_ParaDiag_order(L, M, N, alpha): errors = np.array(errors) dts = np.array(dts) order = np.log(abs(errors[1:] - errors[:-1])) / np.log(abs(dts[1:] - dts[:-1])) - num_order = np.mean(order) + num_order = np.median(order) assert ( expected_order + 1 > num_order > expected_order - ), f'Got unexpected numerical order {num_order} instead of {expected_order} in ParaDiag' + ), f'Got unexpected numerical order {num_order:2f} instead of {expected_order} in ParaDiag {order} {errors}' @pytest.mark.base @pytest.mark.parametrize('L', [4, 12]) -@pytest.mark.parametrize('M', [2, 3]) +@pytest.mark.parametrize('M', [3, 4]) @pytest.mark.parametrize('N', [1]) -@pytest.mark.parametrize('alpha', [1e-4, 1e-2]) +@pytest.mark.parametrize('alpha', [1e-6, 1e-4, 1e-2]) def test_ParaDiag_convergence_rate(L, M, N, alpha): r""" Test that the error in ParaDiag contracts as fast as expected. @@ -258,7 +258,7 @@ def test_ParaDiag_convergence_rate(L, M, N, alpha): assert ( convergence_rate < convergence_bound - ), f'Convergence rate {convergence_rate} exceeds upper bound of {convergence_bound}!' + ), f'Convergence rate {convergence_rate:.2e} exceeds upper bound of {convergence_bound:.2e}!' if __name__ == '__main__': From d316cde1af8681a1fd6672faef10acfd0671a235 Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Thu, 27 Feb 2025 12:15:31 +0000 Subject: [PATCH 6/9] Implemented separate Jacobian inversion to fix bug pointed out by @JHopeCollins --- pySDC/core/problem.py | 32 +++++++++++++++++++ .../controller_ParaDiag_nonMPI.py | 6 ++-- .../problem_classes/Van_der_Pol_implicit.py | 22 +++++++++---- .../sweeper_classes/ParaDiagSweepers.py | 10 ++++-- .../B_paradiag_for_nonlinear_problems.py | 16 ++++------ pySDC/tutorial/step_9/C_paradiag_in_pySDC.py | 12 +++---- 6 files changed, 70 insertions(+), 28 deletions(-) diff --git a/pySDC/core/problem.py b/pySDC/core/problem.py index cdfabeb897..146ecc8cef 100644 --- a/pySDC/core/problem.py +++ b/pySDC/core/problem.py @@ -164,3 +164,35 @@ def plot(self, u, t=None, fig=None): None """ raise NotImplementedError + + def solve_system(self, rhs, dt, u0, t): + """ + Perform an Euler step. + + Args: + rhs: Right hand side for the Euler step + dt (float): Step size for the Euler step + u0: Initial guess + t (float): Current time + + Returns: + solution to the Euler step + """ + raise NotImplementedError + + def solve_jacobian(self, rhs, dt, u=None, u0=None, t=0, **kwargs): + """ + Solve the Jacobian for an Euler step, linearized around u. + This defaults to an Euler step to accommodate linear problems. + + Args: + rhs: Right hand side for the Euler step + dt (float): Step size for the Euler step + u: Solution to linearize around + u0: Initial guess + t (float): Current time + + Returns: + Solution + """ + return self.solve_system(rhs, dt, u0=u, t=t, **kwargs) diff --git a/pySDC/implementations/controller_classes/controller_ParaDiag_nonMPI.py b/pySDC/implementations/controller_classes/controller_ParaDiag_nonMPI.py index e5c5deca3a..efbf161432 100644 --- a/pySDC/implementations/controller_classes/controller_ParaDiag_nonMPI.py +++ b/pySDC/implementations/controller_classes/controller_ParaDiag_nonMPI.py @@ -213,12 +213,12 @@ def it_ParaDiag(self, local_MS_running): for hook in self.hooks: hook.pre_sweep(step=S, level_number=0) - # replace the values stored in the steps with the residuals in order to compute the increment - self.swap_solution_for_all_at_once_residual(local_MS_running) - # communicate average residual for setting up Jacobians for non-linear problems self.prepare_Jacobians(local_MS_running) + # replace the values stored in the steps with the residuals in order to compute the increment + self.swap_solution_for_all_at_once_residual(local_MS_running) + # weighted FFT in time self.FFT_in_time() diff --git a/pySDC/implementations/problem_classes/Van_der_Pol_implicit.py b/pySDC/implementations/problem_classes/Van_der_Pol_implicit.py index 5cbc931315..130891d530 100755 --- a/pySDC/implementations/problem_classes/Van_der_Pol_implicit.py +++ b/pySDC/implementations/problem_classes/Van_der_Pol_implicit.py @@ -69,6 +69,7 @@ def __init__( localVars=locals(), ) self.work_counters['newton'] = WorkCounter() + self.work_counters['jacobian_solves'] = WorkCounter() self.work_counters['rhs'] = WorkCounter() def u_exact(self, t, u_init=None, t_init=None): @@ -167,13 +168,7 @@ def solve_system(self, rhs, dt, u0, t): if res < self.newton_tol or np.isnan(res): break - # prefactor for dg/du - c = 1.0 / (-2 * dt**2 * mu * x1 * x2 - dt**2 - 1 + dt * mu * (1 - x1**2)) - # assemble dg/du - dg = c * np.array([[dt * mu * (1 - x1**2) - 1, -dt], [2 * dt * mu * x1 * x2 + dt, -1]]) - - # newton update: u1 = u0 - g/dg - u -= np.dot(dg, g) + u -= self.solve_jacobian(g, dt, u) # set new values and increase iteration count x1 = u[0] @@ -191,3 +186,16 @@ def solve_system(self, rhs, dt, u0, t): raise ProblemError('Newton did not converge after %i iterations, error is %s' % (n, res)) return u + + def solve_jacobian(self, rhs, dt, u, **kwargs): + mu = self.mu + u1 = u[0] + u2 = u[1] + + # assemble prefactor + c = 1.0 / (-2 * dt**2 * mu * u1 * u2 - dt**2 - 1 + dt * mu * (1 - u1**2)) + # assemble (dg/du)^-1 + dg = c * np.array([[dt * mu * (1 - u1**2) - 1, -dt], [2 * dt * mu * u1 * u2 + dt, -1]]) + + self.work_counters['jacobian_solves']() + return np.dot(dg, rhs) diff --git a/pySDC/implementations/sweeper_classes/ParaDiagSweepers.py b/pySDC/implementations/sweeper_classes/ParaDiagSweepers.py index 1e4df2ad6e..ca66089259 100644 --- a/pySDC/implementations/sweeper_classes/ParaDiagSweepers.py +++ b/pySDC/implementations/sweeper_classes/ParaDiagSweepers.py @@ -105,10 +105,16 @@ def update_nodes(self): x1 = self.mat_vec(self.S_inv, [self.level.u[m + 1] for m in range(M)]) else: x1 = self.mat_vec(self.S_inv, [self.level.u[0] for _ in range(M)]) + + # get averaged state over all nodes for constructing the Jacobian + u_avg = P.u_init + if not any(me is None for me in L.u_avg): + for m in range(M): + u_avg += L.u_avg[m] / M + x2 = [] for m in range(M): - u0 = L.u_avg[m] if L.u_avg[m] is not None else x1[m] - x2.append(P.solve_system(x1[m], self.w[m] * L.dt, u0=u0, t=L.time + L.dt * self.coll.nodes[m])) + x2.append(P.solve_jacobian(x1[m], self.w[m] * L.dt, u=u_avg, t=L.time + L.dt * self.coll.nodes[m])) z = self.mat_vec(self.S, x2) y = self.mat_vec(self.params.G_inv, z) diff --git a/pySDC/tutorial/step_9/B_paradiag_for_nonlinear_problems.py b/pySDC/tutorial/step_9/B_paradiag_for_nonlinear_problems.py index 0abcbd76ff..120262e5a0 100644 --- a/pySDC/tutorial/step_9/B_paradiag_for_nonlinear_problems.py +++ b/pySDC/tutorial/step_9/B_paradiag_for_nonlinear_problems.py @@ -15,11 +15,10 @@ have the same Jacobian on all steps. The ParaDiag iteration then proceeds as follows: - (1) Compute residual of composite collocation problem - - (2) Average the residual across the steps as preparation for computing the average Jacobian - Note that we still have values for each collocation node and space position. + - (2) Average the solution across the steps and nodes as preparation for computing the average Jacobian - (3) Weighted FFT in time to diagonalize E_alpha - - (4) Solve for the increment by perform a single Newton iterations on the subproblems on the different steps and - nodes. The Jacobian is based on the averaged residual from (2) + - (4) Solve for the increment by inverting the averaged Jacobian from (2) on the subproblems on the different steps + and nodes. - (5) Weighted iFFT in time - (6) Increment solution As IMEX ParaDiag is a trivial extension of ParaDiag for linear problems, we focus on the second approach here. @@ -170,15 +169,15 @@ def residual(_u, u0): sol_paradiag = u.copy() * 0j u0 = u.copy() -buf = prob.u_init niter = 0 res = residual(sol_paradiag, u0) while np.max(np.abs(res)) > restol: # compute all-at-once residual res = residual(sol_paradiag, u0) - # compute residual averaged across the L steps. This is the difference to ParaDiag for linear problems. - res_avg = np.mean(res, axis=0) + # compute residual averaged across the L steps and M nodes. This is the difference to ParaDiag for linear problems. + u_avg = prob.u_init + u_avg[:] = np.mean(sol_paradiag, axis=(0, 1)) # weighted FFT in time x = np.fft.fft(mat_vec(J_inv.toarray(), res), axis=0) @@ -191,8 +190,7 @@ def residual(_u, u0): x1 = S_inv[l] @ x[l] x2 = np.empty_like(x1) for m in range(M): - buf[:] = res_avg[m] # set up averaged Jacobian by using averaged residual as initial guess - x2[m, :] = prob.solve_system(x1[m], w[l][m] * dt, u0=buf, t=l * dt) + x2[m, :] = prob.solve_jacobian(x1[m], w[l][m] * dt, u=u_avg, t=l * dt) z = S[l] @ x2 y[l, ...] = sp.linalg.spsolve(G[l], z) diff --git a/pySDC/tutorial/step_9/C_paradiag_in_pySDC.py b/pySDC/tutorial/step_9/C_paradiag_in_pySDC.py index f01b060054..74c0319b54 100644 --- a/pySDC/tutorial/step_9/C_paradiag_in_pySDC.py +++ b/pySDC/tutorial/step_9/C_paradiag_in_pySDC.py @@ -47,11 +47,9 @@ def get_description(problem='advection', mode='ParaDiag'): from pySDC.implementations.sweeper_classes.ParaDiagSweepers import QDiagonalization as sweeper_class # we only want to use the averaged Jacobian and do only one Newton iteration per ParaDiag iteration! - newton_maxiter = 1 else: from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit as sweeper_class - newton_maxiter = 99 # need diagonal preconditioner for same concurrency as ParaDiag sweeper_params['QI'] = 'MIN-SR-S' @@ -63,7 +61,7 @@ def get_description(problem='advection', mode='ParaDiag'): from pySDC.implementations.problem_classes.Van_der_Pol_implicit import vanderpol as problem_class # need to not raise an error when Newton has not converged because we do only one iteration - problem_params = {'newton_maxiter': newton_maxiter, 'crash_at_maxiter': False, 'mu': 1, 'newton_tol': 1e-9} + problem_params = {'newton_maxiter': 99, 'crash_at_maxiter': False, 'mu': 1, 'newton_tol': 1e-9} step_params = {} step_params['maxiter'] = 99 @@ -164,11 +162,11 @@ def compare_ParaDiag_and_PFASST(n_steps, problem): f'Maximum GMRES iterations on each step: {max(me[1] for me in k_GMRES_PD)} in ParaDiag, {max(me[1] for me in k_GMRES_PF)} in single-level PFASST and {sum(me[1] for me in k_GMRES_S)} total GMRES iterations in serial' ) elif problem == 'vdp': - k_Newton_PD = get_sorted(stats_PD, type='work_newton') - k_Newton_PF = get_sorted(stats_PF, type='work_newton') - k_Newton_S = get_sorted(stats_S, type='work_newton') + k_Jac_PD = get_sorted(stats_PD, type='work_jacobian_solves') + k_Jac_PF = get_sorted(stats_PF, type='work_jacobian_solves') + k_Jac_S = get_sorted(stats_S, type='work_jacobian_solves') my_print( - f'Maximum Newton iterations on each step: {max(me[1] for me in k_Newton_PD)} in ParaDiag, {max(me[1] for me in k_Newton_PF)} in single-level PFASST and {sum(me[1] for me in k_Newton_S)} total Newton iterations in serial' + f'Maximum Jacabian solves on each step: {max(me[1] for me in k_Jac_PD)} in ParaDiag, {max(me[1] for me in k_Jac_PF)} in single-level PFASST and {sum(me[1] for me in k_Jac_S)} total Jacobian solves in serial' ) my_print() From 3abdb6ae35fc6fe89ce1ea819fe3e81233cb8e11 Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Thu, 27 Feb 2025 13:59:44 +0000 Subject: [PATCH 7/9] Update pySDC/tutorial/step_9/B_paradiag_for_nonlinear_problems.py Co-authored-by: Josh Hope-Collins --- pySDC/tutorial/step_9/B_paradiag_for_nonlinear_problems.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pySDC/tutorial/step_9/B_paradiag_for_nonlinear_problems.py b/pySDC/tutorial/step_9/B_paradiag_for_nonlinear_problems.py index 120262e5a0..afb8fda139 100644 --- a/pySDC/tutorial/step_9/B_paradiag_for_nonlinear_problems.py +++ b/pySDC/tutorial/step_9/B_paradiag_for_nonlinear_problems.py @@ -175,7 +175,7 @@ def residual(_u, u0): # compute all-at-once residual res = residual(sol_paradiag, u0) - # compute residual averaged across the L steps and M nodes. This is the difference to ParaDiag for linear problems. + # compute solution averaged across the L steps and M nodes. This is the difference to ParaDiag for linear problems. u_avg = prob.u_init u_avg[:] = np.mean(sol_paradiag, axis=(0, 1)) From eb7cfe33dbbb88af90591763c9988967334129a2 Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Thu, 27 Feb 2025 17:33:41 +0000 Subject: [PATCH 8/9] Keep spatially extended residual in the levels after computation --- pySDC/core/level.py | 1 + pySDC/core/sweeper.py | 8 ++++---- .../controller_classes/controller_ParaDiag_nonMPI.py | 5 ++--- pySDC/implementations/sweeper_classes/ParaDiagSweepers.py | 5 ++++- .../tutorial/step_9/B_paradiag_for_nonlinear_problems.py | 1 - 5 files changed, 11 insertions(+), 9 deletions(-) diff --git a/pySDC/core/level.py b/pySDC/core/level.py index 76c415af54..a4ebbf191e 100644 --- a/pySDC/core/level.py +++ b/pySDC/core/level.py @@ -83,6 +83,7 @@ def __init__(self, problem_class, problem_params, sweeper_class, sweeper_params, self.u = [None] * (self.sweep.coll.num_nodes + 1) self.uold = [None] * (self.sweep.coll.num_nodes + 1) self.u_avg = [None] * self.sweep.coll.num_nodes + self.residual = [None] * self.sweep.coll.num_nodes self.f = [None] * (self.sweep.coll.num_nodes + 1) self.fold = [None] * (self.sweep.coll.num_nodes + 1) diff --git a/pySDC/core/sweeper.py b/pySDC/core/sweeper.py index ff8fe29c8c..ee0179b287 100644 --- a/pySDC/core/sweeper.py +++ b/pySDC/core/sweeper.py @@ -192,14 +192,14 @@ def compute_residual(self, stage=''): # build QF(u) res_norm = [] - res = self.integrate() + L.residual = self.integrate() for m in range(self.coll.num_nodes): - res[m] += L.u[0] - L.u[m + 1] + L.residual[m] += L.u[0] - L.u[m + 1] # add tau if associated if L.tau[m] is not None: - res[m] += L.tau[m] + L.residual[m] += L.tau[m] # use abs function from data type here - res_norm.append(abs(res[m])) + res_norm.append(abs(L.residual[m])) # find maximal residual over the nodes if L.params.residual_type == 'full_abs': diff --git a/pySDC/implementations/controller_classes/controller_ParaDiag_nonMPI.py b/pySDC/implementations/controller_classes/controller_ParaDiag_nonMPI.py index efbf161432..a4f517285b 100644 --- a/pySDC/implementations/controller_classes/controller_ParaDiag_nonMPI.py +++ b/pySDC/implementations/controller_classes/controller_ParaDiag_nonMPI.py @@ -152,12 +152,11 @@ def swap_solution_for_all_at_once_residual(self, local_MS_running): hook.post_comm(step=S, level_number=0, add_to_stats=True) # compute residuals locally - residual = S.levels[0].sweep.get_residual() - S.levels[0].status.residual = max(abs(me) for me in residual) + S.levels[0].sweep.compute_residual() # put residual in the solution variables for m in range(S.levels[0].sweep.coll.num_nodes): - S.levels[0].u[m + 1] = residual[m] + S.levels[0].u[m + 1] = S.levels[0].residual[m] def swap_increment_for_solution(self, local_MS_running): """ diff --git a/pySDC/implementations/sweeper_classes/ParaDiagSweepers.py b/pySDC/implementations/sweeper_classes/ParaDiagSweepers.py index ca66089259..5f3251ed88 100644 --- a/pySDC/implementations/sweeper_classes/ParaDiagSweepers.py +++ b/pySDC/implementations/sweeper_classes/ParaDiagSweepers.py @@ -122,7 +122,6 @@ def update_nodes(self): for m in range(M): L.u[m + 1] = y[m] if self.params.update_f_evals: - raise L.f[m + 1] = P.eval_f(L.u[m + 1], L.time + L.dt * self.coll.nodes[m]) L.status.updated = True @@ -154,6 +153,10 @@ def get_residual(self): return residual + def compute_residual(self, *args, **kwargs): + self.eval_f_at_all_nodes() + return super().compute_residual(*args, **kwargs) + class QDiagonalizationIMEX(QDiagonalization): """ diff --git a/pySDC/tutorial/step_9/B_paradiag_for_nonlinear_problems.py b/pySDC/tutorial/step_9/B_paradiag_for_nonlinear_problems.py index afb8fda139..70a24a8a02 100644 --- a/pySDC/tutorial/step_9/B_paradiag_for_nonlinear_problems.py +++ b/pySDC/tutorial/step_9/B_paradiag_for_nonlinear_problems.py @@ -168,7 +168,6 @@ def residual(_u, u0): # do ParaDiag sol_paradiag = u.copy() * 0j u0 = u.copy() - niter = 0 res = residual(sol_paradiag, u0) while np.max(np.abs(res)) > restol: From e4816ff431fca619efa63e22849b52d95d4959ae Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Fri, 28 Feb 2025 16:38:08 +0000 Subject: [PATCH 9/9] Refactored increment formulation --- pySDC/core/controller.py | 11 ++-- pySDC/core/level.py | 1 + .../controller_ParaDiag_nonMPI.py | 66 ++++++++----------- .../sweeper_classes/ParaDiagSweepers.py | 7 +- .../test_controller_ParaDiag_nonMPI.py | 37 ++++++++++- .../test_sweepers/test_ParaDiag_sweepers.py | 6 +- pySDC/tutorial/step_9/C_paradiag_in_pySDC.py | 2 +- 7 files changed, 80 insertions(+), 50 deletions(-) diff --git a/pySDC/core/controller.py b/pySDC/core/controller.py index b35338f217..b06d58e360 100644 --- a/pySDC/core/controller.py +++ b/pySDC/core/controller.py @@ -6,7 +6,6 @@ from pySDC.core.base_transfer import BaseTransfer from pySDC.helpers.pysdc_helper import FrozenClass from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence -from pySDC.implementations.convergence_controller_classes.store_uold import StoreUOld from pySDC.implementations.hooks.default_hook import DefaultHooks from pySDC.implementations.hooks.log_timings import CPUTimings @@ -378,12 +377,10 @@ def __init__(self, controller_params, description, n_steps, useMPI=None): controller_params['all_to_done'] = True super().__init__(controller_params=controller_params, description=description, useMPI=useMPI) - self.base_convergence_controllers += [StoreUOld] - self.ParaDiag_block_u0 = None self.n_steps = n_steps - def FFT_in_time(self): + def FFT_in_time(self, quantity): """ Compute weighted forward FFT in time. The weighting is determined by the alpha parameter in ParaDiag @@ -395,9 +392,9 @@ def FFT_in_time(self): self.__FFT_matrix = get_weighted_FFT_matrix(self.n_steps, self.params.alpha) - self.apply_matrix(self.__FFT_matrix) + self.apply_matrix(self.__FFT_matrix, quantity) - def iFFT_in_time(self): + def iFFT_in_time(self, quantity): """ Compute weighted backward FFT in time. The weighting is determined by the alpha parameter in ParaDiag """ @@ -406,4 +403,4 @@ def iFFT_in_time(self): self.__iFFT_matrix = get_weighted_iFFT_matrix(self.n_steps, self.params.alpha) - self.apply_matrix(self.__iFFT_matrix) + self.apply_matrix(self.__iFFT_matrix, quantity) diff --git a/pySDC/core/level.py b/pySDC/core/level.py index a4ebbf191e..84217d1faf 100644 --- a/pySDC/core/level.py +++ b/pySDC/core/level.py @@ -84,6 +84,7 @@ def __init__(self, problem_class, problem_params, sweeper_class, sweeper_params, self.uold = [None] * (self.sweep.coll.num_nodes + 1) self.u_avg = [None] * self.sweep.coll.num_nodes self.residual = [None] * self.sweep.coll.num_nodes + self.increment = [None] * self.sweep.coll.num_nodes self.f = [None] * (self.sweep.coll.num_nodes + 1) self.fold = [None] * (self.sweep.coll.num_nodes + 1) diff --git a/pySDC/implementations/controller_classes/controller_ParaDiag_nonMPI.py b/pySDC/implementations/controller_classes/controller_ParaDiag_nonMPI.py index a4f517285b..000ba846e8 100644 --- a/pySDC/implementations/controller_classes/controller_ParaDiag_nonMPI.py +++ b/pySDC/implementations/controller_classes/controller_ParaDiag_nonMPI.py @@ -1,11 +1,9 @@ import itertools -import copy as cp import numpy as np -import dill from pySDC.core.controller import ParaDiagController from pySDC.core import step as stepclass -from pySDC.core.errors import ControllerError, CommunicationError +from pySDC.core.errors import ControllerError from pySDC.implementations.convergence_controller_classes.basic_restarting import BasicRestarting from pySDC.helpers.ParaDiagHelper import get_G_inv_matrix @@ -92,7 +90,7 @@ def ParaDiag(self, local_MS_active): return all(S.status.done for S in local_MS_active) - def apply_matrix(self, mat): + def apply_matrix(self, mat, quantity): """ Apply a matrix on the step level. Needs to be square. Puts the result back into the controller. @@ -112,29 +110,33 @@ def apply_matrix(self, mat): None, ] * L + if quantity == 'residual': + me = [S.levels[0].residual for S in self.MS] + elif quantity == 'increment': + me = [S.levels[0].increment for S in self.MS] + else: + raise NotImplementedError + # compute matrix-vector product for i in range(mat.shape[0]): - res[i] = [prob.u_init for _ in range(M + 1)] + res[i] = [prob.u_init for _ in range(M)] for j in range(mat.shape[1]): - for m in range(M + 1): - res[i][m] += mat[i, j] * self.MS[j].levels[0].u[m] + for m in range(M): + res[i][m] += mat[i, j] * me[j][m] # put the result in the "output" for i in range(mat.shape[0]): - for m in range(M + 1): - self.MS[i].levels[0].u[m] = res[i][m] + for m in range(M): + me[i][m] = res[i][m] - def swap_solution_for_all_at_once_residual(self, local_MS_running): + def compute_all_at_once_residual(self, local_MS_running): """ - Replace the solution values in the steps with the all-at-once residual. - This requires to communicate the solutions at the end of the steps to be the initial conditions for the next steps. Afterwards, the residual can be computed locally on the steps. Args: local_MS_running (list): list of currently running steps """ - prob = self.MS[0].levels[0].prob for S in local_MS_running: # communicate initial conditions @@ -143,9 +145,7 @@ def swap_solution_for_all_at_once_residual(self, local_MS_running): for hook in self.hooks: hook.pre_comm(step=S, level_number=0) - if S.status.first: - S.levels[0].u[0] = prob.dtype_u(self.ParaDiag_block_u0) - else: + if not S.status.first: S.levels[0].u[0] = S.prev.levels[0].uend for hook in self.hooks: @@ -154,25 +154,16 @@ def swap_solution_for_all_at_once_residual(self, local_MS_running): # compute residuals locally S.levels[0].sweep.compute_residual() - # put residual in the solution variables - for m in range(S.levels[0].sweep.coll.num_nodes): - S.levels[0].u[m + 1] = S.levels[0].residual[m] - - def swap_increment_for_solution(self, local_MS_running): + def update_solution(self, local_MS_running): """ - After inversion of the preconditioner, the values stored in the steps are the increment. This function adds the - solution after the previous iteration to arrive at the solution after the current iteration. - Note that we also need to put in the initial conditions back in the first step because they will be perturbed by - the circular preconditioner. + Since we solve for the increment, we need to update the solution between iterations by adding the increment. Args: local_MS_running (list): list of currently running steps """ for S in local_MS_running: - for m in range(S.levels[0].sweep.coll.num_nodes + 1): - S.levels[0].u[m] = S.levels[0].uold[m] + S.levels[0].u[m] - if S.status.first: - S.levels[0].u[0] = self.ParaDiag_block_u0 + for m in range(S.levels[0].sweep.coll.num_nodes): + S.levels[0].u[m + 1] += S.levels[0].increment[m] def prepare_Jacobians(self, local_MS_running): # get solutions for constructing average Jacobians @@ -215,22 +206,22 @@ def it_ParaDiag(self, local_MS_running): # communicate average residual for setting up Jacobians for non-linear problems self.prepare_Jacobians(local_MS_running) - # replace the values stored in the steps with the residuals in order to compute the increment - self.swap_solution_for_all_at_once_residual(local_MS_running) + # compute the all-at-once residual to use as right hand side + self.compute_all_at_once_residual(local_MS_running) - # weighted FFT in time - self.FFT_in_time() + # weighted FFT of the residual in time + self.FFT_in_time(quantity='residual') # perform local solves of "collocation problems" on the steps (can be done in parallel) for S in local_MS_running: assert len(S.levels) == 1, 'Multi-level SDC not implemented in ParaDiag' S.levels[0].sweep.update_nodes() - # inverse FFT in time - self.iFFT_in_time() + # inverse FFT of the increment in time + self.iFFT_in_time(quantity='increment') - # replace the values stored in the steps with the previous solution plus the increment - self.swap_increment_for_solution(local_MS_running) + # get the next iterate by adding increment to previous iterate + self.update_solution(local_MS_running) for S in local_MS_running: for hook in self.hooks: @@ -438,7 +429,6 @@ def restart_block(self, active_slots, time, u0): u0: initial value to distribute across the steps """ - self.ParaDiag_block_u0 = u0 # need this for computing residual for j in range(len(active_slots)): # get slot number diff --git a/pySDC/implementations/sweeper_classes/ParaDiagSweepers.py b/pySDC/implementations/sweeper_classes/ParaDiagSweepers.py index 5f3251ed88..b9c1be53cb 100644 --- a/pySDC/implementations/sweeper_classes/ParaDiagSweepers.py +++ b/pySDC/implementations/sweeper_classes/ParaDiagSweepers.py @@ -102,7 +102,7 @@ def update_nodes(self): # perform local solves on the collocation nodes, can be parallelized! if self.params.ignore_ic: - x1 = self.mat_vec(self.S_inv, [self.level.u[m + 1] for m in range(M)]) + x1 = self.mat_vec(self.S_inv, [self.level.residual[m] for m in range(M)]) else: x1 = self.mat_vec(self.S_inv, [self.level.u[0] for _ in range(M)]) @@ -120,7 +120,10 @@ def update_nodes(self): # update solution and evaluate right hand side for m in range(M): - L.u[m + 1] = y[m] + if self.params.ignore_ic: + L.increment[m] = y[m] + else: + L.u[m + 1] = y[m] if self.params.update_f_evals: L.f[m + 1] = P.eval_f(L.u[m + 1], L.time + L.dt * self.coll.nodes[m]) diff --git a/pySDC/tests/test_controllers/test_controller_ParaDiag_nonMPI.py b/pySDC/tests/test_controllers/test_controller_ParaDiag_nonMPI.py index a2e8c7894c..3927e9ecea 100644 --- a/pySDC/tests/test_controllers/test_controller_ParaDiag_nonMPI.py +++ b/pySDC/tests/test_controllers/test_controller_ParaDiag_nonMPI.py @@ -261,9 +261,44 @@ def test_ParaDiag_convergence_rate(L, M, N, alpha): ), f'Convergence rate {convergence_rate:.2e} exceeds upper bound of {convergence_bound:.2e}!' +@pytest.mark.base +@pytest.mark.parametrize('L', [4, 12]) +@pytest.mark.parametrize('M', [3, 4]) +@pytest.mark.parametrize('N', [1, 2]) +def test_fft(L, M, N): + import numpy as np + from pySDC.helpers.ParaDiagHelper import get_FFT_matrix + + dt = 1e-2 + controller, prob = get_composite_collocation_problem(L, M, N, alpha=1e-1, dt=dt, problem='Dahlquist') + # generate random data + data = np.random.random((L, M, N)) + data = np.ones((L, M, N)) + + for l in range(L): + for m in range(M): + controller.MS[l].levels[0].residual[m] = prob.u_init + controller.MS[l].levels[0].residual[m][:] = data[l, m] + + fft_matrix = get_FFT_matrix(L) + controller.apply_matrix(fft_matrix, 'residual') + data_fft = np.fft.fft(data, axis=0, norm='ortho') + + for l in range(L): + for m in range(M): + assert np.allclose(controller.MS[l].levels[0].residual[m], data_fft[l, m]) + + controller.apply_matrix(np.conjugate(fft_matrix), 'residual') + for l in range(L): + for m in range(M): + assert np.allclose(controller.MS[l].levels[0].residual[m], data[l, m]) + + if __name__ == '__main__': - test_ParaDiag_convergence_rate(4, 3, 1, 1e-4) + test_fft(3, 2, 2) + # test_ParaDiag_convergence_rate(4, 3, 1, 1e-4) # test_ParaDiag_vs_PFASST(4, 3, 2, 'Dahlquist') # test_ParaDiag_convergence(4, 3, 1, 1e-4, 'vdp') # test_IMEX_ParaDiag_convergence(4, 3, 64, 1e-4) # test_ParaDiag_order(3, 3, 1, 1e-4) + print('done') diff --git a/pySDC/tests/test_sweepers/test_ParaDiag_sweepers.py b/pySDC/tests/test_sweepers/test_ParaDiag_sweepers.py index aeffedaeb6..d44f7fad04 100644 --- a/pySDC/tests/test_sweepers/test_ParaDiag_sweepers.py +++ b/pySDC/tests/test_sweepers/test_ParaDiag_sweepers.py @@ -75,6 +75,8 @@ def test_direct_solve(M, N, ignore_ic): level.u[m] = prob.u_exact(0) level.f[m] = prob.eval_f(level.u[m], 0) + level.sweep.compute_residual() + if ignore_ic: level.u[0][:] = None @@ -92,6 +94,8 @@ def test_direct_solve(M, N, ignore_ic): u = sp.linalg.spsolve(C_coll, u0.flatten()).reshape(u0.shape) for m in range(M): + if ignore_ic: + level.u[m + 1] = level.u[m + 1] + level.increment[m] assert np.allclose(u[m], level.u[m + 1]) if not ignore_ic: @@ -100,4 +104,4 @@ def test_direct_solve(M, N, ignore_ic): if __name__ == '__main__': - test_direct_solve(2, 1, False) + test_direct_solve(2, 1, True) diff --git a/pySDC/tutorial/step_9/C_paradiag_in_pySDC.py b/pySDC/tutorial/step_9/C_paradiag_in_pySDC.py index 74c0319b54..3f03da942d 100644 --- a/pySDC/tutorial/step_9/C_paradiag_in_pySDC.py +++ b/pySDC/tutorial/step_9/C_paradiag_in_pySDC.py @@ -176,5 +176,5 @@ def compare_ParaDiag_and_PFASST(n_steps, problem): params = { 'n_steps': 16, } - compare_ParaDiag_and_PFASST(**params, problem='advection') + # compare_ParaDiag_and_PFASST(**params, problem='advection') compare_ParaDiag_and_PFASST(**params, problem='vdp')