Skip to content

Commit 5d706c7

Browse files
committed
fix tests
1 parent 3a5fed6 commit 5d706c7

File tree

3 files changed

+27
-5
lines changed

3 files changed

+27
-5
lines changed

src/india_api/internal/inputs/indiadb/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def get_predicted_solar_power_production_for_location(
195195
forecast_horizon=forecast_horizon,
196196
forecast_horizon_minutes=forecast_horizon_minutes,
197197
smooth_flag=smooth_flag,
198-
model_name=model_name,
198+
ml_model_name=model_name,
199199
)
200200

201201
def get_predicted_wind_power_production_for_location(
@@ -224,7 +224,7 @@ def get_predicted_wind_power_production_for_location(
224224
forecast_horizon=forecast_horizon,
225225
forecast_horizon_minutes=forecast_horizon_minutes,
226226
smooth_flag=smooth_flag,
227-
model_name=model_name,
227+
ml_model_name=model_name,
228228
)
229229

230230
def get_actual_solar_power_production_for_location(

src/india_api/internal/inputs/indiadb/conftest.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytest
77
from pvsite_datamodel.sqlmodels import Base, ForecastSQL, ForecastValueSQL, GenerationSQL, SiteSQL
88
from pvsite_datamodel.read.user import get_user_by_email
9+
from pvsite_datamodel.read.model import get_or_create_model
910
from sqlalchemy import create_engine
1011
from sqlalchemy.orm import Session
1112
from testcontainers.postgres import PostgresContainer
@@ -123,6 +124,23 @@ def generations(db_session, sites):
123124
@pytest.fixture()
124125
def forecast_values(db_session, sites):
125126
"""Create some fake forecast values"""
127+
128+
make_fake_forecast_values(db_session, sites, "pvnet_india")
129+
130+
@pytest.fixture()
131+
def forecast_values_wind(db_session, sites):
132+
"""Create some fake forecast values"""
133+
134+
make_fake_forecast_values(db_session, sites, "windnet_india")
135+
136+
@pytest.fixture()
137+
def forecast_values_site(db_session, sites):
138+
"""Create some fake forecast values"""
139+
140+
make_fake_forecast_values(db_session, sites, "pvnet_ad_sites")
141+
142+
143+
def make_fake_forecast_values(db_session, sites, model_name):
126144
forecast_values = []
127145
forecast_version: str = "0.0.0"
128146

@@ -134,6 +152,9 @@ def forecast_values(db_session, sites):
134152
# To make things trickier we make a second forecast at the same for one of the timestamps.
135153
timestamps = timestamps + timestamps[-1:]
136154

155+
# get model
156+
ml_model = get_or_create_model(db_session, model_name)
157+
137158
for site in sites:
138159
for timestamp in timestamps:
139160
forecast: ForecastSQL = ForecastSQL(
@@ -154,6 +175,7 @@ def forecast_values(db_session, sites):
154175
end_utc=timestamp + timedelta(minutes=horizon + duration),
155176
horizon_minutes=horizon,
156177
)
178+
forecast_value.ml_model = ml_model
157179

158180
forecast_values.append(forecast_value)
159181

src/india_api/internal/inputs/indiadb/test_indiadb.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def client(engine, db_session):
2323

2424
class TestIndiaDBClient:
2525
def test_get_predicted_wind_power_production_for_location(
26-
self, client, forecast_values
26+
self, client, forecast_values_wind
2727
) -> None:
2828
locID = "testID"
2929
result = client.get_predicted_wind_power_production_for_location(locID)
@@ -33,7 +33,7 @@ def test_get_predicted_wind_power_production_for_location(
3333
assert isinstance(record, PredictedPower)
3434

3535
def test_get_predicted_wind_power_production_for_location_raise_error(
36-
self, client, forecast_values
36+
self, client, forecast_values_wind
3737
) -> None:
3838

3939
with pytest.raises(Exception):
@@ -83,7 +83,7 @@ def test_get_sites_no_sites(self, client, sites) -> None:
8383
sites_from_api = client.get_sites(email="test2@test.com")
8484
assert len(sites_from_api) == 0
8585

86-
def test_get_site_forecast(self, client, sites, forecast_values) -> None:
86+
def test_get_site_forecast(self, client, sites, forecast_values_site) -> None:
8787
out = client.get_site_forecast(site_uuid=str(sites[0].site_uuid), email="test@test.com")
8888
assert len(out) > 0
8989

0 commit comments

Comments
 (0)