@@ -50,21 +50,24 @@ using namespace sundials::experimental;
50
50
})
51
51
52
52
#define BIND_CVODE_CALLBACK2 (NAME, FN_TYPE1, MEMBER1, WRAPPER1, FN_TYPE2, \
53
- MEMBER2, WRAPPER2) \
54
- m.def(#NAME, \
55
- [](void * cv_mem, std::function<std::remove_pointer_t <FN_TYPE1>> fn1, \
56
- std::function<std::remove_pointer_t <FN_TYPE2>> fn2) \
57
- { \
58
- void * user_data = nullptr ; \
59
- CVodeGetUserData (cv_mem, &user_data); \
60
- if (!user_data) \
61
- throw std::runtime_error ( \
62
- " Failed to get Python function table from CVODE memory" ); \
63
- auto fntable = static_cast <cvode_user_supplied_fn_table*>(user_data); \
64
- fntable->MEMBER1 = nb::cast (fn1); \
65
- fntable->MEMBER2 = nb::cast (fn2); \
66
- return NAME (cv_mem, &WRAPPER1, &WRAPPER2); \
67
- })
53
+ MEMBER2, WRAPPER2, ...) \
54
+ m.def( \
55
+ #NAME, \
56
+ [](void * cv_mem, std::function<std::remove_pointer_t <FN_TYPE1>> fn1, \
57
+ std::function<std::remove_pointer_t <FN_TYPE2>> fn2) \
58
+ { \
59
+ void * user_data = nullptr ; \
60
+ CVodeGetUserData (cv_mem, &user_data); \
61
+ if (!user_data) \
62
+ throw std::runtime_error ( \
63
+ " Failed to get Python function table from CVODE memory" ); \
64
+ auto fntable = static_cast <cvode_user_supplied_fn_table*>(user_data); \
65
+ fntable->MEMBER1 = nb::cast (fn1); \
66
+ fntable->MEMBER2 = nb::cast (fn2); \
67
+ if (fn1) { return NAME (cv_mem, WRAPPER1, WRAPPER2); } \
68
+ else { return NAME (cv_mem, nullptr , WRAPPER2); } \
69
+ }, \
70
+ __VA_ARGS__)
68
71
69
72
#define BIND_CVODEB_CALLBACK (NAME, FN_TYPE, MEMBER, WRAPPER ) \
70
73
m.def(#NAME, \
@@ -82,22 +85,25 @@ using namespace sundials::experimental;
82
85
})
83
86
84
87
#define BIND_CVODEB_CALLBACK2 (NAME, FN_TYPE1, MEMBER1, WRAPPER1, FN_TYPE2, \
85
- MEMBER2, WRAPPER2) \
86
- m.def(#NAME, \
87
- [](void * cv_mem, int which, \
88
- std::function<std::remove_pointer_t <FN_TYPE1>> fn1, \
89
- std::function<std::remove_pointer_t <FN_TYPE2>> fn2) \
90
- { \
91
- void * user_data = nullptr ; \
92
- CVodeGetUserDataB (cv_mem, which, &user_data); \
93
- if (!user_data) \
94
- throw std::runtime_error ( \
95
- " Failed to get Python function table from CVODE memory" ); \
96
- auto fntable = static_cast <cvodea_user_supplied_fn_table*>(user_data); \
97
- fntable->MEMBER1 = nb::cast (fn1); \
98
- fntable->MEMBER2 = nb::cast (fn2); \
99
- return NAME (cv_mem, which, &WRAPPER1, &WRAPPER2); \
100
- })
88
+ MEMBER2, WRAPPER2, ...) \
89
+ m.def( \
90
+ #NAME, \
91
+ [](void * cv_mem, int which, \
92
+ std::function<std::remove_pointer_t <FN_TYPE1>> fn1, \
93
+ std::function<std::remove_pointer_t <FN_TYPE2>> fn2) \
94
+ { \
95
+ void * user_data = nullptr ; \
96
+ CVodeGetUserDataB (cv_mem, which, &user_data); \
97
+ if (!user_data) \
98
+ throw std::runtime_error ( \
99
+ " Failed to get Python function table from CVODE memory" ); \
100
+ auto fntable = static_cast <cvodea_user_supplied_fn_table*>(user_data); \
101
+ fntable->MEMBER1 = nb::cast (fn1); \
102
+ fntable->MEMBER2 = nb::cast (fn2); \
103
+ if (fn1) { return NAME (cv_mem, which, WRAPPER1, WRAPPER2); } \
104
+ else { return NAME (cv_mem, which, nullptr , WRAPPER2); } \
105
+ }, \
106
+ __VA_ARGS__)
101
107
102
108
void bind_cvodes (nb::module_& m)
103
109
{
@@ -176,11 +182,15 @@ void bind_cvodes(nb::module_& m)
176
182
177
183
BIND_CVODE_CALLBACK2 (CVodeSetPreconditioner, CVLsPrecSetupFn, lsprecsetupfn,
178
184
cvode_lsprecsetupfn_wrapper, CVLsPrecSolveFn,
179
- lsprecsolvefn, cvode_lsprecsolvefn_wrapper);
185
+ lsprecsolvefn, cvode_lsprecsolvefn_wrapper,
186
+ nb::arg (" cv_mem" ), nb::arg (" psetup" ).none (),
187
+ nb::arg (" psolve" ));
180
188
181
189
BIND_CVODE_CALLBACK2 (CVodeSetJacTimes, CVLsJacTimesSetupFn, lsjactimessetupfn,
182
190
cvode_lsjactimessetupfn_wrapper, CVLsJacTimesVecFn,
183
- lsjactimesvecfn, cvode_lsjactimesvecfn_wrapper);
191
+ lsjactimesvecfn, cvode_lsjactimesvecfn_wrapper,
192
+ nb::arg (" cv_mem" ), nb::arg (" jsetup" ).none (),
193
+ nb::arg (" jtimes" ));
184
194
185
195
BIND_CVODE_CALLBACK (CVodeSetLinSysFn, CVLsLinSysFn, lslinsysfn,
186
196
cvode_lslinsysfn_wrapper);
@@ -314,13 +324,19 @@ void bind_cvodes(nb::module_& m)
314
324
315
325
BIND_CVODEB_CALLBACK (CVodeSetJacFnB, CVLsJacFnB, lsjacfnB,
316
326
cvode_lsjacfnB_wrapper);
327
+
317
328
BIND_CVODEB_CALLBACK2 (CVodeSetPreconditionerB, CVLsPrecSetupFnB, lsprecsetupfnB,
318
329
cvode_lsprecsetupfnB_wrapper, CVLsPrecSolveFnB,
319
- lsprecsolvefnB, cvode_lsprecsolvefnB_wrapper);
330
+ lsprecsolvefnB, cvode_lsprecsolvefnB_wrapper,
331
+ nb::arg (" cv_mem" ), nb::arg (" which" ),
332
+ nb::arg (" psetupB" ).none (), nb::arg (" psolveB" ));
333
+
320
334
BIND_CVODEB_CALLBACK2 (CVodeSetJacTimesB, CVLsJacTimesSetupFnB,
321
335
lsjactimessetupfnB, cvode_lsjactimessetupfnB_wrapper,
322
336
CVLsJacTimesVecFnB, lsjactimesvecfnB,
323
- cvode_lsjactimesvecfnB_wrapper);
337
+ cvode_lsjactimesvecfnB_wrapper, nb::arg (" cv_mem" ),
338
+ nb::arg (" which" ), nb::arg (" jsetupB" ).none (),
339
+ nb::arg (" jtimesB" ));
324
340
325
341
BIND_CVODEB_CALLBACK (CVodeSetLinSysFnB, CVLsLinSysFnB, lslinsysfnB,
326
342
cvode_lslinsysfnB_wrapper);
@@ -331,6 +347,7 @@ void bind_cvodes(nb::module_& m)
331
347
N_Vector tmp1B, N_Vector tmp2B, N_Vector tmp3B);
332
348
BIND_CVODEB_CALLBACK (CVodeSetJacFnBS, CVLsJacStdFnBS, lsjacfnBS,
333
349
cvode_lsjacfnBS_wrapper);
350
+
334
351
using CVLsPrecSetupStdFnBS = int (sunrealtype t, N_Vector y, N_Vector yB,
335
352
N_Vector fyB, sunbooleantype jokB,
336
353
sunbooleantype * jcurPtrB,
@@ -342,7 +359,10 @@ void bind_cvodes(nb::module_& m)
342
359
BIND_CVODEB_CALLBACK2 (CVodeSetPreconditionerBS, CVLsPrecSetupStdFnBS,
343
360
lsprecsetupfnBS, cvode_lsprecsetupfnBS_wrapper,
344
361
CVLsPrecSolveStdFnBS, lsprecsolvefnBS,
345
- cvode_lsprecsolvefnBS_wrapper);
362
+ cvode_lsprecsolvefnBS_wrapper, nb::arg (" cv_mem" ),
363
+ nb::arg (" which" ), nb::arg (" psetupBS" ).none (),
364
+ nb::arg (" psolveBS" ));
365
+
346
366
using CVLsJacTimesSetupStdFnBS = int (sunrealtype t, N_Vector y,
347
367
std::vector<N_Vector> yS, N_Vector yB,
348
368
N_Vector fyB, void * jac_dataB);
@@ -353,7 +373,10 @@ void bind_cvodes(nb::module_& m)
353
373
BIND_CVODEB_CALLBACK2 (CVodeSetJacTimesBS, CVLsJacTimesSetupStdFnBS,
354
374
lsjactimessetupfnBS, cvode_lsjactimessetupfnBS_wrapper,
355
375
CVLsJacTimesVecStdFnBS, lsjactimesvecfnBS,
356
- cvode_lsjactimesvecfnBS_wrapper);
376
+ cvode_lsjactimesvecfnBS_wrapper, nb::arg (" cv_mem" ),
377
+ nb::arg (" which" ), nb::arg (" jsetupBS" ).none (),
378
+ nb::arg (" jtimesBS" ));
379
+
357
380
using CVLsLinSysStdFnBS =
358
381
int (sunrealtype t, N_Vector y, std::vector<N_Vector> yS, N_Vector yB,
359
382
N_Vector fyB, SUNMatrix AB, sunbooleantype jokB, sunbooleantype * jcurB,
0 commit comments