Skip to content

Commit 8f062c0

Browse files
committed
bind Get functions that use inplace modification
1 parent a8f7004 commit 8f062c0

File tree

6 files changed

+175
-23
lines changed

6 files changed

+175
-23
lines changed

bindings/pysundials/arkode/pysundials_arkode_generated.hpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,18 @@ m.def("ARKodeRootInit",
109109
ARKodeRootInit, nb::arg("arkode_mem"), nb::arg("nrtfn"), nb::arg("g"));
110110

111111
m.def("ARKodeSetRootDirection",
112-
ARKodeSetRootDirection, nb::arg("arkode_mem"), nb::arg("rootdir"));
112+
[](void * arkode_mem, int rootdir) -> std::tuple<int, int>
113+
{
114+
auto ARKodeSetRootDirection_adapt_modifiable_immutable_to_return = [](void * arkode_mem, int rootdir) -> std::tuple<int, int>
115+
{
116+
int * rootdir_adapt_modifiable = & rootdir;
117+
118+
int r = ARKodeSetRootDirection(arkode_mem, rootdir_adapt_modifiable);
119+
return std::make_tuple(r, rootdir);
120+
};
121+
122+
return ARKodeSetRootDirection_adapt_modifiable_immutable_to_return(arkode_mem, rootdir);
123+
}, nb::arg("arkode_mem"), nb::arg("rootdir"));
113124

114125
m.def("ARKodeSetNoInactiveRootWarn",
115126
ARKodeSetNoInactiveRootWarn, nb::arg("arkode_mem"));
@@ -292,7 +303,18 @@ m.def("ARKodeGetNumGEvals",
292303
ARKodeGetNumGEvals, nb::arg("arkode_mem"), nb::arg("ngevals"));
293304

294305
m.def("ARKodeGetRootInfo",
295-
ARKodeGetRootInfo, nb::arg("arkode_mem"), nb::arg("rootsfound"));
306+
[](void * arkode_mem, int rootsfound) -> std::tuple<int, int>
307+
{
308+
auto ARKodeGetRootInfo_adapt_modifiable_immutable_to_return = [](void * arkode_mem, int rootsfound) -> std::tuple<int, int>
309+
{
310+
int * rootsfound_adapt_modifiable = & rootsfound;
311+
312+
int r = ARKodeGetRootInfo(arkode_mem, rootsfound_adapt_modifiable);
313+
return std::make_tuple(r, rootsfound);
314+
};
315+
316+
return ARKodeGetRootInfo_adapt_modifiable_immutable_to_return(arkode_mem, rootsfound);
317+
}, nb::arg("arkode_mem"), nb::arg("rootsfound"));
296318

297319
m.def("ARKodePrintAllStats",
298320
ARKodePrintAllStats, nb::arg("arkode_mem"), nb::arg("outfile"), nb::arg("fmt"));

bindings/pysundials/arkode/pysundials_arkode_lsrkstep_generated.hpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,18 @@ m.def("LSRKStepGetNumDomEigUpdates",
3838
LSRKStepGetNumDomEigUpdates, nb::arg("arkode_mem"), nb::arg("dom_eig_num_evals"));
3939

4040
m.def("LSRKStepGetMaxNumStages",
41-
LSRKStepGetMaxNumStages, nb::arg("arkode_mem"), nb::arg("stage_max"));
41+
[](void * arkode_mem, int stage_max) -> std::tuple<int, int>
42+
{
43+
auto LSRKStepGetMaxNumStages_adapt_modifiable_immutable_to_return = [](void * arkode_mem, int stage_max) -> std::tuple<int, int>
44+
{
45+
int * stage_max_adapt_modifiable = & stage_max;
46+
47+
int r = LSRKStepGetMaxNumStages(arkode_mem, stage_max_adapt_modifiable);
48+
return std::make_tuple(r, stage_max);
49+
};
50+
51+
return LSRKStepGetMaxNumStages_adapt_modifiable_immutable_to_return(arkode_mem, stage_max);
52+
}, nb::arg("arkode_mem"), nb::arg("stage_max"));
4253
// #ifdef __cplusplus
4354
//
4455
// #endif

bindings/pysundials/cvodes/pysundials_cvodes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <nanobind/nanobind.h>
1818
#include <nanobind/stl/function.h>
1919
#include <nanobind/stl/optional.h>
20+
#include <nanobind/stl/tuple.h>
2021

2122
#include <sundials/sundials_core.hpp>
2223

bindings/pysundials/cvodes/pysundials_cvodes_generated.hpp

Lines changed: 121 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,18 @@ m.def("CVodeRootInit",
182182
CVodeRootInit, nb::arg("cvode_mem"), nb::arg("nrtfn"), nb::arg("g"));
183183

184184
m.def("CVodeSetRootDirection",
185-
CVodeSetRootDirection, nb::arg("cvode_mem"), nb::arg("rootdir"));
185+
[](void * cvode_mem, int rootdir) -> std::tuple<int, int>
186+
{
187+
auto CVodeSetRootDirection_adapt_modifiable_immutable_to_return = [](void * cvode_mem, int rootdir) -> std::tuple<int, int>
188+
{
189+
int * rootdir_adapt_modifiable = & rootdir;
190+
191+
int r = CVodeSetRootDirection(cvode_mem, rootdir_adapt_modifiable);
192+
return std::make_tuple(r, rootdir);
193+
};
194+
195+
return CVodeSetRootDirection_adapt_modifiable_immutable_to_return(cvode_mem, rootdir);
196+
}, nb::arg("cvode_mem"), nb::arg("rootdir"));
186197

187198
m.def("CVodeSetNoInactiveRootWarn",
188199
CVodeSetNoInactiveRootWarn, nb::arg("cvode_mem"));
@@ -209,10 +220,32 @@ m.def("CVodeGetNumErrTestFails",
209220
CVodeGetNumErrTestFails, nb::arg("cvode_mem"), nb::arg("netfails"));
210221

211222
m.def("CVodeGetLastOrder",
212-
CVodeGetLastOrder, nb::arg("cvode_mem"), nb::arg("qlast"));
223+
[](void * cvode_mem, int qlast) -> std::tuple<int, int>
224+
{
225+
auto CVodeGetLastOrder_adapt_modifiable_immutable_to_return = [](void * cvode_mem, int qlast) -> std::tuple<int, int>
226+
{
227+
int * qlast_adapt_modifiable = & qlast;
228+
229+
int r = CVodeGetLastOrder(cvode_mem, qlast_adapt_modifiable);
230+
return std::make_tuple(r, qlast);
231+
};
232+
233+
return CVodeGetLastOrder_adapt_modifiable_immutable_to_return(cvode_mem, qlast);
234+
}, nb::arg("cvode_mem"), nb::arg("qlast"));
213235

214236
m.def("CVodeGetCurrentOrder",
215-
CVodeGetCurrentOrder, nb::arg("cvode_mem"), nb::arg("qcur"));
237+
[](void * cvode_mem, int qcur) -> std::tuple<int, int>
238+
{
239+
auto CVodeGetCurrentOrder_adapt_modifiable_immutable_to_return = [](void * cvode_mem, int qcur) -> std::tuple<int, int>
240+
{
241+
int * qcur_adapt_modifiable = & qcur;
242+
243+
int r = CVodeGetCurrentOrder(cvode_mem, qcur_adapt_modifiable);
244+
return std::make_tuple(r, qcur);
245+
};
246+
247+
return CVodeGetCurrentOrder_adapt_modifiable_immutable_to_return(cvode_mem, qcur);
248+
}, nb::arg("cvode_mem"), nb::arg("qcur"));
216249

217250
m.def("CVodeGetCurrentGamma",
218251
CVodeGetCurrentGamma, nb::arg("cvode_mem"), nb::arg("gamma"));
@@ -230,7 +263,18 @@ m.def("CVodeGetCurrentStep",
230263
CVodeGetCurrentStep, nb::arg("cvode_mem"), nb::arg("hcur"));
231264

232265
m.def("CVodeGetCurrentSensSolveIndex",
233-
CVodeGetCurrentSensSolveIndex, nb::arg("cvode_mem"), nb::arg("index"));
266+
[](void * cvode_mem, int index) -> std::tuple<int, int>
267+
{
268+
auto CVodeGetCurrentSensSolveIndex_adapt_modifiable_immutable_to_return = [](void * cvode_mem, int index) -> std::tuple<int, int>
269+
{
270+
int * index_adapt_modifiable = & index;
271+
272+
int r = CVodeGetCurrentSensSolveIndex(cvode_mem, index_adapt_modifiable);
273+
return std::make_tuple(r, index);
274+
};
275+
276+
return CVodeGetCurrentSensSolveIndex_adapt_modifiable_immutable_to_return(cvode_mem, index);
277+
}, nb::arg("cvode_mem"), nb::arg("index"));
234278

235279
m.def("CVodeGetCurrentTime",
236280
CVodeGetCurrentTime, nb::arg("cvode_mem"), nb::arg("tcur"));
@@ -248,10 +292,33 @@ m.def("CVodeGetNumGEvals",
248292
CVodeGetNumGEvals, nb::arg("cvode_mem"), nb::arg("ngevals"));
249293

250294
m.def("CVodeGetRootInfo",
251-
CVodeGetRootInfo, nb::arg("cvode_mem"), nb::arg("rootsfound"));
295+
[](void * cvode_mem, int rootsfound) -> std::tuple<int, int>
296+
{
297+
auto CVodeGetRootInfo_adapt_modifiable_immutable_to_return = [](void * cvode_mem, int rootsfound) -> std::tuple<int, int>
298+
{
299+
int * rootsfound_adapt_modifiable = & rootsfound;
300+
301+
int r = CVodeGetRootInfo(cvode_mem, rootsfound_adapt_modifiable);
302+
return std::make_tuple(r, rootsfound);
303+
};
304+
305+
return CVodeGetRootInfo_adapt_modifiable_immutable_to_return(cvode_mem, rootsfound);
306+
}, nb::arg("cvode_mem"), nb::arg("rootsfound"));
252307

253308
m.def("CVodeGetIntegratorStats",
254-
CVodeGetIntegratorStats, nb::arg("cvode_mem"), nb::arg("nsteps"), nb::arg("nfevals"), nb::arg("nlinsetups"), nb::arg("netfails"), nb::arg("qlast"), nb::arg("qcur"), nb::arg("hinused"), nb::arg("hlast"), nb::arg("hcur"), nb::arg("tcur"));
309+
[](void * cvode_mem, long int * nsteps, long int * nfevals, long int * nlinsetups, long int * netfails, int qlast, int qcur, sunrealtype * hinused, sunrealtype * hlast, sunrealtype * hcur, sunrealtype * tcur) -> std::tuple<int, int, int>
310+
{
311+
auto CVodeGetIntegratorStats_adapt_modifiable_immutable_to_return = [](void * cvode_mem, long int * nsteps, long int * nfevals, long int * nlinsetups, long int * netfails, int qlast, int qcur, sunrealtype * hinused, sunrealtype * hlast, sunrealtype * hcur, sunrealtype * tcur) -> std::tuple<int, int, int>
312+
{
313+
int * qlast_adapt_modifiable = & qlast;
314+
int * qcur_adapt_modifiable = & qcur;
315+
316+
int r = CVodeGetIntegratorStats(cvode_mem, nsteps, nfevals, nlinsetups, netfails, qlast_adapt_modifiable, qcur_adapt_modifiable, hinused, hlast, hcur, tcur);
317+
return std::make_tuple(r, qlast, qcur);
318+
};
319+
320+
return CVodeGetIntegratorStats_adapt_modifiable_immutable_to_return(cvode_mem, nsteps, nfevals, nlinsetups, netfails, qlast, qcur, hinused, hlast, hcur, tcur);
321+
}, nb::arg("cvode_mem"), nb::arg("nsteps"), nb::arg("nfevals"), nb::arg("nlinsetups"), nb::arg("netfails"), nb::arg("qlast"), nb::arg("qcur"), nb::arg("hinused"), nb::arg("hlast"), nb::arg("hcur"), nb::arg("tcur"));
255322

256323
m.def("CVodeGetNumNonlinSolvIters",
257324
CVodeGetNumNonlinSolvIters, nb::arg("cvode_mem"), nb::arg("nniters"));
@@ -320,7 +387,18 @@ m.def("CVodeSetSensMaxNonlinIters",
320387
CVodeSetSensMaxNonlinIters, nb::arg("cvode_mem"), nb::arg("maxcorS"));
321388

322389
m.def("CVodeSetSensParams",
323-
CVodeSetSensParams, nb::arg("cvode_mem"), nb::arg("p"), nb::arg("pbar"), nb::arg("plist"));
390+
[](void * cvode_mem, sunrealtype * p, sunrealtype * pbar, int plist) -> std::tuple<int, int>
391+
{
392+
auto CVodeSetSensParams_adapt_modifiable_immutable_to_return = [](void * cvode_mem, sunrealtype * p, sunrealtype * pbar, int plist) -> std::tuple<int, int>
393+
{
394+
int * plist_adapt_modifiable = & plist;
395+
396+
int r = CVodeSetSensParams(cvode_mem, p, pbar, plist_adapt_modifiable);
397+
return std::make_tuple(r, plist);
398+
};
399+
400+
return CVodeSetSensParams_adapt_modifiable_immutable_to_return(cvode_mem, p, pbar, plist);
401+
}, nb::arg("cvode_mem"), nb::arg("p"), nb::arg("pbar"), nb::arg("plist"));
324402

325403
m.def("CVodeSetNonlinearSolverSensSim",
326404
CVodeSetNonlinearSolverSensSim, nb::arg("cvode_mem"), nb::arg("NLS"));
@@ -368,7 +446,18 @@ m.def("CVodeAdjReInit",
368446
CVodeAdjReInit, nb::arg("cvode_mem"));
369447

370448
m.def("CVodeCreateB",
371-
CVodeCreateB, nb::arg("cvode_mem"), nb::arg("lmmB"), nb::arg("which"));
449+
[](void * cvode_mem, int lmmB, int which) -> std::tuple<int, int>
450+
{
451+
auto CVodeCreateB_adapt_modifiable_immutable_to_return = [](void * cvode_mem, int lmmB, int which) -> std::tuple<int, int>
452+
{
453+
int * which_adapt_modifiable = & which;
454+
455+
int r = CVodeCreateB(cvode_mem, lmmB, which_adapt_modifiable);
456+
return std::make_tuple(r, which);
457+
};
458+
459+
return CVodeCreateB_adapt_modifiable_immutable_to_return(cvode_mem, lmmB, which);
460+
}, nb::arg("cvode_mem"), nb::arg("lmmB"), nb::arg("which"));
372461

373462
m.def("CVodeReInitB",
374463
CVodeReInitB, nb::arg("cvode_mem"), nb::arg("which"), nb::arg("tB0"), nb::arg("yB0"));
@@ -395,7 +484,18 @@ m.def("CVodeQuadSVtolerancesB",
395484
CVodeQuadSVtolerancesB, nb::arg("cvode_mem"), nb::arg("which"), nb::arg("reltolQB"), nb::arg("abstolQB"));
396485

397486
m.def("CVodeF",
398-
CVodeF, nb::arg("cvode_mem"), nb::arg("tout"), nb::arg("yout"), nb::arg("tret"), nb::arg("itask"), nb::arg("ncheckPtr"));
487+
[](void * cvode_mem, sunrealtype tout, N_Vector yout, sunrealtype * tret, int itask, int ncheckPtr) -> std::tuple<int, int>
488+
{
489+
auto CVodeF_adapt_modifiable_immutable_to_return = [](void * cvode_mem, sunrealtype tout, N_Vector yout, sunrealtype * tret, int itask, int ncheckPtr) -> std::tuple<int, int>
490+
{
491+
int * ncheckPtr_adapt_modifiable = & ncheckPtr;
492+
493+
int r = CVodeF(cvode_mem, tout, yout, tret, itask, ncheckPtr_adapt_modifiable);
494+
return std::make_tuple(r, ncheckPtr);
495+
};
496+
497+
return CVodeF_adapt_modifiable_immutable_to_return(cvode_mem, tout, yout, tret, itask, ncheckPtr);
498+
}, nb::arg("cvode_mem"), nb::arg("tout"), nb::arg("yout"), nb::arg("tret"), nb::arg("itask"), nb::arg("ncheckPtr"));
399499

400500
m.def("CVodeB",
401501
CVodeB, nb::arg("cvode_mem"), nb::arg("tBout"), nb::arg("itaskB"));
@@ -449,7 +549,18 @@ m.def("CVodeGetAdjDataPointHermite",
449549
CVodeGetAdjDataPointHermite, nb::arg("cvode_mem"), nb::arg("which"), nb::arg("t"), nb::arg("y"), nb::arg("yd"));
450550

451551
m.def("CVodeGetAdjDataPointPolynomial",
452-
CVodeGetAdjDataPointPolynomial, nb::arg("cvode_mem"), nb::arg("which"), nb::arg("t"), nb::arg("order"), nb::arg("y"));
552+
[](void * cvode_mem, int which, sunrealtype * t, int order, N_Vector y) -> std::tuple<int, int>
553+
{
554+
auto CVodeGetAdjDataPointPolynomial_adapt_modifiable_immutable_to_return = [](void * cvode_mem, int which, sunrealtype * t, int order, N_Vector y) -> std::tuple<int, int>
555+
{
556+
int * order_adapt_modifiable = & order;
557+
558+
int r = CVodeGetAdjDataPointPolynomial(cvode_mem, which, t, order_adapt_modifiable, y);
559+
return std::make_tuple(r, order);
560+
};
561+
562+
return CVodeGetAdjDataPointPolynomial_adapt_modifiable_immutable_to_return(cvode_mem, which, t, order, y);
563+
}, nb::arg("cvode_mem"), nb::arg("which"), nb::arg("t"), nb::arg("order"), nb::arg("y"));
453564
// #ifdef __cplusplus
454565
//
455566
// #endif

bindings/pysundials/generate.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,22 @@ def main():
8080
options.bind_library = litgen.BindLibraryType.nanobind
8181
options.python_run_black_formatter = True
8282
options.python_convert_to_snake_case = False
83+
84+
# Don't capture comments from the source for generating Python doc strings
8385
options.comments_exclude = True
86+
87+
# Export enum values to the package namespace
8488
options.enum_export_values = True
89+
90+
# Allow const char to be nullable
8591
options.fn_params_const_char_pointer_with_default_null = True
92+
93+
# Transform inplace modification of values, e.g. int CVodeGetNumSteps(void* cvode_mem, long* num_steps), to CvodeGetNumSteps(cvode_mem) -> Tuple[int, long]
94+
options.fn_params_output_modifiable_immutable_to_return__regex = r".*"
95+
96+
# Our own custom function adapters
8697
options.fn_custom_adapters = [adapt_default_arg_pointer_with_default_null]
98+
8799
options.srcmlcpp_options.code_preprocess_function = preprocess_header
88100
options.srcmlcpp_options.ignored_warning_parts.append(
89101
# "ops" functions pointers cause this warning, but we dont care cause we dont need to bind those.
@@ -100,7 +112,6 @@ def main():
100112
if not config_object:
101113
raise RuntimeError(f"modules: section not found in {config_yaml_path}")
102114

103-
# print(config_object)
104115
for module_name in config_object:
105116
if module_name == "all":
106117
continue
@@ -125,12 +136,6 @@ def main():
125136
load_macro_defines_from_yaml(config_object, module_name)
126137
)
127138

128-
# TODO(CJB): this does not seem to work
129-
# options.fn_return_force_policy_reference_for_pointers__regex = code_utils.join_string_by_pipe_char([
130-
# "N_VNewEmpty",
131-
# "N_VMake"
132-
# ])
133-
134139
source_code = ""
135140
for file_path in module["headers"]:
136141
with open(file_path, "r") as file:

bindings/pysundials/test/test_cvodes.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
from pysundials.cvodes import *
66
from ode_problems import AnalyticODE
77

8-
def test_implicit():
9-
print(" testing implicit")
8+
def test_bdf():
109
sunctx = SUNContextView.Create()
1110
nv = NVectorView.Create(N_VNew_Serial(1, sunctx.get()))
1211
ls = SUNLinearSolverView.Create(SUNLinSol_SPGMR(nv.get(), 0, 0, sunctx.get()))
@@ -33,6 +32,9 @@ def test_implicit():
3332
status = CVode(solver.get(), tout, nv.get(), tret, CV_NORMAL)
3433
print(f"status={status}, ans={arr}")
3534

35+
status, last_order = CVodeGetLastOrder(solver.get(), 0)
36+
print(f"last_order={last_order}")
37+
3638

3739
if __name__ == "__main__":
38-
test_implicit()
40+
test_bdf()

0 commit comments

Comments
 (0)