diff --git a/app/api/teams.py b/app/api/teams.py index c181f4b..826ab1c 100644 --- a/app/api/teams.py +++ b/app/api/teams.py @@ -6,7 +6,7 @@ import logging from app.db.database import get_db -from app.db.models import DBTeam, DBTeamProduct, DBUser, DBPrivateAIKey, DBRegion, DBTeamRegion, DBProduct +from app.db.models import DBTeam, DBTeamProduct, DBUser, DBPrivateAIKey, DBRegion, DBTeamRegion, DBProduct, DBTeamMetrics from app.core.security import get_role_min_system_admin, get_role_min_specific_team_admin, get_current_user_from_auth, check_sales_or_higher from app.schemas.models import ( Team, TeamCreate, TeamUpdate, @@ -248,14 +248,6 @@ async def list_teams_for_sales( all_regions = db.query(DBRegion).filter(DBRegion.is_active == True).all() regions_map = {r.id: r for r in all_regions} - # Pre-create LiteLLM services for each region to avoid re-instantiation - litellm_services = {} - for region in all_regions: - litellm_services[region.id] = LiteLLMService( - api_url=region.litellm_api_url, - api_key=region.litellm_api_key - ) - # Get all teams with their basic information teams = db.query(DBTeam).all() @@ -277,41 +269,70 @@ async def list_teams_for_sales( for team_product in team_products ] - # Get team AI keys (both team-owned and user-owned) and calculate total spend - team_users = db.query(DBUser).filter(DBUser.team_id == team.id).all() - team_user_ids = [user.id for user in team_users] - - team_keys = db.query(DBPrivateAIKey).filter( - (DBPrivateAIKey.team_id == team.id) | # Team-owned keys - (DBPrivateAIKey.owner_id.in_(team_user_ids)) # User-owned keys by team members - ).all() + # Try to get cached metrics first + team_metrics = db.query(DBTeamMetrics).filter(DBTeamMetrics.team_id == team.id).first() + + if team_metrics: + # Use cached data + total_spend = team_metrics.total_spend + regions = team_metrics.regions or [] + else: + # Fallback to real-time calculation if no cached data + logger.warning(f"No cached metrics found for team {team.id}, falling back to real-time calculation") + current_time = datetime.now(UTC) + + # Create LiteLLM services only when needed for fallback + litellm_services = {} + for region in all_regions: + litellm_services[region.id] = LiteLLMService( + api_url=region.litellm_api_url, + api_key=region.litellm_api_key + ) - # Calculate total spend from all AI keys and build regions list as we go - total_spend = 0.0 - regions_set = set() - - for key in team_keys: - if key.litellm_token and key.region_id in regions_map: - try: - # Use pre-fetched region info and pre-created LiteLLM service - region = regions_map[key.region_id] - litellm_service = litellm_services[region.id] - - # Add region name to our set - regions_set.add(region.name) - - # Get spend data from LiteLLM - key_data = await litellm_service.get_key_info(key.litellm_token) - key_spend = key_data.get("info", {}).get("spend", 0.0) - total_spend += float(key_spend) - except Exception as e: - # Track unreachable endpoint for logging at the end (only once per region) - region = regions_map[key.region_id] - endpoint_info = f"Region: {region.name}" - unreachable_endpoints.add(endpoint_info) - - # Convert set to list for the response - regions = list(regions_set) + # Get team AI keys (both team-owned and user-owned) and calculate total spend + team_users = db.query(DBUser).filter(DBUser.team_id == team.id).all() + team_user_ids = [user.id for user in team_users] + + team_keys = db.query(DBPrivateAIKey).filter( + (DBPrivateAIKey.team_id == team.id) | # Team-owned keys + (DBPrivateAIKey.owner_id.in_(team_user_ids)) # User-owned keys by team members + ).all() + + # Calculate total spend from all AI keys and build regions list as we go + total_spend = 0.0 + regions_set = set() + + for key in team_keys: + if key.litellm_token and key.region_id in regions_map: + try: + # Use pre-fetched region info and pre-created LiteLLM service + region = regions_map[key.region_id] + litellm_service = litellm_services[region.id] + + # Add region name to our set + regions_set.add(region.name) + + # Get spend data from LiteLLM + key_data = await litellm_service.get_key_info(key.litellm_token) + key_spend = key_data.get("info", {}).get("spend", 0.0) + total_spend += float(key_spend) + except Exception as e: + # Track unreachable endpoint for logging at the end (only once per region) + region = regions_map[key.region_id] + endpoint_info = f"Region: {region.name}" + unreachable_endpoints.add(endpoint_info) + + # Convert set to list for the response + regions = list(regions_set) + # Create new metrics record + team_metrics = DBTeamMetrics( + team_id=team.id, + total_spend=total_spend, + last_spend_calculation=current_time, + regions=regions, + last_updated=current_time + ) + db.add(team_metrics) # Calculate trial status trial_status = _calculate_trial_status(team, products) @@ -330,6 +351,8 @@ async def list_teams_for_sales( ) sales_teams.append(sales_team) + # Any metrics calculated on-the-fly will be cached + db.commit() # Log all unreachable endpoints at the end if unreachable_endpoints: diff --git a/app/core/worker.py b/app/core/worker.py index 8377dcf..1e8b35a 100644 --- a/app/core/worker.py +++ b/app/core/worker.py @@ -1,7 +1,7 @@ import re from datetime import datetime, UTC, timedelta from sqlalchemy.orm import Session -from app.db.models import DBTeam, DBProduct, DBTeamProduct, DBPrivateAIKey, DBUser, DBRegion +from app.db.models import DBTeam, DBProduct, DBTeamProduct, DBPrivateAIKey, DBUser, DBRegion, DBTeamMetrics from app.services.litellm import LiteLLMService from app.services.ses import SESService import logging @@ -581,6 +581,31 @@ async def monitor_teams(db: Session): team_name=team.name ).set(team_total) + # Update or create team metrics record + regions_list = list(keys_by_region.keys()) + region_names = [region.name for region in regions_list] + + # Check if metrics record exists + team_metrics = db.query(DBTeamMetrics).filter(DBTeamMetrics.team_id == team.id).first() + + if team_metrics: + logger.info(f"metrics last updated at {team_metrics.last_updated}, curent time is {current_time}") + # Update existing metrics + team_metrics.total_spend = team_total + team_metrics.last_spend_calculation = current_time + team_metrics.regions = region_names + team_metrics.last_updated = current_time + else: + # Create new metrics record + team_metrics = DBTeamMetrics( + team_id=team.id, + total_spend=team_total, + last_spend_calculation=current_time, + regions=region_names, + last_updated=current_time + ) + db.add(team_metrics) + # Update last_monitored timestamp only if notifications were sent if should_send_notifications: team.last_monitored = current_time diff --git a/app/db/models.py b/app/db/models.py index a878621..684c5f9 100644 --- a/app/db/models.py +++ b/app/db/models.py @@ -18,7 +18,7 @@ class DBTeamProduct(Base): """ __tablename__ = "team_products" - team_id = Column(Integer, ForeignKey('teams.id', ondelete='CASCADE'), primary_key=True, nullable=False) + team_id = Column(Integer, ForeignKey('teams.id'), primary_key=True, nullable=False) product_id = Column(String, ForeignKey('products.id', ondelete='CASCADE'), primary_key=True, nullable=False) created_at = Column(DateTime(timezone=True), default=func.now(), nullable=False) updated_at = Column(DateTime(timezone=True), onupdate=func.now(), nullable=True) @@ -38,7 +38,7 @@ class DBTeamRegion(Base): """ __tablename__ = "team_regions" - team_id = Column(Integer, ForeignKey('teams.id', ondelete='CASCADE'), primary_key=True, nullable=False) + team_id = Column(Integer, ForeignKey('teams.id'), primary_key=True, nullable=False) region_id = Column(Integer, ForeignKey('regions.id', ondelete='CASCADE'), primary_key=True, nullable=False) created_at = Column(DateTime(timezone=True), default=func.now(), nullable=False) updated_at = Column(DateTime(timezone=True), onupdate=func.now(), nullable=True) @@ -114,6 +114,31 @@ class DBTeam(Base): private_ai_keys = relationship("DBPrivateAIKey", back_populates="team") active_products = relationship("DBTeamProduct", back_populates="team") dedicated_regions = relationship("DBTeamRegion", back_populates="team") + metrics = relationship("DBTeamMetrics", back_populates="team", uselist=False, cascade="all, delete") + +class DBTeamMetrics(Base): + """ + Cached team metrics table populated by the monitor_teams worker. + This table stores pre-calculated metrics to avoid expensive real-time + LiteLLM API calls in the sales dashboard. + """ + __tablename__ = "team_metrics" + + id = Column(Integer, primary_key=True, index=True) + team_id = Column(Integer, ForeignKey('teams.id', ondelete='CASCADE'), unique=True, nullable=False, index=True) + + # Spend metrics (the expensive calculation we want to cache) + total_spend = Column(Float, default=0.0, nullable=False) + last_spend_calculation = Column(DateTime(timezone=True), nullable=False) + + # Region information (derived from team keys) + regions = Column(JSON, nullable=True) # List of region names + + # Monitoring metadata + last_updated = Column(DateTime(timezone=True), default=func.now(), nullable=False) + + # Relationships + team = relationship("DBTeam", back_populates="metrics") class DBPrivateAIKey(Base): __tablename__ = "ai_tokens" diff --git a/app/migrations/versions/20250127_120000_add_team_metrics_table.py b/app/migrations/versions/20250127_120000_add_team_metrics_table.py new file mode 100644 index 0000000..fa5c21a --- /dev/null +++ b/app/migrations/versions/20250127_120000_add_team_metrics_table.py @@ -0,0 +1,45 @@ +"""add team metrics table + +Revision ID: add_team_metrics_table +Revises: 5bda44ccd008 +Create Date: 2025-01-27 12:00:00.000000+00:00 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'add_team_metrics_table' +down_revision: Union[str, None] = '5bda44ccd008' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('team_metrics', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('team_id', sa.Integer(), nullable=False), + sa.Column('total_spend', sa.Float(), nullable=False), + sa.Column('last_spend_calculation', sa.DateTime(timezone=True), nullable=False), + sa.Column('regions', sa.JSON(), nullable=True), + sa.Column('last_updated', sa.DateTime(timezone=True), nullable=False), + sa.ForeignKeyConstraint(['team_id'], ['teams.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('team_id') + ) + op.create_index(op.f('ix_team_metrics_id'), 'team_metrics', ['id'], unique=False) + op.create_index(op.f('ix_team_metrics_team_id'), 'team_metrics', ['team_id'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_team_metrics_team_id'), table_name='team_metrics') + op.drop_index(op.f('ix_team_metrics_id'), table_name='team_metrics') + op.drop_table('team_metrics') + # ### end Alembic commands ### + diff --git a/tests/test_sales_api.py b/tests/test_sales_api.py index ed29291..e03aa85 100644 --- a/tests/test_sales_api.py +++ b/tests/test_sales_api.py @@ -5,7 +5,8 @@ from datetime import datetime, UTC, timedelta from unittest.mock import patch, AsyncMock from sqlalchemy.orm import Session -from app.db.models import DBTeam, DBProduct, DBTeamProduct, DBPrivateAIKey, DBRegion, DBUser +from app.db.models import DBTeam, DBTeamMetrics, DBTeamProduct, DBPrivateAIKey, DBRegion, DBUser +from app.api.teams import list_teams_for_sales @pytest.fixture @@ -106,7 +107,8 @@ def test_list_teams_for_sales_requires_admin(client, test_team): assert response.status_code == 401 # Unauthorized -def test_list_teams_for_sales_success(client, admin_token, test_team, test_product, +@patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) +def test_list_teams_for_sales_success(mock_get_info, client, admin_token, test_team, test_product, test_region, test_ai_key, mock_litellm_response, db): """Test successful retrieval of sales data.""" # Create team-product association @@ -117,14 +119,25 @@ def test_list_teams_for_sales_success(client, admin_token, test_team, test_produ db.add(team_product) db.commit() + validation_data = { + "id": test_team.id, + "name": test_team.name, + "admin_email": test_team.admin_email, + "is_always_free": False, + "products": [{"id": test_product.id, "name": test_product.name, "active": True}], + "regions": [test_region.name], + "total_spend": 25.50, + "trial_status": "Active Product", + "last_payment": None, + } + # Mock LiteLLM service - with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info: - mock_get_info.return_value = mock_litellm_response + mock_get_info.return_value = mock_litellm_response - response = client.get( - "/teams/sales/list-teams", - headers={"Authorization": f"Bearer {admin_token}"} - ) + response = client.get( + "/teams/sales/list-teams", + headers={"Authorization": f"Bearer {admin_token}"} + ) assert response.status_code == 200 data = response.json() @@ -132,21 +145,12 @@ def test_list_teams_for_sales_success(client, admin_token, test_team, test_produ assert len(data["teams"]) == 1 team_data = data["teams"][0] - assert team_data["id"] == test_team.id - assert team_data["name"] == test_team.name - assert team_data["admin_email"] == test_team.admin_email - assert team_data["is_always_free"] == False - assert len(team_data["products"]) == 1 - assert team_data["products"][0]["id"] == test_product.id - assert team_data["products"][0]["name"] == test_product.name - assert team_data["products"][0]["active"] == True - assert len(team_data["regions"]) == 1 - assert team_data["regions"][0] == test_region.name - assert team_data["total_spend"] == 25.50 - assert team_data["trial_status"] == "Active Product" + team_data.pop("created_at", None) + assert team_data == validation_data -def test_always_free_team_trial_status(client, admin_token, test_always_free_team, +@patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) +def test_always_free_team_trial_status(mock_get_info, client, admin_token, test_always_free_team, test_product, test_region, test_ai_key, db): """Test that always-free teams show correct trial status.""" # Create team-product association @@ -158,13 +162,12 @@ def test_always_free_team_trial_status(client, admin_token, test_always_free_tea db.commit() # Mock LiteLLM service - with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info: - mock_get_info.return_value = {"info": {"spend": 0.0}} + mock_get_info.return_value = {"info": {"spend": 0.0}} - response = client.get( - "/teams/sales/list-teams", - headers={"Authorization": f"Bearer {admin_token}"} - ) + response = client.get( + "/teams/sales/list-teams", + headers={"Authorization": f"Bearer {admin_token}"} + ) assert response.status_code == 200 data = response.json() @@ -172,7 +175,8 @@ def test_always_free_team_trial_status(client, admin_token, test_always_free_tea assert team_data["trial_status"] == "Always Free" -def test_paid_team_trial_status(client, admin_token, test_paid_team, +@patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) +def test_paid_team_trial_status(mock_get_info, client, admin_token, test_paid_team, test_product, test_region, test_ai_key, db): """Test that teams with payment history show correct trial status.""" # Create team-product association @@ -184,13 +188,12 @@ def test_paid_team_trial_status(client, admin_token, test_paid_team, db.commit() # Mock LiteLLM service - with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info: - mock_get_info.return_value = {"info": {"spend": 0.0}} + mock_get_info.return_value = {"info": {"spend": 0.0}} - response = client.get( - "/teams/sales/list-teams", - headers={"Authorization": f"Bearer {admin_token}"} - ) + response = client.get( + "/teams/sales/list-teams", + headers={"Authorization": f"Bearer {admin_token}"} + ) assert response.status_code == 200 data = response.json() @@ -199,16 +202,16 @@ def test_paid_team_trial_status(client, admin_token, test_paid_team, assert team_data["trial_status"] == "Active Product" -def test_team_without_products(client, admin_token, test_team, test_region, test_ai_key, db): +@patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) +def test_team_without_products(mock_get_info, client, admin_token, test_team, test_region, test_ai_key, db): """Test team without any products shows trial status based on creation date.""" # Mock LiteLLM service - with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info: - mock_get_info.return_value = {"info": {"spend": 0.0}} + mock_get_info.return_value = {"info": {"spend": 0.0}} - response = client.get( - "/teams/sales/list-teams", - headers={"Authorization": f"Bearer {admin_token}"} - ) + response = client.get( + "/teams/sales/list-teams", + headers={"Authorization": f"Bearer {admin_token}"} + ) assert response.status_code == 200 data = response.json() @@ -218,8 +221,8 @@ def test_team_without_products(client, admin_token, test_team, test_region, test assert team_data["trial_status"] == "30 days left" assert len(team_data["products"]) == 0 - -def test_team_with_multiple_ai_keys(client, admin_token, test_team, test_product, +@patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) +def test_team_with_multiple_ai_keys(mock_get_info, client, admin_token, test_team, test_product, test_region, test_ai_key, db): """Test team with multiple AI keys aggregates spend correctly.""" # Create second AI key @@ -247,24 +250,23 @@ def test_team_with_multiple_ai_keys(client, admin_token, test_team, test_product db.commit() # Mock LiteLLM service with different spend values - with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info: - mock_get_info.side_effect = [ - {"info": {"spend": 25.50}}, - {"info": {"spend": 15.25}} - ] + mock_get_info.side_effect = [ + {"info": {"spend": 25.50}}, + {"info": {"spend": 15.25}} + ] - response = client.get( - "/teams/sales/list-teams", - headers={"Authorization": f"Bearer {admin_token}"} - ) + response = client.get( + "/teams/sales/list-teams", + headers={"Authorization": f"Bearer {admin_token}"} + ) assert response.status_code == 200 data = response.json() team_data = data["teams"][0] assert team_data["total_spend"] == 40.75 # 25.50 + 15.25 - -def test_litellm_service_error_handling(client, admin_token, test_team, test_product, +@patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) +def test_litellm_service_error_handling(mock_get_info, client, admin_token, test_team, test_product, test_region, test_ai_key, db): """Test that LiteLLM service errors don't break the entire response.""" # Create team-product association @@ -276,13 +278,12 @@ def test_litellm_service_error_handling(client, admin_token, test_team, test_pro db.commit() # Mock LiteLLM service to raise an exception - with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info: - mock_get_info.side_effect = Exception("LiteLLM service error") + mock_get_info.side_effect = Exception("LiteLLM service error") - response = client.get( - "/teams/sales/list-teams", - headers={"Authorization": f"Bearer {admin_token}"} - ) + response = client.get( + "/teams/sales/list-teams", + headers={"Authorization": f"Bearer {admin_token}"} + ) # Should still succeed but with 0 spend assert response.status_code == 200 @@ -291,20 +292,20 @@ def test_litellm_service_error_handling(client, admin_token, test_team, test_pro assert team_data["total_spend"] == 0.0 -def test_expired_trial_status(client, admin_token, test_team, test_region, test_ai_key, db): +@patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) +def test_expired_trial_status(mock_get_info, client, admin_token, test_team, test_region, test_ai_key, db): """Test that expired trials show correct status.""" # Update team to be older than 30 days (no products, so should show expired) test_team.created_at = datetime.now(UTC) - timedelta(days=35) db.commit() # Mock LiteLLM service - with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info: - mock_get_info.return_value = {"info": {"spend": 0.0}} + mock_get_info.return_value = {"info": {"spend": 0.0}} - response = client.get( - "/teams/sales/list-teams", - headers={"Authorization": f"Bearer {admin_token}"} - ) + response = client.get( + "/teams/sales/list-teams", + headers={"Authorization": f"Bearer {admin_token}"} + ) assert response.status_code == 200 data = response.json() @@ -312,7 +313,8 @@ def test_expired_trial_status(client, admin_token, test_team, test_region, test_ assert team_data["trial_status"] == "Expired" -def test_list_teams_for_sales_includes_user_owned_keys(client, admin_token, test_team, test_product, +@patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) +def test_list_teams_for_sales_includes_user_owned_keys(mock_get_info, client, admin_token, test_team, test_product, test_region, test_ai_key, test_team_user, test_user_owned_ai_key, mock_litellm_response, db): """Test that sales data includes both team-owned and user-owned AI keys.""" @@ -323,19 +325,20 @@ def test_list_teams_for_sales_includes_user_owned_keys(client, admin_token, test ) db.add(team_product) db.commit() + team_id = test_team.id + region_name = test_region.name # Mock LiteLLM service to return different spend for each key - with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info: - # Return different spend values for team key vs user key - mock_get_info.side_effect = [ - {"info": {"spend": 25.50}}, # Team key spend - {"info": {"spend": 15.25}} # User key spend - ] - - response = client.get( - "/teams/sales/list-teams", - headers={"Authorization": f"Bearer {admin_token}"} - ) + # Return different spend values for team key vs user key + mock_get_info.side_effect = [ + {"info": {"spend": 25.50}}, # Team key spend + {"info": {"spend": 15.25}} # User key spend + ] + + response = client.get( + "/teams/sales/list-teams", + headers={"Authorization": f"Bearer {admin_token}"} + ) assert response.status_code == 200 data = response.json() @@ -343,30 +346,30 @@ def test_list_teams_for_sales_includes_user_owned_keys(client, admin_token, test assert len(data["teams"]) == 1 team_data = data["teams"][0] - assert team_data["id"] == test_team.id + assert team_data["id"] == team_id # Should include both keys in total spend (25.50 + 15.25 = 40.75) assert team_data["total_spend"] == 40.75 # Should include region from both keys assert len(team_data["regions"]) == 1 - assert team_data["regions"][0] == test_region.name + assert team_data["regions"][0] == region_name -def test_team_with_15_days_remaining(client, admin_token, test_team, test_region, test_ai_key, db): +@patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) +def test_team_with_15_days_remaining(mock_get_info, client, admin_token, test_team, test_region, test_ai_key, db): """Test that teams with 15 days remaining show correct format.""" # Update team to be 15 days old test_team.created_at = datetime.now(UTC) - timedelta(days=15) db.commit() # Mock LiteLLM service - with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info: - mock_get_info.return_value = {"info": {"spend": 0.0}} + mock_get_info.return_value = {"info": {"spend": 0.0}} - response = client.get( - "/teams/sales/list-teams", - headers={"Authorization": f"Bearer {admin_token}"} - ) + response = client.get( + "/teams/sales/list-teams", + headers={"Authorization": f"Bearer {admin_token}"} + ) assert response.status_code == 200 data = response.json() @@ -375,20 +378,20 @@ def test_team_with_15_days_remaining(client, admin_token, test_team, test_region assert team_data["trial_status"] == "15 days left" -def test_team_with_7_days_remaining(client, admin_token, test_team, test_region, test_ai_key, db): +@patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) +def test_team_with_7_days_remaining(mock_get_info, client, admin_token, test_team, test_region, test_ai_key, db): """Test that teams with 7 days remaining show correct format.""" # Update team to be 23 days old (7 days remaining) test_team.created_at = datetime.now(UTC) - timedelta(days=23) db.commit() # Mock LiteLLM service - with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info: - mock_get_info.return_value = {"info": {"spend": 0.0}} + mock_get_info.return_value = {"info": {"spend": 0.0}} - response = client.get( - "/teams/sales/list-teams", - headers={"Authorization": f"Bearer {admin_token}"} - ) + response = client.get( + "/teams/sales/list-teams", + headers={"Authorization": f"Bearer {admin_token}"} + ) assert response.status_code == 200 data = response.json() @@ -397,20 +400,20 @@ def test_team_with_7_days_remaining(client, admin_token, test_team, test_region, assert team_data["trial_status"] == "7 days left" -def test_team_with_last_payment_days_calculation(client, admin_token, test_team, test_region, test_ai_key, db): +@patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) +def test_team_with_last_payment_days_calculation(mock_get_info, client, admin_token, test_team, test_region, test_ai_key, db): """Test that teams with last_payment calculate days remaining from payment date.""" # Set last_payment to 10 days ago (so 20 days remaining from payment) test_team.last_payment = datetime.now(UTC) - timedelta(days=10) db.commit() # Mock LiteLLM service - with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info: - mock_get_info.return_value = {"info": {"spend": 0.0}} + mock_get_info.return_value = {"info": {"spend": 0.0}} - response = client.get( - "/teams/sales/list-teams", - headers={"Authorization": f"Bearer {admin_token}"} - ) + response = client.get( + "/teams/sales/list-teams", + headers={"Authorization": f"Bearer {admin_token}"} + ) assert response.status_code == 200 data = response.json() @@ -476,7 +479,7 @@ def test_list_teams_for_sales_multiple_unreachable_regions(mock_get_info, client Then the function should default spend to 0 and include all regions in response """ # Create a second region - region2 = DBRegion( + region_two = DBRegion( name="Test Region 2", postgres_host="test-host-2", postgres_port=5432, @@ -486,7 +489,7 @@ def test_list_teams_for_sales_multiple_unreachable_regions(mock_get_info, client litellm_api_key="test-key-2", is_active=True ) - db.add(region2) + db.add(region_two) db.commit() # Create team-product association @@ -507,10 +510,12 @@ def test_list_teams_for_sales_multiple_unreachable_regions(mock_get_info, client litellm_api_url="https://test-litellm-2.com", owner_id=None, team_id=test_team.id, - region_id=region2.id + region_id=region_two.id ) db.add(ai_key2) db.commit() + region_name = test_region.name + region_two_name = region_two.name # Mock LiteLLM service to raise different exceptions for different regions mock_get_info.side_effect = [ @@ -533,5 +538,100 @@ def test_list_teams_for_sales_multiple_unreachable_regions(mock_get_info, client # Should include both regions in the response assert len(team_data["regions"]) == 2 - assert test_region.name in team_data["regions"] - assert region2.name in team_data["regions"] + assert region_name in team_data["regions"] + assert region_two_name in team_data["regions"] + + +@pytest.mark.asyncio +@patch('app.api.teams.LiteLLMService') +async def test_list_teams_for_sales_uses_cached_metrics(mock_litellm_class, db, test_team, test_region): + """ + Test that list_teams_for_sales API uses cached metrics instead of making API calls. + + GIVEN: A team with cached metrics in DBTeamMetrics + WHEN: list_teams_for_sales API is called + THEN: It returns data from cache without making LiteLLM API calls + """ + # Arrange + # Create cached metrics + cached_metrics = DBTeamMetrics( + team_id=test_team.id, + total_spend=200.50, + last_spend_calculation=datetime.now(UTC), + regions=[test_region.name], + last_updated=datetime.now(UTC) + ) + db.add(cached_metrics) + db.commit() + + # Create a test key (but we shouldn't need to call LiteLLM for it) + test_key = DBPrivateAIKey( + name="test-key", + team_id=test_team.id, + region_id=test_region.id, + litellm_token="test-token-123", + created_at=datetime.now(UTC) + ) + db.add(test_key) + db.commit() + + # Ensure LiteLLM service is never called + mock_litellm_class.side_effect = AssertionError("LiteLLM service should not be called when using cached metrics") + + # Act + response = await list_teams_for_sales(db) + + # Assert + assert len(response.teams) == 1 + team_data = response.teams[0] + assert team_data.id == test_team.id + assert team_data.total_spend == 200.50 + assert test_region.name in team_data.regions + + +@pytest.mark.asyncio +@patch('app.api.teams.LiteLLMService') +async def test_list_teams_for_sales_fallback_to_api_calls(mock_litellm_class, db, test_team, test_region): + """ + Test that list_teams_for_sales falls back to API calls when no cached metrics exist. + + GIVEN: A team without cached metrics + WHEN: list_teams_for_sales API is called + THEN: It makes LiteLLM API calls and returns the data + """ + # Arrange + # Create a test key + test_key = DBPrivateAIKey( + name="test-key", + team_id=test_team.id, + region_id=test_region.id, + litellm_token="test-token-123", + created_at=datetime.now(UTC) + ) + db.add(test_key) + db.commit() + + # Mock LiteLLM service responses + mock_litellm_service = AsyncMock() + mock_litellm_class.return_value = mock_litellm_service + mock_litellm_service.get_key_info.return_value = { + "info": { + "spend": 300.25, + "max_budget": 500.0, + "key_alias": "test-key" + } + } + + # Act + response = await list_teams_for_sales(db) + updated_metrics = db.query(DBTeamMetrics).filter(DBTeamMetrics.team_id == test_team.id).first() + + # Assert + assert len(response.teams) == 1 + team_data = response.teams[0] + assert team_data.id == test_team.id + assert team_data.total_spend == 300.25 + assert test_region.name in team_data.regions + assert updated_metrics is not None + assert updated_metrics.total_spend == 300.25 + assert test_region.name in updated_metrics.regions diff --git a/tests/test_worker.py b/tests/test_worker.py index 5024ce4..68c21c6 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -1,5 +1,5 @@ import pytest -from app.db.models import DBProduct, DBTeamProduct, DBPrivateAIKey, DBTeam, DBUser +from app.db.models import DBProduct, DBTeamProduct, DBPrivateAIKey, DBTeam, DBTeamMetrics from datetime import datetime, UTC, timedelta from app.core.worker import ( apply_product_for_team, @@ -1832,3 +1832,104 @@ async def test_monitor_team_keys_expiry_beyond_next_month(mock_litellm, db, test # Verify team total spend is calculated correctly assert team_total == 10.0 + +@pytest.mark.asyncio +@patch('app.core.worker.LiteLLMService') +async def test_monitor_teams_populates_team_metrics(mock_litellm_class, db, test_team, test_region): + """ + Test that monitor_teams function populates DBTeamMetrics table. + + GIVEN: A team with AI keys and regions + WHEN: monitor_teams is called + THEN: DBTeamMetrics record is created/updated with spend data + """ + # Arrange + # Create a test key for the team + test_key = DBPrivateAIKey( + name="test-key", + team_id=test_team.id, + region_id=test_region.id, + litellm_token="test-token-123", + created_at=datetime.now(UTC) + ) + db.add(test_key) + db.commit() + + # Mock LiteLLM service responses + mock_litellm_service = AsyncMock() + mock_litellm_class.return_value = mock_litellm_service + mock_litellm_service.get_key_info.return_value = { + "info": { + "spend": 75.50, + "max_budget": 100.0, + "key_alias": "test-key" + } + } + + # Act + await monitor_teams(db) + + # Assert + metrics = db.query(DBTeamMetrics).filter(DBTeamMetrics.team_id == test_team.id).first() + assert metrics is not None + assert metrics.total_spend == 75.50 + assert test_region.name in metrics.regions + assert metrics.last_spend_calculation is not None + + +@pytest.mark.asyncio +@patch('app.core.worker.LiteLLMService') +async def test_monitor_teams_updates_existing_metrics(mock_litellm_class, db, test_team, test_region): + """ + Test that monitor_teams updates existing DBTeamMetrics records. + + GIVEN: A team with existing metrics record + WHEN: monitor_teams is called again + THEN: The existing metrics record is updated with new data + """ + # Arrange + # Create existing metrics with a fixed old timestamp + old_timestamp = datetime.now(UTC) - timedelta(hours=1) + existing_metrics = DBTeamMetrics( + team_id=test_team.id, + total_spend=50.0, + last_spend_calculation=old_timestamp, + regions=["old-region"], + last_updated=old_timestamp + ) + db.add(existing_metrics) + db.commit() + old_update_date = existing_metrics.last_updated + + # Create a test key + test_key = DBPrivateAIKey( + name="test-key", + team_id=test_team.id, + region_id=test_region.id, + litellm_token="test-token-123", + created_at=datetime.now(UTC) + ) + db.add(test_key) + db.commit() + + # Mock LiteLLM service responses + mock_litellm_service = AsyncMock() + mock_litellm_class.return_value = mock_litellm_service + mock_litellm_service.get_key_info.return_value = { + "info": { + "spend": 125.75, + "max_budget": 200.0, + "key_alias": "test-key" + } + } + + # Act + await monitor_teams(db) + + # Assert + updated_metrics = db.query(DBTeamMetrics).filter(DBTeamMetrics.team_id == test_team.id).first() + assert updated_metrics is not None + assert updated_metrics.total_spend == 125.75 + assert test_region.name in updated_metrics.regions + assert updated_metrics.last_updated > old_update_date +