Skip to content

Commit 9deda26

Browse files
authored
Merge pull request #101 from amazeeio/dev
Support dedicated regions and fix error messages
2 parents 92cf6ff + 387b665 commit 9deda26

File tree

9 files changed

+1631
-82
lines changed

9 files changed

+1631
-82
lines changed

app/api/regions.py

Lines changed: 248 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,100 @@
11
from fastapi import APIRouter, Depends, HTTPException, status
22
from sqlalchemy.orm import Session
33
from typing import List
4+
import requests
5+
import asyncpg
6+
import logging
47

58
from app.db.database import get_db
69
from app.api.auth import get_current_user_from_auth
7-
from app.schemas.models import Region, RegionCreate, RegionResponse, User, RegionUpdate
8-
from app.db.models import DBRegion, DBPrivateAIKey
10+
from app.schemas.models import Region, RegionCreate, RegionResponse, User, RegionUpdate, TeamSummary
11+
from app.db.models import DBRegion, DBPrivateAIKey, DBTeamRegion, DBTeam
12+
from app.core.security import check_system_admin
13+
14+
logger = logging.getLogger(__name__)
915

1016
router = APIRouter(
1117
tags=["regions"]
1218
)
1319

14-
@router.post("", response_model=Region)
15-
@router.post("/", response_model=Region)
20+
async def validate_litellm_endpoint(api_url: str, api_key: str) -> bool:
21+
"""
22+
Validate LiteLLM endpoint by making a test request to the health endpoint.
23+
24+
Args:
25+
api_url: The LiteLLM API URL
26+
api_key: The LiteLLM API key
27+
28+
Returns:
29+
bool: True if validation succeeds, raises HTTPException if it fails
30+
"""
31+
try:
32+
# Test the LiteLLM health endpoint
33+
response = requests.get(
34+
f"{api_url}/health/liveliness",
35+
headers={"Authorization": f"Bearer {api_key}"},
36+
timeout=10
37+
)
38+
response.raise_for_status()
39+
logger.info(f"LiteLLM endpoint validation successful for {api_url}")
40+
return True
41+
except requests.exceptions.RequestException as e:
42+
error_msg = str(e)
43+
if hasattr(e, 'response') and e.response is not None:
44+
try:
45+
error_details = e.response.json()
46+
error_msg = f"Status {e.response.status_code}: {error_details}"
47+
except ValueError:
48+
error_msg = f"Status {e.response.status_code}: {e.response.text}"
49+
logger.error(f"LiteLLM endpoint validation failed for {api_url}: {error_msg}")
50+
raise HTTPException(
51+
status_code=status.HTTP_400_BAD_REQUEST,
52+
detail=f"LiteLLM endpoint validation failed: {error_msg}"
53+
)
54+
55+
async def validate_database_connection(host: str, port: int, user: str, password: str) -> bool:
56+
"""
57+
Validate database connection by attempting to connect to PostgreSQL.
58+
59+
Args:
60+
host: Database host
61+
port: Database port
62+
user: Database admin user
63+
password: Database admin password
64+
65+
Returns:
66+
bool: True if validation succeeds, raises HTTPException if it fails
67+
"""
68+
try:
69+
# Attempt to connect to the database
70+
conn = await asyncpg.connect(
71+
host=host,
72+
port=port,
73+
user=user,
74+
password=password
75+
)
76+
await conn.close()
77+
logger.info(f"Database connection validation successful for {host}:{port}")
78+
return True
79+
except asyncpg.exceptions.PostgresError as e:
80+
logger.error(f"Database connection validation failed for {host}:{port}: {str(e)}")
81+
raise HTTPException(
82+
status_code=status.HTTP_400_BAD_REQUEST,
83+
detail=f"Database connection validation failed: {str(e)}"
84+
)
85+
except Exception as e:
86+
logger.error(f"Unexpected error during database validation for {host}:{port}: {str(e)}")
87+
raise HTTPException(
88+
status_code=status.HTTP_400_BAD_REQUEST,
89+
detail=f"Database connection validation failed: {str(e)}"
90+
)
91+
92+
@router.post("", response_model=Region, dependencies=[Depends(check_system_admin)])
93+
@router.post("/", response_model=Region, dependencies=[Depends(check_system_admin)])
1694
async def create_region(
1795
region: RegionCreate,
18-
current_user: User = Depends(get_current_user_from_auth),
1996
db: Session = Depends(get_db)
2097
):
21-
if not current_user.is_admin:
22-
raise HTTPException(
23-
status_code=status.HTTP_403_FORBIDDEN,
24-
detail="Only administrators can create regions"
25-
)
26-
2798
# Check if region with this name already exists
2899
existing_region = db.query(DBRegion).filter(DBRegion.name == region.name).first()
29100
if existing_region:
@@ -32,6 +103,17 @@ async def create_region(
32103
detail=f"A region with the name '{region.name}' already exists"
33104
)
34105

106+
# Validate LiteLLM endpoint
107+
await validate_litellm_endpoint(region.litellm_api_url, region.litellm_api_key)
108+
109+
# Validate database connection
110+
await validate_database_connection(
111+
region.postgres_host,
112+
region.postgres_port,
113+
region.postgres_admin_user,
114+
region.postgres_admin_password
115+
)
116+
35117
db_region = DBRegion(**region.model_dump())
36118
db.add(db_region)
37119
try:
@@ -51,24 +133,40 @@ async def list_regions(
51133
current_user: User = Depends(get_current_user_from_auth),
52134
db: Session = Depends(get_db)
53135
):
54-
return db.query(DBRegion).filter(DBRegion.is_active == True).all()
136+
# System admin users can see all regions
137+
if current_user.is_admin:
138+
return db.query(DBRegion).filter(DBRegion.is_active == True).all()
139+
140+
# Regular users can only see non-dedicated regions
141+
if not current_user.team_id:
142+
return db.query(DBRegion).filter(
143+
DBRegion.is_active == True,
144+
DBRegion.is_dedicated == False
145+
).all()
55146

56-
@router.get("/admin", response_model=List[Region])
147+
# Team members can see non-dedicated regions plus their team's dedicated regions
148+
team_dedicated_regions = db.query(DBRegion).join(DBTeamRegion).filter(
149+
DBRegion.is_active == True,
150+
DBRegion.is_dedicated == True,
151+
DBTeamRegion.team_id == current_user.team_id
152+
).all()
153+
154+
non_dedicated_regions = db.query(DBRegion).filter(
155+
DBRegion.is_active == True,
156+
DBRegion.is_dedicated == False
157+
).all()
158+
159+
return non_dedicated_regions + team_dedicated_regions
160+
161+
@router.get("/admin", response_model=List[Region], dependencies=[Depends(check_system_admin)])
57162
async def list_admin_regions(
58-
current_user: User = Depends(get_current_user_from_auth),
59163
db: Session = Depends(get_db)
60164
):
61-
if not current_user.is_admin:
62-
raise HTTPException(
63-
status_code=status.HTTP_403_FORBIDDEN,
64-
detail="Only administrators can access this endpoint"
65-
)
66165
return db.query(DBRegion).all()
67166

68-
@router.get("/{region_id}", response_model=RegionResponse)
167+
@router.get("/{region_id}", response_model=RegionResponse, dependencies=[Depends(check_system_admin)])
69168
async def get_region(
70169
region_id: int,
71-
current_user: User = Depends(get_current_user_from_auth),
72170
db: Session = Depends(get_db)
73171
):
74172
region = db.query(DBRegion).filter(DBRegion.id == region_id).first()
@@ -79,31 +177,24 @@ async def get_region(
79177
)
80178
return region
81179

82-
@router.delete("/{region_id}")
180+
@router.delete("/{region_id}", dependencies=[Depends(check_system_admin)])
83181
async def delete_region(
84182
region_id: int,
85-
current_user: User = Depends(get_current_user_from_auth),
86183
db: Session = Depends(get_db)
87184
):
88-
if not current_user.is_admin:
89-
raise HTTPException(
90-
status_code=status.HTTP_403_FORBIDDEN,
91-
detail="Only administrators can delete regions"
92-
)
93-
94185
region = db.query(DBRegion).filter(DBRegion.id == region_id).first()
95186
if not region:
96187
raise HTTPException(
97188
status_code=status.HTTP_404_NOT_FOUND,
98189
detail="Region not found"
99190
)
100191

101-
# Check if there are any databases using this region
102-
existing_databases = db.query(DBPrivateAIKey).filter(DBPrivateAIKey.region_id == region_id).count()
103-
if existing_databases > 0:
192+
# Check if there are any keys using this region
193+
existing_keys = db.query(DBPrivateAIKey).filter(DBPrivateAIKey.region_id == region_id).count()
194+
if existing_keys > 0:
104195
raise HTTPException(
105196
status_code=status.HTTP_400_BAD_REQUEST,
106-
detail=f"Cannot delete region: {existing_databases} database(s) are currently using this region. Please delete these databases first."
197+
detail=f"Cannot delete region: {existing_keys} keys(s) are currently using this region. Please delete these keys first."
107198
)
108199

109200
# Instead of deleting, mark as inactive
@@ -118,18 +209,12 @@ async def delete_region(
118209
)
119210
return {"message": "Region deleted successfully"}
120211

121-
@router.put("/{region_id}", response_model=Region)
212+
@router.put("/{region_id}", response_model=Region, dependencies=[Depends(check_system_admin)])
122213
async def update_region(
123214
region_id: int,
124215
region: RegionUpdate,
125-
current_user: User = Depends(get_current_user_from_auth),
126216
db: Session = Depends(get_db)
127217
):
128-
if not current_user.is_admin:
129-
raise HTTPException(
130-
status_code=status.HTTP_403_FORBIDDEN,
131-
detail="Only administrators can update regions"
132-
)
133218

134219
db_region = db.query(DBRegion).filter(DBRegion.id == region_id).first()
135220
if not db_region:
@@ -164,4 +249,126 @@ async def update_region(
164249
status_code=status.HTTP_400_BAD_REQUEST,
165250
detail=f"Failed to update region: {str(e)}"
166251
)
167-
return db_region
252+
return db_region
253+
254+
@router.post("/{region_id}/teams/{team_id}", dependencies=[Depends(check_system_admin)])
255+
async def associate_team_with_region(
256+
region_id: int,
257+
team_id: int,
258+
db: Session = Depends(get_db)
259+
):
260+
"""Associate a team with a dedicated region. Only system admins can do this."""
261+
262+
# Check if region exists and is dedicated
263+
region = db.query(DBRegion).filter(DBRegion.id == region_id).first()
264+
if not region:
265+
raise HTTPException(
266+
status_code=status.HTTP_404_NOT_FOUND,
267+
detail="Region not found"
268+
)
269+
270+
if not region.is_dedicated:
271+
raise HTTPException(
272+
status_code=status.HTTP_400_BAD_REQUEST,
273+
detail="Can only associate teams with dedicated regions"
274+
)
275+
276+
# Check if team exists
277+
team = db.query(DBTeam).filter(DBTeam.id == team_id).first()
278+
if not team:
279+
raise HTTPException(
280+
status_code=status.HTTP_404_NOT_FOUND,
281+
detail="Team not found"
282+
)
283+
284+
# Check if association already exists
285+
existing_association = db.query(DBTeamRegion).filter(
286+
DBTeamRegion.team_id == team_id,
287+
DBTeamRegion.region_id == region_id
288+
).first()
289+
290+
if existing_association:
291+
raise HTTPException(
292+
status_code=status.HTTP_400_BAD_REQUEST,
293+
detail="Team is already associated with this region"
294+
)
295+
296+
# Create the association
297+
team_region = DBTeamRegion(
298+
team_id=team_id,
299+
region_id=region_id
300+
)
301+
db.add(team_region)
302+
303+
try:
304+
db.commit()
305+
except Exception as e:
306+
db.rollback()
307+
raise HTTPException(
308+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
309+
detail=f"Failed to associate team with region: {str(e)}"
310+
)
311+
312+
return {"message": "Team associated with region successfully"}
313+
314+
@router.delete("/{region_id}/teams/{team_id}", dependencies=[Depends(check_system_admin)])
315+
async def disassociate_team_from_region(
316+
region_id: int,
317+
team_id: int,
318+
db: Session = Depends(get_db)
319+
):
320+
"""Disassociate a team from a dedicated region. Only system admins can do this."""
321+
322+
# Check if association exists
323+
association = db.query(DBTeamRegion).filter(
324+
DBTeamRegion.team_id == team_id,
325+
DBTeamRegion.region_id == region_id
326+
).first()
327+
328+
if not association:
329+
raise HTTPException(
330+
status_code=status.HTTP_404_NOT_FOUND,
331+
detail="Team-region association not found"
332+
)
333+
334+
# Remove the association
335+
db.delete(association)
336+
337+
try:
338+
db.commit()
339+
except Exception as e:
340+
db.rollback()
341+
raise HTTPException(
342+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
343+
detail=f"Failed to disassociate team from region: {str(e)}"
344+
)
345+
346+
return {"message": "Team disassociated from region successfully"}
347+
348+
@router.get("/{region_id}/teams", response_model=List[TeamSummary], dependencies=[Depends(check_system_admin)])
349+
async def list_teams_for_region(
350+
region_id: int,
351+
db: Session = Depends(get_db)
352+
):
353+
"""List teams associated with a dedicated region. Only system admins can do this."""
354+
355+
# Check if region exists and is dedicated
356+
region = db.query(DBRegion).filter(DBRegion.id == region_id).first()
357+
if not region:
358+
raise HTTPException(
359+
status_code=status.HTTP_404_NOT_FOUND,
360+
detail="Region not found"
361+
)
362+
363+
if not region.is_dedicated:
364+
raise HTTPException(
365+
status_code=status.HTTP_400_BAD_REQUEST,
366+
detail="Can only list teams for dedicated regions"
367+
)
368+
369+
# Get associated teams
370+
teams = db.query(DBTeam).join(DBTeamRegion).filter(
371+
DBTeamRegion.region_id == region_id
372+
).all()
373+
374+
return teams

app/core/resource_limits.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,19 +100,19 @@ def check_key_limits(db: Session, team_id: int, owner_id: Optional[int] = None)
100100
if result.current_team_keys >= result.max_total_keys:
101101
raise HTTPException(
102102
status_code=status.HTTP_402_PAYMENT_REQUIRED,
103-
detail=f"Team has reached the maximum LLM token limit of {result.max_total_keys} tokens"
103+
detail=f"Team has reached the maximum LLM key limit of {result.max_total_keys} keys"
104104
)
105105

106106
if owner_id is not None and result.current_user_keys >= result.max_keys_per_user:
107107
raise HTTPException(
108108
status_code=status.HTTP_402_PAYMENT_REQUIRED,
109-
detail=f"User has reached the maximum LLM token limit of {result.max_keys_per_user} tokens"
109+
detail=f"User has reached the maximum LLM key limit of {result.max_keys_per_user} keys"
110110
)
111111

112112
if owner_id is None and result.current_service_keys >= result.max_service_keys:
113113
raise HTTPException(
114114
status_code=status.HTTP_402_PAYMENT_REQUIRED,
115-
detail=f"Team has reached the maximum service LLM token limit of {result.max_service_keys} tokens"
115+
detail=f"Team has reached the maximum service LLM key limit of {result.max_service_keys} keys"
116116
)
117117

118118
def check_vector_db_limits(db: Session, team_id: int) -> None:

0 commit comments

Comments
 (0)