1
1
from sqlalchemy .orm import Session
2
- from app .db .models import DBTeam , DBUser , DBPrivateAIKey
2
+ from sqlalchemy import func , and_ , or_
3
+ from app .db .models import DBTeam , DBUser , DBPrivateAIKey , DBTeamProduct , DBProduct
3
4
from fastapi import HTTPException , status
4
5
from typing import Optional
5
6
from datetime import datetime , UTC
@@ -26,25 +27,27 @@ def check_team_user_limit(db: Session, team_id: int) -> None:
26
27
db: Database session
27
28
team_id: ID of the team to check
28
29
"""
29
- # Get current user count for the team
30
- current_user_count = db .query (DBUser ).filter (DBUser .team_id == team_id ).count ()
31
-
32
- # Get all active products for the team
33
- team = db .query (DBTeam ).filter (DBTeam .id == team_id ).first ()
34
- if not team :
30
+ # Get current user count and max allowed users in a single query
31
+ result = db .query (
32
+ func .count (DBUser .id ).label ('current_user_count' ),
33
+ func .coalesce (func .max (DBProduct .user_count ), DEFAULT_USER_COUNT ).label ('max_users' )
34
+ ).select_from (DBUser ).filter (
35
+ DBUser .team_id == team_id
36
+ ).outerjoin (
37
+ DBTeamProduct ,
38
+ DBTeamProduct .team_id == team_id
39
+ ).outerjoin (
40
+ DBProduct ,
41
+ DBProduct .id == DBTeamProduct .product_id
42
+ ).first ()
43
+
44
+ if not result :
35
45
raise HTTPException (status_code = status .HTTP_404_NOT_FOUND , detail = "Team not found" )
36
46
37
- # Find the maximum user count allowed across all active products
38
- max_user_count = max (
39
- (product .user_count for team_product in team .active_products
40
- for product in [team_product .product ] if product .user_count ),
41
- default = DEFAULT_USER_COUNT # Default to 2 if no products have user_count set
42
- )
43
-
44
- if current_user_count >= max_user_count :
47
+ if result .current_user_count >= result .max_users :
45
48
raise HTTPException (
46
49
status_code = status .HTTP_402_PAYMENT_REQUIRED ,
47
- detail = f"Team has reached the maximum user limit of { max_user_count } users"
50
+ detail = f"Team has reached the maximum user limit of { result . max_users } users"
48
51
)
49
52
50
53
def check_key_limits (db : Session , team_id : int , owner_id : Optional [int ] = None ) -> None :
@@ -57,70 +60,60 @@ def check_key_limits(db: Session, team_id: int, owner_id: Optional[int] = None)
57
60
team_id: ID of the team to check
58
61
owner_id: Optional ID of the user who will own the key
59
62
"""
60
- # Get the team and its active products
61
- team = db .query (DBTeam ).filter (DBTeam .id == team_id ).first ()
62
- if not team :
63
+ # Get all limits and current counts in a single query
64
+ result = db .query (
65
+ func .coalesce (func .max (DBProduct .total_key_count ), DEFAULT_TOTAL_KEYS ).label ('max_total_keys' ),
66
+ func .coalesce (func .max (DBProduct .keys_per_user ), DEFAULT_KEYS_PER_USER ).label ('max_keys_per_user' ),
67
+ func .coalesce (func .max (DBProduct .service_key_count ), DEFAULT_SERVICE_KEYS ).label ('max_service_keys' ),
68
+ func .count (DBPrivateAIKey .id ).filter (
69
+ DBPrivateAIKey .litellm_token .isnot (None )
70
+ ).label ('current_team_keys' ),
71
+ func .count (DBPrivateAIKey .id ).filter (
72
+ DBPrivateAIKey .owner_id == owner_id ,
73
+ DBPrivateAIKey .litellm_token .isnot (None )
74
+ ).label ('current_user_keys' ) if owner_id else None ,
75
+ func .count (DBPrivateAIKey .id ).filter (
76
+ DBPrivateAIKey .owner_id .is_ (None ),
77
+ DBPrivateAIKey .litellm_token .isnot (None )
78
+ ).label ('current_service_keys' )
79
+ ).select_from (DBTeam ).filter ( # Have to use Teams table as the base because not every team has a product
80
+ DBTeam .id == team_id
81
+ ).outerjoin (
82
+ DBTeamProduct ,
83
+ DBTeamProduct .team_id == DBTeam .id
84
+ ).outerjoin (
85
+ DBProduct ,
86
+ DBProduct .id == DBTeamProduct .product_id
87
+ ).outerjoin (
88
+ DBPrivateAIKey ,
89
+ or_ (
90
+ DBPrivateAIKey .team_id == DBTeam .id ,
91
+ DBPrivateAIKey .owner_id .in_ (
92
+ db .query (DBUser .id ).filter (DBUser .team_id == DBTeam .id )
93
+ )
94
+ )
95
+ ).first ()
96
+
97
+ if not result :
63
98
raise HTTPException (status_code = status .HTTP_404_NOT_FOUND , detail = "Team not found" )
64
99
65
- # Find the maximum limits across all active products, using defaults if no products
66
- max_total_keys = max (
67
- (product .total_key_count for team_product in team .active_products
68
- for product in [team_product .product ] if product .total_key_count ),
69
- default = DEFAULT_TOTAL_KEYS # Default to 2 if no products have total_key_count set
70
- )
71
- max_keys_per_user = max (
72
- (product .keys_per_user for team_product in team .active_products
73
- for product in [team_product .product ] if product .keys_per_user ),
74
- default = DEFAULT_KEYS_PER_USER # Default to 1 if no products have keys_per_user set
75
- )
76
- max_service_keys = max (
77
- (product .service_key_count for team_product in team .active_products
78
- for product in [team_product .product ] if product .service_key_count ),
79
- default = DEFAULT_SERVICE_KEYS # Default to 1 if no products have service_key_count set
80
- )
81
-
82
- # Get all users in the team
83
- team_users = db .query (DBUser ).filter (DBUser .team_id == team_id ).all ()
84
- user_ids = [user .id for user in team_users ]
85
-
86
- # Check total team LLM tokens (both team-owned and user-owned)
87
- current_team_tokens = db .query (DBPrivateAIKey ).filter (
88
- (
89
- (DBPrivateAIKey .team_id == team_id ) | # Team-owned tokens
90
- (DBPrivateAIKey .owner_id .in_ (user_ids )) # User-owned tokens
91
- ),
92
- DBPrivateAIKey .litellm_token .isnot (None ) # Only count LLM tokens
93
- ).count ()
94
- if current_team_tokens >= max_total_keys :
100
+ if result .current_team_keys >= result .max_total_keys :
95
101
raise HTTPException (
96
102
status_code = status .HTTP_402_PAYMENT_REQUIRED ,
97
- detail = f"Team has reached the maximum LLM token limit of { max_total_keys } tokens"
103
+ detail = f"Team has reached the maximum LLM token limit of { result . max_total_keys } tokens"
98
104
)
99
105
100
- # Check user LLM tokens if owner_id is provided
101
- if owner_id is not None :
102
- current_user_tokens = db .query (DBPrivateAIKey ).filter (
103
- DBPrivateAIKey .owner_id == owner_id ,
104
- DBPrivateAIKey .litellm_token .isnot (None ) # Only count LLM tokens
105
- ).count ()
106
- if current_user_tokens >= max_keys_per_user :
107
- raise HTTPException (
108
- status_code = status .HTTP_402_PAYMENT_REQUIRED ,
109
- detail = f"User has reached the maximum LLM token limit of { max_keys_per_user } tokens"
110
- )
106
+ if owner_id is not None and result .current_user_keys >= result .max_keys_per_user :
107
+ raise HTTPException (
108
+ status_code = status .HTTP_402_PAYMENT_REQUIRED ,
109
+ detail = f"User has reached the maximum LLM token limit of { result .max_keys_per_user } tokens"
110
+ )
111
111
112
- # Check service LLM tokens (team-owned tokens)
113
- if owner_id is None : # This is a team-owned token
114
- current_service_tokens = db .query (DBPrivateAIKey ).filter (
115
- DBPrivateAIKey .team_id == team_id ,
116
- DBPrivateAIKey .owner_id .is_ (None ),
117
- DBPrivateAIKey .litellm_token .isnot (None ) # Only count LLM tokens
118
- ).count ()
119
- if current_service_tokens >= max_service_keys :
120
- raise HTTPException (
121
- status_code = status .HTTP_402_PAYMENT_REQUIRED ,
122
- detail = f"Team has reached the maximum service LLM token limit of { max_service_keys } tokens"
123
- )
112
+ if owner_id is None and result .current_service_keys >= result .max_service_keys :
113
+ raise HTTPException (
114
+ status_code = status .HTTP_402_PAYMENT_REQUIRED ,
115
+ detail = f"Team has reached the maximum service LLM token limit of { result .max_service_keys } tokens"
116
+ )
124
117
125
118
def check_vector_db_limits (db : Session , team_id : int ) -> None :
126
119
"""
@@ -131,67 +124,75 @@ def check_vector_db_limits(db: Session, team_id: int) -> None:
131
124
db: Database session
132
125
team_id: ID of the team to check
133
126
"""
134
- # Get the team and its active products
135
- team = db .query (DBTeam ).filter (DBTeam .id == team_id ).first ()
136
- if not team :
127
+ # Get vector DB limits and current count in a single query
128
+ result = db .query (
129
+ func .coalesce (func .max (DBProduct .vector_db_count ), DEFAULT_VECTOR_DB_COUNT ).label ('max_vector_db_count' ),
130
+ func .count (DBPrivateAIKey .id ).filter (
131
+ DBPrivateAIKey .database_name .isnot (None )
132
+ ).label ('current_vector_db_count' )
133
+ ).select_from (DBTeam ).filter (
134
+ DBTeam .id == team_id
135
+ ).outerjoin (
136
+ DBTeamProduct ,
137
+ DBTeamProduct .team_id == DBTeam .id
138
+ ).outerjoin (
139
+ DBProduct ,
140
+ DBProduct .id == DBTeamProduct .product_id
141
+ ).outerjoin (
142
+ DBPrivateAIKey ,
143
+ or_ (
144
+ DBPrivateAIKey .team_id == DBTeam .id ,
145
+ DBPrivateAIKey .owner_id .in_ (
146
+ db .query (DBUser .id ).filter (DBUser .team_id == DBTeam .id )
147
+ )
148
+ )
149
+ ).first ()
150
+
151
+ if not result :
137
152
raise HTTPException (status_code = status .HTTP_404_NOT_FOUND , detail = "Team not found" )
138
153
139
- # Find the maximum vector DB count across all active products
140
- max_vector_db_count = max (
141
- (product .vector_db_count for team_product in team .active_products
142
- for product in [team_product .product ] if product .vector_db_count ),
143
- default = DEFAULT_VECTOR_DB_COUNT # Default to 1 if no products have vector_db_count set
144
- )
145
-
146
- # Get all users in the team
147
- team_users = db .query (DBUser ).filter (DBUser .team_id == team_id ).all ()
148
- user_ids = [user .id for user in team_users ]
149
-
150
- # Get current vector DB count for the team (both team-owned and user-owned)
151
- current_vector_db_count = db .query (DBPrivateAIKey ).filter (
152
- (
153
- (DBPrivateAIKey .team_id == team_id ) | # Team-owned vector DBs
154
- (DBPrivateAIKey .owner_id .in_ (user_ids )) # User-owned vector DBs
155
- ),
156
- DBPrivateAIKey .database_name .isnot (None ) # Only count keys with database_name set
157
- ).count ()
158
-
159
- if current_vector_db_count >= max_vector_db_count :
154
+ if result .current_vector_db_count >= result .max_vector_db_count :
160
155
raise HTTPException (
161
156
status_code = status .HTTP_402_PAYMENT_REQUIRED ,
162
- detail = f"Team has reached the maximum vector DB limit of { max_vector_db_count } databases"
157
+ detail = f"Team has reached the maximum vector DB limit of { result . max_vector_db_count } databases"
163
158
)
164
159
165
160
def get_token_restrictions (db : Session , team_id : int ) -> tuple [int , float , int ]:
166
161
"""
167
162
Get the token restrictions for a team.
168
163
"""
169
- team = db .query (DBTeam ).filter (DBTeam .id == team_id ).first ()
170
- if not team :
164
+ # Get all token restrictions in a single query
165
+ result = db .query (
166
+ func .coalesce (func .max (DBProduct .renewal_period_days ), DEFAULT_KEY_DURATION ).label ('max_key_duration' ),
167
+ func .coalesce (func .max (DBProduct .max_budget_per_key ), DEFAULT_MAX_SPEND ).label ('max_max_spend' ),
168
+ func .coalesce (func .max (DBProduct .rpm_per_key ), DEFAULT_RPM_PER_KEY ).label ('max_rpm_limit' ),
169
+ DBTeam .created_at ,
170
+ DBTeam .last_payment
171
+ ).select_from (DBTeam ).filter (
172
+ DBTeam .id == team_id
173
+ ).outerjoin (
174
+ DBTeamProduct ,
175
+ DBTeamProduct .team_id == DBTeam .id
176
+ ).outerjoin (
177
+ DBProduct ,
178
+ DBProduct .id == DBTeamProduct .product_id
179
+ ).group_by (
180
+ DBTeam .created_at ,
181
+ DBTeam .last_payment
182
+ ).first ()
183
+
184
+ if not result :
171
185
logger .error (f"Team not found for team_id: { team_id } " )
172
186
raise HTTPException (status_code = status .HTTP_404_NOT_FOUND , detail = "Team not found" )
173
187
174
- max_key_duration = max (
175
- (product .renewal_period_days for team_product in team .active_products
176
- for product in [team_product .product ] if product .renewal_period_days ),
177
- default = DEFAULT_KEY_DURATION
178
- )
179
- if team .last_payment is None :
180
- days_left_in_period = max_key_duration
188
+ if result .last_payment is None :
189
+ days_left_in_period = result .max_key_duration
181
190
else :
182
- days_left_in_period = max_key_duration - (datetime .now (UTC ) - max (team .created_at .replace (tzinfo = UTC ), team .last_payment .replace (tzinfo = UTC ))).days
183
- max_max_spend = max (
184
- (product .max_budget_per_key for team_product in team .active_products
185
- for product in [team_product .product ] if product .max_budget_per_key ),
186
- default = DEFAULT_MAX_SPEND
187
- )
188
- max_rpm_limit = max (
189
- (product .rpm_per_key for team_product in team .active_products
190
- for product in [team_product .product ] if product .rpm_per_key ),
191
- default = DEFAULT_RPM_PER_KEY
192
- )
193
-
194
- return days_left_in_period , max_max_spend , max_rpm_limit
191
+ days_left_in_period = result .max_key_duration - (
192
+ datetime .now (UTC ) - max (result .created_at .replace (tzinfo = UTC ), result .last_payment .replace (tzinfo = UTC ))
193
+ ).days
194
+
195
+ return days_left_in_period , result .max_max_spend , result .max_rpm_limit
195
196
196
197
def get_team_limits (db : Session , team_id : int ):
197
198
# TODO: Go through all products, and create a master list of the limits on all fields for this team.
0 commit comments