Skip to content

Commit 0a344d6

Browse files
author
Ben Moseley
committed
new JAX code
1 parent 1f9a71a commit 0a344d6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+7699
-0
lines changed

README.md

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
# Finite basis physics-informed neural networks (FBPINNs)
2+
3+
---
4+
5+
This repository allows you to solve forward and inverse problems related to partial differential equations (PDEs) using **finite basis physics-informed neural networks** (FBPINNs).
6+
7+
> 🔥 MAJOR UPDATE 🔥: we have rewritten the `fbpinns` library in [JAX](https://jax.readthedocs.io/en/latest/index.html): it now runs 10-1000X faster than the original PyTorch code (by parallelising subdomain computations using `jax.vmap`) and scales to 1000s+ subdomains. We have also added extra functionality: you can now solve inverse problems, add arbitrary types of boundary/data constraints, define irregular domain decompositions and custom subdomain networks, and the high-level interface is much more flexible and easier to use. See the [Release note]() for more info.
8+
9+
FBPINNs were first presented here: *[Finite Basis Physics-Informed Neural Networks (FBPINNs): a scalable domain decomposition approach for solving differential equations](https://arxiv.org/abs/2107.07871), B. Moseley, T. Nissen-Meyer and A. Markham, Jul 2021 ArXiv*.
10+
11+
---
12+
13+
<figure>
14+
<center>
15+
16+
<figcaption><b>Fig 1: FBPINN solution of the (2+1)D wave equation with multiscale sources</b></figcaption>
17+
</center>
18+
</figure>
19+
20+
<figure>
21+
<center>
22+
<img src="images/FBPINN.gif" alt="FBPINN solving the high-frequency 1D harmonic oscillator" style="width:45%"> <img src="images/PINN.gif" alt="PINN solving the high-frequency 1D harmonic oscillator" style="width:45%">
23+
<img src="images/test-loss.png" alt="Test loss comparison" style="width:30%">
24+
<figcaption><b>Fig 2: FBPINN vs PINN solving the high-frequency 1D harmonic oscillator</b></figcaption>
25+
</center>
26+
</figure>
27+
28+
## Why FBPINNs?
29+
30+
- [Physics-informed neural networks](https://benmoseley.blog/my-research/so-what-is-a-physics-informed-neural-network/) (PINNs) are a popular approach for solving **forward and inverse problems** related to PDEs
31+
- However, PINNs often struggle to solve problems with **high frequencies** and/or **multi-scale solutions**
32+
- This is due to the **spectral bias** of neural networks and the **heavily increasing complexity** of the PINN optimisation problem
33+
- FBPINNs improve the performance of PINNs in this regime by combining them with **domain decomposition**, **individual subdomain normalisation** and **flexible subdomain training schedules**
34+
- Empirically, FBPINNs **significantly outperform** PINNs (in terms of accuracy and computational efficiency) when solving problems with high frequencies and multi-scale solutions (Fig 1 and 2)
35+
36+
## How are FBPINNs different to PINNs?
37+
38+
<figure>
39+
<center>
40+
<img src="images/workflow.png" alt="FBPINN workflow overview" style="width:100%">
41+
<figcaption><b>Fig 3: FBPINN workflow overview</b></figcaption>
42+
</center>
43+
</figure>
44+
45+
To improve the scalability of PINNs to high frequency/ multiscale solutions:
46+
47+
- FBPINNs divide the problem domain into many small, **overlapping subdomains** (Fig 3).
48+
49+
- A neural network is placed within each subdomain, and the solution to the PDE is defined as the **summation over all subdomain networks**.
50+
51+
- Each subdomain network is **locally confined** to its subdomain by multiplying it by a smooth, differentiable window function.
52+
53+
- Finally, the inputs of each network are **individually normalised** over their subdomain.
54+
55+
The hypothesis is that this "divide and conquer" approach significantly reduces the complexity of the PINN optimisation problem. Furthermore, individual subdomain normalisation ensures the "effective" frequency each subdomain network sees is low, reducing the effect of spectral bias.
56+
57+
## Subdomain scheduling
58+
59+
<figure>
60+
<center>
61+
<img src="images/scheduling.gif" alt="Solving the time-dependent Burgers' equation using a time-stepping subdomain scheduler" style="width:100%">
62+
<figcaption><b>Fig 4: Solving the time-dependent Burgers' equation using a time-stepping subdomain scheduler</b></figcaption>
63+
</center>
64+
</figure>
65+
66+
Another advantage of using domain decomposition is that we can control which **parts** of the domain are solved at each training step.
67+
68+
This is useful if we want to control how boundary conditions are **communicated** across the domain.
69+
70+
For example, we can define a **time-stepping scheduler** to solve time-dependent PDEs, and learn the solution forwards in time from a set of initial conditions (Fig 4).
71+
72+
This is done by specifying a **subdomain scheduler** (from `fbpinns.schedulers`), which defines which subdomains are actively training and which subdomains have fixed parameters at each training step.
73+
74+
75+
## Installation
76+
77+
`fbpinns` only requires Python libraries to run.
78+
79+
> [JAX](https://jax.readthedocs.io/en/latest/index.html) is used as the main computational engine for `fbpinns`.
80+
81+
To install `fbpinns`, we recommend setting up a new Python environment, for example:
82+
83+
```bash
84+
conda create -n fbpinns python=3 # Using conda
85+
conda activate fbpinns
86+
```
87+
then cloning this repository:
88+
```bash
89+
git clone git@github.com:benmoseley/FBPINNs.git
90+
```
91+
and running this command in the base `FBPINNs/` directory (will also install all of the dependencies):
92+
```
93+
pip install -e .
94+
```
95+
> Note this installs the `fbpinns` package in "editable mode" - you can make changes to the source code and they are immediately present in the package.
96+
97+
## Getting started
98+
99+
Forward and inverse PDE problems are defined and solved by carrying out the following steps:
100+
101+
1. Define the **problem domain**, by selecting or defining your own `fbpinns.domains.Domain` class
102+
2. Define the **PDE** to solve, and any **problem constraints** (such as boundary conditions or data constraints), by selecting or defining your own `fbpinns.problems.Problem` class
103+
3. Define the **domain decomposition** used by the FBPINN, by selecting or defining your own `fbpinns.decompositions.Decompositions` class
104+
4. Define the **neural network** placed in each subdomain, by selecting or defining your own `fbpinns.networks.Network` class
105+
5. Keep track of all the training hyperparameters by passing these classes and their initialisation values to a `fbpinns.constants.Constants` object
106+
6. Start the FBPINN training by instantiating a `fbpinns.trainer.FBPINNTrainer` using the `Constants` object.
107+
108+
For example, to solve the 1D harmonic oscillator problem shown above (Fig 2):
109+
110+
```python
111+
import numpy as np
112+
113+
from fbpinns.domains import RectangularDomainND
114+
from fbpinns.problems import HarmonicOscillator1D
115+
from fbpinns.decompositions import RectangularDecompositionND
116+
from fbpinns.networks import FCN
117+
from fbpinns.constants import Constants
118+
from fbpinns.trainers import FBPINNTrainer
119+
120+
c = Constants(
121+
domain=RectangularDomainND,# use a 1D problem domain [0, 1]
122+
domain_init_kwargs=dict(
123+
xmin=np.array([0,]),
124+
xmax=np.array([1,]),
125+
),
126+
problem=HarmonicOscillator1D,# solve the 1D harmonic oscillator problem
127+
problem_init_kwargs=dict(
128+
d=2, w0=80,# define the ODE parameters
129+
),
130+
decomposition=RectangularDecompositionND,# use a rectangular domain decomposition
131+
decomposition_init_kwargs=dict(
132+
subdomain_xs=[np.linspace(0,1,15)],# use 15 equally spaced subdomains
133+
subdomain_ws=[0.15*np.ones((15,))],# with widths of 0.15
134+
unnorm=(0.,1.),# define unnormalisation of the subdomain networks
135+
),
136+
network=FCN,# place a fully-connected network in each subdomain
137+
network_init_kwargs=dict(
138+
layer_sizes=[1,32,1],# with 2 hidden layers
139+
),
140+
ns=((200,),),# use 200 collocation points for training
141+
n_test=(500,),# use 500 points for testing
142+
n_steps=20000,# number of training steps
143+
optimiser_kwargs=dict(learning_rate=1e-3),
144+
show_figures=True,# display plots during training
145+
)
146+
147+
run = FBPINNTrainer(c)
148+
run.train()# start training the FBPINN
149+
```
150+
151+
The `FBPINNTrainer` will automatically start outputting training statistics, plots and tensorboard summaries. The tensorboard summaries can be viewed by installing [tensorboard](https://www.tensorflow.org/tensorboard) and then running `tensorboard --logdir results/summaries/`
152+
153+
### Comparing to PINNs
154+
155+
You can easily train a PINN using the same hyperparameters above, using:
156+
157+
```python
158+
from fbpinns.trainers import PINNTrainer
159+
160+
c["network_init_kwargs"] = dict(layer_sizes=[1,64,64,1])# use a larger neural network
161+
run = PINNTrainer(c)
162+
run.train()# start training a PINN on the same problem
163+
```
164+
165+
## Going further
166+
167+
See the [examples](https://github.com/benmoseley/FBPINNs/tree/main/examples) folder for more advanced examples covering:
168+
- how to define your own `Problem` class
169+
- how to use hard boundary constraints
170+
- how to solve an inverse problem
171+
- how to use subdomain scheduling
172+
173+
174+
## FAQs
175+
176+
177+
### Installation
178+
179+
I get the error: `RuntimeError: This version of jaxlib was built using AVX instructions, which your CPU and/or operating system do not support.` when using Apple GPUs.
180+
- As of this commit, JAX only has experimental support for Apple GPUs. Either build JAX from [source](https://developer.apple.com/metal/jax/) or install a CPU-only version using conda: `pip uninstall jax jaxlib` and `conda install jax -c conda-forge`
181+
182+
### Using GPUs
183+
184+
How do I train FBPINNs using a GPU?
185+
- Exactly the same code should run on a GPU automatically, without needing any modification. Make sure you have installed the GPU version of JAX, and that JAX can see your GPU devices (e.g. by checking `jax.devices()`)
186+
187+
### Understanding the repository
188+
But I don't know JAX!?
189+
- We highly recommend becoming familiar with JAX - it is a fantastic, general-purpose library for accelerated differentiable computing. But even if you don't want to learn JAX, that's ok - all of the front-end classes (`Domain`, `Problem`, `Decomposition`, and `Network`) can be defined with only basic understanding of `jax.numpy` (which is essentially the [same](https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html) as `numpy` anyway).
190+
191+
### Methodology
192+
193+
How are FBPINNs different to other PINN + domain decomposition methods?
194+
- In contrast to other PINN + domain decomposition methods (such as [XPINNs](https://global-sci.org/intro/article_detail/cicp/18403.html)), FBPINNs by their mathematical construction **do not require additional interface terms** in their loss function, and their **solution is continuous** across subdomain interfaces. Essentially, FBPINNs can just be thought of as defining a custom neural network architecture for PINNs - everything else stays the same.
195+
196+
197+
## Citation
198+
199+
If you find FBPINNs useful and use them in your own work, please use the following citations:
200+
201+
```
202+
@article{Moseley2021,
203+
arxivId = {2107.07871},
204+
author = {Moseley, Ben and Markham, Andrew and Nissen-Meyer, Tarje},
205+
journal = {arXiv},
206+
month = {jul},
207+
title = {{Finite Basis Physics-Informed Neural Networks (FBPINNs): a scalable domain decomposition approach for solving differential equations}},
208+
url = {https://arxiv.org/abs/2107.07871v1 http://arxiv.org/abs/2107.07871},
209+
year = {2021}
210+
}
211+
212+
@article{Dolean2023,
213+
arxivId = {2306.05486},
214+
author = {Dolean, Victorita and Heinlein, Alexander and Mishra, Siddhartha and Moseley, Ben},
215+
journal = {arXiv},
216+
month = {jun},
217+
title = {{Multilevel domain decomposition-based architectures for physics-informed neural networks}},
218+
url = {https://arxiv.org/abs/2306.05486v1 http://arxiv.org/abs/2306.05486},
219+
year = {2023}
220+
}
221+
```
222+
223+
## Reproducing our original paper
224+
225+
To reproduce the exact results of our original FBPINN paper (*[Finite Basis Physics-Informed Neural Networks (FBPINNs): a scalable domain decomposition approach for solving differential equations](https://arxiv.org/abs/2107.07871), B. Moseley, T. Nissen-Meyer and A. Markham, Jul 2021 ArXiv*) you will need to use the legacy PyTorch FBPINN implementation, which is available at this [commit]().
226+
227+
228+
229+
230+
## Further questions?
231+
232+
Please raise a GitHub [issue](https://github.com/benmoseley/FBPINNs/issues) or feel free to contact us.

0 commit comments

Comments
 (0)