Skip to content

Commit b994afb

Browse files
committed
von Karman vortex street example
1 parent b96fd8b commit b994afb

13 files changed

+52009
-50066
lines changed

README.md

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
## A native PyTorch port of [Google's Computational Fluid Dynamics package in Jax](https://github.com/google/jax-cfd)
77
This port is a good pedagogical tool to learn how to implement traditional numerical solvers using modern deep learning software interfaces. The main changes are documented in the [`torch_cfd` directory](./torch_cfd/). The most significant changes in all routines include:
8-
- (**enhanced**) Nonhomogenous boundary conditions support: the user can provide array-valued or function-valued bcs. Many routines in Jax-CFD only work with only periodic or constant boundary.
8+
- (**enhanced**) Nonhomogenous/immersed boundary conditions support: the user can provide array-valued or function-valued boundary conditions, and can add a no-slip mask when there are obstacles. Many routines in Jax-CFD only work with only periodic or constant boundary.
99
- (**changed**) Routines that rely on the functional programming of Jax have been rewritten to be the PyTorch's tensor-in-tensor-out style, which is arguably more user-friendly to debugging as one can view intermediate values in tensors in VS Code debugger, opposed to Jax's `JaxprTrace`.
1010
- (**enhanced**) Batch-dimension: all operations take into consideration the batch dimension of tensors `(b, *, n, m)` regardless of `*` dimension, for example, `(b, T, C, n, m)`, which is similar to PyTorch behavior. In the original Jax-CFD package, only a single trajectory is implemented. The stencil operators are changed to generally operate from the last dimension using negative indexing, following `torch.nn.functional.pad`'s behavior.
1111
- (**changed**) Neural Network interface: functions and operators are in general implemented as `nn.Module` like a factory template.
@@ -15,11 +15,26 @@ This port is a good pedagogical tool to learn how to implement traditional numer
1515
## Neural Operator-Assisted Navier-Stokes Equations simulator.
1616
- The **Spatiotemporal Fourier Neural Operator** (SFNO) is a spacetime tensor-to-tensor learner (or trajectory-to-trajectory), available in the [`fno` directory](./fno). Different components of FNO have been re-implemented keeping the conciseness of the original implementation while allowing modern expansions. We draw inspiration from the [3D FNO in Nvidia's Neural Operator repo](https://github.com/neuraloperator/neuraloperator), [Transformers-based neural operators](https://github.com/thuml/Neural-Solver-Library), as well as Temam's book on functional analysis for the NSE.
1717
- Major architectural changes: learnable spatiotemporal positional encodings, layernorm to replace a hard-coded global Gaussian normalizer, and many others. For more details please see [the documentation of the `SFNO` class](./fno/sfno.py#L485).
18-
- Data generation for the meta-example of the isotropic turbulence in [McWilliams1984]. After the warmup phase, the energy spectra match the inverse cascade of Kolmogorov flow in a periodic box.
18+
- Data generations:
19+
- Isotropic turbulence in McWilliams, J. C. (1984). The emergence of isolated coherent vortices in turbulent flow. *Journal of Fluid Mechanics*, 146, 21-43. After the warmup phase, the energy spectra match the direct cascade in a periodic box.
20+
- Forced turbulence example: Kolmogorov flow with inverse cascades.
1921
- Pipelines for the *a posteriori* error estimation to fine-tune the SFNO to reach the scientific computing level of accuracy ($\le 10^{-6}$) in Bochner norm using FLOPs on par with a single evaluation, and only a fraction of FLOPs of a single `.backward()`.
2022
- [Examples](#examples) can be found below.
2123

22-
[McWilliams1984]: McWilliams, J. C. (1984). The emergence of isolated coherent vortices in turbulent flow. *Journal of Fluid Mechanics*, 146, 21-43.
24+
25+
## Examples
26+
- Demos of different simulation setups:
27+
- [von Kármán vortex street](./examples/von_Karman_vortex_rk4_fvm.ipynb)
28+
- [2D Lid-driven cavity with a random field perturbation using finite volume](./examples/Lid-driven_cavity_rk4_fvm.ipynb)
29+
- [2D decaying isotropic turbulence using the pseudo-spectral method](./examples/Kolmogrov2d_rk4_spectral_forced_turbulence.ipynb)
30+
- [2D Kolmogorov flow using finite volume method](./examples/Kolmogrov2d_rk4_fvm_forced_turbulence.ipynb)
31+
- Demos of Spatiotemporal FNO's training and evaluation using the neural operator-assisted fluid simulation pipelines
32+
- [Training of SFNO for only 15 epochs for the isotropic turbulence example](./examples/ex2_SFNO_train.ipynb)
33+
- [Training of SFNO for only ***10*** epochs with 1k samples and reach `1e-2` level of relative error](./examples/ex2_SFNO_train_fnodata.ipynb) using the data in the FNO paper, which to our best knowledge no operator learner can do this in <100 epochs in the small data regime.
34+
- [Fine-tuning of SFNO on a `256x256` grid for only 50 ADAM iterations to reach `1e-6` residual in the functional norm using FNO data](./examples/ex2_SFNO_finetune_fnodata.ipynb)
35+
- [Fine-tuning of SFNO on the `256x256` grid for the McWilliams 2d isotropic turbulence](./examples/ex2_SFNO_finetune_McWilliams2d.ipynb)
36+
- [Training of SFNO for only 5 epoch to match the inverse cascade of Kolmogorov flow](./examples/ex2_SFNO_5ep_spectra.ipynb)
37+
- [Baseline of FNO3d for fixed step size that requires preloading a normalizer](./examples/ex2_FNO3d_train_normalized.ipynb)
2338

2439
## Installation
2540
To install `torch-cfd`'s current release, simply do:
@@ -38,30 +53,20 @@ pip install -r requirements.txt
3853
The data are available at [https://huggingface.co/datasets/scaomath/navier-stokes-dataset](https://huggingface.co/datasets/scaomath/navier-stokes-dataset).
3954
Data generation instructions are available in the [SFNO folder](./fno).
4055

41-
42-
## Examples
43-
- Demos of different simulation setups:
44-
- [2D Lid-driven cavity with a random field perturbation using finite volume](./examples/Lid-driven_cavity_rk4_fvm.ipynb)
45-
- [2D decaying isotropic turbulence using the pseudo-spectral method](./examples/Kolmogrov2d_rk4_spectral_forced_turbulence.ipynb)
46-
- [2D Kolmogorov flow using finite volume method](./examples/Kolmogrov2d_rk4_fvm_forced_turbulence.ipynb)
47-
- Demos of Spatiotemporal FNO's training and evaluation using the neural operator-assisted fluid simulation pipelines
48-
- [Training of SFNO for only 15 epochs for the isotropic turbulence example](./examples/ex2_SFNO_train.ipynb)
49-
- [Training of SFNO for only ***10*** epochs with 1k samples and reach `1e-2` level of relative error](./examples/ex2_SFNO_train_fnodata.ipynb) using the data in the FNO paper, which to our best knowledge no operator learner can do this in <100 epochs in the small data regime.
50-
- [Fine-tuning of SFNO on a `256x256` grid for only 50 ADAM iterations to reach `1e-6` residual in the functional norm using FNO data](./examples/ex2_SFNO_finetune_fnodata.ipynb)
51-
- [Fine-tuning of SFNO on the `256x256` grid for the McWilliams 2d isotropic turbulence](./examples/ex2_SFNO_finetune_McWilliams2d.ipynb)
52-
- [Training of SFNO for only 5 epoch to match the inverse cascade of Kolmogorov flow](./examples/ex2_SFNO_5ep_spectra.ipynb)
53-
- [Baseline of FNO3d for fixed step size that requires preloading a normalizer](./examples/ex2_FNO3d_train_normalized.ipynb)
54-
5556
## Licenses
5657
The Apache 2.0 License in the root folder applies to the `torch-cfd` folder of the repo that is inherited from Google's original license file for `Jax-cfd`. The `fno` folder has the MIT license inherited from [NVIDIA's Neural Operator repo](https://github.com/neuraloperator/neuraloperator). Note: the license(s) in the subfolder takes precedence.
5758

5859
## Contributions
59-
PR welcome. Currently, the port of `torch-cfd` currently includes:
60+
PR welcome for enhancing [essential functionalities with TODO tags](./torch_cfd/README.md). Currently, the port of `torch-cfd` currently includes:
6061
- The pseudospectral method for vorticity uses anti-aliasing filtering techniques for nonlinear terms to maintain stability.
61-
- The finite volume method on a MAC grids for velocity, and using the projection scheme to impose the divergence free condition.
62+
- The finite volume method using MAC grids for velocity, together with a simple pressure projection scheme to impose the divergence free condition.
6263
- Temporal discretization: Currently it has only single-step RK4-family schemes uses explicit time-stepping for advection, either implicit or explicit time-stepping for diffusion.
63-
- Boundary conditions: periodic and Dirichlet boundary conditions for velocity, Neumann boundary for pressure.
64-
- Solvers: pseudoinverse (either FFT-based or SVD based), Jacobi- or Multigrid V-cycle-preconditioned Conjugate gradient.
64+
- Boundary conditions:
65+
- velocity: periodic, Dirichlet (function-valued or array-valued), Dirichlet-Neumann (Neumann has to be 0-valued).
66+
- pressure: periodic, Neumann, Neumann-Dirichlet mixed.
67+
- Solvers:
68+
- Pseudo-inverse (either FFT-based or SVD based)
69+
- Jacobi-, Gauss-Seidel-, or Multigrid V-cycle-preconditioned Conjugate gradient.
6570

6671
## Reference
6772

examples/Kolmogrov2d_rk4_fvm_forced_turbulence.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
},
1313
{
1414
"cell_type": "code",
15-
"execution_count": 1,
15+
"execution_count": null,
1616
"id": "29f7f30e",
1717
"metadata": {},
1818
"outputs": [],
@@ -21,7 +21,7 @@
2121
"from torch_cfd import grids, boundaries\n",
2222
"from torch_cfd.initial_conditions import filtered_velocity_field\n",
2323
"\n",
24-
"from torch_cfd.equations import stable_time_step\n",
24+
"from torch_cfd.spectral import stable_time_step\n",
2525
"from torch_cfd.fvm import RKStepper, NavierStokes2DFVMProjection\n",
2626
"from torch_cfd.forcings import KolmogorovForcing\n",
2727
"import torch_cfd.finite_differences as fdm\n",

0 commit comments

Comments
 (0)