From d875a8ed2666f9e848bd3f7c738d3bbff02a7787 Mon Sep 17 00:00:00 2001 From: Pippa H Date: Wed, 14 May 2025 10:34:32 +0200 Subject: [PATCH 01/23] Add endpoint to extend private AI key duration - Introduced a new PUT endpoint `/private-ai-keys/{key_id}/extend-token-life` to allow users to extend the life of their private AI keys. - Added `TokenDurationUpdate` schema for specifying the duration to extend. - Implemented logic to verify user access and update the key duration in LiteLLM. - Enhanced `LiteLLMService` with a method to update key duration and handle potential errors. - Updated tests to validate the new functionality and ensure correct duration handling based on environment variables. --- app/api/private_ai_keys.py | 56 +++++++++++++++++++++++++++++++++++++- app/schemas/models.py | 4 +++ app/services/litellm.py | 29 +++++++++++++++++++- tests/test_private_ai.py | 36 ++++++++++++++++++++++-- 4 files changed, 121 insertions(+), 4 deletions(-) diff --git a/app/api/private_ai_keys.py b/app/api/private_ai_keys.py index 9264da0..ce1af1b 100644 --- a/app/api/private_ai_keys.py +++ b/app/api/private_ai_keys.py @@ -7,7 +7,8 @@ from app.db.database import get_db from app.schemas.models import ( PrivateAIKey, PrivateAIKeyCreate, PrivateAIKeySpend, - BudgetPeriodUpdate, LiteLLMToken, VectorDBCreate, VectorDB + BudgetPeriodUpdate, LiteLLMToken, VectorDBCreate, VectorDB, + TokenDurationUpdate ) from app.db.postgres import PostgresManager from app.db.models import DBPrivateAIKey, DBRegion, DBUser, DBTeam @@ -571,3 +572,56 @@ async def update_budget_period( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to update budget period: {str(e)}" ) + +@router.put("/{key_id}/extend-token-life") +async def extend_token_life( + key_id: int, + duration_update: TokenDurationUpdate, + current_user = Depends(get_current_user_from_auth), + user_role: UserRole = Depends(get_role_min_team_admin), + db: Session = Depends(get_db) +): + """ + Extend the life of a private AI key. + + This endpoint will: + 1. Verify the user has access to the key + 2. Update the key's duration in LiteLLM + 3. Return the updated key information + + Required parameters: + - **duration**: The amount of time to add to the key's life (e.g. "30d" for 30 days, "1y" for 1 year) + + Note: You must be authenticated to use this endpoint. + Only the owner of the key or an admin can update it. + """ + private_ai_key = _get_key_if_allowed(key_id, current_user, user_role, db) + + # Get the region + region = db.query(DBRegion).filter(DBRegion.id == private_ai_key.region_id).first() + if not region: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Region not found" + ) + + litellm_service = LiteLLMService( + api_url=region.litellm_api_url, + api_key=region.litellm_api_key + ) + + try: + # Update key duration in LiteLLM + await litellm_service.update_key_duration( + litellm_token=private_ai_key.litellm_token, + duration=duration_update.duration + ) + + # Get updated key information + key_data = await litellm_service.get_key_info(private_ai_key.litellm_token) + return key_data + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to extend token life: {str(e)}" + ) diff --git a/app/schemas/models.py b/app/schemas/models.py index b243c11..ea49ee8 100644 --- a/app/schemas/models.py +++ b/app/schemas/models.py @@ -150,6 +150,10 @@ class PrivateAIKey(PrivateAIKeyBase): class BudgetPeriodUpdate(BaseModel): budget_duration: str +class TokenDurationUpdate(BaseModel): + """Schema for updating a token's duration""" + duration: str # e.g. "30d" for 30 days, "1y" for 1 year + class PrivateAIKeySpend(BaseModel): spend: float expires: datetime diff --git a/app/services/litellm.py b/app/services/litellm.py index a31a9fa..112224e 100644 --- a/app/services/litellm.py +++ b/app/services/litellm.py @@ -20,7 +20,7 @@ async def create_key(self, email: str, name: str, user_id: int, team_id: str) -> try: logger.info(f"Creating new LiteLLM API key for email: {email}, name: {name}, user_id: {user_id}, team_id: {team_id}") request_data = { - "duration": "8760h", # Set token duration to 1 year (365 days * 24 hours) + "duration": "14d" if os.getenv("EXPIRE_KEYS", "").lower() == "true" else "365d", # 14 days if EXPIRE_KEYS is true, otherwise 1 year "models": ["all-team-models"], # Allow access to all models "aliases": {}, "config": {}, @@ -141,3 +141,30 @@ async def update_budget(self, litellm_token: str, budget_duration: str): status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to update LiteLLM budget: {error_msg}" ) + + async def update_key_duration(self, litellm_token: str, duration: str): + """Update the duration for a LiteLLM API key""" + try: + response = requests.post( + f"{self.api_url}/key/update", + headers={ + "Authorization": f"Bearer {self.master_key}" + }, + json={ + "key": litellm_token, + "duration": duration + } + ) + response.raise_for_status() + except requests.exceptions.RequestException as e: + error_msg = str(e) + if hasattr(e, 'response') and e.response is not None: + try: + error_details = e.response.json() + error_msg = f"Status {e.response.status_code}: {error_details}" + except ValueError: + error_msg = f"Status {e.response.status_code}: {e.response.text}" + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to update LiteLLM key duration: {error_msg}" + ) diff --git a/tests/test_private_ai.py b/tests/test_private_ai.py index 6a7ba4d..2e8b186 100644 --- a/tests/test_private_ai.py +++ b/tests/test_private_ai.py @@ -5,6 +5,7 @@ import logging from datetime import datetime, UTC from app.core.security import get_password_hash +import os @pytest.fixture def mock_litellm_response(): @@ -980,7 +981,6 @@ def test_update_budget_period_as_key_creator(mock_post, client, team_key_creator db.commit() @patch("app.services.litellm.requests.post") -@patch("app.services.litellm.requests.get") def test_update_budget_duration_as_team_admin(mock_get, mock_post, client, team_admin_token, test_region, mock_litellm_response, db, test_team): """Test that a team admin can update the budget duration for a team-owned key""" # Mock the LiteLLM API responses @@ -1096,6 +1096,38 @@ def test_create_llm_token_as_system_admin(mock_post, client, admin_token, test_r for key in list_data ) +@patch("app.services.litellm.requests.post") +@patch.dict(os.environ, {"EXPIRE_KEYS": "true"}) +def test_create_llm_token_with_expiration(mock_post, client, admin_token, test_region, mock_litellm_response): + """Test that when EXPIRE_KEYS is true, new LiteLLM tokens are created with a 14-day expiration duration""" + # Mock the LiteLLM API response + mock_post.return_value.status_code = 200 + mock_post.return_value.json.return_value = mock_litellm_response + mock_post.return_value.raise_for_status.return_value = None + + # Create LLM token + response = client.post( + "/private-ai-keys/token", + headers={"Authorization": f"Bearer {admin_token}"}, + json={ + "region_id": test_region.id, + "name": "Test LLM Token with Expiration" + } + ) + + # Verify the create response + assert response.status_code == 200 + data = response.json() + assert data["litellm_token"] == "test-private-key-123" + assert data["litellm_api_url"] == test_region.litellm_api_url + assert data["region"] == test_region.name + assert data["name"] == "Test LLM Token with Expiration" + + # Verify that the LiteLLM API was called with the correct duration + mock_post.assert_called_once() + call_args = mock_post.call_args[1] + assert call_args["json"]["duration"] == "14d" # Verify 14-day duration format + def test_create_vector_db_as_system_admin(client, admin_token, test_region): """Test that a system admin can create a vector database for themselves""" region_name = test_region.name @@ -1323,4 +1355,4 @@ def test_list_private_ai_keys_as_non_team_user(mock_post, client, admin_token, t assert len(data) == 1 assert data[0]["name"] == "user-owned-key" assert data[0]["owner_id"] == test_user.id - assert data[0].get("team_id") is None \ No newline at end of file + assert data[0].get("team_id") is None From a4906f54646350b341aa80ac07082a47e289b338 Mon Sep 17 00:00:00 2001 From: Pippa H Date: Fri, 16 May 2025 08:56:35 +0200 Subject: [PATCH 02/23] Update dependencies in requirements files - Updated `requests-mock` to version 1.12.1 in `requirements-test.txt`. - Upgraded several packages in `requirements.txt`: - `uvicorn` to 0.34.2 - `sqlalchemy` to 2.0.41 - `pydantic` to 2.11.4 - `pydantic-settings` to 2.9.1 - `asyncpg` to 0.30.0 - `psycopg2-binary` to 2.9.10 - `requests` to 2.32.3 - `python-dotenv` to 1.1.0 - `boto3` to 1.38.20 - `markdown` to 3.8.0 - `prometheus-client` to 0.21.1 - `prometheus-fastapi-instrumentator` to 7.0.0 --- requirements-test.txt | 2 +- requirements.txt | 22 +++++++++++----------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/requirements-test.txt b/requirements-test.txt index 9e78e23..f958534 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -2,6 +2,6 @@ pytest==8.3.3 pytest-asyncio==0.23.5 pytest-cov==4.1.0 httpx==0.28.1 -requests-mock==1.11.0 +requests-mock==1.12.1 faker==22.7.0 pytest-env==1.1.5 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 3e4a166..75005b5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,18 +1,18 @@ fastapi==0.115.12 -uvicorn==0.27.1 -sqlalchemy==2.0.40 -pydantic==2.6.1 -pydantic-settings==2.1.0 +uvicorn==0.34.2 +sqlalchemy==2.0.41 +pydantic==2.11.4 +pydantic-settings==2.9.1 python-jose[cryptography]==3.4.0 passlib[bcrypt]==1.7.4 python-multipart==0.0.18 -asyncpg==0.29.0 -psycopg2-binary==2.9.9 -requests==2.32.2 -python-dotenv==1.0.1 +asyncpg==0.30.0 +psycopg2-binary==2.9.10 +requests==2.32.3 +python-dotenv==1.1.0 alembic==1.15.2 boto3==1.38.20 -markdown==3.5.2 +markdown==3.8.0 email-validator==2.1.2 -prometheus-client==0.19.0 -prometheus-fastapi-instrumentator==6.1.0 +prometheus-client==0.21.1 +prometheus-fastapi-instrumentator==7.0.0 From 674d332626b0f7be8ff9ac8ae22473e6ade4062f Mon Sep 17 00:00:00 2001 From: Pippa H Date: Fri, 16 May 2025 08:57:01 +0200 Subject: [PATCH 03/23] Add endpoint to retrieve private AI key details - Introduced a new GET endpoint `/private-ai-keys/{key_id}` to fetch detailed information about a specific private AI key. - Implemented access control to ensure only system administrators can access this endpoint. - Enhanced the `PrivateAIKeyDetail` schema to include additional fields such as spend, key name, key alias, and LiteLLM-specific data. - Updated logging to capture key retrieval processes and potential errors. - Add basic test cases for the new API --- app/api/private_ai_keys.py | 81 +++++++++++++++++++++++- app/schemas/models.py | 18 ++++++ tests/test_private_ai.py | 122 ++++++++++++++++++++++++++++++++++++- tests/test_users.py | 29 +++++++++ 4 files changed, 247 insertions(+), 3 deletions(-) diff --git a/app/api/private_ai_keys.py b/app/api/private_ai_keys.py index ce1af1b..2ae52f6 100644 --- a/app/api/private_ai_keys.py +++ b/app/api/private_ai_keys.py @@ -8,12 +8,12 @@ from app.schemas.models import ( PrivateAIKey, PrivateAIKeyCreate, PrivateAIKeySpend, BudgetPeriodUpdate, LiteLLMToken, VectorDBCreate, VectorDB, - TokenDurationUpdate + TokenDurationUpdate, PrivateAIKeyDetail ) from app.db.postgres import PostgresManager from app.db.models import DBPrivateAIKey, DBRegion, DBUser, DBTeam from app.services.litellm import LiteLLMService -from app.core.security import get_current_user_from_auth, get_role_min_key_creator, get_role_min_team_admin, UserRole +from app.core.security import get_current_user_from_auth, get_role_min_key_creator, get_role_min_team_admin, UserRole, check_system_admin # Set up logging logger = logging.getLogger(__name__) @@ -391,6 +391,83 @@ async def list_private_ai_keys( private_ai_keys = query.all() return [key.to_dict() for key in private_ai_keys] +@router.get("/{key_id}", response_model=PrivateAIKeyDetail, dependencies=[Depends(check_system_admin)]) +async def get_private_ai_key( + key_id: int, + current_user = Depends(get_current_user_from_auth), + db: Session = Depends(get_db) +): + """ + Get details of a specific private AI key. + + This endpoint will: + 1. Verify the user has access to the key + 2. Return the full details of the key including LiteLLM-specific data + + Required parameters: + - **key_id**: The ID of the private AI key to retrieve + + The response will include: + - Database connection details (host, database name, username, password) + - LiteLLM API token for authentication + - LiteLLM API URL for making requests + - Owner and team information + - Region information + - LiteLLM-specific data (spend, duration, budget, etc.) + + Note: You must be authenticated to use this endpoint. + Only system administrators can access this endpoint. + """ + private_ai_key = _get_key_if_allowed(key_id, current_user, "system_admin", db) + + # Get the region + region = db.query(DBRegion).filter(DBRegion.id == private_ai_key.region_id).first() + if not region: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Region not found" + ) + + # Create LiteLLM service instance + litellm_service = LiteLLMService( + api_url=region.litellm_api_url, + api_key=region.litellm_api_key + ) + + try: + # Get LiteLLM key info + litellm_data = await litellm_service.get_key_info(private_ai_key.litellm_token) + info = litellm_data.get("info", {}) + logger.info(f"LiteLLM key info: {info}") + + # Combine database key info with LiteLLM info + key_data = private_ai_key.to_dict() + key_data.update({ + "spend": info.get("spend", 0.0), + "key_name": info.get("key_name"), + "key_alias": info.get("key_alias"), + "soft_budget_cooldown": info.get("soft_budget_cooldown"), + "models": info.get("models"), + "max_parallel_requests": info.get("max_parallel_requests"), + "tpm_limit": info.get("tpm_limit"), + "rpm_limit": info.get("rpm_limit"), + "max_budget": info.get("max_budget"), + "budget_duration": info.get("budget_duration"), + "budget_reset_at": info.get("budget_reset_at"), + "expires_at": info.get("expires"), + "created_at": info.get("created_at"), + "updated_at": info.get("updated_at"), + "metadata": info.get("metadata") + }) + + return PrivateAIKeyDetail.model_validate(key_data) + except Exception as e: + logger.error(f"Failed to get Private AI Key details: {str(e)}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to get Private AI Key details: {str(e)}" + ) + def _get_key_if_allowed(key_id: int, current_user: DBUser, user_role: UserRole, db: Session) -> DBPrivateAIKey: # First try to find the key private_ai_key = db.query(DBPrivateAIKey).filter( diff --git a/app/schemas/models.py b/app/schemas/models.py index ea49ee8..a5ad334 100644 --- a/app/schemas/models.py +++ b/app/schemas/models.py @@ -147,6 +147,24 @@ class PrivateAIKey(PrivateAIKeyBase): team_id: Optional[int] = None model_config = ConfigDict(from_attributes=True) +class PrivateAIKeyDetail(PrivateAIKey): + spend: Optional[float] = None + key_name: Optional[str] = None + key_alias: Optional[str] = None + soft_budget_cooldown: Optional[bool] = None + models: Optional[List[str]] = None + max_parallel_requests: Optional[int] = None + tpm_limit: Optional[int] = None + rpm_limit: Optional[int] = None + max_budget: Optional[float] = None + budget_duration: Optional[str] = None + budget_reset_at: Optional[datetime] = None + expires_at: Optional[datetime] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + metadata: Optional[dict] = None + model_config = ConfigDict(from_attributes=True) + class BudgetPeriodUpdate(BaseModel): budget_duration: str diff --git a/tests/test_private_ai.py b/tests/test_private_ai.py index 2e8b186..8ba9e79 100644 --- a/tests/test_private_ai.py +++ b/tests/test_private_ai.py @@ -980,8 +980,9 @@ def test_update_budget_period_as_key_creator(mock_post, client, team_key_creator db.delete(test_key) db.commit() +@patch("app.services.litellm.requests.get") @patch("app.services.litellm.requests.post") -def test_update_budget_duration_as_team_admin(mock_get, mock_post, client, team_admin_token, test_region, mock_litellm_response, db, test_team): +def test_update_budget_duration_as_team_admin(mock_post, mock_get, client, team_admin_token, test_region, mock_litellm_response, db, test_team): """Test that a team admin can update the budget duration for a team-owned key""" # Mock the LiteLLM API responses mock_post.return_value.status_code = 200 @@ -1356,3 +1357,122 @@ def test_list_private_ai_keys_as_non_team_user(mock_post, client, admin_token, t assert data[0]["name"] == "user-owned-key" assert data[0]["owner_id"] == test_user.id assert data[0].get("team_id") is None + +@patch("app.services.litellm.requests.get") +def test_get_private_ai_key_success(mock_get, client, admin_token, test_region, db, test_team): + """Test successfully retrieving a private AI key""" + region_id = test_region.id + region_name = test_region.name + # Create a test key owned by the team + test_key = DBPrivateAIKey( + database_name="test-db-get", + name="Test Key for Get", + database_host="test-host", + database_username="test-user", + database_password="test-pass", + litellm_token="test-token-get", + litellm_api_url="https://test-litellm.com", + team_id=test_team.id, + region_id=region_id + ) + db.add(test_key) + db.commit() + db.refresh(test_key) + + # Mock the LiteLLM API response for key info + mock_get.return_value.status_code = 200 + mock_get.return_value.json.return_value = { + "info": { + "key_name": "Test Key for Get", + "key_alias": "test-key-alias", + "spend": 0.0, + "expires": "2024-12-31T23:59:59Z", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-02T00:00:00Z", + "max_budget": 100.0, + "budget_duration": "monthly", + "budget_reset_at": "2024-02-01T00:00:00Z", + "metadata": { + "team_id": str(test_team.id), + "region_id": str(test_region.id) + } + } + } + mock_get.return_value.raise_for_status.return_value = None + + # Get the key as admin + response = client.get( + f"/private-ai-keys/{test_key.id}", + headers={"Authorization": f"Bearer {admin_token}"} + ) + + # Verify the response + assert response.status_code == 200 + data = response.json() + assert data["id"] == test_key.id + assert data["name"] == "Test Key for Get" + assert data["database_name"] == "test-db-get" + assert data["database_host"] == "test-host" + assert data["database_username"] == "test-user" + assert data["database_password"] == "test-pass" + assert data["litellm_token"] == "test-token-get" + assert data["litellm_api_url"] == "https://test-litellm.com" + assert data["team_id"] == test_team.id + assert data["region"] == region_name + + # Verify the LiteLLM API was called correctly + mock_get.assert_called_with( + f"{test_region.litellm_api_url}/key/info", + headers={"Authorization": f"Bearer {test_region.litellm_api_key}"}, + params={"key": test_key.litellm_token} + ) + + # Clean up + db.delete(test_key) + db.commit() + +@patch("app.services.litellm.requests.get") +def test_get_private_ai_key_not_found(mock_get, client, admin_token): + """Test getting a non-existent private AI key""" + # Try to get a non-existent key + response = client.get( + "/private-ai-keys/99999", + headers={"Authorization": f"Bearer {admin_token}"} + ) + + # Verify the response + assert response.status_code == 404 + assert "Private AI Key not found" in response.json()["detail"] + +@patch("app.services.litellm.requests.get") +def test_get_private_ai_key_unauthorized(mock_get, client, test_token, test_region, db, test_team): + """Test getting a private AI key without proper authorization""" + # Create a test key owned by the team + test_key = DBPrivateAIKey( + database_name="test-db-unauthorized", + name="Test Key for Unauthorized", + database_host="test-host", + database_username="test-user", + database_password="test-pass", + litellm_token="test-token-unauthorized", + litellm_api_url="https://test-litellm.com", + team_id=test_team.id, + region_id=test_region.id + ) + db.add(test_key) + db.commit() + db.refresh(test_key) + + # Try to get the key as a regular user + response = client.get( + f"/private-ai-keys/{test_key.id}", + headers={"Authorization": f"Bearer {test_token}"} + ) + + # Verify the response + assert response.status_code == 403 + assert "Not authorized to perform this action" in response.json()["detail"] + + # Clean up + db.delete(test_key) + db.commit() diff --git a/tests/test_users.py b/tests/test_users.py index d75bf57..9b077fe 100644 --- a/tests/test_users.py +++ b/tests/test_users.py @@ -319,6 +319,35 @@ def test_create_user_with_invalid_role_by_team_admin(client, team_admin_token, t assert response.status_code == 400 assert "Invalid role" in response.json()["detail"] +def test_make_non_team_user_admin(client, admin_token, test_user, db): + """ + Test that a system admin can make a non-team user an admin. + + GIVEN: The authenticated user is a system admin + WHEN: They try to make a non-team user an admin + THEN: A 200 success is returned and the user is updated + """ + # Ensure test_user is not an admin and not in a team + test_user = db.merge(test_user) + test_user.is_admin = False + test_user.team_id = None + db.commit() + db.refresh(test_user) + + # Update the user to make them an admin + response = client.put( + f"/users/{test_user.id}", + headers={"Authorization": f"Bearer {admin_token}"}, + json={ + "email": test_user.email, + "is_admin": True, + "is_active": True + } + ) + assert response.status_code == 200 + user_data = response.json() + assert user_data["is_admin"] is True + def test_make_team_member_admin_by_team_admin(client, team_admin_token, admin_token): """ Test that a team admin cannot make a user an admin. From 659abff482477de75d838617f9e24fcedf775cd9 Mon Sep 17 00:00:00 2001 From: Pippa H Date: Fri, 16 May 2025 12:29:55 +0200 Subject: [PATCH 04/23] Add billing functionality with Stripe integration - Introduced a new billing module with endpoints for creating checkout sessions and handling Stripe webhook events. - Added a new `billing.py` API file to manage team subscriptions and payment processing. - Implemented `create_checkout_session` and `handle_stripe_event` functions in the Stripe service for managing subscriptions. - Updated the main application to include the billing router and integrated it with existing middleware. - Added a new `DBSystemSecret` model to store sensitive Stripe webhook secrets. - Enhanced the configuration settings to include AWS and SES parameters. - Created a migration script to add the `system_secrets` table to the database. - Updated the initialization script to set up the Stripe webhook in specific environments. - Added a new `CheckoutSessionCreate` schema for validating checkout session requests. --- app/api/billing.py | 151 ++++++++++++ app/core/config.py | 9 + app/db/models.py | 12 +- app/main.py | 16 +- ...090727_891e02a9ce6e_add_system_settings.py | 42 ++++ app/schemas/models.py | 5 +- app/services/stripe.py | 225 ++++++++++++++++++ requirements.txt | 2 + scripts/initialise_resources.py | 18 +- 9 files changed, 468 insertions(+), 12 deletions(-) create mode 100644 app/api/billing.py create mode 100644 app/migrations/versions/20250516_090727_891e02a9ce6e_add_system_settings.py create mode 100644 app/services/stripe.py diff --git a/app/api/billing.py b/app/api/billing.py new file mode 100644 index 0000000..a00f03c --- /dev/null +++ b/app/api/billing.py @@ -0,0 +1,151 @@ +from typing import Optional +from fastapi import APIRouter, Depends, HTTPException, status, Request, Response +from sqlalchemy.orm import Session +import logging +import os +from pydantic import BaseModel + +from app.db.database import get_db +from app.core.security import get_current_user_from_auth, check_specific_team_admin +from app.db.models import DBUser, DBTeam, DBSystemSecret +from app.schemas.models import CheckoutSessionCreate +from app.services.stripe import ( + create_checkout_session, + handle_stripe_event, + create_portal_session +) + +# Configure logger +logger = logging.getLogger(__name__) + +router = APIRouter( + tags=["billing"] +) + +@router.post("/teams/{team_id}/checkout", dependencies=[Depends(check_specific_team_admin)]) +async def checkout( + team_id: int, + request_data: CheckoutSessionCreate, + request: Request, + current_user: DBUser = Depends(get_current_user_from_auth), + db: Session = Depends(get_db) +): + """ + Create a Stripe Checkout Session for team subscription. + + Args: + team_id: The ID of the team to create the subscription for + request_data: Contains the price_lookup_token to identify the specific price + + Returns: + redirect to the checkout session + """ + try: + # Get the team + team = db.query(DBTeam).filter(DBTeam.id == team_id).first() + if not team: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Team not found" + ) + + # Get the frontend URL from environment + frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000") + + # Create checkout session using the service + checkout_url = await create_checkout_session(team, request_data.price_lookup_token, frontend_url) + + return Response( + status_code=status.HTTP_303_SEE_OTHER, + headers={"Location": checkout_url} + ) + except Exception as e: + logger.error(f"Error creating checkout session: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error creating checkout session" + ) + +@router.post("/events") +async def handle_events( + request: Request, + db: Session = Depends(get_db) +): + """ + Handle Stripe webhook events. + + This endpoint processes various Stripe events like subscription updates, + payment successes, and failures. + """ + try: + # Get the webhook secret from database + webhook_secret = db.query(DBSystemSecret).filter( + DBSystemSecret.key == "stripe_webhook_secret" + ).first() + + if not webhook_secret: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Stripe webhook secret not configured" + ) + + # Get the raw request body + payload = await request.body() + sig_header = request.headers.get("stripe-signature") + + # Handle the event using the service + await handle_stripe_event(payload, sig_header, webhook_secret.value, db) + + return Response( + status_code=status.HTTP_200_OK, + content="Webhook processed successfully" + ) + + except Exception as e: + logger.error(f"Error handling Stripe event: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error processing webhook" + ) + +@router.post("/teams/{team_id}/portal", dependencies=[Depends(check_specific_team_admin)]) +async def get_portal( + team_id: int, + request: Request, + current_user: DBUser = Depends(get_current_user_from_auth), + db: Session = Depends(get_db) +): + """ + Create a Stripe Customer Portal session for team subscription management and redirect to it. + + Args: + team_id: The ID of the team to create the portal session for + + Returns: + Redirects to the Stripe Customer Portal URL + """ + try: + # Get the team + team = db.query(DBTeam).filter(DBTeam.id == team_id).first() + if not team: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Team not found" + ) + + # Get the frontend URL from environment + frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000") + + # Create portal session using the service + portal_url = await create_portal_session(team, frontend_url) + + return Response( + status_code=status.HTTP_303_SEE_OTHER, + headers={"Location": portal_url} + ) + except Exception as e: + logger.error(f"Error creating portal session: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error creating portal session" + ) diff --git a/app/core/config.py b/app/core/config.py index 260d849..2aa3bc5 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -21,6 +21,15 @@ class Settings(BaseSettings): ALLOWED_HOSTS: list[str] = ["*"] # In production, restrict this PUBLIC_PATHS: list[str] = ["/health", "/docs", "/openapi.json", "/metrics"] + AWS_ACCESS_KEY_ID: str = "AKIATEST" + AWS_SECRET_ACCESS_KEY: str = "sk-string" + SES_SENDER_EMAIL: str = "info@example.com" + PASSWORDLESS_SIGN_IN: str = "true" + ENV_SUFFIX: str = "local" + DYNAMODB_REGION: str = "eu-west-1" + SES_REGION: str = "eu-west-1" + EXPIRE_KEYS: str = "true" + model_config = ConfigDict(env_file=".env") def model_post_init(self, values): diff --git a/app/db/models.py b/app/db/models.py index e95fb86..8252f1d 100644 --- a/app/db/models.py +++ b/app/db/models.py @@ -118,4 +118,14 @@ class DBAuditLog(Base): user_agent = Column(String, nullable=True) request_source = Column(String, nullable=True) # Values: 'frontend', 'api', or None - user = relationship("DBUser", back_populates="audit_logs") \ No newline at end of file + user = relationship("DBUser", back_populates="audit_logs") + +class DBSystemSecret(Base): + __tablename__ = "system_secrets" + + id = Column(Integer, primary_key=True, index=True) + key = Column(String, unique=True, index=True, nullable=False) + value = Column(String, nullable=False) + description = Column(String, nullable=True) + created_at = Column(DateTime(timezone=True), default=func.now()) + updated_at = Column(DateTime(timezone=True), onupdate=func.now()) \ No newline at end of file diff --git a/app/main.py b/app/main.py index c514898..2d9536c 100644 --- a/app/main.py +++ b/app/main.py @@ -5,6 +5,13 @@ from fastapi.openapi.docs import get_swagger_ui_html from fastapi.openapi.utils import get_openapi from prometheus_fastapi_instrumentator import Instrumentator, metrics +from app.api import auth, private_ai_keys, users, regions, audit, teams, billing +from app.core.config import settings +from app.db.database import get_db +from app.middleware.audit import AuditLogMiddleware +from app.middleware.prometheus import PrometheusMiddleware +from app.middleware.auth import AuthMiddleware + import os import logging @@ -22,13 +29,6 @@ async def dispatch(self, request, call_next): request.scope["scheme"] = "https" return await call_next(request) -from app.api import auth, private_ai_keys, users, regions, audit, teams -from app.core.config import settings -from app.db.database import get_db -from app.middleware.audit import AuditLogMiddleware -from app.middleware.prometheus import PrometheusMiddleware -from app.middleware.auth import AuthMiddleware - app = FastAPI( title="Private AI Keys as a Service", description=""" @@ -132,7 +132,7 @@ async def health_check(): app.include_router(regions.router, prefix="/regions", tags=["regions"]) app.include_router(audit.router, prefix="/audit", tags=["audit"]) app.include_router(teams.router, prefix="/teams", tags=["teams"]) - +app.include_router(billing.router, prefix="/billing", tags=["billing"]) @app.get("/", include_in_schema=False) async def custom_swagger_ui_html(): return get_swagger_ui_html( diff --git a/app/migrations/versions/20250516_090727_891e02a9ce6e_add_system_settings.py b/app/migrations/versions/20250516_090727_891e02a9ce6e_add_system_settings.py new file mode 100644 index 0000000..875def7 --- /dev/null +++ b/app/migrations/versions/20250516_090727_891e02a9ce6e_add_system_settings.py @@ -0,0 +1,42 @@ +"""Add system settings + +Revision ID: 891e02a9ce6e +Revises: 2bb3f48b650d +Create Date: 2025-05-16 09:07:27.786841+00:00 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '891e02a9ce6e' +down_revision: Union[str, None] = '2bb3f48b650d' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('system_secrets', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('key', sa.String(), nullable=False), + sa.Column('value', sa.String(), nullable=False), + sa.Column('description', sa.String(), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_system_secrets_id'), 'system_secrets', ['id'], unique=False) + op.create_index(op.f('ix_system_secrets_key'), 'system_secrets', ['key'], unique=True) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_system_secrets_key'), table_name='system_secrets') + op.drop_index(op.f('ix_system_secrets_id'), table_name='system_secrets') + op.drop_table('system_secrets') + # ### end Alembic commands ### \ No newline at end of file diff --git a/app/schemas/models.py b/app/schemas/models.py index a5ad334..db49e88 100644 --- a/app/schemas/models.py +++ b/app/schemas/models.py @@ -268,4 +268,7 @@ class UserRoleUpdate(BaseModel): class SignInData(BaseModel): username: EmailStr - verification_code: str \ No newline at end of file + verification_code: str + +class CheckoutSessionCreate(BaseModel): + price_lookup_token: str \ No newline at end of file diff --git a/app/services/stripe.py b/app/services/stripe.py new file mode 100644 index 0000000..9765ef0 --- /dev/null +++ b/app/services/stripe.py @@ -0,0 +1,225 @@ +from typing import Optional +import stripe +import os +import logging +from urllib.parse import urljoin +from fastapi import HTTPException, status +from sqlalchemy.orm import Session + +from app.db.models import DBTeam, DBSystemSecret + +# Configure logger +stripe_logger = logging.getLogger(__name__) + +# Initialize Stripe +stripe.api_key = os.getenv("STRIPE_SECRET_KEY") + +async def create_checkout_session( + team: DBTeam, + price_lookup_token: str, + frontend_url: str +) -> str: + """ + Create a Stripe Checkout Session for team subscription. + + Args: + team: The team to create the subscription for + price_lookup_token: Token to identify the specific price + frontend_url: The frontend URL for success/cancel redirects + + Returns: + str: The checkout session URL + """ + try: + # Fetch the specific price using lookup_keys + prices = stripe.Price.list( + active=True, + lookup_keys=[price_lookup_token], + expand=['data.product'] + ) + + if not prices.data: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"No active subscription price found for token: {price_lookup_token}" + ) + + subscription_price = prices.data[0] + + # Create the checkout session + checkout_session = stripe.checkout.Session.create( + customer_email=team.admin_email, + success_url=f"{frontend_url}/teams/{team.id}/dashboard?session_id={{CHECKOUT_SESSION_ID}}", + cancel_url=f"{frontend_url}/teams/{team.id}/pricing", + mode="subscription", + line_items=[{ + "price": subscription_price.id, + "quantity": 1, + }], + metadata={ + "team_id": team.id, + "team_name": team.name, + "admin_email": team.admin_email + } + ) + + return checkout_session.url + except Exception as e: + stripe_logger.error(f"Error creating checkout session: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error creating checkout session" + ) + +async def handle_stripe_event( + payload: bytes, + sig_header: str, + webhook_secret: str, + db: Session +) -> None: + """ + Handle Stripe webhook events. + + Args: + payload: The raw request body + sig_header: The Stripe signature header + webhook_secret: The webhook signing secret + db: Database session + """ + try: + event = stripe.Webhook.construct_event( + payload, sig_header, webhook_secret + ) + + # Handle the event + if event.type == "checkout.session.completed": + session = event.data.object + # Handle successful checkout + # Update team's subscription status in database + team = db.query(DBTeam).filter(DBTeam.id == session.metadata.get("team_id")).first() + if team: + team.is_subscribed = True + team.stripe_customer_id = session.customer + db.commit() + + elif event.type == "customer.subscription.deleted": + subscription = event.data.object + # Handle subscription cancellation + # Update team's subscription status in database + team = db.query(DBTeam).filter(DBTeam.stripe_customer_id == subscription.customer).first() + if team: + team.is_subscribed = False + db.commit() + + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid payload" + ) + except stripe.error.SignatureVerificationError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid signature" + ) + except Exception as e: + stripe_logger.error(f"Error handling Stripe event: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error processing webhook" + ) + +async def create_portal_session( + team: DBTeam, + frontend_url: str +) -> str: + """ + Create a Stripe Customer Portal session for team subscription management. + + Args: + team: The team to create the portal session for + frontend_url: The frontend URL for return redirect + + Returns: + str: The portal session URL + """ + try: + if not team.stripe_customer_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Team has no active Stripe subscription" + ) + + # Create the portal session + portal_session = stripe.billing_portal.Session.create( + customer=team.stripe_customer_id, + return_url=f"{frontend_url}/teams/{team.id}/dashboard" + ) + + return portal_session.url + except Exception as e: + stripe_logger.error(f"Error creating portal session: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error creating portal session" + ) + +async def setup_stripe_webhook(db: Session) -> None: + """ + Set up the Stripe webhook endpoint if it doesn't exist and store its signing secret. + + Args: + db: Database session + """ + try: + # Check if we already have a webhook secret stored + existing_secret = db.query(DBSystemSecret).filter( + DBSystemSecret.key == "stripe_webhook_secret" + ).first() + + if existing_secret: + return + + # Get the base URL from environment + base_url = os.getenv("BACKEND_URL", "http://localhost:8800") + webhook_url = urljoin(base_url, "/api/stripe/handle-event") + + # List existing webhook endpoints + endpoints = stripe.WebhookEndpoint.list() + + # Check if we already have an endpoint for this URL + existing_endpoint = None + for endpoint in endpoints.data: + if endpoint.url == webhook_url: + existing_endpoint = endpoint + break + + if existing_endpoint: + # For existing endpoints, we need to create a new one to get the secret + # First delete the old endpoint + stripe.WebhookEndpoint.delete(existing_endpoint.id) + stripe_logger.info(f"Deleted existing webhook endpoint: {existing_endpoint.id}") + + # Create new webhook endpoint + endpoint = stripe.WebhookEndpoint.create( + url=webhook_url, + enabled_events=[ + "checkout.session.completed", + "customer.subscription.deleted" + ] + ) + + # Store the signing secret + secret = DBSystemSecret( + key="stripe_webhook_secret", + value=endpoint.secret, + description="Stripe webhook signing secret for handling events" + ) + db.add(secret) + db.commit() + + except Exception as e: + stripe_logger.error(f"Error setting up Stripe webhook: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error setting up Stripe webhook" + ) diff --git a/requirements.txt b/requirements.txt index 75005b5..fef4846 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,3 +16,5 @@ markdown==3.8.0 email-validator==2.1.2 prometheus-client==0.21.1 prometheus-fastapi-instrumentator==7.0.0 +stripe==12.1.0 +six==1.17.0 \ No newline at end of file diff --git a/scripts/initialise_resources.py b/scripts/initialise_resources.py index b6b4f53..98c584e 100644 --- a/scripts/initialise_resources.py +++ b/scripts/initialise_resources.py @@ -15,6 +15,7 @@ from app.db.models import Base, DBUser from app.core.security import get_password_hash from app.services.ses import SESService +from app.services.stripe import setup_stripe_webhook import glob def init_database(): @@ -74,8 +75,21 @@ def init_database(): except Exception as e: print(f"Error creating admin user: {str(e)}") db.rollback() - finally: - db.close() + + # Only set up Stripe webhook in specific environments + env_suffix = os.getenv("ENV_SUFFIX", "").lower() + if env_suffix in ["dev", "main", "prod"]: + try: + # Set up Stripe webhook + print("Setting up Stripe webhook...") + setup_stripe_webhook(db) + print("Stripe webhook set up successfully") + except Exception as e: + print(f"Warning: Failed to set up Stripe webhook: {str(e)}") + else: + print(f"Skipping Stripe webhook setup for environment: {env_suffix}") + + db.close() def init_ses_templates(): if os.getenv("PASSWORDLESS_SIGN_IN", "").lower() == "true": From bc3fbd523006d8195241f5c417be6352ae066cf2 Mon Sep 17 00:00:00 2001 From: Pippa H Date: Fri, 16 May 2025 12:30:07 +0200 Subject: [PATCH 05/23] Refactor API routers to standardize tags - Updated the `APIRouter` instances in multiple API files to use consistent lowercase tags. - Changed tags for `auth`, `private-ai-keys`, `regions`, `teams`, and `users` to improve clarity and maintainability. - Removed redundant router definitions in `private_ai_keys.py` to streamline the codebase. --- app/api/auth.py | 2 +- app/api/private_ai_keys.py | 8 ++++---- app/api/regions.py | 4 +++- app/api/teams.py | 4 +++- app/api/users.py | 4 +++- 5 files changed, 14 insertions(+), 8 deletions(-) diff --git a/app/api/auth.py b/app/api/auth.py index b11625a..b3df166 100644 --- a/app/api/auth.py +++ b/app/api/auth.py @@ -67,7 +67,7 @@ auth_logger = logging.getLogger(__name__) router = APIRouter( - tags=["Authentication"] + tags=["auth"] ) def get_cookie_domain(): diff --git a/app/api/private_ai_keys.py b/app/api/private_ai_keys.py index 2ae52f6..a46d663 100644 --- a/app/api/private_ai_keys.py +++ b/app/api/private_ai_keys.py @@ -15,6 +15,10 @@ from app.services.litellm import LiteLLMService from app.core.security import get_current_user_from_auth, get_role_min_key_creator, get_role_min_team_admin, UserRole, check_system_admin +router = APIRouter( + tags=["private-ai-keys"] +) + # Set up logging logger = logging.getLogger(__name__) @@ -58,10 +62,6 @@ def _validate_permissions_and_get_ownership_info( return owner_id, team_id -router = APIRouter( - tags=["Private AI Keys"] -) - @router.post("/vector-db", response_model=VectorDB) async def create_vector_db( vector_db: VectorDBCreate, diff --git a/app/api/regions.py b/app/api/regions.py index 19f03f9..d96bde6 100644 --- a/app/api/regions.py +++ b/app/api/regions.py @@ -7,7 +7,9 @@ from app.schemas.models import Region, RegionCreate, RegionResponse, User, RegionUpdate from app.db.models import DBRegion, DBPrivateAIKey -router = APIRouter() +router = APIRouter( + tags=["regions"] +) @router.post("", response_model=Region) @router.post("/", response_model=Region) diff --git a/app/api/teams.py b/app/api/teams.py index 7fda28a..75c1632 100644 --- a/app/api/teams.py +++ b/app/api/teams.py @@ -11,7 +11,9 @@ TeamWithUsers ) -router = APIRouter() +router = APIRouter( + tags=["teams"] +) @router.post("", response_model=Team, status_code=status.HTTP_201_CREATED) @router.post("/", response_model=Team, status_code=status.HTTP_201_CREATED) diff --git a/app/api/users.py b/app/api/users.py index c8208c4..5a9fb29 100644 --- a/app/api/users.py +++ b/app/api/users.py @@ -8,7 +8,9 @@ from app.core.security import get_password_hash, check_system_admin, get_current_user_from_auth, UserRole, get_role_min_team_admin from datetime import datetime, UTC -router = APIRouter() +router = APIRouter( + tags=["users"] +) @router.get("/search", response_model=List[User], dependencies=[Depends(check_system_admin)]) async def search_users( From 66146fbe4da7dac6684eb9d376a44153a897da73 Mon Sep 17 00:00:00 2001 From: Pippa H Date: Fri, 16 May 2025 13:39:18 +0200 Subject: [PATCH 06/23] Add product management functionality - Introduced a new `products` API module with endpoints for creating, listing, retrieving, updating, and deleting products. - Implemented access control to restrict product management actions to system and team administrators. - Added a new `DBProduct` model to represent products in the database, including fields for name, stripe lookup key, and active status. - Updated the main application to include the products router. - Enhanced configuration settings to include a Stripe key for product management. - Created migration scripts to add the `products` table and update the `teams` table with a new `stripe_customer_id` field. - Developed comprehensive tests to validate product management functionality and access control. --- app/api/products.py | 126 ++++++++ app/core/config.py | 3 +- app/db/models.py | 11 + app/main.py | 4 +- ...516_113604_1069b3617025_manage_products.py | 43 +++ app/schemas/models.py | 21 +- app/services/stripe.py | 47 ++- tests/test_products.py | 288 ++++++++++++++++++ 8 files changed, 539 insertions(+), 4 deletions(-) create mode 100644 app/api/products.py create mode 100644 app/migrations/versions/20250516_113604_1069b3617025_manage_products.py create mode 100644 tests/test_products.py diff --git a/app/api/products.py b/app/api/products.py new file mode 100644 index 0000000..ab9c5f8 --- /dev/null +++ b/app/api/products.py @@ -0,0 +1,126 @@ +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.orm import Session +from typing import List +from datetime import datetime, UTC + +from app.db.database import get_db +from app.db.models import DBProduct +from app.core.security import check_system_admin, get_current_user_from_auth, get_role_min_team_admin +from app.schemas.models import Product, ProductCreate, ProductUpdate + +router = APIRouter( + tags=["products"] +) + +@router.post("", response_model=Product, status_code=status.HTTP_201_CREATED, dependencies=[Depends(check_system_admin)]) +@router.post("/", response_model=Product, status_code=status.HTTP_201_CREATED, dependencies=[Depends(check_system_admin)]) +async def create_product( + product: ProductCreate, + db: Session = Depends(get_db) +): + """ + Create a new product. Only accessible by system admin users. + """ + # Check if stripe_lookup_key already exists + existing_product = db.query(DBProduct).filter(DBProduct.stripe_lookup_key == product.stripe_lookup_key).first() + if existing_product: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Product with this stripe_lookup_key already exists" + ) + + # Create the product + db_product = DBProduct( + name=product.name, + stripe_lookup_key=product.stripe_lookup_key, + active=product.active, + created_at=datetime.now(UTC) + ) + + db.add(db_product) + db.commit() + db.refresh(db_product) + + return db_product + +@router.get("/", response_model=List[Product], dependencies=[Depends(get_role_min_team_admin)]) +async def list_products( + db: Session = Depends(get_db) +): + """ + List all products. Only accessible by team admin users or higher privileges. + """ + return db.query(DBProduct).all() + +@router.get("/{product_id}", response_model=Product, dependencies=[Depends(get_role_min_team_admin)]) +async def get_product( + product_id: int, + db: Session = Depends(get_db) +): + """ + Get a specific product by ID. Only accessible by team admin users or higher privileges. + """ + product = db.query(DBProduct).filter(DBProduct.id == product_id).first() + if not product: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Product not found" + ) + return product + +@router.put("/{product_id}", response_model=Product, dependencies=[Depends(check_system_admin)]) +async def update_product( + product_id: int, + product_update: ProductUpdate, + db: Session = Depends(get_db) +): + """ + Update a product. Only accessible by system admin users. + """ + product = db.query(DBProduct).filter(DBProduct.id == product_id).first() + if not product: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Product not found" + ) + + # If updating stripe_lookup_key, check if it already exists + if product_update.stripe_lookup_key and product_update.stripe_lookup_key != product.stripe_lookup_key: + existing_product = db.query(DBProduct).filter( + DBProduct.stripe_lookup_key == product_update.stripe_lookup_key, + DBProduct.id != product_id + ).first() + if existing_product: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Product with this stripe_lookup_key already exists" + ) + + # Update the product + for key, value in product_update.model_dump(exclude_unset=True).items(): + setattr(product, key, value) + + db.commit() + db.refresh(product) + + return product + +@router.delete("/{product_id}", dependencies=[Depends(check_system_admin)]) +async def delete_product( + product_id: int, + db: Session = Depends(get_db) +): + """ + Delete a product. Only accessible by system admin users. + """ + product = db.query(DBProduct).filter(DBProduct.id == product_id).first() + if not product: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Product not found" + ) + + db.delete(product) + db.commit() + + return {"message": "Product deleted successfully"} \ No newline at end of file diff --git a/app/core/config.py b/app/core/config.py index 2aa3bc5..8ac3536 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -29,6 +29,7 @@ class Settings(BaseSettings): DYNAMODB_REGION: str = "eu-west-1" SES_REGION: str = "eu-west-1" EXPIRE_KEYS: str = "true" + STRIPE_KEY: str = "sk_test_string" model_config = ConfigDict(env_file=".env") @@ -37,4 +38,4 @@ def model_post_init(self, values): lagoon_routes = os.getenv("LAGOON_ROUTES", "").split(",") self.CORS_ORIGINS.extend([route.strip() for route in lagoon_routes if route.strip()]) -settings = Settings() \ No newline at end of file +settings = Settings() diff --git a/app/db/models.py b/app/db/models.py index 8252f1d..10e3927 100644 --- a/app/db/models.py +++ b/app/db/models.py @@ -62,6 +62,7 @@ class DBTeam(Base): is_active = Column(Boolean, default=True) created_at = Column(DateTime(timezone=True), default=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now()) + stripe_customer_id = Column(String, nullable=True, unique=True, index=True) users = relationship("DBUser", back_populates="team") private_ai_keys = relationship("DBPrivateAIKey", back_populates="team") @@ -128,4 +129,14 @@ class DBSystemSecret(Base): value = Column(String, nullable=False) description = Column(String, nullable=True) created_at = Column(DateTime(timezone=True), default=func.now()) + updated_at = Column(DateTime(timezone=True), onupdate=func.now()) + +class DBProduct(Base): + __tablename__ = "products" + + id = Column(Integer, primary_key=True, index=True) + name = Column(String, nullable=False) + stripe_lookup_key = Column(String, unique=True, index=True, nullable=False) + active = Column(Boolean, default=True) + created_at = Column(DateTime(timezone=True), default=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now()) \ No newline at end of file diff --git a/app/main.py b/app/main.py index 2d9536c..df2029a 100644 --- a/app/main.py +++ b/app/main.py @@ -5,7 +5,7 @@ from fastapi.openapi.docs import get_swagger_ui_html from fastapi.openapi.utils import get_openapi from prometheus_fastapi_instrumentator import Instrumentator, metrics -from app.api import auth, private_ai_keys, users, regions, audit, teams, billing +from app.api import auth, private_ai_keys, users, regions, audit, teams, billing, products from app.core.config import settings from app.db.database import get_db from app.middleware.audit import AuditLogMiddleware @@ -133,6 +133,8 @@ async def health_check(): app.include_router(audit.router, prefix="/audit", tags=["audit"]) app.include_router(teams.router, prefix="/teams", tags=["teams"]) app.include_router(billing.router, prefix="/billing", tags=["billing"]) +app.include_router(products.router, prefix="/products", tags=["products"]) + @app.get("/", include_in_schema=False) async def custom_swagger_ui_html(): return get_swagger_ui_html( diff --git a/app/migrations/versions/20250516_113604_1069b3617025_manage_products.py b/app/migrations/versions/20250516_113604_1069b3617025_manage_products.py new file mode 100644 index 0000000..18be89e --- /dev/null +++ b/app/migrations/versions/20250516_113604_1069b3617025_manage_products.py @@ -0,0 +1,43 @@ +"""manage products + +Revision ID: 1069b3617025 +Revises: 891e02a9ce6e +Create Date: 2025-05-16 11:36:04.879027+00:00 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = '1069b3617025' +down_revision: Union[str, None] = '891e02a9ce6e' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('products', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('name', sa.String(), nullable=False), + sa.Column('stripe_lookup_key', sa.String(), nullable=False), + sa.Column('active', sa.Boolean(), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_products_id'), 'products', ['id'], unique=False) + op.create_index(op.f('ix_products_stripe_lookup_key'), 'products', ['stripe_lookup_key'], unique=True) + op.add_column('teams', sa.Column('stripe_customer_id', sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('teams', 'stripe_customer_id') + op.drop_index(op.f('ix_products_stripe_lookup_key'), table_name='products') + op.drop_index(op.f('ix_products_id'), table_name='products') + op.drop_table('products') + # ### end Alembic commands ### \ No newline at end of file diff --git a/app/schemas/models.py b/app/schemas/models.py index db49e88..76c757d 100644 --- a/app/schemas/models.py +++ b/app/schemas/models.py @@ -271,4 +271,23 @@ class SignInData(BaseModel): verification_code: str class CheckoutSessionCreate(BaseModel): - price_lookup_token: str \ No newline at end of file + price_lookup_token: str + +class ProductBase(BaseModel): + name: str + stripe_lookup_key: str + active: bool = True + +class ProductCreate(ProductBase): + pass + +class ProductUpdate(BaseModel): + name: Optional[str] = None + stripe_lookup_key: Optional[str] = None + active: Optional[bool] = None + +class Product(ProductBase): + id: int + created_at: datetime + updated_at: Optional[datetime] = None + model_config = ConfigDict(from_attributes=True) \ No newline at end of file diff --git a/app/services/stripe.py b/app/services/stripe.py index 9765ef0..43ac7b9 100644 --- a/app/services/stripe.py +++ b/app/services/stripe.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import stripe import os import logging @@ -223,3 +223,48 @@ async def setup_stripe_webhook(db: Session) -> None: status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Error setting up Stripe webhook" ) + +async def create_stripe_customer( + team: DBTeam, + db: Session +) -> str: + """ + Create a Stripe customer for a team and save the customer ID. + + Args: + team: The team to create a Stripe customer for + db: Database session + + Returns: + str: The Stripe customer ID + + Raises: + HTTPException: If error creating customer + """ + try: + # Check if team already has a Stripe customer + if team.stripe_customer_id: + return team.stripe_customer_id + + # Create Stripe customer + customer = stripe.Customer.create( + email=team.admin_email, + name=team.name, + metadata={ + "team_id": team.id, + "team_name": team.name + } + ) + + # Save customer ID to team + team.stripe_customer_id = customer.id + db.commit() + + return customer.id + + except Exception as e: + stripe_logger.error(f"Error creating Stripe customer: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error creating Stripe customer" + ) diff --git a/tests/test_products.py b/tests/test_products.py new file mode 100644 index 0000000..8114f94 --- /dev/null +++ b/tests/test_products.py @@ -0,0 +1,288 @@ +import pytest +from fastapi.testclient import TestClient +from sqlalchemy.orm import Session +from app.db.models import DBProduct, DBUser, DBTeam +from datetime import datetime, UTC + +def test_create_product_as_system_admin(client, admin_token, db): + """ + Test that a system admin can create a product. + + GIVEN: The authenticated user is a system admin + WHEN: They create a product + THEN: A 201 - Created is returned with the product data + """ + response = client.post( + "/products/", + headers={"Authorization": f"Bearer {admin_token}"}, + json={ + "name": "Test Product", + "stripe_lookup_key": "test_product_123", + "active": True + } + ) + assert response.status_code == 201 + data = response.json() + assert data["name"] == "Test Product" + assert data["stripe_lookup_key"] == "test_product_123" + assert data["active"] is True + assert "id" in data + assert "created_at" in data + +def test_create_product_duplicate_key(client, admin_token, db): + """ + Test that creating a product with a duplicate stripe_lookup_key fails. + + GIVEN: A product with a specific stripe_lookup_key exists + WHEN: A system admin tries to create another product with the same key + THEN: A 400 - Bad Request is returned + """ + # First create a product + client.post( + "/products/", + headers={"Authorization": f"Bearer {admin_token}"}, + json={ + "name": "Test Product", + "stripe_lookup_key": "test_product_123", + "active": True + } + ) + + # Try to create another product with the same key + response = client.post( + "/products/", + headers={"Authorization": f"Bearer {admin_token}"}, + json={ + "name": "Another Product", + "stripe_lookup_key": "test_product_123", + "active": True + } + ) + assert response.status_code == 400 + assert "already exists" in response.json()["detail"] + +def test_create_product_unauthorized(client, test_token, db): + """ + Test that a non-admin user cannot create a product. + + GIVEN: The authenticated user is not a system admin + WHEN: They try to create a product + THEN: A 403 - Forbidden is returned + """ + response = client.post( + "/products/", + headers={"Authorization": f"Bearer {test_token}"}, + json={ + "name": "Test Product", + "stripe_lookup_key": "test_product_123", + "active": True + } + ) + assert response.status_code == 403 + +def test_list_products_as_team_admin(client, team_admin_token, db): + """ + Test that a team admin can list products. + + GIVEN: The authenticated user is a team admin + WHEN: They request the list of products + THEN: A 200 - OK is returned with the list of products + """ + # Create some test products + db_product1 = DBProduct( + name="Test Product 1", + stripe_lookup_key="test_product_1", + active=True, + created_at=datetime.now(UTC) + ) + db_product2 = DBProduct( + name="Test Product 2", + stripe_lookup_key="test_product_2", + active=True, + created_at=datetime.now(UTC) + ) + db.add(db_product1) + db.add(db_product2) + db.commit() + + response = client.get( + "/products/", + headers={"Authorization": f"Bearer {team_admin_token}"} + ) + assert response.status_code == 200 + data = response.json() + assert len(data) >= 2 + assert any(p["name"] == "Test Product 1" for p in data) + assert any(p["name"] == "Test Product 2" for p in data) + +def test_list_products_unauthorized(client, test_token, db): + """ + Test that a regular user cannot list products. + + GIVEN: The authenticated user is not a team admin + WHEN: They try to list products + THEN: A 403 - Forbidden is returned + """ + response = client.get( + "/products/", + headers={"Authorization": f"Bearer {test_token}"} + ) + assert response.status_code == 403 + +def test_get_product_as_team_admin(client, team_admin_token, db): + """ + Test that a team admin can get a specific product. + + GIVEN: The authenticated user is a team admin + WHEN: They request a specific product + THEN: A 200 - OK is returned with the product data + """ + # Create a test product + db_product = DBProduct( + name="Test Product", + stripe_lookup_key="test_product_123", + active=True, + created_at=datetime.now(UTC) + ) + db.add(db_product) + db.commit() + + response = client.get( + f"/products/{db_product.id}", + headers={"Authorization": f"Bearer {team_admin_token}"} + ) + assert response.status_code == 200 + data = response.json() + assert data["name"] == "Test Product" + assert data["stripe_lookup_key"] == "test_product_123" + assert data["active"] is True + +def test_get_product_not_found(client, team_admin_token, db): + """ + Test that getting a non-existent product returns 404. + + GIVEN: The authenticated user is a team admin + WHEN: They request a non-existent product + THEN: A 404 - Not Found is returned + """ + response = client.get( + "/products/99999", + headers={"Authorization": f"Bearer {team_admin_token}"} + ) + assert response.status_code == 404 + +def test_update_product_as_system_admin(client, admin_token, db): + """ + Test that a system admin can update a product. + + GIVEN: The authenticated user is a system admin + WHEN: They update a product + THEN: A 200 - OK is returned with the updated product data + """ + # Create a test product + db_product = DBProduct( + name="Test Product", + stripe_lookup_key="test_product_123", + active=True, + created_at=datetime.now(UTC) + ) + db.add(db_product) + db.commit() + + response = client.put( + f"/products/{db_product.id}", + headers={"Authorization": f"Bearer {admin_token}"}, + json={ + "name": "Updated Product", + "active": False + } + ) + assert response.status_code == 200 + data = response.json() + assert data["name"] == "Updated Product" + assert data["stripe_lookup_key"] == "test_product_123" # Unchanged + assert data["active"] is False + +def test_update_product_unauthorized(client, team_admin_token, db): + """ + Test that a team admin cannot update a product. + + GIVEN: The authenticated user is a team admin + WHEN: They try to update a product + THEN: A 403 - Forbidden is returned + """ + # Create a test product + db_product = DBProduct( + name="Test Product", + stripe_lookup_key="test_product_123", + active=True, + created_at=datetime.now(UTC) + ) + db.add(db_product) + db.commit() + + response = client.put( + f"/products/{db_product.id}", + headers={"Authorization": f"Bearer {team_admin_token}"}, + json={ + "name": "Updated Product" + } + ) + assert response.status_code == 403 + +def test_delete_product_as_system_admin(client, admin_token, db): + """ + Test that a system admin can delete a product. + + GIVEN: The authenticated user is a system admin + WHEN: They delete a product + THEN: A 200 - OK is returned and the product is deleted + """ + # Create a test product + db_product = DBProduct( + name="Test Product", + stripe_lookup_key="test_product_123", + active=True, + created_at=datetime.now(UTC) + ) + db.add(db_product) + db.commit() + + response = client.delete( + f"/products/{db_product.id}", + headers={"Authorization": f"Bearer {admin_token}"} + ) + assert response.status_code == 200 + assert response.json()["message"] == "Product deleted successfully" + + # Verify the product is deleted + deleted_product = db.query(DBProduct).filter(DBProduct.id == db_product.id).first() + assert deleted_product is None + +def test_delete_product_unauthorized(client, team_admin_token, db): + """ + Test that a team admin cannot delete a product. + + GIVEN: The authenticated user is a team admin + WHEN: They try to delete a product + THEN: A 403 - Forbidden is returned + """ + # Create a test product + db_product = DBProduct( + name="Test Product", + stripe_lookup_key="test_product_123", + active=True, + created_at=datetime.now(UTC) + ) + db.add(db_product) + db.commit() + + response = client.delete( + f"/products/{db_product.id}", + headers={"Authorization": f"Bearer {team_admin_token}"} + ) + assert response.status_code == 403 + + # Verify the product still exists + product = db.query(DBProduct).filter(DBProduct.id == db_product.id).first() + assert product is not None \ No newline at end of file From d491226c167b40dc88d6fae2586419a680d52e3b Mon Sep 17 00:00:00 2001 From: Pippa H Date: Mon, 19 May 2025 10:34:41 +0200 Subject: [PATCH 07/23] Enhance Stripe billing integration and configuration - Refactored the billing API to handle Stripe webhook events asynchronously using background tasks. - Introduced a new function to retrieve Stripe events and updated the webhook setup to use a consistent key. - Added comprehensive tests for handling various Stripe events, including successful checkouts and subscription deletions. - Created a new diagram to illustrate the Stripe flow for better understanding of the billing process. --- app/api/billing.py | 49 ++++++++++++++++---- app/services/stripe.py | 41 +++++------------ docs/design/StripeFlow.diagram | 22 +++++++++ scripts/initialise_resources.py | 9 ++-- tests/test_billing.py | 80 +++++++++++++++++++++++++++++++++ 5 files changed, 158 insertions(+), 43 deletions(-) create mode 100644 docs/design/StripeFlow.diagram create mode 100644 tests/test_billing.py diff --git a/app/api/billing.py b/app/api/billing.py index a00f03c..3c246b0 100644 --- a/app/api/billing.py +++ b/app/api/billing.py @@ -1,9 +1,10 @@ from typing import Optional -from fastapi import APIRouter, Depends, HTTPException, status, Request, Response +from fastapi import APIRouter, Depends, HTTPException, status, Request, Response, BackgroundTasks from sqlalchemy.orm import Session import logging import os from pydantic import BaseModel +import threading from app.db.database import get_db from app.core.security import get_current_user_from_auth, check_specific_team_admin @@ -11,12 +12,13 @@ from app.schemas.models import CheckoutSessionCreate from app.services.stripe import ( create_checkout_session, - handle_stripe_event, + get_stripe_event, create_portal_session ) # Configure logger logger = logging.getLogger(__name__) +BILLING_WEBHOOK_KEY = "stripe_webhook_secret" router = APIRouter( tags=["billing"] @@ -66,21 +68,50 @@ async def checkout( detail="Error creating checkout session" ) +async def handle_stripe_event_background(event, db: Session): + """ + Background task to handle Stripe webhook events. + This runs in a separate thread to avoid blocking the webhook response. + """ + try: + event_type = event.type + + if event_type == "checkout.session.completed": + session = event.data.object + # Handle successful checkout + # Update team's subscription status in database + team = db.query(DBTeam).filter(DBTeam.id == session.metadata.get("team_id")).first() + if team: + team.stripe_customer_id = session.customer + db.commit() + + elif event_type == "customer.subscription.deleted": + subscription = event.data.object + # Handle subscription cancellation + # Update team's subscription status in database + team = db.query(DBTeam).filter(DBTeam.stripe_customer_id == subscription.customer).first() + if team: + team.stripe_customer_id = None + db.commit() + except Exception as e: + logger.error(f"Error in background event handler: {str(e)}") + @router.post("/events") async def handle_events( request: Request, + background_tasks: BackgroundTasks, db: Session = Depends(get_db) ): """ Handle Stripe webhook events. This endpoint processes various Stripe events like subscription updates, - payment successes, and failures. + payment successes, and failures. Events are processed asynchronously in the background. """ try: # Get the webhook secret from database webhook_secret = db.query(DBSystemSecret).filter( - DBSystemSecret.key == "stripe_webhook_secret" + DBSystemSecret.key == BILLING_WEBHOOK_KEY ).first() if not webhook_secret: @@ -91,14 +122,16 @@ async def handle_events( # Get the raw request body payload = await request.body() - sig_header = request.headers.get("stripe-signature") + signature = request.headers.get("stripe-signature") + + event = get_stripe_event(payload, signature, webhook_secret.value) - # Handle the event using the service - await handle_stripe_event(payload, sig_header, webhook_secret.value, db) + # Add the event handling to background tasks + background_tasks.add_task(handle_stripe_event_background, event, db) return Response( status_code=status.HTTP_200_OK, - content="Webhook processed successfully" + content="Webhook received and processing started" ) except Exception as e: diff --git a/app/services/stripe.py b/app/services/stripe.py index 43ac7b9..f8fbd72 100644 --- a/app/services/stripe.py +++ b/app/services/stripe.py @@ -71,45 +71,23 @@ async def create_checkout_session( detail="Error creating checkout session" ) -async def handle_stripe_event( - payload: bytes, - sig_header: str, - webhook_secret: str, - db: Session -) -> None: +def get_stripe_event( payload: bytes, signature: str, webhook_secret: str) -> stripe.Event: """ Handle Stripe webhook events. Args: payload: The raw request body - sig_header: The Stripe signature header + signature: The Stripe signature header webhook_secret: The webhook signing secret - db: Database session + + Returns: + stripe.Event: The Stripe event """ try: event = stripe.Webhook.construct_event( - payload, sig_header, webhook_secret + payload, signature, webhook_secret ) - - # Handle the event - if event.type == "checkout.session.completed": - session = event.data.object - # Handle successful checkout - # Update team's subscription status in database - team = db.query(DBTeam).filter(DBTeam.id == session.metadata.get("team_id")).first() - if team: - team.is_subscribed = True - team.stripe_customer_id = session.customer - db.commit() - - elif event.type == "customer.subscription.deleted": - subscription = event.data.object - # Handle subscription cancellation - # Update team's subscription status in database - team = db.query(DBTeam).filter(DBTeam.stripe_customer_id == subscription.customer).first() - if team: - team.is_subscribed = False - db.commit() + return event except ValueError as e: raise HTTPException( @@ -163,17 +141,18 @@ async def create_portal_session( detail="Error creating portal session" ) -async def setup_stripe_webhook(db: Session) -> None: +async def setup_stripe_webhook(webhook_key: str, db: Session) -> None: """ Set up the Stripe webhook endpoint if it doesn't exist and store its signing secret. Args: + webhook_key: The key to store the webhook secret under db: Database session """ try: # Check if we already have a webhook secret stored existing_secret = db.query(DBSystemSecret).filter( - DBSystemSecret.key == "stripe_webhook_secret" + DBSystemSecret.key == webhook_key ).first() if existing_secret: diff --git a/docs/design/StripeFlow.diagram b/docs/design/StripeFlow.diagram new file mode 100644 index 0000000..2006c62 --- /dev/null +++ b/docs/design/StripeFlow.diagram @@ -0,0 +1,22 @@ +actor #green:0.5 Lauren +actor #blue Customer +participant #green amazee.ai +rparticipant #red Stripe + +Lauren -> Stripe: Create Products & Prices +Lauren -> amazee.ai: createPricingOptions + +Customer -> amazee.ai: chooseProduct +amazee.ai -> Stripe: createCustomer +amazee.ai <-- Stripe: customerID +amazee.ai -> Stripe: createCheckoutSession +amazee.ai <-- Stripe: redirectURL +Customer <-- amazee.ai: redirectURL +Customer -> Stripe: makePayment +Stripe -> amazee.ai: paymentSucceeded +activate amazee.ai +amazee.ai -> amazee.ai: extendKey +amazee.ai -> amazee.ai: setPaymentDate +deactivateafter amazee.ai +Customer <-- amazee.ai: Success + diff --git a/scripts/initialise_resources.py b/scripts/initialise_resources.py index 98c584e..723d3e3 100644 --- a/scripts/initialise_resources.py +++ b/scripts/initialise_resources.py @@ -16,6 +16,7 @@ from app.core.security import get_password_hash from app.services.ses import SESService from app.services.stripe import setup_stripe_webhook +from app.api.billing import BILLING_WEBHOOK_KEY import glob def init_database(): @@ -81,11 +82,11 @@ def init_database(): if env_suffix in ["dev", "main", "prod"]: try: # Set up Stripe webhook - print("Setting up Stripe webhook...") - setup_stripe_webhook(db) - print("Stripe webhook set up successfully") + print("Setting up Stripe Billing webhook...") + setup_stripe_webhook(BILLING_WEBHOOK_KEY, db) + print("Stripe Billing webhook set up successfully") except Exception as e: - print(f"Warning: Failed to set up Stripe webhook: {str(e)}") + print(f"Warning: Failed to set up Stripe Billing webhook: {str(e)}") else: print(f"Skipping Stripe webhook setup for environment: {env_suffix}") diff --git a/tests/test_billing.py b/tests/test_billing.py new file mode 100644 index 0000000..e9a6ab1 --- /dev/null +++ b/tests/test_billing.py @@ -0,0 +1,80 @@ +import pytest +from unittest.mock import Mock, patch +from app.api.billing import handle_stripe_event_background + +@pytest.mark.asyncio +async def test_handle_checkout_session_completed(db, test_team): + # Arrange + mock_event = Mock() + mock_event.type = "checkout.session.completed" + mock_session = Mock() + mock_session.metadata = {"team_id": str(test_team.id)} + mock_session.customer = "cus_123" + mock_event.data.object = mock_session + + # Act + await handle_stripe_event_background(mock_event, db) + + # Assert + db.refresh(test_team) + assert test_team.stripe_customer_id == "cus_123" + +@pytest.mark.asyncio +async def test_handle_subscription_deleted(db, test_team): + # Arrange + # First set up the team with a stripe customer ID + test_team.stripe_customer_id = "cus_123" + db.commit() + db.refresh(test_team) + + mock_event = Mock() + mock_event.type = "customer.subscription.deleted" + mock_subscription = Mock() + mock_subscription.customer = "cus_123" + mock_event.data.object = mock_subscription + + # Act + await handle_stripe_event_background(mock_event, db) + + # Assert + db.refresh(test_team) + assert test_team.stripe_customer_id is None + +@pytest.mark.asyncio +async def test_handle_checkout_session_completed_team_not_found(db): + # Arrange + mock_event = Mock() + mock_event.type = "checkout.session.completed" + mock_session = Mock() + mock_session.metadata = {"team_id": "999"} # Non-existent team ID + mock_event.data.object = mock_session + + # Act + await handle_stripe_event_background(mock_event, db) + + # No assertion needed as we're just verifying no error occurs + +@pytest.mark.asyncio +async def test_handle_subscription_deleted_team_not_found(db): + # Arrange + mock_event = Mock() + mock_event.type = "customer.subscription.deleted" + mock_subscription = Mock() + mock_subscription.customer = "cus_999" # Non-existent customer ID + mock_event.data.object = mock_subscription + + # Act + await handle_stripe_event_background(mock_event, db) + + # No assertion needed as we're just verifying no error occurs + +@pytest.mark.asyncio +async def test_handle_unknown_event_type(db): + # Arrange + mock_event = Mock() + mock_event.type = "unknown.event.type" + + # Act + await handle_stripe_event_background(mock_event, db) + + # No assertion needed as we're just verifying no error occurs \ No newline at end of file From 6385f16b0a8e2d7dff5aa2f8c929d56c9b13c2a2 Mon Sep 17 00:00:00 2001 From: Pippa H Date: Tue, 20 May 2025 13:49:57 +0200 Subject: [PATCH 08/23] Implement product management and user limits functionality - Added a new `DBTeamProduct` model to establish a many-to-many relationship between teams and products. - Enhanced the `products` API to support new fields for product management, including user limits and budget settings. - Introduced a `check_team_user_limit` function to enforce user limits based on active products. - Updated billing API to handle product application for teams, including extending key durations and setting budgets. - Created comprehensive tests for product management, user limits, and billing integration. - Refactored existing code to improve maintainability and clarity. --- app/api/billing.py | 60 +++-- app/api/products.py | 43 ++-- app/core/config.py | 3 +- app/core/resource_limits.py | 42 +++ app/core/worker.py | 107 ++++++++ app/db/models.py | 40 ++- ..._c2f3f1999f62_team_product_relationship.py | 73 ++++++ app/schemas/models.py | 23 +- app/services/litellm.py | 10 +- app/services/stripe.py | 57 ++++- scripts/initialise_resources.py | 21 +- tests/conftest.py | 24 +- tests/test_private_ai.py | 5 +- tests/test_products.py | 196 ++++++++++---- tests/test_resource_limits.py | 206 +++++++++++++++ tests/test_worker.py | 240 ++++++++++++++++++ 16 files changed, 1020 insertions(+), 130 deletions(-) create mode 100644 app/core/resource_limits.py create mode 100644 app/core/worker.py create mode 100644 app/migrations/versions/20250520_112722_c2f3f1999f62_team_product_relationship.py create mode 100644 tests/test_resource_limits.py create mode 100644 tests/test_worker.py diff --git a/app/api/billing.py b/app/api/billing.py index 3c246b0..19d05d8 100644 --- a/app/api/billing.py +++ b/app/api/billing.py @@ -1,11 +1,7 @@ -from typing import Optional from fastapi import APIRouter, Depends, HTTPException, status, Request, Response, BackgroundTasks from sqlalchemy.orm import Session import logging import os -from pydantic import BaseModel -import threading - from app.db.database import get_db from app.core.security import get_current_user_from_auth, check_specific_team_admin from app.db.models import DBUser, DBTeam, DBSystemSecret @@ -13,12 +9,16 @@ from app.services.stripe import ( create_checkout_session, get_stripe_event, - create_portal_session + create_portal_session, + get_product_id_from_sub, + get_product_id_from_session ) +from app.core.worker import apply_product_for_team # Configure logger logger = logging.getLogger(__name__) BILLING_WEBHOOK_KEY = "stripe_webhook_secret" +BILLING_WEBHOOK_ROUTE = "/billing/events" router = APIRouter( tags=["billing"] @@ -73,26 +73,29 @@ async def handle_stripe_event_background(event, db: Session): Background task to handle Stripe webhook events. This runs in a separate thread to avoid blocking the webhook response. """ + checkout_session_success = ["checkout.session.async_payment_succeeded", "checkout.session.completed"] + invoice_success = ["invoice.payment_succeeded"] + failure_events = ["checkout.session.async_payment_failed", "checkout.session.expired", "subscription.payment_failed", "customer.subscription.deleted"] try: event_type = event.type - - if event_type == "checkout.session.completed": - session = event.data.object - # Handle successful checkout - # Update team's subscription status in database - team = db.query(DBTeam).filter(DBTeam.id == session.metadata.get("team_id")).first() - if team: - team.stripe_customer_id = session.customer - db.commit() - - elif event_type == "customer.subscription.deleted": - subscription = event.data.object - # Handle subscription cancellation - # Update team's subscription status in database - team = db.query(DBTeam).filter(DBTeam.stripe_customer_id == subscription.customer).first() - if team: - team.stripe_customer_id = None - db.commit() + if not event_type in checkout_session_success + invoice_success + failure_events: + logger.info(f"Unknown event type: {event_type}") + return + event_object = event.data.object + customer_id = event_object.customer + if event_type in invoice_success: + subscription = event_object.parent.subscription_details.subscription + product_id = await get_product_id_from_sub(subscription) + await apply_product_for_team(db, customer_id, product_id) + elif event_type in checkout_session_success: + product_id = await get_product_id_from_session(event_object.id) + await apply_product_for_team(db, customer_id, product_id) + elif event_type in failure_events: + logger.info(f"Checkout session failed") + event_object = event.data.object + logger.info(f"ID: {event_object.id}") + logger.info(f"Full object: {event_object}") + # Handle failed checkout except Exception as e: logger.error(f"Error in background event handler: {str(e)}") @@ -110,9 +113,12 @@ async def handle_events( """ try: # Get the webhook secret from database - webhook_secret = db.query(DBSystemSecret).filter( - DBSystemSecret.key == BILLING_WEBHOOK_KEY - ).first() + if os.getenv("WEBHOOK_SIG"): + webhook_secret = os.getenv("WEBHOOK_SIG") + else: + webhook_secret = db.query(DBSystemSecret).filter( + DBSystemSecret.key == BILLING_WEBHOOK_KEY + ).first().value if not webhook_secret: raise HTTPException( @@ -124,7 +130,7 @@ async def handle_events( payload = await request.body() signature = request.headers.get("stripe-signature") - event = get_stripe_event(payload, signature, webhook_secret.value) + event = get_stripe_event(payload, signature, webhook_secret) # Add the event handling to background tasks background_tasks.add_task(handle_stripe_event_background, event, db) diff --git a/app/api/products.py b/app/api/products.py index ab9c5f8..2ed51de 100644 --- a/app/api/products.py +++ b/app/api/products.py @@ -21,18 +21,27 @@ async def create_product( """ Create a new product. Only accessible by system admin users. """ - # Check if stripe_lookup_key already exists - existing_product = db.query(DBProduct).filter(DBProduct.stripe_lookup_key == product.stripe_lookup_key).first() + # Check if product ID already exists + existing_product = db.query(DBProduct).filter(DBProduct.id == product.id).first() if existing_product: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Product with this stripe_lookup_key already exists" + detail="Product with this ID already exists" ) - # Create the product + # Create the product with all fields db_product = DBProduct( + id=product.id, name=product.name, - stripe_lookup_key=product.stripe_lookup_key, + user_count=product.user_count, + keys_per_user=product.keys_per_user, + total_key_count=product.total_key_count, + service_key_count=product.service_key_count, + max_budget_per_key=product.max_budget_per_key, + rpm_per_key=product.rpm_per_key, + vector_db_count=product.vector_db_count, + vector_db_storage=product.vector_db_storage, + renewal_period_days=product.renewal_period_days, active=product.active, created_at=datetime.now(UTC) ) @@ -54,7 +63,7 @@ async def list_products( @router.get("/{product_id}", response_model=Product, dependencies=[Depends(get_role_min_team_admin)]) async def get_product( - product_id: int, + product_id: str, db: Session = Depends(get_db) ): """ @@ -70,7 +79,7 @@ async def get_product( @router.put("/{product_id}", response_model=Product, dependencies=[Depends(check_system_admin)]) async def update_product( - product_id: int, + product_id: str, product_update: ProductUpdate, db: Session = Depends(get_db) ): @@ -84,22 +93,12 @@ async def update_product( detail="Product not found" ) - # If updating stripe_lookup_key, check if it already exists - if product_update.stripe_lookup_key and product_update.stripe_lookup_key != product.stripe_lookup_key: - existing_product = db.query(DBProduct).filter( - DBProduct.stripe_lookup_key == product_update.stripe_lookup_key, - DBProduct.id != product_id - ).first() - if existing_product: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Product with this stripe_lookup_key already exists" - ) - - # Update the product - for key, value in product_update.model_dump(exclude_unset=True).items(): + # Update the product with all provided fields + update_data = product_update.model_dump(exclude_unset=True) + for key, value in update_data.items(): setattr(product, key, value) + product.updated_at = datetime.now(UTC) db.commit() db.refresh(product) @@ -107,7 +106,7 @@ async def update_product( @router.delete("/{product_id}", dependencies=[Depends(check_system_admin)]) async def delete_product( - product_id: int, + product_id: str, db: Session = Depends(get_db) ): """ diff --git a/app/core/config.py b/app/core/config.py index 8ac3536..12f7b9e 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -29,7 +29,8 @@ class Settings(BaseSettings): DYNAMODB_REGION: str = "eu-west-1" SES_REGION: str = "eu-west-1" EXPIRE_KEYS: str = "true" - STRIPE_KEY: str = "sk_test_string" + STRIPE_SECRET_KEY: str = "sk_test_string" + WEBHOOK_SIG: str = "whsec_test_1234567890" model_config = ConfigDict(env_file=".env") diff --git a/app/core/resource_limits.py b/app/core/resource_limits.py new file mode 100644 index 0000000..447feee --- /dev/null +++ b/app/core/resource_limits.py @@ -0,0 +1,42 @@ +from sqlalchemy.orm import Session +from app.db.models import DBTeam, DBUser, DBProduct +from fastapi import HTTPException, status + +def check_team_user_limit(db: Session, team_id: int) -> None: + """ + Check if adding a user would exceed the team's product limits. + Raises HTTPException if the limit would be exceeded. + + Args: + db: Database session + team_id: ID of the team to check + """ + # Get current user count for the team + current_user_count = db.query(DBUser).filter(DBUser.team_id == team_id).count() + + # Get all active products for the team + team = db.query(DBTeam).filter(DBTeam.id == team_id).first() + if not team: + raise HTTPException(status_code=404, detail="Team not found") + + # If team has no products, use default limit of 2 + if not team.active_products: + if current_user_count >= 2: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Team has reached the default user limit of 2 users" + ) + return + + # Find the maximum user count allowed across all active products + max_user_count = max( + (product.user_count for team_product in team.active_products + for product in [team_product.product] if product.user_count), + default=2 # Default to 2 if no products have user_count set + ) + + if current_user_count >= max_user_count: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Team has reached the maximum user limit of {max_user_count} users" + ) diff --git a/app/core/worker.py b/app/core/worker.py new file mode 100644 index 0000000..79f9cb4 --- /dev/null +++ b/app/core/worker.py @@ -0,0 +1,107 @@ +from datetime import datetime, timedelta, UTC +from sqlalchemy.orm import Session +from app.db.models import DBTeam, DBProduct, DBTeamProduct, DBPrivateAIKey, DBRegion, DBUser +from app.services.litellm import LiteLLMService +from app.core.config import settings +import logging +from collections import defaultdict + +logger = logging.getLogger(__name__) + +async def apply_product_for_team(db: Session, customer_id: str, product_id: str) -> bool: + """ + Apply a product to a team and update their last payment date. + Also extends all team keys and sets their max budgets via LiteLLM service. + + Args: + db: Database session + customer_id: Stripe customer ID + product_id: Product ID from the database + + Returns: + bool: True if update was successful, False otherwise + """ + try: + # Find the team and product + team = db.query(DBTeam).filter(DBTeam.stripe_customer_id == customer_id).first() + product = db.query(DBProduct).filter(DBProduct.id == product_id).first() + + if not team: + logger.error(f"Team not found for customer ID: {customer_id}") + return False + if not product: + logger.error(f"Product not found for ID: {product_id}") + return False + + # Update the last payment date + team.last_payment = datetime.now(UTC) + + # Check if the product is already active for the team + existing_association = db.query(DBTeamProduct).filter( + DBTeamProduct.team_id == team.id, + DBTeamProduct.product_id == product.id + ).first() + + # Only create new association if it doesn't exist + if not existing_association: + team_product = DBTeamProduct( + team_id=team.id, + product_id=product.id + ) + db.add(team_product) + + # Get all keys for the team with their regions + team_users = db.query(DBUser).filter(DBUser.team_id == team.id).all() + team_user_ids = [user.id for user in team_users] + # Return keys owned by users in the team OR owned by the team + team_keys = db.query(DBPrivateAIKey).filter( + (DBPrivateAIKey.owner_id.in_(team_user_ids)) | + (DBPrivateAIKey.team_id == team.id) + ).all() + + # Group keys by region + keys_by_region = defaultdict(list) + for key in team_keys: + if not key.litellm_token: + logger.warning(f"Key {key.id} has no LiteLLM token, skipping") + continue + if not key.region: + logger.warning(f"Key {key.id} has no region, skipping") + continue + keys_by_region[key.region].append(key) + + # Update keys for each region + for region, keys in keys_by_region.items(): + # Initialize LiteLLM service for this region + litellm_service = LiteLLMService( + api_url=region.litellm_api_url, + api_key=region.litellm_api_key + ) + + # Update each key's duration and budget via LiteLLM + for key in keys: + try: + # Update key duration + await litellm_service.update_key_duration( + litellm_token=key.litellm_token, + duration=f"{product.renewal_period_days}d" + ) + + # Update key budget + await litellm_service.update_budget( + litellm_token=key.litellm_token, + budget_duration=f"{product.renewal_period_days}d", + budget_amount=product.max_budget_per_key + ) + except Exception as e: + logger.error(f"Failed to update key {key.id} via LiteLLM: {str(e)}") + # Continue with other keys even if one fails + continue + + db.commit() + return True + + except Exception as e: + db.rollback() + logger.error(f"Error applying product to team: {str(e)}") + raise e diff --git a/app/db/models.py b/app/db/models.py index 10e3927..a59cdda 100644 --- a/app/db/models.py +++ b/app/db/models.py @@ -1,10 +1,30 @@ -from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, DateTime, JSON +from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, DateTime, JSON, Float, Table from sqlalchemy.orm import relationship, declarative_base from datetime import datetime, UTC from sqlalchemy.sql import func Base = declarative_base() +class DBTeamProduct(Base): + """ + Association table for team-product relationship. + This model is required to implement a many-to-many relationship between teams and products. + It allows: + - Teams to subscribe to multiple products + - Products to be used by multiple teams + - Tracking when products were added to teams + - Maintaining referential integrity between teams and products + """ + __tablename__ = "team_products" + + team_id = Column(Integer, ForeignKey('teams.id', ondelete='CASCADE'), primary_key=True, nullable=False) + product_id = Column(String, ForeignKey('products.id', ondelete='CASCADE'), primary_key=True, nullable=False) + created_at = Column(DateTime(timezone=True), default=func.now(), nullable=False) + updated_at = Column(DateTime(timezone=True), onupdate=func.now(), nullable=True) + + team = relationship("DBTeam", back_populates="active_products") + product = relationship("DBProduct", back_populates="teams") + class DBRegion(Base): __tablename__ = "regions" @@ -63,9 +83,11 @@ class DBTeam(Base): created_at = Column(DateTime(timezone=True), default=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now()) stripe_customer_id = Column(String, nullable=True, unique=True, index=True) + last_payment = Column(DateTime(timezone=True), nullable=True) users = relationship("DBUser", back_populates="team") private_ai_keys = relationship("DBPrivateAIKey", back_populates="team") + active_products = relationship("DBTeamProduct", back_populates="team") class DBPrivateAIKey(Base): __tablename__ = "ai_tokens" @@ -134,9 +156,19 @@ class DBSystemSecret(Base): class DBProduct(Base): __tablename__ = "products" - id = Column(Integer, primary_key=True, index=True) + id = Column(String, primary_key=True, index=True) name = Column(String, nullable=False) - stripe_lookup_key = Column(String, unique=True, index=True, nullable=False) + user_count = Column(Integer, default=0) + keys_per_user = Column(Integer, default=0) + total_key_count = Column(Integer, default=0) + service_key_count = Column(Integer, default=0) + max_budget_per_key = Column(Float, default=0.0) + rpm_per_key = Column(Integer, default=0) + vector_db_count = Column(Integer, default=0) + vector_db_storage = Column(Integer, default=0) + renewal_period_days = Column(Integer, default=30) active = Column(Boolean, default=True) created_at = Column(DateTime(timezone=True), default=func.now()) - updated_at = Column(DateTime(timezone=True), onupdate=func.now()) \ No newline at end of file + updated_at = Column(DateTime(timezone=True), onupdate=func.now()) + + teams = relationship("DBTeamProduct", back_populates="product") \ No newline at end of file diff --git a/app/migrations/versions/20250520_112722_c2f3f1999f62_team_product_relationship.py b/app/migrations/versions/20250520_112722_c2f3f1999f62_team_product_relationship.py new file mode 100644 index 0000000..f1f1956 --- /dev/null +++ b/app/migrations/versions/20250520_112722_c2f3f1999f62_team_product_relationship.py @@ -0,0 +1,73 @@ +"""team product relationship + +Revision ID: c2f3f1999f62 +Revises: 1069b3617025 +Create Date: 2025-05-20 11:27:22.451900+00:00 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = 'c2f3f1999f62' +down_revision: Union[str, None] = '1069b3617025' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('products', sa.Column('user_count', sa.Integer(), nullable=True)) + op.add_column('products', sa.Column('keys_per_user', sa.Integer(), nullable=True)) + op.add_column('products', sa.Column('total_key_count', sa.Integer(), nullable=True)) + op.add_column('products', sa.Column('service_key_count', sa.Integer(), nullable=True)) + op.add_column('products', sa.Column('max_budget_per_key', sa.Float(), nullable=True)) + op.add_column('products', sa.Column('rpm_per_key', sa.Integer(), nullable=True)) + op.add_column('products', sa.Column('vector_db_count', sa.Integer(), nullable=True)) + op.add_column('products', sa.Column('vector_db_storage', sa.Integer(), nullable=True)) + op.add_column('products', sa.Column('renewal_period_days', sa.Integer(), nullable=True)) + op.alter_column('products', 'id', + existing_type=sa.INTEGER(), + type_=sa.String(), + existing_nullable=False) + op.drop_index('ix_products_stripe_lookup_key', table_name='products') + op.drop_column('products', 'stripe_lookup_key') + + op.create_table('team_products', + sa.Column('team_id', sa.Integer(), nullable=False), + sa.Column('product_id', sa.String(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), + sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True), + sa.ForeignKeyConstraint(['product_id'], ['products.id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['team_id'], ['teams.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('team_id', 'product_id') + ) + op.add_column('teams', sa.Column('last_payment', sa.DateTime(timezone=True), nullable=True)) + op.create_index(op.f('ix_teams_stripe_customer_id'), 'teams', ['stripe_customer_id'], unique=True) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_teams_stripe_customer_id'), table_name='teams') + op.drop_column('teams', 'last_payment') + op.add_column('products', sa.Column('stripe_lookup_key', sa.VARCHAR(), autoincrement=False, nullable=False)) + op.create_index('ix_products_stripe_lookup_key', 'products', ['stripe_lookup_key'], unique=True) + op.drop_table('team_products') + op.alter_column('products', 'id', + existing_type=sa.String(), + type_=sa.INTEGER(), + existing_nullable=False) + op.drop_column('products', 'renewal_period_days') + op.drop_column('products', 'vector_db_storage') + op.drop_column('products', 'vector_db_count') + op.drop_column('products', 'rpm_per_key') + op.drop_column('products', 'max_budget_per_key') + op.drop_column('products', 'service_key_count') + op.drop_column('products', 'total_key_count') + op.drop_column('products', 'keys_per_user') + op.drop_column('products', 'user_count') + # ### end Alembic commands ### \ No newline at end of file diff --git a/app/schemas/models.py b/app/schemas/models.py index 76c757d..194a595 100644 --- a/app/schemas/models.py +++ b/app/schemas/models.py @@ -275,7 +275,16 @@ class CheckoutSessionCreate(BaseModel): class ProductBase(BaseModel): name: str - stripe_lookup_key: str + id: str # This is the Stripe product ID, format should be prod_XXX + user_count: int = 1 + keys_per_user: int = 1 + total_key_count: int = 6 + service_key_count: int = 5 + max_budget_per_key: float = 20.0 + rpm_per_key: int = 500 + vector_db_count: int = 1 + vector_db_storage: int = 50 # Not used yet, should be a number in GiB + renewal_period_days: int = 30 active: bool = True class ProductCreate(ProductBase): @@ -283,11 +292,19 @@ class ProductCreate(ProductBase): class ProductUpdate(BaseModel): name: Optional[str] = None - stripe_lookup_key: Optional[str] = None + user_count: Optional[int] = None + keys_per_user: Optional[int] = None + total_key_count: Optional[int] = None + service_key_count: Optional[int] = None + max_budget_per_key: Optional[float] = None + rpm_per_key: Optional[int] = None + vector_db_count: Optional[int] = None + vector_db_storage: Optional[int] = None + renewal_period_days: Optional[int] = None active: Optional[bool] = None + model_config = ConfigDict(from_attributes=True) class Product(ProductBase): - id: int created_at: datetime updated_at: Optional[datetime] = None model_config = ConfigDict(from_attributes=True) \ No newline at end of file diff --git a/app/services/litellm.py b/app/services/litellm.py index 112224e..e5393c0 100644 --- a/app/services/litellm.py +++ b/app/services/litellm.py @@ -20,7 +20,6 @@ async def create_key(self, email: str, name: str, user_id: int, team_id: str) -> try: logger.info(f"Creating new LiteLLM API key for email: {email}, name: {name}, user_id: {user_id}, team_id: {team_id}") request_data = { - "duration": "14d" if os.getenv("EXPIRE_KEYS", "").lower() == "true" else "365d", # 14 days if EXPIRE_KEYS is true, otherwise 1 year "models": ["all-team-models"], # Allow access to all models "aliases": {}, "config": {}, @@ -39,6 +38,15 @@ async def create_key(self, email: str, name: str, user_id: int, team_id: str) -> request_data["key_alias"] = key_alias request_data["metadata"] = metadata request_data["team_id"] = team_id + + if os.getenv("EXPIRE_KEYS", "").lower() == "true": + request_data["duration"] = "30d" + request_data["budget_duration"] = "30d" + request_data["max_budget"] = 20.0 + request_data["rpm_limit"] = 500 + else: + request_data["duration"] = "365d" + if user_id is not None: request_data["user_id"] = str(user_id) diff --git a/app/services/stripe.py b/app/services/stripe.py index f8fbd72..1782d12 100644 --- a/app/services/stripe.py +++ b/app/services/stripe.py @@ -5,11 +5,11 @@ from urllib.parse import urljoin from fastapi import HTTPException, status from sqlalchemy.orm import Session - +from stripe._list_object import ListObject from app.db.models import DBTeam, DBSystemSecret # Configure logger -stripe_logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) # Initialize Stripe stripe.api_key = os.getenv("STRIPE_SECRET_KEY") @@ -65,7 +65,7 @@ async def create_checkout_session( return checkout_session.url except Exception as e: - stripe_logger.error(f"Error creating checkout session: {str(e)}") + logger.error(f"Error creating checkout session: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Error creating checkout session" @@ -84,6 +84,7 @@ def get_stripe_event( payload: bytes, signature: str, webhook_secret: str) -> st stripe.Event: The Stripe event """ try: + logger.info(f"Trying to decode event") event = stripe.Webhook.construct_event( payload, signature, webhook_secret ) @@ -100,7 +101,7 @@ def get_stripe_event( payload: bytes, signature: str, webhook_secret: str) -> st detail="Invalid signature" ) except Exception as e: - stripe_logger.error(f"Error handling Stripe event: {str(e)}") + logger.error(f"Error handling Stripe event: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Error processing webhook" @@ -135,13 +136,13 @@ async def create_portal_session( return portal_session.url except Exception as e: - stripe_logger.error(f"Error creating portal session: {str(e)}") + logger.error(f"Error creating portal session: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Error creating portal session" ) -async def setup_stripe_webhook(webhook_key: str, db: Session) -> None: +async def setup_stripe_webhook(webhook_key: str, webhook_route: str, db: Session) -> None: """ Set up the Stripe webhook endpoint if it doesn't exist and store its signing secret. @@ -160,7 +161,7 @@ async def setup_stripe_webhook(webhook_key: str, db: Session) -> None: # Get the base URL from environment base_url = os.getenv("BACKEND_URL", "http://localhost:8800") - webhook_url = urljoin(base_url, "/api/stripe/handle-event") + webhook_url = urljoin(base_url, webhook_route) # List existing webhook endpoints endpoints = stripe.WebhookEndpoint.list() @@ -176,7 +177,7 @@ async def setup_stripe_webhook(webhook_key: str, db: Session) -> None: # For existing endpoints, we need to create a new one to get the secret # First delete the old endpoint stripe.WebhookEndpoint.delete(existing_endpoint.id) - stripe_logger.info(f"Deleted existing webhook endpoint: {existing_endpoint.id}") + logger.info(f"Deleted existing webhook endpoint: {existing_endpoint.id}") # Create new webhook endpoint endpoint = stripe.WebhookEndpoint.create( @@ -197,10 +198,10 @@ async def setup_stripe_webhook(webhook_key: str, db: Session) -> None: db.commit() except Exception as e: - stripe_logger.error(f"Error setting up Stripe webhook: {str(e)}") + logger.error(f"Error setting up Stripe webhook: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Error setting up Stripe webhook" + detail=f"Error setting up Stripe webhook, {str(e)}" ) async def create_stripe_customer( @@ -242,8 +243,42 @@ async def create_stripe_customer( return customer.id except Exception as e: - stripe_logger.error(f"Error creating Stripe customer: {str(e)}") + logger.error(f"Error creating Stripe customer: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Error creating Stripe customer" ) + +async def get_product_id_from_sub(subscription_id: str) -> str: + """ + Get the Stripe product ID for the team's subscription. + + Args: + subscription_id: The Stripe subscription ID + + Returns: + str: The Stripe product ID + """ + # Get the list of subscription items + subscription_items = stripe.SubscriptionItem.list( + subscription=subscription_id, + expand=['data.price.product'] + ) + + if not subscription_items.data: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="No items found in subscription" + ) + + return subscription_items.data[0].price.product.id + +async def get_product_id_from_session(session_id: str) -> str: + """ + Get the Stripe product ID for the team's subscription from a checkout session. + + Args: + session_id: The Stripe checkout session ID + """ + line_items = stripe.checkout.Session.list_line_items(session_id) + return line_items.data[0].price.product diff --git a/scripts/initialise_resources.py b/scripts/initialise_resources.py index 723d3e3..e8b9bf8 100644 --- a/scripts/initialise_resources.py +++ b/scripts/initialise_resources.py @@ -6,20 +6,20 @@ # Add the parent directory to the Python path sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) -import psycopg2 import alembic.config import alembic.command +import asyncio +import glob from sqlalchemy import inspect -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import sessionmaker, Session from app.db.database import engine from app.db.models import Base, DBUser from app.core.security import get_password_hash from app.services.ses import SESService from app.services.stripe import setup_stripe_webhook -from app.api.billing import BILLING_WEBHOOK_KEY -import glob +from app.api.billing import BILLING_WEBHOOK_KEY, BILLING_WEBHOOK_ROUTE -def init_database(): +def init_database() -> Session: # Check if database is empty (no tables exist) inspector = inspect(engine) existing_tables = inspector.get_table_names() @@ -76,22 +76,23 @@ def init_database(): except Exception as e: print(f"Error creating admin user: {str(e)}") db.rollback() + finally: + return db +def init_webhooks(db: Session): # Only set up Stripe webhook in specific environments env_suffix = os.getenv("ENV_SUFFIX", "").lower() if env_suffix in ["dev", "main", "prod"]: try: # Set up Stripe webhook print("Setting up Stripe Billing webhook...") - setup_stripe_webhook(BILLING_WEBHOOK_KEY, db) + asyncio.run(setup_stripe_webhook(BILLING_WEBHOOK_KEY, BILLING_WEBHOOK_ROUTE, db)) print("Stripe Billing webhook set up successfully") except Exception as e: print(f"Warning: Failed to set up Stripe Billing webhook: {str(e)}") else: print(f"Skipping Stripe webhook setup for environment: {env_suffix}") - db.close() - def init_ses_templates(): if os.getenv("PASSWORDLESS_SIGN_IN", "").lower() == "true": # Initialize SES templates @@ -113,8 +114,10 @@ def init_ses_templates(): def main(): try: - init_database() + db = init_database() + init_webhooks(db) init_ses_templates() + db.close() except Exception as e: print(f"Error during initialization: {str(e)}") sys.exit(1) diff --git a/tests/conftest.py b/tests/conftest.py index e5cc30a..9d8916a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,7 @@ from sqlalchemy.orm import sessionmaker from app.main import app from app.db.database import get_db -from app.db.models import Base, DBRegion, DBUser, DBTeam +from app.db.models import Base, DBRegion, DBUser, DBTeam, DBProduct import os from app.core.security import get_password_hash from datetime import datetime, UTC, timedelta @@ -110,6 +110,28 @@ def test_team(db): db.refresh(team) return team +@pytest.fixture +def test_product(db): + product = DBProduct( + id="prod_test123", + name="Test Product", + user_count=5, + keys_per_user=2, + total_key_count=10, + service_key_count=2, + max_budget_per_key=50.0, + rpm_per_key=1000, + vector_db_count=1, + vector_db_storage=100, + renewal_period_days=30, + active=True, + created_at=datetime.now(UTC) + ) + db.add(product) + db.commit() + db.refresh(product) + return product + @pytest.fixture def test_team_id(test_team): return test_team.id diff --git a/tests/test_private_ai.py b/tests/test_private_ai.py index 8ba9e79..bb03022 100644 --- a/tests/test_private_ai.py +++ b/tests/test_private_ai.py @@ -1127,7 +1127,10 @@ def test_create_llm_token_with_expiration(mock_post, client, admin_token, test_r # Verify that the LiteLLM API was called with the correct duration mock_post.assert_called_once() call_args = mock_post.call_args[1] - assert call_args["json"]["duration"] == "14d" # Verify 14-day duration format + assert call_args["json"]["duration"] == "30d" # Verify 1 month + assert call_args["json"]["budget_duration"] == "30d" # Verify 1 month + assert call_args["json"]["max_budget"] == 20.0 # Verify 20.0 + assert call_args["json"]["rpm_limit"] == 500 # Verify 500 def test_create_vector_db_as_system_admin(client, admin_token, test_region): """Test that a system admin can create a vector database for themselves""" diff --git a/tests/test_products.py b/tests/test_products.py index 8114f94..775c845 100644 --- a/tests/test_products.py +++ b/tests/test_products.py @@ -16,25 +16,42 @@ def test_create_product_as_system_admin(client, admin_token, db): "/products/", headers={"Authorization": f"Bearer {admin_token}"}, json={ + "id": "prod_test123", "name": "Test Product", - "stripe_lookup_key": "test_product_123", + "user_count": 5, + "keys_per_user": 2, + "total_key_count": 10, + "service_key_count": 2, + "max_budget_per_key": 50.0, + "rpm_per_key": 1000, + "vector_db_count": 1, + "vector_db_storage": 100, + "renewal_period_days": 30, "active": True } ) assert response.status_code == 201 data = response.json() assert data["name"] == "Test Product" - assert data["stripe_lookup_key"] == "test_product_123" + assert data["id"] == "prod_test123" + assert data["user_count"] == 5 + assert data["keys_per_user"] == 2 + assert data["total_key_count"] == 10 + assert data["service_key_count"] == 2 + assert data["max_budget_per_key"] == 50.0 + assert data["rpm_per_key"] == 1000 + assert data["vector_db_count"] == 1 + assert data["vector_db_storage"] == 100 + assert data["renewal_period_days"] == 30 assert data["active"] is True - assert "id" in data assert "created_at" in data -def test_create_product_duplicate_key(client, admin_token, db): +def test_create_product_duplicate_id(client, admin_token, db): """ - Test that creating a product with a duplicate stripe_lookup_key fails. + Test that creating a product with a duplicate ID fails. - GIVEN: A product with a specific stripe_lookup_key exists - WHEN: A system admin tries to create another product with the same key + GIVEN: A product with a specific ID exists + WHEN: A system admin tries to create another product with the same ID THEN: A 400 - Bad Request is returned """ # First create a product @@ -42,19 +59,37 @@ def test_create_product_duplicate_key(client, admin_token, db): "/products/", headers={"Authorization": f"Bearer {admin_token}"}, json={ + "id": "prod_test123", "name": "Test Product", - "stripe_lookup_key": "test_product_123", + "user_count": 5, + "keys_per_user": 2, + "total_key_count": 10, + "service_key_count": 2, + "max_budget_per_key": 50.0, + "rpm_per_key": 1000, + "vector_db_count": 1, + "vector_db_storage": 100, + "renewal_period_days": 30, "active": True } ) - # Try to create another product with the same key + # Try to create another product with the same ID response = client.post( "/products/", headers={"Authorization": f"Bearer {admin_token}"}, json={ + "id": "prod_test123", "name": "Another Product", - "stripe_lookup_key": "test_product_123", + "user_count": 5, + "keys_per_user": 2, + "total_key_count": 10, + "service_key_count": 2, + "max_budget_per_key": 50.0, + "rpm_per_key": 1000, + "vector_db_count": 1, + "vector_db_storage": 100, + "renewal_period_days": 30, "active": True } ) @@ -73,8 +108,17 @@ def test_create_product_unauthorized(client, test_token, db): "/products/", headers={"Authorization": f"Bearer {test_token}"}, json={ + "id": "prod_test123", "name": "Test Product", - "stripe_lookup_key": "test_product_123", + "user_count": 5, + "keys_per_user": 2, + "total_key_count": 10, + "service_key_count": 2, + "max_budget_per_key": 50.0, + "rpm_per_key": 1000, + "vector_db_count": 1, + "vector_db_storage": 100, + "renewal_period_days": 30, "active": True } ) @@ -90,14 +134,32 @@ def test_list_products_as_team_admin(client, team_admin_token, db): """ # Create some test products db_product1 = DBProduct( + id="prod_test1", name="Test Product 1", - stripe_lookup_key="test_product_1", + user_count=5, + keys_per_user=2, + total_key_count=10, + service_key_count=2, + max_budget_per_key=50.0, + rpm_per_key=1000, + vector_db_count=1, + vector_db_storage=100, + renewal_period_days=30, active=True, created_at=datetime.now(UTC) ) db_product2 = DBProduct( + id="prod_test2", name="Test Product 2", - stripe_lookup_key="test_product_2", + user_count=10, + keys_per_user=3, + total_key_count=30, + service_key_count=5, + max_budget_per_key=100.0, + rpm_per_key=2000, + vector_db_count=2, + vector_db_storage=200, + renewal_period_days=60, active=True, created_at=datetime.now(UTC) ) @@ -139,8 +201,17 @@ def test_get_product_as_team_admin(client, team_admin_token, db): """ # Create a test product db_product = DBProduct( + id="prod_test123", name="Test Product", - stripe_lookup_key="test_product_123", + user_count=5, + keys_per_user=2, + total_key_count=10, + service_key_count=2, + max_budget_per_key=50.0, + rpm_per_key=1000, + vector_db_count=1, + vector_db_storage=100, + renewal_period_days=30, active=True, created_at=datetime.now(UTC) ) @@ -154,7 +225,16 @@ def test_get_product_as_team_admin(client, team_admin_token, db): assert response.status_code == 200 data = response.json() assert data["name"] == "Test Product" - assert data["stripe_lookup_key"] == "test_product_123" + assert data["id"] == "prod_test123" + assert data["user_count"] == 5 + assert data["keys_per_user"] == 2 + assert data["total_key_count"] == 10 + assert data["service_key_count"] == 2 + assert data["max_budget_per_key"] == 50.0 + assert data["rpm_per_key"] == 1000 + assert data["vector_db_count"] == 1 + assert data["vector_db_storage"] == 100 + assert data["renewal_period_days"] == 30 assert data["active"] is True def test_get_product_not_found(client, team_admin_token, db): @@ -166,7 +246,7 @@ def test_get_product_not_found(client, team_admin_token, db): THEN: A 404 - Not Found is returned """ response = client.get( - "/products/99999", + "/products/prod_nonexistent", headers={"Authorization": f"Bearer {team_admin_token}"} ) assert response.status_code == 404 @@ -181,8 +261,17 @@ def test_update_product_as_system_admin(client, admin_token, db): """ # Create a test product db_product = DBProduct( + id="prod_test123", name="Test Product", - stripe_lookup_key="test_product_123", + user_count=5, + keys_per_user=2, + total_key_count=10, + service_key_count=2, + max_budget_per_key=50.0, + rpm_per_key=1000, + vector_db_count=1, + vector_db_storage=100, + renewal_period_days=30, active=True, created_at=datetime.now(UTC) ) @@ -194,13 +283,31 @@ def test_update_product_as_system_admin(client, admin_token, db): headers={"Authorization": f"Bearer {admin_token}"}, json={ "name": "Updated Product", + "user_count": 10, + "keys_per_user": 3, + "total_key_count": 30, + "service_key_count": 5, + "max_budget_per_key": 100.0, + "rpm_per_key": 2000, + "vector_db_count": 2, + "vector_db_storage": 200, + "renewal_period_days": 60, "active": False } ) assert response.status_code == 200 data = response.json() assert data["name"] == "Updated Product" - assert data["stripe_lookup_key"] == "test_product_123" # Unchanged + assert data["id"] == "prod_test123" # ID should remain unchanged + assert data["user_count"] == 10 + assert data["keys_per_user"] == 3 + assert data["total_key_count"] == 30 + assert data["service_key_count"] == 5 + assert data["max_budget_per_key"] == 100.0 + assert data["rpm_per_key"] == 2000 + assert data["vector_db_count"] == 2 + assert data["vector_db_storage"] == 200 + assert data["renewal_period_days"] == 60 assert data["active"] is False def test_update_product_unauthorized(client, team_admin_token, db): @@ -213,8 +320,17 @@ def test_update_product_unauthorized(client, team_admin_token, db): """ # Create a test product db_product = DBProduct( + id="prod_test123", name="Test Product", - stripe_lookup_key="test_product_123", + user_count=5, + keys_per_user=2, + total_key_count=10, + service_key_count=2, + max_budget_per_key=50.0, + rpm_per_key=1000, + vector_db_count=1, + vector_db_storage=100, + renewal_period_days=30, active=True, created_at=datetime.now(UTC) ) @@ -240,8 +356,17 @@ def test_delete_product_as_system_admin(client, admin_token, db): """ # Create a test product db_product = DBProduct( + id="prod_test123", name="Test Product", - stripe_lookup_key="test_product_123", + user_count=5, + keys_per_user=2, + total_key_count=10, + service_key_count=2, + max_budget_per_key=50.0, + rpm_per_key=1000, + vector_db_count=1, + vector_db_storage=100, + renewal_period_days=30, active=True, created_at=datetime.now(UTC) ) @@ -253,36 +378,7 @@ def test_delete_product_as_system_admin(client, admin_token, db): headers={"Authorization": f"Bearer {admin_token}"} ) assert response.status_code == 200 - assert response.json()["message"] == "Product deleted successfully" # Verify the product is deleted deleted_product = db.query(DBProduct).filter(DBProduct.id == db_product.id).first() - assert deleted_product is None - -def test_delete_product_unauthorized(client, team_admin_token, db): - """ - Test that a team admin cannot delete a product. - - GIVEN: The authenticated user is a team admin - WHEN: They try to delete a product - THEN: A 403 - Forbidden is returned - """ - # Create a test product - db_product = DBProduct( - name="Test Product", - stripe_lookup_key="test_product_123", - active=True, - created_at=datetime.now(UTC) - ) - db.add(db_product) - db.commit() - - response = client.delete( - f"/products/{db_product.id}", - headers={"Authorization": f"Bearer {team_admin_token}"} - ) - assert response.status_code == 403 - - # Verify the product still exists - product = db.query(DBProduct).filter(DBProduct.id == db_product.id).first() - assert product is not None \ No newline at end of file + assert deleted_product is None \ No newline at end of file diff --git a/tests/test_resource_limits.py b/tests/test_resource_limits.py new file mode 100644 index 0000000..42b8e01 --- /dev/null +++ b/tests/test_resource_limits.py @@ -0,0 +1,206 @@ +import pytest +from fastapi.testclient import TestClient +from sqlalchemy.orm import Session +from app.db.models import DBUser, DBTeam, DBProduct, DBTeamProduct +from datetime import datetime, UTC + +def test_add_user_within_product_limit(client, admin_token, db, test_team, test_product): + """Test adding a user when within product user limit""" + # Add product to team + team_id = test_team.id + product_id = test_product.id + team_product = DBTeamProduct( + team_id=team_id, + product_id=product_id + ) + db.add(team_product) + db.commit() + + # Create a user to add + user = DBUser( + email="newuser@example.com", + hashed_password="hashed_password", + is_active=True, + is_admin=False, + role="user", + team_id=None, + created_at=datetime.now(UTC) + ) + db.add(user) + db.commit() + + # Add user to team + response = client.post( + f"/users/{user.id}/add-to-team", + headers={"Authorization": f"Bearer {admin_token}"}, + json={"team_id": team_id} + ) + assert response.status_code == 200 + assert response.json()["team_id"] == team_id + +def test_add_user_exceeding_product_limit(client, admin_token, db, test_team, test_product): + """Test adding a user when it would exceed product user limit""" + # Add product to team + team_product = DBTeamProduct( + team_id=test_team.id, + product_id=test_product.id + ) + db.add(team_product) + db.commit() + + # Create and add users up to the limit + for i in range(test_product.user_count): + user = DBUser( + email=f"user{i}@example.com", + hashed_password="hashed_password", + is_active=True, + is_admin=False, + role="user", + team_id=test_team.id, + created_at=datetime.now(UTC) + ) + db.add(user) + db.commit() + + # Try to add one more user + new_user = DBUser( + email="newuser@example.com", + hashed_password="hashed_password", + is_active=True, + is_admin=False, + role="user", + team_id=None, + created_at=datetime.now(UTC) + ) + db.add(new_user) + db.commit() + + response = client.post( + f"/users/{new_user.id}/add-to-team", + headers={"Authorization": f"Bearer {admin_token}"}, + json={"team_id": test_team.id} + ) + assert response.status_code == 400 + assert f"Team has reached the maximum user limit of {test_product.user_count} users" in response.json()["detail"] + +def test_add_user_with_default_limit(client, admin_token, db, test_team): + """Test adding users with default limit when team has no products""" + # Create and add users up to the default limit (2) + for i in range(2): + user = DBUser( + email=f"user{i}@example.com", + hashed_password="hashed_password", + is_active=True, + is_admin=False, + role="user", + team_id=test_team.id, + created_at=datetime.now(UTC) + ) + db.add(user) + db.commit() + + # Try to add one more user + new_user = DBUser( + email="newuser@example.com", + hashed_password="hashed_password", + is_active=True, + is_admin=False, + role="user", + team_id=None, + created_at=datetime.now(UTC) + ) + db.add(new_user) + db.commit() + + response = client.post( + f"/users/{new_user.id}/add-to-team", + headers={"Authorization": f"Bearer {admin_token}"}, + json={"team_id": test_team.id} + ) + assert response.status_code == 400 + assert "Team has reached the default user limit of 2 users" in response.json()["detail"] + +def test_add_user_with_multiple_products(client, admin_token, db, test_team): + """Test adding users when team has multiple products with different limits""" + # Create two products with different user limits + product1 = DBProduct( + id="prod_test1", + name="Test Product 1", + user_count=3, + keys_per_user=2, + total_key_count=10, + service_key_count=2, + max_budget_per_key=50.0, + rpm_per_key=1000, + vector_db_count=1, + vector_db_storage=100, + renewal_period_days=30, + active=True, + created_at=datetime.now(UTC) + ) + product2 = DBProduct( + id="prod_test2", + name="Test Product 2", + user_count=5, + keys_per_user=2, + total_key_count=10, + service_key_count=2, + max_budget_per_key=50.0, + rpm_per_key=1000, + vector_db_count=1, + vector_db_storage=100, + renewal_period_days=30, + active=True, + created_at=datetime.now(UTC) + ) + db.add(product1) + db.add(product2) + db.commit() + + # Add both products to team + team_product1 = DBTeamProduct( + team_id=test_team.id, + product_id=product1.id + ) + team_product2 = DBTeamProduct( + team_id=test_team.id, + product_id=product2.id + ) + db.add(team_product1) + db.add(team_product2) + db.commit() + + # Create and add users up to the higher limit (5) + for i in range(5): + user = DBUser( + email=f"user{i}@example.com", + hashed_password="hashed_password", + is_active=True, + is_admin=False, + role="user", + team_id=test_team.id, + created_at=datetime.now(UTC) + ) + db.add(user) + db.commit() + + # Try to add one more user + new_user = DBUser( + email="newuser@example.com", + hashed_password="hashed_password", + is_active=True, + is_admin=False, + role="user", + team_id=None, + created_at=datetime.now(UTC) + ) + db.add(new_user) + db.commit() + + response = client.post( + f"/users/{new_user.id}/add-to-team", + headers={"Authorization": f"Bearer {admin_token}"}, + json={"team_id": test_team.id} + ) + assert response.status_code == 400 + assert f"Team has reached the maximum user limit of {product2.user_count} users" in response.json()["detail"] diff --git a/tests/test_worker.py b/tests/test_worker.py new file mode 100644 index 0000000..d1b21aa --- /dev/null +++ b/tests/test_worker.py @@ -0,0 +1,240 @@ +import pytest +from fastapi.testclient import TestClient +from sqlalchemy.orm import Session +from app.db.models import DBProduct, DBTeam, DBUser, DBPrivateAIKey +from datetime import datetime, UTC, timedelta +from app.core.worker import apply_product_for_team +from unittest.mock import AsyncMock, patch + +@pytest.mark.asyncio +async def test_apply_product_success(db, test_team, test_product): + """ + Test successful application of a product to a team. + + GIVEN: A team and a product exist in the database + WHEN: The product is applied to the team + THEN: The team's active products list is updated and last payment date is set + """ + # Set stripe customer ID for the test team + test_team.stripe_customer_id = "cus_test123" + db.commit() + + # Apply product to team + result = await apply_product_for_team(db, test_team.stripe_customer_id, test_product.id) + + # Verify the result + assert result is True + + # Refresh team from database + db.refresh(test_team) + + # Verify team was updated correctly + assert len(test_team.active_products) == 1 + assert test_team.active_products[0].product.id == test_product.id + assert test_team.last_payment is not None + +@pytest.mark.asyncio +async def test_apply_product_team_not_found(db, test_product): + """ + Test applying a product when team is not found. + + GIVEN: A product exists but team does not + WHEN: Attempting to apply the product + THEN: The operation returns False + """ + # Try to apply product to non-existent team + result = await apply_product_for_team(db, "cus_nonexistent", test_product.id) + assert result is False + +@pytest.mark.asyncio +async def test_apply_product_product_not_found(db, test_team): + """ + Test applying a non-existent product to a team. + + GIVEN: A team exists but product does not + WHEN: Attempting to apply the product + THEN: The operation returns False + """ + # Set stripe customer ID for the test team + test_team.stripe_customer_id = "cus_test123" + db.commit() + + # Try to apply non-existent product + result = await apply_product_for_team(db, test_team.stripe_customer_id, "prod_nonexistent") + assert result is False + +@pytest.mark.asyncio +async def test_apply_product_multiple_products(db, test_team, test_product): + """ + Test applying multiple products to a team. + + GIVEN: A team and multiple products exist + WHEN: Multiple products are applied to the team + THEN: All products are added to the team's active products list + """ + # Set stripe customer ID for the test team + test_team.stripe_customer_id = "cus_test123" + db.commit() + + # Create additional test products + products = [test_product] # Start with the fixture product + for i in range(2): # Create 2 more products + product = DBProduct( + id=f"prod_test{i+1}", + name=f"Test Product {i+1}", + user_count=5, + keys_per_user=2, + total_key_count=10, + service_key_count=2, + max_budget_per_key=50.0, + rpm_per_key=1000, + vector_db_count=1, + vector_db_storage=100, + renewal_period_days=30, + active=True, + created_at=datetime.now(UTC) + ) + db.add(product) + products.append(product) + db.commit() + + # Apply each product to the team + for product in products: + result = await apply_product_for_team(db, test_team.stripe_customer_id, product.id) + assert result is True + + # Refresh team from database + db.refresh(test_team) + + # Verify all products were added + assert len(test_team.active_products) == 3 + product_ids = [team_product.product.id for team_product in test_team.active_products] + assert all(expected_product.id in product_ids for expected_product in products) + +@pytest.mark.asyncio +async def test_apply_product_already_active(db, test_team, test_product): + """ + Test applying a product that is already active for a team. + + GIVEN: A team has a specific product already active + WHEN: That product is applied for the team + THEN: The last payment date is updated, but the list of products is unchanged + """ + # Set stripe customer ID for the test team + test_team.stripe_customer_id = "cus_test123" + db.add(test_team) # Ensure the team is added to the session + db.commit() + db.refresh(test_team) # Refresh to ensure we have the latest data + + # First apply the product + result = await apply_product_for_team(db, test_team.stripe_customer_id, test_product.id) + assert result is True + + # Get the initial last payment date + db.refresh(test_team) + initial_last_payment = test_team.last_payment + + # Apply the same product again + result = await apply_product_for_team(db, test_team.stripe_customer_id, test_product.id) + assert result is True + + # Refresh team from database + db.refresh(test_team) + + # Verify the product list is unchanged + assert len(test_team.active_products) == 1 + assert test_team.active_products[0].product.id == test_product.id + + # Verify the last payment date was updated + assert test_team.last_payment > initial_last_payment + +@pytest.mark.asyncio +@patch('app.core.worker.LiteLLMService') +async def test_apply_product_extends_keys_and_sets_budget(mock_litellm, db, test_team, test_product, test_region, test_team_user, test_team_key_creator): + """ + Test that applying a product extends keys and sets max budget correctly. + + GIVEN: A team with users and keys (both team-owned and user-owned), and a product which specifies a max_budget of $20 per key + with a renewal period of 30 days + WHEN: The product is applied to the team + THEN: All keys for the team and users in the team are extended and the max_budget is set correctly + """ + # Set stripe customer ID for the test team + test_team.stripe_customer_id = "cus_test123" + db.commit() + + # Create test keys for the team + team_keys = [] + for i in range(2): # 2 team-owned keys + key = DBPrivateAIKey( + name=f"Team Key {i}", + database_name=f"db_team_{i}", + database_username="test_user", + database_password="test_pass", + team_id=test_team.id, + region_id=test_region.id, + litellm_token=f"test_token_team_{i}", + created_at=datetime.now(UTC) + ) + db.add(key) + team_keys.append(key) + + # Create test keys for both team users + user_keys = [] + for user in [test_team_user, test_team_key_creator]: + for i in range(2): # 2 keys per user + key = DBPrivateAIKey( + name=f"User Key {i} for {user.email}", + database_name=f"db_user_{user.id}_{i}", + database_username="test_user", + database_password="test_pass", + owner_id=user.id, + team_id=test_team.id, + region_id=test_region.id, + litellm_token=f"test_token_user_{user.id}_{i}", + created_at=datetime.now(UTC) + ) + db.add(key) + user_keys.append(key) + db.commit() + + # Setup mock instance + mock_instance = mock_litellm.return_value + mock_instance.update_key_duration = AsyncMock() + mock_instance.update_budget = AsyncMock() + + # Apply product to team + result = await apply_product_for_team(db, test_team.stripe_customer_id, test_product.id) + assert result is True + + # Verify LiteLLM service was initialized with correct region settings + mock_litellm.assert_called_once_with( + api_url=test_region.litellm_api_url, + api_key=test_region.litellm_api_key + ) + + # Verify LiteLLM service was called for all keys (both team and user owned) + all_keys = team_keys + user_keys + assert mock_instance.update_key_duration.call_count == len(all_keys) + assert mock_instance.update_budget.call_count == len(all_keys) + + # Verify each key was updated with correct duration and budget + for key in all_keys: + # Verify duration update + duration_calls = [call for call in mock_instance.update_key_duration.call_args_list + if call[1]['litellm_token'] == key.litellm_token] + assert len(duration_calls) == 1 + assert duration_calls[0][1]['duration'] == f"{test_product.renewal_period_days}d" + + # Verify budget update + budget_calls = [call for call in mock_instance.update_budget.call_args_list + if call[1]['litellm_token'] == key.litellm_token] + assert len(budget_calls) == 1 + assert budget_calls[0][1]['budget_duration'] == f"{test_product.renewal_period_days}d" + assert budget_calls[0][1]['budget_amount'] == test_product.max_budget_per_key + + # Verify team was updated correctly + db.refresh(test_team) + assert len(test_team.active_products) == 1 + assert test_team.active_products[0].product.id == test_product.id + assert test_team.last_payment is not None \ No newline at end of file From d48e46eef113888174703c94bc3fd939f80a283f Mon Sep 17 00:00:00 2001 From: Pippa H Date: Tue, 20 May 2025 14:28:24 +0200 Subject: [PATCH 09/23] Add key limit checks and enhance user limit validation - Introduced a new `check_key_limits` function to validate LLM token creation against team and user limits. - Updated `check_team_user_limit` to remove default limit checks for teams without products. - Enhanced tests to cover scenarios for creating LLM tokens, including exceeding total, user, and service key limits. - Refactored existing tests to utilize the new key limit checks and ensure comprehensive coverage for user and key limits. --- app/core/resource_limits.py | 87 +++++++- tests/test_resource_limits.py | 361 ++++++++++++++++++++++++++-------- 2 files changed, 361 insertions(+), 87 deletions(-) diff --git a/app/core/resource_limits.py b/app/core/resource_limits.py index 447feee..4109b05 100644 --- a/app/core/resource_limits.py +++ b/app/core/resource_limits.py @@ -1,6 +1,7 @@ from sqlalchemy.orm import Session -from app.db.models import DBTeam, DBUser, DBProduct +from app.db.models import DBTeam, DBUser, DBProduct, DBPrivateAIKey from fastapi import HTTPException, status +from typing import Optional def check_team_user_limit(db: Session, team_id: int) -> None: """ @@ -19,15 +20,6 @@ def check_team_user_limit(db: Session, team_id: int) -> None: if not team: raise HTTPException(status_code=404, detail="Team not found") - # If team has no products, use default limit of 2 - if not team.active_products: - if current_user_count >= 2: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Team has reached the default user limit of 2 users" - ) - return - # Find the maximum user count allowed across all active products max_user_count = max( (product.user_count for team_product in team.active_products @@ -40,3 +32,78 @@ def check_team_user_limit(db: Session, team_id: int) -> None: status_code=status.HTTP_400_BAD_REQUEST, detail=f"Team has reached the maximum user limit of {max_user_count} users" ) + +def check_key_limits(db: Session, team_id: int, owner_id: Optional[int] = None) -> None: + """ + Check if creating a new LLM token would exceed the team's or user's key limits. + Raises HTTPException if any limit would be exceeded. + + Args: + db: Database session + team_id: ID of the team to check + owner_id: Optional ID of the user who will own the key + """ + # Get the team and its active products + team = db.query(DBTeam).filter(DBTeam.id == team_id).first() + if not team: + raise HTTPException(status_code=404, detail="Team not found") + + # Find the maximum limits across all active products, using defaults if no products + max_total_keys = max( + (product.total_key_count for team_product in team.active_products + for product in [team_product.product] if product.total_key_count), + default=2 # Default to 2 if no products have total_key_count set + ) + max_keys_per_user = max( + (product.keys_per_user for team_product in team.active_products + for product in [team_product.product] if product.keys_per_user), + default=1 # Default to 1 if no products have keys_per_user set + ) + max_service_keys = max( + (product.service_key_count for team_product in team.active_products + for product in [team_product.product] if product.service_key_count), + default=1 # Default to 1 if no products have service_key_count set + ) + + # Get all users in the team + team_users = db.query(DBUser).filter(DBUser.team_id == team_id).all() + user_ids = [user.id for user in team_users] + + # Check total team LLM tokens (both team-owned and user-owned) + current_team_tokens = db.query(DBPrivateAIKey).filter( + ( + (DBPrivateAIKey.team_id == team_id) | # Team-owned tokens + (DBPrivateAIKey.owner_id.in_(user_ids)) # User-owned tokens + ), + DBPrivateAIKey.litellm_token.isnot(None) # Only count LLM tokens + ).count() + if current_team_tokens >= max_total_keys: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Team has reached the maximum LLM token limit of {max_total_keys} tokens" + ) + + # Check user LLM tokens if owner_id is provided + if owner_id is not None: + current_user_tokens = db.query(DBPrivateAIKey).filter( + DBPrivateAIKey.owner_id == owner_id, + DBPrivateAIKey.litellm_token.isnot(None) # Only count LLM tokens + ).count() + if current_user_tokens >= max_keys_per_user: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"User has reached the maximum LLM token limit of {max_keys_per_user} tokens" + ) + + # Check service LLM tokens (team-owned tokens) + if owner_id is None: # This is a team-owned token + current_service_tokens = db.query(DBPrivateAIKey).filter( + DBPrivateAIKey.team_id == team_id, + DBPrivateAIKey.owner_id.is_(None), + DBPrivateAIKey.litellm_token.isnot(None) # Only count LLM tokens + ).count() + if current_service_tokens >= max_service_keys: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Team has reached the maximum service LLM token limit of {max_service_keys} tokens" + ) diff --git a/tests/test_resource_limits.py b/tests/test_resource_limits.py index 42b8e01..0ee1412 100644 --- a/tests/test_resource_limits.py +++ b/tests/test_resource_limits.py @@ -1,42 +1,23 @@ import pytest from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from app.db.models import DBUser, DBTeam, DBProduct, DBTeamProduct +from app.db.models import DBUser, DBTeam, DBProduct, DBTeamProduct, DBPrivateAIKey from datetime import datetime, UTC +from fastapi import HTTPException +from app.core.resource_limits import check_key_limits, check_team_user_limit def test_add_user_within_product_limit(client, admin_token, db, test_team, test_product): """Test adding a user when within product user limit""" # Add product to team - team_id = test_team.id - product_id = test_product.id team_product = DBTeamProduct( - team_id=team_id, - product_id=product_id + team_id=test_team.id, + product_id=test_product.id ) db.add(team_product) db.commit() - # Create a user to add - user = DBUser( - email="newuser@example.com", - hashed_password="hashed_password", - is_active=True, - is_admin=False, - role="user", - team_id=None, - created_at=datetime.now(UTC) - ) - db.add(user) - db.commit() - - # Add user to team - response = client.post( - f"/users/{user.id}/add-to-team", - headers={"Authorization": f"Bearer {admin_token}"}, - json={"team_id": team_id} - ) - assert response.status_code == 200 - assert response.json()["team_id"] == team_id + # Test that check_team_user_limit doesn't raise an exception + check_team_user_limit(db, test_team.id) def test_add_user_exceeding_product_limit(client, admin_token, db, test_team, test_product): """Test adding a user when it would exceed product user limit""" @@ -62,26 +43,11 @@ def test_add_user_exceeding_product_limit(client, admin_token, db, test_team, te db.add(user) db.commit() - # Try to add one more user - new_user = DBUser( - email="newuser@example.com", - hashed_password="hashed_password", - is_active=True, - is_admin=False, - role="user", - team_id=None, - created_at=datetime.now(UTC) - ) - db.add(new_user) - db.commit() - - response = client.post( - f"/users/{new_user.id}/add-to-team", - headers={"Authorization": f"Bearer {admin_token}"}, - json={"team_id": test_team.id} - ) - assert response.status_code == 400 - assert f"Team has reached the maximum user limit of {test_product.user_count} users" in response.json()["detail"] + # Test that check_team_user_limit raises an exception + with pytest.raises(HTTPException) as exc_info: + check_team_user_limit(db, test_team.id) + assert exc_info.value.status_code == 400 + assert f"Team has reached the maximum user limit of {test_product.user_count} users" in str(exc_info.value.detail) def test_add_user_with_default_limit(client, admin_token, db, test_team): """Test adding users with default limit when team has no products""" @@ -99,26 +65,11 @@ def test_add_user_with_default_limit(client, admin_token, db, test_team): db.add(user) db.commit() - # Try to add one more user - new_user = DBUser( - email="newuser@example.com", - hashed_password="hashed_password", - is_active=True, - is_admin=False, - role="user", - team_id=None, - created_at=datetime.now(UTC) - ) - db.add(new_user) - db.commit() - - response = client.post( - f"/users/{new_user.id}/add-to-team", - headers={"Authorization": f"Bearer {admin_token}"}, - json={"team_id": test_team.id} - ) - assert response.status_code == 400 - assert "Team has reached the default user limit of 2 users" in response.json()["detail"] + # Test that check_team_user_limit raises an exception + with pytest.raises(HTTPException) as exc_info: + check_team_user_limit(db, test_team.id) + assert exc_info.value.status_code == 400 + assert "Team has reached the maximum user limit of 2 users" in str(exc_info.value.detail) def test_add_user_with_multiple_products(client, admin_token, db, test_team): """Test adding users when team has multiple products with different limits""" @@ -184,23 +135,279 @@ def test_add_user_with_multiple_products(client, admin_token, db, test_team): db.add(user) db.commit() - # Try to add one more user - new_user = DBUser( - email="newuser@example.com", + # Test that check_team_user_limit raises an exception + with pytest.raises(HTTPException) as exc_info: + check_team_user_limit(db, test_team.id) + assert exc_info.value.status_code == 400 + assert f"Team has reached the maximum user limit of {product2.user_count} users" in str(exc_info.value.detail) + +def test_create_key_within_limits(client, admin_token, db, test_team, test_product, test_region): + """Test creating an LLM token when within product limits""" + # Add product to team + team_product = DBTeamProduct( + team_id=test_team.id, + product_id=test_product.id + ) + db.add(team_product) + db.commit() + + # Test that check_key_limits doesn't raise an exception + check_key_limits(db, test_team.id, None) + +def test_create_key_exceeding_total_limit(client, admin_token, db, test_team, test_product, test_region): + """Test creating an LLM token when it would exceed total token limit""" + # Add product to team + team_product = DBTeamProduct( + team_id=test_team.id, + product_id=test_product.id + ) + db.add(team_product) + db.commit() + + # Create LLM tokens up to the limit + for i in range(test_product.total_key_count): + key = DBPrivateAIKey( + name=f"Test Token {i}", + database_name=f"test_db_{i}", + database_host="localhost", + database_username="test_user", + database_password="test_pass", + litellm_token=f"test_token_{i}", # Add LLM token + owner_id=None, + team_id=test_team.id, + region_id=test_region.id, + created_at=datetime.now(UTC) + ) + db.add(key) + db.commit() + + # Test that check_key_limits raises an exception + with pytest.raises(HTTPException) as exc_info: + check_key_limits(db, test_team.id, None) + assert exc_info.value.status_code == 400 + assert f"Team has reached the maximum LLM token limit of {test_product.total_key_count} tokens" in str(exc_info.value.detail) + +def test_create_key_exceeding_user_limit(client, admin_token, db, test_team, test_product, test_region): + """Test creating an LLM token when it would exceed user token limit""" + # Add product to team + team_product = DBTeamProduct( + team_id=test_team.id, + product_id=test_product.id + ) + db.add(team_product) + db.commit() + + # Create a test user + user = DBUser( + email="testuser@example.com", hashed_password="hashed_password", is_active=True, is_admin=False, role="user", - team_id=None, + team_id=test_team.id, created_at=datetime.now(UTC) ) - db.add(new_user) + db.add(user) db.commit() - response = client.post( - f"/users/{new_user.id}/add-to-team", - headers={"Authorization": f"Bearer {admin_token}"}, - json={"team_id": test_team.id} + # Create LLM tokens up to the user limit + for i in range(test_product.keys_per_user): + key = DBPrivateAIKey( + name=f"Test Token {i}", + database_name=f"test_db_{i}", + database_host="localhost", + database_username="test_user", + database_password="test_pass", + litellm_token=f"test_token_{i}", # Add LLM token + owner_id=user.id, + team_id=None, + region_id=test_region.id, + created_at=datetime.now(UTC) + ) + db.add(key) + db.commit() + + # Test that check_key_limits raises an exception + with pytest.raises(HTTPException) as exc_info: + check_key_limits(db, test_team.id, user.id) + assert exc_info.value.status_code == 400 + assert f"User has reached the maximum LLM token limit of {test_product.keys_per_user} tokens" in str(exc_info.value.detail) + +def test_create_key_exceeding_service_key_limit(client, admin_token, db, test_team, test_product, test_region): + """Test creating an LLM token when it would exceed service token limit""" + # Add product to team + team_product = DBTeamProduct( + team_id=test_team.id, + product_id=test_product.id ) - assert response.status_code == 400 - assert f"Team has reached the maximum user limit of {product2.user_count} users" in response.json()["detail"] + db.add(team_product) + db.commit() + + # Create service LLM tokens up to the limit + for i in range(test_product.service_key_count): + key = DBPrivateAIKey( + name=f"Test Service Token {i}", + database_name=f"test_db_{i}", + database_host="localhost", + database_username="test_user", + database_password="test_pass", + litellm_token=f"test_token_{i}", # Add LLM token + owner_id=None, # Service tokens have no owner + team_id=test_team.id, + region_id=test_region.id, + created_at=datetime.now(UTC) + ) + db.add(key) + db.commit() + + # Test that check_key_limits raises an exception + with pytest.raises(HTTPException) as exc_info: + check_key_limits(db, test_team.id, None) + assert exc_info.value.status_code == 400 + assert f"Team has reached the maximum service LLM token limit of {test_product.service_key_count} tokens" in str(exc_info.value.detail) + +def test_create_key_with_default_limits(client, admin_token, db, test_team, test_region): + """Test creating LLM tokens with default limits when team has no products""" + # Create LLM tokens up to the default limit (2) + for i in range(2): + key = DBPrivateAIKey( + name=f"Test Token {i}", + database_name=f"test_db_{i}", + database_host="localhost", + database_username="test_user", + database_password="test_pass", + litellm_token=f"test_token_{i}", # Add LLM token + owner_id=None, + team_id=test_team.id, + region_id=test_region.id, + created_at=datetime.now(UTC) + ) + db.add(key) + db.commit() + + # Test that check_key_limits raises an exception + with pytest.raises(HTTPException) as exc_info: + check_key_limits(db, test_team.id, None) + assert exc_info.value.status_code == 400 + assert "Team has reached the maximum LLM token limit of 2 tokens" in str(exc_info.value.detail) + +def test_create_key_with_multiple_products(client, admin_token, db, test_team, test_region): + """Test creating LLM tokens when team has multiple products with different limits""" + # Create two products with different token limits + product1 = DBProduct( + id="prod_test1", + name="Test Product 1", + user_count=3, + keys_per_user=2, + total_key_count=3, + service_key_count=1, + max_budget_per_key=50.0, + rpm_per_key=1000, + vector_db_count=1, + vector_db_storage=100, + renewal_period_days=30, + active=True, + created_at=datetime.now(UTC) + ) + product2 = DBProduct( + id="prod_test2", + name="Test Product 2", + user_count=3, + keys_per_user=3, + total_key_count=5, + service_key_count=2, + max_budget_per_key=50.0, + rpm_per_key=1000, + vector_db_count=1, + vector_db_storage=100, + renewal_period_days=30, + active=True, + created_at=datetime.now(UTC) + ) + db.add(product1) + db.add(product2) + db.commit() + + # Add both products to team + team_product1 = DBTeamProduct( + team_id=test_team.id, + product_id=product1.id + ) + team_product2 = DBTeamProduct( + team_id=test_team.id, + product_id=product2.id + ) + db.add(team_product1) + db.add(team_product2) + db.commit() + + # Create LLM tokens up to the higher total token limit (5) + for i in range(5): + key = DBPrivateAIKey( + name=f"Test Token {i}", + database_name=f"test_db_{i}", + database_host="localhost", + database_username="test_user", + database_password="test_pass", + litellm_token=f"test_token_{i}", # Add LLM token + owner_id=None, + team_id=test_team.id, + region_id=test_region.id, + created_at=datetime.now(UTC) + ) + db.add(key) + db.commit() + + # Test that check_key_limits raises an exception + with pytest.raises(HTTPException) as exc_info: + check_key_limits(db, test_team.id, None) + assert exc_info.value.status_code == 400 + assert f"Team has reached the maximum LLM token limit of {product2.total_key_count} tokens" in str(exc_info.value.detail) + +def test_create_key_with_multiple_users_default_limits(db, test_team, test_region): + """Test creating a key when team has no products and multiple users have keys""" + # Create two users + user1 = DBUser( + email="user1@example.com", + hashed_password="hashed_password", + is_active=True, + is_admin=False, + role="user", + team_id=test_team.id, + created_at=datetime.now(UTC) + ) + user2 = DBUser( + email="user2@example.com", + hashed_password="hashed_password", + is_active=True, + is_admin=False, + role="user", + team_id=test_team.id, + created_at=datetime.now(UTC) + ) + db.add(user1) + db.add(user2) + db.commit() + + # Create one key for each user + for user in [user1, user2]: + key = DBPrivateAIKey( + name=f"Test Token for {user.email}", + database_name=f"test_db_{user.id}", + database_host="localhost", + database_username="test_user", + database_password="test_pass", + litellm_token=f"test_token_{user.id}", + owner_id=user.id, + team_id=None, # Keys with owner_id should not have team_id + region_id=test_region.id, + created_at=datetime.now(UTC) + ) + db.add(key) + db.commit() + + # Test that check_key_limits raises an exception when trying to create a team-owned key + with pytest.raises(HTTPException) as exc_info: + check_key_limits(db, test_team.id, None) + assert exc_info.value.status_code == 400 + assert "Team has reached the maximum LLM token limit of 2 tokens" in str(exc_info.value.detail) From 383084cda6b0627e02d7ad0a295c8631d9ea854d Mon Sep 17 00:00:00 2001 From: Pippa H Date: Tue, 20 May 2025 15:08:13 +0200 Subject: [PATCH 10/23] Add vector DB limit checks and enhance resource limits validation - Introduced a new `check_vector_db_limits` function to validate vector DB creation against team limits. - Updated resource limits tests to include scenarios for creating vector DBs, including exceeding limits and handling user-owned keys. - Refactored existing tests for clarity and to ensure comprehensive coverage of resource limits functionality. --- app/core/resource_limits.py | 42 +++++++- app/core/worker.py | 5 +- tests/test_billing.py | 76 ++++++++------ tests/test_resource_limits.py | 186 +++++++++++++++++++++++++++++++--- 4 files changed, 261 insertions(+), 48 deletions(-) diff --git a/app/core/resource_limits.py b/app/core/resource_limits.py index 4109b05..8f0c3e1 100644 --- a/app/core/resource_limits.py +++ b/app/core/resource_limits.py @@ -1,5 +1,5 @@ from sqlalchemy.orm import Session -from app.db.models import DBTeam, DBUser, DBProduct, DBPrivateAIKey +from app.db.models import DBTeam, DBUser, DBPrivateAIKey from fastapi import HTTPException, status from typing import Optional @@ -107,3 +107,43 @@ def check_key_limits(db: Session, team_id: int, owner_id: Optional[int] = None) status_code=status.HTTP_400_BAD_REQUEST, detail=f"Team has reached the maximum service LLM token limit of {max_service_keys} tokens" ) + +def check_vector_db_limits(db: Session, team_id: int) -> None: + """ + Check if creating a new vector DB would exceed the team's vector DB limits. + Raises HTTPException if the limit would be exceeded. + + Args: + db: Database session + team_id: ID of the team to check + """ + # Get the team and its active products + team = db.query(DBTeam).filter(DBTeam.id == team_id).first() + if not team: + raise HTTPException(status_code=404, detail="Team not found") + + # Find the maximum vector DB count across all active products + max_vector_db_count = max( + (product.vector_db_count for team_product in team.active_products + for product in [team_product.product] if product.vector_db_count), + default=1 # Default to 1 if no products have vector_db_count set + ) + + # Get all users in the team + team_users = db.query(DBUser).filter(DBUser.team_id == team_id).all() + user_ids = [user.id for user in team_users] + + # Get current vector DB count for the team (both team-owned and user-owned) + current_vector_db_count = db.query(DBPrivateAIKey).filter( + ( + (DBPrivateAIKey.team_id == team_id) | # Team-owned vector DBs + (DBPrivateAIKey.owner_id.in_(user_ids)) # User-owned vector DBs + ), + DBPrivateAIKey.database_name.isnot(None) # Only count keys with database_name set + ).count() + + if current_vector_db_count >= max_vector_db_count: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Team has reached the maximum vector DB limit of {max_vector_db_count} databases" + ) diff --git a/app/core/worker.py b/app/core/worker.py index 79f9cb4..2527878 100644 --- a/app/core/worker.py +++ b/app/core/worker.py @@ -1,8 +1,7 @@ -from datetime import datetime, timedelta, UTC +from datetime import datetime, UTC from sqlalchemy.orm import Session -from app.db.models import DBTeam, DBProduct, DBTeamProduct, DBPrivateAIKey, DBRegion, DBUser +from app.db.models import DBTeam, DBProduct, DBTeamProduct, DBPrivateAIKey, DBUser from app.services.litellm import LiteLLMService -from app.core.config import settings import logging from collections import defaultdict diff --git a/tests/test_billing.py b/tests/test_billing.py index e9a6ab1..85a0ba6 100644 --- a/tests/test_billing.py +++ b/tests/test_billing.py @@ -1,5 +1,5 @@ import pytest -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, AsyncMock from app.api.billing import handle_stripe_event_background @pytest.mark.asyncio @@ -10,23 +10,46 @@ async def test_handle_checkout_session_completed(db, test_team): mock_session = Mock() mock_session.metadata = {"team_id": str(test_team.id)} mock_session.customer = "cus_123" + mock_session.id = "cs_123" mock_event.data.object = mock_session - # Act - await handle_stripe_event_background(mock_event, db) + with patch('app.api.billing.get_product_id_from_session', new_callable=AsyncMock) as mock_get_product: + mock_get_product.return_value = "prod_123" + with patch('app.api.billing.apply_product_for_team', new_callable=AsyncMock) as mock_apply: + # Act + await handle_stripe_event_background(mock_event, db) - # Assert - db.refresh(test_team) - assert test_team.stripe_customer_id == "cus_123" + # Assert + mock_get_product.assert_called_once_with("cs_123") + mock_apply.assert_called_once_with(db, "cus_123", "prod_123") @pytest.mark.asyncio -async def test_handle_subscription_deleted(db, test_team): +async def test_handle_invoice_payment_succeeded(db, test_team): # Arrange - # First set up the team with a stripe customer ID - test_team.stripe_customer_id = "cus_123" - db.commit() - db.refresh(test_team) + mock_event = Mock() + mock_event.type = "invoice.payment_succeeded" + mock_invoice = Mock() + mock_invoice.customer = "cus_123" + mock_subscription = Mock() + mock_subscription.id = "sub_123" + mock_invoice.parent = Mock() + mock_invoice.parent.subscription_details = Mock() + mock_invoice.parent.subscription_details.subscription = "sub_123" + mock_event.data.object = mock_invoice + + with patch('app.api.billing.get_product_id_from_sub', new_callable=AsyncMock) as mock_get_product: + mock_get_product.return_value = "prod_123" + with patch('app.api.billing.apply_product_for_team', new_callable=AsyncMock) as mock_apply: + # Act + await handle_stripe_event_background(mock_event, db) + + # Assert + mock_get_product.assert_called_once_with("sub_123") + mock_apply.assert_called_once_with(db, "cus_123", "prod_123") +@pytest.mark.asyncio +async def test_handle_subscription_deleted(db, test_team): + # Arrange mock_event = Mock() mock_event.type = "customer.subscription.deleted" mock_subscription = Mock() @@ -37,8 +60,8 @@ async def test_handle_subscription_deleted(db, test_team): await handle_stripe_event_background(mock_event, db) # Assert - db.refresh(test_team) - assert test_team.stripe_customer_id is None + # No assertions needed as we're just verifying no error occurs + # The function now only logs the event @pytest.mark.asyncio async def test_handle_checkout_session_completed_team_not_found(db): @@ -47,26 +70,19 @@ async def test_handle_checkout_session_completed_team_not_found(db): mock_event.type = "checkout.session.completed" mock_session = Mock() mock_session.metadata = {"team_id": "999"} # Non-existent team ID + mock_session.customer = "cus_123" + mock_session.id = "cs_123" mock_event.data.object = mock_session - # Act - await handle_stripe_event_background(mock_event, db) - - # No assertion needed as we're just verifying no error occurs - -@pytest.mark.asyncio -async def test_handle_subscription_deleted_team_not_found(db): - # Arrange - mock_event = Mock() - mock_event.type = "customer.subscription.deleted" - mock_subscription = Mock() - mock_subscription.customer = "cus_999" # Non-existent customer ID - mock_event.data.object = mock_subscription - - # Act - await handle_stripe_event_background(mock_event, db) + with patch('app.api.billing.get_product_id_from_session', new_callable=AsyncMock) as mock_get_product: + mock_get_product.return_value = "prod_123" + with patch('app.api.billing.apply_product_for_team', new_callable=AsyncMock) as mock_apply: + # Act + await handle_stripe_event_background(mock_event, db) - # No assertion needed as we're just verifying no error occurs + # Assert + mock_get_product.assert_called_once_with("cs_123") + mock_apply.assert_called_once_with(db, "cus_123", "prod_123") @pytest.mark.asyncio async def test_handle_unknown_event_type(db): diff --git a/tests/test_resource_limits.py b/tests/test_resource_limits.py index 0ee1412..611c7c8 100644 --- a/tests/test_resource_limits.py +++ b/tests/test_resource_limits.py @@ -1,12 +1,10 @@ import pytest -from fastapi.testclient import TestClient -from sqlalchemy.orm import Session -from app.db.models import DBUser, DBTeam, DBProduct, DBTeamProduct, DBPrivateAIKey +from app.db.models import DBUser, DBProduct, DBTeamProduct, DBPrivateAIKey from datetime import datetime, UTC from fastapi import HTTPException -from app.core.resource_limits import check_key_limits, check_team_user_limit +from app.core.resource_limits import check_key_limits, check_team_user_limit, check_vector_db_limits -def test_add_user_within_product_limit(client, admin_token, db, test_team, test_product): +def test_add_user_within_product_limit(db, test_team, test_product): """Test adding a user when within product user limit""" # Add product to team team_product = DBTeamProduct( @@ -19,7 +17,7 @@ def test_add_user_within_product_limit(client, admin_token, db, test_team, test_ # Test that check_team_user_limit doesn't raise an exception check_team_user_limit(db, test_team.id) -def test_add_user_exceeding_product_limit(client, admin_token, db, test_team, test_product): +def test_add_user_exceeding_product_limit(db, test_team, test_product): """Test adding a user when it would exceed product user limit""" # Add product to team team_product = DBTeamProduct( @@ -49,7 +47,7 @@ def test_add_user_exceeding_product_limit(client, admin_token, db, test_team, te assert exc_info.value.status_code == 400 assert f"Team has reached the maximum user limit of {test_product.user_count} users" in str(exc_info.value.detail) -def test_add_user_with_default_limit(client, admin_token, db, test_team): +def test_add_user_with_default_limit(db, test_team): """Test adding users with default limit when team has no products""" # Create and add users up to the default limit (2) for i in range(2): @@ -71,7 +69,7 @@ def test_add_user_with_default_limit(client, admin_token, db, test_team): assert exc_info.value.status_code == 400 assert "Team has reached the maximum user limit of 2 users" in str(exc_info.value.detail) -def test_add_user_with_multiple_products(client, admin_token, db, test_team): +def test_add_user_with_multiple_products(db, test_team): """Test adding users when team has multiple products with different limits""" # Create two products with different user limits product1 = DBProduct( @@ -141,7 +139,7 @@ def test_add_user_with_multiple_products(client, admin_token, db, test_team): assert exc_info.value.status_code == 400 assert f"Team has reached the maximum user limit of {product2.user_count} users" in str(exc_info.value.detail) -def test_create_key_within_limits(client, admin_token, db, test_team, test_product, test_region): +def test_create_key_within_limits(db, test_team, test_product, test_region): """Test creating an LLM token when within product limits""" # Add product to team team_product = DBTeamProduct( @@ -154,7 +152,7 @@ def test_create_key_within_limits(client, admin_token, db, test_team, test_produ # Test that check_key_limits doesn't raise an exception check_key_limits(db, test_team.id, None) -def test_create_key_exceeding_total_limit(client, admin_token, db, test_team, test_product, test_region): +def test_create_key_exceeding_total_limit(db, test_team, test_product, test_region): """Test creating an LLM token when it would exceed total token limit""" # Add product to team team_product = DBTeamProduct( @@ -187,7 +185,7 @@ def test_create_key_exceeding_total_limit(client, admin_token, db, test_team, te assert exc_info.value.status_code == 400 assert f"Team has reached the maximum LLM token limit of {test_product.total_key_count} tokens" in str(exc_info.value.detail) -def test_create_key_exceeding_user_limit(client, admin_token, db, test_team, test_product, test_region): +def test_create_key_exceeding_user_limit(db, test_team, test_product, test_region): """Test creating an LLM token when it would exceed user token limit""" # Add product to team team_product = DBTeamProduct( @@ -233,7 +231,7 @@ def test_create_key_exceeding_user_limit(client, admin_token, db, test_team, tes assert exc_info.value.status_code == 400 assert f"User has reached the maximum LLM token limit of {test_product.keys_per_user} tokens" in str(exc_info.value.detail) -def test_create_key_exceeding_service_key_limit(client, admin_token, db, test_team, test_product, test_region): +def test_create_key_exceeding_service_key_limit(db, test_team, test_product, test_region): """Test creating an LLM token when it would exceed service token limit""" # Add product to team team_product = DBTeamProduct( @@ -266,7 +264,7 @@ def test_create_key_exceeding_service_key_limit(client, admin_token, db, test_te assert exc_info.value.status_code == 400 assert f"Team has reached the maximum service LLM token limit of {test_product.service_key_count} tokens" in str(exc_info.value.detail) -def test_create_key_with_default_limits(client, admin_token, db, test_team, test_region): +def test_create_key_with_default_limits(db, test_team, test_region): """Test creating LLM tokens with default limits when team has no products""" # Create LLM tokens up to the default limit (2) for i in range(2): @@ -291,7 +289,7 @@ def test_create_key_with_default_limits(client, admin_token, db, test_team, test assert exc_info.value.status_code == 400 assert "Team has reached the maximum LLM token limit of 2 tokens" in str(exc_info.value.detail) -def test_create_key_with_multiple_products(client, admin_token, db, test_team, test_region): +def test_create_key_with_multiple_products(db, test_team, test_region): """Test creating LLM tokens when team has multiple products with different limits""" # Create two products with different token limits product1 = DBProduct( @@ -411,3 +409,163 @@ def test_create_key_with_multiple_users_default_limits(db, test_team, test_regio check_key_limits(db, test_team.id, None) assert exc_info.value.status_code == 400 assert "Team has reached the maximum LLM token limit of 2 tokens" in str(exc_info.value.detail) + +def test_create_vector_db_within_limits(db, test_team, test_product): + """Test creating a vector DB when within product limits""" + # Add product to team + team_product = DBTeamProduct( + team_id=test_team.id, + product_id=test_product.id + ) + db.add(team_product) + db.commit() + + # Test that check_vector_db_limits doesn't raise an exception + check_vector_db_limits(db, test_team.id) + +def test_create_vector_db_exceeding_limit(db, test_team, test_product, test_region): + """Test creating a vector DB when it would exceed the limit""" + # Add product to team + team_product = DBTeamProduct( + team_id=test_team.id, + product_id=test_product.id + ) + db.add(team_product) + db.commit() + + # Create vector DBs up to the limit + for i in range(test_product.vector_db_count): + key = DBPrivateAIKey( + name=f"Test Vector DB {i}", + database_name=f"test_db_{i}", + database_host="localhost", + database_username="test_user", + database_password="test_pass", + team_id=test_team.id, + region_id=test_region.id, + created_at=datetime.now(UTC) + ) + db.add(key) + db.commit() + + # Test that check_vector_db_limits raises an exception + with pytest.raises(HTTPException) as exc_info: + check_vector_db_limits(db, test_team.id) + assert exc_info.value.status_code == 400 + assert f"Team has reached the maximum vector DB limit of {test_product.vector_db_count} databases" in str(exc_info.value.detail) + +def test_create_vector_db_with_default_limit(db, test_team, test_region): + """Test creating vector DBs with default limit when team has no products""" + # Create a vector DB + key = DBPrivateAIKey( + name="Test Vector DB", + database_name="test_db", + database_host="localhost", + database_username="test_user", + database_password="test_pass", + team_id=test_team.id, + region_id=test_region.id, + created_at=datetime.now(UTC) + ) + db.add(key) + db.commit() + + # Test that check_vector_db_limits raises an exception + with pytest.raises(HTTPException) as exc_info: + check_vector_db_limits(db, test_team.id) + assert exc_info.value.status_code == 400 + assert "Team has reached the maximum vector DB limit of 1 databases" in str(exc_info.value.detail) + +def test_create_vector_db_with_multiple_products(db, test_team, test_region): + """Test creating vector DBs when team has multiple products with different limits""" + # Create two products with different vector DB limits + product1 = DBProduct( + id="prod_test1", + name="Test Product 1", + user_count=3, + keys_per_user=2, + total_key_count=10, + service_key_count=2, + max_budget_per_key=50.0, + rpm_per_key=1000, + vector_db_count=2, + vector_db_storage=100, + renewal_period_days=30, + active=True, + created_at=datetime.now(UTC) + ) + product2 = DBProduct( + id="prod_test2", + name="Test Product 2", + user_count=3, + keys_per_user=2, + total_key_count=10, + service_key_count=2, + max_budget_per_key=50.0, + rpm_per_key=1000, + vector_db_count=3, + vector_db_storage=100, + renewal_period_days=30, + active=True, + created_at=datetime.now(UTC) + ) + db.add(product1) + db.add(product2) + db.commit() + + # Add both products to team + team_product1 = DBTeamProduct( + team_id=test_team.id, + product_id=product1.id + ) + team_product2 = DBTeamProduct( + team_id=test_team.id, + product_id=product2.id + ) + db.add(team_product1) + db.add(team_product2) + db.commit() + + # Create vector DBs up to the higher limit (3) + for i in range(3): + key = DBPrivateAIKey( + name=f"Test Vector DB {i}", + database_name=f"test_db_{i}", + database_host="localhost", + database_username="test_user", + database_password="test_pass", + team_id=test_team.id, + region_id=test_region.id, + created_at=datetime.now(UTC) + ) + db.add(key) + db.commit() + + # Test that check_vector_db_limits raises an exception + with pytest.raises(HTTPException) as exc_info: + check_vector_db_limits(db, test_team.id) + assert exc_info.value.status_code == 400 + assert f"Team has reached the maximum vector DB limit of {product2.vector_db_count} databases" in str(exc_info.value.detail) + +def test_create_vector_db_with_user_owned_key(db, test_team, test_region, test_team_user): + """Test vector DB limit check when a user-owned key has a vector DB and team has no products""" + # Create a user-owned key with a vector DB + key = DBPrivateAIKey( + name="Test User Vector DB", + database_name="test_user_db", + database_host="localhost", + database_username="test_user", + database_password="test_pass", + owner_id=test_team_user.id, + team_id=None, # User-owned keys should not have team_id + region_id=test_region.id, + created_at=datetime.now(UTC) + ) + db.add(key) + db.commit() + + # Test that check_vector_db_limits raises an exception + with pytest.raises(HTTPException) as exc_info: + check_vector_db_limits(db, test_team.id) + assert exc_info.value.status_code == 400 + assert "Team has reached the maximum vector DB limit of 1 databases" in str(exc_info.value.detail) From 7b0a776fdf193340d4b8223498603aa0e8e80cbc Mon Sep 17 00:00:00 2001 From: Pippa H Date: Tue, 20 May 2025 19:02:02 +0200 Subject: [PATCH 11/23] Refactor billing API to enhance Stripe customer portal functionality - Updated the `get_portal` endpoint to create a Stripe customer if one does not exist for the team. - Refactored the `create_portal_session` function to accept a Stripe customer ID and return URL. - Improved error handling for team not found scenarios in the billing API. - Added comprehensive tests for the `get_portal` functionality, covering both existing and new customer cases. --- app/api/billing.py | 45 ++++----- app/services/stripe.py | 16 +--- tests/test_billing.py | 209 ++++++++++++++++++++++++++++++++++++++++- 3 files changed, 236 insertions(+), 34 deletions(-) diff --git a/app/api/billing.py b/app/api/billing.py index 19d05d8..c90c5ae 100644 --- a/app/api/billing.py +++ b/app/api/billing.py @@ -11,7 +11,8 @@ get_stripe_event, create_portal_session, get_product_id_from_sub, - get_product_id_from_session + get_product_id_from_session, + create_stripe_customer ) from app.core.worker import apply_product_for_team @@ -28,8 +29,6 @@ async def checkout( team_id: int, request_data: CheckoutSessionCreate, - request: Request, - current_user: DBUser = Depends(get_current_user_from_auth), db: Session = Depends(get_db) ): """ @@ -42,15 +41,15 @@ async def checkout( Returns: redirect to the checkout session """ - try: - # Get the team - team = db.query(DBTeam).filter(DBTeam.id == team_id).first() - if not team: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Team not found" - ) + # Get the team + team = db.query(DBTeam).filter(DBTeam.id == team_id).first() + if not team: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Team not found" + ) + try: # Get the frontend URL from environment frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000") @@ -150,12 +149,11 @@ async def handle_events( @router.post("/teams/{team_id}/portal", dependencies=[Depends(check_specific_team_admin)]) async def get_portal( team_id: int, - request: Request, - current_user: DBUser = Depends(get_current_user_from_auth), db: Session = Depends(get_db) ): """ Create a Stripe Customer Portal session for team subscription management and redirect to it. + If the team doesn't have a Stripe customer ID, one will be created first. Args: team_id: The ID of the team to create the portal session for @@ -163,20 +161,25 @@ async def get_portal( Returns: Redirects to the Stripe Customer Portal URL """ + # Get the team + team = db.query(DBTeam).filter(DBTeam.id == team_id).first() + if not team: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Team not found" + ) + try: - # Get the team - team = db.query(DBTeam).filter(DBTeam.id == team_id).first() - if not team: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Team not found" - ) + # Create Stripe customer if one doesn't exist + if not team.stripe_customer_id: + team.stripe_customer_id = await create_stripe_customer(team, db) # Get the frontend URL from environment frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000") + return_url = f"{frontend_url}/teams/{team.id}/dashboard" # Create portal session using the service - portal_url = await create_portal_session(team, frontend_url) + portal_url = await create_portal_session(team.stripe_customer_id, return_url) return Response( status_code=status.HTTP_303_SEE_OTHER, diff --git a/app/services/stripe.py b/app/services/stripe.py index 1782d12..ab3165d 100644 --- a/app/services/stripe.py +++ b/app/services/stripe.py @@ -108,30 +108,24 @@ def get_stripe_event( payload: bytes, signature: str, webhook_secret: str) -> st ) async def create_portal_session( - team: DBTeam, - frontend_url: str + stripe_customer_id: str, + return_url: str ) -> str: """ Create a Stripe Customer Portal session for team subscription management. Args: - team: The team to create the portal session for + stripe_customer_id: The Stripe customer ID to create the portal session for frontend_url: The frontend URL for return redirect Returns: str: The portal session URL """ try: - if not team.stripe_customer_id: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Team has no active Stripe subscription" - ) - # Create the portal session portal_session = stripe.billing_portal.Session.create( - customer=team.stripe_customer_id, - return_url=f"{frontend_url}/teams/{team.id}/dashboard" + customer=stripe_customer_id, + return_url=return_url ) return portal_session.url diff --git a/tests/test_billing.py b/tests/test_billing.py index 85a0ba6..5679df6 100644 --- a/tests/test_billing.py +++ b/tests/test_billing.py @@ -1,6 +1,8 @@ import pytest from unittest.mock import Mock, patch, AsyncMock -from app.api.billing import handle_stripe_event_background +from fastapi import HTTPException +from app.api.billing import handle_stripe_event_background, get_portal +from app.db.models import DBTeam @pytest.mark.asyncio async def test_handle_checkout_session_completed(db, test_team): @@ -93,4 +95,207 @@ async def test_handle_unknown_event_type(db): # Act await handle_stripe_event_background(mock_event, db) - # No assertion needed as we're just verifying no error occurs \ No newline at end of file + # No assertion needed as we're just verifying no error occurs + +@patch('app.api.billing.create_portal_session', new_callable=AsyncMock) +def test_get_portal_existing_customer(mock_create_portal, client, db, test_team, team_admin_token): + # Arrange + test_team.stripe_customer_id = "cus_123" + db.add(test_team) + db.commit() + + mock_portal_url = "https://billing.stripe.com/portal/123" + mock_create_portal.return_value = mock_portal_url + + # Act + response = client.post( + f"/billing/teams/{test_team.id}/portal", + headers={"Authorization": f"Bearer {team_admin_token}"}, + follow_redirects=False + ) + + # Assert + assert response.status_code == 303 + assert response.headers["location"] == mock_portal_url + mock_create_portal.assert_called_once_with( + "cus_123", + f"http://localhost:3000/teams/{test_team.id}/dashboard" + ) + +@patch('app.api.billing.create_portal_session', new_callable=AsyncMock) +@patch('app.api.billing.create_stripe_customer', new_callable=AsyncMock) +def test_get_portal_create_customer(mock_create_customer, mock_create_portal, client, db, test_team, team_admin_token): + # Arrange + test_team.stripe_customer_id = None + db.add(test_team) + db.commit() + + mock_portal_url = "https://billing.stripe.com/portal/123" + mock_create_customer.return_value = "cus_new_123" + mock_create_portal.return_value = mock_portal_url + + # Act + response = client.post( + f"/billing/teams/{test_team.id}/portal", + headers={"Authorization": f"Bearer {team_admin_token}"}, + follow_redirects=False + ) + + # Assert + assert response.status_code == 303 + assert response.headers["location"] == mock_portal_url + mock_create_customer.assert_called_once() + mock_create_portal.assert_called_once_with( + "cus_new_123", + f"http://localhost:3000/teams/{test_team.id}/dashboard" + ) + +def test_get_portal_team_not_found(client, db, admin_token): + # Arrange + non_existent_team_id = 999 + + # Act + response = client.post( + f"/billing/teams/{non_existent_team_id}/portal", + headers={"Authorization": f"Bearer {admin_token}"}, + follow_redirects=False + ) + + # Assert + assert response.status_code == 404 + assert response.json()["detail"] == "Team not found" + +@patch('app.api.billing.create_portal_session', new_callable=AsyncMock) +def test_get_portal_as_system_admin(mock_create_portal, client, db, test_team, admin_token): + # Arrange + test_team.stripe_customer_id = "cus_123" + db.add(test_team) + db.commit() + + mock_portal_url = "https://billing.stripe.com/portal/123" + mock_create_portal.return_value = mock_portal_url + + # Act + response = client.post( + f"/billing/teams/{test_team.id}/portal", + headers={"Authorization": f"Bearer {admin_token}"}, + follow_redirects=False + ) + + # Assert + assert response.status_code == 303 + assert response.headers["location"] == mock_portal_url + mock_create_portal.assert_called_once_with( + "cus_123", + f"http://localhost:3000/teams/{test_team.id}/dashboard" + ) + +@patch('app.api.billing.create_portal_session', new_callable=AsyncMock) +def test_get_portal_stripe_error(mock_create_portal, client, db, test_team, team_admin_token): + # Arrange + test_team.stripe_customer_id = "cus_123" + db.add(test_team) + db.commit() + + mock_create_portal.side_effect = Exception("Stripe API error") + + # Act + response = client.post( + f"/billing/teams/{test_team.id}/portal", + headers={"Authorization": f"Bearer {team_admin_token}"}, + follow_redirects=False + ) + + # Assert + assert response.status_code == 500 + assert response.json()["detail"] == "Error creating portal session" + +@patch('app.api.billing.create_checkout_session', new_callable=AsyncMock) +def test_checkout_existing_customer(mock_create_checkout, client, db, test_team, team_admin_token): + # Arrange + test_team.stripe_customer_id = "cus_123" + db.add(test_team) + db.commit() + + mock_checkout_url = "https://checkout.stripe.com/123" + mock_create_checkout.return_value = mock_checkout_url + + # Act + response = client.post( + f"/billing/teams/{test_team.id}/checkout", + headers={"Authorization": f"Bearer {team_admin_token}"}, + json={"price_lookup_token": "price_123"}, + follow_redirects=False + ) + + # Assert + assert response.status_code == 303 + assert response.headers["location"] == mock_checkout_url + mock_create_checkout.assert_called_once_with( + test_team, + "price_123", + "http://localhost:3000" + ) + +@patch('app.api.billing.create_checkout_session', new_callable=AsyncMock) +def test_checkout_as_system_admin(mock_create_checkout, client, db, test_team, admin_token): + # Arrange + test_team.stripe_customer_id = "cus_123" + db.add(test_team) + db.commit() + + mock_checkout_url = "https://checkout.stripe.com/123" + mock_create_checkout.return_value = mock_checkout_url + + # Act + response = client.post( + f"/billing/teams/{test_team.id}/checkout", + headers={"Authorization": f"Bearer {admin_token}"}, + json={"price_lookup_token": "price_123"}, + follow_redirects=False + ) + + # Assert + assert response.status_code == 303 + assert response.headers["location"] == mock_checkout_url + mock_create_checkout.assert_called_once_with( + test_team, + "price_123", + "http://localhost:3000" + ) + +def test_checkout_team_not_found(client, db, admin_token): + # Arrange + non_existent_team_id = 999 + + # Act + response = client.post( + f"/billing/teams/{non_existent_team_id}/checkout", + headers={"Authorization": f"Bearer {admin_token}"}, + json={"price_lookup_token": "price_123"}, + follow_redirects=False + ) + + # Assert + assert response.status_code == 404 + assert response.json()["detail"] == "Team not found" + +@patch('app.api.billing.create_checkout_session', new_callable=AsyncMock) +def test_checkout_stripe_error(mock_create_checkout, client, db, test_team, team_admin_token): + # Arrange + test_team.stripe_customer_id = "cus_123" + db.add(test_team) + db.commit() + + mock_create_checkout.side_effect = Exception("Stripe API error") + + # Act + response = client.post( + f"/billing/teams/{test_team.id}/checkout", + headers={"Authorization": f"Bearer {team_admin_token}"}, + json={"price_lookup_token": "price_123"}, + follow_redirects=False + ) + + # Assert + assert response.status_code == 500 \ No newline at end of file From 0129d587fdc417a043a9e766d5a8b8ee0f1e2bb3 Mon Sep 17 00:00:00 2001 From: Pippa H Date: Wed, 21 May 2025 11:15:31 +0200 Subject: [PATCH 12/23] Enhance Stripe event handling and product management in billing API - Refactored the billing API to improve handling of various Stripe webhook events, including success and failure scenarios for checkout sessions and subscriptions. - Updated the event processing logic to utilize a unified list of known events, enhancing clarity and maintainability. - Introduced a new `remove_product_from_team` function to manage product removals based on event types. - Renamed functions for consistency, such as `get_product_id_from_sub` to `get_product_id_from_subscription`. - Added comprehensive tests for new event handling scenarios, ensuring robust coverage for product application and removal processes. --- app/api/billing.py | 57 +++++--- app/core/worker.py | 40 +++++- app/services/stripe.py | 19 +-- tests/test_billing.py | 293 ++++++++++++++++++++++++++++++++++++----- tests/test_worker.py | 151 ++++++++++++++++++--- 5 files changed, 476 insertions(+), 84 deletions(-) diff --git a/app/api/billing.py b/app/api/billing.py index c90c5ae..41ffbf3 100644 --- a/app/api/billing.py +++ b/app/api/billing.py @@ -3,18 +3,18 @@ import logging import os from app.db.database import get_db -from app.core.security import get_current_user_from_auth, check_specific_team_admin -from app.db.models import DBUser, DBTeam, DBSystemSecret +from app.core.security import check_specific_team_admin +from app.db.models import DBTeam, DBSystemSecret from app.schemas.models import CheckoutSessionCreate from app.services.stripe import ( create_checkout_session, - get_stripe_event, + decode_stripe_event, create_portal_session, - get_product_id_from_sub, + get_product_id_from_subscription, get_product_id_from_session, create_stripe_customer ) -from app.core.worker import apply_product_for_team +from app.core.worker import apply_product_for_team, remove_product_from_team # Configure logger logger = logging.getLogger(__name__) @@ -72,29 +72,48 @@ async def handle_stripe_event_background(event, db: Session): Background task to handle Stripe webhook events. This runs in a separate thread to avoid blocking the webhook response. """ - checkout_session_success = ["checkout.session.async_payment_succeeded", "checkout.session.completed"] - invoice_success = ["invoice.payment_succeeded"] - failure_events = ["checkout.session.async_payment_failed", "checkout.session.expired", "subscription.payment_failed", "customer.subscription.deleted"] + # Full list of possible events: https://docs.stripe.com/api/events/types + session_success_events = ["checkout.session.async_payment_succeeded", "checkout.session.completed"] + invoice_success_events = ["invoice.payment_succeeded"] + subscription_success_events = ["customer.subscription.created", "invoice.payment_succeeded"] + session_failure_events = ["checkout.session.async_payment_failed", "checkout.session.expired"] + subscription_failure_events = ["subscription.payment_failed", "customer.subscription.deleted", "customer.subscription.paused"] + invoice_failure_events = ["invoice.payment_failed"] + + success_events = session_success_events + invoice_success_events + subscription_success_events + failure_events = session_failure_events + subscription_failure_events + invoice_failure_events + known_events = success_events + failure_events try: event_type = event.type - if not event_type in checkout_session_success + invoice_success + failure_events: + if not event_type in known_events: logger.info(f"Unknown event type: {event_type}") return event_object = event.data.object customer_id = event_object.customer - if event_type in invoice_success: + # Success Events + if event_type in invoice_success_events: + # We assume that the invoice is related to a subscription subscription = event_object.parent.subscription_details.subscription - product_id = await get_product_id_from_sub(subscription) + product_id = await get_product_id_from_subscription(subscription) await apply_product_for_team(db, customer_id, product_id) - elif event_type in checkout_session_success: + elif event_type in subscription_success_events: + product_id = await get_product_id_from_subscription(event_object.id) + await apply_product_for_team(db, customer_id, product_id) + elif event_type in session_success_events: product_id = await get_product_id_from_session(event_object.id) await apply_product_for_team(db, customer_id, product_id) - elif event_type in failure_events: - logger.info(f"Checkout session failed") - event_object = event.data.object - logger.info(f"ID: {event_object.id}") - logger.info(f"Full object: {event_object}") - # Handle failed checkout + # Failure Events + elif event_type in session_failure_events: + product_id = await get_product_id_from_session(event_object.id) + await remove_product_from_team(db, customer_id, product_id) + elif event_type in subscription_failure_events: + product_id = await get_product_id_from_subscription(event_object.id) + await remove_product_from_team(db, customer_id, product_id) + elif event_type in invoice_failure_events: + # We assume that the invoice is related to a subscription + subscription = event_object.parent.subscription_details.subscription + product_id = await get_product_id_from_subscription(subscription) + await remove_product_from_team(db, customer_id, product_id) except Exception as e: logger.error(f"Error in background event handler: {str(e)}") @@ -129,7 +148,7 @@ async def handle_events( payload = await request.body() signature = request.headers.get("stripe-signature") - event = get_stripe_event(payload, signature, webhook_secret) + event = decode_stripe_event(payload, signature, webhook_secret) # Add the event handling to background tasks background_tasks.add_task(handle_stripe_event_background, event, db) diff --git a/app/core/worker.py b/app/core/worker.py index 2527878..7144c53 100644 --- a/app/core/worker.py +++ b/app/core/worker.py @@ -7,7 +7,7 @@ logger = logging.getLogger(__name__) -async def apply_product_for_team(db: Session, customer_id: str, product_id: str) -> bool: +async def apply_product_for_team(db: Session, customer_id: str, product_id: str): """ Apply a product to a team and update their last payment date. Also extends all team keys and sets their max budgets via LiteLLM service. @@ -20,6 +20,7 @@ async def apply_product_for_team(db: Session, customer_id: str, product_id: str) Returns: bool: True if update was successful, False otherwise """ + logger.info(f"Applying product {product_id} to team {customer_id}") try: # Find the team and product team = db.query(DBTeam).filter(DBTeam.stripe_customer_id == customer_id).first() @@ -27,10 +28,10 @@ async def apply_product_for_team(db: Session, customer_id: str, product_id: str) if not team: logger.error(f"Team not found for customer ID: {customer_id}") - return False + return if not product: logger.error(f"Product not found for ID: {product_id}") - return False + return # Update the last payment date team.last_payment = datetime.now(UTC) @@ -98,9 +99,40 @@ async def apply_product_for_team(db: Session, customer_id: str, product_id: str) continue db.commit() - return True except Exception as e: db.rollback() logger.error(f"Error applying product to team: {str(e)}") raise e + +async def remove_product_from_team(db: Session, customer_id: str, product_id: str): + logger.info(f"Removing product {product_id} from team {customer_id}") + try: + # Find the team and product + team = db.query(DBTeam).filter(DBTeam.stripe_customer_id == customer_id).first() + product = db.query(DBProduct).filter(DBProduct.id == product_id).first() + + if not team: + logger.error(f"Team not found for customer ID: {customer_id}") + return + if not product: + logger.error(f"Product not found for ID: {product_id}") + return + # Check if the product is already active for the team + existing_association = db.query(DBTeamProduct).filter( + DBTeamProduct.team_id == team.id, + DBTeamProduct.product_id == product.id + ).first() + if not existing_association: + logger.error(f"Product {product_id} not found for team {customer_id}") + return + # Remove the product association + db.delete(existing_association) + + # TODO: Send notification + # TODO: Expire keys if applicable + db.commit() + except Exception as e: + db.rollback() + logger.error(f"Error removing product from team: {str(e)}") + raise e diff --git a/app/services/stripe.py b/app/services/stripe.py index ab3165d..100593b 100644 --- a/app/services/stripe.py +++ b/app/services/stripe.py @@ -71,9 +71,9 @@ async def create_checkout_session( detail="Error creating checkout session" ) -def get_stripe_event( payload: bytes, signature: str, webhook_secret: str) -> stripe.Event: +def decode_stripe_event( payload: bytes, signature: str, webhook_secret: str) -> stripe.Event: """ - Handle Stripe webhook events. + Decode Stripe webhook events. Args: payload: The raw request body @@ -84,21 +84,22 @@ def get_stripe_event( payload: bytes, signature: str, webhook_secret: str) -> st stripe.Event: The Stripe event """ try: - logger.info(f"Trying to decode event") event = stripe.Webhook.construct_event( payload, signature, webhook_secret ) + logger.info(f"Decoded event of type: {event.type}") return event - except ValueError as e: + # If the signature doesn't match, assume bad intent + except stripe.error.SignatureVerificationError as e: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid payload" + status_code=status.HTTP_404_NOT_FOUND, + detail="Not found" ) - except stripe.error.SignatureVerificationError as e: + except ValueError as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid signature" + detail="Invalid payload" ) except Exception as e: logger.error(f"Error handling Stripe event: {str(e)}") @@ -243,7 +244,7 @@ async def create_stripe_customer( detail="Error creating Stripe customer" ) -async def get_product_id_from_sub(subscription_id: str) -> str: +async def get_product_id_from_subscription(subscription_id: str) -> str: """ Get the Stripe product ID for the team's subscription. diff --git a/tests/test_billing.py b/tests/test_billing.py index 5679df6..b5e6934 100644 --- a/tests/test_billing.py +++ b/tests/test_billing.py @@ -2,11 +2,16 @@ from unittest.mock import Mock, patch, AsyncMock from fastapi import HTTPException from app.api.billing import handle_stripe_event_background, get_portal -from app.db.models import DBTeam +from app.db.models import DBTeam, DBTeamProduct @pytest.mark.asyncio -async def test_handle_checkout_session_completed(db, test_team): +@patch('app.api.billing.get_product_id_from_session', new_callable=AsyncMock) +async def test_handle_checkout_session_completed(mock_get_product, db, test_team, test_product): # Arrange + test_team.stripe_customer_id = "cus_123" + db.add(test_team) + db.commit() + mock_event = Mock() mock_event.type = "checkout.session.completed" mock_session = Mock() @@ -15,19 +20,30 @@ async def test_handle_checkout_session_completed(db, test_team): mock_session.id = "cs_123" mock_event.data.object = mock_session - with patch('app.api.billing.get_product_id_from_session', new_callable=AsyncMock) as mock_get_product: - mock_get_product.return_value = "prod_123" - with patch('app.api.billing.apply_product_for_team', new_callable=AsyncMock) as mock_apply: - # Act - await handle_stripe_event_background(mock_event, db) + mock_get_product.return_value = test_product.id + # Act + await handle_stripe_event_background(mock_event, db) - # Assert - mock_get_product.assert_called_once_with("cs_123") - mock_apply.assert_called_once_with(db, "cus_123", "prod_123") + # Assert + mock_get_product.assert_called_once_with("cs_123") + # Verify team-product association was created + team_product = db.query(DBTeamProduct).filter( + DBTeamProduct.team_id == test_team.id, + DBTeamProduct.product_id == test_product.id + ).first() + assert team_product is not None + # Verify last payment was updated + db.refresh(test_team) + assert test_team.last_payment is not None @pytest.mark.asyncio -async def test_handle_invoice_payment_succeeded(db, test_team): +@patch('app.api.billing.get_product_id_from_subscription', new_callable=AsyncMock) +async def test_handle_invoice_payment_succeeded(mock_get_product, db, test_team, test_product): # Arrange + test_team.stripe_customer_id = "cus_123" + db.add(test_team) + db.commit() + mock_event = Mock() mock_event.type = "invoice.payment_succeeded" mock_invoice = Mock() @@ -39,34 +55,62 @@ async def test_handle_invoice_payment_succeeded(db, test_team): mock_invoice.parent.subscription_details.subscription = "sub_123" mock_event.data.object = mock_invoice - with patch('app.api.billing.get_product_id_from_sub', new_callable=AsyncMock) as mock_get_product: - mock_get_product.return_value = "prod_123" - with patch('app.api.billing.apply_product_for_team', new_callable=AsyncMock) as mock_apply: - # Act - await handle_stripe_event_background(mock_event, db) + mock_get_product.return_value = test_product.id + # Act + await handle_stripe_event_background(mock_event, db) - # Assert - mock_get_product.assert_called_once_with("sub_123") - mock_apply.assert_called_once_with(db, "cus_123", "prod_123") + # Assert + mock_get_product.assert_called_once_with("sub_123") + # Verify team-product association was created + team_product = db.query(DBTeamProduct).filter( + DBTeamProduct.team_id == test_team.id, + DBTeamProduct.product_id == test_product.id + ).first() + assert team_product is not None + # Verify last payment was updated + db.refresh(test_team) + assert test_team.last_payment is not None @pytest.mark.asyncio -async def test_handle_subscription_deleted(db, test_team): +@patch('app.api.billing.get_product_id_from_subscription', new_callable=AsyncMock) +async def test_handle_subscription_deleted(mock_get_product, db, test_team, test_product): # Arrange + test_team.stripe_customer_id = "cus_123" + db.add(test_team) + db.commit() + mock_event = Mock() mock_event.type = "customer.subscription.deleted" mock_subscription = Mock() mock_subscription.customer = "cus_123" + mock_subscription.id = "sub_123" mock_event.data.object = mock_subscription + mock_get_product.return_value = test_product.id + + # Set up initial team-product association + team_product = DBTeamProduct( + team_id=test_team.id, + product_id=test_product.id + ) + db.add(team_product) + db.commit() + # Act await handle_stripe_event_background(mock_event, db) # Assert - # No assertions needed as we're just verifying no error occurs - # The function now only logs the event + mock_get_product.assert_called_once_with("sub_123") + # Verify team-product association was removed + team_product = db.query(DBTeamProduct).filter( + DBTeamProduct.team_id == test_team.id, + DBTeamProduct.product_id == test_product.id + ).first() + assert team_product is None @pytest.mark.asyncio -async def test_handle_checkout_session_completed_team_not_found(db): +@patch('app.api.billing.get_product_id_from_session', new_callable=AsyncMock) +async def test_handle_checkout_session_completed_team_not_found(mock_get_product, db, test_product): # Arrange mock_event = Mock() mock_event.type = "checkout.session.completed" @@ -76,15 +120,17 @@ async def test_handle_checkout_session_completed_team_not_found(db): mock_session.id = "cs_123" mock_event.data.object = mock_session - with patch('app.api.billing.get_product_id_from_session', new_callable=AsyncMock) as mock_get_product: - mock_get_product.return_value = "prod_123" - with patch('app.api.billing.apply_product_for_team', new_callable=AsyncMock) as mock_apply: - # Act - await handle_stripe_event_background(mock_event, db) + mock_get_product.return_value = test_product.id + # Act + await handle_stripe_event_background(mock_event, db) - # Assert - mock_get_product.assert_called_once_with("cs_123") - mock_apply.assert_called_once_with(db, "cus_123", "prod_123") + # Assert + mock_get_product.assert_called_once_with("cs_123") + # Verify no team-product association was created + team_product = db.query(DBTeamProduct).filter( + DBTeamProduct.product_id == test_product.id + ).first() + assert team_product is None @pytest.mark.asyncio async def test_handle_unknown_event_type(db): @@ -298,4 +344,187 @@ def test_checkout_stripe_error(mock_create_checkout, client, db, test_team, team ) # Assert - assert response.status_code == 500 \ No newline at end of file + assert response.status_code == 500 + +@patch('app.api.billing.get_product_id_from_session', new_callable=AsyncMock) +@pytest.mark.asyncio +async def test_handle_checkout_session_async_payment_succeeded(mock_get_product, db, test_team, test_product): + # Arrange + test_team.stripe_customer_id = "cus_123" + db.add(test_team) + db.commit() + + mock_event = Mock() + mock_event.type = "checkout.session.async_payment_succeeded" + mock_session = Mock() + mock_session.metadata = {"team_id": str(test_team.id)} + mock_session.customer = "cus_123" + mock_session.id = "cs_123" + mock_event.data.object = mock_session + + mock_get_product.return_value = test_product.id + + # Act + await handle_stripe_event_background(mock_event, db) + + # Assert + mock_get_product.assert_called_once_with("cs_123") + # Verify team-product association was created + team_product = db.query(DBTeamProduct).filter( + DBTeamProduct.team_id == test_team.id, + DBTeamProduct.product_id == test_product.id + ).first() + assert team_product is not None + # Verify last payment was updated + db.refresh(test_team) + assert test_team.last_payment is not None + +@patch('app.api.billing.get_product_id_from_session', new_callable=AsyncMock) +@pytest.mark.asyncio +async def test_handle_checkout_session_async_payment_failed(mock_get_product, db, test_team, test_product): + # Arrange + test_team.stripe_customer_id = "cus_123" + db.add(test_team) + db.commit() + + mock_event = Mock() + mock_event.type = "checkout.session.async_payment_failed" + mock_session = Mock() + mock_session.metadata = {"team_id": str(test_team.id)} + mock_session.customer = "cus_123" + mock_session.id = "cs_123" + mock_event.data.object = mock_session + + mock_get_product.return_value = test_product.id + + # Set up initial team-product association + team_product = DBTeamProduct( + team_id=test_team.id, + product_id=test_product.id + ) + db.add(team_product) + db.commit() + + # Act + await handle_stripe_event_background(mock_event, db) + + # Assert + mock_get_product.assert_called_once_with("cs_123") + # Verify team-product association was removed + team_product = db.query(DBTeamProduct).filter( + DBTeamProduct.team_id == test_team.id, + DBTeamProduct.product_id == test_product.id + ).first() + assert team_product is None + +@patch('app.api.billing.get_product_id_from_session', new_callable=AsyncMock) +@pytest.mark.asyncio +async def test_handle_checkout_session_expired(mock_get_product, db, test_team, test_product): + # Arrange + test_team.stripe_customer_id = "cus_123" + db.add(test_team) + db.commit() + + mock_event = Mock() + mock_event.type = "checkout.session.expired" + mock_session = Mock() + mock_session.metadata = {"team_id": str(test_team.id)} + mock_session.customer = "cus_123" + mock_session.id = "cs_123" + mock_event.data.object = mock_session + + mock_get_product.return_value = test_product.id + + # Set up initial team-product association + team_product = DBTeamProduct( + team_id=test_team.id, + product_id=test_product.id + ) + db.add(team_product) + db.commit() + + # Act + await handle_stripe_event_background(mock_event, db) + + # Assert + mock_get_product.assert_called_once_with("cs_123") + # Verify team-product association was removed + team_product = db.query(DBTeamProduct).filter( + DBTeamProduct.team_id == test_team.id, + DBTeamProduct.product_id == test_product.id + ).first() + assert team_product is None + +@patch('app.api.billing.get_product_id_from_subscription', new_callable=AsyncMock) +@pytest.mark.asyncio +async def test_handle_subscription_payment_failed(mock_get_product, db, test_team, test_product): + # Arrange + test_team.stripe_customer_id = "cus_123" + db.add(test_team) + db.commit() + + mock_event = Mock() + mock_event.type = "subscription.payment_failed" + mock_subscription = Mock() + mock_subscription.customer = "cus_123" + mock_subscription.id = "sub_123" + mock_event.data.object = mock_subscription + + mock_get_product.return_value = test_product.id + + # Set up initial team-product association + team_product = DBTeamProduct( + team_id=test_team.id, + product_id=test_product.id + ) + db.add(team_product) + db.commit() + + # Act + await handle_stripe_event_background(mock_event, db) + + # Assert + mock_get_product.assert_called_once_with("sub_123") + # Verify team-product association was removed + team_product = db.query(DBTeamProduct).filter( + DBTeamProduct.team_id == test_team.id, + DBTeamProduct.product_id == test_product.id + ).first() + assert team_product is None + +@patch('app.api.billing.get_product_id_from_subscription', new_callable=AsyncMock) +@pytest.mark.asyncio +async def test_handle_subscription_paused(mock_get_product, db, test_team, test_product): + # Arrange + test_team.stripe_customer_id = "cus_123" + db.add(test_team) + db.commit() + + mock_event = Mock() + mock_event.type = "customer.subscription.paused" + mock_subscription = Mock() + mock_subscription.customer = "cus_123" + mock_subscription.id = "sub_123" + mock_event.data.object = mock_subscription + + mock_get_product.return_value = test_product.id + + # Set up initial team-product association + team_product = DBTeamProduct( + team_id=test_team.id, + product_id=test_product.id + ) + db.add(team_product) + db.commit() + + # Act + await handle_stripe_event_background(mock_event, db) + + # Assert + mock_get_product.assert_called_once_with("sub_123") + # Verify team-product association was removed + team_product = db.query(DBTeamProduct).filter( + DBTeamProduct.team_id == test_team.id, + DBTeamProduct.product_id == test_product.id + ).first() + assert team_product is None \ No newline at end of file diff --git a/tests/test_worker.py b/tests/test_worker.py index d1b21aa..e03dbc6 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -3,7 +3,7 @@ from sqlalchemy.orm import Session from app.db.models import DBProduct, DBTeam, DBUser, DBPrivateAIKey from datetime import datetime, UTC, timedelta -from app.core.worker import apply_product_for_team +from app.core.worker import apply_product_for_team, remove_product_from_team from unittest.mock import AsyncMock, patch @pytest.mark.asyncio @@ -20,10 +20,7 @@ async def test_apply_product_success(db, test_team, test_product): db.commit() # Apply product to team - result = await apply_product_for_team(db, test_team.stripe_customer_id, test_product.id) - - # Verify the result - assert result is True + await apply_product_for_team(db, test_team.stripe_customer_id, test_product.id) # Refresh team from database db.refresh(test_team) @@ -40,11 +37,11 @@ async def test_apply_product_team_not_found(db, test_product): GIVEN: A product exists but team does not WHEN: Attempting to apply the product - THEN: The operation returns False + THEN: The operation completes without error """ # Try to apply product to non-existent team - result = await apply_product_for_team(db, "cus_nonexistent", test_product.id) - assert result is False + await apply_product_for_team(db, "cus_nonexistent", test_product.id) + # No assertions needed as function should complete without error @pytest.mark.asyncio async def test_apply_product_product_not_found(db, test_team): @@ -53,15 +50,15 @@ async def test_apply_product_product_not_found(db, test_team): GIVEN: A team exists but product does not WHEN: Attempting to apply the product - THEN: The operation returns False + THEN: The operation completes without error """ # Set stripe customer ID for the test team test_team.stripe_customer_id = "cus_test123" db.commit() # Try to apply non-existent product - result = await apply_product_for_team(db, test_team.stripe_customer_id, "prod_nonexistent") - assert result is False + await apply_product_for_team(db, test_team.stripe_customer_id, "prod_nonexistent") + # No assertions needed as function should complete without error @pytest.mark.asyncio async def test_apply_product_multiple_products(db, test_team, test_product): @@ -100,8 +97,7 @@ async def test_apply_product_multiple_products(db, test_team, test_product): # Apply each product to the team for product in products: - result = await apply_product_for_team(db, test_team.stripe_customer_id, product.id) - assert result is True + await apply_product_for_team(db, test_team.stripe_customer_id, product.id) # Refresh team from database db.refresh(test_team) @@ -127,16 +123,14 @@ async def test_apply_product_already_active(db, test_team, test_product): db.refresh(test_team) # Refresh to ensure we have the latest data # First apply the product - result = await apply_product_for_team(db, test_team.stripe_customer_id, test_product.id) - assert result is True + await apply_product_for_team(db, test_team.stripe_customer_id, test_product.id) # Get the initial last payment date db.refresh(test_team) initial_last_payment = test_team.last_payment # Apply the same product again - result = await apply_product_for_team(db, test_team.stripe_customer_id, test_product.id) - assert result is True + await apply_product_for_team(db, test_team.stripe_customer_id, test_product.id) # Refresh team from database db.refresh(test_team) @@ -204,8 +198,7 @@ async def test_apply_product_extends_keys_and_sets_budget(mock_litellm, db, test mock_instance.update_budget = AsyncMock() # Apply product to team - result = await apply_product_for_team(db, test_team.stripe_customer_id, test_product.id) - assert result is True + await apply_product_for_team(db, test_team.stripe_customer_id, test_product.id) # Verify LiteLLM service was initialized with correct region settings mock_litellm.assert_called_once_with( @@ -237,4 +230,122 @@ async def test_apply_product_extends_keys_and_sets_budget(mock_litellm, db, test db.refresh(test_team) assert len(test_team.active_products) == 1 assert test_team.active_products[0].product.id == test_product.id - assert test_team.last_payment is not None \ No newline at end of file + assert test_team.last_payment is not None + +@pytest.mark.asyncio +async def test_remove_product_success(db, test_team, test_product): + """ + Test successful removal of a product from a team. + + GIVEN: A team with an active product + WHEN: The product is removed from the team + THEN: The product association is removed from the team's active products + """ + # Set stripe customer ID for the test team + test_team.stripe_customer_id = "cus_test123" + db.commit() + + # First apply the product to ensure it exists + await apply_product_for_team(db, test_team.stripe_customer_id, test_product.id) + + # Remove the product + await remove_product_from_team(db, test_team.stripe_customer_id, test_product.id) + + # Refresh team from database + db.refresh(test_team) + + # Verify product was removed + assert len(test_team.active_products) == 0 + +@pytest.mark.asyncio +async def test_remove_product_team_not_found(db, test_product): + """ + Test removing a product when team is not found. + + GIVEN: A product exists but team does not + WHEN: Attempting to remove the product + THEN: The operation returns None + """ + # Try to remove product from non-existent team + result = await remove_product_from_team(db, "cus_nonexistent", test_product.id) + assert result is None + +@pytest.mark.asyncio +async def test_remove_product_product_not_found(db, test_team): + """ + Test removing a non-existent product from a team. + + GIVEN: A team exists but product does not + WHEN: Attempting to remove the product + THEN: The operation returns None + """ + # Set stripe customer ID for the test team + test_team.stripe_customer_id = "cus_test123" + db.commit() + + # Try to remove non-existent product + result = await remove_product_from_team(db, test_team.stripe_customer_id, "prod_nonexistent") + assert result is None + +@pytest.mark.asyncio +async def test_remove_product_not_active(db, test_team, test_product): + """ + Test removing a product that is not active for a team. + + GIVEN: A team exists but does not have the specified product active + WHEN: Attempting to remove the product + THEN: The operation returns None + """ + # Set stripe customer ID for the test team + test_team.stripe_customer_id = "cus_test123" + db.commit() + + # Try to remove product that was never added + result = await remove_product_from_team(db, test_team.stripe_customer_id, test_product.id) + assert result is None + +@pytest.mark.asyncio +async def test_remove_product_multiple_products(db, test_team, test_product): + """ + Test removing one product while keeping others active. + + GIVEN: A team with multiple active products + WHEN: One product is removed + THEN: Only the specified product is removed, others remain active + """ + # Set stripe customer ID for the test team + test_team.stripe_customer_id = "cus_test123" + db.commit() + + # Create additional test product + second_product = DBProduct( + id="prod_test456", + name="Test Product 2", + user_count=5, + keys_per_user=2, + total_key_count=10, + service_key_count=2, + max_budget_per_key=50.0, + rpm_per_key=1000, + vector_db_count=1, + vector_db_storage=100, + renewal_period_days=30, + active=True, + created_at=datetime.now(UTC) + ) + db.add(second_product) + db.commit() + + # Apply both products to the team + await apply_product_for_team(db, test_team.stripe_customer_id, test_product.id) + await apply_product_for_team(db, test_team.stripe_customer_id, second_product.id) + + # Remove only the first product + await remove_product_from_team(db, test_team.stripe_customer_id, test_product.id) + + # Refresh team from database + db.refresh(test_team) + + # Verify only the first product was removed + assert len(test_team.active_products) == 1 + assert test_team.active_products[0].product.id == second_product.id \ No newline at end of file From 64544e740de3bfaebb5409cded1c483015cfec54 Mon Sep 17 00:00:00 2001 From: Pippa H Date: Wed, 21 May 2025 13:14:00 +0200 Subject: [PATCH 13/23] Add pricing table session endpoint and improve error handling in billing API - Introduced a new endpoint `/teams/{team_id}/pricing-table-session` to create a Stripe Customer Session for team subscription management. - Enhanced error handling for scenarios where the team is not found and when the Stripe webhook secret is not configured. - Updated the subscription success events to include `customer.subscription.resumed`. - Added comprehensive tests for the new pricing table session functionality, covering both existing and new customer cases, as well as error scenarios. --- app/api/billing.py | 55 ++++++++++++++++++++-- app/services/stripe.py | 27 +++++++++++ tests/test_billing.py | 101 ++++++++++++++++++++++++++++++++++++++++- 3 files changed, 177 insertions(+), 6 deletions(-) diff --git a/app/api/billing.py b/app/api/billing.py index 41ffbf3..a75556b 100644 --- a/app/api/billing.py +++ b/app/api/billing.py @@ -12,7 +12,8 @@ create_portal_session, get_product_id_from_subscription, get_product_id_from_session, - create_stripe_customer + create_stripe_customer, + get_pricing_table_session ) from app.core.worker import apply_product_for_team, remove_product_from_team @@ -75,11 +76,14 @@ async def handle_stripe_event_background(event, db: Session): # Full list of possible events: https://docs.stripe.com/api/events/types session_success_events = ["checkout.session.async_payment_succeeded", "checkout.session.completed"] invoice_success_events = ["invoice.payment_succeeded"] - subscription_success_events = ["customer.subscription.created", "invoice.payment_succeeded"] + subscription_success_events = ["customer.subscription.resumed", "customer.subscription.created", "invoice.payment_succeeded"] session_failure_events = ["checkout.session.async_payment_failed", "checkout.session.expired"] subscription_failure_events = ["subscription.payment_failed", "customer.subscription.deleted", "customer.subscription.paused"] invoice_failure_events = ["invoice.payment_failed"] + # TODO: Manage invoicing + # invoice_respose_needed_events = ["invoice.created", "invoice.upcoming"] + success_events = session_success_events + invoice_success_events + subscription_success_events failure_events = session_failure_events + subscription_failure_events + invoice_failure_events known_events = success_events + failure_events @@ -130,7 +134,7 @@ async def handle_events( payment successes, and failures. Events are processed asynchronously in the background. """ try: - # Get the webhook secret from database + # Get the webhook secret from database or environment variable if os.getenv("WEBHOOK_SIG"): webhook_secret = os.getenv("WEBHOOK_SIG") else: @@ -139,9 +143,11 @@ async def handle_events( ).first().value if not webhook_secret: + logger.error("Stripe webhook secret not configured") + # 404 for security reasons - if we're not accepting traffic here, then it doesn't exist raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Stripe webhook secret not configured" + status_code=status.HTTP_404_NOT_FOUND, + detail="Not found" ) # Get the raw request body @@ -210,3 +216,42 @@ async def get_portal( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Error creating portal session" ) + +@router.get("/teams/{team_id}/pricing-table-session", dependencies=[Depends(check_specific_team_admin)]) +async def get_pricing_table_session( + team_id: int, + db: Session = Depends(get_db) +): + """ + Create a Stripe Customer Session client secret for team subscription management. + If the team doesn't have a Stripe customer ID, one will be created first. + + Args: + team_id: The ID of the team to create the customer session for + + Returns: + JSON response containing the client secret + """ + # Get the team + team = db.query(DBTeam).filter(DBTeam.id == team_id).first() + if not team: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Team not found" + ) + + try: + # Create Stripe customer if one doesn't exist + if not team.stripe_customer_id: + team.stripe_customer_id = await create_stripe_customer(team, db) + + # Create customer session using the service + client_secret = await get_pricing_table_session(team.stripe_customer_id) + + return {"client_secret": client_secret} + except Exception as e: + logger.error(f"Error creating customer session: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error creating customer session" + ) diff --git a/app/services/stripe.py b/app/services/stripe.py index 100593b..ce0c04d 100644 --- a/app/services/stripe.py +++ b/app/services/stripe.py @@ -277,3 +277,30 @@ async def get_product_id_from_session(session_id: str) -> str: """ line_items = stripe.checkout.Session.list_line_items(session_id) return line_items.data[0].price.product + +async def get_pricing_table_session(customer_id: str) -> str: + """ + Create a Stripe Customer Session client secret for a customer. + + Args: + customer_id: The Stripe customer ID to create the session for + + Returns: + str: The customer session client secret + """ + try: + # Create the customer session + session = stripe.CustomerSession.create( + customer=customer_id, + components={ + "pricing_table": {"enabled": True} + } + ) + + return session.client_secret + except Exception as e: + logger.error(f"Error creating customer session: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error creating customer session" + ) diff --git a/tests/test_billing.py b/tests/test_billing.py index b5e6934..796f129 100644 --- a/tests/test_billing.py +++ b/tests/test_billing.py @@ -527,4 +527,103 @@ async def test_handle_subscription_paused(mock_get_product, db, test_team, test_ DBTeamProduct.team_id == test_team.id, DBTeamProduct.product_id == test_product.id ).first() - assert team_product is None \ No newline at end of file + assert team_product is None + +@patch('app.api.billing.get_pricing_table_session', new_callable=AsyncMock) +def test_get_pricing_table_session_existing_customer(mock_get_session, client, db, test_team, team_admin_token): + # Arrange + test_team.stripe_customer_id = "cus_123" + db.add(test_team) + db.commit() + + mock_client_secret = "cs_test_123" + mock_get_session.return_value = mock_client_secret + + # Act + response = client.get( + f"/billing/teams/{test_team.id}/pricing-table-session", + headers={"Authorization": f"Bearer {team_admin_token}"} + ) + + # Assert + assert response.status_code == 200 + assert response.json()["client_secret"] == mock_client_secret + mock_get_session.assert_called_once_with("cus_123") + +@patch('app.api.billing.get_pricing_table_session', new_callable=AsyncMock) +@patch('app.api.billing.create_stripe_customer', new_callable=AsyncMock) +def test_get_pricing_table_session_create_customer(mock_create_customer, mock_get_session, client, db, test_team, team_admin_token): + # Arrange + test_team.stripe_customer_id = None + db.add(test_team) + db.commit() + + mock_client_secret = "cs_test_123" + mock_create_customer.return_value = "cus_new_123" + mock_get_session.return_value = mock_client_secret + + # Act + response = client.get( + f"/billing/teams/{test_team.id}/pricing-table-session", + headers={"Authorization": f"Bearer {team_admin_token}"} + ) + + # Assert + assert response.status_code == 200 + assert response.json()["client_secret"] == mock_client_secret + mock_create_customer.assert_called_once() + mock_get_session.assert_called_once_with("cus_new_123") + +def test_get_pricing_table_session_team_not_found(client, db, admin_token): + # Arrange + non_existent_team_id = 999 + + # Act + response = client.get( + f"/billing/teams/{non_existent_team_id}/pricing-table-session", + headers={"Authorization": f"Bearer {admin_token}"} + ) + + # Assert + assert response.status_code == 404 + assert response.json()["detail"] == "Team not found" + +@patch('app.api.billing.get_pricing_table_session', new_callable=AsyncMock) +def test_get_pricing_table_session_as_system_admin(mock_get_session, client, db, test_team, admin_token): + # Arrange + test_team.stripe_customer_id = "cus_123" + db.add(test_team) + db.commit() + + mock_client_secret = "cs_test_123" + mock_get_session.return_value = mock_client_secret + + # Act + response = client.get( + f"/billing/teams/{test_team.id}/pricing-table-session", + headers={"Authorization": f"Bearer {admin_token}"} + ) + + # Assert + assert response.status_code == 200 + assert response.json()["client_secret"] == mock_client_secret + mock_get_session.assert_called_once_with("cus_123") + +@patch('app.api.billing.get_pricing_table_session', new_callable=AsyncMock) +def test_get_pricing_table_session_stripe_error(mock_get_session, client, db, test_team, team_admin_token): + # Arrange + test_team.stripe_customer_id = "cus_123" + db.add(test_team) + db.commit() + + mock_get_session.side_effect = Exception("Stripe API error") + + # Act + response = client.get( + f"/billing/teams/{test_team.id}/pricing-table-session", + headers={"Authorization": f"Bearer {team_admin_token}"} + ) + + # Assert + assert response.status_code == 500 + assert response.json()["detail"] == "Error creating customer session" \ No newline at end of file From 8421411c4100d35ff8e282ba30d8073544bad80a Mon Sep 17 00:00:00 2001 From: Pippa H Date: Wed, 21 May 2025 13:46:34 +0200 Subject: [PATCH 14/23] Implement resource limits checks for team and user management - Added functionality to enforce limits on team users and LLM tokens based on configuration settings. - Introduced `check_key_limits` and `check_team_user_limit` functions to validate resource usage against defined limits. - Updated the `create_user` and `create_llm_token` endpoints to incorporate these checks. - Enhanced the `check_vector_db_limits` function to validate vector DB creation against team limits. - Refactored existing tests to ensure comprehensive coverage for new limit checks and updated error handling for limit exceedances. --- app/api/private_ai_keys.py | 11 +++++++++++ app/api/users.py | 6 +++++- app/core/config.py | 2 +- app/core/resource_limits.py | 36 ++++++++++++++++++++++------------- app/services/litellm.py | 13 +++++++------ tests/test_private_ai.py | 4 ++-- tests/test_resource_limits.py | 26 ++++++++++++------------- 7 files changed, 62 insertions(+), 36 deletions(-) diff --git a/app/api/private_ai_keys.py b/app/api/private_ai_keys.py index a46d663..fe70ea5 100644 --- a/app/api/private_ai_keys.py +++ b/app/api/private_ai_keys.py @@ -14,6 +14,8 @@ from app.db.models import DBPrivateAIKey, DBRegion, DBUser, DBTeam from app.services.litellm import LiteLLMService from app.core.security import get_current_user_from_auth, get_role_min_key_creator, get_role_min_team_admin, UserRole, check_system_admin +from app.core.config import settings +from app.core.resource_limits import check_key_limits, check_vector_db_limits router = APIRouter( tags=["private-ai-keys"] @@ -113,6 +115,12 @@ async def create_vector_db( detail="Region not found or inactive" ) + if settings.ENABLE_LIMITS: + if not team_id: # if the team_id is not set we have already validated the owner_id + user = db.query(DBUser).filter(DBUser.id == owner_id).first() + team_id = user.team_id or FAKE_ID + check_vector_db_limits(db, team_id) + try: # Create new postgres database postgres_manager = PostgresManager(region=region) @@ -294,6 +302,9 @@ async def create_llm_token( litellm_team = owner.team_id or FAKE_ID try: + if settings.ENABLE_LIMITS: # Have to do this check so late since we always need the team ID + check_key_limits(db, litellm_team, owner_id) + # Generate LiteLLM token litellm_service = LiteLLMService( api_url=region.litellm_api_url, diff --git a/app/api/users.py b/app/api/users.py index 5a9fb29..c78ca3f 100644 --- a/app/api/users.py +++ b/app/api/users.py @@ -1,7 +1,8 @@ from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session from typing import List, get_args - +from app.core.config import settings +from app.core.resource_limits import check_team_user_limit from app.db.database import get_db from app.schemas.models import User, UserUpdate, UserCreate, TeamOperation, UserRoleUpdate from app.db.models import DBUser, DBTeam @@ -72,6 +73,9 @@ async def create_user( detail="Not authorized to perform this action" ) + if settings.ENABLE_LIMITS and user.team_id is not None: + check_team_user_limit(db, user.team_id) + # Validate role if provided if user.role and user.role not in get_args(UserRole): raise HTTPException( diff --git a/app/core/config.py b/app/core/config.py index 12f7b9e..33aa33e 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -28,7 +28,7 @@ class Settings(BaseSettings): ENV_SUFFIX: str = "local" DYNAMODB_REGION: str = "eu-west-1" SES_REGION: str = "eu-west-1" - EXPIRE_KEYS: str = "true" + ENABLE_LIMITS: bool = True STRIPE_SECRET_KEY: str = "sk_test_string" WEBHOOK_SIG: str = "whsec_test_1234567890" diff --git a/app/core/resource_limits.py b/app/core/resource_limits.py index 8f0c3e1..8cd4a43 100644 --- a/app/core/resource_limits.py +++ b/app/core/resource_limits.py @@ -3,6 +3,16 @@ from fastapi import HTTPException, status from typing import Optional +# Default limits across all customers and products +DEFAULT_USER_COUNT = 2 +DEFAULT_KEYS_PER_USER = 1 +DEFAULT_TOTAL_KEYS = 2 +DEFAULT_SERVICE_KEYS = 1 +DEFAULT_VECTOR_DB_COUNT = 1 +DEFAULT_KEY_DURATION = "30d" +DEFAULT_MAX_SPEND = 20.0 +DEFAULT_RPM_PER_KEY = 500 + def check_team_user_limit(db: Session, team_id: int) -> None: """ Check if adding a user would exceed the team's product limits. @@ -18,18 +28,18 @@ def check_team_user_limit(db: Session, team_id: int) -> None: # Get all active products for the team team = db.query(DBTeam).filter(DBTeam.id == team_id).first() if not team: - raise HTTPException(status_code=404, detail="Team not found") + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Team not found") # Find the maximum user count allowed across all active products max_user_count = max( (product.user_count for team_product in team.active_products for product in [team_product.product] if product.user_count), - default=2 # Default to 2 if no products have user_count set + default=DEFAULT_USER_COUNT # Default to 2 if no products have user_count set ) if current_user_count >= max_user_count: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, + status_code=status.HTTP_402_PAYMENT_REQUIRED, detail=f"Team has reached the maximum user limit of {max_user_count} users" ) @@ -46,23 +56,23 @@ def check_key_limits(db: Session, team_id: int, owner_id: Optional[int] = None) # Get the team and its active products team = db.query(DBTeam).filter(DBTeam.id == team_id).first() if not team: - raise HTTPException(status_code=404, detail="Team not found") + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Team not found") # Find the maximum limits across all active products, using defaults if no products max_total_keys = max( (product.total_key_count for team_product in team.active_products for product in [team_product.product] if product.total_key_count), - default=2 # Default to 2 if no products have total_key_count set + default=DEFAULT_TOTAL_KEYS # Default to 2 if no products have total_key_count set ) max_keys_per_user = max( (product.keys_per_user for team_product in team.active_products for product in [team_product.product] if product.keys_per_user), - default=1 # Default to 1 if no products have keys_per_user set + default=DEFAULT_KEYS_PER_USER # Default to 1 if no products have keys_per_user set ) max_service_keys = max( (product.service_key_count for team_product in team.active_products for product in [team_product.product] if product.service_key_count), - default=1 # Default to 1 if no products have service_key_count set + default=DEFAULT_SERVICE_KEYS # Default to 1 if no products have service_key_count set ) # Get all users in the team @@ -79,7 +89,7 @@ def check_key_limits(db: Session, team_id: int, owner_id: Optional[int] = None) ).count() if current_team_tokens >= max_total_keys: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, + status_code=status.HTTP_402_PAYMENT_REQUIRED, detail=f"Team has reached the maximum LLM token limit of {max_total_keys} tokens" ) @@ -91,7 +101,7 @@ def check_key_limits(db: Session, team_id: int, owner_id: Optional[int] = None) ).count() if current_user_tokens >= max_keys_per_user: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, + status_code=status.HTTP_402_PAYMENT_REQUIRED, detail=f"User has reached the maximum LLM token limit of {max_keys_per_user} tokens" ) @@ -104,7 +114,7 @@ def check_key_limits(db: Session, team_id: int, owner_id: Optional[int] = None) ).count() if current_service_tokens >= max_service_keys: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, + status_code=status.HTTP_402_PAYMENT_REQUIRED, detail=f"Team has reached the maximum service LLM token limit of {max_service_keys} tokens" ) @@ -120,13 +130,13 @@ def check_vector_db_limits(db: Session, team_id: int) -> None: # Get the team and its active products team = db.query(DBTeam).filter(DBTeam.id == team_id).first() if not team: - raise HTTPException(status_code=404, detail="Team not found") + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Team not found") # Find the maximum vector DB count across all active products max_vector_db_count = max( (product.vector_db_count for team_product in team.active_products for product in [team_product.product] if product.vector_db_count), - default=1 # Default to 1 if no products have vector_db_count set + default=DEFAULT_VECTOR_DB_COUNT # Default to 1 if no products have vector_db_count set ) # Get all users in the team @@ -144,6 +154,6 @@ def check_vector_db_limits(db: Session, team_id: int) -> None: if current_vector_db_count >= max_vector_db_count: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, + status_code=status.HTTP_402_PAYMENT_REQUIRED, detail=f"Team has reached the maximum vector DB limit of {max_vector_db_count} databases" ) diff --git a/app/services/litellm.py b/app/services/litellm.py index e5393c0..3128b16 100644 --- a/app/services/litellm.py +++ b/app/services/litellm.py @@ -1,7 +1,8 @@ import requests from fastapi import HTTPException, status -import os import logging +from app.core.resource_limits import DEFAULT_KEY_DURATION, DEFAULT_MAX_SPEND, DEFAULT_RPM_PER_KEY +from app.core.config import settings logger = logging.getLogger(__name__) @@ -39,11 +40,11 @@ async def create_key(self, email: str, name: str, user_id: int, team_id: str) -> request_data["metadata"] = metadata request_data["team_id"] = team_id - if os.getenv("EXPIRE_KEYS", "").lower() == "true": - request_data["duration"] = "30d" - request_data["budget_duration"] = "30d" - request_data["max_budget"] = 20.0 - request_data["rpm_limit"] = 500 + if settings.ENABLE_LIMITS: + request_data["duration"] = DEFAULT_KEY_DURATION + request_data["budget_duration"] = DEFAULT_KEY_DURATION + request_data["max_budget"] = DEFAULT_MAX_SPEND + request_data["rpm_limit"] = DEFAULT_RPM_PER_KEY else: request_data["duration"] = "365d" diff --git a/tests/test_private_ai.py b/tests/test_private_ai.py index bb03022..99466be 100644 --- a/tests/test_private_ai.py +++ b/tests/test_private_ai.py @@ -1098,9 +1098,9 @@ def test_create_llm_token_as_system_admin(mock_post, client, admin_token, test_r ) @patch("app.services.litellm.requests.post") -@patch.dict(os.environ, {"EXPIRE_KEYS": "true"}) +@patch.dict(os.environ, {"ENABLE_LIMITS": "true"}) def test_create_llm_token_with_expiration(mock_post, client, admin_token, test_region, mock_litellm_response): - """Test that when EXPIRE_KEYS is true, new LiteLLM tokens are created with a 14-day expiration duration""" + """Test that when ENABLE_LIMITS is true, new LiteLLM tokens are created with a 14-day expiration duration""" # Mock the LiteLLM API response mock_post.return_value.status_code = 200 mock_post.return_value.json.return_value = mock_litellm_response diff --git a/tests/test_resource_limits.py b/tests/test_resource_limits.py index 611c7c8..6cc8604 100644 --- a/tests/test_resource_limits.py +++ b/tests/test_resource_limits.py @@ -44,7 +44,7 @@ def test_add_user_exceeding_product_limit(db, test_team, test_product): # Test that check_team_user_limit raises an exception with pytest.raises(HTTPException) as exc_info: check_team_user_limit(db, test_team.id) - assert exc_info.value.status_code == 400 + assert exc_info.value.status_code == 402 assert f"Team has reached the maximum user limit of {test_product.user_count} users" in str(exc_info.value.detail) def test_add_user_with_default_limit(db, test_team): @@ -66,7 +66,7 @@ def test_add_user_with_default_limit(db, test_team): # Test that check_team_user_limit raises an exception with pytest.raises(HTTPException) as exc_info: check_team_user_limit(db, test_team.id) - assert exc_info.value.status_code == 400 + assert exc_info.value.status_code == 402 assert "Team has reached the maximum user limit of 2 users" in str(exc_info.value.detail) def test_add_user_with_multiple_products(db, test_team): @@ -136,7 +136,7 @@ def test_add_user_with_multiple_products(db, test_team): # Test that check_team_user_limit raises an exception with pytest.raises(HTTPException) as exc_info: check_team_user_limit(db, test_team.id) - assert exc_info.value.status_code == 400 + assert exc_info.value.status_code == 402 assert f"Team has reached the maximum user limit of {product2.user_count} users" in str(exc_info.value.detail) def test_create_key_within_limits(db, test_team, test_product, test_region): @@ -182,7 +182,7 @@ def test_create_key_exceeding_total_limit(db, test_team, test_product, test_regi # Test that check_key_limits raises an exception with pytest.raises(HTTPException) as exc_info: check_key_limits(db, test_team.id, None) - assert exc_info.value.status_code == 400 + assert exc_info.value.status_code == 402 assert f"Team has reached the maximum LLM token limit of {test_product.total_key_count} tokens" in str(exc_info.value.detail) def test_create_key_exceeding_user_limit(db, test_team, test_product, test_region): @@ -228,7 +228,7 @@ def test_create_key_exceeding_user_limit(db, test_team, test_product, test_regio # Test that check_key_limits raises an exception with pytest.raises(HTTPException) as exc_info: check_key_limits(db, test_team.id, user.id) - assert exc_info.value.status_code == 400 + assert exc_info.value.status_code == 402 assert f"User has reached the maximum LLM token limit of {test_product.keys_per_user} tokens" in str(exc_info.value.detail) def test_create_key_exceeding_service_key_limit(db, test_team, test_product, test_region): @@ -261,7 +261,7 @@ def test_create_key_exceeding_service_key_limit(db, test_team, test_product, tes # Test that check_key_limits raises an exception with pytest.raises(HTTPException) as exc_info: check_key_limits(db, test_team.id, None) - assert exc_info.value.status_code == 400 + assert exc_info.value.status_code == 402 assert f"Team has reached the maximum service LLM token limit of {test_product.service_key_count} tokens" in str(exc_info.value.detail) def test_create_key_with_default_limits(db, test_team, test_region): @@ -286,7 +286,7 @@ def test_create_key_with_default_limits(db, test_team, test_region): # Test that check_key_limits raises an exception with pytest.raises(HTTPException) as exc_info: check_key_limits(db, test_team.id, None) - assert exc_info.value.status_code == 400 + assert exc_info.value.status_code == 402 assert "Team has reached the maximum LLM token limit of 2 tokens" in str(exc_info.value.detail) def test_create_key_with_multiple_products(db, test_team, test_region): @@ -359,7 +359,7 @@ def test_create_key_with_multiple_products(db, test_team, test_region): # Test that check_key_limits raises an exception with pytest.raises(HTTPException) as exc_info: check_key_limits(db, test_team.id, None) - assert exc_info.value.status_code == 400 + assert exc_info.value.status_code == 402 assert f"Team has reached the maximum LLM token limit of {product2.total_key_count} tokens" in str(exc_info.value.detail) def test_create_key_with_multiple_users_default_limits(db, test_team, test_region): @@ -407,7 +407,7 @@ def test_create_key_with_multiple_users_default_limits(db, test_team, test_regio # Test that check_key_limits raises an exception when trying to create a team-owned key with pytest.raises(HTTPException) as exc_info: check_key_limits(db, test_team.id, None) - assert exc_info.value.status_code == 400 + assert exc_info.value.status_code == 402 assert "Team has reached the maximum LLM token limit of 2 tokens" in str(exc_info.value.detail) def test_create_vector_db_within_limits(db, test_team, test_product): @@ -451,7 +451,7 @@ def test_create_vector_db_exceeding_limit(db, test_team, test_product, test_regi # Test that check_vector_db_limits raises an exception with pytest.raises(HTTPException) as exc_info: check_vector_db_limits(db, test_team.id) - assert exc_info.value.status_code == 400 + assert exc_info.value.status_code == 402 assert f"Team has reached the maximum vector DB limit of {test_product.vector_db_count} databases" in str(exc_info.value.detail) def test_create_vector_db_with_default_limit(db, test_team, test_region): @@ -473,7 +473,7 @@ def test_create_vector_db_with_default_limit(db, test_team, test_region): # Test that check_vector_db_limits raises an exception with pytest.raises(HTTPException) as exc_info: check_vector_db_limits(db, test_team.id) - assert exc_info.value.status_code == 400 + assert exc_info.value.status_code == 402 assert "Team has reached the maximum vector DB limit of 1 databases" in str(exc_info.value.detail) def test_create_vector_db_with_multiple_products(db, test_team, test_region): @@ -544,7 +544,7 @@ def test_create_vector_db_with_multiple_products(db, test_team, test_region): # Test that check_vector_db_limits raises an exception with pytest.raises(HTTPException) as exc_info: check_vector_db_limits(db, test_team.id) - assert exc_info.value.status_code == 400 + assert exc_info.value.status_code == 402 assert f"Team has reached the maximum vector DB limit of {product2.vector_db_count} databases" in str(exc_info.value.detail) def test_create_vector_db_with_user_owned_key(db, test_team, test_region, test_team_user): @@ -567,5 +567,5 @@ def test_create_vector_db_with_user_owned_key(db, test_team, test_region, test_t # Test that check_vector_db_limits raises an exception with pytest.raises(HTTPException) as exc_info: check_vector_db_limits(db, test_team.id) - assert exc_info.value.status_code == 400 + assert exc_info.value.status_code == 402 assert "Team has reached the maximum vector DB limit of 1 databases" in str(exc_info.value.detail) From 1666f0621e46f354b641d3f48d49bc2d4fa8cc14 Mon Sep 17 00:00:00 2001 From: Pippa H Date: Thu, 22 May 2025 11:36:12 +0200 Subject: [PATCH 15/23] Enhance product management and Stripe integration - Added a new page for product management in the frontend, allowing users to create, edit, and delete products with detailed forms and validation. - Updated the backend to include checks preventing the deletion of products associated with teams. - Introduced a new function to retrieve the Stripe customer ID from payment intents. - Enhanced the README with instructions for testing Stripe integration. - Added tests to ensure product creation and deletion functionality, including checks for inconsistent key counts and team associations. --- README.md | 3 + app/api/billing.py | 6 +- app/api/products.py | 11 +- app/schemas/models.py | 16 +- app/services/stripe.py | 8 + frontend/src/app/admin/products/page.tsx | 473 +++++++++++++++++++++ frontend/src/components/sidebar-layout.tsx | 4 +- tests/stripe_test_triggers.md | 43 ++ tests/test_products.py | 61 ++- 9 files changed, 612 insertions(+), 13 deletions(-) create mode 100644 frontend/src/app/admin/products/page.tsx create mode 100644 tests/stripe_test_triggers.md diff --git a/README.md b/README.md index a202446..4834231 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,9 @@ make backend-test-cov # Run backend tests with coverage report make backend-test-regex # Waits for a string which pytest will parse to only collect a subset of tests ``` +### 💳 Testing Stripe +See [[tests/stripe_test_trigger.md]] for detailed instructions on testing integration with Stripe for billing purposes. + ### Frontend Tests ```bash make frontend-test # Run frontend tests diff --git a/app/api/billing.py b/app/api/billing.py index a75556b..31a1ec0 100644 --- a/app/api/billing.py +++ b/app/api/billing.py @@ -13,7 +13,8 @@ get_product_id_from_subscription, get_product_id_from_session, create_stripe_customer, - get_pricing_table_session + get_pricing_table_session, + get_customer_from_pi ) from app.core.worker import apply_product_for_team, remove_product_from_team @@ -94,6 +95,9 @@ async def handle_stripe_event_background(event, db: Session): return event_object = event.data.object customer_id = event_object.customer + if not customer_id: + logger.warning(f"No customer ID found in event, cannot complete processing") + return # Success Events if event_type in invoice_success_events: # We assume that the invoice is related to a subscription diff --git a/app/api/products.py b/app/api/products.py index 2ed51de..45c243a 100644 --- a/app/api/products.py +++ b/app/api/products.py @@ -4,7 +4,7 @@ from datetime import datetime, UTC from app.db.database import get_db -from app.db.models import DBProduct +from app.db.models import DBProduct, DBTeamProduct from app.core.security import check_system_admin, get_current_user_from_auth, get_role_min_team_admin from app.schemas.models import Product, ProductCreate, ProductUpdate @@ -52,6 +52,7 @@ async def create_product( return db_product +@router.get("", response_model=List[Product], dependencies=[Depends(get_role_min_team_admin)]) @router.get("/", response_model=List[Product], dependencies=[Depends(get_role_min_team_admin)]) async def list_products( db: Session = Depends(get_db) @@ -119,6 +120,14 @@ async def delete_product( detail="Product not found" ) + # Check if the product is associated with any teams + team_association = db.query(DBTeamProduct).filter(DBTeamProduct.product_id == product_id).first() + if team_association: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Cannot delete product that is associated with one or more teams" + ) + db.delete(product) db.commit() diff --git a/app/schemas/models.py b/app/schemas/models.py index 194a595..42c6a3d 100644 --- a/app/schemas/models.py +++ b/app/schemas/models.py @@ -276,14 +276,14 @@ class CheckoutSessionCreate(BaseModel): class ProductBase(BaseModel): name: str id: str # This is the Stripe product ID, format should be prod_XXX - user_count: int = 1 - keys_per_user: int = 1 - total_key_count: int = 6 - service_key_count: int = 5 - max_budget_per_key: float = 20.0 - rpm_per_key: int = 500 - vector_db_count: int = 1 - vector_db_storage: int = 50 # Not used yet, should be a number in GiB + user_count: Optional[int] = 1 + keys_per_user: Optional[int] = 1 + total_key_count: Optional[int] = 6 + service_key_count: Optional[int] = 5 + max_budget_per_key: Optional[float] = 20.0 + rpm_per_key: Optional[int] = 500 + vector_db_count: Optional[int] = 1 + vector_db_storage: Optional[int] = 50 # Not used yet, should be a number in GiB renewal_period_days: int = 30 active: bool = True diff --git a/app/services/stripe.py b/app/services/stripe.py index ce0c04d..66429a5 100644 --- a/app/services/stripe.py +++ b/app/services/stripe.py @@ -278,6 +278,14 @@ async def get_product_id_from_session(session_id: str) -> str: line_items = stripe.checkout.Session.list_line_items(session_id) return line_items.data[0].price.product +async def get_customer_from_pi(payment_intent: str) -> str: + """ + Get the Stripe customer ID from a payment intent. + """ + payment_intent = stripe.PaymentIntent.retrieve(payment_intent) + logger.info(f"Payment intent is:\n{payment_intent}") + return payment_intent.customer + async def get_pricing_table_session(customer_id: str) -> str: """ Create a Stripe Customer Session client secret for a customer. diff --git a/frontend/src/app/admin/products/page.tsx b/frontend/src/app/admin/products/page.tsx new file mode 100644 index 0000000..24929b3 --- /dev/null +++ b/frontend/src/app/admin/products/page.tsx @@ -0,0 +1,473 @@ +'use client'; + +import { useState } from 'react'; +import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'; +import { + Table, + TableBody, + TableCell, + TableHead, + TableHeader, + TableRow, +} from "@/components/ui/table"; +import { Button } from "@/components/ui/button"; +import { + Dialog, + DialogContent, + DialogHeader, + DialogTitle, + DialogTrigger, +} from "@/components/ui/dialog"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { useToast } from '@/hooks/use-toast'; +import { get, post, put, del } from '@/utils/api'; + +interface Product { + id: string; + name: string; + user_count: number; + keys_per_user: number; + total_key_count: number; + service_key_count: number; + max_budget_per_key: number; + rpm_per_key: number; + vector_db_count: number; + vector_db_storage: number; + renewal_period_days: number; + active: boolean; + created_at: string; +} + +export default function ProductsPage() { + const { toast } = useToast(); + const queryClient = useQueryClient(); + const [isCreateDialogOpen, setIsCreateDialogOpen] = useState(false); + const [isEditDialogOpen, setIsEditDialogOpen] = useState(false); + const [selectedProduct, setSelectedProduct] = useState(null); + const [formData, setFormData] = useState>({}); + + // Update form data + const updateFormData = (newData: Partial) => { + setFormData(newData); + }; + + // Queries + const { data: products = [], isLoading } = useQuery({ + queryKey: ['products'], + queryFn: async () => { + const response = await get('/products'); + return response.json(); + }, + }); + + // Mutations + const createProductMutation = useMutation({ + mutationFn: async (productData: Partial) => { + const response = await post('/products', productData); + return response.json(); + }, + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: ['products'] }); + setIsCreateDialogOpen(false); + setFormData({}); + toast({ + title: "Success", + description: "Product created successfully" + }); + }, + onError: (error: Error) => { + toast({ + variant: "destructive", + title: "Error", + description: error.message + }); + }, + }); + + const updateProductMutation = useMutation({ + mutationFn: async ({ id, data }: { id: string; data: Partial }) => { + const response = await put(`/products/${id}`, data); + return response.json(); + }, + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: ['products'] }); + setIsEditDialogOpen(false); + setSelectedProduct(null); + setFormData({}); + toast({ + title: "Success", + description: "Product updated successfully" + }); + }, + onError: (error: Error) => { + toast({ + variant: "destructive", + title: "Error", + description: error.message + }); + }, + }); + + const deleteProductMutation = useMutation({ + mutationFn: async (id: string) => { + await del(`/products/${id}`); + }, + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: ['products'] }); + toast({ + title: "Success", + description: "Product deleted successfully" + }); + }, + onError: (error: Error) => { + toast({ + variant: "destructive", + title: "Error", + description: error.message + }); + }, + }); + + const handleCreate = () => { + createProductMutation.mutate(formData); + }; + + const handleUpdate = () => { + if (!selectedProduct) return; + updateProductMutation.mutate({ id: selectedProduct.id, data: formData }); + }; + + const handleDelete = (id: string) => { + if (!confirm('Are you sure you want to delete this product?')) return; + deleteProductMutation.mutate(id); + }; + + return ( +
+
+

Product Management

+ + + + + + + Create New Product + +
+
+ + updateFormData({ ...formData, id: e.target.value })} + /> +
+
+ + updateFormData({ ...formData, name: e.target.value })} + /> +
+
+ + updateFormData({ ...formData, user_count: parseInt(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, keys_per_user: parseInt(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, total_key_count: parseInt(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, service_key_count: parseInt(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, max_budget_per_key: parseFloat(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, rpm_per_key: parseInt(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, vector_db_count: parseInt(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, vector_db_storage: parseInt(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, renewal_period_days: parseInt(e.target.value) })} + /> +
+
+
+ updateFormData({ ...formData, active: e.target.checked })} + className="h-4 w-4 rounded border-gray-300" + /> + +
+
+
+
+ +
+
+
+
+ + + + + ID + Name + User Count + Keys/User + Total Keys + Service Keys + Budget/Key + RPM/Key + Vector DBs + Storage (GiB) + Renewal (Days) + Status + Created + Actions + + + + {products.map((product) => ( + + {product.id} + {product.name} + {product.user_count} + {product.keys_per_user} + {product.total_key_count} + {product.service_key_count} + ${product.max_budget_per_key.toFixed(2)} + {product.rpm_per_key} + {product.vector_db_count} + {product.vector_db_storage} + {product.renewal_period_days} + + + {product.active ? 'Active' : 'Inactive'} + + + {new Date(product.created_at).toLocaleDateString()} + +
+ + +
+
+
+ ))} +
+
+ + + + + Edit Product + +
+
+ + +
+
+ + updateFormData({ ...formData, name: e.target.value })} + /> +
+
+ + updateFormData({ ...formData, user_count: parseInt(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, keys_per_user: parseInt(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, total_key_count: parseInt(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, service_key_count: parseInt(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, max_budget_per_key: parseFloat(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, rpm_per_key: parseInt(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, vector_db_count: parseInt(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, vector_db_storage: parseInt(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, renewal_period_days: parseInt(e.target.value) })} + /> +
+
+
+ updateFormData({ ...formData, active: e.target.checked })} + className="h-4 w-4 rounded border-gray-300" + /> + +
+
+
+
+ +
+
+
+
+ ); +} \ No newline at end of file diff --git a/frontend/src/components/sidebar-layout.tsx b/frontend/src/components/sidebar-layout.tsx index 8b2e147..f2143af 100644 --- a/frontend/src/components/sidebar-layout.tsx +++ b/frontend/src/components/sidebar-layout.tsx @@ -10,7 +10,8 @@ import { ChevronDown, PanelLeftClose, PanelLeft, - Users2 + Users2, + Package } from 'lucide-react'; import { Sidebar, SidebarProvider } from '@/components/ui/sidebar'; import { NavUser } from '@/components/nav-user'; @@ -45,6 +46,7 @@ const navigation = [ { name: 'Teams', href: '/admin/teams', icon: }, { name: 'Users', href: '/admin/users', icon: }, { name: 'Regions', href: '/admin/regions', icon: }, + { name: 'Products', href: '/admin/products', icon: }, { name: 'Private AI Keys', href: '/admin/private-ai-keys', icon: }, { name: 'Audit Logs', href: '/admin/audit-logs', icon: }, ], diff --git a/tests/stripe_test_triggers.md b/tests/stripe_test_triggers.md new file mode 100644 index 0000000..45048b3 --- /dev/null +++ b/tests/stripe_test_triggers.md @@ -0,0 +1,43 @@ +Useful commands for testing the Stripe integration from the CLI + +## Setup: +Follow the [installation instructions](https://docs.stripe.com/stripe-cli#install), and then log in. If you have your test key, you can use that when you log in to ensure you're using it for everything. + +## Testing: +You will need to have at least four terminal panes open. A multiplexer like tmux will make this super easy, but you can do new tabs or windows if you need to. +### Pane 1 - Logs: +Watch the logs for what you're sending to Stripe by running +```sh +stripe logs tail +``` +### Pane 2 - Webhook: +Listen for stripe events, and forward them to your local testing environment: +```sh +stripe listen --forward-to localhost:8800/billing/events +``` +You can optionally limit which events are forwarded, but the handler defaults to accepting anything and then choosing what to do with it in the background. +Running this command will give you the webhook secret you need to decode all events. make sure to copy it into the appropriate environment variable. +### Pane 3 - Backend service: +Stand up the service using docker compose, and watch the logs to see events being processed as they com in: +```sh +docker compose up --build -d +docker compose logs -f backend +``` +### Pane 4 - Triggers: +Each of these triggers will initiate a different flow in the backend. USe them to ensure everything is still working as expected: + +Add a product +```sh +stripe trigger checkout.session.completed --override checkout_session:customer=cus_SLpVFWQFHmls9T # forces a customer ID to be set +stripe trigger checkout.session.async_payment_succeeded --override checkout_session:customer=cus_SLpVFWQFHmls9T # will succeed twice +stripe trigger subscription.payment_succeeded # Will succeed twice +``` + +Remove a product +```sh +stripe trigger checkout.session.async_payment_failed --override checkout_session:customer=cus_SLpVFWQFHmls9T # Should add then remove +stripe trigger checkout.session.expired --override checkout_session:customer=cus_SLpVFWQFHmls9T # Will remove the product without first adding it. +stripe trigger subscription.payment_failed # Add then remove +stripe trigger customer.subscription.paused # Add then remove +stripe trigger customer.subscription.deleted # Add twice, then remove +``` \ No newline at end of file diff --git a/tests/test_products.py b/tests/test_products.py index 775c845..26e750f 100644 --- a/tests/test_products.py +++ b/tests/test_products.py @@ -1,7 +1,7 @@ import pytest from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from app.db.models import DBProduct, DBUser, DBTeam +from app.db.models import DBProduct, DBUser, DBTeam, DBTeamProduct from datetime import datetime, UTC def test_create_product_as_system_admin(client, admin_token, db): @@ -96,6 +96,35 @@ def test_create_product_duplicate_id(client, admin_token, db): assert response.status_code == 400 assert "already exists" in response.json()["detail"] +def test_create_product_inconsistent_key_counts(client, admin_token, db): + """ + Test that creating a product with inconsistent key counts fails. + + GIVEN: The authenticated user is a system admin + WHEN: They create a product with user_count = 5, keys_per_user = 2, service_key_count = 2, and total_key_count = 6 + THEN: A 400 - Bad Request is returned + """ + response = client.post( + "/products/", + headers={"Authorization": f"Bearer {admin_token}"}, + json={ + "id": "prod_test123", + "name": "Test Product", + "user_count": 5, + "keys_per_user": 2, + "total_key_count": 6, # Should be 12 (5 * 2 + 2) + "service_key_count": 2, + "max_budget_per_key": 50.0, + "rpm_per_key": 1000, + "vector_db_count": 1, + "vector_db_storage": 100, + "renewal_period_days": 30, + "active": True + } + ) + assert response.status_code == 400 + assert "inconsistent" in response.json()["detail"].lower() + def test_create_product_unauthorized(client, test_token, db): """ Test that a non-admin user cannot create a product. @@ -381,4 +410,32 @@ def test_delete_product_as_system_admin(client, admin_token, db): # Verify the product is deleted deleted_product = db.query(DBProduct).filter(DBProduct.id == db_product.id).first() - assert deleted_product is None \ No newline at end of file + assert deleted_product is None + +def test_delete_product_with_team_association(client, admin_token, db, test_team, test_product): + """ + Test that a product cannot be deleted if it's associated with a team. + + GIVEN: A product which has been applied to a team + WHEN: An authorised user tries to delete the product + THEN: An error is returned + """ + # Associate the product with a team + team_product = DBTeamProduct( + team_id=test_team.id, + product_id=test_product.id + ) + db.add(team_product) + db.commit() + + # Try to delete the product + response = client.delete( + f"/products/{test_product.id}", + headers={"Authorization": f"Bearer {admin_token}"} + ) + assert response.status_code == 400 + assert "cannot be deleted" in response.json()["detail"].lower() + + # Verify the product still exists + existing_product = db.query(DBProduct).filter(DBProduct.id == test_product.id).first() + assert existing_product is not None \ No newline at end of file From 3467782b59b9b529788b86cda348a55e8abe9baf Mon Sep 17 00:00:00 2001 From: Pippa H Date: Thu, 22 May 2025 12:18:13 +0200 Subject: [PATCH 16/23] Refactor API key management and enhance resource limits validation - Updated the `create_llm_token` function to check key limits based on team and user configurations. - Modified the `Settings` class to allow environment variable overrides for limit settings. - Adjusted default user count limit to improve resource management. - Enhanced tests to validate behavior when limits are enabled, ensuring proper error handling for exceeding key creation limits. - Removed outdated test for inconsistent key counts in product creation. --- app/api/private_ai_keys.py | 7 +- app/core/config.py | 8 +- app/core/resource_limits.py | 2 +- tests/test_private_ai.py | 152 +++++++++++++++++++++++++++++++++- tests/test_products.py | 31 +------ tests/test_resource_limits.py | 2 +- tests/test_users.py | 23 +++++ 7 files changed, 184 insertions(+), 41 deletions(-) diff --git a/app/api/private_ai_keys.py b/app/api/private_ai_keys.py index fe70ea5..6258d9d 100644 --- a/app/api/private_ai_keys.py +++ b/app/api/private_ai_keys.py @@ -294,6 +294,10 @@ async def create_llm_token( detail="Team not found" ) + if owner.team_id or team_id: + if settings.ENABLE_LIMITS: + check_key_limits(db, owner.team_id or team_id, owner_id) + if team is not None: owner_email = team.admin_email litellm_team = team.id @@ -302,9 +306,6 @@ async def create_llm_token( litellm_team = owner.team_id or FAKE_ID try: - if settings.ENABLE_LIMITS: # Have to do this check so late since we always need the team ID - check_key_limits(db, litellm_team, owner_id) - # Generate LiteLLM token litellm_service = LiteLLMService( api_url=region.litellm_api_url, diff --git a/app/core/config.py b/app/core/config.py index 33aa33e..8864b85 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -25,12 +25,12 @@ class Settings(BaseSettings): AWS_SECRET_ACCESS_KEY: str = "sk-string" SES_SENDER_EMAIL: str = "info@example.com" PASSWORDLESS_SIGN_IN: str = "true" - ENV_SUFFIX: str = "local" + ENV_SUFFIX: str = os.getenv("ENV_SUFFIX", "local") DYNAMODB_REGION: str = "eu-west-1" SES_REGION: str = "eu-west-1" - ENABLE_LIMITS: bool = True - STRIPE_SECRET_KEY: str = "sk_test_string" - WEBHOOK_SIG: str = "whsec_test_1234567890" + ENABLE_LIMITS: bool = os.getenv("ENABLE_LIMITS", "false") == "true" + STRIPE_SECRET_KEY: str = os.getenv("STRIPE_SECRET_KEY", "sk_test_string") + WEBHOOK_SIG: str = os.getenv("WEBHOOK_SIG", "whsec_test_1234567890") model_config = ConfigDict(env_file=".env") diff --git a/app/core/resource_limits.py b/app/core/resource_limits.py index 8cd4a43..b960fcd 100644 --- a/app/core/resource_limits.py +++ b/app/core/resource_limits.py @@ -4,7 +4,7 @@ from typing import Optional # Default limits across all customers and products -DEFAULT_USER_COUNT = 2 +DEFAULT_USER_COUNT = 1 DEFAULT_KEYS_PER_USER = 1 DEFAULT_TOTAL_KEYS = 2 DEFAULT_SERVICE_KEYS = 1 diff --git a/tests/test_private_ai.py b/tests/test_private_ai.py index 99466be..784ee2a 100644 --- a/tests/test_private_ai.py +++ b/tests/test_private_ai.py @@ -1098,9 +1098,9 @@ def test_create_llm_token_as_system_admin(mock_post, client, admin_token, test_r ) @patch("app.services.litellm.requests.post") -@patch.dict(os.environ, {"ENABLE_LIMITS": "true"}) +@patch('app.core.config.settings.ENABLE_LIMITS', True) def test_create_llm_token_with_expiration(mock_post, client, admin_token, test_region, mock_litellm_response): - """Test that when ENABLE_LIMITS is true, new LiteLLM tokens are created with a 14-day expiration duration""" + """Test that when ENABLE_LIMITS is true, new LiteLLM tokens are created with a 30-day expiration duration""" # Mock the LiteLLM API response mock_post.return_value.status_code = 200 mock_post.return_value.json.return_value = mock_litellm_response @@ -1479,3 +1479,151 @@ def test_get_private_ai_key_unauthorized(mock_get, client, test_token, test_regi # Clean up db.delete(test_key) db.commit() + +@patch("app.services.litellm.requests.post") +@patch('app.core.config.settings.ENABLE_LIMITS', True) +def test_create_too_many_service_keys(mock_post, client, admin_token, test_region, mock_litellm_response, db, test_team): + """Test that when ENABLE_LIMITS is true, creating too many service keys fails""" + # Mock the LiteLLM API response + mock_post.return_value.status_code = 200 + mock_post.return_value.json.return_value = mock_litellm_response + mock_post.return_value.raise_for_status.return_value = None + + team_id = test_team.id + # Create first service key + response = client.post( + "/private-ai-keys/token", + headers={"Authorization": f"Bearer {admin_token}"}, + json={ + "region_id": test_region.id, + "name": "First Service Key", + "team_id": team_id + } + ) + assert response.status_code == 200 + + # Try to create second service key + response = client.post( + "/private-ai-keys/token", + headers={"Authorization": f"Bearer {admin_token}"}, + json={ + "region_id": test_region.id, + "name": "Second Service Key", + "team_id": team_id + } + ) + assert response.status_code == 402 + assert "Team has reached the maximum service LLM token limit of 1 tokens" in response.json()["detail"] + +@patch("app.services.litellm.requests.post") +@patch('app.core.config.settings.ENABLE_LIMITS', True) +def test_create_too_many_user_keys(mock_post, client, admin_token, test_region, mock_litellm_response, db, test_team_user): + """Test that when ENABLE_LIMITS is true, creating too many user keys fails""" + # Mock the LiteLLM API response + mock_post.return_value.status_code = 200 + mock_post.return_value.json.return_value = mock_litellm_response + mock_post.return_value.raise_for_status.return_value = None + + user_id = test_team_user.id + # Create first user key + response = client.post( + "/private-ai-keys/token", + headers={"Authorization": f"Bearer {admin_token}"}, + json={ + "region_id": test_region.id, + "name": "First User Key", + "owner_id": user_id + } + ) + assert response.status_code == 200 + + # Try to create second user key + response = client.post( + "/private-ai-keys/token", + headers={"Authorization": f"Bearer {admin_token}"}, + json={ + "region_id": test_region.id, + "name": "Second User Key", + "owner_id": user_id + } + ) + assert response.status_code == 402 + assert "User has reached the maximum LLM token limit of 1 tokens" in response.json()["detail"] + +@patch('app.core.config.settings.ENABLE_LIMITS', True) +def test_create_too_many_vector_dbs(client, admin_token, test_region, db, test_team): + """Test that when ENABLE_LIMITS is true, creating too many vector DBs fails""" + # Create first vector DB + team_id = test_team.id + response = client.post( + "/private-ai-keys/vector-db", + headers={"Authorization": f"Bearer {admin_token}"}, + json={ + "region_id": test_region.id, + "name": "First Vector DB", + "team_id": team_id + } + ) + assert response.status_code == 200 + + # Try to create second vector DB + response = client.post( + "/private-ai-keys/vector-db", + headers={"Authorization": f"Bearer {admin_token}"}, + json={ + "region_id": test_region.id, + "name": "Second Vector DB", + "team_id": team_id + } + ) + assert response.status_code == 402 + assert "Team has reached the maximum vector DB limit of 1 databases" in response.json()["detail"] + +@patch("app.services.litellm.requests.post") +@patch('app.core.config.settings.ENABLE_LIMITS', True) +def test_create_too_many_total_keys(mock_post, client, admin_token, test_region, mock_litellm_response, db, test_team, test_team_user): + """Test that when ENABLE_LIMITS is true, creating too many total keys fails""" + # Mock the LiteLLM API response + mock_post.return_value.status_code = 200 + mock_post.return_value.json.return_value = mock_litellm_response + mock_post.return_value.raise_for_status.return_value = None + + team_id = test_team.id + user_id = test_team_user.id + region_id = test_region.id + # Create first key (service key) + response = client.post( + "/private-ai-keys/token", + headers={"Authorization": f"Bearer {admin_token}"}, + json={ + "region_id": region_id, + "name": "First Key", + "team_id": team_id + } + ) + assert response.status_code == 200 + + # Create second key (user key) + response = client.post( + "/private-ai-keys/token", + headers={"Authorization": f"Bearer {admin_token}"}, + json={ + "region_id": region_id, + "name": "Second Key", + "owner_id": user_id + } + ) + assert response.status_code == 200 + + # Try to create third key + response = client.post( + "/private-ai-keys/token", + headers={"Authorization": f"Bearer {admin_token}"}, + json={ + "region_id": region_id, + "name": "Third Key", + "owner_id": user_id + } + ) + assert response.status_code == 402 + assert "Team has reached the maximum LLM token limit of 2 tokens" in response.json()["detail"] diff --git a/tests/test_products.py b/tests/test_products.py index 26e750f..76080aa 100644 --- a/tests/test_products.py +++ b/tests/test_products.py @@ -96,35 +96,6 @@ def test_create_product_duplicate_id(client, admin_token, db): assert response.status_code == 400 assert "already exists" in response.json()["detail"] -def test_create_product_inconsistent_key_counts(client, admin_token, db): - """ - Test that creating a product with inconsistent key counts fails. - - GIVEN: The authenticated user is a system admin - WHEN: They create a product with user_count = 5, keys_per_user = 2, service_key_count = 2, and total_key_count = 6 - THEN: A 400 - Bad Request is returned - """ - response = client.post( - "/products/", - headers={"Authorization": f"Bearer {admin_token}"}, - json={ - "id": "prod_test123", - "name": "Test Product", - "user_count": 5, - "keys_per_user": 2, - "total_key_count": 6, # Should be 12 (5 * 2 + 2) - "service_key_count": 2, - "max_budget_per_key": 50.0, - "rpm_per_key": 1000, - "vector_db_count": 1, - "vector_db_storage": 100, - "renewal_period_days": 30, - "active": True - } - ) - assert response.status_code == 400 - assert "inconsistent" in response.json()["detail"].lower() - def test_create_product_unauthorized(client, test_token, db): """ Test that a non-admin user cannot create a product. @@ -434,7 +405,7 @@ def test_delete_product_with_team_association(client, admin_token, db, test_team headers={"Authorization": f"Bearer {admin_token}"} ) assert response.status_code == 400 - assert "cannot be deleted" in response.json()["detail"].lower() + assert "cannot delete product" in response.json()["detail"].lower() # Verify the product still exists existing_product = db.query(DBProduct).filter(DBProduct.id == test_product.id).first() diff --git a/tests/test_resource_limits.py b/tests/test_resource_limits.py index 6cc8604..7bb61ee 100644 --- a/tests/test_resource_limits.py +++ b/tests/test_resource_limits.py @@ -67,7 +67,7 @@ def test_add_user_with_default_limit(db, test_team): with pytest.raises(HTTPException) as exc_info: check_team_user_limit(db, test_team.id) assert exc_info.value.status_code == 402 - assert "Team has reached the maximum user limit of 2 users" in str(exc_info.value.detail) + assert "Team has reached the maximum user limit" in str(exc_info.value.detail) def test_add_user_with_multiple_products(db, test_team): """Test adding users when team has multiple products with different limits""" diff --git a/tests/test_users.py b/tests/test_users.py index 9b077fe..d74b354 100644 --- a/tests/test_users.py +++ b/tests/test_users.py @@ -3,6 +3,7 @@ from app.db.models import DBUser, DBTeam from datetime import datetime, UTC from unittest.mock import patch +import os def test_create_user(client, test_admin, admin_token): response = client.post( @@ -406,3 +407,25 @@ def test_user_privilege_escalation(client, team_admin_token): ) assert response.status_code == 404 assert "User not found" in response.json()["detail"] + +@patch('app.api.users.settings.ENABLE_LIMITS', True) +def test_create_user_with_limits_enabled(client, team_admin_token, test_team, db): + """ + Test that a team cannot create more users when ENABLE_LIMITS is true and they have reached their limit. + + GIVEN: a team with one user, and ENABLE_LIMITS is true + WHEN: the team tries to create another user + THEN: a 402 payment required is returned + """ + # Create a new user in the team + response = client.post( + "/users/", + headers={"Authorization": f"Bearer {team_admin_token}"}, + json={ + "email": "newteamuser@example.com", + "password": "newpassword", + "team_id": test_team.id + } + ) + assert response.status_code == 402 + assert "Team has reached the maximum user limit" in response.json()["detail"] From 66b91ee16b1707601c0843db7379f80466ee855f Mon Sep 17 00:00:00 2001 From: Pippa H Date: Fri, 23 May 2025 11:27:04 +0200 Subject: [PATCH 17/23] Enhance token restrictions management and resource limits validation - Introduced a new `get_token_restrictions` function to retrieve token limits based on team products and payment history. - Updated the `create_llm_token` function to apply token restrictions dynamically based on the retrieved values. - Refactored the `LiteLLMService` to include a new method for setting key restrictions, consolidating budget and duration updates. - Enhanced tests to validate the behavior of token restrictions under various scenarios, ensuring proper handling of default and product-specific limits. - Updated existing tests to reflect changes in key management and resource limits functionality. --- app/api/private_ai_keys.py | 13 +++- app/core/resource_limits.py | 37 ++++++++++- app/core/worker.py | 19 +++--- app/services/litellm.py | 55 +++++++++++++--- tests/test_private_ai.py | 3 + tests/test_resource_limits.py | 120 +++++++++++++++++++++++++++++++++- tests/test_worker.py | 26 +++----- 7 files changed, 232 insertions(+), 41 deletions(-) diff --git a/app/api/private_ai_keys.py b/app/api/private_ai_keys.py index 6258d9d..f7a513c 100644 --- a/app/api/private_ai_keys.py +++ b/app/api/private_ai_keys.py @@ -15,7 +15,7 @@ from app.services.litellm import LiteLLMService from app.core.security import get_current_user_from_auth, get_role_min_key_creator, get_role_min_team_admin, UserRole, check_system_admin from app.core.config import settings -from app.core.resource_limits import check_key_limits, check_vector_db_limits +from app.core.resource_limits import check_key_limits, check_vector_db_limits, get_token_restrictions, DEFAULT_KEY_DURATION, DEFAULT_MAX_SPEND, DEFAULT_RPM_PER_KEY router = APIRouter( tags=["private-ai-keys"] @@ -297,6 +297,12 @@ async def create_llm_token( if owner.team_id or team_id: if settings.ENABLE_LIMITS: check_key_limits(db, owner.team_id or team_id, owner_id) + # Limits are conditionally applied in LiteLLM service + days_left_in_period, max_max_spend, max_rpm_limit = get_token_restrictions(db, owner.team_id or team_id) + else: # Super system users... + days_left_in_period = DEFAULT_KEY_DURATION + max_max_spend = DEFAULT_MAX_SPEND + max_rpm_limit = DEFAULT_RPM_PER_KEY if team is not None: owner_email = team.admin_email @@ -315,7 +321,10 @@ async def create_llm_token( email=owner_email, name=private_ai_key.name, user_id=owner_id, - team_id=f"{region.name.replace(' ', '_')}_{litellm_team}" + team_id=f"{region.name.replace(' ', '_')}_{litellm_team}", + duration=f"{days_left_in_period}d", + max_budget=max_max_spend, + rpm_limit=max_rpm_limit ) # Create response object diff --git a/app/core/resource_limits.py b/app/core/resource_limits.py index b960fcd..b21180f 100644 --- a/app/core/resource_limits.py +++ b/app/core/resource_limits.py @@ -2,6 +2,10 @@ from app.db.models import DBTeam, DBUser, DBPrivateAIKey from fastapi import HTTPException, status from typing import Optional +from datetime import datetime, UTC +import logging + +logger = logging.getLogger(__name__) # Default limits across all customers and products DEFAULT_USER_COUNT = 1 @@ -9,7 +13,7 @@ DEFAULT_TOTAL_KEYS = 2 DEFAULT_SERVICE_KEYS = 1 DEFAULT_VECTOR_DB_COUNT = 1 -DEFAULT_KEY_DURATION = "30d" +DEFAULT_KEY_DURATION = 30 DEFAULT_MAX_SPEND = 20.0 DEFAULT_RPM_PER_KEY = 500 @@ -157,3 +161,34 @@ def check_vector_db_limits(db: Session, team_id: int) -> None: status_code=status.HTTP_402_PAYMENT_REQUIRED, detail=f"Team has reached the maximum vector DB limit of {max_vector_db_count} databases" ) + +def get_token_restrictions(db: Session, team_id: int) -> tuple[int, float, int]: + """ + Get the token restrictions for a team. + """ + team = db.query(DBTeam).filter(DBTeam.id == team_id).first() + if not team: + logger.error(f"Team not found for team_id: {team_id}") + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Team not found") + + max_key_duration = max( + (product.renewal_period_days for team_product in team.active_products + for product in [team_product.product] if product.renewal_period_days), + default=DEFAULT_KEY_DURATION + ) + if team.last_payment is None: + days_left_in_period = max_key_duration + else: + days_left_in_period = max_key_duration - (datetime.now(UTC) - max(team.created_at, team.last_payment)).days + max_max_spend = max( + (product.max_budget_per_key for team_product in team.active_products + for product in [team_product.product] if product.max_budget_per_key), + default=DEFAULT_MAX_SPEND + ) + max_rpm_limit = max( + (product.rpm_per_key for team_product in team.active_products + for product in [team_product.product] if product.rpm_per_key), + default=DEFAULT_RPM_PER_KEY + ) + + return days_left_in_period, max_max_spend, max_rpm_limit \ No newline at end of file diff --git a/app/core/worker.py b/app/core/worker.py index 7144c53..9d22daa 100644 --- a/app/core/worker.py +++ b/app/core/worker.py @@ -4,6 +4,7 @@ from app.services.litellm import LiteLLMService import logging from collections import defaultdict +from app.core.resource_limits import get_token_restrictions logger = logging.getLogger(__name__) @@ -49,6 +50,9 @@ async def apply_product_for_team(db: Session, customer_id: str, product_id: str) product_id=product.id ) db.add(team_product) + db.commit() # Commit the product association + + days_left_in_period, max_max_spend, max_rpm_limit = get_token_restrictions(db, team.id) # Get all keys for the team with their regions team_users = db.query(DBUser).filter(DBUser.team_id == team.id).all() @@ -81,17 +85,12 @@ async def apply_product_for_team(db: Session, customer_id: str, product_id: str) # Update each key's duration and budget via LiteLLM for key in keys: try: - # Update key duration - await litellm_service.update_key_duration( - litellm_token=key.litellm_token, - duration=f"{product.renewal_period_days}d" - ) - - # Update key budget - await litellm_service.update_budget( + await litellm_service.set_key_restrictions( litellm_token=key.litellm_token, - budget_duration=f"{product.renewal_period_days}d", - budget_amount=product.max_budget_per_key + duration=f"{days_left_in_period}d", + budget_duration=f"{days_left_in_period}d", + budget_amount=max_max_spend, + rpm_limit=max_rpm_limit ) except Exception as e: logger.error(f"Failed to update key {key.id} via LiteLLM: {str(e)}") diff --git a/app/services/litellm.py b/app/services/litellm.py index 3128b16..4780b08 100644 --- a/app/services/litellm.py +++ b/app/services/litellm.py @@ -3,6 +3,7 @@ import logging from app.core.resource_limits import DEFAULT_KEY_DURATION, DEFAULT_MAX_SPEND, DEFAULT_RPM_PER_KEY from app.core.config import settings +from typing import Optional logger = logging.getLogger(__name__) @@ -16,7 +17,7 @@ def __init__(self, api_url: str, api_key: str): if not self.master_key: raise ValueError("LiteLLM API key is required") - async def create_key(self, email: str, name: str, user_id: int, team_id: str) -> str: + async def create_key(self, email: str, name: str, user_id: int, team_id: str, duration: str = f"{DEFAULT_KEY_DURATION}d", max_budget: float = DEFAULT_MAX_SPEND, rpm_limit: int = DEFAULT_RPM_PER_KEY) -> str: """Create a new API key for LiteLLM""" try: logger.info(f"Creating new LiteLLM API key for email: {email}, name: {name}, user_id: {user_id}, team_id: {team_id}") @@ -41,10 +42,10 @@ async def create_key(self, email: str, name: str, user_id: int, team_id: str) -> request_data["team_id"] = team_id if settings.ENABLE_LIMITS: - request_data["duration"] = DEFAULT_KEY_DURATION - request_data["budget_duration"] = DEFAULT_KEY_DURATION - request_data["max_budget"] = DEFAULT_MAX_SPEND - request_data["rpm_limit"] = DEFAULT_RPM_PER_KEY + request_data["duration"] = duration + request_data["budget_duration"] = duration + request_data["max_budget"] = max_budget + request_data["rpm_limit"] = rpm_limit else: request_data["duration"] = "365d" @@ -123,19 +124,23 @@ async def get_key_info(self, litellm_token: str) -> dict: detail=f"Failed to get LiteLLM key information: {error_msg}" ) - async def update_budget(self, litellm_token: str, budget_duration: str): + async def update_budget(self, litellm_token: str, budget_duration: str, budget_amount: Optional[float] = None): """Update the budget for a LiteLLM API key""" try: # Update budget period in LiteLLM + request_data = { + "key": litellm_token, + "budget_duration": budget_duration + } + if budget_amount: + request_data["max_budget"] = budget_amount + response = requests.post( f"{self.api_url}/key/update", headers={ "Authorization": f"Bearer {self.master_key}" }, - json={ - "key": litellm_token, - "budget_duration": budget_duration - } + json=request_data ) response.raise_for_status() except requests.exceptions.RequestException as e: @@ -177,3 +182,33 @@ async def update_key_duration(self, litellm_token: str, duration: str): status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to update LiteLLM key duration: {error_msg}" ) + + async def set_key_restrictions(self, litellm_token: str, duration: str, budget_amount: float, rpm_limit: int, budget_duration: Optional[str] = None): + """Set the restrictions for a LiteLLM API key""" + try: + response = requests.post( + f"{self.api_url}/key/update", + headers={ + "Authorization": f"Bearer {self.master_key}" + }, + json={ + "key": litellm_token, + "duration": duration, + "budget_duration": budget_duration, + "max_budget": budget_amount, + "rpm_limit": rpm_limit + } + ) + response.raise_for_status() + except requests.exceptions.RequestException as e: + error_msg = str(e) + if hasattr(e, 'response') and e.response is not None: + try: + error_details = e.response.json() + error_msg = f"Status {e.response.status_code}: {error_details}" + except ValueError: + error_msg = f"Status {e.response.status_code}: {e.response.text}" + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to set LiteLLM key restrictions: {error_msg}" + ) diff --git a/tests/test_private_ai.py b/tests/test_private_ai.py index 784ee2a..0a89334 100644 --- a/tests/test_private_ai.py +++ b/tests/test_private_ai.py @@ -224,6 +224,9 @@ def test_create_private_ai_key_without_owner_or_team(mock_post, client, admin_to data = response.json() assert data["region"] == test_region.name assert data["litellm_token"] == "test-private-key-123" + assert data["owner_id"] is not None # Should be set to admin's ID + assert data["team_id"] is None + # Verify that the LiteLLM API was called mock_post.assert_called_once() diff --git a/tests/test_resource_limits.py b/tests/test_resource_limits.py index 7bb61ee..dae1454 100644 --- a/tests/test_resource_limits.py +++ b/tests/test_resource_limits.py @@ -1,8 +1,8 @@ import pytest from app.db.models import DBUser, DBProduct, DBTeamProduct, DBPrivateAIKey -from datetime import datetime, UTC +from datetime import datetime, UTC, timedelta from fastapi import HTTPException -from app.core.resource_limits import check_key_limits, check_team_user_limit, check_vector_db_limits +from app.core.resource_limits import check_key_limits, check_team_user_limit, check_vector_db_limits, get_token_restrictions, DEFAULT_KEY_DURATION, DEFAULT_MAX_SPEND, DEFAULT_RPM_PER_KEY def test_add_user_within_product_limit(db, test_team, test_product): """Test adding a user when within product user limit""" @@ -569,3 +569,119 @@ def test_create_vector_db_with_user_owned_key(db, test_team, test_region, test_t check_vector_db_limits(db, test_team.id) assert exc_info.value.status_code == 402 assert "Team has reached the maximum vector DB limit of 1 databases" in str(exc_info.value.detail) + +def test_get_token_restrictions_default_limits(db, test_team): + """Test getting token restrictions when team has no products (using default limits)""" + days_left, max_spend, rpm_limit = get_token_restrictions(db, test_team.id) + + # Should use default values since team has no products + assert days_left == DEFAULT_KEY_DURATION # 30 days + assert max_spend == DEFAULT_MAX_SPEND # 20.0 + assert rpm_limit == DEFAULT_RPM_PER_KEY # 500 + +def test_get_token_restrictions_with_product(db, test_team, test_product): + """Test getting token restrictions when team has a product""" + # Add product to team + team_product = DBTeamProduct( + team_id=test_team.id, + product_id=test_product.id + ) + db.add(team_product) + db.commit() + + days_left, max_spend, rpm_limit = get_token_restrictions(db, test_team.id) + + # Should use product values + assert days_left == test_product.renewal_period_days # 30 days + assert max_spend == test_product.max_budget_per_key # 50.0 + assert rpm_limit == test_product.rpm_per_key # 1000 + +def test_get_token_restrictions_with_multiple_products(db, test_team): + """Test getting token restrictions when team has multiple products with different limits""" + # Create two products with different limits + product1 = DBProduct( + id="prod_test1", + name="Test Product 1", + user_count=3, + keys_per_user=2, + total_key_count=10, + service_key_count=2, + max_budget_per_key=50.0, + rpm_per_key=1000, + vector_db_count=1, + vector_db_storage=100, + renewal_period_days=30, + active=True, + created_at=datetime.now(UTC) + ) + product2 = DBProduct( + id="prod_test2", + name="Test Product 2", + user_count=3, + keys_per_user=2, + total_key_count=10, + service_key_count=2, + max_budget_per_key=75.0, # Higher budget + rpm_per_key=2000, # Higher RPM + vector_db_count=1, + vector_db_storage=100, + renewal_period_days=60, # Longer duration + active=True, + created_at=datetime.now(UTC) + ) + db.add(product1) + db.add(product2) + db.commit() + + # Add both products to team + team_product1 = DBTeamProduct( + team_id=test_team.id, + product_id=product1.id + ) + team_product2 = DBTeamProduct( + team_id=test_team.id, + product_id=product2.id + ) + db.add(team_product1) + db.add(team_product2) + db.commit() + + days_left, max_spend, rpm_limit = get_token_restrictions(db, test_team.id) + + # Should use the maximum values from both products + assert days_left == product2.renewal_period_days # 60 days + assert max_spend == product2.max_budget_per_key # 75.0 + assert rpm_limit == product2.rpm_per_key # 2000 + +def test_get_token_restrictions_with_payment_history(db, test_team, test_product): + """Test getting token restrictions when team has payment history""" + # Add product to team + team_product = DBTeamProduct( + team_id=test_team.id, + product_id=test_product.id + ) + db.add(team_product) + db.commit() + + # Set created_at to 30 days ago and last_payment to 15 days ago + now = datetime.now(UTC) + test_team.created_at = now - timedelta(days=30) + test_team.last_payment = now - timedelta(days=15) + db.commit() + + days_left, max_spend, rpm_limit = get_token_restrictions(db, test_team.id) + + # Should have 15 days left (30 - 15) + assert days_left == 15 + assert max_spend == test_product.max_budget_per_key + assert rpm_limit == test_product.rpm_per_key + +def test_get_token_restrictions_team_not_found(db): + """Test getting token restrictions for non-existent team""" + from app.core.resource_limits import get_token_restrictions + from fastapi import HTTPException + + with pytest.raises(HTTPException) as exc_info: + get_token_restrictions(db, 99999) # Non-existent team ID + assert exc_info.value.status_code == 404 + assert "Team not found" in str(exc_info.value.detail) diff --git a/tests/test_worker.py b/tests/test_worker.py index e03dbc6..78c6ec2 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -148,7 +148,7 @@ async def test_apply_product_extends_keys_and_sets_budget(mock_litellm, db, test """ Test that applying a product extends keys and sets max budget correctly. - GIVEN: A team with users and keys (both team-owned and user-owned), and a product which specifies a max_budget of $20 per key + GIVEN: A team with users and keys (both team-owned and user-owned), and a product which specifies a max_budget of $50 per key with a renewal period of 30 days WHEN: The product is applied to the team THEN: All keys for the team and users in the team are extended and the max_budget is set correctly @@ -194,8 +194,7 @@ async def test_apply_product_extends_keys_and_sets_budget(mock_litellm, db, test # Setup mock instance mock_instance = mock_litellm.return_value - mock_instance.update_key_duration = AsyncMock() - mock_instance.update_budget = AsyncMock() + mock_instance.set_key_restrictions = AsyncMock() # Apply product to team await apply_product_for_team(db, test_team.stripe_customer_id, test_product.id) @@ -208,23 +207,18 @@ async def test_apply_product_extends_keys_and_sets_budget(mock_litellm, db, test # Verify LiteLLM service was called for all keys (both team and user owned) all_keys = team_keys + user_keys - assert mock_instance.update_key_duration.call_count == len(all_keys) - assert mock_instance.update_budget.call_count == len(all_keys) + assert mock_instance.set_key_restrictions.call_count == len(all_keys) # Verify each key was updated with correct duration and budget for key in all_keys: - # Verify duration update - duration_calls = [call for call in mock_instance.update_key_duration.call_args_list + # Verify key restrictions update + restriction_calls = [call for call in mock_instance.set_key_restrictions.call_args_list if call[1]['litellm_token'] == key.litellm_token] - assert len(duration_calls) == 1 - assert duration_calls[0][1]['duration'] == f"{test_product.renewal_period_days}d" - - # Verify budget update - budget_calls = [call for call in mock_instance.update_budget.call_args_list - if call[1]['litellm_token'] == key.litellm_token] - assert len(budget_calls) == 1 - assert budget_calls[0][1]['budget_duration'] == f"{test_product.renewal_period_days}d" - assert budget_calls[0][1]['budget_amount'] == test_product.max_budget_per_key + assert len(restriction_calls) == 1 + assert restriction_calls[0][1]['duration'] == f"{test_product.renewal_period_days}d" + assert restriction_calls[0][1]['budget_duration'] == f"{test_product.renewal_period_days}d" + assert restriction_calls[0][1]['budget_amount'] == test_product.max_budget_per_key + assert restriction_calls[0][1]['rpm_limit'] == test_product.rpm_per_key # Verify team was updated correctly db.refresh(test_team) From 91218f80877cd6a082a2e0ba6e38609d64d56533 Mon Sep 17 00:00:00 2001 From: Pippa H Date: Fri, 30 May 2025 10:01:10 +0200 Subject: [PATCH 18/23] Add frontend for purchase from pricing table Also update event management and improve billing APIs --- app/api/billing.py | 26 ++++-- app/core/resource_limits.py | 8 +- app/schemas/models.py | 4 + app/services/stripe.py | 28 +++---- frontend/src/app/admin/products/page.tsx | 2 +- frontend/src/app/team-admin/pricing/page.tsx | 86 ++++++++++++++++++++ frontend/src/components/sidebar-layout.tsx | 3 +- tests/test_billing.py | 26 +++--- 8 files changed, 144 insertions(+), 39 deletions(-) create mode 100644 frontend/src/app/team-admin/pricing/page.tsx diff --git a/app/api/billing.py b/app/api/billing.py index 31a1ec0..18985b8 100644 --- a/app/api/billing.py +++ b/app/api/billing.py @@ -5,7 +5,7 @@ from app.db.database import get_db from app.core.security import check_specific_team_admin from app.db.models import DBTeam, DBSystemSecret -from app.schemas.models import CheckoutSessionCreate +from app.schemas.models import CheckoutSessionCreate, PricingTableSession from app.services.stripe import ( create_checkout_session, decode_stripe_event, @@ -13,8 +13,7 @@ get_product_id_from_subscription, get_product_id_from_session, create_stripe_customer, - get_pricing_table_session, - get_customer_from_pi + get_pricing_table_secret, ) from app.core.worker import apply_product_for_team, remove_product_from_team @@ -52,6 +51,11 @@ async def checkout( ) try: + if not team.stripe_customer_id: + team.stripe_customer_id = await create_stripe_customer(team) + db.add(team) + db.commit() + # Get the frontend URL from environment frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000") @@ -201,7 +205,9 @@ async def get_portal( try: # Create Stripe customer if one doesn't exist if not team.stripe_customer_id: - team.stripe_customer_id = await create_stripe_customer(team, db) + team.stripe_customer_id = await create_stripe_customer(team) + db.add(team) + db.commit() # Get the frontend URL from environment frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000") @@ -221,7 +227,7 @@ async def get_portal( detail="Error creating portal session" ) -@router.get("/teams/{team_id}/pricing-table-session", dependencies=[Depends(check_specific_team_admin)]) +@router.get("/teams/{team_id}/pricing-table-session", dependencies=[Depends(check_specific_team_admin)], response_model=PricingTableSession) async def get_pricing_table_session( team_id: int, db: Session = Depends(get_db) @@ -247,12 +253,16 @@ async def get_pricing_table_session( try: # Create Stripe customer if one doesn't exist if not team.stripe_customer_id: - team.stripe_customer_id = await create_stripe_customer(team, db) + logger.info(f"Creating Stripe customer for team {team.id}") + team.stripe_customer_id = await create_stripe_customer(team) + db.add(team) + db.commit() + logger.info(f"Stripe ID is {team.stripe_customer_id}") # Create customer session using the service - client_secret = await get_pricing_table_session(team.stripe_customer_id) + client_secret = await get_pricing_table_secret(team.stripe_customer_id) - return {"client_secret": client_secret} + return PricingTableSession(client_secret=client_secret) except Exception as e: logger.error(f"Error creating customer session: {str(e)}") raise HTTPException( diff --git a/app/core/resource_limits.py b/app/core/resource_limits.py index b21180f..daf81a7 100644 --- a/app/core/resource_limits.py +++ b/app/core/resource_limits.py @@ -179,7 +179,7 @@ def get_token_restrictions(db: Session, team_id: int) -> tuple[int, float, int]: if team.last_payment is None: days_left_in_period = max_key_duration else: - days_left_in_period = max_key_duration - (datetime.now(UTC) - max(team.created_at, team.last_payment)).days + days_left_in_period = max_key_duration - (datetime.now(UTC) - max(team.created_at.replace(tzinfo=UTC), team.last_payment.replace(tzinfo=UTC))).days max_max_spend = max( (product.max_budget_per_key for team_product in team.active_products for product in [team_product.product] if product.max_budget_per_key), @@ -191,4 +191,8 @@ def get_token_restrictions(db: Session, team_id: int) -> tuple[int, float, int]: default=DEFAULT_RPM_PER_KEY ) - return days_left_in_period, max_max_spend, max_rpm_limit \ No newline at end of file + return days_left_in_period, max_max_spend, max_rpm_limit + +def get_team_limits(db: Session, team_id: int): + # TODO: Go through all products, and create a master list of the limits on all fields for this team. + pass \ No newline at end of file diff --git a/app/schemas/models.py b/app/schemas/models.py index 42c6a3d..4b7da8f 100644 --- a/app/schemas/models.py +++ b/app/schemas/models.py @@ -307,4 +307,8 @@ class ProductUpdate(BaseModel): class Product(ProductBase): created_at: datetime updated_at: Optional[datetime] = None + model_config = ConfigDict(from_attributes=True) + +class PricingTableSession(BaseModel): + client_secret: str model_config = ConfigDict(from_attributes=True) \ No newline at end of file diff --git a/app/services/stripe.py b/app/services/stripe.py index 66429a5..1690ff2 100644 --- a/app/services/stripe.py +++ b/app/services/stripe.py @@ -15,7 +15,9 @@ stripe.api_key = os.getenv("STRIPE_SECRET_KEY") async def create_checkout_session( - team: DBTeam, + team_name: str, + admin_email: str, + team_id: int, price_lookup_token: str, frontend_url: str ) -> str: @@ -48,18 +50,18 @@ async def create_checkout_session( # Create the checkout session checkout_session = stripe.checkout.Session.create( - customer_email=team.admin_email, - success_url=f"{frontend_url}/teams/{team.id}/dashboard?session_id={{CHECKOUT_SESSION_ID}}", - cancel_url=f"{frontend_url}/teams/{team.id}/pricing", + customer_email=admin_email, + success_url=f"{frontend_url}/teams/{team_id}/dashboard?session_id={{CHECKOUT_SESSION_ID}}", + cancel_url=f"{frontend_url}/teams/{team_id}/pricing", mode="subscription", line_items=[{ "price": subscription_price.id, "quantity": 1, }], metadata={ - "team_id": team.id, - "team_name": team.name, - "admin_email": team.admin_email + "team_id": team_id, + "team_name": team_name, + "admin_email": admin_email } ) @@ -200,15 +202,13 @@ async def setup_stripe_webhook(webhook_key: str, webhook_route: str, db: Session ) async def create_stripe_customer( - team: DBTeam, - db: Session + team: DBTeam ) -> str: """ - Create a Stripe customer for a team and save the customer ID. + Create a Stripe customer for a team. Args: team: The team to create a Stripe customer for - db: Database session Returns: str: The Stripe customer ID @@ -231,10 +231,6 @@ async def create_stripe_customer( } ) - # Save customer ID to team - team.stripe_customer_id = customer.id - db.commit() - return customer.id except Exception as e: @@ -286,7 +282,7 @@ async def get_customer_from_pi(payment_intent: str) -> str: logger.info(f"Payment intent is:\n{payment_intent}") return payment_intent.customer -async def get_pricing_table_session(customer_id: str) -> str: +async def get_pricing_table_secret(customer_id: str) -> str: """ Create a Stripe Customer Session client secret for a customer. diff --git a/frontend/src/app/admin/products/page.tsx b/frontend/src/app/admin/products/page.tsx index 24929b3..2399499 100644 --- a/frontend/src/app/admin/products/page.tsx +++ b/frontend/src/app/admin/products/page.tsx @@ -53,7 +53,7 @@ export default function ProductsPage() { }; // Queries - const { data: products = [], isLoading } = useQuery({ + const { data: products = [] } = useQuery({ queryKey: ['products'], queryFn: async () => { const response = await get('/products'); diff --git a/frontend/src/app/team-admin/pricing/page.tsx b/frontend/src/app/team-admin/pricing/page.tsx new file mode 100644 index 0000000..0cbd0b1 --- /dev/null +++ b/frontend/src/app/team-admin/pricing/page.tsx @@ -0,0 +1,86 @@ +'use client'; + +import { useEffect, useState } from 'react'; +import { useAuth } from '@/hooks/use-auth'; +import { get, post } from '@/utils/api'; +import Script from 'next/script'; + +declare module 'react' { + interface HTMLAttributes extends AriaAttributes, DOMAttributes { + 'pricing-table-id'?: string; + 'publishable-key'?: string; + 'customer-session-client-secret'?: string; + } +} + +declare module 'react/jsx-runtime' { + interface Element { + 'stripe-pricing-table': any; + } +} + +declare global { + interface HTMLElementTagNameMap { + 'stripe-pricing-table': HTMLElement; + } +} + +export default function PricingPage() { + const { user } = useAuth(); + const [clientSecret, setClientSecret] = useState(null); + const [error, setError] = useState(null); + + useEffect(() => { + const fetchSessionToken = async () => { + try { + if (!user?.team_id) return; + const response = await get(`/billing/teams/${user.team_id}/pricing-table-session`); + const data = await response.json(); + setClientSecret(data.client_secret); + } catch (err) { + setError('Failed to load pricing table. Please try again later.'); + console.error('Error fetching pricing table session:', err); + } + }; + + fetchSessionToken(); + }, [user?.team_id]); + + const handleManageSubscription = async () => { + try { + const response = await post(`/billing/teams/${user?.team_id}/portal`, {}); + if (response.redirected) { + window.location.href = response.url; + } + } catch (error) { + console.error('Error accessing portal:', error); + } + }; + + if (error) { + return
{error}
; + } + + return ( +
+
+

Subscription Plans

+ +
+