Skip to content

Commit a406ccf

Browse files
committed
move mristepinnerstepper function table to special python member
1 parent 5158e37 commit a406ccf

File tree

4 files changed

+24
-36
lines changed

4 files changed

+24
-36
lines changed

bindings/pysundials/arkode/pysundials_arkode_mristep.cpp

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,7 @@ void bind_arkode_mristep(nb::module_& m)
6767

6868
auto cb_fns = mristepinnerstepper_user_supplied_fn_table_alloc();
6969

70-
status = MRIStepInnerStepper_SetContent(stepper,
71-
static_cast<void*>(cb_fns));
72-
if (status != ARK_SUCCESS)
73-
{
74-
throw std::runtime_error(
75-
"Failed to set content in MRIStepInnerStepper");
76-
}
77-
78-
// TODO(CJB): need to set ownership of content so that MRIStepInnerStepper_Free frees the user-supplied function table
70+
stepper->python = static_cast<void*>(cb_fns);
7971

8072
return stepper;
8173
},
@@ -113,9 +105,6 @@ void bind_arkode_mristep(nb::module_& m)
113105
throw std::runtime_error("Failed to create ARKODE memory");
114106
}
115107

116-
void* content = nullptr;
117-
MRIStepInnerStepper_GetContent(stepper, &content);
118-
119108
// Create the user-supplied function table to store the Python user functions
120109
auto cb_fns = arkode_user_supplied_fn_table_alloc();
121110

bindings/pysundials/arkode/pysundials_arkode_usersupplied.hpp

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ inline int erkstep_f_wrapper(sunrealtype t, N_Vector y, N_Vector ydot,
346346
}
347347

348348
inline int erkstep_adjf_wrapper(sunrealtype t, N_Vector y, N_Vector sens,
349-
N_Vector sens_dot, void* user_data)
349+
N_Vector sens_dot, void* user_data)
350350
{
351351
return pysundials::user_supplied_fn_caller<
352352
std::remove_pointer_t<SUNAdjRhsFn>, arkode_user_supplied_fn_table,
@@ -479,10 +479,9 @@ inline int mristepinner_evolvefn_wrapper(MRIStepInnerStepper stepper,
479479
MRIStepInnerStepper_GetContent(stepper, &user_data);
480480

481481
return pysundials::user_supplied_fn_caller<
482-
std::remove_pointer_t<MRIStepInnerEvolveFn>,
483-
mristepinnerstepper_user_supplied_fn_table>(&mristepinnerstepper_user_supplied_fn_table::
484-
mristepinner_evolvefn,
485-
user_data, stepper, t0, tout, y);
482+
std::remove_pointer_t<MRIStepInnerEvolveFn>, mristepinnerstepper_user_supplied_fn_table,
483+
MRIStepInnerStepper>(&mristepinnerstepper_user_supplied_fn_table::mristepinner_evolvefn,
484+
stepper, t0, tout, y);
486485
}
487486

488487
inline int mristepinner_fullrhsfn_wrapper(MRIStepInnerStepper stepper,
@@ -493,11 +492,9 @@ inline int mristepinner_fullrhsfn_wrapper(MRIStepInnerStepper stepper,
493492
MRIStepInnerStepper_GetContent(stepper, &user_data);
494493

495494
return pysundials::user_supplied_fn_caller<
496-
std::remove_pointer_t<MRIStepInnerFullRhsFn>,
497-
mristepinnerstepper_user_supplied_fn_table>(&mristepinnerstepper_user_supplied_fn_table::
498-
mristepinner_fullrhsfn,
499-
user_data, stepper, t, y, f,
500-
mode);
495+
std::remove_pointer_t<MRIStepInnerFullRhsFn>, mristepinnerstepper_user_supplied_fn_table,
496+
MRIStepInnerStepper>(&mristepinnerstepper_user_supplied_fn_table::mristepinner_fullrhsfn,
497+
stepper, t, y, f, mode);
501498
}
502499

503500
inline int mristepinner_resetfn_wrapper(MRIStepInnerStepper stepper,
@@ -507,10 +504,9 @@ inline int mristepinner_resetfn_wrapper(MRIStepInnerStepper stepper,
507504
MRIStepInnerStepper_GetContent(stepper, &user_data);
508505

509506
return pysundials::user_supplied_fn_caller<
510-
std::remove_pointer_t<MRIStepInnerResetFn>,
511-
mristepinnerstepper_user_supplied_fn_table>(&mristepinnerstepper_user_supplied_fn_table::
512-
mristepinner_resetfn,
513-
user_data, stepper, tR, yR);
507+
std::remove_pointer_t<MRIStepInnerResetFn>, mristepinnerstepper_user_supplied_fn_table,
508+
MRIStepInnerStepper>(&mristepinnerstepper_user_supplied_fn_table::mristepinner_resetfn,
509+
stepper, tR, yR);
514510
}
515511

516512
inline int mristepinner_getaccumulatederrorfn_wrapper(MRIStepInnerStepper stepper,
@@ -521,9 +517,9 @@ inline int mristepinner_getaccumulatederrorfn_wrapper(MRIStepInnerStepper steppe
521517

522518
return pysundials::user_supplied_fn_caller<
523519
std::remove_pointer_t<MRIStepInnerGetAccumulatedError>,
524-
mristepinnerstepper_user_supplied_fn_table>(&mristepinnerstepper_user_supplied_fn_table::
525-
mristepinner_getaccumulatederrorfn,
526-
user_data, stepper, accum_error);
520+
mristepinnerstepper_user_supplied_fn_table,
521+
MRIStepInnerStepper>(&mristepinnerstepper_user_supplied_fn_table::mristepinner_getaccumulatederrorfn,
522+
stepper, accum_error);
527523
}
528524

529525
inline int mristepinner_resetaccumulatederrorfn_wrapper(MRIStepInnerStepper stepper)
@@ -533,9 +529,9 @@ inline int mristepinner_resetaccumulatederrorfn_wrapper(MRIStepInnerStepper step
533529

534530
return pysundials::user_supplied_fn_caller<
535531
std::remove_pointer_t<MRIStepInnerResetAccumulatedError>,
536-
mristepinnerstepper_user_supplied_fn_table>(&mristepinnerstepper_user_supplied_fn_table::
537-
mristepinner_resetaccumulatederrorfn,
538-
user_data, stepper);
532+
mristepinnerstepper_user_supplied_fn_table,
533+
MRIStepInnerStepper>(&mristepinnerstepper_user_supplied_fn_table::mristepinner_resetaccumulatederrorfn,
534+
stepper);
539535
}
540536

541537
inline int mristepinner_setrtolfn_wrapper(MRIStepInnerStepper stepper,
@@ -545,10 +541,9 @@ inline int mristepinner_setrtolfn_wrapper(MRIStepInnerStepper stepper,
545541
MRIStepInnerStepper_GetContent(stepper, &user_data);
546542

547543
return pysundials::user_supplied_fn_caller<
548-
std::remove_pointer_t<MRIStepInnerSetRTol>,
549-
mristepinnerstepper_user_supplied_fn_table>(&mristepinnerstepper_user_supplied_fn_table::
550-
mristepinner_setrtolfn,
551-
user_data, stepper, rtol);
544+
std::remove_pointer_t<MRIStepInnerSetRTol>, mristepinnerstepper_user_supplied_fn_table,
545+
MRIStepInnerStepper>(&mristepinnerstepper_user_supplied_fn_table::mristepinner_setrtolfn,
546+
stepper, rtol);
552547
}
553548

554549
#endif

src/arkode/arkode_mristep.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4324,6 +4324,9 @@ int MRIStepInnerStepper_Free(MRIStepInnerStepper* stepper)
43244324
/* free operations structure */
43254325
free((*stepper)->ops);
43264326

4327+
/* free python data */
4328+
free((*stepper)->python);
4329+
43274330
/* free inner stepper mem */
43284331
free(*stepper);
43294332
*stepper = NULL;

src/arkode/arkode_mristep_impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ struct _MRIStepInnerStepper
183183
{
184184
/* stepper specific content and operations */
185185
void* content;
186+
void* python;
186187
MRIStepInnerStepper_Ops ops;
187188

188189
/* stepper context */

0 commit comments

Comments
 (0)