Skip to content

Commit 83d76a3

Browse files
committed
interface callbacks for SUNStepper
1 parent 6ae0a12 commit 83d76a3

File tree

5 files changed

+390
-2
lines changed

5 files changed

+390
-2
lines changed

bindings/pysundials/helpers/pysundials_helpers.hpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ int user_supplied_fn_caller(nb::object FnTableType::*fn_member, Args... args)
5555
/// \tparam Args is the template parameter pack that contains all of the types of the function arguments to the C function
5656
///
5757
/// \param fn_member is the name of the function in the FnTableType to call
58-
/// \param args is the arguments to the C function, which will be forwarded to the user-supplied Python function, except user_data, which is intercepted and passed as a nullptr.
58+
/// \param args is the arguments to the C function, which will be forwarded to the user-supplied Python function.
5959
template<typename FnType, typename FnTableType, typename... Args>
6060
int user_supplied_fn_caller(nb::object FnTableType::*fn_member, void* user_data,
6161
Args... args)
@@ -70,6 +70,27 @@ int user_supplied_fn_caller(nb::object FnTableType::*fn_member, void* user_data,
7070
args_tuple);
7171
}
7272

73+
/// This function will call a user-supplied Python function through C++ side wrappers
74+
/// \tparam FnType is the function signature, e.g., std::remove_pointer_t<CVRhsFn>
75+
/// \tparam FnTableType is the struct function table that holds the user-supplied Python functions as std::function
76+
/// \tparam Args is the template parameter pack that contains all of the types of the function arguments to the C function
77+
///
78+
/// \param fn_member is the name of the function in the FnTableType to call
79+
/// \param args is the arguments to the C function, which will be forwarded to the user-supplied Python function.
80+
template<typename FnType, typename FnTableType, typename T, typename... Args>
81+
int user_supplied_fn_caller(nb::object FnTableType::*fn_member, Args... args)
82+
{
83+
auto args_tuple = std::tuple<Args...>(args...);
84+
85+
// Cast object->python to FnTableType*
86+
auto object = static_cast<T>(std::get<0>(args_tuple));
87+
auto fn_table = static_cast<FnTableType*>(object->python);
88+
auto fn = nb::cast<std::function<FnType>>(fn_table->*fn_member);
89+
90+
return std::apply([&](auto&&... call_args) { return fn(call_args...); },
91+
args_tuple);
92+
}
93+
7394
} // namespace pysundials
7495

7596
#endif

bindings/pysundials/sundials/pysundials_stepper.cpp

Lines changed: 219 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@
1919

2020
#include <nanobind/nanobind.h>
2121
#include <nanobind/ndarray.h>
22+
#include <nanobind/stl/vector.h>
2223

2324
#include <sundials/sundials_stepper.hpp>
24-
2525
#include "sundials_stepper_impl.h"
2626

27+
#include "pysundials_stepper_usersupplied.hpp"
28+
2729
namespace nb = nanobind;
2830

2931
using SUNStepperView = sundials::experimental::SUNStepperView;
@@ -38,4 +40,220 @@ void bind_sunstepper(nb::module_& m)
3840
.def("get", nb::overload_cast<>(&SUNStepperView::get, nb::const_),
3941
nb::rv_policy::reference)
4042
.def_static("Create", SUNStepperView::Create<SUNStepper>);
43+
44+
m.def(
45+
"SUNStepper_SetEvolveFn",
46+
[](SUNStepper stepper,
47+
std::function<std::remove_pointer_t<SUNStepperEvolveFn>> fn) -> SUNErrCode
48+
{
49+
if (!stepper->python)
50+
{
51+
stepper->python = SUNStepperFunctionTable_Alloc();
52+
}
53+
auto fntable = static_cast<SUNStepperFunctionTable*>(stepper->python);
54+
fntable->evolve = nb::cast(fn);
55+
if (fn)
56+
{
57+
return SUNStepper_SetEvolveFn(stepper, sunstepper_evolve_wrapper);
58+
}
59+
else { return SUNStepper_SetEvolveFn(stepper, nullptr); }
60+
},
61+
nb::arg("stepper"), nb::arg("fn").none());
62+
63+
m.def(
64+
"SUNStepper_SetOneStepFn",
65+
[](SUNStepper stepper,
66+
std::function<std::remove_pointer_t<SUNStepperOneStepFn>> fn) -> SUNErrCode
67+
{
68+
if (!stepper->python)
69+
{
70+
stepper->python = SUNStepperFunctionTable_Alloc();
71+
}
72+
auto fntable = static_cast<SUNStepperFunctionTable*>(stepper->python);
73+
fntable->one_step = nb::cast(fn);
74+
if (fn)
75+
{
76+
return SUNStepper_SetOneStepFn(stepper, sunstepper_one_step_wrapper);
77+
}
78+
else { return SUNStepper_SetOneStepFn(stepper, nullptr); }
79+
},
80+
nb::arg("stepper"), nb::arg("fn").none());
81+
82+
m.def(
83+
"SUNStepper_SetFullRhsFn",
84+
[](SUNStepper stepper,
85+
std::function<std::remove_pointer_t<SUNStepperFullRhsFn>> fn) -> SUNErrCode
86+
{
87+
if (!stepper->python)
88+
{
89+
stepper->python = SUNStepperFunctionTable_Alloc();
90+
}
91+
auto fntable = static_cast<SUNStepperFunctionTable*>(stepper->python);
92+
fntable->full_rhs = nb::cast(fn);
93+
if (fn)
94+
{
95+
return SUNStepper_SetFullRhsFn(stepper, sunstepper_full_rhs_wrapper);
96+
}
97+
else { return SUNStepper_SetFullRhsFn(stepper, nullptr); }
98+
},
99+
nb::arg("stepper"), nb::arg("fn").none());
100+
101+
m.def(
102+
"SUNStepper_SetReInitFn",
103+
[](SUNStepper stepper,
104+
std::function<std::remove_pointer_t<SUNStepperReInitFn>> fn) -> SUNErrCode
105+
{
106+
if (!stepper->python)
107+
{
108+
stepper->python = SUNStepperFunctionTable_Alloc();
109+
}
110+
auto fntable = static_cast<SUNStepperFunctionTable*>(stepper->python);
111+
fntable->reinit = nb::cast(fn);
112+
if (fn)
113+
{
114+
return SUNStepper_SetReInitFn(stepper, sunstepper_reinit_wrapper);
115+
}
116+
else { return SUNStepper_SetReInitFn(stepper, nullptr); }
117+
},
118+
nb::arg("stepper"), nb::arg("fn").none());
119+
120+
m.def(
121+
"SUNStepper_SetResetFn",
122+
[](SUNStepper stepper,
123+
std::function<std::remove_pointer_t<SUNStepperResetFn>> fn) -> SUNErrCode
124+
{
125+
if (!stepper->python)
126+
{
127+
stepper->python = SUNStepperFunctionTable_Alloc();
128+
}
129+
auto fntable = static_cast<SUNStepperFunctionTable*>(stepper->python);
130+
fntable->reset = nb::cast(fn);
131+
if (fn)
132+
{
133+
return SUNStepper_SetResetFn(stepper, sunstepper_reset_wrapper);
134+
}
135+
else { return SUNStepper_SetResetFn(stepper, nullptr); }
136+
},
137+
nb::arg("stepper"), nb::arg("fn").none());
138+
139+
m.def(
140+
"SUNStepper_SetResetCheckpointIndexFn",
141+
[](SUNStepper stepper,
142+
std::function<std::remove_pointer_t<SUNStepperResetCheckpointIndexFn>> fn)
143+
-> SUNErrCode
144+
{
145+
if (!stepper->python)
146+
{
147+
stepper->python = SUNStepperFunctionTable_Alloc();
148+
}
149+
auto fntable = static_cast<SUNStepperFunctionTable*>(stepper->python);
150+
fntable->reset_ckpt_idx = nb::cast(fn);
151+
if (fn)
152+
{
153+
return SUNStepper_SetResetCheckpointIndexFn(stepper,
154+
sunstepper_reset_ckpt_idx_wrapper);
155+
}
156+
else { return SUNStepper_SetResetCheckpointIndexFn(stepper, nullptr); }
157+
},
158+
nb::arg("stepper"), nb::arg("fn").none());
159+
160+
m.def(
161+
"SUNStepper_SetStopTimeFn",
162+
[](SUNStepper stepper,
163+
std::function<std::remove_pointer_t<SUNStepperSetStopTimeFn>> fn) -> SUNErrCode
164+
{
165+
if (!stepper->python)
166+
{
167+
stepper->python = SUNStepperFunctionTable_Alloc();
168+
}
169+
auto fntable = static_cast<SUNStepperFunctionTable*>(stepper->python);
170+
fntable->set_stop_time = nb::cast(fn);
171+
if (fn)
172+
{
173+
return SUNStepper_SetStopTimeFn(stepper,
174+
sunstepper_set_stop_time_wrapper);
175+
}
176+
else { return SUNStepper_SetStopTimeFn(stepper, nullptr); }
177+
},
178+
nb::arg("stepper"), nb::arg("fn").none());
179+
180+
m.def(
181+
"SUNStepper_SetStepDirectionFn",
182+
[](SUNStepper stepper,
183+
std::function<std::remove_pointer_t<SUNStepperSetStepDirectionFn>> fn) -> SUNErrCode
184+
{
185+
if (!stepper->python)
186+
{
187+
stepper->python = SUNStepperFunctionTable_Alloc();
188+
}
189+
auto fntable = static_cast<SUNStepperFunctionTable*>(stepper->python);
190+
fntable->set_step_direction = nb::cast(fn);
191+
if (fn)
192+
{
193+
return SUNStepper_SetStepDirectionFn(stepper,
194+
sunstepper_set_step_direction_wrapper);
195+
}
196+
else { return SUNStepper_SetStepDirectionFn(stepper, nullptr); }
197+
},
198+
nb::arg("stepper"), nb::arg("fn").none());
199+
200+
using SUNStepperSetForcingStdFn =
201+
SUNErrCode(SUNStepper stepper, sunrealtype tshift, sunrealtype tscale,
202+
std::vector<N_Vector> forcing, int nforcing);
203+
m.def(
204+
"SUNStepper_SetForcingFn",
205+
[](SUNStepper stepper, std::function<SUNStepperSetForcingStdFn> fn) -> SUNErrCode
206+
{
207+
if (!stepper->python)
208+
{
209+
stepper->python = SUNStepperFunctionTable_Alloc();
210+
}
211+
auto fntable = static_cast<SUNStepperFunctionTable*>(stepper->python);
212+
fntable->set_forcing = nb::cast(fn);
213+
if (fn)
214+
{
215+
return SUNStepper_SetForcingFn(stepper, sunstepper_set_forcing_wrapper);
216+
}
217+
else { return SUNStepper_SetForcingFn(stepper, nullptr); }
218+
},
219+
nb::arg("stepper"), nb::arg("fn").none());
220+
221+
m.def(
222+
"SUNStepper_SetGetNumStepsFn",
223+
[](SUNStepper stepper,
224+
std::function<std::remove_pointer_t<SUNStepperGetNumStepsFn>> fn) -> SUNErrCode
225+
{
226+
if (!stepper->python)
227+
{
228+
stepper->python = SUNStepperFunctionTable_Alloc();
229+
}
230+
auto fntable = static_cast<SUNStepperFunctionTable*>(stepper->python);
231+
fntable->get_num_steps = nb::cast(fn);
232+
if (fn)
233+
{
234+
return SUNStepper_SetGetNumStepsFn(stepper,
235+
sunstepper_get_num_steps_wrapper);
236+
}
237+
else { return SUNStepper_SetGetNumStepsFn(stepper, nullptr); }
238+
},
239+
nb::arg("stepper"), nb::arg("fn").none());
240+
241+
m.def(
242+
"SUNStepper_SetDestroyFn",
243+
[](SUNStepper stepper,
244+
std::function<std::remove_pointer_t<SUNStepperDestroyFn>> fn) -> SUNErrCode
245+
{
246+
if (!stepper->python)
247+
{
248+
stepper->python = SUNStepperFunctionTable_Alloc();
249+
}
250+
auto fntable = static_cast<SUNStepperFunctionTable*>(stepper->python);
251+
fntable->destroy = nb::cast(fn);
252+
if (fn)
253+
{
254+
return SUNStepper_SetDestroyFn(stepper, sunstepper_destroy_wrapper);
255+
}
256+
else { return SUNStepper_SetDestroyFn(stepper, nullptr); }
257+
},
258+
nb::arg("stepper"), nb::arg("fn").none());
41259
}

0 commit comments

Comments
 (0)