Skip to content

Commit a8f7004

Browse files
committed
fix SetLinearSolver in arkode
1 parent dc2d3a4 commit a8f7004

File tree

3 files changed

+33
-4
lines changed

3 files changed

+33
-4
lines changed

bindings/pysundials/arkode/generate.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ modules:
2525
macro_define_include_by_name__regex:
2626
- "^SUN_"
2727
- "^ARK_"
28+
fn_params_optional_with_default_null:
29+
"SetLinearSolver":
30+
- "A"
2831
arkode:
2932
path: arkode/pysundials_arkode_generated.hpp
3033
headers:
@@ -45,6 +48,8 @@ modules:
4548
- "ARKodeSetWFtolerances"
4649
- "ARKodeSet.*Preconditioner"
4750
- "ARKodeSet.*Times"
51+
# generator cannot handle functions with optional (i.e. NULLable) parameters that is not followed by only optional parameters
52+
- "ARKodeSetMassLinearSolver"
4853
arkstep:
4954
path: arkode/pysundials_arkode_arkstep_generated.hpp
5055
headers:

bindings/pysundials/arkode/pysundials_arkode.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212
* SUNDIALS Copyright End
1313
*----------------------------------------------------------------------------*/
1414

15+
#include <optional>
16+
1517
#include <nanobind/nanobind.h>
1618
#include <nanobind/stl/function.h>
19+
#include <nanobind/stl/optional.h>
1720

1821
#include <sundials/sundials_core.hpp>
1922

@@ -44,7 +47,9 @@ void bind_arkode(nb::module_& m)
4447
.def("get", nb::overload_cast<>(&ARKodeView::get, nb::const_),
4548
nb::rv_policy::reference);
4649

50+
//
4751
// ARKode function pointer setters
52+
//
4853
m.def("ARKodeSetPostprocessStepFn",
4954
[](void* ark_mem, std::function<std::remove_pointer_t<ARKPostProcessFn>> fn)
5055
{
@@ -240,6 +245,15 @@ void bind_arkode(nb::module_& m)
240245
return ARKodeSetLinSysFn(ark_mem, &arkode_lslinsysfn_wrapper);
241246
});
242247

248+
/////////////////////////////////////////////////////////////////////////////
249+
// Additional functions that litgen cannot generate
250+
/////////////////////////////////////////////////////////////////////////////
251+
252+
// These functions have optional arguments which litgen cannot deal with
253+
m.def("ARKodeSetMassLinearSolver",
254+
ARKodeSetMassLinearSolver, nb::arg("arkode_mem"), nb::arg("LS"), nb::arg("M").none(), nb::arg("time_dep"));
255+
256+
243257
bind_arkode_arkstep(m);
244258
bind_arkode_erkstep(m);
245259
bind_arkode_sprkstep(m);

bindings/pysundials/arkode/pysundials_arkode_generated.hpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -471,10 +471,20 @@ m.def("ARKodeGetNumRelaxSolveIters",
471471
//
472472

473473
m.def("ARKodeSetLinearSolver",
474-
ARKodeSetLinearSolver, nb::arg("arkode_mem"), nb::arg("LS"), nb::arg("A"));
475-
476-
m.def("ARKodeSetMassLinearSolver",
477-
ARKodeSetMassLinearSolver, nb::arg("arkode_mem"), nb::arg("LS"), nb::arg("M"), nb::arg("time_dep"));
474+
[](void * arkode_mem, SUNLinearSolver LS, std::optional<SUNMatrix> A = std::nullopt) -> int
475+
{
476+
auto ARKodeSetLinearSolver_adapt_optional_arg_with_default_null = [](void * arkode_mem, SUNLinearSolver LS, std::optional<SUNMatrix> A = std::nullopt) -> int
477+
{
478+
SUNMatrix A_adapt_default_null = nullptr;
479+
if (A.has_value())
480+
A_adapt_default_null = A.value();
481+
482+
auto lambda_result = ARKodeSetLinearSolver(arkode_mem, LS, A_adapt_default_null);
483+
return lambda_result;
484+
};
485+
486+
return ARKodeSetLinearSolver_adapt_optional_arg_with_default_null(arkode_mem, LS, A);
487+
}, nb::arg("arkode_mem"), nb::arg("LS"), nb::arg("A") = nb::none());
478488

479489
m.def("ARKodeSetJacEvalFrequency",
480490
ARKodeSetJacEvalFrequency, nb::arg("arkode_mem"), nb::arg("msbj"));

0 commit comments

Comments
 (0)