Skip to content

Commit ab992fd

Browse files
Implementation of other NNJA sensors (#134)
* nnjai support * ruff format * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * changes as per requested * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test_nnjai.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * NNJAI with different tensors * Refactor: Updated naming for consistency and clarity * Refactor: data/__init__.py * pytest fixture implementation * nnjai_wrapp removal * removal nnjai_wrapp.pr * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 251f4db commit ab992fd

File tree

4 files changed

+197
-111
lines changed

4 files changed

+197
-111
lines changed

graph_weather/__init__.py

100755100644
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Main import for the complete models"""
22

3-
from .data.nnja_ai import AMSUDataset, collate_fn
3+
from .data.nnja_ai import SensorDataset, collate_fn
44
from .models.analysis import GraphWeatherAssimilator
55
from .models.forecast import GraphWeatherForecaster

graph_weather/data/__init__.py

100755100644
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""Dataloaders and data processing utilities"""
22

3-
from .nnja_ai import AMSUDataset, collate_fn
3+
from .nnja_ai import SensorDataset, collate_fn

graph_weather/data/nnja_ai.py

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
2-
A custom PyTorch Dataset implementation for AMSU datasets.
2+
A custom PyTorch Dataset implementation for various sensors like AMSU, ATMS, MHS, IASI, CrIS
33
4-
This script defines a custom PyTorch Dataset (`AMSUDataset`) for working with AMSU datasets.
54
The dataset is loaded via the nnja library's `DataCatalog` and filtered for specific times and
65
variables. Each data point consists of a timestamp, latitude, longitude, and associated metadata.
76
"""
@@ -18,36 +17,62 @@
1817
)
1918

2019

21-
class AMSUDataset(Dataset):
22-
"""A custom PyTorch Dataset for handling AMSU data.
20+
class SensorDataset(Dataset):
21+
"""A custom PyTorch Dataset for handling various sensor data."""
2322

24-
This dataset retrieves observations and their metadata, filtered by the provided time and
25-
variable descriptors.
26-
"""
27-
28-
def __init__(self, dataset_name, time, primary_descriptors, additional_variables):
29-
"""Initialize the AMSU dataset loader.
23+
def __init__(
24+
self, dataset_name, time, primary_descriptors, additional_variables, sensor_type="AMSU"
25+
):
26+
"""Initialize the dataset loader for various sensors.
3027
3128
Args:
3229
dataset_name: Name of the dataset to load.
3330
time: Specific timestamp to filter the data.
34-
primary_descriptors: List of primary descriptor variables to include (e.g.,
35-
OBS_TIMESTAMP, LAT, LON).
31+
primary_descriptors: List of primary descriptor variables to include (e.g., OBS_TIMESTAMP, LAT, LON).
3632
additional_variables: List of additional variables to include in metadata.
33+
sensor_type: Type of sensor (AMSU, ATMS, MHS, IASI, CrIS)
3734
"""
3835
self.dataset_name = dataset_name
3936
self.time = time
4037
self.primary_descriptors = primary_descriptors
4138
self.additional_variables = additional_variables
39+
self.sensor_type = sensor_type # New argument for selecting sensor type
4240

4341
# Load data catalog and dataset
4442
self.catalog = DataCatalog(skip_manifest=True)
4543
self.dataset = self.catalog[self.dataset_name]
4644
self.dataset.load_manifest()
4745

48-
self.dataset = self.dataset.sel(
49-
time=self.time, variables=self.primary_descriptors + self.additional_variables
50-
)
46+
if self.sensor_type == "AMSU":
47+
self.dataset = self.dataset.sel(
48+
time=self.time,
49+
variables=self.primary_descriptors + [f"TMBR_000{i:02d}" for i in range(1, 16)],
50+
)
51+
elif self.sensor_type == "ATMS":
52+
self.dataset = self.dataset.sel(
53+
time=self.time,
54+
variables=self.primary_descriptors + [f"TMBR_000{i:02d}" for i in range(1, 23)],
55+
)
56+
elif self.sensor_type == "MHS":
57+
self.dataset = self.dataset.sel(
58+
time=self.time,
59+
variables=self.primary_descriptors + [f"TMBR_000{i:02d}" for i in range(1, 6)],
60+
)
61+
elif self.sensor_type == "IASI":
62+
self.dataset = self.dataset.sel(
63+
time=self.time,
64+
variables=self.primary_descriptors
65+
+ ["SCRA_" + str(i).zfill(5) for i in range(1, 617)],
66+
)
67+
elif self.sensor_type == "CrIS":
68+
self.dataset = self.dataset.sel(
69+
time=self.time,
70+
variables=self.primary_descriptors
71+
+ [f"SRAD01_{str(i).zfill(5)}" for i in range(1, 432)],
72+
)
73+
else:
74+
raise ValueError(f"Unsupported sensor type: {self.sensor_type}")
75+
5176
self.dataframe = self.dataset.load_dataset(engine="pandas")
5277

5378
for col in primary_descriptors:
@@ -63,14 +88,7 @@ def __len__(self):
6388
return len(self.dataframe)
6489

6590
def __getitem__(self, index):
66-
"""Return the observation and metadata for a given index.
67-
68-
Args:
69-
index: Index of the observation to retrieve.
70-
71-
Returns:
72-
A dictionary containing timestamp, latitude, longitude, and metadata.
73-
"""
91+
"""Return the observation and metadata for a given index."""
7492
row = self.dataframe.iloc[index]
7593
time = row["OBS_TIMESTAMP"].timestamp()
7694
latitude = row["LAT"]

tests/test_nnjai.py

100755100644
Lines changed: 154 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,111 +1,185 @@
11
"""
2-
Tests for the nnjai_wrapp module in the graph_weather package.
3-
4-
This file contains unit tests for AMSUDataset and collate_fn functions.
2+
Unit tests for the `SensorDataset` class, mocking the `DataCatalog` to simulate sensor data loading and validate dataset behavior.
3+
The tests ensure correct handling of data types, shapes, and batch processing for various sensor types.
54
"""
65

76
from datetime import datetime
87
from unittest.mock import MagicMock, patch
9-
8+
import numpy as np
109
import pytest
1110
import torch
11+
import pandas as pd
12+
13+
from graph_weather.data.nnja_ai import SensorDataset, collate_fn
1214

13-
from graph_weather.data.nnja_ai import AMSUDataset, collate_fn
15+
16+
def get_sensor_variables(sensor_type):
17+
"""Helper function to get the correct variables for each sensor type."""
18+
if sensor_type == "AMSU":
19+
return [f"TMBR_000{i:02d}" for i in range(1, 16)] # 15 channels
20+
elif sensor_type == "ATMS":
21+
return [f"TMBR_000{i:02d}" for i in range(1, 23)] # 22 channels
22+
elif sensor_type == "MHS":
23+
return [f"TMBR_000{i:02d}" for i in range(1, 6)] # 5 channels
24+
elif sensor_type == "IASI":
25+
return [f"SCRA_{str(i).zfill(5)}" for i in range(1, 617)] # 616 channels
26+
elif sensor_type == "CrIS":
27+
return [f"SRAD01_{str(i).zfill(5)}" for i in range(1, 432)] # 431 channels
28+
return []
1429

1530

16-
# Mock the DataCatalog to avoid actual data loading
1731
@pytest.fixture
1832
def mock_datacatalog():
1933
"""
2034
Fixture to mock the DataCatalog for unit tests to avoid actual data loading.
21-
22-
This mock provides a mock dataset with predefined columns and values.
2335
"""
24-
with patch("graph_weather.data.nnjai_wrapp.DataCatalog") as mock:
25-
# Mock dataset structure
26-
mock_df = MagicMock()
27-
mock_df.columns = ["OBS_TIMESTAMP", "LAT", "LON", "TMBR_00001", "TMBR_00002"]
28-
29-
# Define a mock row
30-
class MockRow:
31-
def __getitem__(self, key):
32-
data = {
33-
"OBS_TIMESTAMP": datetime.now(),
34-
"LAT": 45.0,
35-
"LON": -120.0,
36-
"TMBR_00001": 250.0,
37-
"TMBR_00002": 260.0,
38-
}
39-
return data.get(key, None)
40-
41-
# Configure mock dataset
42-
mock_row = MockRow()
43-
mock_df.iloc = MagicMock()
44-
mock_df.iloc.__getitem__.return_value = mock_row
45-
mock_df.__len__.return_value = 100
36+
with patch("graph_weather.data.nnja_ai.DataCatalog") as mock:
37+
# Create a mock catalog
38+
mock_catalog = MagicMock()
4639

40+
# Create a mock dataset with direct DataFrame return
4741
mock_dataset = MagicMock()
48-
mock_dataset.load_dataset.return_value = mock_df
49-
mock_dataset.sel.return_value = mock_dataset
5042
mock_dataset.load_manifest = MagicMock()
43+
mock_dataset.sel = MagicMock(return_value=mock_dataset) # Return self to chain calls
44+
45+
def create_mock_df(engine="pandas"):
46+
# Get the sensor type from the mock dataset
47+
sensor_vars = get_sensor_variables(mock_dataset.sensor_type)
48+
49+
# Create DataFrame with required columns
50+
df = pd.DataFrame(
51+
{
52+
"OBS_TIMESTAMP": pd.date_range(
53+
start=datetime(2021, 1, 1), periods=100, freq="H"
54+
),
55+
"LAT": np.full(100, 45.0),
56+
"LON": np.full(100, -120.0),
57+
}
58+
)
5159

52-
mock.return_value.__getitem__.return_value = mock_dataset
53-
yield mock
60+
# Add sensor-specific variables
61+
for var in sensor_vars:
62+
df[var] = np.full(100, 250.0)
5463

64+
return df
5565

56-
def test_amsu_dataset(mock_datacatalog):
57-
"""
58-
Test the AMSUDataset class to ensure proper data loading and tensor structure.
66+
# Set up the mock to return our DataFrame
67+
mock_dataset.load_dataset = create_mock_df
5968

60-
This test validates the AMSUDataset class for its ability to load the dataset
61-
correctly, check for the appropriate tensor properties, and ensure the keys
62-
and data types match expectations.
63-
"""
64-
# Initialize dataset parameters
65-
dataset_name = "amsua-1bamua-NC021023"
66-
time = "2021-01-01 00Z"
69+
# Configure the catalog to return our mock dataset
70+
def get_mock_dataset(self, name):
71+
# Set the sensor type based on the requested dataset name
72+
mock_dataset.sensor_type = next(
73+
config["sensor_type"] for config in SENSOR_CONFIGS if config["name"] == name
74+
)
75+
return mock_dataset
76+
77+
mock_catalog.__getitem__ = get_mock_dataset # Fix: Explicitly define the method with `self`
78+
mock.return_value = mock_catalog
79+
80+
yield mock
81+
82+
83+
# Test configurations
84+
SENSOR_CONFIGS = [
85+
{
86+
"name": "amsu-1bamua-NC021023",
87+
"sensor_type": "AMSU",
88+
"expected_metadata_size": 15, # 15 TMBR channels
89+
},
90+
{
91+
"name": "atms-atms-NC021203",
92+
"sensor_type": "ATMS",
93+
"expected_metadata_size": 22, # 22 TMBR channels
94+
},
95+
{
96+
"name": "mhs-1bmhs-NC021027",
97+
"sensor_type": "MHS",
98+
"expected_metadata_size": 5, # 5 TMBR channels
99+
},
100+
{
101+
"name": "iasi-mtiasi-NC021241",
102+
"sensor_type": "IASI",
103+
"expected_metadata_size": 616, # 616 SCRA channels
104+
},
105+
{
106+
"name": "cris-crisf4-NC021206",
107+
"sensor_type": "CrIS",
108+
"expected_metadata_size": 431, # 431 SRAD channels
109+
},
110+
]
111+
112+
113+
@pytest.mark.parametrize("sensor_config", SENSOR_CONFIGS)
114+
def test_sensor_dataset(mock_datacatalog, sensor_config):
115+
"""Test the SensorDataset class for different sensor types."""
116+
time = datetime(2021, 1, 1, 0, 0)
67117
primary_descriptors = ["OBS_TIMESTAMP", "LAT", "LON"]
68-
additional_variables = ["TMBR_00001", "TMBR_00002"]
69118

70-
dataset = AMSUDataset(dataset_name, time, primary_descriptors, additional_variables)
119+
dataset = SensorDataset(
120+
dataset_name=sensor_config["name"],
121+
time=time,
122+
primary_descriptors=primary_descriptors,
123+
additional_variables=get_sensor_variables(sensor_config["sensor_type"]),
124+
sensor_type=sensor_config["sensor_type"],
125+
)
71126

72127
# Test dataset length
73-
assert len(dataset) > 0, "Dataset should not be empty."
128+
assert len(dataset) > 0, f"Dataset should not be empty for {sensor_config['sensor_type']}"
74129

130+
# Test single item structure
75131
item = dataset[0]
76132
expected_keys = {"timestamp", "latitude", "longitude", "metadata"}
77-
assert set(item.keys()) == expected_keys, "Dataset item keys are not as expected."
133+
assert (
134+
set(item.keys()) == expected_keys
135+
), f"Dataset item keys are not as expected for {sensor_config['sensor_type']}"
78136

79137
# Validate tensor properties
80-
assert isinstance(item["timestamp"], torch.Tensor), "Timestamp should be a tensor."
81-
assert item["timestamp"].dtype == torch.float32, "Timestamp should have dtype float32."
82-
assert item["timestamp"].ndim == 0, "Timestamp should be a scalar tensor."
83-
84-
assert isinstance(item["latitude"], torch.Tensor), "Latitude should be a tensor."
85-
assert item["latitude"].dtype == torch.float32, "Latitude should have dtype float32."
86-
assert item["latitude"].ndim == 0, "Latitude should be a scalar tensor."
87-
88-
assert isinstance(item["longitude"], torch.Tensor), "Longitude should be a tensor."
89-
assert item["longitude"].dtype == torch.float32, "Longitude should have dtype float32."
90-
assert item["longitude"].ndim == 0, "Longitude should be a scalar tensor."
91-
92-
assert isinstance(item["metadata"], torch.Tensor), "Metadata should be a tensor."
138+
assert isinstance(
139+
item["timestamp"], torch.Tensor
140+
), f"Timestamp should be a tensor for {sensor_config['sensor_type']}"
141+
assert (
142+
item["timestamp"].dtype == torch.float32
143+
), f"Timestamp should have dtype float32 for {sensor_config['sensor_type']}"
144+
assert (
145+
item["timestamp"].ndim == 0
146+
), f"Timestamp should be a scalar tensor for {sensor_config['sensor_type']}"
147+
148+
assert isinstance(
149+
item["latitude"], torch.Tensor
150+
), f"Latitude should be a tensor for {sensor_config['sensor_type']}"
151+
assert (
152+
item["latitude"].dtype == torch.float32
153+
), f"Latitude should have dtype float32 for {sensor_config['sensor_type']}"
154+
assert (
155+
item["latitude"].ndim == 0
156+
), f"Latitude should be a scalar tensor for {sensor_config['sensor_type']}"
157+
158+
assert isinstance(
159+
item["longitude"], torch.Tensor
160+
), f"Longitude should be a tensor for {sensor_config['sensor_type']}"
161+
assert (
162+
item["longitude"].dtype == torch.float32
163+
), f"Longitude should have dtype float32 for {sensor_config['sensor_type']}"
164+
assert (
165+
item["longitude"].ndim == 0
166+
), f"Longitude should be a scalar tensor for {sensor_config['sensor_type']}"
167+
168+
assert isinstance(
169+
item["metadata"], torch.Tensor
170+
), f"Metadata should be a tensor for {sensor_config['sensor_type']}"
93171
assert item["metadata"].shape == (
94-
len(additional_variables),
95-
), f"Metadata shape mismatch. Expected ({len(additional_variables)},)."
96-
assert item["metadata"].dtype == torch.float32, "Metadata should have dtype float32."
172+
sensor_config["expected_metadata_size"],
173+
), f"Metadata shape mismatch for {sensor_config['sensor_type']}. Expected ({sensor_config['expected_metadata_size']},)"
174+
assert (
175+
item["metadata"].dtype == torch.float32
176+
), f"Metadata should have dtype float32 for {sensor_config['sensor_type']}"
97177

98178

99179
def test_collate_function():
100-
"""
101-
Test the collate_fn function to ensure proper batching of dataset items.
102-
103-
This test checks that the collate_fn properly batches the timestamp, latitude,
104-
longitude, and metadata fields of the dataset, ensuring correct shapes and data types.
105-
"""
106-
# Mock a batch of items
180+
"""Test the collate_fn function to ensure proper batching of dataset items."""
107181
batch_size = 4
108-
metadata_size = 2
182+
metadata_size = 15 # Using AMSU size for this test
109183
mock_batch = [
110184
{
111185
"timestamp": torch.tensor(datetime.now().timestamp(), dtype=torch.float32),
@@ -116,19 +190,13 @@ def test_collate_function():
116190
for _ in range(batch_size)
117191
]
118192

119-
# Collate the batch
120193
batched = collate_fn(mock_batch)
121194

122-
# Validate batched shapes and types
123-
assert batched["timestamp"].shape == (batch_size,), "Timestamp batch shape mismatch."
124-
assert batched["latitude"].shape == (batch_size,), "Latitude batch shape mismatch."
125-
assert batched["longitude"].shape == (batch_size,), "Longitude batch shape mismatch."
126-
assert batched["metadata"].shape == (
127-
batch_size,
128-
metadata_size,
129-
), "Metadata batch shape mismatch."
130-
131-
assert batched["timestamp"].dtype == torch.float32, "Timestamp dtype mismatch."
132-
assert batched["latitude"].dtype == torch.float32, "Latitude dtype mismatch."
133-
assert batched["longitude"].dtype == torch.float32, "Longitude dtype mismatch."
134-
assert batched["metadata"].dtype == torch.float32, "Metadata dtype mismatch."
195+
assert batched["timestamp"].shape == (batch_size,), "Timestamp batch shape mismatch"
196+
assert batched["latitude"].shape == (batch_size,), "Latitude batch shape mismatch"
197+
assert batched["longitude"].shape == (batch_size,), "Longitude batch shape mismatch"
198+
assert batched["metadata"].shape == (batch_size, metadata_size), "Metadata batch shape mismatch"
199+
assert batched["timestamp"].dtype == torch.float32, "Timestamp dtype mismatch"
200+
assert batched["latitude"].dtype == torch.float32, "Latitude dtype mismatch"
201+
assert batched["longitude"].dtype == torch.float32, "Longitude dtype mismatch"
202+
assert batched["metadata"].dtype == torch.float32, "Metadata dtype mismatch"

0 commit comments

Comments
 (0)