Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions src/india_api/internal/inputs/indiadb/test_csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import logging

import pandas as pd
from fastapi import HTTPException
import pytest

from india_api.internal import PredictedPower, ActualPower, SiteProperties

from pvsite_datamodel.sqlmodels import APIRequestSQL

from .client import Client
from .conftest import forecast_values
from ...models import ForecastHorizon
from ...service.csv import format_csv_and_created_time

log = logging.getLogger(__name__)

# TODO add list of test that are here


@pytest.fixture()
def client(engine, db_session):
"""Hooks Client into pytest db_session fixture"""
client = Client(database_url=str(engine.url))
client.session = db_session

return client

# Skip for now
@pytest.mark.skip(reason="Not finished yet")
class TestCsvExport:
def test_format_csv_and_created_time(self, client, forecast_values_wind) -> None:
"""Test the format_csv_and_created_time function."""
forecast_values_wind = client.get_predicted_wind_power_production_for_location(
location="testID"
)
assert forecast_values_wind is not None
assert len(forecast_values_wind) > 0
assert isinstance(forecast_values_wind[0], PredictedPower)

result = format_csv_and_created_time(
forecast_values_wind,
ForecastHorizon.latest,
)
assert isinstance(result, tuple)
assert isinstance(result[0], pd.DataFrame)
assert isinstance(result[1], pd.Timestamp)
logging.info(f"CSV created at: {result[1]}")
logging.info(f"CSV content: {result[0].head()}")
# Check the shape of the DataFrame
# The shape should match the number of forecast values
# and the number of columns in the DataFrame
# The DataFrame should have 3 columns: Date [IST], Time, PowerMW
assert result[0].shape[1] == 3
# Check the first row of the DataFrame
# The date of the first row should be the nearest rounded 15min from now
rounded_15_min = pd.Timestamp.now(tz="Asia/Kolkata").round("15min")
assert result[0].iloc[0]["Time"] == rounded_15_min.strftime("%H:%M")
# Check the number of rows in the DataFrame
# For the latest forecast, it should be the number of
# forecast values after now
forecast_values_from_now = [
value for value in forecast_values_wind if value.Time >= rounded_15_min
]
assert result[0].shape[0] == len(forecast_values_from_now)
# Check the column names
assert list(result[0].columns) == ["Date [IST]", "Time", "PowerMW"]
# Check the data types of the columns
assert result[0]["Date [IST]"].dtype == "datetime64[ns, Asia/Kolkata]"
assert result[0]["Time"].dtype == "object"
assert result[0]["PowerMW"].dtype == "float64"
19 changes: 12 additions & 7 deletions src/india_api/internal/service/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from datetime import datetime

from india_api.internal import PredictedPower
from india_api.internal.models import ForecastHorizon


def format_csv_and_created_time(values: list[PredictedPower]) -> (pd.DataFrame, datetime):
def format_csv_and_created_time(values: list[PredictedPower], forecast_horizon: ForecastHorizon) -> (pd.DataFrame, datetime):
"""
Format the predicted power values into a pandas dataframe ready for CSV export.

Expand All @@ -26,16 +27,20 @@ def format_csv_and_created_time(values: list[PredictedPower]) -> (pd.DataFrame,
df["Date [IST]"] = df["Time"].dt.date
# create start and end time column and only show HH:MM
df["Start Time [IST]"] = df["Time"].dt.strftime("%H:%M")
df["End Time [IST]"] = (df["Time"] + pd.to_timedelta("15T")).dt.strftime("%H:%M")
df["End Time [IST]"] = (df["Time"] + pd.to_timedelta("15min")).dt.strftime("%H:%M")

now_ist = pd.Timestamp.now(tz="Asia/Kolkata")
if forecast_horizon == ForecastHorizon.day_ahead:
# only get tomorrow's results, for IST time.
tomorrow = now_ist + pd.Timedelta(days=1)
df = df[df["Date [IST]"] == tomorrow.date()]
elif forecast_horizon == ForecastHorizon.latest:
# only get results from now onwards, for IST time.
df = df[df["Time"] >= now_ist]

# combine start and end times
df["Time"] = df["Start Time [IST]"].astype(str) + " - " + df["End Time [IST]"].astype(str)

# only get tomorrows results. This is for IST time.
now_ist = pd.Timestamp.now(tz="Asia/Kolkata")
tomorrow = now_ist + pd.Timedelta(days=1)
df = df[df["Date [IST]"] == tomorrow.date()]

# get the max created time
created_time = df["CreatedTime"].max()

Expand Down
38 changes: 32 additions & 6 deletions src/india_api/internal/service/regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,35 +186,61 @@ def get_forecast_timeseries_route(
response_class=FileResponse,
include_in_schema=False,
)
def get_forecast_da_csv(
def get_forecast_csv(
source: ValidSourceDependency,
region: str,
db: DBClientDependency,
auth: dict = Depends(auth),
forecast_horizon: Optional[ForecastHorizon] = ForecastHorizon.latest,
forecast_horizon_minutes: Optional[int] = 0,
):
"""
Route to get the day ahead forecast as a CSV file.
By default, the CSV file will be for the latest forecast, from now forwards.
The forecast_horizon can be set to 'latest', 'day_ahead' or 'horizon'.
- latest: The latest forecast, from now forwards.
- day_ahead: The forecast for the next day, from midnight.
- horizon: The forecast for the next horizon_horizon_minutes minutes, from default forecast history start.
The forecast_horizon_minutes is only used if the forecast_horizon is set to 'horizon'.
"""

forcasts: GetForecastGenerationResponse = get_forecast_timeseries_route(
if forecast_horizon is not None:
if forecast_horizon not in [ForecastHorizon.latest, ForecastHorizon.day_ahead, ForecastHorizon.horizon]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid forecast_horizon {forecast_horizon}. Must be 'latest', 'day_ahead', or 'horizon.",
)

forecasts: GetForecastGenerationResponse = get_forecast_timeseries_route(
source=source,
region=region,
db=db,
auth=auth,
forecast_horizon=ForecastHorizon.day_ahead,
forecast_horizon=forecast_horizon,
forecast_horizon_minutes=forecast_horizon_minutes,
smooth_flag=False,
)

# format to dataframe
df, created_time = format_csv_and_created_time(forcasts.values)
df, created_time = format_csv_and_created_time(forecasts.values, forecast_horizon=forecast_horizon)

# make file format
now_ist = pd.Timestamp.now(tz="Asia/Kolkata")
tomorrow_ist = df["Date [IST]"].iloc[0]
csv_file_path = f"{region}_{source}_da_{tomorrow_ist}.csv"
match forecast_horizon:
case ForecastHorizon.latest:
forecast_type = "intraday"
case ForecastHorizon.day_ahead:
forecast_type = "da"
case ForecastHorizon.horizon:
forecast_type = f"horizon_{forecast_horizon_minutes}"
case _:
# this shouldn't happen but will handle if class is changed
Copy link
Contributor

@peterdudfield peterdudfield Apr 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raise 404 instead? or other appropriate error

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should raise a 400 at the top of this route handler if it's not one of the defined options – this was more to keep the inferred Python types happy because I don't think they can quite work out this should always hit one of these cases? 😅

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we remove above - we should be using ForecastHorizon.horizon and this should raise a 400 then

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have stripped out the horizon option and doubled up on the 400 exception; feels like duplication, but can't see how else the inferred types could want this if we still want to validate at the start and return early if the param isn't valid (which we do)

forecast_type = "default"
csv_file_path = f"{region}_{source}_{forecast_type}_{tomorrow_ist}.csv"

description = (
f"Forecast for {region} for {source} for {tomorrow_ist}. "
f"Forecast for {region} for {source}, {forecast_type}, for {tomorrow_ist}. "
f"The Forecast was created at {created_time} and downloaded at {now_ist}"
)

Expand Down
Loading