Skip to content

Commit a393f69

Browse files
Aurora model implementation (#136)
* aurora model implementation * Temp commit for switching branches * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * a new implementation of Aurora (3d+Unstructured) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * checkpointing configuration with smaller nums * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * EINOPS and processor configuration added * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * EINOPS implementation in encoder and processor for clear reshaping * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * checkpointing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 01b9688 commit a393f69

File tree

8 files changed

+1164
-0
lines changed

8 files changed

+1164
-0
lines changed
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
"""
2+
Aurora: A Foundation Model for Earth System Science
3+
- Combines 3D Swin Transformer encoding
4+
- Perceiver processing for efficient computation
5+
- 3D decoding for spatial-temporal predictions
6+
"""
7+
8+
from .decoder import Decoder3D
9+
from .encoder import Swin3DEncoder
10+
from .model import AuroraModel, EarthSystemLoss
11+
from .processor import PerceiverProcessor
12+
13+
__version__ = "0.1.0"
14+
15+
__all__ = [
16+
"AuroraModel",
17+
"EarthSystemLoss",
18+
"Swin3DEncoder",
19+
"Decoder3D",
20+
"PerceiverProcessor",
21+
]
22+
23+
# Default configurations for different model sizes
24+
MODEL_CONFIGS = {
25+
"tiny": {
26+
"in_channels": 1,
27+
"out_channels": 1,
28+
"embed_dim": 48,
29+
"latent_dim": 256,
30+
"spatial_shape": (16, 16, 16),
31+
"max_seq_len": 2048,
32+
},
33+
"base": {
34+
"in_channels": 1,
35+
"out_channels": 1,
36+
"embed_dim": 96,
37+
"latent_dim": 512,
38+
"spatial_shape": (32, 32, 32),
39+
"max_seq_len": 4096,
40+
},
41+
"large": {
42+
"in_channels": 1,
43+
"out_channels": 1,
44+
"embed_dim": 192,
45+
"latent_dim": 1024,
46+
"spatial_shape": (64, 64, 64),
47+
"max_seq_len": 8192,
48+
},
49+
}
50+
51+
52+
def create_model(config="base", **kwargs):
53+
"""
54+
Create an Aurora model with specified configuration.
55+
56+
Args:
57+
config (str): Model size configuration ('tiny', 'base', or 'large')
58+
**kwargs: Override default configuration parameters
59+
60+
Returns:
61+
AuroraModel: Initialized model with specified configuration
62+
"""
63+
if config not in MODEL_CONFIGS:
64+
raise ValueError(
65+
f"Unknown configuration: {config}. Choose from {list(MODEL_CONFIGS.keys())}"
66+
)
67+
68+
# Start with default config and update with any provided kwargs
69+
model_config = MODEL_CONFIGS[config].copy()
70+
model_config.update(kwargs)
71+
72+
return AuroraModel(**model_config)
73+
74+
75+
def create_loss(alpha=0.5, beta=0.3, gamma=0.2):
76+
"""
77+
Create an EarthSystemLoss instance with specified weights.
78+
79+
Args:
80+
alpha (float): Weight for MSE loss
81+
beta (float): Weight for gradient loss
82+
gamma (float): Weight for physical consistency loss
83+
84+
Returns:
85+
EarthSystemLoss: Initialized loss function
86+
"""
87+
return EarthSystemLoss(alpha=alpha, beta=beta, gamma=gamma)
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""
2+
3D Decoder:
3+
- Takes processed latent representations and reconstructs output.
4+
- Uses transposed convolution to upscale back to spatial-temporal format.
5+
"""
6+
7+
import torch.nn as nn
8+
9+
10+
class Decoder3D(nn.Module):
11+
"""
12+
3D Decoder:
13+
- Takes processed latent representations and reconstructs the spatial-temporal output.
14+
- Uses transposed convolutions to upscale latent features to the original format.
15+
"""
16+
17+
def __init__(self, output_channels=1, embed_dim=96, target_shape=(32, 32, 32)):
18+
"""
19+
Args:
20+
output_channels (int): Number of channels in the output tensor (e.g., 1 for grayscale).
21+
embed_dim (int): Dimension of the latent features (matches the encoder's output).
22+
target_shape (tuple): The desired shape of the reconstructed 3D tensor (D, H, W).
23+
"""
24+
super().__init__()
25+
self.embed_dim = embed_dim
26+
self.target_shape = target_shape
27+
self.deconv1 = nn.ConvTranspose3d(
28+
embed_dim, output_channels, kernel_size=3, padding=1, stride=1
29+
)
30+
31+
def forward(self, x):
32+
"""
33+
Forward pass for the decoder.
34+
35+
Args:
36+
x (torch.Tensor): Input latent representation, shape (batch, seq_len, embed_dim).
37+
38+
Returns:
39+
torch.Tensor: Reconstructed 3D tensor, shape (batch, output_channels, *target_shape).
40+
"""
41+
batch_size = x.shape[0]
42+
depth, height, width = self.target_shape
43+
# Reshape latent features into 3D tensor
44+
x = x.view(batch_size, self.embed_dim, depth, height, width)
45+
# Transposed convolution to upscale to the final shape
46+
x = self.deconv1(x)
47+
return x
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""
2+
Swin 3D Transformer Encoder:
3+
- Uses a 3D convolution for initial feature extraction.
4+
- Applies layer normalization and reshapes data.
5+
- Uses a transformer-based encoder to learn spatial-temporal features.
6+
"""
7+
8+
import torch.nn as nn
9+
from einops import rearrange
10+
from einops.layers.torch import Rearrange
11+
12+
13+
class Swin3DEncoder(nn.Module):
14+
def __init__(self, in_channels=1, embed_dim=96):
15+
super().__init__()
16+
self.conv1 = nn.Conv3d(in_channels, embed_dim, kernel_size=3, padding=1, stride=1)
17+
self.norm = nn.LayerNorm(embed_dim)
18+
self.swin_transformer = nn.Transformer(
19+
d_model=embed_dim,
20+
nhead=8,
21+
num_encoder_layers=4,
22+
num_decoder_layers=4,
23+
dim_feedforward=embed_dim * 4,
24+
)
25+
self.embed_dim = embed_dim
26+
27+
# Define rearrangement patterns using einops
28+
self.to_transformer_format = Rearrange("b d h w c -> (d h w) b c")
29+
self.from_transformer_format = Rearrange("(d h w) b c -> b d h w c", d=None, h=None, w=None)
30+
31+
# To use rearrange function directly instead of the Rearrange layer
32+
def forward(self, x):
33+
# 3D convolution with einops rearrangement
34+
x = self.conv1(x)
35+
36+
# Rearrange for normalization using einops
37+
x = rearrange(x, "b c d h w -> b d h w c")
38+
x = self.norm(x)
39+
40+
# Store spatial dimensions for later reconstruction
41+
d, h, w = x.shape[1:4]
42+
43+
# Transform to sequence format for transformer
44+
x = rearrange(x, "b d h w c -> (d h w) b c")
45+
x = self.swin_transformer.encoder(x)
46+
47+
# Restore original spatial structure
48+
x = rearrange(x, "(d h w) b c -> b (d h w) c", d=d, h=h, w=w)
49+
50+
# Reshape to the expected output format (batch, seq_len, embed_dim)
51+
x = rearrange(x, "b (d h w) c -> b (d h w) c", d=d, h=h, w=w)
52+
53+
return x
54+
55+
def convolution(self, x):
56+
"""Apply 3D convolution with clear shape transformation."""
57+
return self.conv1(x) # b c d h w -> b embed_dim d h w
58+
59+
def normalization_layer(self, x):
60+
"""Apply layer normalization with einops rearrangement."""
61+
x = rearrange(x, "b c d h w -> b d h w c")
62+
return self.norm(x)
63+
64+
def transformer_encoder(self, x, spatial_dims):
65+
"""
66+
Apply transformer encoding with proper shape handling.
67+
68+
Args:
69+
x (torch.Tensor): Input tensor
70+
spatial_dims (tuple): Original (depth, height, width) dimensions
71+
"""
72+
d, h, w = spatial_dims
73+
x = self.to_transformer_format(x)
74+
x = self.swin_transformer.encoder(x)
75+
x = self.from_transformer_format(x, d=d, h=h, w=w)
76+
return x

0 commit comments

Comments
 (0)