Skip to content

Commit 7644c47

Browse files
committed
Enhance resource limits validation and optimize queries
- Refactored the `check_team_user_limit`, `check_key_limits`, and `check_vector_db_limits` functions to consolidate database queries, improving performance and reducing redundancy. - Updated the logic to retrieve current counts and maximum limits in a single query, enhancing efficiency. - Enhanced error handling to provide clearer messages when limits are exceeded. - Added new tests to validate the behavior of resource limits under various scenarios, ensuring comprehensive coverage for the updated functionality.
1 parent 3ba1d8e commit 7644c47

File tree

2 files changed

+276
-120
lines changed

2 files changed

+276
-120
lines changed

app/core/resource_limits.py

Lines changed: 121 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from sqlalchemy.orm import Session
2-
from app.db.models import DBTeam, DBUser, DBPrivateAIKey
2+
from sqlalchemy import func, and_, or_
3+
from app.db.models import DBTeam, DBUser, DBPrivateAIKey, DBTeamProduct, DBProduct
34
from fastapi import HTTPException, status
45
from typing import Optional
56
from datetime import datetime, UTC
@@ -26,25 +27,27 @@ def check_team_user_limit(db: Session, team_id: int) -> None:
2627
db: Database session
2728
team_id: ID of the team to check
2829
"""
29-
# Get current user count for the team
30-
current_user_count = db.query(DBUser).filter(DBUser.team_id == team_id).count()
31-
32-
# Get all active products for the team
33-
team = db.query(DBTeam).filter(DBTeam.id == team_id).first()
34-
if not team:
30+
# Get current user count and max allowed users in a single query
31+
result = db.query(
32+
func.count(DBUser.id).label('current_user_count'),
33+
func.coalesce(func.max(DBProduct.user_count), DEFAULT_USER_COUNT).label('max_users')
34+
).select_from(DBUser).filter(
35+
DBUser.team_id == team_id
36+
).outerjoin(
37+
DBTeamProduct,
38+
DBTeamProduct.team_id == team_id
39+
).outerjoin(
40+
DBProduct,
41+
DBProduct.id == DBTeamProduct.product_id
42+
).first()
43+
44+
if not result:
3545
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Team not found")
3646

37-
# Find the maximum user count allowed across all active products
38-
max_user_count = max(
39-
(product.user_count for team_product in team.active_products
40-
for product in [team_product.product] if product.user_count),
41-
default=DEFAULT_USER_COUNT # Default to 2 if no products have user_count set
42-
)
43-
44-
if current_user_count >= max_user_count:
47+
if result.current_user_count >= result.max_users:
4548
raise HTTPException(
4649
status_code=status.HTTP_402_PAYMENT_REQUIRED,
47-
detail=f"Team has reached the maximum user limit of {max_user_count} users"
50+
detail=f"Team has reached the maximum user limit of {result.max_users} users"
4851
)
4952

5053
def check_key_limits(db: Session, team_id: int, owner_id: Optional[int] = None) -> None:
@@ -57,70 +60,60 @@ def check_key_limits(db: Session, team_id: int, owner_id: Optional[int] = None)
5760
team_id: ID of the team to check
5861
owner_id: Optional ID of the user who will own the key
5962
"""
60-
# Get the team and its active products
61-
team = db.query(DBTeam).filter(DBTeam.id == team_id).first()
62-
if not team:
63+
# Get all limits and current counts in a single query
64+
result = db.query(
65+
func.coalesce(func.max(DBProduct.total_key_count), DEFAULT_TOTAL_KEYS).label('max_total_keys'),
66+
func.coalesce(func.max(DBProduct.keys_per_user), DEFAULT_KEYS_PER_USER).label('max_keys_per_user'),
67+
func.coalesce(func.max(DBProduct.service_key_count), DEFAULT_SERVICE_KEYS).label('max_service_keys'),
68+
func.count(DBPrivateAIKey.id).filter(
69+
DBPrivateAIKey.litellm_token.isnot(None)
70+
).label('current_team_keys'),
71+
func.count(DBPrivateAIKey.id).filter(
72+
DBPrivateAIKey.owner_id == owner_id,
73+
DBPrivateAIKey.litellm_token.isnot(None)
74+
).label('current_user_keys') if owner_id else None,
75+
func.count(DBPrivateAIKey.id).filter(
76+
DBPrivateAIKey.owner_id.is_(None),
77+
DBPrivateAIKey.litellm_token.isnot(None)
78+
).label('current_service_keys')
79+
).select_from(DBTeam).filter( # Have to use Teams table as the base because not every team has a product
80+
DBTeam.id == team_id
81+
).outerjoin(
82+
DBTeamProduct,
83+
DBTeamProduct.team_id == DBTeam.id
84+
).outerjoin(
85+
DBProduct,
86+
DBProduct.id == DBTeamProduct.product_id
87+
).outerjoin(
88+
DBPrivateAIKey,
89+
or_(
90+
DBPrivateAIKey.team_id == DBTeam.id,
91+
DBPrivateAIKey.owner_id.in_(
92+
db.query(DBUser.id).filter(DBUser.team_id == DBTeam.id)
93+
)
94+
)
95+
).first()
96+
97+
if not result:
6398
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Team not found")
6499

65-
# Find the maximum limits across all active products, using defaults if no products
66-
max_total_keys = max(
67-
(product.total_key_count for team_product in team.active_products
68-
for product in [team_product.product] if product.total_key_count),
69-
default=DEFAULT_TOTAL_KEYS # Default to 2 if no products have total_key_count set
70-
)
71-
max_keys_per_user = max(
72-
(product.keys_per_user for team_product in team.active_products
73-
for product in [team_product.product] if product.keys_per_user),
74-
default=DEFAULT_KEYS_PER_USER # Default to 1 if no products have keys_per_user set
75-
)
76-
max_service_keys = max(
77-
(product.service_key_count for team_product in team.active_products
78-
for product in [team_product.product] if product.service_key_count),
79-
default=DEFAULT_SERVICE_KEYS # Default to 1 if no products have service_key_count set
80-
)
81-
82-
# Get all users in the team
83-
team_users = db.query(DBUser).filter(DBUser.team_id == team_id).all()
84-
user_ids = [user.id for user in team_users]
85-
86-
# Check total team LLM tokens (both team-owned and user-owned)
87-
current_team_tokens = db.query(DBPrivateAIKey).filter(
88-
(
89-
(DBPrivateAIKey.team_id == team_id) | # Team-owned tokens
90-
(DBPrivateAIKey.owner_id.in_(user_ids)) # User-owned tokens
91-
),
92-
DBPrivateAIKey.litellm_token.isnot(None) # Only count LLM tokens
93-
).count()
94-
if current_team_tokens >= max_total_keys:
100+
if result.current_team_keys >= result.max_total_keys:
95101
raise HTTPException(
96102
status_code=status.HTTP_402_PAYMENT_REQUIRED,
97-
detail=f"Team has reached the maximum LLM token limit of {max_total_keys} tokens"
103+
detail=f"Team has reached the maximum LLM token limit of {result.max_total_keys} tokens"
98104
)
99105

100-
# Check user LLM tokens if owner_id is provided
101-
if owner_id is not None:
102-
current_user_tokens = db.query(DBPrivateAIKey).filter(
103-
DBPrivateAIKey.owner_id == owner_id,
104-
DBPrivateAIKey.litellm_token.isnot(None) # Only count LLM tokens
105-
).count()
106-
if current_user_tokens >= max_keys_per_user:
107-
raise HTTPException(
108-
status_code=status.HTTP_402_PAYMENT_REQUIRED,
109-
detail=f"User has reached the maximum LLM token limit of {max_keys_per_user} tokens"
110-
)
106+
if owner_id is not None and result.current_user_keys >= result.max_keys_per_user:
107+
raise HTTPException(
108+
status_code=status.HTTP_402_PAYMENT_REQUIRED,
109+
detail=f"User has reached the maximum LLM token limit of {result.max_keys_per_user} tokens"
110+
)
111111

112-
# Check service LLM tokens (team-owned tokens)
113-
if owner_id is None: # This is a team-owned token
114-
current_service_tokens = db.query(DBPrivateAIKey).filter(
115-
DBPrivateAIKey.team_id == team_id,
116-
DBPrivateAIKey.owner_id.is_(None),
117-
DBPrivateAIKey.litellm_token.isnot(None) # Only count LLM tokens
118-
).count()
119-
if current_service_tokens >= max_service_keys:
120-
raise HTTPException(
121-
status_code=status.HTTP_402_PAYMENT_REQUIRED,
122-
detail=f"Team has reached the maximum service LLM token limit of {max_service_keys} tokens"
123-
)
112+
if owner_id is None and result.current_service_keys >= result.max_service_keys:
113+
raise HTTPException(
114+
status_code=status.HTTP_402_PAYMENT_REQUIRED,
115+
detail=f"Team has reached the maximum service LLM token limit of {result.max_service_keys} tokens"
116+
)
124117

125118
def check_vector_db_limits(db: Session, team_id: int) -> None:
126119
"""
@@ -131,67 +124,75 @@ def check_vector_db_limits(db: Session, team_id: int) -> None:
131124
db: Database session
132125
team_id: ID of the team to check
133126
"""
134-
# Get the team and its active products
135-
team = db.query(DBTeam).filter(DBTeam.id == team_id).first()
136-
if not team:
127+
# Get vector DB limits and current count in a single query
128+
result = db.query(
129+
func.coalesce(func.max(DBProduct.vector_db_count), DEFAULT_VECTOR_DB_COUNT).label('max_vector_db_count'),
130+
func.count(DBPrivateAIKey.id).filter(
131+
DBPrivateAIKey.database_name.isnot(None)
132+
).label('current_vector_db_count')
133+
).select_from(DBTeam).filter(
134+
DBTeam.id == team_id
135+
).outerjoin(
136+
DBTeamProduct,
137+
DBTeamProduct.team_id == DBTeam.id
138+
).outerjoin(
139+
DBProduct,
140+
DBProduct.id == DBTeamProduct.product_id
141+
).outerjoin(
142+
DBPrivateAIKey,
143+
or_(
144+
DBPrivateAIKey.team_id == DBTeam.id,
145+
DBPrivateAIKey.owner_id.in_(
146+
db.query(DBUser.id).filter(DBUser.team_id == DBTeam.id)
147+
)
148+
)
149+
).first()
150+
151+
if not result:
137152
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Team not found")
138153

139-
# Find the maximum vector DB count across all active products
140-
max_vector_db_count = max(
141-
(product.vector_db_count for team_product in team.active_products
142-
for product in [team_product.product] if product.vector_db_count),
143-
default=DEFAULT_VECTOR_DB_COUNT # Default to 1 if no products have vector_db_count set
144-
)
145-
146-
# Get all users in the team
147-
team_users = db.query(DBUser).filter(DBUser.team_id == team_id).all()
148-
user_ids = [user.id for user in team_users]
149-
150-
# Get current vector DB count for the team (both team-owned and user-owned)
151-
current_vector_db_count = db.query(DBPrivateAIKey).filter(
152-
(
153-
(DBPrivateAIKey.team_id == team_id) | # Team-owned vector DBs
154-
(DBPrivateAIKey.owner_id.in_(user_ids)) # User-owned vector DBs
155-
),
156-
DBPrivateAIKey.database_name.isnot(None) # Only count keys with database_name set
157-
).count()
158-
159-
if current_vector_db_count >= max_vector_db_count:
154+
if result.current_vector_db_count >= result.max_vector_db_count:
160155
raise HTTPException(
161156
status_code=status.HTTP_402_PAYMENT_REQUIRED,
162-
detail=f"Team has reached the maximum vector DB limit of {max_vector_db_count} databases"
157+
detail=f"Team has reached the maximum vector DB limit of {result.max_vector_db_count} databases"
163158
)
164159

165160
def get_token_restrictions(db: Session, team_id: int) -> tuple[int, float, int]:
166161
"""
167162
Get the token restrictions for a team.
168163
"""
169-
team = db.query(DBTeam).filter(DBTeam.id == team_id).first()
170-
if not team:
164+
# Get all token restrictions in a single query
165+
result = db.query(
166+
func.coalesce(func.max(DBProduct.renewal_period_days), DEFAULT_KEY_DURATION).label('max_key_duration'),
167+
func.coalesce(func.max(DBProduct.max_budget_per_key), DEFAULT_MAX_SPEND).label('max_max_spend'),
168+
func.coalesce(func.max(DBProduct.rpm_per_key), DEFAULT_RPM_PER_KEY).label('max_rpm_limit'),
169+
DBTeam.created_at,
170+
DBTeam.last_payment
171+
).select_from(DBTeam).filter(
172+
DBTeam.id == team_id
173+
).outerjoin(
174+
DBTeamProduct,
175+
DBTeamProduct.team_id == DBTeam.id
176+
).outerjoin(
177+
DBProduct,
178+
DBProduct.id == DBTeamProduct.product_id
179+
).group_by(
180+
DBTeam.created_at,
181+
DBTeam.last_payment
182+
).first()
183+
184+
if not result:
171185
logger.error(f"Team not found for team_id: {team_id}")
172186
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Team not found")
173187

174-
max_key_duration = max(
175-
(product.renewal_period_days for team_product in team.active_products
176-
for product in [team_product.product] if product.renewal_period_days),
177-
default=DEFAULT_KEY_DURATION
178-
)
179-
if team.last_payment is None:
180-
days_left_in_period = max_key_duration
188+
if result.last_payment is None:
189+
days_left_in_period = result.max_key_duration
181190
else:
182-
days_left_in_period = max_key_duration - (datetime.now(UTC) - max(team.created_at.replace(tzinfo=UTC), team.last_payment.replace(tzinfo=UTC))).days
183-
max_max_spend = max(
184-
(product.max_budget_per_key for team_product in team.active_products
185-
for product in [team_product.product] if product.max_budget_per_key),
186-
default=DEFAULT_MAX_SPEND
187-
)
188-
max_rpm_limit = max(
189-
(product.rpm_per_key for team_product in team.active_products
190-
for product in [team_product.product] if product.rpm_per_key),
191-
default=DEFAULT_RPM_PER_KEY
192-
)
193-
194-
return days_left_in_period, max_max_spend, max_rpm_limit
191+
days_left_in_period = result.max_key_duration - (
192+
datetime.now(UTC) - max(result.created_at.replace(tzinfo=UTC), result.last_payment.replace(tzinfo=UTC))
193+
).days
194+
195+
return days_left_in_period, result.max_max_spend, result.max_rpm_limit
195196

196197
def get_team_limits(db: Session, team_id: int):
197198
# TODO: Go through all products, and create a master list of the limits on all fields for this team.

0 commit comments

Comments
 (0)