Skip to content

Commit f472717

Browse files
authored
Merge pull request #63 from amazeeio/periodic-job
Periodic job for team management
2 parents 6425a91 + f74b7da commit f472717

File tree

10 files changed

+634
-37
lines changed

10 files changed

+634
-37
lines changed

app/api/private_ai_keys.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,9 @@ async def create_vector_db(
118118
if settings.ENABLE_LIMITS:
119119
if not team_id: # if the team_id is not set we have already validated the owner_id
120120
user = db.query(DBUser).filter(DBUser.id == owner_id).first()
121-
team_id = user.team_id or FAKE_ID
122-
check_vector_db_limits(db, team_id)
121+
team_id = user.team_id # Remove the FAKE_ID fallback
122+
if team_id: # Only check limits if we have a valid team_id
123+
check_vector_db_limits(db, team_id)
123124

124125
try:
125126
# Create new postgres database
@@ -332,7 +333,7 @@ async def create_llm_token(
332333
litellm_token=litellm_token,
333334
litellm_api_url=region.litellm_api_url,
334335
owner_id=owner_id,
335-
team_id=team_id,
336+
team_id=None if team_id is None else team_id,
336337
name=private_ai_key.name,
337338
region_id = private_ai_key.region_id
338339
)

app/core/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class Settings(BaseSettings):
3434

3535
model_config = ConfigDict(env_file=".env")
3636
main_route: str = os.getenv("LAGOON_ROUTE", "http://localhost:8800")
37+
frontend_route: str = os.getenv("FRONTEND_ROUTE", "http://localhost:3000")
3738

3839
def model_post_init(self, values):
3940
# Add Lagoon routes to CORS origins if available

app/core/resource_limits.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
DEFAULT_SERVICE_KEYS = 1
1616
DEFAULT_VECTOR_DB_COUNT = 1
1717
DEFAULT_KEY_DURATION = 30
18-
DEFAULT_MAX_SPEND = 20.0
18+
DEFAULT_MAX_SPEND = 27.0
1919
DEFAULT_RPM_PER_KEY = 500
2020

2121
def check_team_user_limit(db: Session, team_id: int) -> None:

app/core/worker.py

Lines changed: 292 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from datetime import datetime, UTC
1+
from datetime import datetime, UTC, timedelta
22
from sqlalchemy.orm import Session
3-
from app.db.models import DBTeam, DBProduct, DBTeamProduct, DBPrivateAIKey, DBUser
3+
from app.db.models import DBTeam, DBProduct, DBTeamProduct, DBPrivateAIKey, DBUser, DBRegion
44
from app.services.litellm import LiteLLMService
5+
from app.services.ses import SESService
56
import logging
67
from collections import defaultdict
78
from app.core.resource_limits import get_token_restrictions
@@ -15,9 +16,79 @@
1516
INVOICE_FAILURE_EVENTS,
1617
INVOICE_SUCCESS_EVENTS
1718
)
19+
from prometheus_client import Gauge, Counter
20+
from typing import Dict, List
21+
from app.core.security import create_access_token
22+
from app.core.config import settings
23+
from urllib.parse import urljoin
1824

1925
logger = logging.getLogger(__name__)
2026

27+
FIRST_EMAIL_DAYS_LEFT = 7
28+
SECOND_EMAIL_DAYS_LEFT = 5
29+
TRIAL_OVER_DAYS = 30
30+
31+
# Prometheus metrics
32+
team_freshness_days = Gauge(
33+
"team_freshness_days",
34+
"Freshness of teams in days (since creation for teams without products, since last payment for teams with products)",
35+
["team_id", "team_name"]
36+
)
37+
38+
team_expired_metric = Counter(
39+
"team_expired_total",
40+
"Total number of teams that have expired without products",
41+
["team_id", "team_name"]
42+
)
43+
44+
key_spend_percentage = Gauge(
45+
"key_spend_percentage",
46+
"Percentage of budget used for each key",
47+
["team_id", "team_name", "key_alias"]
48+
)
49+
50+
team_total_spend = Gauge(
51+
"team_total_spend",
52+
"Total spend across all keys in a team for the current budget period",
53+
["team_id", "team_name"]
54+
)
55+
56+
# Track active team labels to zero out metrics for inactive teams
57+
active_team_labels = set()
58+
59+
def get_team_keys_by_region(db: Session, team_id: int) -> Dict[DBRegion, List[DBPrivateAIKey]]:
60+
"""
61+
Get all keys for a team grouped by region.
62+
63+
Args:
64+
db: Database session
65+
team_id: ID of the team to get keys for
66+
67+
Returns:
68+
Dictionary mapping regions to lists of keys
69+
"""
70+
# Get all keys for the team with their regions
71+
team_users = db.query(DBUser).filter(DBUser.team_id == team_id).all()
72+
team_user_ids = [user.id for user in team_users]
73+
# Return keys owned by users in the team OR owned by the team
74+
team_keys = db.query(DBPrivateAIKey).filter(
75+
(DBPrivateAIKey.owner_id.in_(team_user_ids)) |
76+
(DBPrivateAIKey.team_id == team_id)
77+
).all()
78+
79+
# Group keys by region
80+
keys_by_region = defaultdict(list)
81+
for key in team_keys:
82+
if not key.litellm_token:
83+
logger.warning(f"Key {key.id} has no LiteLLM token, skipping")
84+
continue
85+
if not key.region:
86+
logger.warning(f"Key {key.id} has no region, skipping")
87+
continue
88+
keys_by_region[key.region].append(key)
89+
90+
return keys_by_region
91+
2192
async def handle_stripe_event_background(event, db: Session):
2293
"""
2394
Background task to handle Stripe webhook events.
@@ -60,7 +131,6 @@ async def handle_stripe_event_background(event, db: Session):
60131
except Exception as e:
61132
logger.error(f"Error in background event handler: {str(e)}")
62133

63-
64134
async def apply_product_for_team(db: Session, customer_id: str, product_id: str, start_date: datetime):
65135
"""
66136
Apply a product to a team and update their last payment date.
@@ -107,25 +177,8 @@ async def apply_product_for_team(db: Session, customer_id: str, product_id: str,
107177

108178
days_left_in_period, max_max_spend, max_rpm_limit = get_token_restrictions(db, team.id)
109179

110-
# Get all keys for the team with their regions
111-
team_users = db.query(DBUser).filter(DBUser.team_id == team.id).all()
112-
team_user_ids = [user.id for user in team_users]
113-
# Return keys owned by users in the team OR owned by the team
114-
team_keys = db.query(DBPrivateAIKey).filter(
115-
(DBPrivateAIKey.owner_id.in_(team_user_ids)) |
116-
(DBPrivateAIKey.team_id == team.id)
117-
).all()
118-
119-
# Group keys by region
120-
keys_by_region = defaultdict(list)
121-
for key in team_keys:
122-
if not key.litellm_token:
123-
logger.warning(f"Key {key.id} has no LiteLLM token, skipping")
124-
continue
125-
if not key.region:
126-
logger.warning(f"Key {key.id} has no region, skipping")
127-
continue
128-
keys_by_region[key.region].append(key)
180+
# Get all keys for the team grouped by region
181+
keys_by_region = get_team_keys_by_region(db, team.id)
129182

130183
# Update keys for each region
131184
for region, keys in keys_by_region.items():
@@ -188,3 +241,220 @@ async def remove_product_from_team(db: Session, customer_id: str, product_id: st
188241
db.rollback()
189242
logger.error(f"Error removing product from team: {str(e)}")
190243
raise e
244+
245+
async def monitor_teams(db: Session):
246+
"""
247+
Daily monitoring task for teams that:
248+
1. Posts age metrics for teams (since creation for teams without products, since last payment for teams with products)
249+
2. Sends notifications for teams approaching expiration (25-30 days)
250+
3. Posts metrics for expired teams (>30 days)
251+
4. Monitors key spend and notifies if approaching limits
252+
"""
253+
logger.info("Monitoring teams")
254+
try:
255+
# Initialize SES service
256+
ses_service = SESService()
257+
258+
# Get all teams
259+
teams = db.query(DBTeam).all()
260+
current_time = datetime.now(UTC)
261+
262+
# Track current active team labels
263+
current_team_labels = set()
264+
265+
logger.info(f"Found {len(teams)} teams to track")
266+
for team in teams:
267+
team_label = (str(team.id), team.name)
268+
current_team_labels.add(team_label)
269+
270+
# Check if team has any products
271+
has_products = db.query(DBTeamProduct).filter(
272+
DBTeamProduct.team_id == team.id
273+
).first() is not None
274+
275+
# Calculate team age based on whether they have products
276+
if has_products and team.last_payment:
277+
team_freshness = (current_time - team.last_payment).days
278+
else: # If a subscription is cancelled this will jump dramatically.
279+
team_freshness = (current_time - team.created_at).days
280+
281+
if team_freshness < 0:
282+
logger.warning(f"Team {team.name} (ID: {team.id}) has a negative age: {team_freshness} days")
283+
team_freshness = 0
284+
285+
# Post freshness metric
286+
team_freshness_days.labels(
287+
team_id=str(team.id),
288+
team_name=team.name
289+
).set(team_freshness)
290+
291+
# Check for notification conditions for teams without products
292+
if not has_products:
293+
days_remaining = TRIAL_OVER_DAYS - team_freshness
294+
if days_remaining == FIRST_EMAIL_DAYS_LEFT or days_remaining == SECOND_EMAIL_DAYS_LEFT:
295+
logger.warning(f"Team {team.name} (ID: {team.id}) is approaching expiration in {days_remaining} days")
296+
# Send expiration notification email
297+
try:
298+
if team.admin_email:
299+
template_data = {
300+
"team_name": team.name,
301+
"days_remaining": days_remaining,
302+
"dashboard_url": generate_pricing_url(db, team)
303+
}
304+
ses_service.send_email(
305+
to_addresses=[team.admin_email],
306+
template_name="team-expiring",
307+
template_data=template_data
308+
)
309+
logger.info(f"Sent expiration notification email to team {team.name} (ID: {team.id})")
310+
else:
311+
logger.warning(f"No email found for team {team.name} (ID: {team.id})")
312+
except Exception as e:
313+
logger.error(f"Failed to send expiration notification email to team {team.name}: {str(e)}")
314+
elif days_remaining <= 0:
315+
# Post expired metric
316+
team_expired_metric.labels(
317+
team_id=str(team.id),
318+
team_name=team.name
319+
).inc()
320+
321+
# Get all keys for the team grouped by region
322+
keys_by_region = get_team_keys_by_region(db, team.id)
323+
324+
# Track total spend across all keys for this team
325+
team_total = 0
326+
327+
# Monitor keys for each region
328+
for region, keys in keys_by_region.items():
329+
try:
330+
# Initialize LiteLLM service for this region
331+
litellm_service = LiteLLMService(
332+
api_url=region.litellm_api_url,
333+
api_key=region.litellm_api_key
334+
)
335+
336+
# Check spend for each key in this region
337+
for key in keys:
338+
try:
339+
# Get current spend using get_key_info
340+
key_info = await litellm_service.get_key_info(key.litellm_token)
341+
info = key_info.get("info", {})
342+
current_spend = info.get("spend", 0)
343+
budget = info.get("max_budget", 0)
344+
key_alias = info.get("key_alias", f"key-{key.id}") # Fallback to key-{id} if no alias
345+
346+
# Add to team total
347+
team_total += current_spend
348+
349+
# Calculate and post percentage used
350+
if budget > 0:
351+
percentage_used = (current_spend / budget) * 100
352+
key_spend_percentage.labels(
353+
team_id=str(team.id),
354+
team_name=team.name,
355+
key_alias=key_alias
356+
).set(percentage_used)
357+
358+
# Log warning if approaching limit
359+
if percentage_used >= 80:
360+
logger.warning(
361+
f"Key {key_alias} for team {team.name} is approaching spend limit: "
362+
f"${current_spend:.2f} of ${budget:.2f} ({percentage_used:.1f}%)"
363+
)
364+
else:
365+
# Set to 0 if no budget is set
366+
key_spend_percentage.labels(
367+
team_id=str(team.id),
368+
team_name=team.name,
369+
key_alias=key_alias
370+
).set(0)
371+
372+
except Exception as e:
373+
logger.error(f"Error monitoring key {key.id} spend: {str(e)}")
374+
continue
375+
376+
except Exception as e:
377+
logger.error(f"Error initializing LiteLLM service for region {region.name}: {str(e)}")
378+
continue
379+
380+
# Set the total spend metric for the team
381+
team_total_spend.labels(
382+
team_id=str(team.id),
383+
team_name=team.name
384+
).set(team_total)
385+
386+
# Zero out metrics for teams that are no longer active
387+
for old_label in active_team_labels - current_team_labels:
388+
team_freshness_days.labels(
389+
team_id=old_label[0],
390+
team_name=old_label[1]
391+
).set(0)
392+
393+
# Update active team labels for next run
394+
active_team_labels.clear()
395+
active_team_labels.update(current_team_labels)
396+
397+
except Exception as e:
398+
logger.error(f"Error in team monitoring task: {str(e)}")
399+
raise e
400+
401+
def generate_team_admin_token(db: Session, team: DBTeam) -> str:
402+
"""
403+
Generate a JWT token that authorizes the bearer as an administrator of the team.
404+
405+
Args:
406+
db: Database session
407+
team: The team object to generate the token for
408+
409+
Returns:
410+
str: The generated JWT token
411+
412+
Raises:
413+
ValueError: If no admin user is found for the team
414+
"""
415+
token_validity_days = 1
416+
# Find a team admin user
417+
admin_user = db.query(DBUser).filter(
418+
DBUser.team_id == team.id,
419+
DBUser.role == "admin"
420+
).first()
421+
422+
if not admin_user:
423+
raise ValueError(f"No admin user found for team {team.name} (ID: {team.id})")
424+
425+
# Create token payload with team admin claims
426+
payload = {
427+
"sub": admin_user.email,
428+
"exp": datetime.now(UTC) + timedelta(days=token_validity_days)
429+
}
430+
431+
# Generate the token
432+
token = create_access_token(
433+
data=payload,
434+
expires_delta=timedelta(days=token_validity_days)
435+
)
436+
437+
return token
438+
439+
def generate_pricing_url(db: Session, team: DBTeam) -> str:
440+
"""
441+
Generate a URL for the team admin pricing page with a JWT token.
442+
443+
Args:
444+
db: Database session
445+
team: The team object to generate the URL for
446+
447+
Returns:
448+
str: The generated URL with the JWT token
449+
"""
450+
# Generate the token
451+
token = generate_team_admin_token(db, team)
452+
453+
# Get the frontend URL from settings
454+
base_url = settings.frontend_route
455+
path = '/pricing'
456+
url = urljoin(base_url, path)
457+
458+
# Add the token as a query parameter
459+
return f"{url}?token={token}"
460+

0 commit comments

Comments
 (0)