Skip to content

Commit 08220f9

Browse files
Add Initial WeatherMesh-2 Implementation (#138)
* Push updates * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add test for irregular grid * Add working encoder Lots of inspiration and code from this: https://github.com/Brayden-Zhang/WeatherMesh with some changes to work with natten and some other changes. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add configurable decoder and WeatherMesh2 model * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add dacite configs * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix keeping the correct depth (vertical pressure level) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add new updates and change to pixi install --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 34406d1 commit 08220f9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

77 files changed

+698
-28
lines changed

.all-contributorsrc

100644100755
File mode changed.

.bumpversion.cfg

100644100755
File mode changed.

.github/workflows/release.yaml

100644100755
File mode changed.

.github/workflows/workflows.yaml

100644100755
Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,9 @@ jobs:
1212
fail-fast: true
1313
matrix:
1414
os: [ubuntu-latest]
15-
python-version: ["3.10", "3.11", "3.12"]
16-
torch-version: [2.3.0, 2.4.0]
15+
python-version: ["3.11", "3.12"]
16+
torch-version: [2.4.0]
1717
include:
18-
- torch-version: 2.3.0
19-
torchvision-version: 0.18.0
2018
- torch-version: 2.4.0
2119
torchvision-version: 0.19.0
2220
steps:

.gitignore

100644100755
File mode changed.

.pre-commit-config.yaml

100644100755
File mode changed.

Dockerfile

100644100755
File mode changed.

LICENSE

100644100755
File mode changed.

MANIFEST.in

100644100755
File mode changed.

README.md

100644100755
File mode changed.

environment_cpu.yml

100644100755
File mode changed.

environment_cuda.yml

100644100755
File mode changed.

graph_weather/__init__.py

100644100755
File mode changed.

graph_weather/data/IFSAnalysis_dataloader.py

100644100755
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ def __init__(self, filepath: str, features: list, start_year: int = 2016, end_ye
4343
"""
4444

4545
super().__init__()
46-
assert start_year <= end_year, (
47-
f"start_year ({start_year}) cannot be greater than end_year ({end_year})."
48-
)
46+
assert (
47+
start_year <= end_year
48+
), f"start_year ({start_year}) cannot be greater than end_year ({end_year})."
4949
assert start_year >= 2016 and start_year <= 2022, "Time data range from 2016 to 2022"
5050
assert end_year >= 2016 and end_year <= 2022, "Time data range from 2016 to 2022"
5151
self.data = xr.open_zarr(filepath)

graph_weather/data/__init__.py

100644100755
File mode changed.

graph_weather/data/const.py

100644100755
File mode changed.

graph_weather/data/dataloader.py

100644100755
File mode changed.

graph_weather/data/gencast_dataloader.py

100644100755
File mode changed.

graph_weather/data/nnja_ai.py

100644100755
File mode changed.

graph_weather/models/__init__.py

100644100755
File mode changed.

graph_weather/models/analysis.py

100644100755
File mode changed.

graph_weather/models/fengwu_ghr/__init__.py

100644100755
File mode changed.

graph_weather/models/fengwu_ghr/layers.py

100644100755
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,9 @@ def __init__(
105105
)
106106
)
107107
if self.res:
108-
assert image_size is not None and scale_factor is not None, (
109-
"If res=True, you must provide h, w and scale_factor"
110-
)
108+
assert (
109+
image_size is not None and scale_factor is not None
110+
), "If res=True, you must provide h, w and scale_factor"
111111
h, w = pair(image_size)
112112
s_h, s_w = pair(scale_factor)
113113
self.res_layers.append(

graph_weather/models/forecast.py

100644100755
File mode changed.

graph_weather/models/gencast/README.md

100644100755
File mode changed.

graph_weather/models/gencast/__init__.py

100644100755
File mode changed.

graph_weather/models/gencast/denoiser.py

100644100755
File mode changed.

graph_weather/models/gencast/graph/__init__.py

100644100755
File mode changed.

graph_weather/models/gencast/graph/graph_builder.py

100644100755
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,9 @@ def _init_khop_mesh_graph(self):
309309
edge_index,
310310
values=torch.ones_like(edge_index[0], dtype=torch.float32),
311311
size=(self._num_mesh_nodes, self._num_mesh_nodes),
312-
).to(self.khop_device) # cpu is more memory-efficient, why?
312+
).to(
313+
self.khop_device
314+
) # cpu is more memory-efficient, why?
313315

314316
adj_k = adj.coalesce()
315317
for _ in range(self.num_hops - 1):

graph_weather/models/gencast/graph/grid_mesh_connectivity.py

100644100755
File mode changed.

graph_weather/models/gencast/graph/icosahedral_mesh.py

100644100755
File mode changed.

graph_weather/models/gencast/graph/model_utils.py

100644100755
File mode changed.

graph_weather/models/gencast/images/animated.gif

100644100755
File mode changed.

graph_weather/models/gencast/images/autoregressive.gif

100644100755
File mode changed.

graph_weather/models/gencast/images/fullmodel.png

100644100755
File mode changed.

graph_weather/models/gencast/images/readme.md

100644100755
File mode changed.

graph_weather/models/gencast/layers/__init__.py

100644100755
File mode changed.

graph_weather/models/gencast/layers/decoder.py

100644100755
File mode changed.

graph_weather/models/gencast/layers/encoder.py

100644100755
File mode changed.

graph_weather/models/gencast/layers/experimental/__init__.py

100644100755
File mode changed.

graph_weather/models/gencast/layers/experimental/sparse_transformer.py

100644100755
File mode changed.

graph_weather/models/gencast/layers/modules.py

100644100755
File mode changed.

graph_weather/models/gencast/layers/processor.py

100644100755
File mode changed.

graph_weather/models/gencast/sampler.py

100644100755
File mode changed.

graph_weather/models/gencast/train.py

100644100755
File mode changed.

graph_weather/models/gencast/utils/__init__.py

100644100755
File mode changed.

graph_weather/models/gencast/utils/batching.py

100644100755
File mode changed.

graph_weather/models/gencast/utils/noise.py

100644100755
File mode changed.

graph_weather/models/gencast/utils/statistics.py

100644100755
File mode changed.

graph_weather/models/gencast/weighted_mse_loss.py

100644100755
File mode changed.

graph_weather/models/layers/__init__.py

100644100755
File mode changed.

graph_weather/models/layers/assimilator_decoder.py

100644100755
File mode changed.

graph_weather/models/layers/assimilator_encoder.py

100644100755
File mode changed.

graph_weather/models/layers/decoder.py

100644100755
File mode changed.

graph_weather/models/layers/encoder.py

100644100755
File mode changed.

graph_weather/models/layers/graph_net_block.py

100644100755
File mode changed.

graph_weather/models/layers/processor.py

100644100755
File mode changed.

graph_weather/models/losses.py

100644100755
Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(
2828
device: checks for device whether it supports gpu or not
2929
normalize: option for normalize
3030
"""
31-
# TODO Rescale by nominal static air density at each pressure level
31+
# TODO Rescale by nominal static air density at each pressure level, could be 1/pressure level or something similar
3232
super().__init__()
3333
self.feature_variance = torch.tensor(feature_variance)
3434
assert not torch.isnan(self.feature_variance).any()
@@ -55,15 +55,19 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor):
5555
"""
5656
self.feature_variance = self.feature_variance.to(pred.device)
5757
self.weights = self.weights.to(pred.device)
58+
print(pred.shape)
59+
print(target.shape)
60+
print(self.weights.shape)
5861

5962
out = (pred - target) ** 2
60-
63+
print(out.shape)
6164
if self.normalize:
6265
out = out / self.feature_variance
6366

6467
assert not torch.isnan(out).any()
6568
# Mean of the physical variables
6669
out = out.mean(-1)
70+
print(out.shape)
6771
# Weight by the latitude, as that changes, so does the size of the pixel
6872
out = out * self.weights.expand_as(out)
6973
assert not torch.isnan(out).any()

graph_weather/models/weathermesh/__init__.py

Whitespace-only changes.
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""
2+
Implementation based off the technical report and this repo: https://github.com/Brayden-Zhang/WeatherMesh
3+
"""
4+
5+
from dataclasses import dataclass
6+
7+
import dacite
8+
import einops
9+
import torch
10+
import torch.nn as nn
11+
from natten import NeighborhoodAttention3D
12+
13+
from graph_weather.models.weathermesh.layers import ConvUpBlock
14+
15+
16+
@dataclass
17+
class WeatherMeshDecoderConfig:
18+
latent_dim: int
19+
output_channels_2d: int
20+
output_channels_3d: int
21+
n_conv_blocks: int
22+
hidden_dim: int
23+
kernel_size: tuple
24+
num_heads: int
25+
num_transformer_layers: int
26+
27+
@staticmethod
28+
def from_json(json: dict) -> "WeatherMeshDecoder":
29+
return dacite.from_dict(data_class=WeatherMeshDecoderConfig, data=json)
30+
31+
def to_json(self) -> dict:
32+
return dacite.asdict(self)
33+
34+
35+
class WeatherMeshDecoder(nn.Module):
36+
def __init__(
37+
self,
38+
latent_dim,
39+
output_channels_2d,
40+
output_channels_3d,
41+
n_conv_blocks=3,
42+
hidden_dim=256,
43+
kernel_size: tuple = (5, 7, 7),
44+
num_heads: int = 8,
45+
num_transformer_layers: int = 3,
46+
):
47+
super().__init__()
48+
49+
# Transformer layers for initial decoding
50+
self.transformer_layers = nn.ModuleList(
51+
[
52+
NeighborhoodAttention3D(
53+
dim=latent_dim, num_heads=num_heads, kernel_size=kernel_size
54+
)
55+
for _ in range(num_transformer_layers)
56+
]
57+
)
58+
59+
# Split into pressure levels and surface paths
60+
self.split = nn.Conv3d(latent_dim, hidden_dim * (2**n_conv_blocks), kernel_size=1)
61+
62+
# Pressure levels (3D) path
63+
self.pressure_path = nn.ModuleList(
64+
[
65+
ConvUpBlock(
66+
hidden_dim * (2 ** (i + 1)),
67+
hidden_dim * (2**i) if i > 0 else output_channels_3d,
68+
is_3d=True,
69+
)
70+
for i in reversed(range(n_conv_blocks))
71+
]
72+
)
73+
74+
# Surface (2D) path
75+
self.surface_path = nn.ModuleList(
76+
[
77+
ConvUpBlock(
78+
hidden_dim * (2 ** (i + 1)),
79+
hidden_dim * (2**i) if i > 0 else output_channels_2d,
80+
)
81+
for i in reversed(range(n_conv_blocks))
82+
]
83+
)
84+
85+
def forward(self, latent: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
86+
# Needs to be (B,D,H,W,C) with Batch, Depth (vertical levels), Height, Width, Channels
87+
# Apply transformer layers
88+
for transformer in self.transformer_layers:
89+
latent = transformer(latent)
90+
91+
latent = einops.rearrange(latent, "B D H W C -> B C D H W")
92+
# Split features
93+
features = self.split(latent)
94+
pressure_features = features[:, :, :-1]
95+
surface_features = features[:, :, -1:]
96+
# Decode pressure levels
97+
for block in self.pressure_path:
98+
pressure_features = block(pressure_features)
99+
# Decode surface features
100+
surface_features = surface_features.squeeze(2)
101+
for block in self.surface_path:
102+
surface_features = block(surface_features)
103+
104+
return surface_features, pressure_features
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""
2+
Implementation based off the technical report and this repo: https://github.com/Brayden-Zhang/WeatherMesh
3+
"""
4+
5+
from dataclasses import dataclass
6+
7+
import dacite
8+
import einops
9+
import torch
10+
import torch.nn as nn
11+
from natten import NeighborhoodAttention3D
12+
13+
from graph_weather.models.weathermesh.layers import ConvDownBlock
14+
15+
16+
@dataclass
17+
class WeatherMeshEncoderConfig:
18+
input_channels_2d: int
19+
input_channels_3d: int
20+
latent_dim: int
21+
n_pressure_levels: int
22+
num_conv_blocks: int
23+
hidden_dim: int
24+
kernel_size: tuple
25+
num_heads: int
26+
num_transformer_layers: int
27+
28+
@staticmethod
29+
def from_json(json: dict) -> "WeatherMeshEncoder":
30+
return dacite.from_dict(data_class=WeatherMeshEncoderConfig, data=json)
31+
32+
def to_json(self) -> dict:
33+
return dacite.asdict(self)
34+
35+
36+
class WeatherMeshEncoder(nn.Module):
37+
def __init__(
38+
self,
39+
input_channels_2d: int,
40+
input_channels_3d: int,
41+
latent_dim: int,
42+
n_pressure_levels: int,
43+
num_conv_blocks: int = 3,
44+
hidden_dim: int = 256,
45+
kernel_size: tuple = (5, 7, 7),
46+
num_heads: int = 8,
47+
num_transformer_layers: int = 3,
48+
):
49+
super().__init__()
50+
51+
# Surface (2D) path
52+
self.surface_path = nn.ModuleList(
53+
[
54+
ConvDownBlock(
55+
input_channels_2d if i == 0 else hidden_dim * (2**i),
56+
hidden_dim * (2 ** (i + 1)),
57+
)
58+
for i in range(num_conv_blocks)
59+
]
60+
)
61+
62+
# Pressure levels (3D) path
63+
self.pressure_path = nn.ModuleList(
64+
[
65+
ConvDownBlock(
66+
input_channels_3d if i == 0 else hidden_dim * (2**i),
67+
hidden_dim * (2 ** (i + 1)),
68+
stride=(1, 2, 2), # Want to keep depth the same size
69+
is_3d=True,
70+
)
71+
for i in range(num_conv_blocks)
72+
]
73+
)
74+
75+
# Transformer layers for final encoding
76+
self.transformer_layers = nn.ModuleList(
77+
[
78+
NeighborhoodAttention3D(
79+
dim=latent_dim, kernel_size=kernel_size, num_heads=num_heads
80+
)
81+
for _ in range(num_transformer_layers)
82+
]
83+
)
84+
85+
# Final projection to latent space
86+
self.to_latent = nn.Conv3d(hidden_dim * (2**num_conv_blocks), latent_dim, kernel_size=1)
87+
88+
def forward(self, surface: torch.Tensor, pressure: torch.Tensor) -> torch.Tensor:
89+
# Process surface data
90+
for block in self.surface_path:
91+
surface = block(surface)
92+
93+
# Process pressure level data
94+
for block in self.pressure_path:
95+
pressure = block(pressure)
96+
# Combine features
97+
features = torch.cat(
98+
[pressure, surface.unsqueeze(2)], dim=2
99+
) # B C D H W currently, want it to be B D H W C
100+
101+
# Transform to latent space
102+
latent = self.to_latent(features)
103+
104+
# Reshape to get the shapes
105+
latent = einops.rearrange(latent, "B C D H W -> B D H W C")
106+
# Apply transformer layers
107+
for transformer in self.transformer_layers:
108+
latent = transformer(latent)
109+
return latent

0 commit comments

Comments
 (0)