|
1 |
| -from datetime import datetime, UTC |
| 1 | +from datetime import datetime, UTC, timedelta |
2 | 2 | from sqlalchemy.orm import Session
|
3 |
| -from app.db.models import DBTeam, DBProduct, DBTeamProduct, DBPrivateAIKey, DBUser |
| 3 | +from app.db.models import DBTeam, DBProduct, DBTeamProduct, DBPrivateAIKey, DBUser, DBRegion |
4 | 4 | from app.services.litellm import LiteLLMService
|
| 5 | +from app.services.ses import SESService |
5 | 6 | import logging
|
6 | 7 | from collections import defaultdict
|
7 | 8 | from app.core.resource_limits import get_token_restrictions
|
|
15 | 16 | INVOICE_FAILURE_EVENTS,
|
16 | 17 | INVOICE_SUCCESS_EVENTS
|
17 | 18 | )
|
| 19 | +from prometheus_client import Gauge, Counter |
| 20 | +from typing import Dict, List |
| 21 | +from app.core.security import create_access_token |
| 22 | +from app.core.config import settings |
| 23 | +from urllib.parse import urljoin |
18 | 24 |
|
19 | 25 | logger = logging.getLogger(__name__)
|
20 | 26 |
|
| 27 | +FIRST_EMAIL_DAYS_LEFT = 7 |
| 28 | +SECOND_EMAIL_DAYS_LEFT = 5 |
| 29 | +TRIAL_OVER_DAYS = 30 |
| 30 | + |
| 31 | +# Prometheus metrics |
| 32 | +team_freshness_days = Gauge( |
| 33 | + "team_freshness_days", |
| 34 | + "Freshness of teams in days (since creation for teams without products, since last payment for teams with products)", |
| 35 | + ["team_id", "team_name"] |
| 36 | +) |
| 37 | + |
| 38 | +team_expired_metric = Counter( |
| 39 | + "team_expired_total", |
| 40 | + "Total number of teams that have expired without products", |
| 41 | + ["team_id", "team_name"] |
| 42 | +) |
| 43 | + |
| 44 | +key_spend_percentage = Gauge( |
| 45 | + "key_spend_percentage", |
| 46 | + "Percentage of budget used for each key", |
| 47 | + ["team_id", "team_name", "key_alias"] |
| 48 | +) |
| 49 | + |
| 50 | +team_total_spend = Gauge( |
| 51 | + "team_total_spend", |
| 52 | + "Total spend across all keys in a team for the current budget period", |
| 53 | + ["team_id", "team_name"] |
| 54 | +) |
| 55 | + |
| 56 | +# Track active team labels to zero out metrics for inactive teams |
| 57 | +active_team_labels = set() |
| 58 | + |
| 59 | +def get_team_keys_by_region(db: Session, team_id: int) -> Dict[DBRegion, List[DBPrivateAIKey]]: |
| 60 | + """ |
| 61 | + Get all keys for a team grouped by region. |
| 62 | +
|
| 63 | + Args: |
| 64 | + db: Database session |
| 65 | + team_id: ID of the team to get keys for |
| 66 | +
|
| 67 | + Returns: |
| 68 | + Dictionary mapping regions to lists of keys |
| 69 | + """ |
| 70 | + # Get all keys for the team with their regions |
| 71 | + team_users = db.query(DBUser).filter(DBUser.team_id == team_id).all() |
| 72 | + team_user_ids = [user.id for user in team_users] |
| 73 | + # Return keys owned by users in the team OR owned by the team |
| 74 | + team_keys = db.query(DBPrivateAIKey).filter( |
| 75 | + (DBPrivateAIKey.owner_id.in_(team_user_ids)) | |
| 76 | + (DBPrivateAIKey.team_id == team_id) |
| 77 | + ).all() |
| 78 | + |
| 79 | + # Group keys by region |
| 80 | + keys_by_region = defaultdict(list) |
| 81 | + for key in team_keys: |
| 82 | + if not key.litellm_token: |
| 83 | + logger.warning(f"Key {key.id} has no LiteLLM token, skipping") |
| 84 | + continue |
| 85 | + if not key.region: |
| 86 | + logger.warning(f"Key {key.id} has no region, skipping") |
| 87 | + continue |
| 88 | + keys_by_region[key.region].append(key) |
| 89 | + |
| 90 | + return keys_by_region |
| 91 | + |
21 | 92 | async def handle_stripe_event_background(event, db: Session):
|
22 | 93 | """
|
23 | 94 | Background task to handle Stripe webhook events.
|
@@ -60,7 +131,6 @@ async def handle_stripe_event_background(event, db: Session):
|
60 | 131 | except Exception as e:
|
61 | 132 | logger.error(f"Error in background event handler: {str(e)}")
|
62 | 133 |
|
63 |
| - |
64 | 134 | async def apply_product_for_team(db: Session, customer_id: str, product_id: str, start_date: datetime):
|
65 | 135 | """
|
66 | 136 | Apply a product to a team and update their last payment date.
|
@@ -107,25 +177,8 @@ async def apply_product_for_team(db: Session, customer_id: str, product_id: str,
|
107 | 177 |
|
108 | 178 | days_left_in_period, max_max_spend, max_rpm_limit = get_token_restrictions(db, team.id)
|
109 | 179 |
|
110 |
| - # Get all keys for the team with their regions |
111 |
| - team_users = db.query(DBUser).filter(DBUser.team_id == team.id).all() |
112 |
| - team_user_ids = [user.id for user in team_users] |
113 |
| - # Return keys owned by users in the team OR owned by the team |
114 |
| - team_keys = db.query(DBPrivateAIKey).filter( |
115 |
| - (DBPrivateAIKey.owner_id.in_(team_user_ids)) | |
116 |
| - (DBPrivateAIKey.team_id == team.id) |
117 |
| - ).all() |
118 |
| - |
119 |
| - # Group keys by region |
120 |
| - keys_by_region = defaultdict(list) |
121 |
| - for key in team_keys: |
122 |
| - if not key.litellm_token: |
123 |
| - logger.warning(f"Key {key.id} has no LiteLLM token, skipping") |
124 |
| - continue |
125 |
| - if not key.region: |
126 |
| - logger.warning(f"Key {key.id} has no region, skipping") |
127 |
| - continue |
128 |
| - keys_by_region[key.region].append(key) |
| 180 | + # Get all keys for the team grouped by region |
| 181 | + keys_by_region = get_team_keys_by_region(db, team.id) |
129 | 182 |
|
130 | 183 | # Update keys for each region
|
131 | 184 | for region, keys in keys_by_region.items():
|
@@ -188,3 +241,220 @@ async def remove_product_from_team(db: Session, customer_id: str, product_id: st
|
188 | 241 | db.rollback()
|
189 | 242 | logger.error(f"Error removing product from team: {str(e)}")
|
190 | 243 | raise e
|
| 244 | + |
| 245 | +async def monitor_teams(db: Session): |
| 246 | + """ |
| 247 | + Daily monitoring task for teams that: |
| 248 | + 1. Posts age metrics for teams (since creation for teams without products, since last payment for teams with products) |
| 249 | + 2. Sends notifications for teams approaching expiration (25-30 days) |
| 250 | + 3. Posts metrics for expired teams (>30 days) |
| 251 | + 4. Monitors key spend and notifies if approaching limits |
| 252 | + """ |
| 253 | + logger.info("Monitoring teams") |
| 254 | + try: |
| 255 | + # Initialize SES service |
| 256 | + ses_service = SESService() |
| 257 | + |
| 258 | + # Get all teams |
| 259 | + teams = db.query(DBTeam).all() |
| 260 | + current_time = datetime.now(UTC) |
| 261 | + |
| 262 | + # Track current active team labels |
| 263 | + current_team_labels = set() |
| 264 | + |
| 265 | + logger.info(f"Found {len(teams)} teams to track") |
| 266 | + for team in teams: |
| 267 | + team_label = (str(team.id), team.name) |
| 268 | + current_team_labels.add(team_label) |
| 269 | + |
| 270 | + # Check if team has any products |
| 271 | + has_products = db.query(DBTeamProduct).filter( |
| 272 | + DBTeamProduct.team_id == team.id |
| 273 | + ).first() is not None |
| 274 | + |
| 275 | + # Calculate team age based on whether they have products |
| 276 | + if has_products and team.last_payment: |
| 277 | + team_freshness = (current_time - team.last_payment).days |
| 278 | + else: # If a subscription is cancelled this will jump dramatically. |
| 279 | + team_freshness = (current_time - team.created_at).days |
| 280 | + |
| 281 | + if team_freshness < 0: |
| 282 | + logger.warning(f"Team {team.name} (ID: {team.id}) has a negative age: {team_freshness} days") |
| 283 | + team_freshness = 0 |
| 284 | + |
| 285 | + # Post freshness metric |
| 286 | + team_freshness_days.labels( |
| 287 | + team_id=str(team.id), |
| 288 | + team_name=team.name |
| 289 | + ).set(team_freshness) |
| 290 | + |
| 291 | + # Check for notification conditions for teams without products |
| 292 | + if not has_products: |
| 293 | + days_remaining = TRIAL_OVER_DAYS - team_freshness |
| 294 | + if days_remaining == FIRST_EMAIL_DAYS_LEFT or days_remaining == SECOND_EMAIL_DAYS_LEFT: |
| 295 | + logger.warning(f"Team {team.name} (ID: {team.id}) is approaching expiration in {days_remaining} days") |
| 296 | + # Send expiration notification email |
| 297 | + try: |
| 298 | + if team.admin_email: |
| 299 | + template_data = { |
| 300 | + "team_name": team.name, |
| 301 | + "days_remaining": days_remaining, |
| 302 | + "dashboard_url": generate_pricing_url(db, team) |
| 303 | + } |
| 304 | + ses_service.send_email( |
| 305 | + to_addresses=[team.admin_email], |
| 306 | + template_name="team-expiring", |
| 307 | + template_data=template_data |
| 308 | + ) |
| 309 | + logger.info(f"Sent expiration notification email to team {team.name} (ID: {team.id})") |
| 310 | + else: |
| 311 | + logger.warning(f"No email found for team {team.name} (ID: {team.id})") |
| 312 | + except Exception as e: |
| 313 | + logger.error(f"Failed to send expiration notification email to team {team.name}: {str(e)}") |
| 314 | + elif days_remaining <= 0: |
| 315 | + # Post expired metric |
| 316 | + team_expired_metric.labels( |
| 317 | + team_id=str(team.id), |
| 318 | + team_name=team.name |
| 319 | + ).inc() |
| 320 | + |
| 321 | + # Get all keys for the team grouped by region |
| 322 | + keys_by_region = get_team_keys_by_region(db, team.id) |
| 323 | + |
| 324 | + # Track total spend across all keys for this team |
| 325 | + team_total = 0 |
| 326 | + |
| 327 | + # Monitor keys for each region |
| 328 | + for region, keys in keys_by_region.items(): |
| 329 | + try: |
| 330 | + # Initialize LiteLLM service for this region |
| 331 | + litellm_service = LiteLLMService( |
| 332 | + api_url=region.litellm_api_url, |
| 333 | + api_key=region.litellm_api_key |
| 334 | + ) |
| 335 | + |
| 336 | + # Check spend for each key in this region |
| 337 | + for key in keys: |
| 338 | + try: |
| 339 | + # Get current spend using get_key_info |
| 340 | + key_info = await litellm_service.get_key_info(key.litellm_token) |
| 341 | + info = key_info.get("info", {}) |
| 342 | + current_spend = info.get("spend", 0) |
| 343 | + budget = info.get("max_budget", 0) |
| 344 | + key_alias = info.get("key_alias", f"key-{key.id}") # Fallback to key-{id} if no alias |
| 345 | + |
| 346 | + # Add to team total |
| 347 | + team_total += current_spend |
| 348 | + |
| 349 | + # Calculate and post percentage used |
| 350 | + if budget > 0: |
| 351 | + percentage_used = (current_spend / budget) * 100 |
| 352 | + key_spend_percentage.labels( |
| 353 | + team_id=str(team.id), |
| 354 | + team_name=team.name, |
| 355 | + key_alias=key_alias |
| 356 | + ).set(percentage_used) |
| 357 | + |
| 358 | + # Log warning if approaching limit |
| 359 | + if percentage_used >= 80: |
| 360 | + logger.warning( |
| 361 | + f"Key {key_alias} for team {team.name} is approaching spend limit: " |
| 362 | + f"${current_spend:.2f} of ${budget:.2f} ({percentage_used:.1f}%)" |
| 363 | + ) |
| 364 | + else: |
| 365 | + # Set to 0 if no budget is set |
| 366 | + key_spend_percentage.labels( |
| 367 | + team_id=str(team.id), |
| 368 | + team_name=team.name, |
| 369 | + key_alias=key_alias |
| 370 | + ).set(0) |
| 371 | + |
| 372 | + except Exception as e: |
| 373 | + logger.error(f"Error monitoring key {key.id} spend: {str(e)}") |
| 374 | + continue |
| 375 | + |
| 376 | + except Exception as e: |
| 377 | + logger.error(f"Error initializing LiteLLM service for region {region.name}: {str(e)}") |
| 378 | + continue |
| 379 | + |
| 380 | + # Set the total spend metric for the team |
| 381 | + team_total_spend.labels( |
| 382 | + team_id=str(team.id), |
| 383 | + team_name=team.name |
| 384 | + ).set(team_total) |
| 385 | + |
| 386 | + # Zero out metrics for teams that are no longer active |
| 387 | + for old_label in active_team_labels - current_team_labels: |
| 388 | + team_freshness_days.labels( |
| 389 | + team_id=old_label[0], |
| 390 | + team_name=old_label[1] |
| 391 | + ).set(0) |
| 392 | + |
| 393 | + # Update active team labels for next run |
| 394 | + active_team_labels.clear() |
| 395 | + active_team_labels.update(current_team_labels) |
| 396 | + |
| 397 | + except Exception as e: |
| 398 | + logger.error(f"Error in team monitoring task: {str(e)}") |
| 399 | + raise e |
| 400 | + |
| 401 | +def generate_team_admin_token(db: Session, team: DBTeam) -> str: |
| 402 | + """ |
| 403 | + Generate a JWT token that authorizes the bearer as an administrator of the team. |
| 404 | +
|
| 405 | + Args: |
| 406 | + db: Database session |
| 407 | + team: The team object to generate the token for |
| 408 | +
|
| 409 | + Returns: |
| 410 | + str: The generated JWT token |
| 411 | +
|
| 412 | + Raises: |
| 413 | + ValueError: If no admin user is found for the team |
| 414 | + """ |
| 415 | + token_validity_days = 1 |
| 416 | + # Find a team admin user |
| 417 | + admin_user = db.query(DBUser).filter( |
| 418 | + DBUser.team_id == team.id, |
| 419 | + DBUser.role == "admin" |
| 420 | + ).first() |
| 421 | + |
| 422 | + if not admin_user: |
| 423 | + raise ValueError(f"No admin user found for team {team.name} (ID: {team.id})") |
| 424 | + |
| 425 | + # Create token payload with team admin claims |
| 426 | + payload = { |
| 427 | + "sub": admin_user.email, |
| 428 | + "exp": datetime.now(UTC) + timedelta(days=token_validity_days) |
| 429 | + } |
| 430 | + |
| 431 | + # Generate the token |
| 432 | + token = create_access_token( |
| 433 | + data=payload, |
| 434 | + expires_delta=timedelta(days=token_validity_days) |
| 435 | + ) |
| 436 | + |
| 437 | + return token |
| 438 | + |
| 439 | +def generate_pricing_url(db: Session, team: DBTeam) -> str: |
| 440 | + """ |
| 441 | + Generate a URL for the team admin pricing page with a JWT token. |
| 442 | +
|
| 443 | + Args: |
| 444 | + db: Database session |
| 445 | + team: The team object to generate the URL for |
| 446 | +
|
| 447 | + Returns: |
| 448 | + str: The generated URL with the JWT token |
| 449 | + """ |
| 450 | + # Generate the token |
| 451 | + token = generate_team_admin_token(db, team) |
| 452 | + |
| 453 | + # Get the frontend URL from settings |
| 454 | + base_url = settings.frontend_route |
| 455 | + path = '/pricing' |
| 456 | + url = urljoin(base_url, path) |
| 457 | + |
| 458 | + # Add the token as a query parameter |
| 459 | + return f"{url}?token={token}" |
| 460 | + |
0 commit comments