|
| 1 | +import logging |
| 2 | + |
| 3 | +import pandas as pd |
| 4 | +from fastapi import HTTPException |
| 5 | +import pytest |
| 6 | + |
| 7 | +from india_api.internal import PredictedPower, ActualPower, SiteProperties |
| 8 | + |
| 9 | +from pvsite_datamodel.sqlmodels import APIRequestSQL |
| 10 | + |
| 11 | +from .client import Client |
| 12 | +from .conftest import forecast_values |
| 13 | +from ...models import ForecastHorizon |
| 14 | +from ...service.csv import format_csv_and_created_time |
| 15 | + |
| 16 | +log = logging.getLogger(__name__) |
| 17 | + |
| 18 | +# TODO add list of test that are here |
| 19 | + |
| 20 | + |
| 21 | +@pytest.fixture() |
| 22 | +def client(engine, db_session): |
| 23 | + """Hooks Client into pytest db_session fixture""" |
| 24 | + client = Client(database_url=str(engine.url)) |
| 25 | + client.session = db_session |
| 26 | + |
| 27 | + return client |
| 28 | + |
| 29 | +# Skip for now |
| 30 | +@pytest.mark.skip(reason="Not finished yet") |
| 31 | +class TestCsvExport: |
| 32 | + def test_format_csv_and_created_time(self, client, forecast_values_wind) -> None: |
| 33 | + """Test the format_csv_and_created_time function.""" |
| 34 | + forecast_values_wind = client.get_predicted_wind_power_production_for_location( |
| 35 | + location="testID" |
| 36 | + ) |
| 37 | + assert forecast_values_wind is not None |
| 38 | + assert len(forecast_values_wind) > 0 |
| 39 | + assert isinstance(forecast_values_wind[0], PredictedPower) |
| 40 | + |
| 41 | + result = format_csv_and_created_time( |
| 42 | + forecast_values_wind, |
| 43 | + ForecastHorizon.latest, |
| 44 | + ) |
| 45 | + assert isinstance(result, tuple) |
| 46 | + assert isinstance(result[0], pd.DataFrame) |
| 47 | + assert isinstance(result[1], pd.Timestamp) |
| 48 | + logging.info(f"CSV created at: {result[1]}") |
| 49 | + logging.info(f"CSV content: {result[0].head()}") |
| 50 | + # Check the shape of the DataFrame |
| 51 | + # The shape should match the number of forecast values |
| 52 | + # and the number of columns in the DataFrame |
| 53 | + # The DataFrame should have 3 columns: Date [IST], Time, PowerMW |
| 54 | + assert result[0].shape[1] == 3 |
| 55 | + # Check the first row of the DataFrame |
| 56 | + # The date of the first row should be the nearest rounded 15min from now |
| 57 | + rounded_15_min = pd.Timestamp.now(tz="Asia/Kolkata").round("15min") |
| 58 | + assert result[0].iloc[0]["Time"] == rounded_15_min.strftime("%H:%M") |
| 59 | + # Check the number of rows in the DataFrame |
| 60 | + # For the latest forecast, it should be the number of |
| 61 | + # forecast values after now |
| 62 | + forecast_values_from_now = [ |
| 63 | + value for value in forecast_values_wind if value.Time >= rounded_15_min |
| 64 | + ] |
| 65 | + assert result[0].shape[0] == len(forecast_values_from_now) |
| 66 | + # Check the column names |
| 67 | + assert list(result[0].columns) == ["Date [IST]", "Time", "PowerMW"] |
| 68 | + # Check the data types of the columns |
| 69 | + assert result[0]["Date [IST]"].dtype == "datetime64[ns, Asia/Kolkata]" |
| 70 | + assert result[0]["Time"].dtype == "object" |
| 71 | + assert result[0]["PowerMW"].dtype == "float64" |
0 commit comments