Skip to content

Commit 4cc4829

Browse files
committed
Refactor data generation imports and update setup dependencies
1 parent b9a0955 commit 4cc4829

File tree

8 files changed

+538
-320
lines changed

8 files changed

+538
-320
lines changed

examples/Kolmogrov2d_rk4_cn_forced_turbulence.ipynb

Lines changed: 39 additions & 32 deletions
Large diffs are not rendered by default.

setup.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,14 @@
1414
keywords = ['pytorch', 'cfd', 'pde', 'spectral', 'fluid dynamics', 'deep learning', 'neural operator'],
1515
python_requires='>=3.10',
1616
install_requires=[
17-
'numpy>=1.24.0',
17+
'numpy>=2.2.0',
1818
'torch>=2.5.0',
19+
'xarray>=2025.3.1',
20+
'tqdm>=4.62.0',
21+
'einops>=0.8.0',
22+
'dill>=0.4.0',
23+
'matplotlib>=3.5.0',
24+
'seaborn>=0.13.0',
1925
],
2026
classifiers=[
2127
'Development Status :: 4 - Beta',

sfno/data_gen/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .data_utils import *
2+
from .solvers import *
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
# The MIT License (MIT)
2+
# Copyright © 2024 Shuhao Cao
3+
4+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
5+
6+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
7+
8+
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
9+
10+
import os, sys
11+
12+
import dill
13+
14+
import torch
15+
import torch.fft as fft
16+
17+
from torch_cfd.grids import *
18+
from torch_cfd.equations import *
19+
from torch_cfd.initial_conditions import *
20+
from torch_cfd.finite_differences import *
21+
from torch_cfd.forcings import *
22+
23+
from tqdm import tqdm
24+
from data_utils import *
25+
from solvers import *
26+
27+
import logging
28+
29+
from sfno.pipeline import DATA_PATH, LOG_PATH
30+
31+
32+
def main(args):
33+
"""
34+
Generate the Kolmogorov 2d flow data in [1] that are used an examples in [2].
35+
36+
[1]: Kolmogorov, A. N. (1941). The local structure of turbulence in incompressible viscous fluid for very large Reynolds. Numbers. In Dokl. Akad. Nauk SSSR, 30, 301.
37+
38+
[2]: Kochkov, D., Smith, J. A., Alieva, A., Wang, Q., Brenner, M. P., & Hoyer, S. (2021). Machine learning-accelerated computational fluid dynamics. Proceedings of the National Academy of Sciences, 118(21), e2101784118.
39+
40+
Training dataset:
41+
>>> python data_gen_Kolmogorov2d.py --num-samples 1152 --batch-size 128 --grid-size 256 --subsample 4 --visc 1e-3 --dt 1e-3 --time 10 --time-warmup 4.5 --num-steps 100 --diam "2*torch.pi"
42+
43+
Testing dataset for plotting the enstrohpy spectrum:
44+
>>> python data_gen_Kolmogorov2d.py --num-samples 16 --batch-size 8 --grid-size 256 --subsample 1 --visc 1e-3 --dt 1e-3 --time 10 --time-warmup 4.5 --num-steps 100 --diam "2*torch.pi" --double
45+
"""
46+
args = args.parse_args()
47+
48+
current_time = datetime.now().strftime("%d_%b_%Y_%Hh%Mm")
49+
log_name = "".join(os.path.basename(__file__).split(".")[:-1])
50+
51+
log_filename = os.path.join(LOG_PATH, f"{current_time}_{log_name}.log")
52+
logger = get_logger(log_filename)
53+
54+
total_samples = args.num_samples
55+
batch_size = args.batch_size # 128
56+
n = args.grid_size # 256
57+
scale = args.scale
58+
viscosity = args.visc
59+
dt = args.dt # 1e-3
60+
T = args.time # 10
61+
subsample = args.subsample # 4
62+
ns = n // subsample
63+
T_warmup = args.time_warmup # 4.5
64+
num_snapshots = args.num_steps # 100
65+
random_state = args.seed
66+
peak_wavenumber = args.peak_wavenumber # 4
67+
diam = (
68+
eval(args.diam) if isinstance(args.diam, str) else args.diam
69+
) # "2 * torch.pi"
70+
force_rerun = args.force_rerun
71+
72+
logger = logging.getLogger()
73+
logger.info(f"Generating data for Kolmogorov 2d flow with {total_samples} samples")
74+
75+
max_velocity = args.max_velocity # 5
76+
dt = stable_time_step(diam / n, dt, max_velocity, viscosity=viscosity)
77+
logger.info(f"Using dt = {dt}")
78+
79+
warmup_steps = int(T_warmup / dt)
80+
total_steps = int((T - T_warmup) / dt)
81+
record_every_iters = int(total_steps / num_snapshots)
82+
83+
dtype = torch.float64 if args.double else torch.float32
84+
dtype_str = "_fp64" if args.double else ""
85+
filename = args.filename
86+
if filename is None:
87+
filename = f"Kolmogorov2d{dtype_str}_{ns}x{ns}_N{total_samples}_v{viscosity:.0e}_T{num_snapshots}.pt".replace(
88+
"e-0", "e-"
89+
)
90+
args.filename = filename
91+
data_filepath = os.path.join(DATA_PATH, filename)
92+
if os.path.exists(data_filepath) and not force_rerun:
93+
logger.info(f"Data already exists at {data_filepath}")
94+
return
95+
elif os.path.exists(data_filepath) and force_rerun:
96+
logger.info(f"Force rerun and save data to {data_filepath}")
97+
os.remove(data_filepath)
98+
else:
99+
logger.info(f"Save data to {data_filepath}")
100+
101+
cuda = not args.no_cuda and torch.cuda.is_available()
102+
no_tqdm = args.no_tqdm
103+
device = torch.device("cuda:0" if cuda else "cpu")
104+
105+
torch.set_default_dtype(dtype)
106+
logger.info(f"Using device: {device} | dtype: {dtype}")
107+
108+
grid = Grid(shape=(n, n), domain=((0, diam), (0, diam)), device=device)
109+
110+
forcing_fn = KolmogorovForcing(
111+
grid=grid,
112+
scale=scale,
113+
k=peak_wavenumber,
114+
swap_xy=False,
115+
)
116+
117+
ns2d = NavierStokes2DSpectral(
118+
viscosity=viscosity,
119+
grid=grid,
120+
drag=0.1,
121+
smooth=True,
122+
forcing_fn=forcing_fn,
123+
solver=rk4_crank_nicolson,
124+
).to(device)
125+
126+
num_batches = total_samples // batch_size
127+
for i, idx in enumerate(range(0, total_samples, batch_size)):
128+
logger.info(f"Generate trajectory for batch [{i+1}/{num_batches}]")
129+
logger.info(
130+
f"random states: {random_state + idx} to {random_state + idx + batch_size-1}"
131+
)
132+
133+
vort_init = torch.stack(
134+
[
135+
curl_2d(
136+
filtered_velocity_field(
137+
grid,
138+
max_velocity,
139+
peak_wavenumber,
140+
random_state=random_state + i + k,
141+
)
142+
).data
143+
for k in range(batch_size)
144+
]
145+
)
146+
vort_hat = fft.rfft2(vort_init).to(device)
147+
148+
with tqdm(total=warmup_steps, disable=no_tqdm) as pbar:
149+
for j in range(warmup_steps):
150+
vort_hat, _ = ns2d.step(vort_hat, dt)
151+
if j % 100 == 0:
152+
vort_norm = torch.linalg.norm(fft.irfft2(vort_hat)).item() / n
153+
desc = (
154+
datetime.now().strftime("%d-%b-%Y %H:%M:%S")
155+
+ f" - Warmup | vort_hat ell2 norm {vort_norm:.4e}"
156+
)
157+
pbar.set_description(desc)
158+
pbar.update(100)
159+
160+
result = get_trajectory_rk4(
161+
ns2d,
162+
vort_hat,
163+
dt,
164+
num_steps=total_steps,
165+
record_every_steps=record_every_iters,
166+
pbar=not no_tqdm,
167+
)
168+
169+
for field, value in result.items():
170+
value = fft.irfft2(value).real.cpu().to(dtype)
171+
logger.info(
172+
f"variable: {field} | shape: {value.shape} | dtype: {value.dtype}"
173+
)
174+
if subsample > 1:
175+
result[field] = F.interpolate(value, size=(ns, ns), mode="bilinear")
176+
else:
177+
result[field] = value
178+
179+
result["random_states"] = torch.tensor(
180+
[random_state + idx + k for k in range(batch_size)], dtype=torch.int32
181+
)
182+
logger.info(f"Saving batch [{i+1}/{num_batches}] to {data_filepath}")
183+
save_pickle(result, data_filepath)
184+
del result
185+
186+
pickle_to_pt(data_filepath)
187+
logger.info(f"Done saving.")
188+
if args.demo_plots:
189+
try:
190+
verify_trajectories(
191+
data_filepath,
192+
dt=record_every_iters * dt,
193+
T_warmup=T_warmup,
194+
n_samples=1,
195+
)
196+
except Exception as e:
197+
logger.error(f"Error in plotting: {e}")
198+
199+
200+
if __name__ == "__main__":
201+
args = get_args("Params Kolmogorov 2d flow data generation")
202+
main(args)

sfno/data_gen/data_gen_McWilliams2d.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
from torch_cfd.initial_conditions import *
2020
from torch_cfd.finite_differences import *
2121
from torch_cfd.forcings import *
22+
2223
from tqdm import tqdm
23-
from data_gen import *
24+
from data_utils import *
25+
from solvers import *
2426
import logging
2527

2628
from sfno.pipeline import DATA_PATH, LOG_PATH

sfno/data_gen/data_gen_fno.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from grf import GRF2d
1111
from solvers import *
12-
from data_gen import *
12+
from data_utils import *
1313
from sfno.pipeline import DATA_PATH, LOG_PATH
1414

1515
def main(args):

0 commit comments

Comments
 (0)