Skip to content

Commit ac0adec

Browse files
committed
add cvodes example
1 parent 19827f5 commit ac0adec

File tree

4 files changed

+419
-107
lines changed

4 files changed

+419
-107
lines changed

doc/arkode/guide/source/Mathematics.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2134,12 +2134,12 @@ Adjoint Sensitivity Analysis
21342134
============================
21352135

21362136
Consider :eq:`ARKODE_IVP_simple_explicit`, but where the ODE also depends on some parameters
2137-
:math:`p` (that is, we have :math:`f(t,y,p)`). Now, let :math:`g(y(t_f),p)` be a cost function for
2138-
which we would like to compute the gradients :math:`\partial g/\partial y(t_0)` and/or
2139-
:math:`\partial g/\partial p`. The adjoint method is one approach to obtaining the gradients that is
2140-
particularly efficient when there are relatively few cost functionals and a large number of parameters.
2141-
With the adjoint method we solve the adjoint ODEs for :math:`\lambda(t) \in \mathbb{R}^N` and
2142-
:math:`\mu(t) \in \mathbb{R}^{N_s}`:
2137+
:math:`p` (that is, we have :math:`f(t,y,p)`). Now, suppose we have a functional :math:`g(y(t_f),p)`
2138+
for which we would like to compute the gradients :math:`\partial g/\partial y(t_0)`
2139+
and/or :math:`\partial g/\partial p`. The adjoint method is one approach to obtaining the
2140+
gradients that is particularly efficient when there are relatively few functionals and a
2141+
large number of parameters. With the adjoint method we solve the adjoint ODEs for :math:`\lambda(t)
2142+
\in \mathbb{R}^N` and :math:`\mu(t) \in \mathbb{R}^{N_s}`:
21432143

21442144
.. math::
21452145
\lambda'(t) &= -f_y^T(t, y, p) \lambda,\quad \lambda(t_F) = g_y^T(y(t_f), p) \\

examples/arkode/C_serial/ark_lotka_volterra_adj.c

Lines changed: 144 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,40 @@
11
/* -----------------------------------------------------------------------------
2-
* SUNDIALS Copyright Start
3-
* Copyright (c) 2002-2024, Lawrence Livermore National Security
4-
* and Southern Methodist University.
5-
* All rights reserved.
6-
*
7-
* See the top-level LICENSE and NOTICE files for details.
8-
*
9-
* SPDX-License-Identifier: BSD-3-Clause
10-
* SUNDIALS Copyright End
11-
* -----------------------------------------------------------------------------
12-
* This example solves the Lotka-Volterra ODE with four parameters,
13-
*
14-
* u = [dx/dt] = [ p_0*x - p_1*x*y ]
15-
* [dy/dt] [ -p_2*y + p_3*x*y ].
16-
*
17-
* The initial condition is u(t_0) = 1.0 and we use the parameters
18-
* p = [1.5, 1.0, 3.0, 1.0]. The integration interval can be controlled via
19-
* the --tf command line argument, but by default it is t \in [0, 10.].
20-
* An explicit Runge--Kutta method is employed via the ARKStep time stepper
21-
* provided by ARKODE. After solving the forward problem, adjoint sensitivity
22-
* analysis (ASA) is performed using the discrete adjoint method available with
23-
* with ARKStep in order to obtain the gradient of the cost function,
24-
*
25-
* g(u,p,t) = (sum(u)^2) / 2,
26-
*
27-
* with respect to the initial condition and the parameters.
28-
*
29-
* ./ark_lotka_volterra_adj options:
30-
* --tf <real> the final simulation time
31-
* --dt <real> the timestep size
32-
* --order <int> the order of the RK method
33-
* --check-freq <int> how often to checkpoint (in steps)
34-
* --no-stages don't checkpoint stages
35-
* --dont-keep don't keep checkpoints around after loading
36-
* --help print these options
37-
* ---------------------------------------------------------------------------*/
2+
* SUNDIALS Copyright Start
3+
* Copyright (c) 2002-2024, Lawrence Livermore National Security
4+
* and Southern Methodist University.
5+
* All rights reserved.
6+
*
7+
* See the top-level LICENSE and NOTICE files for details.
8+
*
9+
* SPDX-License-Identifier: BSD-3-Clause
10+
* SUNDIALS Copyright End
11+
* -----------------------------------------------------------------------------
12+
* This example solves the Lotka-Volterra ODE with four parameters,
13+
*
14+
* u = [dx/dt] = [ p_0*x - p_1*x*y ]
15+
* [dy/dt] [ -p_2*y + p_3*x*y ].
16+
*
17+
* The initial condition is u(t_0) = 1.0 and we use the parameters
18+
* p = [1.5, 1.0, 3.0, 1.0]. The integration interval can be controlled via
19+
* the --tf command line argument, but by default it is t \in [0, 10.].
20+
* An explicit Runge--Kutta method is employed via the ARKStep time stepper
21+
* provided by ARKODE. After solving the forward problem, adjoint sensitivity
22+
* analysis (ASA) is performed using the discrete adjoint method available with
23+
* with ARKStep in order to obtain the gradient of the scalar cost function,
24+
*
25+
* g(u,p,t) = (sum(u)^2) / 2,
26+
*
27+
* with respect to the initial condition and the parameters.
28+
*
29+
* ./ark_lotka_volterra_adj options:
30+
* --tf <real> the final simulation time
31+
* --dt <real> the timestep size
32+
* --order <int> the order of the RK method
33+
* --check-freq <int> how often to checkpoint (in steps)
34+
* --no-stages don't checkpoint stages
35+
* --dont-keep don't keep checkpoints around after loading
36+
* --help print these options
37+
* ---------------------------------------------------------------------------*/
3838

3939
#include <stdio.h>
4040
#include <stdlib.h>
@@ -51,6 +51,7 @@
5151

5252
#include <arkode/arkode.h>
5353
#include <arkode/arkode_arkstep.h>
54+
#include "sundials/sundials_nvector.h"
5455

5556
typedef struct
5657
{
@@ -65,6 +66,7 @@ typedef struct
6566
static const sunrealtype params[4] = {1.5, 1.0, 3.0, 1.0};
6667
static void parse_args(int argc, char* argv[], ProgramArgs* args);
6768
static void print_help(int argc, char* argv[], int exit_code);
69+
static int check_retval(void* retval_ptr, const char* funcname, int opt);
6870

6971
int lotka_volterra(sunrealtype t, N_Vector uvec, N_Vector udotvec, void* user_data)
7072
{
@@ -110,7 +112,7 @@ int parameter_vjp(N_Vector vvec, N_Vector Jvvec, sunrealtype t, N_Vector uvec,
110112
return 0;
111113
}
112114

113-
sunrealtype g(N_Vector u, const sunrealtype* p, sunrealtype t)
115+
sunrealtype g(N_Vector u, const sunrealtype* p)
114116
{
115117
/* (sum(u) .^ 2) ./ 2 */
116118
sunrealtype* uarr = N_VGetArrayPointer(u);
@@ -119,7 +121,7 @@ sunrealtype g(N_Vector u, const sunrealtype* p, sunrealtype t)
119121
return (sum * sum) / SUN_RCONST(2.0);
120122
}
121123

122-
void dgdu(N_Vector uvec, N_Vector dgvec, const sunrealtype* p, sunrealtype t)
124+
void dgdu(N_Vector uvec, N_Vector dgvec, const sunrealtype* p)
123125
{
124126
sunrealtype* u = N_VGetArrayPointer(uvec);
125127
sunrealtype* dg = N_VGetArrayPointer(dgvec);
@@ -128,7 +130,7 @@ void dgdu(N_Vector uvec, N_Vector dgvec, const sunrealtype* p, sunrealtype t)
128130
dg[1] = u[0] + u[1];
129131
}
130132

131-
void dgdp(N_Vector uvec, N_Vector dgvec, const sunrealtype* p, sunrealtype t)
133+
void dgdp(N_Vector uvec, N_Vector dgvec, const sunrealtype* p)
132134
{
133135
sunrealtype* u = N_VGetArrayPointer(uvec);
134136
sunrealtype* dg = N_VGetArrayPointer(dgvec);
@@ -139,58 +141,9 @@ void dgdp(N_Vector uvec, N_Vector dgvec, const sunrealtype* p, sunrealtype t)
139141
dg[3] = SUN_RCONST(0.0);
140142
}
141143

142-
int forward_solution(SUNContext sunctx, void* arkode_mem,
143-
SUNAdjointCheckpointScheme checkpoint_scheme,
144-
const sunrealtype t0, const sunrealtype tf,
145-
const sunrealtype dt, N_Vector u)
146-
{
147-
int retval = 0;
148-
149-
retval = ARKodeSetUserData(arkode_mem, (void*)params);
150-
retval = ARKodeSetFixedStep(arkode_mem, dt);
151-
152-
sunrealtype t = t0;
153-
while (t < tf)
154-
{
155-
retval = ARKodeEvolve(arkode_mem, tf, u, &t, ARK_NORMAL);
156-
if (retval < 0)
157-
{
158-
fprintf(stderr, ">>> ERROR: ARKodeEvolve returned %d\n", retval);
159-
return -1;
160-
}
161-
}
162-
163-
printf("Forward Solution:\n");
164-
N_VPrint(u);
165-
166-
printf("ARKODE Stats for Forward Solution:\n");
167-
ARKodePrintAllStats(arkode_mem, stdout, SUN_OUTPUTFORMAT_TABLE);
168-
printf("\n");
169-
170-
return 0;
171-
}
172-
173-
int adjoint_solution(SUNContext sunctx, SUNAdjointStepper adj_stepper,
174-
SUNAdjointCheckpointScheme checkpoint_scheme,
175-
const sunrealtype tf, const sunrealtype tout, N_Vector sf)
176-
{
177-
int retval = 0;
178-
int stop_reason = 0;
179-
sunrealtype t = tf;
180-
retval = SUNAdjointStepper_Evolve(adj_stepper, tout, sf, &t, &stop_reason);
181-
182-
printf("Adjoint Solution:\n");
183-
N_VPrint(sf);
184-
185-
printf("\nSUNAdjointStepper Stats:\n");
186-
SUNAdjointStepper_PrintAllStats(adj_stepper, stdout, SUN_OUTPUTFORMAT_TABLE);
187-
printf("\n");
188-
189-
return 0;
190-
}
191-
192144
int main(int argc, char* argv[])
193145
{
146+
int retval = 0;
194147
SUNContext sunctx = NULL;
195148
SUNContext_Create(SUN_COMM_NULL, &sunctx);
196149

@@ -209,6 +162,8 @@ int main(int argc, char* argv[])
209162

210163
sunindextype neq = 2;
211164
N_Vector u = N_VNew_Serial(neq, sunctx);
165+
N_Vector u0 = N_VClone(u);
166+
N_VConst(1.0, u0);
212167
N_VConst(1.0, u);
213168

214169
//
@@ -222,8 +177,11 @@ int main(int argc, char* argv[])
222177
const int order = args.order;
223178
void* arkode_mem = ARKStepCreate(lotka_volterra, NULL, t0, u, sunctx);
224179

225-
ARKodeSetOrder(arkode_mem, order);
226-
ARKodeSetMaxNumSteps(arkode_mem, nsteps * 2);
180+
retval = ARKodeSetOrder(arkode_mem, order);
181+
if (check_retval(&retval, "ARKodeSetOrder", 1)) { return 1; }
182+
183+
retval = ARKodeSetMaxNumSteps(arkode_mem, nsteps * 2);
184+
if (check_retval(&retval, "ARKodeSetMaxNumSteps", 1)) { return 1; }
227185

228186
// Enable checkpointing during the forward solution.
229187
const int check_interval = args.check_freq;
@@ -232,10 +190,22 @@ int main(int argc, char* argv[])
232190
const sunbooleantype keep_check = args.keep_checks;
233191
SUNAdjointCheckpointScheme checkpoint_scheme = NULL;
234192
SUNMemoryHelper mem_helper = SUNMemoryHelper_Sys(sunctx);
235-
SUNAdjointCheckpointScheme_Create_Basic(SUNDATAIOMODE_INMEM, mem_helper,
236-
check_interval, ncheck, save_stages,
237-
keep_check, sunctx, &checkpoint_scheme);
238-
ARKodeSetAdjointCheckpointScheme(arkode_mem, checkpoint_scheme);
193+
194+
retval = SUNAdjointCheckpointScheme_Create_Basic(SUNDATAIOMODE_INMEM,
195+
mem_helper, check_interval,
196+
ncheck, save_stages,
197+
keep_check, sunctx,
198+
&checkpoint_scheme);
199+
if (check_retval(&retval, "SUNAdjointCheckpointScheme_Create_Basic", 1))
200+
{
201+
return 1;
202+
}
203+
204+
retval = ARKodeSetAdjointCheckpointScheme(arkode_mem, checkpoint_scheme);
205+
if (check_retval(&retval, "ARKodeSetAdjointCheckpointScheme", 1))
206+
{
207+
return 1;
208+
}
239209

240210
//
241211
// Compute the forward solution
@@ -244,7 +214,30 @@ int main(int argc, char* argv[])
244214
printf("Initial condition:\n");
245215
N_VPrint(u);
246216

247-
forward_solution(sunctx, arkode_mem, checkpoint_scheme, t0, tf, dt, u);
217+
retval = ARKodeSetUserData(arkode_mem, (void*)params);
218+
if (check_retval(&retval, "ARKodeSetUserData", 1)) { return 1; }
219+
220+
retval = ARKodeSetFixedStep(arkode_mem, dt);
221+
if (check_retval(&retval, "ARKodeSetFixedStep", 1)) { return 1; }
222+
223+
sunrealtype tret = t0;
224+
while (tret < tf)
225+
{
226+
retval = ARKodeEvolve(arkode_mem, tf, u, &tret, ARK_NORMAL);
227+
if (retval < 0)
228+
{
229+
fprintf(stderr, ">>> ERROR: ARKodeEvolve returned %d\n", retval);
230+
return -1;
231+
}
232+
}
233+
234+
printf("Forward Solution:\n");
235+
N_VPrint(u);
236+
237+
printf("ARKODE Stats for Forward Solution:\n");
238+
retval = ARKodePrintAllStats(arkode_mem, stdout, SUN_OUTPUTFORMAT_TABLE);
239+
if (check_retval(&retval, "ARKodePrintAllStats", 1)) { return 1; }
240+
printf("\n");
248241

249242
//
250243
// Create the adjoint stepper
@@ -258,21 +251,44 @@ int main(int argc, char* argv[])
258251

259252
// Set the terminal condition for the adjoint system, which
260253
// should be the the gradient of our cost function at tf.
261-
dgdu(u, sensu0, params, tf);
262-
dgdp(u, sensp, params, tf);
254+
dgdu(u, sensu0, params);
255+
dgdp(u, sensp, params);
263256

264257
printf("Adjoint terminal condition:\n");
265258
N_VPrint(sf);
266259

267260
SUNAdjointStepper adj_stepper;
268-
ARKStepCreateAdjointStepper(arkode_mem, sf, &adj_stepper);
261+
retval = ARKStepCreateAdjointStepper(arkode_mem, sf, &adj_stepper);
262+
if (check_retval(&retval, "ARKStepCreateAdjointStepper", 1)) { return 1; }
263+
264+
retval = SUNAdjointStepper_SetVecTimesJacFn(adj_stepper, vjp, parameter_vjp);
265+
if (check_retval(&retval, "SUNAdjointStepper_SetVecTimesJacFn", 1))
266+
{
267+
return 1;
268+
}
269269

270270
//
271271
// Now compute the adjoint solution
272272
//
273273

274-
SUNAdjointStepper_SetVecTimesJacFn(adj_stepper, vjp, parameter_vjp);
275-
adjoint_solution(sunctx, adj_stepper, checkpoint_scheme, tf, t0, sf);
274+
int stop_reason = 0;
275+
retval = SUNAdjointStepper_Evolve(adj_stepper, t0, sf, &tret, &stop_reason);
276+
if (check_retval(&retval, "SUNAdjointStepper_Evolve", 1)) { return 1; }
277+
278+
// Compute gradient w.r.t. parameters:
279+
N_Vector tmp = N_VClone(sensp);
280+
parameter_vjp(sensp, tmp, tret, u0, NULL, (void*)params, NULL);
281+
N_VLinearSum(1.0, sensp, 1.0, tmp, sensp);
282+
N_VDestroy(tmp);
283+
284+
printf("Adjoint Solution:\n");
285+
N_VPrint(sf);
286+
287+
printf("\nSUNAdjointStepper Stats:\n");
288+
retval = SUNAdjointStepper_PrintAllStats(adj_stepper, stdout,
289+
SUN_OUTPUTFORMAT_TABLE);
290+
if (check_retval(&retval, "SUNAdjointStepper_PrintAllStats", 1)) { return 1; }
291+
printf("\n");
276292

277293
//
278294
// Cleanup
@@ -321,3 +337,30 @@ void parse_args(int argc, char* argv[], ProgramArgs* args)
321337
else { print_help(argc, argv, 1); }
322338
}
323339
}
340+
341+
int check_retval(void* retval_ptr, const char* funcname, int opt)
342+
{
343+
int* retval;
344+
345+
/* Check if SUNDIALS function returned NULL pointer - no memory allocated */
346+
if (opt == 0 && retval_ptr == NULL)
347+
{
348+
fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed - returned NULL pointer\n\n",
349+
funcname);
350+
return 1;
351+
}
352+
353+
/* Check if retval < 0 */
354+
else if (opt == 1)
355+
{
356+
retval = (int*)retval_ptr;
357+
if (*retval < 0)
358+
{
359+
fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed with retval = %d\n\n",
360+
funcname, *retval);
361+
return 1;
362+
}
363+
}
364+
365+
return (0);
366+
}

examples/cvodes/serial/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ set(CVODES_examples
3737
"cvsKrylovDemo_ls\;1\;develop"
3838
"cvsKrylovDemo_ls\;2\;develop"
3939
"cvsKrylovDemo_prec\;\;develop"
40+
"cvsLotkaVolterra_ASA\;\;develop"
4041
"cvsParticle_dns\;\;develop"
4142
"cvsPendulum_dns\;\;exclude-single"
4243
"cvsRoberts_ASAi_dns\;\;exclude-single"

0 commit comments

Comments
 (0)