Skip to content

Commit 85bbfba

Browse files
authored
Merge pull request #135 from amazeeio/async-litellm
Async litellm
2 parents 04ee9da + ea7676a commit 85bbfba

12 files changed

+554
-590
lines changed

app/core/resource_limits.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def check_team_user_limit(db: Session, team_id: int) -> None:
2828
"""
2929
# Get current user count and max allowed users in a single query
3030
result = db.query(
31-
func.count(DBUser.id).label('current_user_count'),
31+
func.count(func.distinct(DBUser.id)).label('current_user_count'),
3232
func.coalesce(func.max(DBProduct.user_count), DEFAULT_USER_COUNT).label('max_users')
3333
).select_from(DBUser).filter(
3434
DBUser.team_id == team_id
@@ -64,14 +64,14 @@ def check_key_limits(db: Session, team_id: int, owner_id: Optional[int] = None)
6464
func.coalesce(func.max(DBProduct.total_key_count), DEFAULT_TOTAL_KEYS).label('max_total_keys'),
6565
func.coalesce(func.max(DBProduct.keys_per_user), DEFAULT_KEYS_PER_USER).label('max_keys_per_user'),
6666
func.coalesce(func.max(DBProduct.service_key_count), DEFAULT_SERVICE_KEYS).label('max_service_keys'),
67-
func.count(DBPrivateAIKey.id).filter(
67+
func.count(func.distinct(DBPrivateAIKey.id)).filter(
6868
DBPrivateAIKey.litellm_token.isnot(None)
6969
).label('current_team_keys'),
70-
func.count(DBPrivateAIKey.id).filter(
70+
func.count(func.distinct(DBPrivateAIKey.id)).filter(
7171
DBPrivateAIKey.owner_id == owner_id,
7272
DBPrivateAIKey.litellm_token.isnot(None)
7373
).label('current_user_keys') if owner_id else None,
74-
func.count(DBPrivateAIKey.id).filter(
74+
func.count(func.distinct(DBPrivateAIKey.id)).filter(
7575
DBPrivateAIKey.owner_id.is_(None),
7676
DBPrivateAIKey.litellm_token.isnot(None)
7777
).label('current_service_keys')
@@ -126,7 +126,7 @@ def check_vector_db_limits(db: Session, team_id: int) -> None:
126126
# Get vector DB limits and current count in a single query
127127
result = db.query(
128128
func.coalesce(func.max(DBProduct.vector_db_count), DEFAULT_VECTOR_DB_COUNT).label('max_vector_db_count'),
129-
func.count(DBPrivateAIKey.id).filter(
129+
func.count(func.distinct(DBPrivateAIKey.id)).filter(
130130
DBPrivateAIKey.database_name.isnot(None)
131131
).label('current_vector_db_count')
132132
).select_from(DBTeam).filter(

app/services/litellm.py

Lines changed: 96 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import requests
1+
import httpx
22
from fastapi import HTTPException, status
33
import logging
44
from app.core.resource_limits import DEFAULT_KEY_DURATION, DEFAULT_MAX_SPEND, DEFAULT_RPM_PER_KEY
@@ -57,20 +57,21 @@ async def create_key(self, email: str, name: str, user_id: int, team_id: str, du
5757
if user_id is not None:
5858
request_data["user_id"] = str(user_id)
5959

60-
logger.info(f"Making request to LiteLLM API to generate key with data: {request_data}")
61-
response = requests.post(
62-
f"{self.api_url}/key/generate",
63-
json=request_data,
64-
headers={
65-
"Authorization": f"Bearer {self.master_key}"
66-
}
67-
)
60+
async with httpx.AsyncClient() as client:
61+
response = await client.post(
62+
f"{self.api_url}/key/generate",
63+
json=request_data,
64+
headers={
65+
"Authorization": f"Bearer {self.master_key}"
66+
}
67+
)
6868

69-
response.raise_for_status()
70-
key = response.json()["key"]
71-
logger.info("Successfully generated new LiteLLM API key")
72-
return key
73-
except requests.exceptions.RequestException as e:
69+
response.raise_for_status()
70+
response_data = response.json()
71+
key = response_data["key"]
72+
logger.info("Successfully generated new LiteLLM API key")
73+
return key
74+
except httpx.HTTPStatusError as e:
7475
error_msg = str(e)
7576
if hasattr(e, 'response') and e.response is not None:
7677
try:
@@ -87,21 +88,22 @@ async def create_key(self, email: str, name: str, user_id: int, team_id: str, du
8788
async def delete_key(self, key: str) -> bool:
8889
"""Delete a LiteLLM API key"""
8990
try:
90-
response = requests.post(
91-
f"{self.api_url}/key/delete",
92-
json={"keys": [key]}, # API expects an array of keys
93-
headers={
94-
"Authorization": f"Bearer {self.master_key}"
95-
}
96-
)
91+
async with httpx.AsyncClient() as client:
92+
response = await client.post(
93+
f"{self.api_url}/key/delete",
94+
json={"keys": [key]}, # API expects an array of keys
95+
headers={
96+
"Authorization": f"Bearer {self.master_key}"
97+
}
98+
)
9799

98-
# Treat 404 (key not found) as success
99-
if response.status_code == 404:
100-
return True
100+
# Treat 404 (key not found) as success
101+
if response.status_code == 404:
102+
return True
101103

102-
response.raise_for_status()
103-
return True
104-
except requests.exceptions.RequestException as e:
104+
response.raise_for_status()
105+
return True
106+
except httpx.HTTPStatusError as e:
105107
error_msg = str(e)
106108
if hasattr(e, 'response') and e.response is not None:
107109
try:
@@ -118,19 +120,21 @@ async def delete_key(self, key: str) -> bool:
118120
async def get_key_info(self, litellm_token: str) -> dict:
119121
"""Get information about a LiteLLM API key"""
120122
try:
121-
response = requests.get(
122-
f"{self.api_url}/key/info",
123-
headers={
124-
"Authorization": f"Bearer {self.master_key}"
125-
},
126-
params={
127-
"key": litellm_token
128-
}
129-
)
130-
response.raise_for_status()
131-
logger.info(f"LiteLLM key information: {response.json()}")
132-
return response.json()
133-
except requests.exceptions.RequestException as e:
123+
async with httpx.AsyncClient() as client:
124+
response = await client.get(
125+
f"{self.api_url}/key/info",
126+
headers={
127+
"Authorization": f"Bearer {self.master_key}"
128+
},
129+
params={
130+
"key": litellm_token
131+
}
132+
)
133+
response.raise_for_status()
134+
response_data = response.json()
135+
logger.info(f"LiteLLM key information: {response_data}")
136+
return response_data
137+
except httpx.HTTPStatusError as e:
134138
error_msg = str(e)
135139
logger.error(f"Error getting LiteLLM key information: {error_msg}")
136140
if hasattr(e, 'response') and e.response is not None:
@@ -156,15 +160,16 @@ async def update_budget(self, litellm_token: str, budget_duration: str, budget_a
156160
if budget_amount:
157161
request_data["max_budget"] = budget_amount
158162

159-
response = requests.post(
160-
f"{self.api_url}/key/update",
161-
headers={
162-
"Authorization": f"Bearer {self.master_key}"
163-
},
164-
json=request_data
165-
)
166-
response.raise_for_status()
167-
except requests.exceptions.RequestException as e:
163+
async with httpx.AsyncClient() as client:
164+
response = await client.post(
165+
f"{self.api_url}/key/update",
166+
headers={
167+
"Authorization": f"Bearer {self.master_key}"
168+
},
169+
json=request_data
170+
)
171+
response.raise_for_status()
172+
except httpx.HTTPStatusError as e:
168173
error_msg = str(e)
169174
if hasattr(e, 'response') and e.response is not None:
170175
try:
@@ -180,18 +185,19 @@ async def update_budget(self, litellm_token: str, budget_duration: str, budget_a
180185
async def update_key_duration(self, litellm_token: str, duration: str):
181186
"""Update the duration for a LiteLLM API key"""
182187
try:
183-
response = requests.post(
184-
f"{self.api_url}/key/update",
185-
headers={
186-
"Authorization": f"Bearer {self.master_key}"
187-
},
188-
json={
189-
"key": litellm_token,
190-
"duration": duration
191-
}
192-
)
193-
response.raise_for_status()
194-
except requests.exceptions.RequestException as e:
188+
async with httpx.AsyncClient() as client:
189+
response = await client.post(
190+
f"{self.api_url}/key/update",
191+
headers={
192+
"Authorization": f"Bearer {self.master_key}"
193+
},
194+
json={
195+
"key": litellm_token,
196+
"duration": duration
197+
}
198+
)
199+
response.raise_for_status()
200+
except httpx.HTTPStatusError as e:
195201
error_msg = str(e)
196202
if hasattr(e, 'response') and e.response is not None:
197203
try:
@@ -207,21 +213,22 @@ async def update_key_duration(self, litellm_token: str, duration: str):
207213
async def set_key_restrictions(self, litellm_token: str, duration: str, budget_amount: float, rpm_limit: int, budget_duration: Optional[str] = None):
208214
"""Set the restrictions for a LiteLLM API key"""
209215
try:
210-
response = requests.post(
211-
f"{self.api_url}/key/update",
212-
headers={
213-
"Authorization": f"Bearer {self.master_key}"
214-
},
215-
json={
216-
"key": litellm_token,
217-
"duration": duration,
218-
"budget_duration": budget_duration,
219-
"max_budget": budget_amount,
220-
"rpm_limit": rpm_limit
221-
}
222-
)
223-
response.raise_for_status()
224-
except requests.exceptions.RequestException as e:
216+
async with httpx.AsyncClient() as client:
217+
response = await client.post(
218+
f"{self.api_url}/key/update",
219+
headers={
220+
"Authorization": f"Bearer {self.master_key}"
221+
},
222+
json={
223+
"key": litellm_token,
224+
"duration": duration,
225+
"budget_duration": budget_duration,
226+
"max_budget": budget_amount,
227+
"rpm_limit": rpm_limit
228+
}
229+
)
230+
response.raise_for_status()
231+
except httpx.HTTPStatusError as e:
225232
error_msg = str(e)
226233
if hasattr(e, 'response') and e.response is not None:
227234
try:
@@ -237,18 +244,19 @@ async def set_key_restrictions(self, litellm_token: str, duration: str, budget_a
237244
async def update_key_team_association(self, litellm_token: str, new_team_id: str):
238245
"""Update the team association for a LiteLLM API key"""
239246
try:
240-
response = requests.post(
241-
f"{self.api_url}/key/update",
242-
headers={
243-
"Authorization": f"Bearer {self.master_key}"
244-
},
245-
json={
246-
"key": litellm_token,
247-
"team_id": new_team_id
248-
}
249-
)
250-
response.raise_for_status()
251-
except requests.exceptions.RequestException as e:
247+
async with httpx.AsyncClient() as client:
248+
response = await client.post(
249+
f"{self.api_url}/key/update",
250+
headers={
251+
"Authorization": f"Bearer {self.master_key}"
252+
},
253+
json={
254+
"key": litellm_token,
255+
"team_id": new_team_id
256+
}
257+
)
258+
response.raise_for_status()
259+
except httpx.HTTPStatusError as e:
252260
error_msg = str(e)
253261
if hasattr(e, 'response') and e.response is not None:
254262
try:

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@ prometheus-client==0.21.1
1818
prometheus-fastapi-instrumentator==7.0.0
1919
stripe==12.1.0
2020
six==1.17.0
21-
apscheduler==3.10.4
21+
apscheduler==3.10.4
22+
httpx==0.28.1

scripts/add_test_data.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44
import sys
55
from datetime import datetime, timedelta, UTC
66
import random
7+
import asyncio
78

89
# Add the parent directory to the Python path
910
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
1011

1112
from sqlalchemy.orm import sessionmaker, Session
1213
from app.db.database import engine
13-
from app.db.models import DBTeam, DBUser, DBProduct, DBTeamProduct
14+
from app.db.models import DBTeam, DBUser, DBProduct, DBTeamProduct, DBRegion, DBPrivateAIKey
1415
from app.core.security import get_password_hash
16+
from app.services.litellm import LiteLLMService
1517

1618
def create_test_data():
1719
"""Create test data for teams, users, and products"""
@@ -264,10 +266,49 @@ def create_test_data():
264266
finally:
265267
db.close()
266268

269+
async def create_test_keys(count: int):
270+
# Create database session
271+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
272+
db = SessionLocal()
273+
274+
try:
275+
team = db.query(DBTeam).first()
276+
region = db.query(DBRegion).filter(DBRegion.is_active == True).first()
277+
if not region:
278+
print(f"No active regions, not creating test keys")
279+
return
280+
litellm = LiteLLMService(region.litellm_api_url, region.litellm_api_key)
281+
team_id = team.id
282+
for i in range(0, count):
283+
key_name = f"auto_test_{i}"
284+
litellm_token = await litellm.create_key(
285+
email=team.admin_email,
286+
name=key_name,
287+
user_id=team_id,
288+
team_id=LiteLLMService.format_team_id(region.name, team_id),
289+
)
290+
291+
# Create response object
292+
db_token = DBPrivateAIKey(
293+
litellm_token=litellm_token,
294+
litellm_api_url=region.litellm_api_url,
295+
owner_id=None,
296+
team_id=None if team_id is None else team_id,
297+
name=key_name,
298+
region_id = region.id
299+
)
300+
db.add(db_token)
301+
db.commit()
302+
print(f"Created LLM token {key_name} int team {team.name}")
303+
304+
except Exception as e:
305+
print(f"failed to create test keys {str(e)}")
306+
267307
def main():
268308
"""Main function to run the script"""
269309
try:
270310
create_test_data()
311+
asyncio.run(create_test_keys(50))
271312
except Exception as e:
272313
print(f"Script failed: {str(e)}")
273314
sys.exit(1)

0 commit comments

Comments
 (0)