19
19
20
20
#include < nanobind/nanobind.h>
21
21
#include < nanobind/ndarray.h>
22
+ #include < nanobind/stl/vector.h>
22
23
23
24
#include < sundials/sundials_stepper.hpp>
24
-
25
25
#include " sundials_stepper_impl.h"
26
26
27
+ #include " pysundials_stepper_usersupplied.hpp"
28
+
27
29
namespace nb = nanobind;
28
30
29
31
using SUNStepperView = sundials::experimental::SUNStepperView;
@@ -38,4 +40,220 @@ void bind_sunstepper(nb::module_& m)
38
40
.def (" get" , nb::overload_cast<>(&SUNStepperView::get, nb::const_),
39
41
nb::rv_policy::reference)
40
42
.def_static (" Create" , SUNStepperView::Create<SUNStepper>);
43
+
44
+ m.def (
45
+ " SUNStepper_SetEvolveFn" ,
46
+ [](SUNStepper stepper,
47
+ std::function<std::remove_pointer_t <SUNStepperEvolveFn>> fn) -> SUNErrCode
48
+ {
49
+ if (!stepper->python )
50
+ {
51
+ stepper->python = SUNStepperFunctionTable_Alloc ();
52
+ }
53
+ auto fntable = static_cast <SUNStepperFunctionTable*>(stepper->python );
54
+ fntable->evolve = nb::cast (fn);
55
+ if (fn)
56
+ {
57
+ return SUNStepper_SetEvolveFn (stepper, sunstepper_evolve_wrapper);
58
+ }
59
+ else { return SUNStepper_SetEvolveFn (stepper, nullptr ); }
60
+ },
61
+ nb::arg (" stepper" ), nb::arg (" fn" ).none ());
62
+
63
+ m.def (
64
+ " SUNStepper_SetOneStepFn" ,
65
+ [](SUNStepper stepper,
66
+ std::function<std::remove_pointer_t <SUNStepperOneStepFn>> fn) -> SUNErrCode
67
+ {
68
+ if (!stepper->python )
69
+ {
70
+ stepper->python = SUNStepperFunctionTable_Alloc ();
71
+ }
72
+ auto fntable = static_cast <SUNStepperFunctionTable*>(stepper->python );
73
+ fntable->one_step = nb::cast (fn);
74
+ if (fn)
75
+ {
76
+ return SUNStepper_SetOneStepFn (stepper, sunstepper_one_step_wrapper);
77
+ }
78
+ else { return SUNStepper_SetOneStepFn (stepper, nullptr ); }
79
+ },
80
+ nb::arg (" stepper" ), nb::arg (" fn" ).none ());
81
+
82
+ m.def (
83
+ " SUNStepper_SetFullRhsFn" ,
84
+ [](SUNStepper stepper,
85
+ std::function<std::remove_pointer_t <SUNStepperFullRhsFn>> fn) -> SUNErrCode
86
+ {
87
+ if (!stepper->python )
88
+ {
89
+ stepper->python = SUNStepperFunctionTable_Alloc ();
90
+ }
91
+ auto fntable = static_cast <SUNStepperFunctionTable*>(stepper->python );
92
+ fntable->full_rhs = nb::cast (fn);
93
+ if (fn)
94
+ {
95
+ return SUNStepper_SetFullRhsFn (stepper, sunstepper_full_rhs_wrapper);
96
+ }
97
+ else { return SUNStepper_SetFullRhsFn (stepper, nullptr ); }
98
+ },
99
+ nb::arg (" stepper" ), nb::arg (" fn" ).none ());
100
+
101
+ m.def (
102
+ " SUNStepper_SetReInitFn" ,
103
+ [](SUNStepper stepper,
104
+ std::function<std::remove_pointer_t <SUNStepperReInitFn>> fn) -> SUNErrCode
105
+ {
106
+ if (!stepper->python )
107
+ {
108
+ stepper->python = SUNStepperFunctionTable_Alloc ();
109
+ }
110
+ auto fntable = static_cast <SUNStepperFunctionTable*>(stepper->python );
111
+ fntable->reinit = nb::cast (fn);
112
+ if (fn)
113
+ {
114
+ return SUNStepper_SetReInitFn (stepper, sunstepper_reinit_wrapper);
115
+ }
116
+ else { return SUNStepper_SetReInitFn (stepper, nullptr ); }
117
+ },
118
+ nb::arg (" stepper" ), nb::arg (" fn" ).none ());
119
+
120
+ m.def (
121
+ " SUNStepper_SetResetFn" ,
122
+ [](SUNStepper stepper,
123
+ std::function<std::remove_pointer_t <SUNStepperResetFn>> fn) -> SUNErrCode
124
+ {
125
+ if (!stepper->python )
126
+ {
127
+ stepper->python = SUNStepperFunctionTable_Alloc ();
128
+ }
129
+ auto fntable = static_cast <SUNStepperFunctionTable*>(stepper->python );
130
+ fntable->reset = nb::cast (fn);
131
+ if (fn)
132
+ {
133
+ return SUNStepper_SetResetFn (stepper, sunstepper_reset_wrapper);
134
+ }
135
+ else { return SUNStepper_SetResetFn (stepper, nullptr ); }
136
+ },
137
+ nb::arg (" stepper" ), nb::arg (" fn" ).none ());
138
+
139
+ m.def (
140
+ " SUNStepper_SetResetCheckpointIndexFn" ,
141
+ [](SUNStepper stepper,
142
+ std::function<std::remove_pointer_t <SUNStepperResetCheckpointIndexFn>> fn)
143
+ -> SUNErrCode
144
+ {
145
+ if (!stepper->python )
146
+ {
147
+ stepper->python = SUNStepperFunctionTable_Alloc ();
148
+ }
149
+ auto fntable = static_cast <SUNStepperFunctionTable*>(stepper->python );
150
+ fntable->reset_ckpt_idx = nb::cast (fn);
151
+ if (fn)
152
+ {
153
+ return SUNStepper_SetResetCheckpointIndexFn (stepper,
154
+ sunstepper_reset_ckpt_idx_wrapper);
155
+ }
156
+ else { return SUNStepper_SetResetCheckpointIndexFn (stepper, nullptr ); }
157
+ },
158
+ nb::arg (" stepper" ), nb::arg (" fn" ).none ());
159
+
160
+ m.def (
161
+ " SUNStepper_SetStopTimeFn" ,
162
+ [](SUNStepper stepper,
163
+ std::function<std::remove_pointer_t <SUNStepperSetStopTimeFn>> fn) -> SUNErrCode
164
+ {
165
+ if (!stepper->python )
166
+ {
167
+ stepper->python = SUNStepperFunctionTable_Alloc ();
168
+ }
169
+ auto fntable = static_cast <SUNStepperFunctionTable*>(stepper->python );
170
+ fntable->set_stop_time = nb::cast (fn);
171
+ if (fn)
172
+ {
173
+ return SUNStepper_SetStopTimeFn (stepper,
174
+ sunstepper_set_stop_time_wrapper);
175
+ }
176
+ else { return SUNStepper_SetStopTimeFn (stepper, nullptr ); }
177
+ },
178
+ nb::arg (" stepper" ), nb::arg (" fn" ).none ());
179
+
180
+ m.def (
181
+ " SUNStepper_SetStepDirectionFn" ,
182
+ [](SUNStepper stepper,
183
+ std::function<std::remove_pointer_t <SUNStepperSetStepDirectionFn>> fn) -> SUNErrCode
184
+ {
185
+ if (!stepper->python )
186
+ {
187
+ stepper->python = SUNStepperFunctionTable_Alloc ();
188
+ }
189
+ auto fntable = static_cast <SUNStepperFunctionTable*>(stepper->python );
190
+ fntable->set_step_direction = nb::cast (fn);
191
+ if (fn)
192
+ {
193
+ return SUNStepper_SetStepDirectionFn (stepper,
194
+ sunstepper_set_step_direction_wrapper);
195
+ }
196
+ else { return SUNStepper_SetStepDirectionFn (stepper, nullptr ); }
197
+ },
198
+ nb::arg (" stepper" ), nb::arg (" fn" ).none ());
199
+
200
+ using SUNStepperSetForcingStdFn =
201
+ SUNErrCode (SUNStepper stepper, sunrealtype tshift, sunrealtype tscale,
202
+ std::vector<N_Vector> forcing, int nforcing);
203
+ m.def (
204
+ " SUNStepper_SetForcingFn" ,
205
+ [](SUNStepper stepper, std::function<SUNStepperSetForcingStdFn> fn) -> SUNErrCode
206
+ {
207
+ if (!stepper->python )
208
+ {
209
+ stepper->python = SUNStepperFunctionTable_Alloc ();
210
+ }
211
+ auto fntable = static_cast <SUNStepperFunctionTable*>(stepper->python );
212
+ fntable->set_forcing = nb::cast (fn);
213
+ if (fn)
214
+ {
215
+ return SUNStepper_SetForcingFn (stepper, sunstepper_set_forcing_wrapper);
216
+ }
217
+ else { return SUNStepper_SetForcingFn (stepper, nullptr ); }
218
+ },
219
+ nb::arg (" stepper" ), nb::arg (" fn" ).none ());
220
+
221
+ m.def (
222
+ " SUNStepper_SetGetNumStepsFn" ,
223
+ [](SUNStepper stepper,
224
+ std::function<std::remove_pointer_t <SUNStepperGetNumStepsFn>> fn) -> SUNErrCode
225
+ {
226
+ if (!stepper->python )
227
+ {
228
+ stepper->python = SUNStepperFunctionTable_Alloc ();
229
+ }
230
+ auto fntable = static_cast <SUNStepperFunctionTable*>(stepper->python );
231
+ fntable->get_num_steps = nb::cast (fn);
232
+ if (fn)
233
+ {
234
+ return SUNStepper_SetGetNumStepsFn (stepper,
235
+ sunstepper_get_num_steps_wrapper);
236
+ }
237
+ else { return SUNStepper_SetGetNumStepsFn (stepper, nullptr ); }
238
+ },
239
+ nb::arg (" stepper" ), nb::arg (" fn" ).none ());
240
+
241
+ m.def (
242
+ " SUNStepper_SetDestroyFn" ,
243
+ [](SUNStepper stepper,
244
+ std::function<std::remove_pointer_t <SUNStepperDestroyFn>> fn) -> SUNErrCode
245
+ {
246
+ if (!stepper->python )
247
+ {
248
+ stepper->python = SUNStepperFunctionTable_Alloc ();
249
+ }
250
+ auto fntable = static_cast <SUNStepperFunctionTable*>(stepper->python );
251
+ fntable->destroy = nb::cast (fn);
252
+ if (fn)
253
+ {
254
+ return SUNStepper_SetDestroyFn (stepper, sunstepper_destroy_wrapper);
255
+ }
256
+ else { return SUNStepper_SetDestroyFn (stepper, nullptr ); }
257
+ },
258
+ nb::arg (" stepper" ), nb::arg (" fn" ).none ());
41
259
}
0 commit comments