Skip to content
88 changes: 43 additions & 45 deletions tests/config/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,19 @@
from ocf_data_sampler.config import Configuration, load_yaml_configuration


def _load_config_and_provider(config_path):
cfg = load_yaml_configuration(config_path)
provider = next(iter(cfg.input_data.nwp.root.keys()))
return cfg, provider


def _revalidate(cfg):
return Configuration(**cfg.model_dump())


def test_default_configuration(test_config_gsp_path):
"""Test default pydantic class"""
_ = load_yaml_configuration(test_config_gsp_path)
load_yaml_configuration(test_config_gsp_path)


def test_extra_field_error(test_config_gsp_path):
Expand All @@ -17,125 +27,114 @@ def test_extra_field_error(test_config_gsp_path):
configuration_dict = configuration.model_dump()
configuration_dict["extra_field"] = "extra_value"
with pytest.raises(ValidationError, match="Extra inputs are not permitted"):
_ = Configuration(**configuration_dict)
Configuration(**configuration_dict)


def test_incorrect_interval_start_minutes(test_config_filename):
"""
Check a history length not divisible by time resolution causes error
"""

configuration = load_yaml_configuration(test_config_filename)

configuration.input_data.nwp["ukv"].interval_start_minutes = -1111
configuration, provider = _load_config_and_provider(test_config_filename)
configuration.input_data.nwp[provider].interval_start_minutes = -1111
with pytest.raises(
ValueError,
match=r"interval_start_minutes \(-1111\) "
r"must be divisible by time_resolution_minutes \(60\)",
):
_ = Configuration(**configuration.model_dump())
_revalidate(configuration)


def test_incorrect_interval_end_minutes(test_config_filename):
"""
Check a forecast length not divisible by time resolution causes error
"""

configuration = load_yaml_configuration(test_config_filename)

configuration.input_data.nwp["ukv"].interval_end_minutes = 1111
configuration, provider = _load_config_and_provider(test_config_filename)
configuration.input_data.nwp[provider].interval_end_minutes = 1111
with pytest.raises(
ValueError,
match=r"interval_end_minutes \(1111\) "
r"must be divisible by time_resolution_minutes \(60\)",
):
_ = Configuration(**configuration.model_dump())
_revalidate(configuration)


def test_incorrect_nwp_provider(test_config_filename):
"""
Check an unexpected nwp provider causes error
"""

configuration = load_yaml_configuration(test_config_filename)

configuration.input_data.nwp["ukv"].provider = "unexpected_provider"
configuration, provider = _load_config_and_provider(test_config_filename)
configuration.input_data.nwp[provider].provider = "unexpected_provider"
with pytest.raises(Exception, match="NWP provider"):
_ = Configuration(**configuration.model_dump())
_revalidate(configuration)


def test_incorrect_dropout(test_config_filename):
"""
Check a dropout timedelta over 0 causes error and 0 doesn't
"""
configuration, provider = _load_config_and_provider(test_config_filename)

configuration = load_yaml_configuration(test_config_filename)

# check a positive number is not allowed
configuration.input_data.nwp["ukv"].dropout_timedeltas_minutes = [120]
# Check that a positive number is not allowed
configuration.input_data.nwp[provider].dropout_timedeltas_minutes = [120]
with pytest.raises(Exception, match="Dropout timedeltas must be negative"):
_ = Configuration(**configuration.model_dump())
_revalidate(configuration)

# check 0 is allowed
configuration.input_data.nwp["ukv"].dropout_timedeltas_minutes = [0]
_ = Configuration(**configuration.model_dump())
# Check that zero is allowed
configuration.input_data.nwp[provider].dropout_timedeltas_minutes = [0]
_revalidate(configuration)


def test_incorrect_dropout_fraction(test_config_filename):
"""
Check dropout fraction outside of range causes error
"""
configuration, provider = _load_config_and_provider(test_config_filename)

configuration = load_yaml_configuration(test_config_filename)

configuration.input_data.nwp["ukv"].dropout_fraction = 1.1

configuration.input_data.nwp[provider].dropout_fraction = 1.1
with pytest.raises(ValidationError, match=r"Dropout fractions must be in range *"):
_ = Configuration(**configuration.model_dump())
_revalidate(configuration)

configuration.input_data.nwp["ukv"].dropout_fraction = -0.1
configuration.input_data.nwp[provider].dropout_fraction = -0.1
with pytest.raises(ValidationError, match=r"Dropout fractions must be in range *"):
_ = Configuration(**configuration.model_dump())
_revalidate(configuration)

configuration.input_data.nwp["ukv"].dropout_fraction = [1.0,0.1]
configuration.input_data.nwp[provider].dropout_fraction = [1.0, 0.1]
with pytest.raises(ValidationError, match=r"The sum of dropout fractions must be in range *"):
_ = Configuration(**configuration.model_dump())
_revalidate(configuration)

configuration.input_data.nwp["ukv"].dropout_fraction = [-0.1,1.1]
configuration.input_data.nwp[provider].dropout_fraction = [-0.1, 1.1]
with pytest.raises(ValidationError, match=r"All dropout fractions must be in range *"):
_ = Configuration(**configuration.model_dump())
_revalidate(configuration)

configuration.input_data.nwp["ukv"].dropout_fraction = []
configuration.input_data.nwp[provider].dropout_fraction = []
with pytest.raises(ValidationError, match="List cannot be empty"):
_ = Configuration(**configuration.model_dump())
_revalidate(configuration)


def test_inconsistent_dropout_use(test_config_filename):
"""
Check dropout fraction outside of range causes error
"""

configuration = load_yaml_configuration(test_config_filename)
configuration.input_data.satellite.dropout_fraction = 1.0
configuration.input_data.satellite.dropout_timedeltas_minutes = []

with pytest.raises(
ValueError,
match="To dropout fraction > 0 requires a list of dropout timedeltas",
):
_ = Configuration(**configuration.model_dump())
_revalidate(configuration)

configuration.input_data.satellite.dropout_fraction = 0.0
configuration.input_data.satellite.dropout_timedeltas_minutes = [-120, -60]
with pytest.raises(
ValueError,
match="To use dropout timedeltas dropout fraction should be > 0",
):
_ = Configuration(**configuration.model_dump())
_revalidate(configuration)


def test_accum_channels_validation(test_config_filename):
"""Test accum_channels validation with required normalization constants."""
# Load valid config (implicitly tests valid case)
config = load_yaml_configuration(test_config_filename)
nwp_name, _ = next(iter(config.input_data.nwp.root.items()))

Expand All @@ -152,7 +151,7 @@ def test_accum_channels_validation(test_config_filename):
r"Extra values found: {'invalid_channel'}.*"
)
with pytest.raises(ValidationError, match=expected_error):
_ = Configuration(**invalid_config.model_dump())
_revalidate(invalid_config)


def test_configuration_requires_site_or_gsp():
Expand All @@ -161,4 +160,3 @@ def test_configuration_requires_site_or_gsp():
"""
with pytest.raises(ValidationError, match="You must provide either `site` or `gsp`"):
Configuration()

3 changes: 1 addition & 2 deletions tests/config/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,4 @@


def test_load_yaml_configuration(test_config_filename):
loaded_config = load_yaml_configuration(test_config_filename)
assert isinstance(loaded_config, Configuration)
assert isinstance(load_yaml_configuration(test_config_filename), Configuration)
16 changes: 5 additions & 11 deletions tests/config/test_save.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
"""Tests for configuration saving functionality."""

import os

from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration


def test_save_yaml_configuration_basic(tmp_path, test_config_gsp_path):
"""Save an empty configuration object"""
config = load_yaml_configuration(test_config_gsp_path)

filepath = f"{tmp_path}/config.yaml"
filepath = tmp_path / "config.yaml"
save_yaml_configuration(config, filepath)

assert os.path.exists(filepath)
assert filepath.exists()


def test_save_load_yaml_configuration(tmp_path, test_config_filename):
Expand All @@ -21,10 +18,7 @@ def test_save_load_yaml_configuration(tmp_path, test_config_filename):
# Start with this config
initial_config = load_yaml_configuration(test_config_filename)

# Save it
filepath = f"{tmp_path}/config.yaml"
# Save it - then load and check it is identical
filepath = tmp_path / "config.yaml"
save_yaml_configuration(initial_config, filepath)

# Load it and check it is still the same
loaded_config = load_yaml_configuration(filepath)
assert loaded_config == initial_config
assert load_yaml_configuration(filepath) == initial_config
144 changes: 143 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import pytest
import xarray as xr

from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
from ocf_data_sampler.config import Configuration, load_yaml_configuration, save_yaml_configuration
from ocf_data_sampler.config.model import Site, SolarPosition
from ocf_data_sampler.numpy_sample import GSPSampleKey, NWPSampleKey, SatelliteSampleKey
from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset

_top_test_directory = os.path.dirname(os.path.realpath(__file__))
Expand Down Expand Up @@ -466,3 +467,144 @@ def site_config_filename(
@pytest.fixture()
def sites_dataset(site_config_filename):
return SitesDataset(site_config_filename)

@pytest.fixture(scope="module")
def concatable_nwp_like_data(ds_nwp_ecmwf):

# Make a second NWP-like dataset so we can concat them
ds_2 = ds_nwp_ecmwf.copy(deep=True)
ds_2["init_time"] = pd.date_range(
start=ds_nwp_ecmwf.init_time.max().values + pd.Timedelta("6h"),
freq=pd.Timedelta("6h"),
periods=len(ds_nwp_ecmwf.init_time),
)

return ds_nwp_ecmwf, ds_2


@pytest.fixture(scope="module")
def nwp_like_zarr2_paths(session_tmp_path, concatable_nwp_like_data):

data_paths = [
f"{session_tmp_path}/nwp_like_data_{n}.zarr2" for n in range(len(concatable_nwp_like_data))
]

for ds, path in zip(concatable_nwp_like_data, data_paths, strict=False):
ds.to_zarr(path, zarr_format=2)

return data_paths


@pytest.fixture(scope="module")
def nwp_like_zarr3_paths(session_tmp_path, concatable_nwp_like_data):

data_paths = [
f"{session_tmp_path}/nwp_like_data_{n}.zarr3" for n in range(len(concatable_nwp_like_data))
]

for ds, path in zip(concatable_nwp_like_data, data_paths, strict=False):
ds.to_zarr(path, zarr_format=3)

return data_paths


@pytest.fixture(scope="module")
def da_sample():

datetimes = pd.date_range("2024-01-01 12:00", "2024-01-01 13:00", freq="5min")

da_sat = xr.DataArray(
np.random.normal(size=(len(datetimes),)),
coords={"time_utc": (["time_utc"], datetimes)},
)
return da_sat


@pytest.fixture(scope="module")
def da():
# Create dummy data
x = np.arange(-100, 100)
y = np.arange(-100, 100)

da = xr.DataArray(
np.random.normal(size=(len(x), len(y))),
coords={
"x_osgb": (["x_osgb"], x),
"y_osgb": (["y_osgb"], y),
},
)
return da


@pytest.fixture(scope="module")
def da_sat_like():
"""Create dummy data which looks like satellite data"""
x = np.arange(-100, 100)
y = np.arange(-100, 100)
datetimes = pd.date_range("2024-01-02 00:00", "2024-01-03 00:00", freq="5min")

return xr.DataArray(
np.random.normal(size=(len(datetimes), len(x), len(y))),
coords={
"time_utc": datetimes,
"x_geostationary": x,
"y_geostationary": y,
},
)


@pytest.fixture(scope="module")
def da_nwp_like():
"""Create dummy data which looks like NWP data"""
x = np.arange(-100, 100)
y = np.arange(-100, 100)
datetimes = pd.date_range(
"2024-01-02 00:00",
"2024-01-03 00:00",
freq=pd.Timedelta("3h"),
)
steps = pd.timedelta_range("0h", "16h", freq="1h")
channels = ["t", "dswrf"]

return xr.DataArray(
np.random.normal(size=(len(datetimes), len(steps), len(channels), len(x), len(y))),
coords={
"init_time_utc": datetimes,
"step": steps,
"channel": channels,
"x_osgb": x,
"y_osgb": y,
},
)


@pytest.fixture
def numpy_sample():
"""Synthetic data generation"""
expected_gsp_shape = (7,)
expected_nwp_ukv_shape = (4, 1, 2, 2)
expected_sat_shape = (7, 1, 2, 2)
expected_solar_shape = (7,)

nwp_data = {
"nwp": np.random.rand(*expected_nwp_ukv_shape),
"x": np.array([1, 2]),
"y": np.array([1, 2]),
NWPSampleKey.channel_names: ["t"],
}

return {
"nwp": {
"ukv": nwp_data,
},
GSPSampleKey.gsp: np.random.rand(*expected_gsp_shape),
SatelliteSampleKey.satellite_actual: np.random.rand(*expected_sat_shape),
"solar_azimuth": np.random.rand(*expected_solar_shape),
"solar_elevation": np.random.rand(*expected_solar_shape),
}


@pytest.fixture
def pvnet_configuration_object(pvnet_config_filename) -> Configuration:
"""Loads the configuration from the temporary file path."""
return load_yaml_configuration(pvnet_config_filename)
Loading