Skip to content

Commit 94b57c5

Browse files
committed
fix memory issue, support multiple states
1 parent d317a58 commit 94b57c5

File tree

2 files changed

+47
-23
lines changed

2 files changed

+47
-23
lines changed

src/state/state.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,8 +370,8 @@ void define_solver(py::module_ &m)
370370
py::arg("log_level"))
371371

372372
.def(
373-
"mesh", [](State &s) -> mesh::Mesh * { return s.mesh.get(); },
374-
"Get mesh in simulator")
373+
"mesh", [](State &s) -> mesh::Mesh& { return *s.mesh.get(); },
374+
"Get mesh in simulator", py::return_value_policy::reference)
375375

376376
.def(
377377
"load_mesh_from_settings",

test/test_differentiable.ipynb

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,26 +20,36 @@
2020
"metadata": {},
2121
"outputs": [],
2222
"source": [
23+
"\n",
2324
"# Differentiable simulator that computes shape derivatives\n",
2425
"class Simulate(torch.autograd.Function):\n",
2526
"\n",
2627
" @staticmethod\n",
27-
" def forward(ctx, solver, vertices):\n",
28-
" solver.mesh().set_vertices(vertices)\n",
29-
" solver.set_cache_level(pf.CacheLevel.Derivatives) # enable backward derivatives\n",
30-
" solver.solve()\n",
31-
" cache = solver.get_solution_cache()\n",
32-
" solutions = torch.zeros((solver.ndof(), cache.size()))\n",
33-
" for t in range(cache.size()):\n",
34-
" solutions[:, t] = torch.tensor(cache.solution(t))\n",
28+
" def forward(ctx, solvers, vertices):\n",
29+
" solutions = []\n",
30+
" for solver in solvers:\n",
31+
" solver.mesh().set_vertices(vertices)\n",
32+
" solver.set_cache_level(pf.CacheLevel.Derivatives) # enable backward derivatives\n",
33+
" solver.solve()\n",
34+
" cache = solver.get_solution_cache()\n",
35+
" sol = torch.zeros((solver.ndof(), cache.size()))\n",
36+
" for t in range(cache.size()):\n",
37+
" sol[:, t] = torch.tensor(cache.solution(t))\n",
38+
" solutions.append(sol)\n",
3539
" \n",
36-
" ctx.solver = solver\n",
37-
" return solutions\n",
40+
" ctx.save_for_backward(vertices)\n",
41+
" ctx.solvers = solvers\n",
42+
" return tuple(solutions)\n",
3843
"\n",
3944
" @staticmethod\n",
40-
" def backward(ctx, grad_output):\n",
41-
" ctx.solver.solve_adjoint(grad_output)\n",
42-
" return None, torch.tensor(pf.shape_derivative(ctx.solver))"
45+
" @torch.autograd.function.once_differentiable\n",
46+
" def backward(ctx, *grad_output):\n",
47+
" vertices, = ctx.saved_tensors\n",
48+
" grad = torch.zeros_like(vertices)\n",
49+
" for i, solver in enumerate(ctx.solvers):\n",
50+
" solver.solve_adjoint(grad_output[i])\n",
51+
" grad += torch.tensor(pf.shape_derivative(solver))\n",
52+
" return None, grad\n"
4353
]
4454
},
4555
{
@@ -48,20 +58,33 @@
4858
"metadata": {},
4959
"outputs": [],
5060
"source": [
61+
"\n",
5162
"root = \"../data/differentiable/input\"\n",
5263
"with open(root + \"/initial-contact.json\", \"r\") as f:\n",
5364
" config = json.load(f)\n",
5465
"\n",
5566
"config[\"root_path\"] = root + \"/initial-contact.json\"\n",
5667
"\n",
57-
"solver = pf.Solver()\n",
58-
"solver.set_settings(json.dumps(config), False)\n",
59-
"solver.set_log_level(2)\n",
60-
"solver.load_mesh_from_settings()\n",
68+
"# Simulation 1\n",
69+
"\n",
70+
"solver1 = pf.Solver()\n",
71+
"solver1.set_settings(json.dumps(config), False)\n",
72+
"solver1.set_log_level(2)\n",
73+
"solver1.load_mesh_from_settings()\n",
74+
"# solver1.solve()\n",
6175
"\n",
62-
"mesh = solver.mesh()\n",
76+
"mesh = solver1.mesh()\n",
6377
"v = mesh.vertices()\n",
64-
"vertices = torch.tensor(solver.mesh().vertices(), requires_grad=True)"
78+
"vertices = torch.tensor(solver1.mesh().vertices(), requires_grad=True)\n",
79+
"\n",
80+
"# Simulation 2\n",
81+
"\n",
82+
"config[\"initial_conditions\"][\"velocity\"][0][\"value\"] = [3, 0]\n",
83+
"solver2 = pf.Solver()\n",
84+
"solver2.set_settings(json.dumps(config), False)\n",
85+
"solver2.set_log_level(2)\n",
86+
"solver2.load_mesh_from_settings()\n",
87+
"# solver2.solve()"
6588
]
6689
},
6790
{
@@ -70,11 +93,12 @@
7093
"metadata": {},
7194
"outputs": [],
7295
"source": [
96+
"\n",
7397
"# Verify gradient\n",
7498
"\n",
7599
"def loss(vertices):\n",
76-
" solutions = Simulate.apply(solver, vertices)\n",
77-
" return torch.linalg.norm(solutions[:, -1])\n",
100+
" solutions1, solutions2 = Simulate.apply([solver1, solver2], vertices)\n",
101+
" return torch.linalg.norm(solutions1[:, -1]) * torch.linalg.norm(solutions2[:, -1])\n",
78102
"\n",
79103
"torch.set_printoptions(12)\n",
80104
"\n",

0 commit comments

Comments
 (0)