Skip to content

Commit 7b10df9

Browse files
committed
fix CVode SetPreconditioner and SetJtimes nullables
1 parent c7faa9d commit 7b10df9

File tree

2 files changed

+62
-38
lines changed

2 files changed

+62
-38
lines changed

bindings/pysundials/cvodes/pysundials_cvodes.cpp

Lines changed: 60 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -50,21 +50,24 @@ using namespace sundials::experimental;
5050
})
5151

5252
#define BIND_CVODE_CALLBACK2(NAME, FN_TYPE1, MEMBER1, WRAPPER1, FN_TYPE2, \
53-
MEMBER2, WRAPPER2) \
54-
m.def(#NAME, \
55-
[](void* cv_mem, std::function<std::remove_pointer_t<FN_TYPE1>> fn1, \
56-
std::function<std::remove_pointer_t<FN_TYPE2>> fn2) \
57-
{ \
58-
void* user_data = nullptr; \
59-
CVodeGetUserData(cv_mem, &user_data); \
60-
if (!user_data) \
61-
throw std::runtime_error( \
62-
"Failed to get Python function table from CVODE memory"); \
63-
auto fntable = static_cast<cvode_user_supplied_fn_table*>(user_data); \
64-
fntable->MEMBER1 = nb::cast(fn1); \
65-
fntable->MEMBER2 = nb::cast(fn2); \
66-
return NAME(cv_mem, &WRAPPER1, &WRAPPER2); \
67-
})
53+
MEMBER2, WRAPPER2, ...) \
54+
m.def( \
55+
#NAME, \
56+
[](void* cv_mem, std::function<std::remove_pointer_t<FN_TYPE1>> fn1, \
57+
std::function<std::remove_pointer_t<FN_TYPE2>> fn2) \
58+
{ \
59+
void* user_data = nullptr; \
60+
CVodeGetUserData(cv_mem, &user_data); \
61+
if (!user_data) \
62+
throw std::runtime_error( \
63+
"Failed to get Python function table from CVODE memory"); \
64+
auto fntable = static_cast<cvode_user_supplied_fn_table*>(user_data); \
65+
fntable->MEMBER1 = nb::cast(fn1); \
66+
fntable->MEMBER2 = nb::cast(fn2); \
67+
if (fn1) { return NAME(cv_mem, WRAPPER1, WRAPPER2); } \
68+
else { return NAME(cv_mem, nullptr, WRAPPER2); } \
69+
}, \
70+
__VA_ARGS__)
6871

6972
#define BIND_CVODEB_CALLBACK(NAME, FN_TYPE, MEMBER, WRAPPER) \
7073
m.def(#NAME, \
@@ -82,22 +85,25 @@ using namespace sundials::experimental;
8285
})
8386

8487
#define BIND_CVODEB_CALLBACK2(NAME, FN_TYPE1, MEMBER1, WRAPPER1, FN_TYPE2, \
85-
MEMBER2, WRAPPER2) \
86-
m.def(#NAME, \
87-
[](void* cv_mem, int which, \
88-
std::function<std::remove_pointer_t<FN_TYPE1>> fn1, \
89-
std::function<std::remove_pointer_t<FN_TYPE2>> fn2) \
90-
{ \
91-
void* user_data = nullptr; \
92-
CVodeGetUserDataB(cv_mem, which, &user_data); \
93-
if (!user_data) \
94-
throw std::runtime_error( \
95-
"Failed to get Python function table from CVODE memory"); \
96-
auto fntable = static_cast<cvodea_user_supplied_fn_table*>(user_data); \
97-
fntable->MEMBER1 = nb::cast(fn1); \
98-
fntable->MEMBER2 = nb::cast(fn2); \
99-
return NAME(cv_mem, which, &WRAPPER1, &WRAPPER2); \
100-
})
88+
MEMBER2, WRAPPER2, ...) \
89+
m.def( \
90+
#NAME, \
91+
[](void* cv_mem, int which, \
92+
std::function<std::remove_pointer_t<FN_TYPE1>> fn1, \
93+
std::function<std::remove_pointer_t<FN_TYPE2>> fn2) \
94+
{ \
95+
void* user_data = nullptr; \
96+
CVodeGetUserDataB(cv_mem, which, &user_data); \
97+
if (!user_data) \
98+
throw std::runtime_error( \
99+
"Failed to get Python function table from CVODE memory"); \
100+
auto fntable = static_cast<cvodea_user_supplied_fn_table*>(user_data); \
101+
fntable->MEMBER1 = nb::cast(fn1); \
102+
fntable->MEMBER2 = nb::cast(fn2); \
103+
if (fn1) { return NAME(cv_mem, which, WRAPPER1, WRAPPER2); } \
104+
else { return NAME(cv_mem, which, nullptr, WRAPPER2); } \
105+
}, \
106+
__VA_ARGS__)
101107

102108
void bind_cvodes(nb::module_& m)
103109
{
@@ -176,11 +182,15 @@ void bind_cvodes(nb::module_& m)
176182

177183
BIND_CVODE_CALLBACK2(CVodeSetPreconditioner, CVLsPrecSetupFn, lsprecsetupfn,
178184
cvode_lsprecsetupfn_wrapper, CVLsPrecSolveFn,
179-
lsprecsolvefn, cvode_lsprecsolvefn_wrapper);
185+
lsprecsolvefn, cvode_lsprecsolvefn_wrapper,
186+
nb::arg("cv_mem"), nb::arg("psetup").none(),
187+
nb::arg("psolve"));
180188

181189
BIND_CVODE_CALLBACK2(CVodeSetJacTimes, CVLsJacTimesSetupFn, lsjactimessetupfn,
182190
cvode_lsjactimessetupfn_wrapper, CVLsJacTimesVecFn,
183-
lsjactimesvecfn, cvode_lsjactimesvecfn_wrapper);
191+
lsjactimesvecfn, cvode_lsjactimesvecfn_wrapper,
192+
nb::arg("cv_mem"), nb::arg("jsetup").none(),
193+
nb::arg("jtimes"));
184194

185195
BIND_CVODE_CALLBACK(CVodeSetLinSysFn, CVLsLinSysFn, lslinsysfn,
186196
cvode_lslinsysfn_wrapper);
@@ -314,13 +324,19 @@ void bind_cvodes(nb::module_& m)
314324

315325
BIND_CVODEB_CALLBACK(CVodeSetJacFnB, CVLsJacFnB, lsjacfnB,
316326
cvode_lsjacfnB_wrapper);
327+
317328
BIND_CVODEB_CALLBACK2(CVodeSetPreconditionerB, CVLsPrecSetupFnB, lsprecsetupfnB,
318329
cvode_lsprecsetupfnB_wrapper, CVLsPrecSolveFnB,
319-
lsprecsolvefnB, cvode_lsprecsolvefnB_wrapper);
330+
lsprecsolvefnB, cvode_lsprecsolvefnB_wrapper,
331+
nb::arg("cv_mem"), nb::arg("which"),
332+
nb::arg("psetupB").none(), nb::arg("psolveB"));
333+
320334
BIND_CVODEB_CALLBACK2(CVodeSetJacTimesB, CVLsJacTimesSetupFnB,
321335
lsjactimessetupfnB, cvode_lsjactimessetupfnB_wrapper,
322336
CVLsJacTimesVecFnB, lsjactimesvecfnB,
323-
cvode_lsjactimesvecfnB_wrapper);
337+
cvode_lsjactimesvecfnB_wrapper, nb::arg("cv_mem"),
338+
nb::arg("which"), nb::arg("jsetupB").none(),
339+
nb::arg("jtimesB"));
324340

325341
BIND_CVODEB_CALLBACK(CVodeSetLinSysFnB, CVLsLinSysFnB, lslinsysfnB,
326342
cvode_lslinsysfnB_wrapper);
@@ -331,6 +347,7 @@ void bind_cvodes(nb::module_& m)
331347
N_Vector tmp1B, N_Vector tmp2B, N_Vector tmp3B);
332348
BIND_CVODEB_CALLBACK(CVodeSetJacFnBS, CVLsJacStdFnBS, lsjacfnBS,
333349
cvode_lsjacfnBS_wrapper);
350+
334351
using CVLsPrecSetupStdFnBS = int(sunrealtype t, N_Vector y, N_Vector yB,
335352
N_Vector fyB, sunbooleantype jokB,
336353
sunbooleantype * jcurPtrB,
@@ -342,7 +359,10 @@ void bind_cvodes(nb::module_& m)
342359
BIND_CVODEB_CALLBACK2(CVodeSetPreconditionerBS, CVLsPrecSetupStdFnBS,
343360
lsprecsetupfnBS, cvode_lsprecsetupfnBS_wrapper,
344361
CVLsPrecSolveStdFnBS, lsprecsolvefnBS,
345-
cvode_lsprecsolvefnBS_wrapper);
362+
cvode_lsprecsolvefnBS_wrapper, nb::arg("cv_mem"),
363+
nb::arg("which"), nb::arg("psetupBS").none(),
364+
nb::arg("psolveBS"));
365+
346366
using CVLsJacTimesSetupStdFnBS = int(sunrealtype t, N_Vector y,
347367
std::vector<N_Vector> yS, N_Vector yB,
348368
N_Vector fyB, void* jac_dataB);
@@ -353,7 +373,10 @@ void bind_cvodes(nb::module_& m)
353373
BIND_CVODEB_CALLBACK2(CVodeSetJacTimesBS, CVLsJacTimesSetupStdFnBS,
354374
lsjactimessetupfnBS, cvode_lsjactimessetupfnBS_wrapper,
355375
CVLsJacTimesVecStdFnBS, lsjactimesvecfnBS,
356-
cvode_lsjactimesvecfnBS_wrapper);
376+
cvode_lsjactimesvecfnBS_wrapper, nb::arg("cv_mem"),
377+
nb::arg("which"), nb::arg("jsetupBS").none(),
378+
nb::arg("jtimesBS"));
379+
357380
using CVLsLinSysStdFnBS =
358381
int(sunrealtype t, N_Vector y, std::vector<N_Vector> yS, N_Vector yB,
359382
N_Vector fyB, SUNMatrix AB, sunbooleantype jokB, sunbooleantype * jcurB,

bindings/pysundials/idas/pysundials_idas.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ using namespace sundials::experimental;
9797
auto fntable = static_cast<idasa_user_supplied_fn_table*>(user_data); \
9898
fntable->MEMBER1 = nb::cast(fn1); \
9999
fntable->MEMBER2 = nb::cast(fn2); \
100-
return NAME(ida_mem, which, &WRAPPER1, &WRAPPER2); \
100+
if (fn1) { return NAME(ida_mem, which, WRAPPER1, WRAPPER2); } \
101+
else { return NAME(ida_mem, which, nullptr, WRAPPER2); } \
101102
}, \
102103
__VA_ARGS__)
103104

0 commit comments

Comments
 (0)