diff --git a/tests/config/test_config.py b/tests/config/test_config.py index 9b08b4e7..4e2e44b4 100644 --- a/tests/config/test_config.py +++ b/tests/config/test_config.py @@ -4,9 +4,20 @@ from ocf_data_sampler.config import Configuration, load_yaml_configuration +def _load_config_and_provider(config_path): + config = load_yaml_configuration(config_path) + provider = next(iter(config.input_data.nwp.root.keys())) + return config, provider + + +def _validate_configuration(config): + """Recreate config instance from dict to trigger validation.""" + return Configuration(**config.model_dump()) + + def test_default_configuration(test_config_gsp_path): """Test default pydantic class""" - _ = load_yaml_configuration(test_config_gsp_path) + load_yaml_configuration(test_config_gsp_path) def test_extra_field_error(test_config_gsp_path): @@ -17,127 +28,115 @@ def test_extra_field_error(test_config_gsp_path): configuration_dict = configuration.model_dump() configuration_dict["extra_field"] = "extra_value" with pytest.raises(ValidationError, match="Extra inputs are not permitted"): - _ = Configuration(**configuration_dict) + Configuration(**configuration_dict) def test_incorrect_interval_start_minutes(test_config_filename): """ Check a history length not divisible by time resolution causes error """ - - configuration = load_yaml_configuration(test_config_filename) - - configuration.input_data.nwp["ukv"].interval_start_minutes = -1111 + configuration, provider = _load_config_and_provider(test_config_filename) + configuration.input_data.nwp[provider].interval_start_minutes = -1111 with pytest.raises( ValueError, match=r"interval_start_minutes \(-1111\) " r"must be divisible by time_resolution_minutes \(60\)", ): - _ = Configuration(**configuration.model_dump()) + _validate_configuration(configuration) def test_incorrect_interval_end_minutes(test_config_filename): """ Check a forecast length not divisible by time resolution causes error """ - - configuration = load_yaml_configuration(test_config_filename) - - configuration.input_data.nwp["ukv"].interval_end_minutes = 1111 + configuration, provider = _load_config_and_provider(test_config_filename) + configuration.input_data.nwp[provider].interval_end_minutes = 1111 with pytest.raises( ValueError, match=r"interval_end_minutes \(1111\) " r"must be divisible by time_resolution_minutes \(60\)", ): - _ = Configuration(**configuration.model_dump()) + _validate_configuration(configuration) def test_incorrect_nwp_provider(test_config_filename): """ Check an unexpected nwp provider causes error """ - - configuration = load_yaml_configuration(test_config_filename) - - configuration.input_data.nwp["ukv"].provider = "unexpected_provider" + configuration, provider = _load_config_and_provider(test_config_filename) + configuration.input_data.nwp[provider].provider = "unexpected_provider" with pytest.raises(Exception, match="NWP provider"): - _ = Configuration(**configuration.model_dump()) + _validate_configuration(configuration) def test_incorrect_dropout(test_config_filename): """ Check a dropout timedelta over 0 causes error and 0 doesn't """ + configuration, provider = _load_config_and_provider(test_config_filename) - configuration = load_yaml_configuration(test_config_filename) - - # check a positive number is not allowed - configuration.input_data.nwp["ukv"].dropout_timedeltas_minutes = [120] + # Check that a positive number is not allowed + configuration.input_data.nwp[provider].dropout_timedeltas_minutes = [120] with pytest.raises(Exception, match="Dropout timedeltas must be negative"): - _ = Configuration(**configuration.model_dump()) + _validate_configuration(configuration) - # check 0 is allowed - configuration.input_data.nwp["ukv"].dropout_timedeltas_minutes = [0] - _ = Configuration(**configuration.model_dump()) + # Check that zero is allowed + configuration.input_data.nwp[provider].dropout_timedeltas_minutes = [0] + _validate_configuration(configuration) def test_incorrect_dropout_fraction(test_config_filename): """ Check dropout fraction outside of range causes error """ + configuration, provider = _load_config_and_provider(test_config_filename) - configuration = load_yaml_configuration(test_config_filename) - - configuration.input_data.nwp["ukv"].dropout_fraction = 1.1 - + configuration.input_data.nwp[provider].dropout_fraction = 1.1 with pytest.raises(ValidationError, match=r"Dropout fractions must be in range *"): - _ = Configuration(**configuration.model_dump()) + _validate_configuration(configuration) - configuration.input_data.nwp["ukv"].dropout_fraction = -0.1 + configuration.input_data.nwp[provider].dropout_fraction = -0.1 with pytest.raises(ValidationError, match=r"Dropout fractions must be in range *"): - _ = Configuration(**configuration.model_dump()) + _validate_configuration(configuration) - configuration.input_data.nwp["ukv"].dropout_fraction = [1.0,0.1] + configuration.input_data.nwp[provider].dropout_fraction = [1.0, 0.1] with pytest.raises(ValidationError, match=r"The sum of dropout fractions must be in range *"): - _ = Configuration(**configuration.model_dump()) + _validate_configuration(configuration) - configuration.input_data.nwp["ukv"].dropout_fraction = [-0.1,1.1] + configuration.input_data.nwp[provider].dropout_fraction = [-0.1, 1.1] with pytest.raises(ValidationError, match=r"All dropout fractions must be in range *"): - _ = Configuration(**configuration.model_dump()) + _validate_configuration(configuration) - configuration.input_data.nwp["ukv"].dropout_fraction = [] + configuration.input_data.nwp[provider].dropout_fraction = [] with pytest.raises(ValidationError, match="List cannot be empty"): - _ = Configuration(**configuration.model_dump()) + _validate_configuration(configuration) def test_inconsistent_dropout_use(test_config_filename): """ Check dropout fraction outside of range causes error """ - configuration = load_yaml_configuration(test_config_filename) configuration.input_data.satellite.dropout_fraction = 1.0 configuration.input_data.satellite.dropout_timedeltas_minutes = [] - with pytest.raises( ValueError, match="To dropout fraction > 0 requires a list of dropout timedeltas", ): - _ = Configuration(**configuration.model_dump()) + _validate_configuration(configuration) + configuration.input_data.satellite.dropout_fraction = 0.0 configuration.input_data.satellite.dropout_timedeltas_minutes = [-120, -60] with pytest.raises( ValueError, match="To use dropout timedeltas dropout fraction should be > 0", ): - _ = Configuration(**configuration.model_dump()) + _validate_configuration(configuration) def test_accum_channels_validation(test_config_filename): """Test accum_channels validation with required normalization constants.""" - # Load valid config (implicitly tests valid case) - config = load_yaml_configuration(test_config_filename) - nwp_name, _ = next(iter(config.input_data.nwp.root.items())) + config, nwp_name = _load_config_and_provider(test_config_filename) # Test invalid channel scenario invalid_config = config.model_copy(deep=True) @@ -152,7 +151,7 @@ def test_accum_channels_validation(test_config_filename): r"Extra values found: {'invalid_channel'}.*" ) with pytest.raises(ValidationError, match=expected_error): - _ = Configuration(**invalid_config.model_dump()) + _validate_configuration(invalid_config) def test_configuration_requires_site_or_gsp(): @@ -161,4 +160,3 @@ def test_configuration_requires_site_or_gsp(): """ with pytest.raises(ValidationError, match="You must provide either `site` or `gsp`"): Configuration() - diff --git a/tests/config/test_load.py b/tests/config/test_load.py index 46c89f82..cdcbfc3f 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -2,5 +2,4 @@ def test_load_yaml_configuration(test_config_filename): - loaded_config = load_yaml_configuration(test_config_filename) - assert isinstance(loaded_config, Configuration) + assert isinstance(load_yaml_configuration(test_config_filename), Configuration) diff --git a/tests/config/test_save.py b/tests/config/test_save.py index e46c2b45..90ef0d60 100644 --- a/tests/config/test_save.py +++ b/tests/config/test_save.py @@ -1,6 +1,5 @@ """Tests for configuration saving functionality.""" -import os from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration @@ -8,11 +7,9 @@ def test_save_yaml_configuration_basic(tmp_path, test_config_gsp_path): """Save an empty configuration object""" config = load_yaml_configuration(test_config_gsp_path) - - filepath = f"{tmp_path}/config.yaml" + filepath = tmp_path / "config.yaml" save_yaml_configuration(config, filepath) - - assert os.path.exists(filepath) + assert filepath.exists() def test_save_load_yaml_configuration(tmp_path, test_config_filename): @@ -21,10 +18,7 @@ def test_save_load_yaml_configuration(tmp_path, test_config_filename): # Start with this config initial_config = load_yaml_configuration(test_config_filename) - # Save it - filepath = f"{tmp_path}/config.yaml" + # Save it - then load and check it is identical + filepath = tmp_path / "config.yaml" save_yaml_configuration(initial_config, filepath) - - # Load it and check it is still the same - loaded_config = load_yaml_configuration(filepath) - assert loaded_config == initial_config + assert load_yaml_configuration(filepath) == initial_config diff --git a/tests/load/test_load_gsp.py b/tests/load/test_load_gsp.py index 1e39d10e..5f549531 100644 --- a/tests/load/test_load_gsp.py +++ b/tests/load/test_load_gsp.py @@ -15,9 +15,7 @@ def test_get_gsp_boundaries(version, expected_length): assert isinstance(df, pd.DataFrame) assert len(df) == expected_length - assert "x_osgb" in df.columns - assert "y_osgb" in df.columns - + assert {"x_osgb", "y_osgb"}.issubset(df.columns) assert df.index.is_unique @@ -27,12 +25,8 @@ def test_open_gsp(uk_gsp_zarr_path): assert isinstance(da, xr.DataArray) assert da.dims == ("time_utc", "gsp_id") - - assert "effective_capacity_mwp" in da.coords - assert "x_osgb" in da.coords - assert "y_osgb" in da.coords + assert {"effective_capacity_mwp", "x_osgb", "y_osgb"}.issubset(da.coords) assert da.shape == (49, 318) - assert len(np.unique(da.coords["gsp_id"])) == da.shape[1] diff --git a/tests/load/test_load_nwp.py b/tests/load/test_load_nwp.py index e4f0cd78..b9c96d9f 100755 --- a/tests/load/test_load_nwp.py +++ b/tests/load/test_load_nwp.py @@ -1,5 +1,3 @@ -import os - import numpy as np import pytest from xarray import DataArray @@ -50,7 +48,7 @@ def test_load_gfs(nwp_gfs_zarr_path): def test_load_ecmwf_bad_dtype_latitude(tmp_path): """Test validation fails for ECMWF with bad latitude dtype.""" - zarr_path = os.path.join(tmp_path, "bad_ecmwf_latitude.zarr") + zarr_path = tmp_path / "bad_ecmwf_latitude.zarr" bad_array = DataArray( np.random.rand(1, 1, 1, 1, 1).astype(np.float32), dims=("init_time", "step", "variable", "longitude", "latitude"), @@ -69,7 +67,7 @@ def test_load_ecmwf_bad_dtype_latitude(tmp_path): def test_load_ecmwf_bad_dtype_init_time(tmp_path): """Test validation fails for ECMWF with bad init_time_utc dtype.""" - zarr_path = os.path.join(tmp_path, "bad_ecmwf_init_time.zarr") + zarr_path = tmp_path / "bad_ecmwf_init_time.zarr" bad_array = DataArray( np.random.rand(1, 1, 1, 1, 1).astype(np.float32), dims=("init_time", "step", "variable", "longitude", "latitude"), @@ -88,7 +86,7 @@ def test_load_ecmwf_bad_dtype_init_time(tmp_path): def test_load_ecmwf_bad_dtype_step(tmp_path): """Test validation fails for ECMWF with bad step dtype.""" - zarr_path = os.path.join(tmp_path, "bad_ecmwf_step.zarr") + zarr_path = tmp_path / "bad_ecmwf_step.zarr" bad_array = DataArray( np.random.rand(1, 1, 1, 1, 1).astype(np.float32), dims=("init_time", "step", "variable", "longitude", "latitude"), diff --git a/tests/load/test_load_satellite.py b/tests/load/test_load_satellite.py index c6519f88..83ae359a 100755 --- a/tests/load/test_load_satellite.py +++ b/tests/load/test_load_satellite.py @@ -18,7 +18,6 @@ def test_open_satellite(sat_zarr_path): # There are 11 channels # There are 100 x 100 pixels assert da.shape == (288, 11, 100, 100) - assert len(np.unique(da.coords["channel"])) == da.shape[1] @@ -35,7 +34,7 @@ def test_open_satellite_bad_dtype(tmp_path: Path): ), }, coords={ - "time": pd.to_datetime(pd.date_range("2023-01-01", periods=10, freq="5min")), + "time": pd.date_range("2023-01-01", periods=10, freq="5min"), "variable": [1, 2], "y_geostationary": np.arange(4), "x_geostationary": np.arange(4), diff --git a/tests/load/test_load_sites.py b/tests/load/test_load_sites.py index 2c0879f8..6e47437d 100644 --- a/tests/load/test_load_sites.py +++ b/tests/load/test_load_sites.py @@ -10,15 +10,15 @@ def test_open_site(default_data_site_model): """Test the site data loader with valid data.""" - da = open_site(default_data_site_model.file_path, default_data_site_model.metadata_file_path) + da = open_site( + default_data_site_model.file_path, + default_data_site_model.metadata_file_path, + ) assert isinstance(da, xr.DataArray) assert da.dims == ("time_utc", "site_id") - assert "capacity_kwp" in da.coords - assert "latitude" in da.coords - assert "longitude" in da.coords + assert {"capacity_kwp", "latitude", "longitude"}.issubset(da.coords) assert da.shape == (49, 10) - assert len(np.unique(da.coords["site_id"])) == da.shape[1] @@ -28,11 +28,9 @@ def test_open_site_bad_dtype(tmp_path: Path): meta_path = tmp_path / "site_meta.csv" bad_ds = xr.Dataset( - data_vars={ - "generation_kw": (("time_utc", "site_id"), np.random.rand(10, 2)), - }, + data_vars={"generation_kw": (("time_utc", "site_id"), np.random.rand(10, 2))}, coords={ - "time_utc": pd.to_datetime(pd.date_range("2023-01-01", periods=10, freq="30min")), + "time_utc": pd.date_range("2023-01-01", periods=10, freq="30min"), "site_id": np.array([1.0, 2.0]), }, ) diff --git a/tests/load/test_open_xarray_tensorstore.py b/tests/load/test_open_xarray_tensorstore.py index b2582fb7..95f2f598 100644 --- a/tests/load/test_open_xarray_tensorstore.py +++ b/tests/load/test_open_xarray_tensorstore.py @@ -46,35 +46,29 @@ def nwp_like_zarr3_paths(session_tmp_path, concatable_nwp_like_data): def test_open_zarr(nwp_like_zarr2_paths, nwp_like_zarr3_paths): - - # Check the function can open zarr2 + # Check function can open zarr2 ds_ts = open_zarr(nwp_like_zarr2_paths[0]) - - # Check the tensorstore version returns the same results as the dask version + # Check tensorstore version returns same results as dask version ds_dask = xr.open_zarr(nwp_like_zarr2_paths[0]) assert ds_ts.compute().equals(ds_dask.compute()) - # Check the function can open zarr3 + # Check function can open zarr3 ds_ts = open_zarr(nwp_like_zarr3_paths[0]) - - # Check the tensorstore version returns the same results as the dask version + # Check tensorstore version returns same results as dask version ds_dask = xr.open_zarr(nwp_like_zarr3_paths[0]) assert ds_ts.compute().equals(ds_dask.compute()) def test_open_zarrs(nwp_like_zarr2_paths, nwp_like_zarr3_paths): - - # Check the function can open zarr2 + # Check function can open zarr2 ds_ts = open_zarrs(nwp_like_zarr2_paths, concat_dim="init_time") - - # Check the tensorstore version returns the same results as the dask version + # Check tensorstore version returns same results as dask version kwargs = {"concat_dim": "init_time", "combine": "nested", "engine": "zarr"} ds_dask = xr.open_mfdataset(nwp_like_zarr2_paths, **kwargs) assert ds_ts.compute().equals(ds_dask.compute()) - # Check the function can open zarr3 + # Check function can open zarr3 ds_ts = open_zarrs(nwp_like_zarr3_paths, concat_dim="init_time") - - # Check the tensorstore version returns the same results as the dask version + # Check tensorstore version returns same results as dask version ds_dask = xr.open_mfdataset(nwp_like_zarr3_paths, **kwargs) assert ds_ts.compute().equals(ds_dask.compute()) diff --git a/tests/numpy_sample/test_collate.py b/tests/numpy_sample/test_collate.py index a63c69a7..df5e472e 100644 --- a/tests/numpy_sample/test_collate.py +++ b/tests/numpy_sample/test_collate.py @@ -3,19 +3,15 @@ def test_stack_np_samples_into_batch(pvnet_config_filename): - # Create dataset object - dataset = PVNetUKRegionalDataset(pvnet_config_filename) - - # Generate 2 samples - sample1 = dataset[0] - sample2 = dataset[1] - batch = stack_np_samples_into_batch([sample1, sample2]) + # Create dataset object - generate two samples + dataset = PVNetUKRegionalDataset(pvnet_config_filename) + batch = stack_np_samples_into_batch([dataset[0], dataset[1]]) assert isinstance(batch, dict) assert "nwp" in batch assert isinstance(batch["nwp"], dict) assert "ukv" in batch["nwp"] - assert "gsp" in batch - assert "satellite_actual" in batch - assert "t0" in batch + + for key in ("gsp", "satellite_actual", "t0"): + assert key in batch diff --git a/tests/numpy_sample/test_datetime_features.py b/tests/numpy_sample/test_datetime_features.py index 93846921..cd02a1f1 100644 --- a/tests/numpy_sample/test_datetime_features.py +++ b/tests/numpy_sample/test_datetime_features.py @@ -5,19 +5,14 @@ def test_encode_datetimes(): - # Pick the day of the summer solstice + # Pick summer solstice day and calculate encoding features datetimes = pd.to_datetime(["2024-06-20 12:00", "2024-06-20 12:30", "2024-06-20 13:00"]) + features = encode_datetimes(datetimes) - # Calculate datetime encoding features - datetime_features = encode_datetimes(datetimes) + assert len(features) == 4 + assert all(len(arr) == len(datetimes) for arr in features.values()) + assert (features["date_cos"] != features["date_sin"]).all() - assert len(datetime_features) == 4 - - assert len(datetime_features["date_sin"]) == len(datetimes) - assert (datetime_features["date_cos"] != datetime_features["date_sin"]).all() - - # assert all values are between -1 and 1 - assert all(np.abs(datetime_features["date_sin"]) <= 1) - assert all(np.abs(datetime_features["date_cos"]) <= 1) - assert all(np.abs(datetime_features["time_sin"]) <= 1) - assert all(np.abs(datetime_features["time_cos"]) <= 1) + # Values should be between -1 and 1 + for key in ("date_sin", "date_cos", "time_sin", "time_cos"): + assert np.all(np.abs(features[key]) <= 1) diff --git a/tests/numpy_sample/test_gsp.py b/tests/numpy_sample/test_gsp.py index 6045963b..3c2480e3 100644 --- a/tests/numpy_sample/test_gsp.py +++ b/tests/numpy_sample/test_gsp.py @@ -6,32 +6,27 @@ def test_convert_gsp_to_numpy_sample(uk_gsp_zarr_path): da = open_gsp(uk_gsp_zarr_path).isel(time_utc=slice(0, 10)).sel(gsp_id=1) - numpy_sample = convert_gsp_to_numpy_sample(da) - # Test data structure - assert isinstance(numpy_sample, dict), "Should be dict" - assert set(numpy_sample.keys()).issubset( - { - GSPSampleKey.gsp, - GSPSampleKey.effective_capacity_mwp, - GSPSampleKey.time_utc, - }, - ), "Unexpected keys" + # Assert structure + expected_keys = { + GSPSampleKey.gsp, + GSPSampleKey.effective_capacity_mwp, + GSPSampleKey.time_utc, + } + assert isinstance(numpy_sample, dict) + assert set(numpy_sample) <= expected_keys + + # Assert content and capacity values + assert np.array_equal(numpy_sample[GSPSampleKey.gsp], da.values) + assert isinstance(numpy_sample[GSPSampleKey.time_utc], np.ndarray) + assert numpy_sample[GSPSampleKey.time_utc].dtype == float - # Assert data content and capacity values - assert np.array_equal(numpy_sample[GSPSampleKey.gsp], da.values), "GSP values mismatch" - assert isinstance( - numpy_sample[GSPSampleKey.time_utc], - np.ndarray, - ), "Time UTC should be numpy array" - assert numpy_sample[GSPSampleKey.time_utc].dtype == float, "Time UTC should be float type" - assert ( - numpy_sample[GSPSampleKey.effective_capacity_mwp] - == da.isel(time_utc=0)["effective_capacity_mwp"].values + assert numpy_sample[GSPSampleKey.effective_capacity_mwp] == ( + da.effective_capacity_mwp.isel(time_utc=0).values ) - # Test with t0_idx + # With t0_idx t0_idx = 5 numpy_sample_with_t0 = convert_gsp_to_numpy_sample(da, t0_idx=t0_idx) - assert numpy_sample_with_t0[GSPSampleKey.t0_idx] == t0_idx, "t0_idx not correctly set" + assert numpy_sample_with_t0[GSPSampleKey.t0_idx] == t0_idx diff --git a/tests/numpy_sample/test_nwp.py b/tests/numpy_sample/test_nwp.py index 561f3872..26cda095 100644 --- a/tests/numpy_sample/test_nwp.py +++ b/tests/numpy_sample/test_nwp.py @@ -2,11 +2,8 @@ def test_convert_nwp_to_numpy_sample(ds_nwp_ukv_time_sliced): - # Call the function numpy_sample = convert_nwp_to_numpy_sample(ds_nwp_ukv_time_sliced) - # Assert the output type + # Assert output type and shape of sample assert isinstance(numpy_sample, dict) - - # Assert the shape of the numpy sample assert (numpy_sample[NWPSampleKey.nwp] == ds_nwp_ukv_time_sliced.values).all() diff --git a/tests/numpy_sample/test_satellite.py b/tests/numpy_sample/test_satellite.py index 98ecf6d2..a2df9c4f 100644 --- a/tests/numpy_sample/test_satellite.py +++ b/tests/numpy_sample/test_satellite.py @@ -27,11 +27,8 @@ def da_sat_like(): def test_convert_satellite_to_numpy_sample(da_sat_like): - # Call the function numpy_sample = convert_satellite_to_numpy_sample(da_sat_like) - # Assert the output type + # Assert output type and shape of sample assert isinstance(numpy_sample, dict) - - # Assert the shape of the numpy sample assert (numpy_sample[SatelliteSampleKey.satellite_actual] == da_sat_like.values).all() diff --git a/tests/numpy_sample/test_sun_position.py b/tests/numpy_sample/test_sun_position.py index 17aee89e..2c1a2546 100644 --- a/tests/numpy_sample/test_sun_position.py +++ b/tests/numpy_sample/test_sun_position.py @@ -10,68 +10,48 @@ @pytest.mark.parametrize("lat", [0, 5, 10, 23.5]) def test_calculate_azimuth_and_elevation(lat): - # Pick the day of the summer solstice - datetimes = pd.to_datetime(["2024-06-20 12:00"]) - # Calculate sun angles + # Summer solstice day and sun angle calculation + datetimes = pd.to_datetime(["2024-06-20 12:00"]) azimuth, elevation = calculate_azimuth_and_elevation(datetimes, lon=0, lat=lat) assert len(azimuth) == len(datetimes) assert len(elevation) == len(datetimes) - - # elevation should be close to (90 - (23.5-lat) degrees + # Elevation should be close to 90 - (23.5 - lat) assert np.abs(elevation - (90 - 23.5 + lat)) < 1 def test_calculate_azimuth_and_elevation_random(): """Test that the function produces the expected range of azimuths and elevations""" - # Set seed so we know the test should pass np.random.seed(0) - # Pick the day of the summer solstice + # Pick day of summer solstice datetimes = pd.to_datetime(["2024-06-20 12:00"]) - # Pick 100 random locations and measure their azimuth and elevations - azimuths = [] - elevations = [] - + # For 100 random locations - calculate azimuth and elevations + azimuths, elevations = [], [] for _ in range(100): lon = np.random.uniform(low=0, high=360) lat = np.random.uniform(low=-90, high=90) - - # Calculate sun angles azimuth, elevation = calculate_azimuth_and_elevation(datetimes, lon=lon, lat=lat) - azimuths.append(azimuth.item()) elevations.append(elevation.item()) - azimuths = np.array(azimuths) - elevations = np.array(elevations) - - assert (azimuths >= 0).all() and (azimuths <= 360).all() - assert (elevations >= -90).all() and (elevations <= 90).all() - - # Azimuth range is [0, 360] - assert azimuths.min() < 30 - assert azimuths.max() > 330 + azimuths, elevations = np.array(azimuths), np.array(elevations) - # Elevation range is [-90, 90] - assert elevations.min() < -70 - assert elevations.max() > 70 + # Assert both azimuth range is [0, 360] and elevation range is [-90, 90] + assert np.all((azimuths >= 0) & (azimuths <= 360)) + assert np.all((elevations >= -90) & (elevations <= 90)) + assert azimuths.min() < 30 and azimuths.max() > 330 + assert elevations.min() < -70 and elevations.max() > 70 def test_make_sun_position_numpy_sample(): datetimes = pd.date_range("2024-06-20 12:00", "2024-06-20 16:00", freq="30min") - lon, lat = 0, 51.5 - - sample = make_sun_position_numpy_sample(datetimes, lon, lat) - - assert "solar_elevation" in sample - assert "solar_azimuth" in sample + sample = make_sun_position_numpy_sample(datetimes, lon=0, lat=51.5) - # The solar coords are normalised in the function - assert (sample["solar_elevation"] >= 0).all() - assert (sample["solar_elevation"] <= 1).all() - assert (sample["solar_azimuth"] >= 0).all() - assert (sample["solar_azimuth"] <= 1).all() + # Assertion accounting for solar coord normalisation + assert {"solar_elevation", "solar_azimuth"} <= set(sample) + assert np.all((sample["solar_elevation"] >= 0) & (sample["solar_elevation"] <= 1)) + assert np.all((sample["solar_azimuth"] >= 0) & (sample["solar_azimuth"] <= 1)) diff --git a/tests/select/test_diff_channels.py b/tests/select/test_diff_channels.py index 5ce6efda..951b8602 100644 --- a/tests/select/test_diff_channels.py +++ b/tests/select/test_diff_channels.py @@ -1,28 +1,23 @@ - from ocf_data_sampler.select.diff_channels import diff_channels def test_diff_channels(ds_nwp_ukv_time_sliced): - - # Make a copy since the function edits the inputs in-place + # Construct copy as function edits inputs in-place + # Assert more than one channel in fixture da = ds_nwp_ukv_time_sliced.copy(deep=True) - channels = [*da.channel.values] - - # This test relies on there being more than one channel in the fixture - assert len(channels)>1 + channels = list(da.channel.values) + assert len(channels) > 1 + # Assert diff function reduces the steps by one da_diffed = diff_channels(da, accum_channels=channels[:1]) + assert (da_diffed.step.values == ds_nwp_ukv_time_sliced.step.values[:-1]).all() - # The diff function reduces the steps by one - assert (da_diffed.step.values==ds_nwp_ukv_time_sliced.step.values[:-1]).all() - - # Check that these channels have not beeen changed - expected_result = ds_nwp_ukv_time_sliced.isel(channel=slice(1, None), step=slice(None, -1)) - assert da_diffed.isel(channel=slice(1, None)).equals(expected_result) + # Check these channels have not been changed + expected_unchanged = ds_nwp_ukv_time_sliced.isel(channel=slice(1, None), step=slice(None, -1)) + assert da_diffed.isel(channel=slice(1, None)).equals(expected_unchanged) - # Check that these channels have been properly diffed - expected_result = ( - ds_nwp_ukv_time_sliced.diff(dim="step", label="lower") - .isel(channel=slice(None, 1)) + # Check these channels have been properly diffed + expected_diffed = ( + ds_nwp_ukv_time_sliced.diff(dim="step", label="lower").isel(channel=slice(None, 1)) ).values - assert (da_diffed.isel(channel=slice(None, 1)).values==expected_result).all() + assert (da_diffed.isel(channel=slice(None, 1)).values == expected_diffed).all() diff --git a/tests/select/test_dropout.py b/tests/select/test_dropout.py index 624799ad..a10ede13 100644 --- a/tests/select/test_dropout.py +++ b/tests/select/test_dropout.py @@ -48,19 +48,14 @@ def test_apply_history_dropout_none(da_sample): dropout_frac=0, da=da_sample, ) - - # Check data arrays are equal as dropout time would be None xr.testing.assert_equal(da_sample_dropout, da_sample) - # No dropout timedeltas and dropout fraction is 0 da_sample_dropout = apply_history_dropout( t0, dropout_timedeltas=[], dropout_frac=0, da=da_sample, ) - - # Check data arrays are equal as dropout time would be None xr.testing.assert_equal(da_sample_dropout, da_sample) @@ -70,14 +65,15 @@ def test_apply_history_dropout_list(da_sample): da_sample_dropout = apply_history_dropout( t0, dropout_timedeltas=minutes([-30, -45]), - dropout_frac=[0.5,0.5], + dropout_frac=[0.5, 0.5], da=da_sample, ) latest_expected_cut_off = t0 + minutes(-30) assert ( - da_sample_dropout.sel(time_utc=slice(latest_expected_cut_off + minutes(5), None)) + da_sample_dropout + .sel(time_utc=slice(latest_expected_cut_off + minutes(5), None)) .isnull() .all() ) @@ -96,5 +92,4 @@ def test_apply_history_dropout(da_sample, t0_str): ) assert da_dropout.sel(time_utc=slice(None, dropout_time)).notnull().all() - assert (da_dropout.sel(time_utc=slice(dropout_time + minutes(5), t0_time)).isnull().all()) - + assert da_dropout.sel(time_utc=slice(dropout_time + minutes(5), t0_time)).isnull().all() diff --git a/tests/select/test_fill_time_periods.py b/tests/select/test_fill_time_periods.py index dbb0d050..94cf5ae2 100644 --- a/tests/select/test_fill_time_periods.py +++ b/tests/select/test_fill_time_periods.py @@ -20,23 +20,23 @@ def test_fill_time_periods(): ], }, ) - freq = pd.Timedelta("30min") - filled_time_periods = fill_time_periods(time_periods, freq) - expected_times = [ - "04:30", - "05:00", - "05:30", - "06:00", - "09:00", - "12:00", - "12:30", - "13:00", - "13:30", - "14:00", - "14:30", - ] + filled = fill_time_periods(time_periods, freq=pd.Timedelta("30min")) - expected_times = pd.DatetimeIndex([f"2021-01-01 {t}" for t in expected_times]) + expected = pd.to_datetime( + [ + "2021-01-01 04:30", + "2021-01-01 05:00", + "2021-01-01 05:30", + "2021-01-01 06:00", + "2021-01-01 09:00", + "2021-01-01 12:00", + "2021-01-01 12:30", + "2021-01-01 13:00", + "2021-01-01 13:30", + "2021-01-01 14:00", + "2021-01-01 14:30", + ], + ) - pd.testing.assert_index_equal(filled_time_periods, expected_times) + pd.testing.assert_index_equal(filled, expected) diff --git a/tests/select/test_find_contiguous_time_periods.py b/tests/select/test_find_contiguous_time_periods.py index 2561878f..20ba19ec 100644 --- a/tests/select/test_find_contiguous_time_periods.py +++ b/tests/select/test_find_contiguous_time_periods.py @@ -12,8 +12,8 @@ def construct_time_periods_df(start_dt: list[str], end_dt: list[str]) -> pd.Data """Helper function to construct a DataFrame of time periods Args: - start_dts: List of period start datetimes - end_dts: List of period end datetimes + start_dt: List of period start datetimes + end_dt: List of period end datetimes Returns: pd.DataFrame: DataFrame with start and end datetimes columns where each period is a row @@ -43,41 +43,39 @@ def test_find_contiguous_t0_periods(): expected_results = construct_time_periods_df( start_dt=["2023-01-01 13:35", "2023-01-01 15:35"], - end_dt = ["2023-01-01 14:10", "2023-01-01 16:45"], + end_dt=["2023-01-01 14:10", "2023-01-01 16:45"], ) assert periods.equals(expected_results) def test_find_contiguous_t0_periods_nwp(): - # These are the expected results of the test exp_res1 = construct_time_periods_df( start_dt=["2023-01-01 03:00", "2023-01-02 03:00"], - end_dt = ["2023-01-01 21:00", "2023-01-03 06:00"], + end_dt=["2023-01-01 21:00", "2023-01-03 06:00"], ) exp_res2 = construct_time_periods_df( start_dt=["2023-01-01 05:00", "2023-01-02 05:00"], - end_dt = ["2023-01-01 21:00", "2023-01-03 06:00"], + end_dt=["2023-01-01 21:00", "2023-01-03 06:00"], ) exp_res3 = construct_time_periods_df( start_dt=["2023-01-01 05:00", "2023-01-02 05:00", "2023-01-02 14:00"], - end_dt = ["2023-01-01 18:00", "2023-01-02 09:00", "2023-01-03 03:00"], + end_dt=["2023-01-01 18:00", "2023-01-02 09:00", "2023-01-03 03:00"], ) exp_res4 = construct_time_periods_df( start_dt=["2023-01-01 05:00", "2023-01-01 11:00", "2023-01-02 05:00", "2023-01-02 14:00"], - end_dt = ["2023-01-01 06:00", "2023-01-01 15:00", "2023-01-02 06:00", "2023-01-03 00:00"], + end_dt=["2023-01-01 06:00", "2023-01-01 15:00", "2023-01-02 06:00", "2023-01-03 00:00"], ) exp_res5 = construct_time_periods_df( start_dt=["2023-01-01 06:00", "2023-01-01 12:00", "2023-01-02 06:00", "2023-01-02 15:00"], - end_dt = ["2023-01-01 09:00", "2023-01-01 18:00", "2023-01-02 09:00", "2023-01-03 03:00"], + end_dt=["2023-01-01 09:00", "2023-01-01 18:00", "2023-01-02 09:00", "2023-01-03 03:00"], ) expected_results = [exp_res1, exp_res2, exp_res3, exp_res4, exp_res5] # Create 3-hourly init times with a few time stamps missing freq = pd.Timedelta(3, "h") - init_times = pd.date_range("2023-01-01 03:00", "2023-01-02 21:00", freq=freq).delete( [1, 4, 5, 6, 7, 9, 10], ) @@ -87,7 +85,7 @@ def test_find_contiguous_t0_periods_nwp(): max_stalenesses_hr = [9, 9, 6, 3, 6] max_dropouts_hr = [0, 0, 0, 0, 3] - for i in range(len(expected_results)): + for i, expected in enumerate(expected_results): interval_start = pd.Timedelta(-history_durations_hr[i], "h") max_staleness = pd.Timedelta(max_stalenesses_hr[i], "h") max_dropout = pd.Timedelta(max_dropouts_hr[i], "h") @@ -100,13 +98,12 @@ def test_find_contiguous_t0_periods_nwp(): ) # Check if results are as expected - assert time_periods.equals(expected_results[i]) + assert time_periods.equals(expected) def test_intersection_of_2_dataframes_of_periods(): - def assert_expected_result_with_reverse(a, b, expected_result): - """Assert the calulated intersection is as expected with and without a and b switched""" + """Assert the calculated intersection is as expected with and without a and b switched""" assert intersection_of_2_dataframes_of_periods(a, b).equals(expected_result) assert intersection_of_2_dataframes_of_periods(b, a).equals(expected_result) @@ -158,29 +155,29 @@ def assert_expected_result_with_reverse(a, b, expected_result): # b: | a = construct_time_periods_df(start_dt=["2025-01-01 00:00"], end_dt=["2025-01-01 00:00"]) b = construct_time_periods_df(start_dt=["2025-01-01 06:00"], end_dt=["2025-01-01 06:00"]) - exp_res = construct_time_periods_df([], []) # no intersection + exp_res = construct_time_periods_df([], []) # no intersection assert_expected_result_with_reverse(a, b, expected_result=exp_res) def test_intersection_of_multiple_dataframes_of_periods(): periods_1 = construct_time_periods_df( start_dt=["2023-01-01 05:00", "2023-01-01 14:10"], - end_dt = ["2023-01-01 13:35", "2023-01-01 18:00"], + end_dt=["2023-01-01 13:35", "2023-01-01 18:00"], ) periods_2 = construct_time_periods_df( start_dt=["2023-01-01 12:00"], - end_dt = ["2023-01-02 00:00"], + end_dt=["2023-01-02 00:00"], ) periods_3 = construct_time_periods_df( start_dt=["2023-01-01 00:00", "2023-01-01 13:00"], - end_dt = ["2023-01-01 12:30", "2023-01-01 23:00"], + end_dt=["2023-01-01 12:30", "2023-01-01 23:00"], ) expected_result = construct_time_periods_df( start_dt=["2023-01-01 12:00", "2023-01-01 13:00", "2023-01-01 14:10"], - end_dt = ["2023-01-01 12:30", "2023-01-01 13:35", "2023-01-01 18:00"], + end_dt=["2023-01-01 12:30", "2023-01-01 13:35", "2023-01-01 18:00"], ) result = intersection_of_multiple_dataframes_of_periods([periods_1, periods_2, periods_3]) diff --git a/tests/select/test_location.py b/tests/select/test_location.py index 37e50401..389d95f1 100644 --- a/tests/select/test_location.py +++ b/tests/select/test_location.py @@ -2,5 +2,4 @@ def test_make_valid_location_object(): - x, y = -1000.5, 50000 - _ = Location(x=x, y=y, coord_system="osgb") + Location(x=-1000.5, y=50000, coord_system="osgb") diff --git a/tests/select/test_select_time_slice.py b/tests/select/test_select_time_slice.py index 81669aea..124c8d18 100644 --- a/tests/select/test_select_time_slice.py +++ b/tests/select/test_select_time_slice.py @@ -55,7 +55,7 @@ def test_select_time_slice(da_sat_like, t0_str): # Slice parameters t0 = pd.Timestamp(f"2024-01-02 {t0_str}") - interval_start = pd.Timedelta(-0, "min") + interval_start = pd.Timedelta(0, "min") interval_end = pd.Timedelta(60, "min") freq = pd.Timedelta("5min") @@ -113,7 +113,7 @@ def test_select_time_slice_out_of_bounds(da_sat_like, t0_str): if expected_datetimes[0] < min_time: assert all_nan_space.sel(time_utc=slice(None, min_time - freq)).all(dim="time_utc") - # Check all the values before the first timestamp available in the data are NaN + # Check all the values after the last timestamp available in the data are NaN if expected_datetimes[-1] > max_time: assert all_nan_space.sel(time_utc=slice(max_time + freq, None)).all(dim="time_utc") @@ -145,7 +145,6 @@ def test_select_time_slice_nwp_basic(da_nwp_like, t0_str): # Check the target-times are as expected expected_target_times = pd.date_range(t0 + interval_start, t0 + interval_end, freq=freq) - valid_times = da_slice.init_time_utc + da_slice.step assert (valid_times == expected_target_times).all() @@ -154,7 +153,7 @@ def test_select_time_slice_nwp_basic(da_nwp_like, t0_str): expected_init_times = pd.to_datetime( [t if t < t0 else t0 for t in expected_target_times], ).floor(NWP_FREQ) - assert (expected_init_times==da_slice.init_time_utc.values).all() + assert (expected_init_times == da_slice.init_time_utc.values).all() @pytest.mark.parametrize("dropout_hours", [1, 2, 5]) @@ -187,4 +186,4 @@ def test_select_time_slice_nwp_with_dropout(da_nwp_like, dropout_hours): expected_init_times = pd.to_datetime( [t if t < t0_delayed else t0_delayed for t in expected_target_times], ).floor(NWP_FREQ) - assert (expected_init_times==da_slice.init_time_utc.values).all() + assert (expected_init_times == da_slice.init_time_utc.values).all() diff --git a/tests/torch_datasets/sample/test_base.py b/tests/torch_datasets/sample/test_base.py index fa687de5..dcf3d7b0 100644 --- a/tests/torch_datasets/sample/test_base.py +++ b/tests/torch_datasets/sample/test_base.py @@ -29,7 +29,7 @@ def plot(self): return None @override - def to_numpy(self) -> None: + def to_numpy(self) -> dict: return {key: np.array(value) for key, value in self._data.items()} @override @@ -67,7 +67,6 @@ def test_sample_base_save_load(tmp_path): def test_sample_base_abstract_methods(): """Test abstract method enforcement""" - with pytest.raises(TypeError, match="Can't instantiate abstract class"): SampleBase() @@ -83,17 +82,13 @@ def test_sample_base_to_numpy(): numpy_data = sample.to_numpy() assert isinstance(numpy_data, dict) - assert all(isinstance(value, np.ndarray) for value in numpy_data.values()) + assert all(isinstance(v, np.ndarray) for v in numpy_data.values()) assert np.array_equal(numpy_data["list_data"], np.array([1, 2, 3])) def test_batch_to_tensor_nested(): """Test nested dictionary conversion""" - batch = { - "outer": { - "inner": np.array([1, 2, 3]), - }, - } + batch = {"outer": {"inner": np.array([1, 2, 3])}} tensor_batch = batch_to_tensor(batch) assert torch.equal(tensor_batch["outer"]["inner"], torch.tensor([1, 2, 3])) @@ -104,10 +99,7 @@ def test_batch_to_tensor_mixed_types(): batch = { "tensor_data": np.array([1, 2, 3]), "string_data": "not_a_tensor", - "nested": { - "numbers": np.array([4, 5, 6]), - "text": "still_not_a_tensor", - }, + "nested": {"numbers": np.array([4, 5, 6]), "text": "still_not_a_tensor"}, } tensor_batch = batch_to_tensor(batch) @@ -147,15 +139,14 @@ def test_batch_to_tensor_multidimensional(): def test_copy_batch_to_device(): """Test moving tensors to a different device""" - device = torch.device("cuda", index=0) if torch.cuda.is_available() else torch.device("cpu") + device = torch.device("cuda", index=0) if torch.cuda.is_available() else torch.device("cpu") batch = { "tensor_data": torch.tensor([1, 2, 3]), - "nested": { - "matrix": torch.tensor([[1, 2], [3, 4]]), - }, + "nested": {"matrix": torch.tensor([[1, 2], [3, 4]])}, "non_tensor": "unchanged", } moved_batch = copy_batch_to_device(batch, device) + assert moved_batch["tensor_data"].device == device assert moved_batch["nested"]["matrix"].device == device assert moved_batch["non_tensor"] == "unchanged" # Non-tensors should remain unchanged diff --git a/tests/torch_datasets/sample/test_site_sample.py b/tests/torch_datasets/sample/test_site_sample.py index bc53b2c0..77776be6 100644 --- a/tests/torch_datasets/sample/test_site_sample.py +++ b/tests/torch_datasets/sample/test_site_sample.py @@ -1,6 +1,7 @@ """ Site class testing - SiteSample """ + import tempfile import numpy as np @@ -26,9 +27,7 @@ def numpy_sample(): } return { - "nwp": { - "ukv": nwp_data, - }, + "nwp": {"ukv": nwp_data}, SiteSampleKey.generation: np.random.rand(*expected_site_shape), SatelliteSampleKey.satellite_actual: np.random.rand(*expected_sat_shape), "solar_azimuth": np.random.rand(*expected_solar_shape), @@ -44,9 +43,7 @@ def test_site_sample_with_data(numpy_sample): """Testing of defined sample with actual data""" sample = SiteSample(numpy_sample) - # Assert data structure assert isinstance(sample._data, dict) - assert sample._data["satellite_actual"].shape == (7, 1, 2, 2) assert sample._data["nwp"]["ukv"]["nwp"].shape == (4, 1, 2, 2) assert sample._data["site"].shape == (7,) @@ -61,10 +58,9 @@ def test_sample_save_load(numpy_sample): sample.save(tf.name) loaded = SiteSample.load(tf.name) - assert set(loaded._data.keys()) == set(sample._data.keys()) + assert set(loaded._data) == set(sample._data) assert isinstance(loaded._data["nwp"], dict) assert "ukv" in loaded._data["nwp"] - assert loaded._data[SiteSampleKey.generation].shape == (7,) assert loaded._data[SatelliteSampleKey.satellite_actual].shape == (7, 1, 2, 2) @@ -79,17 +75,15 @@ def test_to_numpy(numpy_sample): sample = SiteSample(numpy_sample) numpy_data = sample.to_numpy() - # Assert structure assert isinstance(numpy_data, dict) - assert "site" in numpy_data - assert "nwp" in numpy_data + assert "site" in numpy_data and "nwp" in numpy_data # Check site - numpy array instead of dict site_data = numpy_data["site"] assert isinstance(site_data, np.ndarray) assert site_data.ndim == 1 assert len(site_data) == 7 - assert np.all(site_data >= 0) and np.all(site_data <= 1) + assert np.all((site_data >= 0) & (site_data <= 1)) # Check NWP assert "ukv" in numpy_data["nwp"] diff --git a/tests/torch_datasets/sample/test_uk_regional_sample.py b/tests/torch_datasets/sample/test_uk_regional_sample.py index 5ff4a90f..eb3a2e83 100644 --- a/tests/torch_datasets/sample/test_uk_regional_sample.py +++ b/tests/torch_datasets/sample/test_uk_regional_sample.py @@ -53,7 +53,7 @@ def test_sample_save_load(numpy_sample): sample.save(tf.name) loaded = UKRegionalSample.load(tf.name) - assert set(loaded._data.keys()) == set(sample._data.keys()) + assert set(loaded._data) == set(sample._data) assert isinstance(loaded._data["nwp"], dict) assert "ukv" in loaded._data["nwp"] @@ -68,19 +68,15 @@ def test_sample_save_load(numpy_sample): def test_load_corrupted_file(): """Test loading - corrupted / empty file""" - with tempfile.NamedTemporaryFile(suffix=".pt") as tf, open(tf.name, "wb") as f: f.write(b"corrupted data") - with pytest.raises(EOFError): UKRegionalSample.load(tf.name) def test_to_numpy(numpy_sample): """To numpy conversion check""" - sample = UKRegionalSample(numpy_sample) - numpy_data = sample.to_numpy() # Check returned data matches @@ -100,8 +96,8 @@ def test_to_numpy(numpy_sample): def test_validate_sample(numpy_sample, pvnet_configuration_object: Configuration, caplog): """Test the validate_sample method succeeds with no warnings for a valid sample.""" sample = UKRegionalSample(numpy_sample) - caplog.set_level(logging.WARNING) - result = sample.validate_sample(pvnet_configuration_object) + with caplog.at_level(logging.WARNING): + result = sample.validate_sample(pvnet_configuration_object) assert isinstance(result, dict) assert result["valid"] is True @@ -149,6 +145,7 @@ def test_validate_sample_with_missing_solar_coors( modified_data = numpy_sample.copy() solar_key = "solar_azimuth" modified_data.pop(solar_key) + sample = UKRegionalSample(modified_data) expected_error_pattern = f"^Configuration expects {solar_key} data but is missing" @@ -163,6 +160,7 @@ def test_validate_sample_with_wrong_solar_shapes( """Test validation raises ValueError when solar data shape is incorrect.""" modified_data = numpy_sample.copy() modified_data["solar_azimuth"] = np.random.rand(10) + sample = UKRegionalSample(modified_data) with pytest.raises(ValueError, match="'Solar Azimuth data' shape mismatch: Actual shape:"): @@ -184,7 +182,7 @@ def test_validate_sample_with_unexpected_provider( NWPSampleKey.channel_names: ["t"], } if NWPSampleKey.nwp not in modified_data: - modified_data[NWPSampleKey.nwp] = {} + modified_data[NWPSampleKey.nwp] = {} modified_data[NWPSampleKey.nwp][unexpected_provider] = nwp_data sample = UKRegionalSample(modified_data) @@ -213,6 +211,7 @@ def test_validate_sample_with_unexpected_component( modified_data = numpy_sample.copy() unexpected_key = "unexpected_component_key_xyz" modified_data[unexpected_key] = np.random.rand(7).astype(np.float32) + sample = UKRegionalSample(modified_data) with caplog.at_level(logging.WARNING): @@ -226,5 +225,4 @@ def test_validate_sample_with_unexpected_component( assert len(warning_logs) == 1, "Expected exactly one warning log" log_message = warning_logs[0].message - expected_substring = f"Unexpected component '{unexpected_key}'" - assert expected_substring in log_message + assert f"Unexpected component '{unexpected_key}'" in log_message diff --git a/tests/torch_datasets/test_merge_and_fill_utils.py b/tests/torch_datasets/test_merge_and_fill_utils.py index 0caf0978..e20bcbbc 100644 --- a/tests/torch_datasets/test_merge_and_fill_utils.py +++ b/tests/torch_datasets/test_merge_and_fill_utils.py @@ -31,9 +31,7 @@ def test_fill_nans_in_arrays(): array_with_nans = np.array([1.0, np.nan, 3.0, np.nan]) nested_dict = { "array1": array_with_nans, - "nested": { - "array2": np.array([np.nan, 2.0, np.nan, 4.0]), - }, + "nested": {"array2": np.array([np.nan, 2.0, np.nan, 4.0])}, "string_key": "not_an_array", } diff --git a/tests/torch_datasets/test_pvnet_uk.py b/tests/torch_datasets/test_pvnet_uk.py index c023a052..f4e3eef1 100644 --- a/tests/torch_datasets/test_pvnet_uk.py +++ b/tests/torch_datasets/test_pvnet_uk.py @@ -15,31 +15,28 @@ def test_pvnet_uk_regional_dataset(pvnet_config_filename): - # Create dataset object dataset = PVNetUKRegionalDataset(pvnet_config_filename) - assert len(dataset.locations) == 317 # Number of regional GSPs + assert len(dataset.locations) == 317 # Quantity of regional GSPs # NB. I have not checked the value (39 below) is in fact correct assert len(dataset.valid_t0_times) == 39 assert len(dataset) == 317 * 39 - # Generate a sample sample = dataset[0] - assert isinstance(sample, dict) - # These keys should always be present + # Specific keys should always be present required_keys = ["nwp", "satellite_actual", "gsp", "t0"] for key in required_keys: assert key in sample solar_keys = ["solar_azimuth", "solar_elevation"] if dataset.config.input_data.solar_position is not None: - # Test that solar position keys are present when configured + # Test solar position keys are present when configured for key in solar_keys: assert key in sample, f"Solar position key {key} should be present in sample" - # Get expected time steps from configuration + # Get expected time steps from config expected_time_steps = ( dataset.config.input_data.solar_position.interval_end_minutes - dataset.config.input_data.solar_position.interval_start_minutes @@ -49,7 +46,7 @@ def test_pvnet_uk_regional_dataset(pvnet_config_filename): assert sample["solar_azimuth"].shape == (expected_time_steps,) assert sample["solar_elevation"].shape == (expected_time_steps,) else: - # Test that solar position keys are not present + # Assert that solar position keys are not present for key in solar_keys: assert key not in sample, f"Solar position key {key} should not be present" @@ -66,10 +63,9 @@ def test_pvnet_uk_regional_dataset(pvnet_config_filename): def test_pvnet_uk_regional_dataset_limit_gsp_ids(pvnet_config_filename): - # Create dataset object dataset = PVNetUKRegionalDataset(pvnet_config_filename, gsp_ids=[1, 2, 3]) - assert len(dataset.locations) == 3 # Number of regional GSPs + assert len(dataset.locations) == 3 # Quantity of regional GSPs assert len(dataset.datasets_dict["gsp"].gsp_id) == 3 @@ -80,71 +76,56 @@ def test_pvnet_no_gsp(tmp_path, pvnet_config_filename): new_config_path = tmp_path / "pvnet_config_no_gsp.yaml" save_yaml_configuration(config, new_config_path) - # Create dataset object + # Create dataset object and geneerate sample dataset = PVNetUKRegionalDataset(new_config_path) - - # Generate a sample _ = dataset[0] def test_pvnet_uk_concurrent_dataset(pvnet_config_filename): - # Create dataset object using a limited set of GSPs for test + # Create dataset object using limited set of GSPs gsp_ids = [1, 2, 3] num_gsps = len(gsp_ids) - dataset = PVNetUKConcurrentDataset(pvnet_config_filename, gsp_ids=gsp_ids) - assert len(dataset.locations) == num_gsps # Number of regional GSPs + assert len(dataset.locations) == num_gsps # Quantity of regional GSPs # NB. I have not checked the value (39 below) is in fact correct assert len(dataset.valid_t0_times) == 39 assert len(dataset) == 39 - # Generate a sample sample = dataset[0] - assert isinstance(sample, dict) - # These keys should always be present required_keys = ["nwp", "satellite_actual", "gsp"] for key in required_keys: assert key in sample - # Check if solar position is configured in the dataset solar_keys = ["solar_azimuth", "solar_elevation"] if dataset.config.input_data.solar_position is not None: - # Solar position keys should be present when configured for key in solar_keys: assert key in sample, f"Solar position key {key} should be present in sample" - # Get expected time steps from configuration expected_time_steps = ( dataset.config.input_data.solar_position.interval_end_minutes - dataset.config.input_data.solar_position.interval_start_minutes ) // dataset.config.input_data.solar_position.time_resolution_minutes + 1 - # Test solar angle shapes based on configuration assert sample["solar_azimuth"].shape == (num_gsps, expected_time_steps) assert sample["solar_elevation"].shape == (num_gsps, expected_time_steps) else: - # Solar position keys should not be present when not configured for key in solar_keys: assert key not in sample, f"Solar position key {key} should not be present" for nwp_source in ["ukv"]: assert nwp_source in sample["nwp"] - # Check the shape of the data is correct - # 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels + # Shape assertion checking assert sample["satellite_actual"].shape == (num_gsps, 7, 1, 2, 2) - # 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels assert sample["nwp"]["ukv"]["nwp"].shape == (num_gsps, 4, 1, 2, 2) - # 3 hours of 30 minute data (inclusive) assert sample["gsp"].shape == (num_gsps, 7) def test_solar_position_decoupling(tmp_path, pvnet_config_filename): """Test that solar position calculations are properly decoupled from data sources.""" - config = load_yaml_configuration(pvnet_config_filename) config_without_solar = config.model_copy(deep=True) config_without_solar.input_data.solar_position = None @@ -157,13 +138,13 @@ def test_solar_position_decoupling(tmp_path, pvnet_config_filename): interval_end_minutes=180, ) - # Save both testing configurations + # Save both testing configs config_without_solar_path = tmp_path / "config_without_solar.yaml" config_with_solar_path = tmp_path / "config_with_solar.yaml" save_yaml_configuration(config_without_solar, config_without_solar_path) save_yaml_configuration(config_with_solar, config_with_solar_path) - # Create datasets with both configurations + # Create datasets with both configs dataset_without_solar = PVNetUKRegionalDataset(config_without_solar_path, gsp_ids=[1]) dataset_with_solar = PVNetUKRegionalDataset(config_with_solar_path, gsp_ids=[1]) @@ -171,24 +152,17 @@ def test_solar_position_decoupling(tmp_path, pvnet_config_filename): sample_without_solar = dataset_without_solar[0] sample_with_solar = dataset_with_solar[0] - # Assert solar position keys are only in sample specifically with solar configuration + # Assert solar position keys are only in sample specifically with solar config solar_keys = ["solar_azimuth", "solar_elevation"] - # Sample without solar config should not have solar position data for key in solar_keys: assert key not in sample_without_solar, f"Solar key {key} should not be in sample" - - # Sample with solar config should have solar position data for key in solar_keys: assert key in sample_with_solar, f"Solar key {key} should be in sample" def test_pvnet_uk_regional_dataset_raw_sample_iteration(pvnet_config_filename): - """ - Tests iterating raw samples (dict of tensors) from PVNetUKRegionalDataset - """ - - # Create dataset object + """Tests iterating raw samples (dict of tensors) from PVNetUKRegionalDataset""" dataset = PVNetUKRegionalDataset(pvnet_config_filename) dataloader = DataLoader( dataset, @@ -201,15 +175,22 @@ def test_pvnet_uk_regional_dataset_raw_sample_iteration(pvnet_config_filename): raw_sample = next(iter(dataloader)) # Assertions for the raw sample - assert isinstance(raw_sample, dict), \ - "Sample yielded by DataLoader with batch_size=None should be a dict" - - # Check for expected keys directly - required_keys = ["nwp", "satellite_actual", "gsp", "solar_azimuth", "solar_elevation", "gsp_id"] + assert isinstance( + raw_sample, dict, + ), "Sample yielded by DataLoader with batch_size=None should be a dict" + + required_keys = [ + "nwp", + "satellite_actual", + "gsp", + "solar_azimuth", + "solar_elevation", + "gsp_id", + ] for key in required_keys: assert key in raw_sample, f"Raw Sample: Expected key '{key}' not found" - # Check types are primarily torch.Tensor + # Type assertions assert isinstance(raw_sample["satellite_actual"], torch.Tensor) assert isinstance(raw_sample["gsp"], torch.Tensor) assert isinstance(raw_sample["solar_azimuth"], torch.Tensor) @@ -219,16 +200,15 @@ def test_pvnet_uk_regional_dataset_raw_sample_iteration(pvnet_config_filename): assert isinstance(raw_sample["nwp"]["ukv"]["nwp"], torch.Tensor) assert isinstance(raw_sample["nwp"]["ukv"]["nwp_channel_names"], np.ndarray) - # Check shapes + # Shape assertions assert raw_sample["satellite_actual"].shape == (7, 1, 2, 2) assert raw_sample["nwp"]["ukv"]["nwp"].shape == (4, 1, 2, 2) assert raw_sample["gsp"].shape == (7,) - # Check solar position shapes (no batch dimension) + # Solar position shapes - no batch dimension solar_config = dataset.config.input_data.solar_position expected_time_steps = ( - solar_config.interval_end_minutes - - solar_config.interval_start_minutes + solar_config.interval_end_minutes - solar_config.interval_start_minutes ) // solar_config.time_resolution_minutes + 1 assert raw_sample["solar_azimuth"].shape == (expected_time_steps,) assert raw_sample["solar_elevation"].shape == (expected_time_steps,) @@ -237,34 +217,25 @@ def test_pvnet_uk_regional_dataset_raw_sample_iteration(pvnet_config_filename): def test_pvnet_uk_regional_dataset_pickle(tmp_path, pvnet_config_filename): - pickle_path = f"{tmp_path}.pkl" dataset = PVNetUKRegionalDataset(pvnet_config_filename) - # Presave the pickled dataset + # Assert path is in pickle object dataset.presave_pickle(pickle_path) - - # Since its been pe-pickled this should just return a reference to the previous pickle pickle_bytes = pickle.dumps(dataset) - - # Check the path is in the pickle object assert pickle_path.encode("utf-8") in pickle_bytes # Check we can reload the object - _ = pickle.loads(pickle_bytes) # noqa: S301 - + _ = pickle.loads(pickle_bytes) # noqa: S301 # Check we can still pickle and unpickle if we don't presave dataset = PVNetUKRegionalDataset(pvnet_config_filename) pickle_bytes = pickle.dumps(dataset) - _ = pickle.loads(pickle_bytes) # noqa: S301 + _ = pickle.loads(pickle_bytes) # noqa: S301 -def test_pvnet_uk_regional_dataset_batch_size_2(pvnet_config_filename): - """ - Tests makeing batches from PVNetUKRegionalDataset - """ - # Create dataset object +def test_pvnet_uk_regional_dataset_batch_size_2(pvnet_config_filename): + """Tests making batches from PVNetUKRegionalDataset""" dataset = PVNetUKRegionalDataset(pvnet_config_filename) dataloader = DataLoader( dataset, @@ -279,16 +250,21 @@ def test_pvnet_uk_regional_dataset_batch_size_2(pvnet_config_filename): batch = copy_batch_to_device(batch, torch.device("cpu")) # Assertions for the raw batch - assert isinstance(batch, dict), \ - "Sample yielded by DataLoader with batch_size=2 should be a dict" - - # Check for expected keys directly - required_keys = \ - ["nwp", "satellite_actual", "gsp", "solar_azimuth", "solar_elevation", "gsp_id", "t0"] + assert isinstance(batch, dict), "Sample yielded by DataLoader with batch_size=2 should be dict" + + required_keys = [ + "nwp", + "satellite_actual", + "gsp", + "solar_azimuth", + "solar_elevation", + "gsp_id", + "t0", + ] for key in required_keys: assert key in batch, f"Raw Sample: Expected key '{key}' not found" - # Check types are primarily torch.Tensor + # Type assertions assert isinstance(batch["satellite_actual"], torch.Tensor) assert isinstance(batch["gsp"], torch.Tensor) assert isinstance(batch["solar_azimuth"], torch.Tensor) @@ -299,8 +275,8 @@ def test_pvnet_uk_regional_dataset_batch_size_2(pvnet_config_filename): assert isinstance(batch["nwp"]["ukv"]["nwp_channel_names"], np.ndarray) assert isinstance(batch["t0"], torch.Tensor) - # Check shapes + # Shape assertions assert batch["satellite_actual"].shape == (2, 7, 1, 2, 2) assert batch["nwp"]["ukv"]["nwp"].shape == (2, 4, 1, 2, 2) - assert batch["gsp"].shape == (2,7) + assert batch["gsp"].shape == (2, 7) assert batch["t0"].shape == (2,) diff --git a/tests/torch_datasets/test_site.py b/tests/torch_datasets/test_site.py index f80f68cd..2cdd1706 100644 --- a/tests/torch_datasets/test_site.py +++ b/tests/torch_datasets/test_site.py @@ -3,24 +3,19 @@ from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration from ocf_data_sampler.config.model import SolarPosition -from ocf_data_sampler.torch_datasets.datasets.site import ( - SitesDataset, -) +from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset def test_site(site_config_filename): - # Create dataset object dataset = SitesDataset(site_config_filename) - # (10 sites * 24 valid t0s per site = 240) assert len(dataset) == 240 # Generate a sample sample = dataset[0] - assert isinstance(sample, dict) - # # Expected dimensions and data variables + # Expected keys expected_keys = { "date_cos", "date_sin", @@ -39,31 +34,26 @@ def test_site(site_config_filename): "nwp", "t0", } - - # Check keys assert set(sample.keys()) == expected_keys, ( f"Missing or extra dimensions: {set(sample.keys()) ^ expected_keys}" ) - # check the shape of the data is correct based on new config intervals and image sizes - # Satellite: (0 - (-30)) / 5 + 1 = 7 time steps; 2 channels (IR_016, VIS006); 24x24 pixels + + # Check shapes based on config intervals and image sizes + # Satellite: (0 - (-30)) / 5 + 1 = 7 time steps; 2 channels; 24x24 pixels assert sample["satellite_actual"].shape == (7, 2, 24, 24) - # NWP-UKV: (480 - (-60)) / 60 + 1 = 10 time steps; 1 channel (t); 24x24 pixels + # NWP-UKV: (480 - (-60)) / 60 + 1 = 10 time steps; 1 channel; 24x24 pixels assert sample["nwp"]["ukv"]["nwp"].shape == (10, 1, 24, 24) - # Site: (60 - (-30)) / 30 + 1 = 4 time steps (from site_config_filename interval) + # Site: (60 - (-30)) / 30 + 1 = 4 time steps assert sample["site"].shape == (4,) def test_site_time_filter_start(site_config_filename): - # Create dataset object dataset = SitesDataset(site_config_filename, start_time="2024-01-01") - assert len(dataset) == 0 def test_site_time_filter_end(site_config_filename): - # Create dataset object dataset = SitesDataset(site_config_filename, end_time="2000-01-01") - assert len(dataset) == 0 @@ -83,28 +73,22 @@ def test_site_dataset_with_dataloader(sites_dataset) -> None: individual_sample = next(iter(dataloader)) except StopIteration: pytest.skip("Skipping test as dataloader is empty.") - return - assert isinstance(individual_sample, dict) - # check the shape of the data is correct based on new config intervals and image sizes - # Satellite: (0 - (-30)) / 5 + 1 = 7 time steps; 2 channels (IR_016, VIS006); 24x24 pixels + # Check shapes based on config intervals and image sizes assert individual_sample["satellite_actual"].shape == (7, 2, 24, 24) - # NWP-UKV: (480 - (-60)) / 60 + 1 = 10 time steps; 1 channel (t); 24x24 pixels assert individual_sample["nwp"]["ukv"]["nwp"].shape == (10, 1, 24, 24) - # Site: (60 - (-30)) / 30 + 1 = 4 time steps (from site_config_filename interval) assert individual_sample["site"].shape == (4,) def test_solar_position_decoupling_site(tmp_path, site_config_filename): """Test that solar position calculations are properly decoupled from data sources.""" - config = load_yaml_configuration(site_config_filename) config_without_solar = config.model_copy(deep=True) config_without_solar.input_data.solar_position = None - # Create version with explicit solar position configuration + # Version with explicit solar position config config_with_solar = config.model_copy(deep=True) config_with_solar.input_data.solar_position = SolarPosition( time_resolution_minutes=30, @@ -112,27 +96,19 @@ def test_solar_position_decoupling_site(tmp_path, site_config_filename): interval_end_minutes=60, ) - # Save both testing configurations + # Save both testing configs config_without_solar_path = tmp_path / "site_config_without_solar.yaml" config_with_solar_path = tmp_path / "site_config_with_solar.yaml" save_yaml_configuration(config_without_solar, config_without_solar_path) save_yaml_configuration(config_with_solar, config_with_solar_path) - # Create datasets with both configurations + # Create datasets and generate samples dataset_without_solar = SitesDataset(config_without_solar_path) dataset_with_solar = SitesDataset(config_with_solar_path) - - # Generate samples sample_without_solar = dataset_without_solar[0] sample_with_solar = dataset_with_solar[0] - # Assert solar position keys are only in sample specifically with solar configuration - solar_keys = ["solar_azimuth", "solar_elevation"] - - # Sample without solar config should not have solar position data - for key in solar_keys: + # Assert solar position keys presence/absence + for key in ["solar_azimuth", "solar_elevation"]: assert key not in sample_without_solar, f"Solar key {key} should not be in sample" - - # Sample with solar config should have solar position data - for key in solar_keys: assert key in sample_with_solar, f"Solar key {key} should be in sample" diff --git a/tests/torch_datasets/test_time_slice_for_dataset.py b/tests/torch_datasets/test_time_slice_for_dataset.py index c298fd0d..cdb2e204 100644 --- a/tests/torch_datasets/test_time_slice_for_dataset.py +++ b/tests/torch_datasets/test_time_slice_for_dataset.py @@ -6,22 +6,22 @@ def test_time_slice_for_dataset_site_dropout(site_config_filename): - # Create dataset object dataset = SitesDataset(site_config_filename) - datasets_dict = dataset.datasets_dict config = dataset.config - # set dropout + # Set dropout config.input_data.site.dropout_timedeltas_minutes = [-30] config.input_data.site.dropout_fraction = 1.0 - sliced_datasets_dict = slice_datasets_by_time( - datasets_dict=datasets_dict, t0=pd.Timestamp("2023-01-01 12:00"), config=config, + sliced = slice_datasets_by_time( + datasets_dict=dataset.datasets_dict, + t0=pd.Timestamp("2023-01-01 12:00"), + config=config, ) - site_dataset = sliced_datasets_dict["site"] + site_dataset = sliced["site"] - # for all 10 site ids the second element should nan due to dropout + # For all 10 site IDs the second time step should be NaN due to dropout assert np.all(np.isnan(site_dataset[1, :])) - # the last element which is after t0 should not be impacted by dropout + # The last time step (after t0) should not be impacted by dropout assert np.all(~np.isnan(site_dataset[-1, :]))