Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ classifiers = ["Programming Language :: Python :: 3"]
dependencies = [
"cryptography >= 42.0.7",
"fastapi >= 0.105.0",
"pvsite-datamodel >= 1.0.45",
"pvsite-datamodel@git+https://github.com/openclimatefix/pv-site-datamodel.git#egg=site-location",
"pyjwt >= 2.8.0",
"pyproj >= 3.3.0",
"pytz >= 2023.3",
Expand Down
37 changes: 20 additions & 17 deletions src/india_api/internal/inputs/indiadb/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
get_site_by_uuid,
)
from pvsite_datamodel.write.generation import insert_generation_values
from pvsite_datamodel.sqlmodels import SiteAssetType, ForecastValueSQL
from pvsite_datamodel.sqlmodels import LocationAssetType, ForecastValueSQL
from pvsite_datamodel.write.database import save_api_call_to_db
from pvsite_datamodel.write.user_and_site import edit_site
from pvsite_datamodel.pydantic_models import PVSiteEditMetadata
Expand Down Expand Up @@ -60,7 +60,7 @@ def save_api_call_to_db(self, url: str, email=""):
def get_predicted_power_production_for_location(
self,
location: str,
asset_type: SiteAssetType,
asset_type: LocationAssetType,
ml_model_name: str,
forecast_horizon: ForecastHorizon = ForecastHorizon.latest,
forecast_horizon_minutes: Optional[int] = None,
Expand Down Expand Up @@ -115,14 +115,14 @@ def get_predicted_power_production_for_location(
# read actual generations
values = get_latest_forecast_values_by_site(
session,
site_uuids=[site.site_uuid],
site_uuids=[site.location_uuid],
start_utc=start,
day_ahead_hours=day_ahead_hours,
day_ahead_timezone_delta_hours=day_ahead_timezone_delta_hours,
forecast_horizon_minutes=forecast_horizon_minutes,
model_name=ml_model_name,
)
forecast_values: [ForecastValueSQL] = values[site.site_uuid]
forecast_values: [ForecastValueSQL] = values[site.location_uuid]

# convert ForecastValueSQL to PredictedPower
values = [
Expand All @@ -145,7 +145,7 @@ def get_predicted_power_production_for_location(
def get_generation_for_location(
self,
location: str,
asset_type: SiteAssetType,
asset_type: LocationAssetType,
) -> [internal.PredictedPower]:
"""Gets the predicted power production for a location."""

Expand All @@ -162,7 +162,7 @@ def get_generation_for_location(

# read actual generations
values = get_pv_generation_by_sites(
session=session, site_uuids=[site.site_uuid], start_utc=start, end_utc=end
session=session, site_uuids=[site.location_uuid], start_utc=start, end_utc=end
)

# convert from GenerationSQL to ActualPower
Expand Down Expand Up @@ -200,7 +200,7 @@ def get_predicted_solar_power_production_for_location(

return self.get_predicted_power_production_for_location(
location=location,
asset_type=SiteAssetType.pv,
asset_type=LocationAssetType.pv,
forecast_horizon=forecast_horizon,
forecast_horizon_minutes=forecast_horizon_minutes,
smooth_flag=smooth_flag,
Expand Down Expand Up @@ -229,7 +229,7 @@ def get_predicted_wind_power_production_for_location(

return self.get_predicted_power_production_for_location(
location=location,
asset_type=SiteAssetType.wind,
asset_type=LocationAssetType.wind,
forecast_horizon=forecast_horizon,
forecast_horizon_minutes=forecast_horizon_minutes,
smooth_flag=smooth_flag,
Expand All @@ -241,14 +241,14 @@ def get_actual_solar_power_production_for_location(
) -> list[internal.PredictedPower]:
"""Gets the actual solar power production for a location."""

return self.get_generation_for_location(location=location, asset_type=SiteAssetType.pv)
return self.get_generation_for_location(location=location, asset_type=LocationAssetType.pv)

def get_actual_wind_power_production_for_location(
self, location: str
) -> list[internal.PredictedPower]:
"""Gets the actual wind power production for a location."""

return self.get_generation_for_location(location=location, asset_type=SiteAssetType.wind)
return self.get_generation_for_location(location=location, asset_type=LocationAssetType.wind)

def get_wind_regions(self) -> list[str]:
"""Gets the valid wind regions."""
Expand All @@ -269,8 +269,8 @@ def get_sites(self, email: str) -> list[internal.Site]:
sites = []
for site_sql in sites_sql:
site = internal.Site(
site_uuid=str(site_sql.site_uuid),
client_site_name=site_sql.client_site_name,
site_uuid=str(site_sql.location_uuid),
client_site_name=site_sql.client_location_name,
orientation=site_sql.orientation,
tilt=site_sql.tilt,
capacity_kw=site_sql.capacity_kw,
Expand All @@ -290,11 +290,14 @@ def put_site(
with self._get_session() as session:
user = get_user_by_email(session, email)
site = get_site_by_uuid(session, site_uuid)
check_user_has_access_to_site(session, email, site.site_uuid)
check_user_has_access_to_site(session, email, site.location_uuid)

site_info = PVSiteEditMetadata(
**site_properties.model_dump(exclude_unset=True, exclude_none=False)
)
site_dict = site_properties.model_dump(exclude_unset=True, exclude_none=False)
if "client_site_name" in site_dict:
site_dict["client_location_name"] = site_dict["client_site_name"]
site_dict.pop("client_site_name")

site_info = PVSiteEditMetadata(**site_dict)

site, _ = edit_site(
session=session,
Expand Down Expand Up @@ -434,7 +437,7 @@ def check_user_has_access_to_site(session: Session, email: str, site_uuid: str):
"""

user = get_user_by_email(session=session, email=email)
site_uuids = [str(site.site_uuid) for site in user.site_group.sites]
site_uuids = [str(site.location_uuid) for site in user.location_group.locations]
site_uuid = str(site_uuid)

if site_uuid not in site_uuids:
Expand Down
20 changes: 10 additions & 10 deletions src/india_api/internal/inputs/indiadb/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from datetime import datetime, timedelta

import pytest
from pvsite_datamodel.sqlmodels import Base, ForecastSQL, ForecastValueSQL, GenerationSQL, SiteSQL
from pvsite_datamodel.sqlmodels import Base, ForecastSQL, ForecastValueSQL, GenerationSQL, LocationSQL
from pvsite_datamodel.read.user import get_user_by_email
from pvsite_datamodel.read.model import get_or_create_model
from sqlalchemy import create_engine
Expand Down Expand Up @@ -58,31 +58,31 @@ def sites(db_session):

sites = []
# PV site
site = SiteSQL(
client_site_id=1,
site = LocationSQL(
client_location_id=1,
latitude=20.59,
longitude=78.96,
capacity_kw=4,
ml_id=1,
asset_type="pv",
country="india",
region="testID",
client_site_name="ruvnl_pv_testID1",
client_location_name="ruvnl_pv_testID1",
)
db_session.add(site)
sites.append(site)

# Wind site
site = SiteSQL(
client_site_id=2,
site = LocationSQL(
client_location_id=2,
latitude=20.59,
longitude=78.96,
capacity_kw=4,
ml_id=2,
asset_type="wind",
country="india",
region="testID",
client_site_name="ruvnl_wind_testID",
client_location_name="ruvnl_wind_testID",
)
db_session.add(site)
sites.append(site)
Expand All @@ -91,7 +91,7 @@ def sites(db_session):

# create user
user = get_user_by_email(session=db_session, email="test@test.com")
user.site_group.sites = sites
user.location_group.locations = sites

db_session.commit()

Expand All @@ -108,7 +108,7 @@ def generations(db_session, sites):
for site in sites:
for i in range(0, 10):
generation = GenerationSQL(
site_uuid=site.site_uuid,
location_uuid=site.location_uuid,
generation_power_kw=i,
start_utc=start_times[i],
end_utc=start_times[i] + timedelta(minutes=5),
Expand Down Expand Up @@ -158,7 +158,7 @@ def make_fake_forecast_values(db_session, sites, model_name):
for site in sites:
for timestamp in timestamps:
forecast: ForecastSQL = ForecastSQL(
site_uuid=site.site_uuid, forecast_version=forecast_version, timestamp_utc=timestamp
location_uuid=site.location_uuid, forecast_version=forecast_version, timestamp_utc=timestamp
)

db_session.add(forecast)
Expand Down
16 changes: 8 additions & 8 deletions src/india_api/internal/inputs/indiadb/test_indiadb.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,40 +88,40 @@ def test_get_put_site(self, client, sites) -> None:
sites_from_api = client.get_sites(email="test@test.com")
assert sites_from_api[0].client_site_name == "ruvnl_pv_testID1"
site = client.put_site(
site_uuid=sites[0].site_uuid,
site_uuid=sites[0].location_uuid,
site_properties=SiteProperties(client_site_name="test_zzz"),
email="test@test.com",
)
assert site.client_site_name == "test_zzz"
assert site.client_location_name == "test_zzz"
assert site.latitude is not None

def test_get_site_forecast(self, client, sites, forecast_values_site) -> None:
out = client.get_site_forecast(site_uuid=str(sites[0].site_uuid), email="test@test.com")
out = client.get_site_forecast(site_uuid=str(sites[0].location_uuid), email="test@test.com")
assert len(out) > 0

def test_get_site_forecast_no_forecast_values(self, client, sites) -> None:
out = client.get_site_forecast(site_uuid=sites[0].site_uuid, email="test@test.com")
out = client.get_site_forecast(site_uuid=sites[0].location_uuid, email="test@test.com")
assert len(out) == 0

def test_get_site_forecast_no_access(self, client, sites) -> None:
with pytest.raises(Exception):
_ = client.get_site_forecast(site_uuid=sites[0].site_uuid, email="test2@test.com")
_ = client.get_site_forecast(site_uuid=sites[0].location_uuid, email="test2@test.com")

def test_get_site_generation(self, client, sites, generations) -> None:
out = client.get_site_generation(site_uuid=str(sites[0].site_uuid), email="test@test.com")
out = client.get_site_generation(site_uuid=str(sites[0].location_uuid), email="test@test.com")
assert len(out) > 0

def test_post_site_generation(self, client, sites) -> None:
client.post_site_generation(
site_uuid=sites[0].site_uuid,
site_uuid=sites[0].location_uuid,
generation=[ActualPower(Time=1, PowerKW=1)],
email="test@test.com",
)

def test_post_site_generation_exceding_max_capacity(self, client, sites):
try:
client.post_site_generation(
site_uuid=sites[0].site_uuid,
site_uuid=sites[0].location_uuid,
generation=[ActualPower(Time=1, PowerKW=1000)],
email="test@test.com",
)
Expand Down
Loading