Skip to content

Commit e63f747

Browse files
authored
Merge pull request #2 from arvigj/main
python binding fixes
2 parents 68cb9a7 + aa0c9df commit e63f747

File tree

5 files changed

+178
-57
lines changed

5 files changed

+178
-57
lines changed

src/differentiable/adjoint.cpp

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@ void define_adjoint(py::module_ &m)
3030
[](State &state) {
3131
Eigen::VectorXd term;
3232
if (state.problem->is_time_dependent())
33-
AdjointTools::dJ_material_transient_adjoint_term(state, state.get_adjoint_mat(1), state.get_adjoint_mat(0), term);
33+
AdjointTools::dJ_material_transient_adjoint_term(
34+
state, state.get_adjoint_mat(1), state.get_adjoint_mat(0), term);
3435
else
35-
AdjointTools::dJ_material_static_adjoint_term(state, state.diff_cached.u(0), state.get_adjoint_mat(0), term);
36+
AdjointTools::dJ_material_static_adjoint_term(
37+
state, state.diff_cached.u(0), state.get_adjoint_mat(0), term);
3638

3739
return utils::unflatten(term, state.bases.size());
3840
},
@@ -43,7 +45,8 @@ void define_adjoint(py::module_ &m)
4345
[](State &state) {
4446
Eigen::VectorXd term;
4547
if (state.problem->is_time_dependent())
46-
AdjointTools::dJ_friction_transient_adjoint_term(state, state.get_adjoint_mat(1), state.get_adjoint_mat(0), term);
48+
AdjointTools::dJ_friction_transient_adjoint_term(
49+
state, state.get_adjoint_mat(1), state.get_adjoint_mat(0), term);
4750
else
4851
log_and_throw_adjoint_error(
4952
"Friction coefficient derivative is only supported for transient problems!");
@@ -88,7 +91,7 @@ void define_adjoint(py::module_ &m)
8891
vec += term.segment(state.ndof() + g.index * dim, dim);
8992
}
9093
}
91-
94+
9295
return map;
9396
},
9497
py::arg("solver"));
@@ -129,8 +132,24 @@ void define_adjoint(py::module_ &m)
129132
vec += term.segment(g.index * dim, dim);
130133
}
131134
}
132-
135+
133136
return map;
134137
},
135138
py::arg("solver"));
139+
140+
m.def(
141+
"dirichlet_derivative",
142+
[](State &state) {
143+
const int dim = state.mesh->dimension();
144+
145+
Eigen::VectorXd term;
146+
if (state.problem->is_time_dependent())
147+
log_and_throw_adjoint_error(
148+
"Dirichlet derivative is only supported for static problems!");
149+
150+
AdjointTools::dJ_dirichlet_static_adjoint_term(
151+
state, state.get_adjoint_mat(0), term);
152+
return utils::unflatten(term, state.mesh->dimension());
153+
},
154+
py::arg("solver"));
136155
}

src/differentiable/objective.cpp

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,38 @@ using namespace polyfem::solver;
1515

1616
void define_objective(py::module_ &m)
1717
{
18-
py::class_<AdjointForm, std::shared_ptr<AdjointForm>>(m, "Objective")
19-
.def("name", &AdjointForm::name)
18+
py::class_<AdjointForm, std::shared_ptr<AdjointForm>>(m, "Objective")
19+
.def("name", &AdjointForm::name)
2020

21-
.def("value", &AdjointForm::value, py::arg("x"))
21+
.def("value", &AdjointForm::value, py::arg("x"))
2222

23-
.def("solution_changed", &AdjointForm::solution_changed, py::arg("x"))
23+
.def("solution_changed", &AdjointForm::solution_changed, py::arg("x"))
2424

25-
.def("derivative", [](AdjointForm &obj, State &solver, const Eigen::VectorXd &x, const std::string &wrt) -> Eigen::VectorXd {
25+
.def("is_step_collision_free", &AdjointForm::is_step_collision_free,
26+
py::arg("x0"), py::arg("x1"))
27+
28+
.def("max_step_size", &AdjointForm::max_step_size, py::arg("x0"),
29+
py::arg("x1"))
30+
31+
.def(
32+
"derivative",
33+
[](AdjointForm &obj, State &solver, const Eigen::VectorXd &x,
34+
const std::string &wrt) -> Eigen::VectorXd {
2635
if (wrt == "solution")
27-
return obj.compute_adjoint_rhs(x, solver);
36+
return obj.compute_adjoint_rhs(x, solver);
2837
else if (wrt == obj.get_variable_to_simulations()[0]->name())
2938
{
30-
Eigen::VectorXd grad;
31-
obj.compute_partial_gradient(x, grad);
32-
return grad;
39+
Eigen::VectorXd grad;
40+
obj.compute_partial_gradient(x, grad);
41+
return grad;
3342
}
3443
else
35-
throw std::runtime_error("Input type does not match objective derivative type!");
36-
}, py::arg("solver"), py::arg("x"), py::arg("wrt"));
44+
throw std::runtime_error(
45+
"Input type does not match objective derivative type!");
46+
},
47+
py::arg("solver"), py::arg("x"), py::arg("wrt"));
3748

38-
m.def("create_objective", &AdjointOptUtils::create_simple_form,
39-
py::arg("obj_type"), py::arg("param_type"), py::arg("solver"), py::arg("parameters"));
49+
m.def("create_objective", &AdjointOptUtils::create_simple_form,
50+
py::arg("obj_type"), py::arg("param_type"), py::arg("solver"),
51+
py::arg("parameters"));
4052
}

src/differentiable/utils.cpp

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <polyfem/mesh/SlimSmooth.hpp>
2+
#include <polyfem/solver/AdjointTools.hpp>
23
#include <polyfem/utils/MatrixUtils.hpp>
34
#include <polyfem/State.hpp>
45
#include "binding.hpp"
@@ -8,18 +9,27 @@
89
namespace py = pybind11;
910
using namespace polyfem;
1011
using namespace polyfem::mesh;
12+
using namespace polyfem::solver;
1113

1214
void define_opt_utils(py::module_ &m)
1315
{
1416
m.def(
15-
"apply_slim",
16-
[](const Eigen::MatrixXd &V, const Eigen::MatrixXi &F,
17-
const Eigen::MatrixXd &Vnew) {
18-
Eigen::MatrixXd Vsmooth;
19-
bool succeed = apply_slim(V, F, Vnew, Vsmooth, 1000);
20-
if (!succeed)
21-
throw std::runtime_error("SLIM failed to converge!");
22-
return Vsmooth;
23-
},
24-
py::arg("Vold"), py::arg("faces"), py::arg("Vnew"));
17+
"apply_slim",
18+
[](const Eigen::MatrixXd &V, const Eigen::MatrixXi &F,
19+
const Eigen::MatrixXd &Vnew) {
20+
Eigen::MatrixXd Vsmooth;
21+
bool succeed = apply_slim(V, F, Vnew, Vsmooth, 1000);
22+
if (!succeed)
23+
throw std::runtime_error("SLIM failed to converge!");
24+
return Vsmooth;
25+
},
26+
py::arg("Vold"), py::arg("faces"), py::arg("Vnew"))
27+
28+
.def("map_primitive_to_node_order",
29+
&AdjointTools::map_primitive_to_node_order, py::arg("state"),
30+
py::arg("primitives"))
31+
32+
.def("map_node_to_primitive_order",
33+
&AdjointTools::map_node_to_primitive_order, py::arg("state"),
34+
py::arg("nodes"));
2535
}

src/mesh/mesh.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <polyfem/mesh/mesh2D/CMesh2D.hpp>
44
#include "binding.hpp"
55
#include <pybind11/eigen.h>
6+
#include <pybind11/stl.h>
67

78
namespace py = pybind11;
89
using namespace polyfem;
@@ -50,31 +51,37 @@ void define_mesh(py::module_ &m)
5051
.def("get_boundary_id", &Mesh::get_boundary_id,
5152
"Get boundary ID of one boundary primitive", py::arg("primitive"))
5253

54+
.def("get_body_ids", &Mesh::get_body_ids, "Get body IDs")
55+
5356
// .def(
5457
// "set_boundary_side_set_from_bary",
5558
// [](Mesh &mesh,
56-
// const std::function<int(const RowVectorNd &)> &boundary_marker) {
59+
// const std::function<int(const RowVectorNd &)> &boundary_marker)
60+
// {
5761
// mesh.compute_boundary_ids(boundary_marker);
5862
// },
59-
// "Sets the side set for the boundary conditions, the functions takes the barycenter of the boundary (edge or face)",
63+
// "Sets the side set for the boundary conditions, the functions takes
64+
// the barycenter of the boundary (edge or face)",
6065
// py::arg("boundary_marker"))
6166
// .def(
6267
// "set_boundary_side_set_from_bary_and_boundary",
6368
// [](Mesh &mesh, const std::function<int(const RowVectorNd &, bool)>
6469
// &boundary_marker) {
6570
// mesh.compute_boundary_ids(boundary_marker);
6671
// },
67-
// "Sets the side set for the boundary conditions, the functions takes the barycenter of the boundary (edge or face) and a flag that says if the element is boundary",
68-
// py::arg("boundary_marker"))
72+
// "Sets the side set for the boundary conditions, the functions takes
73+
// the barycenter of the boundary (edge or face) and a flag that says
74+
// if the element is boundary", py::arg("boundary_marker"))
6975
// .def(
7076
// "set_boundary_side_set_from_v_ids",
7177
// [](Mesh &mesh,
7278
// const std::function<int(const std::vector<int> &, bool)>
7379
// &boundary_marker) {
7480
// mesh.compute_boundary_ids(boundary_marker);
7581
// },
76-
// "Sets the side set for the boundary conditions, the functions takes the sorted list of vertex id and a flag that says if the element is boundary",
77-
// py::arg("boundary_marker"))
82+
// "Sets the side set for the boundary conditions, the functions takes
83+
// the sorted list of vertex id and a flag that says if the element is
84+
// boundary", py::arg("boundary_marker"))
7885

7986
.def("set_body_ids", &Mesh::set_body_ids, "Set body IDs with an array",
8087
py::arg("ids"))

0 commit comments

Comments
 (0)