1
1
"""
2
- Tests for the nnjai_wrapp module in the graph_weather package.
3
-
4
- This file contains unit tests for AMSUDataset and collate_fn functions.
2
+ Unit tests for the `SensorDataset` class, mocking the `DataCatalog` to simulate sensor data loading and validate dataset behavior.
3
+ The tests ensure correct handling of data types, shapes, and batch processing for various sensor types.
5
4
"""
6
5
7
6
from datetime import datetime
8
7
from unittest .mock import MagicMock , patch
9
-
8
+ import numpy as np
10
9
import pytest
11
10
import torch
11
+ import pandas as pd
12
+
13
+ from graph_weather .data .nnja_ai import SensorDataset , collate_fn
12
14
13
- from graph_weather .data .nnja_ai import AMSUDataset , collate_fn
15
+
16
+ def get_sensor_variables (sensor_type ):
17
+ """Helper function to get the correct variables for each sensor type."""
18
+ if sensor_type == "AMSU" :
19
+ return [f"TMBR_000{ i :02d} " for i in range (1 , 16 )] # 15 channels
20
+ elif sensor_type == "ATMS" :
21
+ return [f"TMBR_000{ i :02d} " for i in range (1 , 23 )] # 22 channels
22
+ elif sensor_type == "MHS" :
23
+ return [f"TMBR_000{ i :02d} " for i in range (1 , 6 )] # 5 channels
24
+ elif sensor_type == "IASI" :
25
+ return [f"SCRA_{ str (i ).zfill (5 )} " for i in range (1 , 617 )] # 616 channels
26
+ elif sensor_type == "CrIS" :
27
+ return [f"SRAD01_{ str (i ).zfill (5 )} " for i in range (1 , 432 )] # 431 channels
28
+ return []
14
29
15
30
16
- # Mock the DataCatalog to avoid actual data loading
17
31
@pytest .fixture
18
32
def mock_datacatalog ():
19
33
"""
20
34
Fixture to mock the DataCatalog for unit tests to avoid actual data loading.
21
-
22
- This mock provides a mock dataset with predefined columns and values.
23
35
"""
24
- with patch ("graph_weather.data.nnjai_wrapp.DataCatalog" ) as mock :
25
- # Mock dataset structure
26
- mock_df = MagicMock ()
27
- mock_df .columns = ["OBS_TIMESTAMP" , "LAT" , "LON" , "TMBR_00001" , "TMBR_00002" ]
28
-
29
- # Define a mock row
30
- class MockRow :
31
- def __getitem__ (self , key ):
32
- data = {
33
- "OBS_TIMESTAMP" : datetime .now (),
34
- "LAT" : 45.0 ,
35
- "LON" : - 120.0 ,
36
- "TMBR_00001" : 250.0 ,
37
- "TMBR_00002" : 260.0 ,
38
- }
39
- return data .get (key , None )
40
-
41
- # Configure mock dataset
42
- mock_row = MockRow ()
43
- mock_df .iloc = MagicMock ()
44
- mock_df .iloc .__getitem__ .return_value = mock_row
45
- mock_df .__len__ .return_value = 100
36
+ with patch ("graph_weather.data.nnja_ai.DataCatalog" ) as mock :
37
+ # Create a mock catalog
38
+ mock_catalog = MagicMock ()
46
39
40
+ # Create a mock dataset with direct DataFrame return
47
41
mock_dataset = MagicMock ()
48
- mock_dataset .load_dataset .return_value = mock_df
49
- mock_dataset .sel .return_value = mock_dataset
50
42
mock_dataset .load_manifest = MagicMock ()
43
+ mock_dataset .sel = MagicMock (return_value = mock_dataset ) # Return self to chain calls
44
+
45
+ def create_mock_df (engine = "pandas" ):
46
+ # Get the sensor type from the mock dataset
47
+ sensor_vars = get_sensor_variables (mock_dataset .sensor_type )
48
+
49
+ # Create DataFrame with required columns
50
+ df = pd .DataFrame (
51
+ {
52
+ "OBS_TIMESTAMP" : pd .date_range (
53
+ start = datetime (2021 , 1 , 1 ), periods = 100 , freq = "H"
54
+ ),
55
+ "LAT" : np .full (100 , 45.0 ),
56
+ "LON" : np .full (100 , - 120.0 ),
57
+ }
58
+ )
51
59
52
- mock .return_value .__getitem__ .return_value = mock_dataset
53
- yield mock
60
+ # Add sensor-specific variables
61
+ for var in sensor_vars :
62
+ df [var ] = np .full (100 , 250.0 )
54
63
64
+ return df
55
65
56
- def test_amsu_dataset (mock_datacatalog ):
57
- """
58
- Test the AMSUDataset class to ensure proper data loading and tensor structure.
66
+ # Set up the mock to return our DataFrame
67
+ mock_dataset .load_dataset = create_mock_df
59
68
60
- This test validates the AMSUDataset class for its ability to load the dataset
61
- correctly, check for the appropriate tensor properties, and ensure the keys
62
- and data types match expectations.
63
- """
64
- # Initialize dataset parameters
65
- dataset_name = "amsua-1bamua-NC021023"
66
- time = "2021-01-01 00Z"
69
+ # Configure the catalog to return our mock dataset
70
+ def get_mock_dataset (self , name ):
71
+ # Set the sensor type based on the requested dataset name
72
+ mock_dataset .sensor_type = next (
73
+ config ["sensor_type" ] for config in SENSOR_CONFIGS if config ["name" ] == name
74
+ )
75
+ return mock_dataset
76
+
77
+ mock_catalog .__getitem__ = get_mock_dataset # Fix: Explicitly define the method with `self`
78
+ mock .return_value = mock_catalog
79
+
80
+ yield mock
81
+
82
+
83
+ # Test configurations
84
+ SENSOR_CONFIGS = [
85
+ {
86
+ "name" : "amsu-1bamua-NC021023" ,
87
+ "sensor_type" : "AMSU" ,
88
+ "expected_metadata_size" : 15 , # 15 TMBR channels
89
+ },
90
+ {
91
+ "name" : "atms-atms-NC021203" ,
92
+ "sensor_type" : "ATMS" ,
93
+ "expected_metadata_size" : 22 , # 22 TMBR channels
94
+ },
95
+ {
96
+ "name" : "mhs-1bmhs-NC021027" ,
97
+ "sensor_type" : "MHS" ,
98
+ "expected_metadata_size" : 5 , # 5 TMBR channels
99
+ },
100
+ {
101
+ "name" : "iasi-mtiasi-NC021241" ,
102
+ "sensor_type" : "IASI" ,
103
+ "expected_metadata_size" : 616 , # 616 SCRA channels
104
+ },
105
+ {
106
+ "name" : "cris-crisf4-NC021206" ,
107
+ "sensor_type" : "CrIS" ,
108
+ "expected_metadata_size" : 431 , # 431 SRAD channels
109
+ },
110
+ ]
111
+
112
+
113
+ @pytest .mark .parametrize ("sensor_config" , SENSOR_CONFIGS )
114
+ def test_sensor_dataset (mock_datacatalog , sensor_config ):
115
+ """Test the SensorDataset class for different sensor types."""
116
+ time = datetime (2021 , 1 , 1 , 0 , 0 )
67
117
primary_descriptors = ["OBS_TIMESTAMP" , "LAT" , "LON" ]
68
- additional_variables = ["TMBR_00001" , "TMBR_00002" ]
69
118
70
- dataset = AMSUDataset (dataset_name , time , primary_descriptors , additional_variables )
119
+ dataset = SensorDataset (
120
+ dataset_name = sensor_config ["name" ],
121
+ time = time ,
122
+ primary_descriptors = primary_descriptors ,
123
+ additional_variables = get_sensor_variables (sensor_config ["sensor_type" ]),
124
+ sensor_type = sensor_config ["sensor_type" ],
125
+ )
71
126
72
127
# Test dataset length
73
- assert len (dataset ) > 0 , "Dataset should not be empty. "
128
+ assert len (dataset ) > 0 , f "Dataset should not be empty for { sensor_config [ 'sensor_type' ] } "
74
129
130
+ # Test single item structure
75
131
item = dataset [0 ]
76
132
expected_keys = {"timestamp" , "latitude" , "longitude" , "metadata" }
77
- assert set (item .keys ()) == expected_keys , "Dataset item keys are not as expected."
133
+ assert (
134
+ set (item .keys ()) == expected_keys
135
+ ), f"Dataset item keys are not as expected for { sensor_config ['sensor_type' ]} "
78
136
79
137
# Validate tensor properties
80
- assert isinstance (item ["timestamp" ], torch .Tensor ), "Timestamp should be a tensor."
81
- assert item ["timestamp" ].dtype == torch .float32 , "Timestamp should have dtype float32."
82
- assert item ["timestamp" ].ndim == 0 , "Timestamp should be a scalar tensor."
83
-
84
- assert isinstance (item ["latitude" ], torch .Tensor ), "Latitude should be a tensor."
85
- assert item ["latitude" ].dtype == torch .float32 , "Latitude should have dtype float32."
86
- assert item ["latitude" ].ndim == 0 , "Latitude should be a scalar tensor."
87
-
88
- assert isinstance (item ["longitude" ], torch .Tensor ), "Longitude should be a tensor."
89
- assert item ["longitude" ].dtype == torch .float32 , "Longitude should have dtype float32."
90
- assert item ["longitude" ].ndim == 0 , "Longitude should be a scalar tensor."
91
-
92
- assert isinstance (item ["metadata" ], torch .Tensor ), "Metadata should be a tensor."
138
+ assert isinstance (
139
+ item ["timestamp" ], torch .Tensor
140
+ ), f"Timestamp should be a tensor for { sensor_config ['sensor_type' ]} "
141
+ assert (
142
+ item ["timestamp" ].dtype == torch .float32
143
+ ), f"Timestamp should have dtype float32 for { sensor_config ['sensor_type' ]} "
144
+ assert (
145
+ item ["timestamp" ].ndim == 0
146
+ ), f"Timestamp should be a scalar tensor for { sensor_config ['sensor_type' ]} "
147
+
148
+ assert isinstance (
149
+ item ["latitude" ], torch .Tensor
150
+ ), f"Latitude should be a tensor for { sensor_config ['sensor_type' ]} "
151
+ assert (
152
+ item ["latitude" ].dtype == torch .float32
153
+ ), f"Latitude should have dtype float32 for { sensor_config ['sensor_type' ]} "
154
+ assert (
155
+ item ["latitude" ].ndim == 0
156
+ ), f"Latitude should be a scalar tensor for { sensor_config ['sensor_type' ]} "
157
+
158
+ assert isinstance (
159
+ item ["longitude" ], torch .Tensor
160
+ ), f"Longitude should be a tensor for { sensor_config ['sensor_type' ]} "
161
+ assert (
162
+ item ["longitude" ].dtype == torch .float32
163
+ ), f"Longitude should have dtype float32 for { sensor_config ['sensor_type' ]} "
164
+ assert (
165
+ item ["longitude" ].ndim == 0
166
+ ), f"Longitude should be a scalar tensor for { sensor_config ['sensor_type' ]} "
167
+
168
+ assert isinstance (
169
+ item ["metadata" ], torch .Tensor
170
+ ), f"Metadata should be a tensor for { sensor_config ['sensor_type' ]} "
93
171
assert item ["metadata" ].shape == (
94
- len (additional_variables ),
95
- ), f"Metadata shape mismatch. Expected ({ len (additional_variables )} ,)."
96
- assert item ["metadata" ].dtype == torch .float32 , "Metadata should have dtype float32."
172
+ sensor_config ["expected_metadata_size" ],
173
+ ), f"Metadata shape mismatch for { sensor_config ['sensor_type' ]} . Expected ({ sensor_config ['expected_metadata_size' ]} ,)"
174
+ assert (
175
+ item ["metadata" ].dtype == torch .float32
176
+ ), f"Metadata should have dtype float32 for { sensor_config ['sensor_type' ]} "
97
177
98
178
99
179
def test_collate_function ():
100
- """
101
- Test the collate_fn function to ensure proper batching of dataset items.
102
-
103
- This test checks that the collate_fn properly batches the timestamp, latitude,
104
- longitude, and metadata fields of the dataset, ensuring correct shapes and data types.
105
- """
106
- # Mock a batch of items
180
+ """Test the collate_fn function to ensure proper batching of dataset items."""
107
181
batch_size = 4
108
- metadata_size = 2
182
+ metadata_size = 15 # Using AMSU size for this test
109
183
mock_batch = [
110
184
{
111
185
"timestamp" : torch .tensor (datetime .now ().timestamp (), dtype = torch .float32 ),
@@ -116,19 +190,13 @@ def test_collate_function():
116
190
for _ in range (batch_size )
117
191
]
118
192
119
- # Collate the batch
120
193
batched = collate_fn (mock_batch )
121
194
122
- # Validate batched shapes and types
123
- assert batched ["timestamp" ].shape == (batch_size ,), "Timestamp batch shape mismatch."
124
- assert batched ["latitude" ].shape == (batch_size ,), "Latitude batch shape mismatch."
125
- assert batched ["longitude" ].shape == (batch_size ,), "Longitude batch shape mismatch."
126
- assert batched ["metadata" ].shape == (
127
- batch_size ,
128
- metadata_size ,
129
- ), "Metadata batch shape mismatch."
130
-
131
- assert batched ["timestamp" ].dtype == torch .float32 , "Timestamp dtype mismatch."
132
- assert batched ["latitude" ].dtype == torch .float32 , "Latitude dtype mismatch."
133
- assert batched ["longitude" ].dtype == torch .float32 , "Longitude dtype mismatch."
134
- assert batched ["metadata" ].dtype == torch .float32 , "Metadata dtype mismatch."
195
+ assert batched ["timestamp" ].shape == (batch_size ,), "Timestamp batch shape mismatch"
196
+ assert batched ["latitude" ].shape == (batch_size ,), "Latitude batch shape mismatch"
197
+ assert batched ["longitude" ].shape == (batch_size ,), "Longitude batch shape mismatch"
198
+ assert batched ["metadata" ].shape == (batch_size , metadata_size ), "Metadata batch shape mismatch"
199
+ assert batched ["timestamp" ].dtype == torch .float32 , "Timestamp dtype mismatch"
200
+ assert batched ["latitude" ].dtype == torch .float32 , "Latitude dtype mismatch"
201
+ assert batched ["longitude" ].dtype == torch .float32 , "Longitude dtype mismatch"
202
+ assert batched ["metadata" ].dtype == torch .float32 , "Metadata dtype mismatch"
0 commit comments