|
20 | 20 | "metadata": {},
|
21 | 21 | "outputs": [],
|
22 | 22 | "source": [
|
| 23 | + "\n", |
23 | 24 | "# Differentiable simulator that computes shape derivatives\n",
|
24 | 25 | "class Simulate(torch.autograd.Function):\n",
|
25 | 26 | "\n",
|
26 | 27 | " @staticmethod\n",
|
27 |
| - " def forward(ctx, solver, vertices):\n", |
28 |
| - " solver.mesh().set_vertices(vertices)\n", |
29 |
| - " solver.set_cache_level(pf.CacheLevel.Derivatives) # enable backward derivatives\n", |
30 |
| - " solver.solve()\n", |
31 |
| - " cache = solver.get_solution_cache()\n", |
32 |
| - " solutions = torch.zeros((solver.ndof(), cache.size()))\n", |
33 |
| - " for t in range(cache.size()):\n", |
34 |
| - " solutions[:, t] = torch.tensor(cache.solution(t))\n", |
| 28 | + " def forward(ctx, solvers, vertices):\n", |
| 29 | + " solutions = []\n", |
| 30 | + " for solver in solvers:\n", |
| 31 | + " solver.mesh().set_vertices(vertices)\n", |
| 32 | + " solver.set_cache_level(pf.CacheLevel.Derivatives) # enable backward derivatives\n", |
| 33 | + " solver.solve()\n", |
| 34 | + " cache = solver.get_solution_cache()\n", |
| 35 | + " sol = torch.zeros((solver.ndof(), cache.size()))\n", |
| 36 | + " for t in range(cache.size()):\n", |
| 37 | + " sol[:, t] = torch.tensor(cache.solution(t))\n", |
| 38 | + " solutions.append(sol)\n", |
35 | 39 | " \n",
|
36 |
| - " ctx.solver = solver\n", |
37 |
| - " return solutions\n", |
| 40 | + " ctx.save_for_backward(vertices)\n", |
| 41 | + " ctx.solvers = solvers\n", |
| 42 | + " return tuple(solutions)\n", |
38 | 43 | "\n",
|
39 | 44 | " @staticmethod\n",
|
40 |
| - " def backward(ctx, grad_output):\n", |
41 |
| - " ctx.solver.solve_adjoint(grad_output)\n", |
42 |
| - " return None, torch.tensor(pf.shape_derivative(ctx.solver))" |
| 45 | + " @torch.autograd.function.once_differentiable\n", |
| 46 | + " def backward(ctx, *grad_output):\n", |
| 47 | + " vertices, = ctx.saved_tensors\n", |
| 48 | + " grad = torch.zeros_like(vertices)\n", |
| 49 | + " for i, solver in enumerate(ctx.solvers):\n", |
| 50 | + " solver.solve_adjoint(grad_output[i])\n", |
| 51 | + " grad += torch.tensor(pf.shape_derivative(solver))\n", |
| 52 | + " return None, grad\n" |
43 | 53 | ]
|
44 | 54 | },
|
45 | 55 | {
|
|
48 | 58 | "metadata": {},
|
49 | 59 | "outputs": [],
|
50 | 60 | "source": [
|
| 61 | + "\n", |
51 | 62 | "root = \"../data/differentiable/input\"\n",
|
52 | 63 | "with open(root + \"/initial-contact.json\", \"r\") as f:\n",
|
53 | 64 | " config = json.load(f)\n",
|
54 | 65 | "\n",
|
55 | 66 | "config[\"root_path\"] = root + \"/initial-contact.json\"\n",
|
56 | 67 | "\n",
|
57 |
| - "solver = pf.Solver()\n", |
58 |
| - "solver.set_settings(json.dumps(config), False)\n", |
59 |
| - "solver.set_log_level(2)\n", |
60 |
| - "solver.load_mesh_from_settings()\n", |
| 68 | + "# Simulation 1\n", |
| 69 | + "\n", |
| 70 | + "solver1 = pf.Solver()\n", |
| 71 | + "solver1.set_settings(json.dumps(config), False)\n", |
| 72 | + "solver1.set_log_level(2)\n", |
| 73 | + "solver1.load_mesh_from_settings()\n", |
| 74 | + "# solver1.solve()\n", |
61 | 75 | "\n",
|
62 |
| - "mesh = solver.mesh()\n", |
| 76 | + "mesh = solver1.mesh()\n", |
63 | 77 | "v = mesh.vertices()\n",
|
64 |
| - "vertices = torch.tensor(solver.mesh().vertices(), requires_grad=True)" |
| 78 | + "vertices = torch.tensor(solver1.mesh().vertices(), requires_grad=True)\n", |
| 79 | + "\n", |
| 80 | + "# Simulation 2\n", |
| 81 | + "\n", |
| 82 | + "config[\"initial_conditions\"][\"velocity\"][0][\"value\"] = [3, 0]\n", |
| 83 | + "solver2 = pf.Solver()\n", |
| 84 | + "solver2.set_settings(json.dumps(config), False)\n", |
| 85 | + "solver2.set_log_level(2)\n", |
| 86 | + "solver2.load_mesh_from_settings()\n", |
| 87 | + "# solver2.solve()" |
65 | 88 | ]
|
66 | 89 | },
|
67 | 90 | {
|
|
70 | 93 | "metadata": {},
|
71 | 94 | "outputs": [],
|
72 | 95 | "source": [
|
| 96 | + "\n", |
73 | 97 | "# Verify gradient\n",
|
74 | 98 | "\n",
|
75 | 99 | "def loss(vertices):\n",
|
76 |
| - " solutions = Simulate.apply(solver, vertices)\n", |
77 |
| - " return torch.linalg.norm(solutions[:, -1])\n", |
| 100 | + " solutions1, solutions2 = Simulate.apply([solver1, solver2], vertices)\n", |
| 101 | + " return torch.linalg.norm(solutions1[:, -1]) * torch.linalg.norm(solutions2[:, -1])\n", |
78 | 102 | "\n",
|
79 | 103 | "torch.set_printoptions(12)\n",
|
80 | 104 | "\n",
|
|
0 commit comments