1
1
/* -----------------------------------------------------------------------------
2
- * SUNDIALS Copyright Start
3
- * Copyright (c) 2002-2024, Lawrence Livermore National Security
4
- * and Southern Methodist University.
5
- * All rights reserved.
6
- *
7
- * See the top-level LICENSE and NOTICE files for details.
8
- *
9
- * SPDX-License-Identifier: BSD-3-Clause
10
- * SUNDIALS Copyright End
11
- * -----------------------------------------------------------------------------
12
- * This example solves the Lotka-Volterra ODE with four parameters,
13
- *
14
- * u = [dx/dt] = [ p_0*x - p_1*x*y ]
15
- * [dy/dt] [ -p_2*y + p_3*x*y ].
16
- *
17
- * The initial condition is u(t_0) = 1.0 and we use the parameters
18
- * p = [1.5, 1.0, 3.0, 1.0]. The integration interval can be controlled via
19
- * the --tf command line argument, but by default it is t \in [0, 10.].
20
- * An explicit Runge--Kutta method is employed via the ARKStep time stepper
21
- * provided by ARKODE. After solving the forward problem, adjoint sensitivity
22
- * analysis (ASA) is performed using the discrete adjoint method available with
23
- * with ARKStep in order to obtain the gradient of the cost function,
24
- *
25
- * g(u,p,t) = (sum(u)^2) / 2,
26
- *
27
- * with respect to the initial condition and the parameters.
28
- *
29
- * ./ark_lotka_volterra_adj options:
30
- * --tf <real> the final simulation time
31
- * --dt <real> the timestep size
32
- * --order <int> the order of the RK method
33
- * --check-freq <int> how often to checkpoint (in steps)
34
- * --no-stages don't checkpoint stages
35
- * --dont-keep don't keep checkpoints around after loading
36
- * --help print these options
37
- * ---------------------------------------------------------------------------*/
2
+ * SUNDIALS Copyright Start
3
+ * Copyright (c) 2002-2024, Lawrence Livermore National Security
4
+ * and Southern Methodist University.
5
+ * All rights reserved.
6
+ *
7
+ * See the top-level LICENSE and NOTICE files for details.
8
+ *
9
+ * SPDX-License-Identifier: BSD-3-Clause
10
+ * SUNDIALS Copyright End
11
+ * -----------------------------------------------------------------------------
12
+ * This example solves the Lotka-Volterra ODE with four parameters,
13
+ *
14
+ * u = [dx/dt] = [ p_0*x - p_1*x*y ]
15
+ * [dy/dt] [ -p_2*y + p_3*x*y ].
16
+ *
17
+ * The initial condition is u(t_0) = 1.0 and we use the parameters
18
+ * p = [1.5, 1.0, 3.0, 1.0]. The integration interval can be controlled via
19
+ * the --tf command line argument, but by default it is t \in [0, 10.].
20
+ * An explicit Runge--Kutta method is employed via the ARKStep time stepper
21
+ * provided by ARKODE. After solving the forward problem, adjoint sensitivity
22
+ * analysis (ASA) is performed using the discrete adjoint method available with
23
+ * with ARKStep in order to obtain the gradient of the scalar cost function,
24
+ *
25
+ * g(u,p,t) = (sum(u)^2) / 2,
26
+ *
27
+ * with respect to the initial condition and the parameters.
28
+ *
29
+ * ./ark_lotka_volterra_adj options:
30
+ * --tf <real> the final simulation time
31
+ * --dt <real> the timestep size
32
+ * --order <int> the order of the RK method
33
+ * --check-freq <int> how often to checkpoint (in steps)
34
+ * --no-stages don't checkpoint stages
35
+ * --dont-keep don't keep checkpoints around after loading
36
+ * --help print these options
37
+ * ---------------------------------------------------------------------------*/
38
38
39
39
#include <stdio.h>
40
40
#include <stdlib.h>
51
51
52
52
#include <arkode/arkode.h>
53
53
#include <arkode/arkode_arkstep.h>
54
+ #include "sundials/sundials_nvector.h"
54
55
55
56
typedef struct
56
57
{
@@ -65,6 +66,7 @@ typedef struct
65
66
static const sunrealtype params [4 ] = {1.5 , 1.0 , 3.0 , 1.0 };
66
67
static void parse_args (int argc , char * argv [], ProgramArgs * args );
67
68
static void print_help (int argc , char * argv [], int exit_code );
69
+ static int check_retval (void * retval_ptr , const char * funcname , int opt );
68
70
69
71
int lotka_volterra (sunrealtype t , N_Vector uvec , N_Vector udotvec , void * user_data )
70
72
{
@@ -110,7 +112,7 @@ int parameter_vjp(N_Vector vvec, N_Vector Jvvec, sunrealtype t, N_Vector uvec,
110
112
return 0 ;
111
113
}
112
114
113
- sunrealtype g (N_Vector u , const sunrealtype * p , sunrealtype t )
115
+ sunrealtype g (N_Vector u , const sunrealtype * p )
114
116
{
115
117
/* (sum(u) .^ 2) ./ 2 */
116
118
sunrealtype * uarr = N_VGetArrayPointer (u );
@@ -119,7 +121,7 @@ sunrealtype g(N_Vector u, const sunrealtype* p, sunrealtype t)
119
121
return (sum * sum ) / SUN_RCONST (2.0 );
120
122
}
121
123
122
- void dgdu (N_Vector uvec , N_Vector dgvec , const sunrealtype * p , sunrealtype t )
124
+ void dgdu (N_Vector uvec , N_Vector dgvec , const sunrealtype * p )
123
125
{
124
126
sunrealtype * u = N_VGetArrayPointer (uvec );
125
127
sunrealtype * dg = N_VGetArrayPointer (dgvec );
@@ -128,7 +130,7 @@ void dgdu(N_Vector uvec, N_Vector dgvec, const sunrealtype* p, sunrealtype t)
128
130
dg [1 ] = u [0 ] + u [1 ];
129
131
}
130
132
131
- void dgdp (N_Vector uvec , N_Vector dgvec , const sunrealtype * p , sunrealtype t )
133
+ void dgdp (N_Vector uvec , N_Vector dgvec , const sunrealtype * p )
132
134
{
133
135
sunrealtype * u = N_VGetArrayPointer (uvec );
134
136
sunrealtype * dg = N_VGetArrayPointer (dgvec );
@@ -139,58 +141,9 @@ void dgdp(N_Vector uvec, N_Vector dgvec, const sunrealtype* p, sunrealtype t)
139
141
dg [3 ] = SUN_RCONST (0.0 );
140
142
}
141
143
142
- int forward_solution (SUNContext sunctx , void * arkode_mem ,
143
- SUNAdjointCheckpointScheme checkpoint_scheme ,
144
- const sunrealtype t0 , const sunrealtype tf ,
145
- const sunrealtype dt , N_Vector u )
146
- {
147
- int retval = 0 ;
148
-
149
- retval = ARKodeSetUserData (arkode_mem , (void * )params );
150
- retval = ARKodeSetFixedStep (arkode_mem , dt );
151
-
152
- sunrealtype t = t0 ;
153
- while (t < tf )
154
- {
155
- retval = ARKodeEvolve (arkode_mem , tf , u , & t , ARK_NORMAL );
156
- if (retval < 0 )
157
- {
158
- fprintf (stderr , ">>> ERROR: ARKodeEvolve returned %d\n" , retval );
159
- return -1 ;
160
- }
161
- }
162
-
163
- printf ("Forward Solution:\n" );
164
- N_VPrint (u );
165
-
166
- printf ("ARKODE Stats for Forward Solution:\n" );
167
- ARKodePrintAllStats (arkode_mem , stdout , SUN_OUTPUTFORMAT_TABLE );
168
- printf ("\n" );
169
-
170
- return 0 ;
171
- }
172
-
173
- int adjoint_solution (SUNContext sunctx , SUNAdjointStepper adj_stepper ,
174
- SUNAdjointCheckpointScheme checkpoint_scheme ,
175
- const sunrealtype tf , const sunrealtype tout , N_Vector sf )
176
- {
177
- int retval = 0 ;
178
- int stop_reason = 0 ;
179
- sunrealtype t = tf ;
180
- retval = SUNAdjointStepper_Evolve (adj_stepper , tout , sf , & t , & stop_reason );
181
-
182
- printf ("Adjoint Solution:\n" );
183
- N_VPrint (sf );
184
-
185
- printf ("\nSUNAdjointStepper Stats:\n" );
186
- SUNAdjointStepper_PrintAllStats (adj_stepper , stdout , SUN_OUTPUTFORMAT_TABLE );
187
- printf ("\n" );
188
-
189
- return 0 ;
190
- }
191
-
192
144
int main (int argc , char * argv [])
193
145
{
146
+ int retval = 0 ;
194
147
SUNContext sunctx = NULL ;
195
148
SUNContext_Create (SUN_COMM_NULL , & sunctx );
196
149
@@ -209,6 +162,8 @@ int main(int argc, char* argv[])
209
162
210
163
sunindextype neq = 2 ;
211
164
N_Vector u = N_VNew_Serial (neq , sunctx );
165
+ N_Vector u0 = N_VClone (u );
166
+ N_VConst (1.0 , u0 );
212
167
N_VConst (1.0 , u );
213
168
214
169
//
@@ -222,8 +177,11 @@ int main(int argc, char* argv[])
222
177
const int order = args .order ;
223
178
void * arkode_mem = ARKStepCreate (lotka_volterra , NULL , t0 , u , sunctx );
224
179
225
- ARKodeSetOrder (arkode_mem , order );
226
- ARKodeSetMaxNumSteps (arkode_mem , nsteps * 2 );
180
+ retval = ARKodeSetOrder (arkode_mem , order );
181
+ if (check_retval (& retval , "ARKodeSetOrder" , 1 )) { return 1 ; }
182
+
183
+ retval = ARKodeSetMaxNumSteps (arkode_mem , nsteps * 2 );
184
+ if (check_retval (& retval , "ARKodeSetMaxNumSteps" , 1 )) { return 1 ; }
227
185
228
186
// Enable checkpointing during the forward solution.
229
187
const int check_interval = args .check_freq ;
@@ -232,10 +190,22 @@ int main(int argc, char* argv[])
232
190
const sunbooleantype keep_check = args .keep_checks ;
233
191
SUNAdjointCheckpointScheme checkpoint_scheme = NULL ;
234
192
SUNMemoryHelper mem_helper = SUNMemoryHelper_Sys (sunctx );
235
- SUNAdjointCheckpointScheme_Create_Basic (SUNDATAIOMODE_INMEM , mem_helper ,
236
- check_interval , ncheck , save_stages ,
237
- keep_check , sunctx , & checkpoint_scheme );
238
- ARKodeSetAdjointCheckpointScheme (arkode_mem , checkpoint_scheme );
193
+
194
+ retval = SUNAdjointCheckpointScheme_Create_Basic (SUNDATAIOMODE_INMEM ,
195
+ mem_helper , check_interval ,
196
+ ncheck , save_stages ,
197
+ keep_check , sunctx ,
198
+ & checkpoint_scheme );
199
+ if (check_retval (& retval , "SUNAdjointCheckpointScheme_Create_Basic" , 1 ))
200
+ {
201
+ return 1 ;
202
+ }
203
+
204
+ retval = ARKodeSetAdjointCheckpointScheme (arkode_mem , checkpoint_scheme );
205
+ if (check_retval (& retval , "ARKodeSetAdjointCheckpointScheme" , 1 ))
206
+ {
207
+ return 1 ;
208
+ }
239
209
240
210
//
241
211
// Compute the forward solution
@@ -244,7 +214,30 @@ int main(int argc, char* argv[])
244
214
printf ("Initial condition:\n" );
245
215
N_VPrint (u );
246
216
247
- forward_solution (sunctx , arkode_mem , checkpoint_scheme , t0 , tf , dt , u );
217
+ retval = ARKodeSetUserData (arkode_mem , (void * )params );
218
+ if (check_retval (& retval , "ARKodeSetUserData" , 1 )) { return 1 ; }
219
+
220
+ retval = ARKodeSetFixedStep (arkode_mem , dt );
221
+ if (check_retval (& retval , "ARKodeSetFixedStep" , 1 )) { return 1 ; }
222
+
223
+ sunrealtype tret = t0 ;
224
+ while (tret < tf )
225
+ {
226
+ retval = ARKodeEvolve (arkode_mem , tf , u , & tret , ARK_NORMAL );
227
+ if (retval < 0 )
228
+ {
229
+ fprintf (stderr , ">>> ERROR: ARKodeEvolve returned %d\n" , retval );
230
+ return -1 ;
231
+ }
232
+ }
233
+
234
+ printf ("Forward Solution:\n" );
235
+ N_VPrint (u );
236
+
237
+ printf ("ARKODE Stats for Forward Solution:\n" );
238
+ retval = ARKodePrintAllStats (arkode_mem , stdout , SUN_OUTPUTFORMAT_TABLE );
239
+ if (check_retval (& retval , "ARKodePrintAllStats" , 1 )) { return 1 ; }
240
+ printf ("\n" );
248
241
249
242
//
250
243
// Create the adjoint stepper
@@ -258,21 +251,44 @@ int main(int argc, char* argv[])
258
251
259
252
// Set the terminal condition for the adjoint system, which
260
253
// should be the the gradient of our cost function at tf.
261
- dgdu (u , sensu0 , params , tf );
262
- dgdp (u , sensp , params , tf );
254
+ dgdu (u , sensu0 , params );
255
+ dgdp (u , sensp , params );
263
256
264
257
printf ("Adjoint terminal condition:\n" );
265
258
N_VPrint (sf );
266
259
267
260
SUNAdjointStepper adj_stepper ;
268
- ARKStepCreateAdjointStepper (arkode_mem , sf , & adj_stepper );
261
+ retval = ARKStepCreateAdjointStepper (arkode_mem , sf , & adj_stepper );
262
+ if (check_retval (& retval , "ARKStepCreateAdjointStepper" , 1 )) { return 1 ; }
263
+
264
+ retval = SUNAdjointStepper_SetVecTimesJacFn (adj_stepper , vjp , parameter_vjp );
265
+ if (check_retval (& retval , "SUNAdjointStepper_SetVecTimesJacFn" , 1 ))
266
+ {
267
+ return 1 ;
268
+ }
269
269
270
270
//
271
271
// Now compute the adjoint solution
272
272
//
273
273
274
- SUNAdjointStepper_SetVecTimesJacFn (adj_stepper , vjp , parameter_vjp );
275
- adjoint_solution (sunctx , adj_stepper , checkpoint_scheme , tf , t0 , sf );
274
+ int stop_reason = 0 ;
275
+ retval = SUNAdjointStepper_Evolve (adj_stepper , t0 , sf , & tret , & stop_reason );
276
+ if (check_retval (& retval , "SUNAdjointStepper_Evolve" , 1 )) { return 1 ; }
277
+
278
+ // Compute gradient w.r.t. parameters:
279
+ N_Vector tmp = N_VClone (sensp );
280
+ parameter_vjp (sensp , tmp , tret , u0 , NULL , (void * )params , NULL );
281
+ N_VLinearSum (1.0 , sensp , 1.0 , tmp , sensp );
282
+ N_VDestroy (tmp );
283
+
284
+ printf ("Adjoint Solution:\n" );
285
+ N_VPrint (sf );
286
+
287
+ printf ("\nSUNAdjointStepper Stats:\n" );
288
+ retval = SUNAdjointStepper_PrintAllStats (adj_stepper , stdout ,
289
+ SUN_OUTPUTFORMAT_TABLE );
290
+ if (check_retval (& retval , "SUNAdjointStepper_PrintAllStats" , 1 )) { return 1 ; }
291
+ printf ("\n" );
276
292
277
293
//
278
294
// Cleanup
@@ -321,3 +337,30 @@ void parse_args(int argc, char* argv[], ProgramArgs* args)
321
337
else { print_help (argc , argv , 1 ); }
322
338
}
323
339
}
340
+
341
+ int check_retval (void * retval_ptr , const char * funcname , int opt )
342
+ {
343
+ int * retval ;
344
+
345
+ /* Check if SUNDIALS function returned NULL pointer - no memory allocated */
346
+ if (opt == 0 && retval_ptr == NULL )
347
+ {
348
+ fprintf (stderr , "\nSUNDIALS_ERROR: %s() failed - returned NULL pointer\n\n" ,
349
+ funcname );
350
+ return 1 ;
351
+ }
352
+
353
+ /* Check if retval < 0 */
354
+ else if (opt == 1 )
355
+ {
356
+ retval = (int * )retval_ptr ;
357
+ if (* retval < 0 )
358
+ {
359
+ fprintf (stderr , "\nSUNDIALS_ERROR: %s() failed with retval = %d\n\n" ,
360
+ funcname , * retval );
361
+ return 1 ;
362
+ }
363
+ }
364
+
365
+ return (0 );
366
+ }
0 commit comments