|
453 | 453 | ],
|
454 | 454 | "source": [
|
455 | 455 | "from sbi.utils.vector_field_utils import VectorFieldNet\n",
|
| 456 | + "\n", |
456 | 457 | "?VectorFieldNet.forward"
|
457 | 458 | ]
|
458 | 459 | },
|
|
480 | 481 | "\n",
|
481 | 482 | " def forward(self, theta, x, t):\n",
|
482 | 483 | " # Requires 3 arguments for input (i.e. theta), condition (i.e. x) and \"time\" (a scalar)\n",
|
483 |
| - " \n", |
| 484 | + "\n", |
484 | 485 | " # Whatever weird things you want to do, you can do here.\n",
|
485 | 486 | " h1 = self.in_layer1(x[...,None,:]).mean(-1)\n",
|
486 | 487 | " h2 = self.in_layer2(theta[...,None,:]).mean(-1)\n",
|
487 | 488 | " t = self.time_layer(t[...,None])\n",
|
488 | 489 | " h = torch.relu(h1 + h2 + t)\n",
|
489 | 490 | " out = self.out_layer(h)\n",
|
490 |
| - " \n", |
| 491 | + "\n", |
491 | 492 | " # Output dimension must exaclty match \"theta.shape\"\n",
|
492 | 493 | " return out\n",
|
493 |
| - " \n", |
494 |
| - "net_base = CustomNetExample() \n", |
| 494 | + "\n", |
| 495 | + "net_base = CustomNetExample()\n", |
495 | 496 | "# This will still add some of `sbi` standard \"preconditioning\" for\n",
|
496 | 497 | "# diffusion/fmpe as well as automatic z-scoring transforms.\n",
|
497 | 498 | "# If you do want this you can also simple write your own builder.\n",
|
|
0 commit comments