You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: README.md
+26-21Lines changed: 26 additions & 21 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -5,7 +5,7 @@
5
5
6
6
## A native PyTorch port of [Google's Computational Fluid Dynamics package in Jax](https://github.com/google/jax-cfd)
7
7
This port is a good pedagogical tool to learn how to implement traditional numerical solvers using modern deep learning software interfaces. The main changes are documented in the [`torch_cfd` directory](./torch_cfd/). The most significant changes in all routines include:
8
-
- (**enhanced**) Nonhomogenous boundary conditions support: the user can provide array-valued or function-valued bcs. Many routines in Jax-CFD only work with only periodic or constant boundary.
8
+
- (**enhanced**) Nonhomogenous/immersed boundary conditions support: the user can provide array-valued or function-valued boundary conditions, and can add a no-slip mask when there are obstacles. Many routines in Jax-CFD only work with only periodic or constant boundary.
9
9
- (**changed**) Routines that rely on the functional programming of Jax have been rewritten to be the PyTorch's tensor-in-tensor-out style, which is arguably more user-friendly to debugging as one can view intermediate values in tensors in VS Code debugger, opposed to Jax's `JaxprTrace`.
10
10
- (**enhanced**) Batch-dimension: all operations take into consideration the batch dimension of tensors `(b, *, n, m)` regardless of `*` dimension, for example, `(b, T, C, n, m)`, which is similar to PyTorch behavior. In the original Jax-CFD package, only a single trajectory is implemented. The stencil operators are changed to generally operate from the last dimension using negative indexing, following `torch.nn.functional.pad`'s behavior.
11
11
- (**changed**) Neural Network interface: functions and operators are in general implemented as `nn.Module` like a factory template.
@@ -15,11 +15,26 @@ This port is a good pedagogical tool to learn how to implement traditional numer
- The **Spatiotemporal Fourier Neural Operator** (SFNO) is a spacetime tensor-to-tensor learner (or trajectory-to-trajectory), available in the [`fno` directory](./fno). Different components of FNO have been re-implemented keeping the conciseness of the original implementation while allowing modern expansions. We draw inspiration from the [3D FNO in Nvidia's Neural Operator repo](https://github.com/neuraloperator/neuraloperator), [Transformers-based neural operators](https://github.com/thuml/Neural-Solver-Library), as well as Temam's book on functional analysis for the NSE.
17
17
- Major architectural changes: learnable spatiotemporal positional encodings, layernorm to replace a hard-coded global Gaussian normalizer, and many others. For more details please see [the documentation of the `SFNO` class](./fno/sfno.py#L485).
18
-
- Data generation for the meta-example of the isotropic turbulence in [McWilliams1984]. After the warmup phase, the energy spectra match the inverse cascade of Kolmogorov flow in a periodic box.
18
+
- Data generations:
19
+
- Isotropic turbulence in McWilliams, J. C. (1984). The emergence of isolated coherent vortices in turbulent flow. *Journal of Fluid Mechanics*, 146, 21-43. After the warmup phase, the energy spectra match the direct cascade in a periodic box.
20
+
- Forced turbulence example: Kolmogorov flow with inverse cascades.
19
21
- Pipelines for the *a posteriori* error estimation to fine-tune the SFNO to reach the scientific computing level of accuracy ($\le 10^{-6}$) in Bochner norm using FLOPs on par with a single evaluation, and only a fraction of FLOPs of a single `.backward()`.
20
22
-[Examples](#examples) can be found below.
21
23
22
-
[McWilliams1984]: McWilliams, J. C. (1984). The emergence of isolated coherent vortices in turbulent flow. *Journal of Fluid Mechanics*, 146, 21-43.
-[2D Lid-driven cavity with a random field perturbation using finite volume](./examples/Lid-driven_cavity_rk4_fvm.ipynb)
29
+
-[2D decaying isotropic turbulence using the pseudo-spectral method](./examples/Kolmogrov2d_rk4_spectral_forced_turbulence.ipynb)
30
+
-[2D Kolmogorov flow using finite volume method](./examples/Kolmogrov2d_rk4_fvm_forced_turbulence.ipynb)
31
+
- Demos of Spatiotemporal FNO's training and evaluation using the neural operator-assisted fluid simulation pipelines
32
+
-[Training of SFNO for only 15 epochs for the isotropic turbulence example](./examples/ex2_SFNO_train.ipynb)
33
+
-[Training of SFNO for only ***10*** epochs with 1k samples and reach `1e-2` level of relative error](./examples/ex2_SFNO_train_fnodata.ipynb) using the data in the FNO paper, which to our best knowledge no operator learner can do this in <100 epochs in the small data regime.
34
+
-[Fine-tuning of SFNO on a `256x256` grid for only 50 ADAM iterations to reach `1e-6` residual in the functional norm using FNO data](./examples/ex2_SFNO_finetune_fnodata.ipynb)
35
+
-[Fine-tuning of SFNO on the `256x256` grid for the McWilliams 2d isotropic turbulence](./examples/ex2_SFNO_finetune_McWilliams2d.ipynb)
36
+
-[Training of SFNO for only 5 epoch to match the inverse cascade of Kolmogorov flow](./examples/ex2_SFNO_5ep_spectra.ipynb)
37
+
-[Baseline of FNO3d for fixed step size that requires preloading a normalizer](./examples/ex2_FNO3d_train_normalized.ipynb)
23
38
24
39
## Installation
25
40
To install `torch-cfd`'s current release, simply do:
The data are available at [https://huggingface.co/datasets/scaomath/navier-stokes-dataset](https://huggingface.co/datasets/scaomath/navier-stokes-dataset).
39
54
Data generation instructions are available in the [SFNO folder](./fno).
40
55
41
-
42
-
## Examples
43
-
- Demos of different simulation setups:
44
-
-[2D Lid-driven cavity with a random field perturbation using finite volume](./examples/Lid-driven_cavity_rk4_fvm.ipynb)
45
-
-[2D decaying isotropic turbulence using the pseudo-spectral method](./examples/Kolmogrov2d_rk4_spectral_forced_turbulence.ipynb)
46
-
-[2D Kolmogorov flow using finite volume method](./examples/Kolmogrov2d_rk4_fvm_forced_turbulence.ipynb)
47
-
- Demos of Spatiotemporal FNO's training and evaluation using the neural operator-assisted fluid simulation pipelines
48
-
-[Training of SFNO for only 15 epochs for the isotropic turbulence example](./examples/ex2_SFNO_train.ipynb)
49
-
-[Training of SFNO for only ***10*** epochs with 1k samples and reach `1e-2` level of relative error](./examples/ex2_SFNO_train_fnodata.ipynb) using the data in the FNO paper, which to our best knowledge no operator learner can do this in <100 epochs in the small data regime.
50
-
-[Fine-tuning of SFNO on a `256x256` grid for only 50 ADAM iterations to reach `1e-6` residual in the functional norm using FNO data](./examples/ex2_SFNO_finetune_fnodata.ipynb)
51
-
-[Fine-tuning of SFNO on the `256x256` grid for the McWilliams 2d isotropic turbulence](./examples/ex2_SFNO_finetune_McWilliams2d.ipynb)
52
-
-[Training of SFNO for only 5 epoch to match the inverse cascade of Kolmogorov flow](./examples/ex2_SFNO_5ep_spectra.ipynb)
53
-
-[Baseline of FNO3d for fixed step size that requires preloading a normalizer](./examples/ex2_FNO3d_train_normalized.ipynb)
54
-
55
56
## Licenses
56
57
The Apache 2.0 License in the root folder applies to the `torch-cfd` folder of the repo that is inherited from Google's original license file for `Jax-cfd`. The `fno` folder has the MIT license inherited from [NVIDIA's Neural Operator repo](https://github.com/neuraloperator/neuraloperator). Note: the license(s) in the subfolder takes precedence.
57
58
58
59
## Contributions
59
-
PR welcome. Currently, the port of `torch-cfd` currently includes:
60
+
PR welcome for enhancing [essential functionalities with TODO tags](./torch_cfd/README.md). Currently, the port of `torch-cfd` currently includes:
60
61
- The pseudospectral method for vorticity uses anti-aliasing filtering techniques for nonlinear terms to maintain stability.
61
-
- The finite volume method on a MAC grids for velocity, and using the projection scheme to impose the divergence free condition.
62
+
- The finite volume method using MAC grids for velocity, together with a simple pressure projection scheme to impose the divergence free condition.
62
63
- Temporal discretization: Currently it has only single-step RK4-family schemes uses explicit time-stepping for advection, either implicit or explicit time-stepping for diffusion.
63
-
- Boundary conditions: periodic and Dirichlet boundary conditions for velocity, Neumann boundary for pressure.
64
-
- Solvers: pseudoinverse (either FFT-based or SVD based), Jacobi- or Multigrid V-cycle-preconditioned Conjugate gradient.
64
+
- Boundary conditions:
65
+
- velocity: periodic, Dirichlet (function-valued or array-valued), Dirichlet-Neumann (Neumann has to be 0-valued).
0 commit comments