Skip to content

Commit d317a58

Browse files
committed
shape derivative python interface
1 parent a5de325 commit d317a58

File tree

9 files changed

+331
-47
lines changed

9 files changed

+331
-47
lines changed

polyfempy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# from .polyfempy import solve
77
from .polyfempy import CacheLevel
88
from .polyfempy import polyfem_command
9+
from .polyfempy import shape_derivative
910

1011
from .Settings import Settings
1112
from .Selection import Selection

src/binding.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,5 @@ PYBIND11_MODULE(polyfempy, m)
3030
define_nonlinear_problem(m);
3131

3232
define_differentiable_cache(m);
33+
define_adjoint(m);
3334
}

src/differentiable/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
set(SOURCES
22
diff_cache.cpp
3+
adjoint.cpp
34
)
45

56
source_group(TREE "${CMAKE_CURRENT_SOURCE_DIR}" PREFIX "Source Files" FILES ${SOURCES})

src/differentiable/adjoint.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#include <polyfem/solver/AdjointTools.hpp>
2+
#include <polyfem/utils/MatrixUtils.hpp>
3+
#include <polyfem/State.hpp>
4+
#include "binding.hpp"
5+
#include <pybind11/eigen.h>
6+
7+
namespace py = pybind11;
8+
using namespace polyfem;
9+
using namespace polyfem::solver;
10+
11+
void define_adjoint(py::module_ &m)
12+
{
13+
m.def("shape_derivative", [](State &state) {
14+
Eigen::VectorXd term;
15+
if (state.problem->is_time_dependent())
16+
AdjointTools::dJ_shape_transient_adjoint_term(
17+
state, state.get_adjoint_mat(1), state.get_adjoint_mat(0), term);
18+
else
19+
AdjointTools::dJ_shape_static_adjoint_term(
20+
state, state.diff_cached.u(0), state.get_adjoint_mat(0), term);
21+
return utils::unflatten(term, state.mesh->dimension());
22+
}, py::arg("solver"));
23+
}

src/differentiable/binding.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
namespace py = pybind11;
66

77
void define_differentiable_cache(py::module_ &m);
8+
void define_adjoint(py::module_ &m);

src/mesh/mesh.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,25 @@ void define_mesh(py::module_ &m)
4646
"Set boundary IDs with an array", py::arg("ids"))
4747

4848
.def("set_body_ids", &Mesh::set_body_ids, "Set body IDs with an array",
49-
py::arg("ids"));
49+
py::arg("ids"))
50+
51+
.def("point", &Mesh::point, "Get vertex position",
52+
py::arg("vertex_id"))
53+
54+
.def("set_point", &Mesh::set_point, "Set vertex position",
55+
py::arg("vertex_id"), py::arg("position"))
56+
57+
.def("vertices", [](const Mesh &mesh) {
58+
Eigen::MatrixXd points(mesh.n_vertices(), mesh.dimension());
59+
for (int i = 0; i < mesh.n_vertices(); i++)
60+
points.row(i) = mesh.point(i);
61+
return points;
62+
}, "Get all vertex positions")
63+
64+
.def("set_vertices", [](Mesh &mesh, const Eigen::MatrixXd &points) {
65+
for (int i = 0; i < mesh.n_vertices(); i++)
66+
mesh.set_point(i, points.row(i));
67+
}, "Set all vertex positions");
5068

5169
py::class_<CMesh2D, Mesh>(m, "Mesh2D", "");
5270
py::class_<CMesh3D, Mesh>(m, "Mesh3D", "");

src/state/state.cpp

Lines changed: 40 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,22 @@
1-
#include <polyfem/assembler/AssemblerUtils.hpp>
2-
#include <polyfem/utils/Logger.hpp>
31
#include <polyfem/mesh/MeshUtils.hpp>
2+
#include <polyfem/assembler/AssemblerUtils.hpp>
43
#include <polyfem/assembler/GenericProblem.hpp>
5-
#include <polyfem/utils/StringUtils.hpp>
4+
#include <polyfem/io/Evaluator.hpp>
65
#include <polyfem/io/YamlToJson.hpp>
6+
#include <polyfem/utils/Logger.hpp>
7+
#include <polyfem/utils/StringUtils.hpp>
78
#include <polyfem/utils/JSONUtils.hpp>
8-
9-
#include <polyfem/State.hpp>
9+
#include <polyfem/utils/GeogramUtils.hpp>
1010
#include <polyfem/solver/NLProblem.hpp>
1111
#include <polyfem/time_integrator/ImplicitTimeIntegrator.hpp>
12-
#include <polyfem/io/Evaluator.hpp>
12+
#include <polyfem/State.hpp>
1313

1414
// #include "raster.hpp"
1515

16-
#include <geogram/basic/command_line.h>
17-
#include <geogram/basic/command_line_args.h>
18-
1916
#include <igl/boundary_facets.h>
2017
#include <igl/remove_unreferenced.h>
21-
#include <stdexcept>
2218

23-
#ifdef USE_TBB
24-
#include <tbb/task_scheduler_init.h>
25-
#include <thread>
26-
#endif
19+
#include <stdexcept>
2720

2821
#include <pybind11_json/pybind11_json.hpp>
2922

@@ -90,25 +83,7 @@ namespace
9083

9184
if (!initialized)
9285
{
93-
#ifndef WIN32
94-
setenv("GEO_NO_SIGNAL_HANDLER", "1", 1);
95-
#endif
96-
97-
GEO::initialize();
98-
99-
#ifdef USE_TBB
100-
const size_t MB = 1024 * 1024;
101-
const size_t stack_size = 64 * MB;
102-
unsigned int num_threads =
103-
std::max(1u, std::thread::hardware_concurrency());
104-
tbb::task_scheduler_init scheduler(num_threads, stack_size);
105-
#endif
106-
107-
// Import standard command line arguments, and custom ones
108-
GEO::CmdLine::import_arg_group("standard");
109-
GEO::CmdLine::import_arg_group("pre");
110-
GEO::CmdLine::import_arg_group("algo");
111-
86+
state.set_max_threads(1);
11287
state.init_logger("", spdlog::level::level_enum::info,
11388
spdlog::level::level_enum::debug, false);
11489

@@ -381,6 +356,8 @@ void define_solver(py::module_ &m)
381356
"load PDE and problem parameters from the settings", py::arg("json"),
382357
py::arg("strict_validation") = false)
383358

359+
.def("ndof", &State::ndof, "Dimension of the solution")
360+
384361
.def(
385362
"set_log_level",
386363
[](State &s, int log_level) {
@@ -401,7 +378,6 @@ void define_solver(py::module_ &m)
401378
[](State &s, const bool normalize_mesh, const double vismesh_rel_area,
402379
const int n_refs, const double boundary_id_threshold) {
403380
init_globals(s);
404-
// py::scoped_ostream_redirect output;
405381
s.args["geometry"][0]["advanced"]["normalize_mesh"] =
406382
normalize_mesh;
407383
s.args["geometry"][0]["surface_selection"] =
@@ -425,7 +401,7 @@ void define_solver(py::module_ &m)
425401
const double vismesh_rel_area, const int n_refs,
426402
const double boundary_id_threshold) {
427403
init_globals(s);
428-
// py::scoped_ostream_redirect output;
404+
s.args["geometry"] = R"([{ }])"_json;
429405
s.args["geometry"][0]["mesh"] = path;
430406
s.args["geometry"][0]["advanced"]["normalize_mesh"] =
431407
normalize_mesh;
@@ -449,7 +425,7 @@ void define_solver(py::module_ &m)
449425
const bool normalize_mesh, const double vismesh_rel_area,
450426
const int n_refs, const double boundary_id_threshold) {
451427
init_globals(s);
452-
// py::scoped_ostream_redirect output;
428+
s.args["geometry"] = R"([{ }])"_json;
453429
s.args["geometry"][0]["mesh"] = path;
454430
s.args["bc_tag"] = bc_tag;
455431
s.args["geometry"][0]["advanced"]["normalize_mesh"] =
@@ -471,28 +447,20 @@ void define_solver(py::module_ &m)
471447
.def(
472448
"set_mesh",
473449
[](State &s, const Eigen::MatrixXd &V, const Eigen::MatrixXi &F,
474-
const bool normalize_mesh, const double vismesh_rel_area,
475450
const int n_refs, const double boundary_id_threshold) {
476451
init_globals(s);
477-
// py::scoped_ostream_redirect output;
478-
479452
s.mesh = mesh::Mesh::create(V, F);
480-
481-
s.args["geometry"][0]["advanced"]["normalize_mesh"] =
482-
normalize_mesh;
453+
s.args["geometry"] = R"([{ }])"_json;
483454
s.args["geometry"][0]["n_refs"] = n_refs;
484455
s.args["geometry"][0]["surface_selection"] =
485456
R"({ "threshold": 0.0 })"_json;
486457
s.args["geometry"][0]["surface_selection"]["threshold"] =
487458
boundary_id_threshold;
488-
s.args["output"]["paraview"]["vismesh_rel_area"] = vismesh_rel_area;
489459

490460
s.load_mesh();
491461
},
492462
"Loads a mesh from vertices and connectivity", py::arg("vertices"),
493-
py::arg("connectivity"), py::arg("normalize_mesh") = bool(false),
494-
py::arg("vismesh_rel_area") = double(0.00001),
495-
py::arg("n_refs") = int(0),
463+
py::arg("connectivity"), py::arg("n_refs") = int(0),
496464
py::arg("boundary_id_threshold") = double(-1))
497465

498466
.def(
@@ -643,6 +611,32 @@ void define_solver(py::module_ &m)
643611
"step in time", py::arg("solution"), py::arg("t0"), py::arg("dt"),
644612
py::arg("t"))
645613

614+
.def(
615+
"solve_adjoint",
616+
[](State &s, const Eigen::MatrixXd &adjoint_rhs) {
617+
if (adjoint_rhs.cols() != s.diff_cached.size()
618+
|| adjoint_rhs.rows() != s.diff_cached.u(0).size())
619+
throw std::runtime_error("Invalid adjoint_rhs shape!");
620+
if (!s.problem->is_time_dependent()
621+
&& !s.lin_solver_cached) // nonlinear static solve only
622+
{
623+
Eigen::MatrixXd reduced;
624+
for (int i = 0; i < adjoint_rhs.cols(); i++)
625+
{
626+
Eigen::VectorXd reduced_vec =
627+
s.solve_data.nl_problem->full_to_reduced_grad(
628+
adjoint_rhs.col(i));
629+
if (i == 0)
630+
reduced.setZero(reduced_vec.rows(), adjoint_rhs.cols());
631+
reduced.col(i) = reduced_vec;
632+
}
633+
return s.solve_adjoint_cached(reduced);
634+
}
635+
else
636+
return s.solve_adjoint_cached(adjoint_rhs);
637+
},
638+
"Solve the adjoint equation given the gradient of objective wrt. PDE solution")
639+
646640
.def(
647641
"set_cache_level",
648642
[](State &s, solver::CacheLevel level) {

test/test_basic.ipynb

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import polyfempy as pf\n",
10+
"import json\n",
11+
"import numpy as np"
12+
]
13+
},
14+
{
15+
"cell_type": "code",
16+
"execution_count": null,
17+
"metadata": {},
18+
"outputs": [],
19+
"source": [
20+
"root = \"../data/differentiable/input\"\n",
21+
"with open(root + \"/initial-contact.json\", \"r\") as f:\n",
22+
" config = json.load(f)\n",
23+
"\n",
24+
"config[\"root_path\"] = root + \"/initial-contact.json\"\n",
25+
"\n",
26+
"solver = pf.Solver()\n",
27+
"solver.set_settings(json.dumps(config), False)\n",
28+
"solver.set_log_level(2)\n",
29+
"solver.load_mesh_from_settings()"
30+
]
31+
},
32+
{
33+
"cell_type": "code",
34+
"execution_count": null,
35+
"metadata": {},
36+
"outputs": [],
37+
"source": [
38+
"mesh = solver.mesh()\n",
39+
"\n",
40+
"print(mesh.n_vertices())\n",
41+
"print(mesh.n_elements())\n",
42+
"print(mesh.n_cell_vertices(1))\n",
43+
"print(mesh.element_vertex(3, 0))\n",
44+
"print(mesh.boundary_element_vertex(3, 0))\n",
45+
"assert(mesh.is_boundary_vertex(1))\n",
46+
"\n",
47+
"min, max = mesh.bounding_box()"
48+
]
49+
},
50+
{
51+
"cell_type": "code",
52+
"execution_count": null,
53+
"metadata": {},
54+
"outputs": [],
55+
"source": [
56+
"config = solver.settings()\n",
57+
"t0 = config[\"time\"][\"t0\"]\n",
58+
"dt = config[\"time\"][\"dt\"]\n",
59+
"\n",
60+
"# inits stuff\n",
61+
"solver.build_basis()\n",
62+
"solver.assemble()\n",
63+
"sol = solver.init_timestepping(t0, dt)\n",
64+
"\n",
65+
"for i in range(1, 5):\n",
66+
" \n",
67+
" # substepping\n",
68+
" for t in range(1):\n",
69+
" sol = solver.step_in_time(sol, t0, dt, t+1)\n",
70+
"\n",
71+
" t0 += dt\n",
72+
" solver.export_vtu(sol, np.zeros((0, 0)), t0, dt, \"step_\" + str(i) + \".vtu\")\n"
73+
]
74+
},
75+
{
76+
"cell_type": "code",
77+
"execution_count": null,
78+
"metadata": {},
79+
"outputs": [],
80+
"source": [
81+
"prob = solver.nl_problem()\n",
82+
"\n",
83+
"h = prob.hessian(sol)\n",
84+
"reduced_sol = prob.full_to_reduced(sol)\n",
85+
"full_sol = prob.reduced_to_full(reduced_sol)\n",
86+
"\n",
87+
"assert(np.linalg.norm(full_sol - sol.flatten()) < 1e-12)"
88+
]
89+
},
90+
{
91+
"cell_type": "code",
92+
"execution_count": null,
93+
"metadata": {},
94+
"outputs": [],
95+
"source": [
96+
"cache = solver.get_solution_cache()\n",
97+
"\n",
98+
"print(cache.solution(1).shape)\n",
99+
"print(cache.velocity(2).shape)\n",
100+
"print(cache.acceleration(3).shape)\n",
101+
"print(cache.hessian(4).shape)"
102+
]
103+
}
104+
],
105+
"metadata": {
106+
"kernelspec": {
107+
"display_name": "base",
108+
"language": "python",
109+
"name": "python3"
110+
},
111+
"language_info": {
112+
"codemirror_mode": {
113+
"name": "ipython",
114+
"version": 3
115+
},
116+
"file_extension": ".py",
117+
"mimetype": "text/x-python",
118+
"name": "python",
119+
"nbconvert_exporter": "python",
120+
"pygments_lexer": "ipython3",
121+
"version": "3.11.8"
122+
}
123+
},
124+
"nbformat": 4,
125+
"nbformat_minor": 2
126+
}

0 commit comments

Comments
 (0)