Skip to content

Commit ff43cc1

Browse files
Generating physically-consistent high-resolution climate data with hard-constrained neural networks (#137)
* Generating physically-consistent high-resolution climate data with hard-constrained neural networks * Generating physically-consistent high-resolution climate data with hard-constrained neural networks * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Modify constraint layer and update UTs to comply with constraints Add functional logic for additive, multiplicative, and softmax physical constraints * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Use einops for tensor manipulation, update constraint set-up configuration and add UTs for each constraint type * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Minor fix due to rebase * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add default constraint as none and fix NormalizedLoss function when there are no constraints applied for forecasting * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Rearrange graph to grid only when constraints are to be applied * Fix test_forecaster_and_loss_irregular after rebase * Resolve ruff errors regarding docstrings * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 526de40 commit ff43cc1

File tree

4 files changed

+405
-10
lines changed

4 files changed

+405
-10
lines changed

graph_weather/models/forecast.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
from typing import Optional
44

55
import torch
6+
from einops import rearrange, repeat
67
from huggingface_hub import PyTorchModelHubMixin
78

89
from graph_weather.models import Decoder, Encoder, Processor
10+
from graph_weather.models.layers.constraint_layer import PhysicalConstraintLayer
911

1012

1113
class GraphWeatherForecaster(torch.nn.Module, PyTorchModelHubMixin):
12-
"""Main weather prediction model from the paper"""
14+
"""Main weather prediction model from the paper with physical constraints"""
1315

1416
def __init__(
1517
self,
@@ -29,6 +31,7 @@ def __init__(
2931
hidden_layers_decoder: int = 2,
3032
norm_type: str = "LayerNorm",
3133
use_checkpointing: bool = False,
34+
constraint_type: str = "none",
3235
):
3336
"""
3437
Graph Weather Model based off https://arxiv.org/pdf/2202.07575.pdf
@@ -53,11 +56,24 @@ def __init__(
5356
norm_type: Type of norm for the MLPs
5457
one of 'LayerNorm', 'GraphNorm', 'InstanceNorm', 'BatchNorm', 'MessageNorm', or None
5558
use_checkpointing: Use gradient checkpointing to reduce model memory
59+
constraint_type: Type of constraint to apply for physical constraints
60+
one of 'additive', 'multiplicative', 'softmax', or 'none'
5661
"""
5762
super().__init__()
5863
self.feature_dim = feature_dim
64+
self.constraint_type = constraint_type
5965
if output_dim is None:
6066
output_dim = self.feature_dim
67+
self.output_dim = output_dim
68+
69+
# Compute the geographical grid shape from lat_lons.
70+
unique_lats = sorted(set(lat for lat, _ in lat_lons))
71+
unique_lons = sorted(set(lon for _, lon in lat_lons))
72+
self.grid_shape = (len(unique_lats), len(unique_lons)) # (H, W)
73+
74+
# Store original node order and create grid mapping
75+
self.original_lat_lons = lat_lons.copy()
76+
self._create_grid_mapping(unique_lats, unique_lons)
6177

6278
self.encoder = Encoder(
6379
lat_lons=lat_lons,
@@ -98,6 +114,51 @@ def __init__(
98114
use_checkpointing=use_checkpointing,
99115
)
100116

117+
# Add physical constraint layer if constraint_type is not "none"
118+
if self.constraint_type != "none":
119+
self.constraint = PhysicalConstraintLayer(
120+
model=self,
121+
grid_shape=self.grid_shape,
122+
constraint_type=constraint_type,
123+
upsampling_factor=1,
124+
)
125+
126+
def _create_grid_mapping(self, unique_lats, unique_lons):
127+
"""Create (row,col) mapping for original node order"""
128+
self.node_to_grid = []
129+
for lat, lon in self.original_lat_lons:
130+
row = int(
131+
(lat - min(unique_lats))
132+
/ (max(unique_lats) - min(unique_lats))
133+
* (len(unique_lats) - 1)
134+
)
135+
col = int(
136+
(lon - min(unique_lons))
137+
/ (max(unique_lons) - min(unique_lons))
138+
* (len(unique_lons) - 1)
139+
)
140+
self.node_to_grid.append((row, col))
141+
142+
def graph_to_grid(self, graph_tensor):
143+
"""
144+
145+
Convert graph tensor to grid using spatial mapping:
146+
[B, N, C] -> [B, C, H, W]
147+
"""
148+
batch_size, num_nodes, features = graph_tensor.shape
149+
grid = torch.zeros(batch_size, features, *self.grid_shape)
150+
for node_idx, (row, col) in enumerate(self.node_to_grid):
151+
grid[..., row, col] = graph_tensor[..., node_idx, :]
152+
return grid
153+
154+
def grid_to_graph(self, grid_tensor):
155+
"""Convert grid to graph tensor: [B, C, H, W] -> [B, N, C]"""
156+
batch_size, features, H, W = grid_tensor.shape
157+
graph = torch.zeros(batch_size, H * W, features)
158+
for node_idx, (row, col) in enumerate(self.node_to_grid):
159+
graph[..., node_idx, :] = grid_tensor[..., row, col]
160+
return graph
161+
101162
def forward(self, features: torch.Tensor) -> torch.Tensor:
102163
"""
103164
Compute the new state of the forecast
@@ -111,4 +172,22 @@ def forward(self, features: torch.Tensor) -> torch.Tensor:
111172
x, edge_idx, edge_attr = self.encoder(features)
112173
x = self.processor(x, edge_idx, edge_attr)
113174
x = self.decoder(x, features[..., : self.feature_dim])
175+
176+
# Here, assume decoder output x is a 4D tensor,
177+
# e.g. [B, output_dim, H, W] where H and W are grid dimensions.
178+
# Convert graph output to grid format
179+
180+
# Apply physical constraints to decoder output
181+
if self.constraint_type != "none":
182+
x = rearrange(x, "b (h w) c -> b c h w", h=self.grid_shape[0], w=self.grid_shape[1])
183+
# Extract the low-res reference from the input.
184+
# (Original features has shape [B, num_nodes, feature_dim])
185+
lr = features[..., : self.feature_dim] # shape: [B, num_nodes, feature_dim]
186+
# Convert from node format to grid format using the grid_shape computed in __init__
187+
# From [B, num_nodes, feature_dim] to [B, feature_dim, H, W]
188+
lr = rearrange(lr, "b (h w) c -> b c h w", h=self.grid_shape[0], w=self.grid_shape[1])
189+
if lr.size(1) != x.size(1):
190+
repeat_factor = x.size(1) // lr.size(1)
191+
lr = repeat(lr, "b c h w -> b (r c) h w", r=repeat_factor)
192+
x = self.constraint(x, lr)
114193
return x
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
"""Module for physical constraint layers used in graph weather models.
2+
3+
This module implements several constraints on a network’s intermediate outputs,
4+
ensuring physical consistency with an input at a lower resolution.
5+
6+
"""
7+
8+
import torch
9+
import torch.nn as nn
10+
11+
12+
class PhysicalConstraintLayer(nn.Module):
13+
"""
14+
15+
This module implements several constraint types on the network’s intermediate outputs ỹ,
16+
given the corresponding low-resolution input x. The following equations are implemented
17+
(with all operations acting per patch – here, a patch is the full grid of H×W pixels):
18+
19+
Additive constraint:
20+
y = ỹ + x - avg(ỹ)
21+
22+
Multiplicative constraint:
23+
y = ỹ * ( x / avg(ỹ) )
24+
25+
Softmax constraint:
26+
y = exp(ỹ) * ( x / sum(exp(ỹ)) )
27+
28+
We assume that both the intermediate outputs and the low-resolution reference are 4D
29+
tensors in grid format, with shape [B, C, H, W], where n = H*W is the number of pixels
30+
(or nodes) in a patch.
31+
"""
32+
33+
def __init__(
34+
self, model, grid_shape, upsampling_factor, constraint_type="none", exp_factor=1.0
35+
):
36+
"""Initialize the PhysicalConstraintLayer.
37+
38+
Args:
39+
model (nn.Module): The model containing the helper methods
40+
'graph_to_grid' and 'grid_to_graph'.
41+
grid_shape (tuple): Expected spatial dimensions (H, W) of the
42+
high-resolution grid.
43+
upsampling_factor (int): Factor by which the low-resolution grid is upsampled.
44+
constraint_type (str, optional): The constraint to apply. Options are
45+
'additive', 'multiplicative', or 'softmax'. Defaults to "none".
46+
exp_factor (float, optional): Exponent factor for the softmax constraint.
47+
Defaults to 1.0.
48+
"""
49+
super().__init__()
50+
self.model = model
51+
self.constraint_type = constraint_type
52+
self.grid_shape = grid_shape
53+
self.exp_factor = exp_factor
54+
self.upsampling_factor = upsampling_factor
55+
self.pool = nn.AvgPool2d(kernel_size=upsampling_factor)
56+
57+
def forward(self, hr_graph, lr_graph):
58+
"""Apply the selected physical constraint.
59+
60+
Processes the high-resolution output and low-resolution input by converting
61+
between graph and grid formats as needed, and then applying the specified constraint.
62+
63+
Args:
64+
hr_graph (torch.Tensor): High-resolution model output in either graph (3D)
65+
or grid (4D) format.
66+
lr_graph (torch.Tensor): Low-resolution input in the corresponding
67+
graph or grid format.
68+
69+
Returns:
70+
torch.Tensor: The adjusted output in graph format.
71+
"""
72+
# Check if inputs are in graph (3D) or grid (4D) formats.
73+
if hr_graph.dim() == 3:
74+
# Convert graph format to grid format
75+
hr_grid = self.model.graph_to_grid(hr_graph)
76+
lr_grid = self.model.graph_to_grid(lr_graph)
77+
elif hr_graph.dim() == 4:
78+
# Already in grid format: [B, C, H, W]
79+
_, _, H, W = hr_graph.shape
80+
if (H, W) != self.grid_shape:
81+
raise ValueError(f"Expected spatial dimensions {self.grid_shape}, got {(H, W)}")
82+
hr_grid = hr_graph
83+
lr_grid = lr_graph
84+
else:
85+
raise ValueError("Input tensor must be either 3D (graph) or 4D (grid).")
86+
87+
# Apply constraint based on type in grid format
88+
if self.constraint_type == "additive":
89+
result = self.additive_constraint(hr_grid, lr_grid)
90+
elif self.constraint_type == "multiplicative":
91+
result = self.multiplicative_constraint(hr_grid, lr_grid)
92+
elif self.constraint_type == "softmax":
93+
result = self.softmax_constraint(hr_grid, lr_grid)
94+
else:
95+
raise ValueError(f"Unknown constraint type: {self.constraint_type}")
96+
97+
# Convert grid back to graph format
98+
return self.model.grid_to_graph(result)
99+
100+
def additive_constraint(self, hr, lr):
101+
"""Enforces local conservation using an additive correction:
102+
y = ỹ + ( x - avg(ỹ) )
103+
where avg(ỹ) is computed per patch (via an average-pooling layer).
104+
105+
For the additive constraint we follow the paper’s formulation using a Kronecker
106+
product to expand the discrepancy between the low-resolution field and the
107+
average of the high-resolution output.
108+
109+
hr: high-resolution tensor [B, C, H_hr, W_hr]
110+
lr: low-resolution tensor [B, C, h_lr, w_lr]
111+
(with H_hr = upsampling_factor * h_lr & W_hr = upsampling_factor * w_lr)
112+
"""
113+
# Convert grids to graph format using model's mapping
114+
hr_graph = self.model.grid_to_graph(hr)
115+
lr_graph = self.model.grid_to_graph(lr)
116+
117+
# Apply constraint logic
118+
# Compute average over NODES
119+
avg_hr = hr_graph.mean(dim=1, keepdim=True)
120+
diff = lr_graph - avg_hr
121+
122+
# Expand difference using spatial mapping
123+
diff_expanded = diff.repeat(1, self.upsampling_factor**2, 1)
124+
125+
# Apply correction and convert back to GRID format
126+
adjusted_graph = hr_graph + diff_expanded
127+
return self.model.graph_to_grid(adjusted_graph)
128+
129+
def multiplicative_constraint(self, hr, lr):
130+
"""Enforce conservation using a multiplicative correction in graph space.
131+
132+
The correction is applied by scaling the high-resolution output by a ratio computed
133+
from the low-resolution input and the average of the high-resolution output.
134+
135+
Args:
136+
hr (torch.Tensor): High-resolution tensor in grid format [B, C, H_hr, W_hr].
137+
lr (torch.Tensor): Low-resolution tensor in grid format [B, C, h_lr, w_lr].
138+
139+
Returns:
140+
torch.Tensor: Adjusted high-resolution tensor in grid format.
141+
"""
142+
# Convert grids to graph format using model's mapping
143+
hr_graph = self.model.grid_to_graph(hr)
144+
lr_graph = self.model.grid_to_graph(lr)
145+
146+
# Apply constraint logic
147+
# Compute average over NODES
148+
avg_hr = hr_graph.mean(dim=1, keepdim=True)
149+
lr_patch_avg = lr_graph.mean(dim=1, keepdim=True)
150+
151+
# Compute ratio and expand to match HR graph structure
152+
ratio = lr_patch_avg / (avg_hr + 1e-8)
153+
154+
# Apply multiplicative correction and convert back to GRID format
155+
adjusted_graph = hr_graph * ratio
156+
return self.model.graph_to_grid(adjusted_graph)
157+
158+
def softmax_constraint(self, y, lr):
159+
"""Apply a softmax-based constraint correction.
160+
161+
The softmax correction scales the exponentiated high-resolution output so that the
162+
sum over spatial blocks matches the low-resolution reference.
163+
164+
Args:
165+
y (torch.Tensor): High-resolution tensor in grid format [B, C, H, W].
166+
lr (torch.Tensor): Low-resolution tensor in grid format [B, C, h, w].
167+
168+
Returns:
169+
torch.Tensor: Adjusted high-resolution tensor in grid format after applying
170+
the softmax constraint.
171+
"""
172+
# Apply the exponential function
173+
y = torch.exp(self.exp_factor * y)
174+
175+
# Pool over spatial blocks
176+
kernel_area = self.upsampling_factor**2
177+
sum_y = self.pool(y) * kernel_area
178+
179+
# Ensure that lr * (1/sum_y) is contiguous
180+
ratio = (lr * (1 / sum_y)).contiguous()
181+
182+
# Use device of lr for kron expansion:
183+
device = lr.device
184+
expansion = torch.ones((self.upsampling_factor, self.upsampling_factor), device=device)
185+
186+
# Expand the low-resolution ratio and correct the y values so that the block sum matches lr.
187+
out = y * torch.kron(ratio, expansion)
188+
return out

graph_weather/models/losses.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,12 @@ def __init__(
3232
super().__init__()
3333
self.feature_variance = torch.tensor(feature_variance)
3434
assert not torch.isnan(self.feature_variance).any()
35-
weights = []
36-
for lat, lon in lat_lons:
37-
weights.append(np.cos(lat * np.pi / 180.0))
38-
self.weights = torch.tensor(weights, dtype=torch.float)
35+
# Compute unique latitudes from the provided lat/lon pairs.
36+
unique_lats = sorted(set(lat for lat, _ in lat_lons))
37+
# Use the cosine of each unique latitude (converted to radians) as its weight.
38+
self.weights = torch.tensor(
39+
[np.cos(lat * np.pi / 180.0) for lat in unique_lats], dtype=torch.float
40+
)
3941
self.normalize = normalize
4042
assert not torch.isnan(self.weights).any()
4143

@@ -67,8 +69,24 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor):
6769
assert not torch.isnan(out).any()
6870
# Mean of the physical variables
6971
out = out.mean(-1)
70-
print(out.shape)
71-
# Weight by the latitude, as that changes, so does the size of the pixel
72-
out = out * self.weights.expand_as(out)
72+
73+
# Flatten all dimensions except the batch dimension.
74+
B, *dims = out.shape
75+
num_nodes = np.prod(
76+
dims
77+
) # Total number of grid nodes (e.g., if grid is HxW, then num_nodes = H*W)
78+
out = out.view(B, num_nodes)
79+
80+
# Determine the number of unique latitude weights and infer the number of grid columns.
81+
num_unique = self.weights.shape[0] # e.g., number of unique latitudes (rows)
82+
num_lon = num_nodes // num_unique # e.g. if 2592 nodes and 36 unique lat, then num_lon=72
83+
84+
# Tile the unique latitude weights into a full weight grid
85+
weight_grid = self.weights.unsqueeze(1).expand(num_unique, num_lon).reshape(1, num_nodes)
86+
weight_grid = weight_grid.expand(B, num_nodes) # Now weight_grid is [B, num_nodes]
87+
88+
# Multiply the per-node error by the corresponding weight.
89+
out = out * weight_grid
90+
7391
assert not torch.isnan(out).any()
7492
return out.mean()

0 commit comments

Comments
 (0)