From 2fb8b4ef503d18c32aaeba8323f6cd70d11ca9b3 Mon Sep 17 00:00:00 2001 From: Pippa H Date: Fri, 9 May 2025 14:20:28 +0200 Subject: [PATCH 1/9] Add Prometheus and Grafana integration for metrics monitoring - Updated docker-compose.yml to include Prometheus and Grafana services. - Added environment variable ENABLE_METRICS to enable Prometheus metrics in the backend. - Implemented Prometheus middleware for FastAPI to track HTTP requests and audit events. - Enhanced audit logging to record metrics for events and durations. - Created Prometheus configuration file for scraping metrics from the FastAPI application. - Added Grafana provisioning files for dashboards and data sources to visualize metrics. - Updated requirements.txt to include Prometheus client libraries. --- app/main.py | 23 ++ app/middleware/audit.py | 26 +- app/middleware/prometheus.py | 108 +++++ docker-compose.yml | 49 ++- grafana/provisioning/dashboards/fastapi.json | 375 ++++++++++++++++++ grafana/provisioning/dashboards/fastapi.yml | 14 + .../provisioning/datasources/prometheus.yml | 14 + prometheus/prometheus.yml | 10 + requirements.txt | 2 + 9 files changed, 618 insertions(+), 3 deletions(-) create mode 100644 app/middleware/prometheus.py create mode 100644 grafana/provisioning/dashboards/fastapi.json create mode 100644 grafana/provisioning/dashboards/fastapi.yml create mode 100644 grafana/provisioning/datasources/prometheus.yml create mode 100644 prometheus/prometheus.yml diff --git a/app/main.py b/app/main.py index 85c47d0..64aef2e 100644 --- a/app/main.py +++ b/app/main.py @@ -6,6 +6,7 @@ from starlette.middleware.base import BaseHTTPMiddleware from fastapi.openapi.docs import get_swagger_ui_html from fastapi.openapi.utils import get_openapi +from prometheus_fastapi_instrumentator import Instrumentator, metrics import os import logging @@ -27,6 +28,7 @@ async def dispatch(self, request, call_next): from app.core.config import settings from app.db.database import get_db from app.middleware.audit import AuditLogMiddleware +from app.middleware.prometheus import PrometheusMiddleware app = FastAPI( title="Private AI Keys as a Service", @@ -79,6 +81,9 @@ async def dispatch(self, request, call_next): # Add HTTPS redirect middleware first app.add_middleware(HTTPSRedirectMiddleware) +# Add Prometheus middleware +app.add_middleware(PrometheusMiddleware) + # Configure CORS app.add_middleware( CORSMiddleware, @@ -96,6 +101,24 @@ async def dispatch(self, request, call_next): app.add_middleware(AuditLogMiddleware, db=next(get_db())) +# Setup Prometheus instrumentation +instrumentator = Instrumentator( + should_group_status_codes=False, + should_ignore_untemplated=True, + should_respect_env_var=True, + should_instrument_requests_inprogress=True, + excluded_handlers=["/metrics"], + env_var_name="ENABLE_METRICS", + inprogress_name="fastapi_inprogress", + inprogress_labels=True, +) + +# Add default metrics +instrumentator.add(metrics.default()) + +# Instrument the app +instrumentator.instrument(app).expose(app) + @app.get("/health") async def health_check(): return {"status": "healthy"} diff --git a/app/middleware/audit.py b/app/middleware/audit.py index 95173a9..7ea4a32 100644 --- a/app/middleware/audit.py +++ b/app/middleware/audit.py @@ -4,8 +4,10 @@ from app.db.models import DBAuditLog from app.api.auth import get_current_user_from_auth from app.db.database import get_db +from app.middleware.prometheus import audit_events_total, audit_event_duration_seconds import json import logging +import time from fastapi import Cookie, Header from typing import Optional @@ -18,9 +20,11 @@ def __init__(self, app, db: Session): async def dispatch(self, request: Request, call_next): # Skip audit logging for certain paths - if request.url.path in ["/health", "/docs", "/openapi.json", "/audit/logs", "/auth/me"]: + if request.url.path in ["/health", "/docs", "/openapi.json", "/audit/logs", "/auth/me", "/metrics"]: return await call_next(request) + start_time = time.time() + # Get the response response = await call_next(request) @@ -69,11 +73,14 @@ async def dispatch(self, request: Request, call_next): # If no origin/referer and has auth header, likely direct API call request_source = "api" if auth_header else None + # Get resource type from path + resource_type = request.url.path.split("/")[1] # First path segment + # Create audit log entry audit_log = DBAuditLog( user_id=user_id, event_type=request.method, - resource_type=request.url.path.split("/")[1], # First path segment + resource_type=resource_type, resource_id=str(resource_id) if resource_id else None, action=f"{request.method} {request.url.path}", details={ @@ -89,6 +96,21 @@ async def dispatch(self, request: Request, call_next): db.add(audit_log) db.commit() + # Record audit metrics + audit_events_total.labels( + event_type=request.method, + resource_type=resource_type, + request_source=request_source or "unknown", + status_code=response.status_code + ).inc() + + # Record audit event duration + duration = time.time() - start_time + audit_event_duration_seconds.labels( + event_type=request.method, + resource_type=resource_type + ).observe(duration) + except Exception as e: logger.error(f"Failed to create audit log: {str(e)}", exc_info=True) # Don't re-raise the exception - we don't want to break the request if audit logging fails diff --git a/app/middleware/prometheus.py b/app/middleware/prometheus.py new file mode 100644 index 0000000..fd89859 --- /dev/null +++ b/app/middleware/prometheus.py @@ -0,0 +1,108 @@ +from prometheus_client import Counter, Histogram, Gauge +from prometheus_fastapi_instrumentator import Instrumentator, metrics +from prometheus_fastapi_instrumentator.metrics import Info +from fastapi import Request +from starlette.middleware.base import BaseHTTPMiddleware +import time +import logging + +logger = logging.getLogger(__name__) + +# RED Metrics +http_requests_total = Counter( + "http_requests_total", + "Total number of HTTP requests", + ["method", "endpoint", "status_code"] +) + +http_request_duration_seconds = Histogram( + "http_request_duration_seconds", + "HTTP request duration in seconds", + ["method", "endpoint"], + buckets=(0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0, float("inf")) +) + +# Audit Metrics +audit_events_total = Counter( + "audit_events_total", + "Total number of audit events", + ["event_type", "resource_type", "request_source", "status_code"] +) + +audit_event_duration_seconds = Histogram( + "audit_event_duration_seconds", + "Audit event processing duration in seconds", + ["event_type", "resource_type"], + buckets=(0.01, 0.05, 0.1, 0.5, 1.0, 2.0, 5.0, float("inf")) +) + +# User Metrics +requests_per_user = Counter( + "requests_per_user", + "Number of requests per user", + ["user_id", "method", "endpoint"] +) + +class PrometheusMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + # Skip metrics for certain paths + if request.url.path in ["/metrics", "/health", "/docs", "/openapi.json"]: + return await call_next(request) + + start_time = time.time() + + # Process the request + response = await call_next(request) + + # Calculate duration + duration = time.time() - start_time + + # Record RED metrics + http_requests_total.labels( + method=request.method, + endpoint=request.url.path, + status_code=response.status_code + ).inc() + + http_request_duration_seconds.labels( + method=request.method, + endpoint=request.url.path + ).observe(duration) + + # Get user ID from request if available + user_id = None + try: + # Get access token from cookie or authorization header + cookies = request.cookies + headers = request.headers + access_token = cookies.get("access_token") + auth_header = headers.get("authorization") + + if auth_header: + parts = auth_header.split() + if len(parts) == 2 and parts[0].lower() == "bearer": + access_token = parts[1] + + if access_token: + from app.api.auth import get_current_user_from_auth + from app.db.database import get_db + db = next(get_db()) + user = await get_current_user_from_auth( + access_token=access_token if access_token else None, + authorization=auth_header if auth_header else None, + db=db + ) + user_id = str(user.id) if user else "anonymous" + db.close() + except Exception as e: + logger.debug(f"Could not get user for metrics: {str(e)}") + user_id = "anonymous" + + # Record requests per user + requests_per_user.labels( + user_id=user_id, + method=request.method, + endpoint=request.url.path + ).inc() + + return response \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index 9d7be83..6b0ea51 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -25,6 +25,7 @@ services: DATABASE_URL: postgresql://postgres:postgres@postgres/postgres_service SECRET_KEY: "dKq2BK3pqGQfNqC7SK8ZxNCdqJnGV4F9" # More secure key for development ENV_SUFFIX: "local" + ENABLE_METRICS: "true" # Enable Prometheus metrics ports: - "8800:8800" volumes: @@ -88,8 +89,54 @@ services: labels: lagoon.type: none + prometheus: + image: prom/prometheus:latest + ports: + - "9090:9090" + volumes: + - ./prometheus:/etc/prometheus + - prometheus_data:/prometheus + command: + - '--config.file=/etc/prometheus/prometheus.yml' + - '--storage.tsdb.path=/prometheus' + - '--web.console.libraries=/usr/share/prometheus/console_libraries' + - '--web.console.templates=/usr/share/prometheus/consoles' + depends_on: + - backend + labels: + lagoon.type: none + + grafana: + image: grafana/grafana:latest + ports: + - "3001:3000" + volumes: + - ./grafana/provisioning:/etc/grafana/provisioning + - grafana_data:/var/lib/grafana + environment: + - GF_SECURITY_ADMIN_USER=admin + - GF_SECURITY_ADMIN_PASSWORD=admin123 + - GF_USERS_ALLOW_SIGN_UP=false + - GF_AUTH_ANONYMOUS_ENABLED=false + - GF_AUTH_BASIC_ENABLED=true + - GF_AUTH_DISABLE_LOGIN_FORM=false + - GF_AUTH_DISABLE_SIGNOUT_MENU=false + - GF_INSTALL_PLUGINS=grafana-piechart-panel + depends_on: + prometheus: + condition: service_started + healthcheck: + test: ["CMD-SHELL", "wget --spider --quiet http://localhost:3000/api/health || exit 1"] + interval: 10s + timeout: 5s + retries: 5 + start_period: 30s + labels: + lagoon.type: none volumes: postgres_data: litellm_postgres_data: - name: litellm_postgres_data # Named volume for Postgres data persistence \ No newline at end of file + name: litellm_postgres_data # Named volume for Postgres data persistence + prometheus_data: + grafana_data: \ No newline at end of file diff --git a/grafana/provisioning/dashboards/fastapi.json b/grafana/provisioning/dashboards/fastapi.json new file mode 100644 index 0000000..50781d8 --- /dev/null +++ b/grafana/provisioning/dashboards/fastapi.json @@ -0,0 +1,375 @@ +{ + "annotations": { + "list": [] + }, + "editable": true, + "fiscalYearStartMonth": 0, + "graphTooltip": 0, + "links": [], + "liveNow": false, + "panels": [ + { + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 10, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + } + ] + }, + "unit": "short" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 0 + }, + "id": 1, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "expr": "rate(http_requests_total[5m])", + "legendFormat": "{{method}} {{endpoint}}", + "refId": "A" + } + ], + "title": "Request Rate", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 10, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 0 + }, + "id": 2, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "expr": "rate(http_request_duration_seconds_sum[5m]) / rate(http_request_duration_seconds_count[5m])", + "legendFormat": "{{method}} {{endpoint}}", + "refId": "A" + } + ], + "title": "Request Duration", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 10, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 8 + }, + "id": 3, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "expr": "rate(audit_events_total[5m])", + "legendFormat": "{{event_type}} {{resource_type}}", + "refId": "A" + } + ], + "title": "Audit Events Rate", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 10, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 8 + }, + "id": 4, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "expr": "rate(requests_per_user_total[5m])", + "legendFormat": "User {{user_id}} - {{method}} {{endpoint}}", + "refId": "A" + } + ], + "title": "Requests per User", + "type": "timeseries" + } + ], + "refresh": "5s", + "schemaVersion": 38, + "style": "dark", + "tags": [], + "templating": { + "list": [] + }, + "time": { + "from": "now-1h", + "to": "now" + }, + "timepicker": {}, + "timezone": "", + "title": "FastAPI Dashboard", + "uid": "fastapi", + "version": 1, + "weekStart": "" +} \ No newline at end of file diff --git a/grafana/provisioning/dashboards/fastapi.yml b/grafana/provisioning/dashboards/fastapi.yml new file mode 100644 index 0000000..9ac0b64 --- /dev/null +++ b/grafana/provisioning/dashboards/fastapi.yml @@ -0,0 +1,14 @@ +apiVersion: 1 + +providers: + - name: 'FastAPI' + orgId: 1 + folder: '' + type: file + disableDeletion: false + editable: true + updateIntervalSeconds: 10 + allowUiUpdates: true + options: + path: /etc/grafana/provisioning/dashboards + foldersFromFilesStructure: true \ No newline at end of file diff --git a/grafana/provisioning/datasources/prometheus.yml b/grafana/provisioning/datasources/prometheus.yml new file mode 100644 index 0000000..ea52740 --- /dev/null +++ b/grafana/provisioning/datasources/prometheus.yml @@ -0,0 +1,14 @@ +apiVersion: 1 + +datasources: + - name: Prometheus + type: prometheus + access: proxy + url: http://prometheus:9090 + isDefault: true + editable: false + uid: prometheus + jsonData: + timeInterval: 15s + queryTimeout: 30s + httpMethod: GET \ No newline at end of file diff --git a/prometheus/prometheus.yml b/prometheus/prometheus.yml new file mode 100644 index 0000000..0c0021a --- /dev/null +++ b/prometheus/prometheus.yml @@ -0,0 +1,10 @@ +global: + scrape_interval: 15s + evaluation_interval: 15s + +scrape_configs: + - job_name: 'fastapi' + static_configs: + - targets: ['backend:8800'] + metrics_path: '/metrics' + scheme: 'http' \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 4101a22..6a85b14 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,5 @@ alembic==1.15.2 boto3==1.35.1 markdown==3.5.2 email-validator==2.1.2 +prometheus-client==0.19.0 +prometheus-fastapi-instrumentator==6.1.0 From 3ed5c8837eb700b1fe79341da2df6a6945d6c9e9 Mon Sep 17 00:00:00 2001 From: Pippa H Date: Mon, 12 May 2025 09:05:29 +0200 Subject: [PATCH 2/9] Enhance authentication flow and metrics tracking - Updated docker-compose.yml to include an env_file for environment variables. - Refactored login and registration endpoints in auth.py to utilize a new metrics tracking function for authentication attempts. - Introduced a new auth.py module for tracking authentication metrics with Prometheus. - Simplified error handling and improved user feedback during login and registration processes. - Updated Grafana dashboard configuration to visualize authentication metrics. --- app/api/auth.py | 206 +++++-------- app/metrics/auth.py | 35 +++ app/middleware/prometheus.py | 4 +- docker-compose.yml | 17 +- grafana/provisioning/dashboards/fastapi.json | 307 ++++++++++++++++++- 5 files changed, 406 insertions(+), 163 deletions(-) create mode 100644 app/metrics/auth.py diff --git a/app/api/auth.py b/app/api/auth.py index 301831c..f99e235 100644 --- a/app/api/auth.py +++ b/app/api/auth.py @@ -36,6 +36,7 @@ from app.services.dynamodb import DynamoDBService from app.services.ses import SESService from app.api.teams import register_team +from app.metrics.auth import track_auth_attempt router = APIRouter( tags=["Authentication"] @@ -127,43 +128,19 @@ def create_and_set_access_token(response: Response, user_email: str) -> Token: return {"access_token": access_token, "token_type": "bearer"} -@router.post("/login", response_model=Token) -async def login( - request: Request, - response: Response, - login_data: Optional[LoginData] = Depends(get_login_data), - db: Session = Depends(get_db) -): - """ - Login to get access to the API. - - Accepts both application/x-www-form-urlencoded and application/json formats. - - Form data: - - **username**: Your email address - - **password**: Your password - - JSON data: - - **username**: Your email address - - **password**: Your password - - On successful login, an access token will be set as an HTTP-only cookie and also returned in the response. - Use this token for subsequent authenticated requests. - """ - if not login_data: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid login data. Please provide username and password in either form data or JSON format." - ) - - user = db.query(DBUser).filter(DBUser.email == login_data.username).first() - if not user or not verify_password(login_data.password, user.hashed_password): - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect email or password" - ) +@router.post("/login") +async def login(request: Request, user_data: LoginData, db: Session = Depends(get_db)): + try: + user = db.query(DBUser).filter(DBUser.email == user_data.username).first() + if not user or not verify_password(user_data.password, user.hashed_password): + await track_auth_attempt(request, user_data.username, "failure") + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect email or password") - return create_and_set_access_token(response, user.email) + await track_auth_attempt(request, user_data.username, "success") + return create_and_set_access_token(request.response, user.email) + except Exception as e: + await track_auth_attempt(request, user_data.username, "failure") + raise e @router.post("/logout") async def logout(response: Response): @@ -242,35 +219,28 @@ async def update_user_me( db.refresh(current_user) return current_user -@router.post("/register", response_model=User) -async def register(user: UserCreate, db: Session = Depends(get_db)): - """ - Register a new user account. - - - **email**: Your email address - - **password**: A secure password (minimum 8 characters) - - After registration, you'll need to login to get an access token. - """ - # Check if user with this email exists - db_user = db.query(DBUser).filter(DBUser.email == user.email).first() - if db_user: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Email already registered" +@router.post("/register") +async def register(request: Request, user_data: UserCreate, db: Session = Depends(get_db)): + try: + db_user = db.query(DBUser).filter(DBUser.email == user_data.email).first() + if db_user: + await track_auth_attempt(request, user_data.email, "failure") + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered") + + hashed_password = get_password_hash(user_data.password) + db_user = DBUser( + email=user_data.email, + hashed_password=hashed_password, + is_admin=False # Force is_admin to be False for all new registrations ) - - # Create new user - hashed_password = get_password_hash(user.password) - db_user = DBUser( - email=user.email, - hashed_password=hashed_password, - is_admin=False # Force is_admin to be False for all new registrations - ) - db.add(db_user) - db.commit() - db.refresh(db_user) - return db_user + db.add(db_user) + db.commit() + db.refresh(db_user) + await track_auth_attempt(request, user_data.email, "success") + return db_user + except Exception as e: + await track_auth_attempt(request, user_data.email, "failure") + raise e @router.post("/validate-email") async def validate_email( @@ -410,71 +380,47 @@ async def delete_token( db.commit() return {"message": "Token deleted successfully"} -@router.post("/sign-in", response_model=Token) -async def sign_in( - request: Request, - response: Response, - sign_in_data: Optional[SignInData] = Depends(get_sign_in_data), - db: Session = Depends(get_db) -): - """ - Sign in using a verification code instead of a password. - - Accepts both application/x-www-form-urlencoded and application/json formats. - - Form data: - - **username**: Your email address - - **verification_code**: The verification code sent to your email - - JSON data: - - **username**: Your email address - - **verification_code**: The verification code sent to your email - - On successful sign in, an access token will be set as an HTTP-only cookie and also returned in the response. - Use this token for subsequent authenticated requests. - - If the user doesn't exist, they will be automatically registered and a new team will be created - with them as the admin. - """ - if not sign_in_data: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid sign in data. Please provide username and verification code in either form data or JSON format." - ) - - # Verify the code using DynamoDB first - dynamodb_service = DynamoDBService() - stored_code = dynamodb_service.read_validation_code(sign_in_data.username) - - if not stored_code or stored_code.get('code').upper() != sign_in_data.verification_code.upper(): - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect email or verification code" - ) +@router.post("/sign-in") +async def sign_in(request: Request, user_data: SignInData, db: Session = Depends(get_db)): + try: + # Verify the code using DynamoDB first + dynamodb_service = DynamoDBService() + stored_code = dynamodb_service.read_validation_code(user_data.username) - # Get user from database after verifying the code - user = db.query(DBUser).filter(DBUser.email == sign_in_data.username).first() - - # If user doesn't exist, create a new user and team - if not user: - # First create the team - team_data = TeamCreate( - name=f"Team {sign_in_data.username}", - admin_email=sign_in_data.username, - phone="", # Required by schema but not used for auto-created teams - billing_address="" # Required by schema but not used for auto-created teams - ) - team = await register_team(team_data, db) - - # Create new user without password since they're using verification code - user = DBUser( - email=sign_in_data.username, - hashed_password="", # Empty password since they'll use verification code - role="admin", # Set role to admin for new users - team_id=team.id # Associate user with the team - ) - db.add(user) - db.commit() - db.refresh(user) + if not stored_code or stored_code.get('code').upper() != user_data.verification_code.upper(): + await track_auth_attempt(request, user_data.username, "failure") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect email or verification code" + ) - return create_and_set_access_token(response, user.email) \ No newline at end of file + # Get user from database after verifying the code + user = db.query(DBUser).filter(DBUser.email == user_data.username).first() + + # If user doesn't exist, create a new user and team + if not user: + # First create the team + team_data = TeamCreate( + name=f"Team {user_data.username}", + admin_email=user_data.username, + phone="", # Required by schema but not used for auto-created teams + billing_address="" # Required by schema but not used for auto-created teams + ) + team = await register_team(team_data, db) + + # Create new user without password since they're using verification code + user = DBUser( + email=user_data.username, + hashed_password="", # Empty password since they'll use verification code + role="admin", # Set role to admin for new users + team_id=team.id # Associate user with the team + ) + db.add(user) + db.commit() + db.refresh(user) + + await track_auth_attempt(request, user_data.username, "success") + return create_and_set_access_token(request.response, user.email) + except Exception as e: + await track_auth_attempt(request, user_data.username, "failure") + raise e diff --git a/app/metrics/auth.py b/app/metrics/auth.py new file mode 100644 index 0000000..71c9c5c --- /dev/null +++ b/app/metrics/auth.py @@ -0,0 +1,35 @@ +from prometheus_client import Counter +from fastapi import Request, Depends +import logging + +logger = logging.getLogger(__name__) + +# Auth Metrics +auth_attempts_total = Counter( + "auth_attempts_total", + "Total number of authentication attempts", + ["action", "email", "status"] +) + +async def track_auth_attempt(request: Request, email: str, status: str): + """Track authentication attempts using Prometheus metrics.""" + try: + # Map the endpoint to the appropriate action + action_map = { + "/auth/login": "login", + "/auth/register": "register", + "/auth/validate-email": "validate_email", + "/auth/sign-in": "sign_in" + } + action = action_map.get(request.url.path, "unknown") + + # Log the metric increment + logger.info(f"Incrementing auth_attempts_total for {action} - {email} - {status}") + + auth_attempts_total.labels( + action=action, + email=email, + status=status + ).inc() + except Exception as e: + logger.error(f"Could not track auth attempt: {str(e)}", exc_info=True) \ No newline at end of file diff --git a/app/middleware/prometheus.py b/app/middleware/prometheus.py index fd89859..7fa8f9e 100644 --- a/app/middleware/prometheus.py +++ b/app/middleware/prometheus.py @@ -3,6 +3,8 @@ from prometheus_fastapi_instrumentator.metrics import Info from fastapi import Request from starlette.middleware.base import BaseHTTPMiddleware +from app.core.security import get_current_user_from_auth +from app.db.database import get_db import time import logging @@ -84,8 +86,6 @@ async def dispatch(self, request: Request, call_next): access_token = parts[1] if access_token: - from app.api.auth import get_current_user_from_auth - from app.db.database import get_db db = next(get_db()) user = await get_current_user_from_auth( access_token=access_token if access_token else None, diff --git a/docker-compose.yml b/docker-compose.yml index 6b0ea51..69cb3e4 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -24,7 +24,6 @@ services: environment: DATABASE_URL: postgresql://postgres:postgres@postgres/postgres_service SECRET_KEY: "dKq2BK3pqGQfNqC7SK8ZxNCdqJnGV4F9" # More secure key for development - ENV_SUFFIX: "local" ENABLE_METRICS: "true" # Enable Prometheus metrics ports: - "8800:8800" @@ -115,22 +114,10 @@ services: - grafana_data:/var/lib/grafana environment: - GF_SECURITY_ADMIN_USER=admin - - GF_SECURITY_ADMIN_PASSWORD=admin123 + - GF_SECURITY_ADMIN_PASSWORD=admin - GF_USERS_ALLOW_SIGN_UP=false - - GF_AUTH_ANONYMOUS_ENABLED=false - - GF_AUTH_BASIC_ENABLED=true - - GF_AUTH_DISABLE_LOGIN_FORM=false - - GF_AUTH_DISABLE_SIGNOUT_MENU=false - - GF_INSTALL_PLUGINS=grafana-piechart-panel depends_on: - prometheus: - condition: service_started - healthcheck: - test: ["CMD-SHELL", "wget --spider --quiet http://localhost:3000/api/health || exit 1"] - interval: 10s - timeout: 5s - retries: 5 - start_period: 30s + - prometheus labels: lagoon.type: none diff --git a/grafana/provisioning/dashboards/fastapi.json b/grafana/provisioning/dashboards/fastapi.json index 50781d8..255e50b 100644 --- a/grafana/provisioning/dashboards/fastapi.json +++ b/grafana/provisioning/dashboards/fastapi.json @@ -71,10 +71,12 @@ "id": 1, "options": { "legend": { - "calcs": [], - "displayMode": "list", - "placement": "bottom", - "showLegend": true + "calcs": ["max"], + "displayMode": "table", + "placement": "right", + "showLegend": true, + "sortBy": "Max", + "sortDesc": true }, "tooltip": { "mode": "single", @@ -158,10 +160,12 @@ "id": 2, "options": { "legend": { - "calcs": [], - "displayMode": "list", - "placement": "bottom", - "showLegend": true + "calcs": ["max"], + "displayMode": "table", + "placement": "right", + "showLegend": true, + "sortBy": "Max", + "sortDesc": true }, "tooltip": { "mode": "single", @@ -244,10 +248,12 @@ "id": 3, "options": { "legend": { - "calcs": [], - "displayMode": "list", - "placement": "bottom", - "showLegend": true + "calcs": ["max"], + "displayMode": "table", + "placement": "right", + "showLegend": true, + "sortBy": "Max", + "sortDesc": true }, "tooltip": { "mode": "single", @@ -330,10 +336,12 @@ "id": 4, "options": { "legend": { - "calcs": [], - "displayMode": "list", - "placement": "bottom", - "showLegend": true + "calcs": ["max"], + "displayMode": "table", + "placement": "right", + "showLegend": true, + "sortBy": "Max", + "sortDesc": true }, "tooltip": { "mode": "single", @@ -353,6 +361,273 @@ ], "title": "Requests per User", "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 10, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + } + ] + }, + "unit": "short" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 16 + }, + "id": 5, + "options": { + "legend": { + "calcs": ["max"], + "displayMode": "table", + "placement": "right", + "showLegend": true, + "sortBy": "Max", + "sortDesc": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "expr": "rate(http_requests_total{status_code=~\"4..\"}[5m])", + "legendFormat": "{{method}} {{endpoint}} - {{status_code}}", + "refId": "A" + } + ], + "title": "4xx Client Error Rate", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 10, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + } + ] + }, + "unit": "short" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 16 + }, + "id": 6, + "options": { + "legend": { + "calcs": ["max"], + "displayMode": "table", + "placement": "right", + "showLegend": true, + "sortBy": "Max", + "sortDesc": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "expr": "rate(http_requests_total{status_code=~\"5..\"}[5m])", + "legendFormat": "{{method}} {{endpoint}} - {{status_code}}", + "refId": "A" + } + ], + "title": "5xx Server Error Rate", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 10, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + } + ] + }, + "unit": "short" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 24, + "x": 0, + "y": 24 + }, + "id": 7, + "options": { + "legend": { + "calcs": ["max"], + "displayMode": "table", + "placement": "right", + "showLegend": true, + "sortBy": "Max", + "sortDesc": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "expr": "rate(auth_attempts_total[5m])", + "legendFormat": "{{action}} - {{endpoint}} ({{status}})", + "refId": "A" + } + ], + "title": "Authentication Attempts", + "type": "timeseries" } ], "refresh": "5s", From 7d49a1126ff438f80c34550a9fe8955d352f2021 Mon Sep 17 00:00:00 2001 From: Pippa H Date: Tue, 13 May 2025 11:09:54 +0200 Subject: [PATCH 3/9] Refactor authentication endpoints and improve user feedback - Updated the login and registration endpoints in auth.py to enhance user experience with detailed documentation and error handling. - Removed metrics tracking for authentication attempts to streamline the login process. - Added support for both form data and JSON formats in login and registration requests. - Improved the sign-in process to automatically register users and create teams if they do not exist. - Enhanced error messages for invalid login and sign-in data. --- app/api/auth.py | 206 ++++++++++++++++++++++++++++++------------------ 1 file changed, 130 insertions(+), 76 deletions(-) diff --git a/app/api/auth.py b/app/api/auth.py index f99e235..301831c 100644 --- a/app/api/auth.py +++ b/app/api/auth.py @@ -36,7 +36,6 @@ from app.services.dynamodb import DynamoDBService from app.services.ses import SESService from app.api.teams import register_team -from app.metrics.auth import track_auth_attempt router = APIRouter( tags=["Authentication"] @@ -128,19 +127,43 @@ def create_and_set_access_token(response: Response, user_email: str) -> Token: return {"access_token": access_token, "token_type": "bearer"} -@router.post("/login") -async def login(request: Request, user_data: LoginData, db: Session = Depends(get_db)): - try: - user = db.query(DBUser).filter(DBUser.email == user_data.username).first() - if not user or not verify_password(user_data.password, user.hashed_password): - await track_auth_attempt(request, user_data.username, "failure") - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect email or password") +@router.post("/login", response_model=Token) +async def login( + request: Request, + response: Response, + login_data: Optional[LoginData] = Depends(get_login_data), + db: Session = Depends(get_db) +): + """ + Login to get access to the API. + + Accepts both application/x-www-form-urlencoded and application/json formats. + + Form data: + - **username**: Your email address + - **password**: Your password + + JSON data: + - **username**: Your email address + - **password**: Your password + + On successful login, an access token will be set as an HTTP-only cookie and also returned in the response. + Use this token for subsequent authenticated requests. + """ + if not login_data: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid login data. Please provide username and password in either form data or JSON format." + ) + + user = db.query(DBUser).filter(DBUser.email == login_data.username).first() + if not user or not verify_password(login_data.password, user.hashed_password): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect email or password" + ) - await track_auth_attempt(request, user_data.username, "success") - return create_and_set_access_token(request.response, user.email) - except Exception as e: - await track_auth_attempt(request, user_data.username, "failure") - raise e + return create_and_set_access_token(response, user.email) @router.post("/logout") async def logout(response: Response): @@ -219,28 +242,35 @@ async def update_user_me( db.refresh(current_user) return current_user -@router.post("/register") -async def register(request: Request, user_data: UserCreate, db: Session = Depends(get_db)): - try: - db_user = db.query(DBUser).filter(DBUser.email == user_data.email).first() - if db_user: - await track_auth_attempt(request, user_data.email, "failure") - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered") - - hashed_password = get_password_hash(user_data.password) - db_user = DBUser( - email=user_data.email, - hashed_password=hashed_password, - is_admin=False # Force is_admin to be False for all new registrations +@router.post("/register", response_model=User) +async def register(user: UserCreate, db: Session = Depends(get_db)): + """ + Register a new user account. + + - **email**: Your email address + - **password**: A secure password (minimum 8 characters) + + After registration, you'll need to login to get an access token. + """ + # Check if user with this email exists + db_user = db.query(DBUser).filter(DBUser.email == user.email).first() + if db_user: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Email already registered" ) - db.add(db_user) - db.commit() - db.refresh(db_user) - await track_auth_attempt(request, user_data.email, "success") - return db_user - except Exception as e: - await track_auth_attempt(request, user_data.email, "failure") - raise e + + # Create new user + hashed_password = get_password_hash(user.password) + db_user = DBUser( + email=user.email, + hashed_password=hashed_password, + is_admin=False # Force is_admin to be False for all new registrations + ) + db.add(db_user) + db.commit() + db.refresh(db_user) + return db_user @router.post("/validate-email") async def validate_email( @@ -380,47 +410,71 @@ async def delete_token( db.commit() return {"message": "Token deleted successfully"} -@router.post("/sign-in") -async def sign_in(request: Request, user_data: SignInData, db: Session = Depends(get_db)): - try: - # Verify the code using DynamoDB first - dynamodb_service = DynamoDBService() - stored_code = dynamodb_service.read_validation_code(user_data.username) +@router.post("/sign-in", response_model=Token) +async def sign_in( + request: Request, + response: Response, + sign_in_data: Optional[SignInData] = Depends(get_sign_in_data), + db: Session = Depends(get_db) +): + """ + Sign in using a verification code instead of a password. - if not stored_code or stored_code.get('code').upper() != user_data.verification_code.upper(): - await track_auth_attempt(request, user_data.username, "failure") - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect email or verification code" - ) + Accepts both application/x-www-form-urlencoded and application/json formats. - # Get user from database after verifying the code - user = db.query(DBUser).filter(DBUser.email == user_data.username).first() - - # If user doesn't exist, create a new user and team - if not user: - # First create the team - team_data = TeamCreate( - name=f"Team {user_data.username}", - admin_email=user_data.username, - phone="", # Required by schema but not used for auto-created teams - billing_address="" # Required by schema but not used for auto-created teams - ) - team = await register_team(team_data, db) - - # Create new user without password since they're using verification code - user = DBUser( - email=user_data.username, - hashed_password="", # Empty password since they'll use verification code - role="admin", # Set role to admin for new users - team_id=team.id # Associate user with the team - ) - db.add(user) - db.commit() - db.refresh(user) - - await track_auth_attempt(request, user_data.username, "success") - return create_and_set_access_token(request.response, user.email) - except Exception as e: - await track_auth_attempt(request, user_data.username, "failure") - raise e + Form data: + - **username**: Your email address + - **verification_code**: The verification code sent to your email + + JSON data: + - **username**: Your email address + - **verification_code**: The verification code sent to your email + + On successful sign in, an access token will be set as an HTTP-only cookie and also returned in the response. + Use this token for subsequent authenticated requests. + + If the user doesn't exist, they will be automatically registered and a new team will be created + with them as the admin. + """ + if not sign_in_data: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid sign in data. Please provide username and verification code in either form data or JSON format." + ) + + # Verify the code using DynamoDB first + dynamodb_service = DynamoDBService() + stored_code = dynamodb_service.read_validation_code(sign_in_data.username) + + if not stored_code or stored_code.get('code').upper() != sign_in_data.verification_code.upper(): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect email or verification code" + ) + + # Get user from database after verifying the code + user = db.query(DBUser).filter(DBUser.email == sign_in_data.username).first() + + # If user doesn't exist, create a new user and team + if not user: + # First create the team + team_data = TeamCreate( + name=f"Team {sign_in_data.username}", + admin_email=sign_in_data.username, + phone="", # Required by schema but not used for auto-created teams + billing_address="" # Required by schema but not used for auto-created teams + ) + team = await register_team(team_data, db) + + # Create new user without password since they're using verification code + user = DBUser( + email=sign_in_data.username, + hashed_password="", # Empty password since they'll use verification code + role="admin", # Set role to admin for new users + team_id=team.id # Associate user with the team + ) + db.add(user) + db.commit() + db.refresh(user) + + return create_and_set_access_token(response, user.email) \ No newline at end of file From dd11dde0e8c89ef1ca30c12fc2abc790fd66ba24 Mon Sep 17 00:00:00 2001 From: Pippa H Date: Tue, 13 May 2025 13:28:28 +0200 Subject: [PATCH 4/9] Implement authentication request tracking and update metrics - Added a new function to track authentication requests using Prometheus metrics in auth.py. - Updated login, registration, and email validation endpoints to call the tracking function. - Refactored Prometheus middleware to handle error responses and record metrics accurately. - Modified Grafana dashboard configuration to reflect changes in authentication metrics tracking. --- app/api/auth.py | 50 +++++++++++++++- app/main.py | 4 +- app/middleware/prometheus.py | 62 +++++++++++++++----- grafana/provisioning/dashboards/fastapi.json | 8 +-- 4 files changed, 100 insertions(+), 24 deletions(-) diff --git a/app/api/auth.py b/app/api/auth.py index 301831c..07beb4e 100644 --- a/app/api/auth.py +++ b/app/api/auth.py @@ -36,6 +36,38 @@ from app.services.dynamodb import DynamoDBService from app.services.ses import SESService from app.api.teams import register_team +from app.middleware.prometheus import auth_requests_total + +def track_auth_request(request: Request, identifier: Optional[str] = None) -> None: + """ + Track authentication requests using Prometheus metrics. + + Args: + request: The FastAPI request object + identifier: Optional identifier (username/email) for the request + """ + # Get the path from the request + path = request.url.path + + # If no identifier provided, try to extract from request + if not identifier: + try: + # Try to get from form data + form_data = request.form() + identifier = form_data.get("username") or form_data.get("email") + except: + try: + # Try to get from JSON body + body = request.json() + identifier = body.get("username") or body.get("email") + except: + identifier = "unknown" + + # Increment the counter + auth_requests_total.labels( + endpoint=path, + identifier=identifier or "unknown" + ).inc() router = APIRouter( tags=["Authentication"] @@ -156,6 +188,9 @@ async def login( detail="Invalid login data. Please provide username and password in either form data or JSON format." ) + # Track the auth request + track_auth_request(request, login_data.username) + user = db.query(DBUser).filter(DBUser.email == login_data.username).first() if not user or not verify_password(login_data.password, user.hashed_password): raise HTTPException( @@ -243,7 +278,11 @@ async def update_user_me( return current_user @router.post("/register", response_model=User) -async def register(user: UserCreate, db: Session = Depends(get_db)): +async def register( + request: Request, + user: UserCreate, + db: Session = Depends(get_db) +): """ Register a new user account. @@ -252,6 +291,9 @@ async def register(user: UserCreate, db: Session = Depends(get_db)): After registration, you'll need to login to get an access token. """ + # Track the auth request + track_auth_request(request, user.email) + # Check if user with this email exists db_user = db.query(DBUser).filter(DBUser.email == user.email).first() if db_user: @@ -309,6 +351,9 @@ async def validate_email( detail="Email is required" ) + # Track the auth request + track_auth_request(request, email) + try: email_validator.validate_email(email, check_deliverability=False) except EmailNotValidError as e: @@ -442,6 +487,9 @@ async def sign_in( detail="Invalid sign in data. Please provide username and verification code in either form data or JSON format." ) + # Track the auth request + track_auth_request(request, sign_in_data.username) + # Verify the code using DynamoDB first dynamodb_service = DynamoDBService() stored_code = dynamodb_service.read_validation_code(sign_in_data.username) diff --git a/app/main.py b/app/main.py index 64aef2e..fb882d9 100644 --- a/app/main.py +++ b/app/main.py @@ -1,6 +1,4 @@ -from fastapi import FastAPI, Depends, HTTPException -from fastapi.security import OAuth2PasswordBearer -from sqlalchemy.orm import Session +from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.trustedhost import TrustedHostMiddleware from starlette.middleware.base import BaseHTTPMiddleware diff --git a/app/middleware/prometheus.py b/app/middleware/prometheus.py index 7fa8f9e..6da3c86 100644 --- a/app/middleware/prometheus.py +++ b/app/middleware/prometheus.py @@ -1,7 +1,7 @@ from prometheus_client import Counter, Histogram, Gauge from prometheus_fastapi_instrumentator import Instrumentator, metrics from prometheus_fastapi_instrumentator.metrics import Info -from fastapi import Request +from fastapi import Request, HTTPException, status from starlette.middleware.base import BaseHTTPMiddleware from app.core.security import get_current_user_from_auth from app.db.database import get_db @@ -45,6 +45,13 @@ ["user_id", "method", "endpoint"] ) +# Auth Metrics +auth_requests_total = Counter( + "auth_requests_total", + "Total number of authentication requests", + ["endpoint", "identifier"] +) + class PrometheusMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): # Skip metrics for certain paths @@ -52,18 +59,33 @@ async def dispatch(self, request: Request, call_next): return await call_next(request) start_time = time.time() + response = None + duration = 0 + is_error = False - # Process the request - response = await call_next(request) - - # Calculate duration - duration = time.time() - start_time + try: + # Process the request + response = await call_next(request) + duration = time.time() - start_time + except Exception as e: + logger.warning(f"Request failed: {e}") + # Record the actual duration of the failed request + duration = time.time() - start_time + # Capture the error response to be raised later + is_error = True + if not isinstance(e, HTTPException): + error_response = HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(e) + ) + else: + error_response = e # Record RED metrics http_requests_total.labels( method=request.method, endpoint=request.url.path, - status_code=response.status_code + status_code=response.status_code if response else 500 ).inc() http_request_duration_seconds.labels( @@ -86,14 +108,19 @@ async def dispatch(self, request: Request, call_next): access_token = parts[1] if access_token: - db = next(get_db()) - user = await get_current_user_from_auth( - access_token=access_token if access_token else None, - authorization=auth_header if auth_header else None, - db=db - ) - user_id = str(user.id) if user else "anonymous" - db.close() + try: + db = next(get_db()) + user = await get_current_user_from_auth( + access_token=access_token if access_token else None, + authorization=auth_header if auth_header else None, + db=db + ) + user_id = str(user.id) if user else "anonymous" + except Exception as e: + logger.debug(f"Could not get user for metrics: {str(e)}") + user_id = "anonymous" + finally: + db.close() except Exception as e: logger.debug(f"Could not get user for metrics: {str(e)}") user_id = "anonymous" @@ -105,4 +132,7 @@ async def dispatch(self, request: Request, call_next): endpoint=request.url.path ).inc() - return response \ No newline at end of file + if is_error: + raise error_response + else: + return response diff --git a/grafana/provisioning/dashboards/fastapi.json b/grafana/provisioning/dashboards/fastapi.json index 255e50b..41f4c47 100644 --- a/grafana/provisioning/dashboards/fastapi.json +++ b/grafana/provisioning/dashboards/fastapi.json @@ -621,8 +621,8 @@ "type": "prometheus", "uid": "prometheus" }, - "expr": "rate(auth_attempts_total[5m])", - "legendFormat": "{{action}} - {{endpoint}} ({{status}})", + "expr": "rate(auth_requests_total[5m])", + "legendFormat": "{{endpoint}} - {{identifier}} ({{status_code}})", "refId": "A" } ], @@ -643,8 +643,8 @@ }, "timepicker": {}, "timezone": "", - "title": "FastAPI Dashboard", - "uid": "fastapi", + "title": "amazee.ai Backend", + "uid": "amazeeai-backend", "version": 1, "weekStart": "" } \ No newline at end of file From 46230230941cf09afeb2eec7a881431ed006d9ec Mon Sep 17 00:00:00 2001 From: Pippa H Date: Tue, 13 May 2025 13:54:13 +0200 Subject: [PATCH 5/9] Delete superfluous auth tracking code --- app/metrics/auth.py | 35 ----------------------------------- 1 file changed, 35 deletions(-) delete mode 100644 app/metrics/auth.py diff --git a/app/metrics/auth.py b/app/metrics/auth.py deleted file mode 100644 index 71c9c5c..0000000 --- a/app/metrics/auth.py +++ /dev/null @@ -1,35 +0,0 @@ -from prometheus_client import Counter -from fastapi import Request, Depends -import logging - -logger = logging.getLogger(__name__) - -# Auth Metrics -auth_attempts_total = Counter( - "auth_attempts_total", - "Total number of authentication attempts", - ["action", "email", "status"] -) - -async def track_auth_attempt(request: Request, email: str, status: str): - """Track authentication attempts using Prometheus metrics.""" - try: - # Map the endpoint to the appropriate action - action_map = { - "/auth/login": "login", - "/auth/register": "register", - "/auth/validate-email": "validate_email", - "/auth/sign-in": "sign_in" - } - action = action_map.get(request.url.path, "unknown") - - # Log the metric increment - logger.info(f"Incrementing auth_attempts_total for {action} - {email} - {status}") - - auth_attempts_total.labels( - action=action, - email=email, - status=status - ).inc() - except Exception as e: - logger.error(f"Could not track auth attempt: {str(e)}", exc_info=True) \ No newline at end of file From f1369598db7ab9f861ed81e13c2a04e6ce42c577 Mon Sep 17 00:00:00 2001 From: Pippa H Date: Wed, 14 May 2025 11:51:46 +0200 Subject: [PATCH 6/9] Implement logging for authentication events and update metrics tracking - Added a logging mechanism for authentication events in auth.py, including login attempts, registration, and email validation. - Configured a TimedRotatingFileHandler to manage log files for authentication events. - Updated Prometheus middleware to track authentication request metrics with success and failure statuses. - Enhanced Grafana dashboard configuration to reflect changes in authentication metrics and improve visualization. - Cleaned up unused code related to previous metrics tracking. --- .gitignore | 3 + app/api/auth.py | 93 ++++---- app/middleware/prometheus.py | 108 +++------ grafana/provisioning/dashboards/fastapi.json | 231 +++++++++---------- grafana/provisioning/dashboards/fastapi.yml | 2 +- 5 files changed, 191 insertions(+), 246 deletions(-) diff --git a/.gitignore b/.gitignore index 7af86f2..d135cdb 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,9 @@ npm-debug.log* yarn-debug.log* yarn-error.log* +**/logs/** +*.log +*.log.* # Environment variables .env diff --git a/app/api/auth.py b/app/api/auth.py index 07beb4e..bc0632f 100644 --- a/app/api/auth.py +++ b/app/api/auth.py @@ -4,11 +4,38 @@ from email_validator import validate_email, EmailNotValidError from sqlalchemy.orm import Session import logging +from logging.handlers import TimedRotatingFileHandler import secrets import os from urllib.parse import urlparse +from pathlib import Path + +# Configure auth logger +auth_logger = logging.getLogger("auth") +auth_logger.setLevel(logging.INFO) + +# Create logs directory if it doesn't exist +log_dir = Path("logs") +log_dir.mkdir(exist_ok=True) + +# Configure file handler with daily rotation +file_handler = TimedRotatingFileHandler( + filename=log_dir / "auth.log", + when="midnight", + interval=1, + backupCount=30, # Keep logs for 30 days + encoding="utf-8" +) + +# Configure formatter +formatter = logging.Formatter( + "%(asctime)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S" +) +file_handler.setFormatter(formatter) -logger = logging.getLogger(__name__) +# Add handler to logger +auth_logger.addHandler(file_handler) from app.db.database import get_db from app.schemas.models import ( @@ -36,38 +63,6 @@ from app.services.dynamodb import DynamoDBService from app.services.ses import SESService from app.api.teams import register_team -from app.middleware.prometheus import auth_requests_total - -def track_auth_request(request: Request, identifier: Optional[str] = None) -> None: - """ - Track authentication requests using Prometheus metrics. - - Args: - request: The FastAPI request object - identifier: Optional identifier (username/email) for the request - """ - # Get the path from the request - path = request.url.path - - # If no identifier provided, try to extract from request - if not identifier: - try: - # Try to get from form data - form_data = request.form() - identifier = form_data.get("username") or form_data.get("email") - except: - try: - # Try to get from JSON body - body = request.json() - identifier = body.get("username") or body.get("email") - except: - identifier = "unknown" - - # Increment the counter - auth_requests_total.labels( - endpoint=path, - identifier=identifier or "unknown" - ).inc() router = APIRouter( tags=["Authentication"] @@ -188,16 +183,16 @@ async def login( detail="Invalid login data. Please provide username and password in either form data or JSON format." ) - # Track the auth request - track_auth_request(request, login_data.username) - + auth_logger.info(f"Login attempt for user: {login_data.username}") user = db.query(DBUser).filter(DBUser.email == login_data.username).first() if not user or not verify_password(login_data.password, user.hashed_password): + auth_logger.warning(f"Failed login attempt for user: {login_data.username}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect email or password" ) + auth_logger.info(f"Successful login for user: {login_data.username}") return create_and_set_access_token(response, user.email) @router.post("/logout") @@ -291,12 +286,11 @@ async def register( After registration, you'll need to login to get an access token. """ - # Track the auth request - track_auth_request(request, user.email) - + auth_logger.info(f"Registration attempt for user: {user.email}") # Check if user with this email exists db_user = db.query(DBUser).filter(DBUser.email == user.email).first() if db_user: + auth_logger.warning(f"Registration failed - email already exists: {user.email}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered" @@ -312,6 +306,7 @@ async def register( db.add(db_user) db.commit() db.refresh(db_user) + auth_logger.info(f"Successfully registered new user: {user.email}") return db_user @router.post("/validate-email") @@ -346,17 +341,17 @@ async def validate_email( pass if not email: + auth_logger.warning("Email validation attempt with no email provided") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Email is required" ) - # Track the auth request - track_auth_request(request, email) - + auth_logger.info(f"Email validation attempt for: {email}") try: email_validator.validate_email(email, check_deliverability=False) except EmailNotValidError as e: + auth_logger.warning(f"Invalid email format for {email}: {e}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid email format: {e}" @@ -367,8 +362,10 @@ async def validate_email( user = db.query(DBUser).filter(DBUser.email == email).first() if user: email_template = 'returning-user-code' + auth_logger.info(f"Sending validation code to existing user: {email}") else: email_template = 'new-user-code' + auth_logger.info(f"Sending validation code to new user: {email}") # Send the validation code via email ses_service = SESService() @@ -381,12 +378,13 @@ async def validate_email( ) if not email_sent: - logger.error(f"Failed to send validation code email to {email}") + auth_logger.error(f"Failed to send validation code email to {email}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to send validation code email" ) + auth_logger.info(f"Successfully sent validation code to: {email}") return { "message": "Validation code has been generated and sent" } @@ -482,19 +480,19 @@ async def sign_in( with them as the admin. """ if not sign_in_data: + auth_logger.warning("Sign-in attempt with invalid data format") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid sign in data. Please provide username and verification code in either form data or JSON format." ) - # Track the auth request - track_auth_request(request, sign_in_data.username) - + auth_logger.info(f"Sign-in attempt for user: {sign_in_data.username}") # Verify the code using DynamoDB first dynamodb_service = DynamoDBService() stored_code = dynamodb_service.read_validation_code(sign_in_data.username) if not stored_code or stored_code.get('code').upper() != sign_in_data.verification_code.upper(): + auth_logger.warning(f"Invalid verification code for user: {sign_in_data.username}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect email or verification code" @@ -505,6 +503,7 @@ async def sign_in( # If user doesn't exist, create a new user and team if not user: + auth_logger.info(f"Creating new user and team for: {sign_in_data.username}") # First create the team team_data = TeamCreate( name=f"Team {sign_in_data.username}", @@ -524,5 +523,7 @@ async def sign_in( db.add(user) db.commit() db.refresh(user) + auth_logger.info(f"Successfully created new user and team for: {sign_in_data.username}") + auth_logger.info(f"Successful sign-in for user: {sign_in_data.username}") return create_and_set_access_token(response, user.email) \ No newline at end of file diff --git a/app/middleware/prometheus.py b/app/middleware/prometheus.py index 6da3c86..ebd3626 100644 --- a/app/middleware/prometheus.py +++ b/app/middleware/prometheus.py @@ -1,28 +1,15 @@ -from prometheus_client import Counter, Histogram, Gauge -from prometheus_fastapi_instrumentator import Instrumentator, metrics -from prometheus_fastapi_instrumentator.metrics import Info -from fastapi import Request, HTTPException, status +from prometheus_client import Counter, Histogram +from fastapi import Request from starlette.middleware.base import BaseHTTPMiddleware from app.core.security import get_current_user_from_auth from app.db.database import get_db -import time import logging logger = logging.getLogger(__name__) -# RED Metrics -http_requests_total = Counter( - "http_requests_total", - "Total number of HTTP requests", - ["method", "endpoint", "status_code"] -) - -http_request_duration_seconds = Histogram( - "http_request_duration_seconds", - "HTTP request duration in seconds", - ["method", "endpoint"], - buckets=(0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0, float("inf")) -) +def normalize_path(path: str) -> str: + """Replace numeric segments in path with {id}.""" + return '/'.join('{id}' if segment.isdigit() else segment for segment in path.split('/')) # Audit Metrics audit_events_total = Counter( @@ -38,18 +25,18 @@ buckets=(0.01, 0.05, 0.1, 0.5, 1.0, 2.0, 5.0, float("inf")) ) -# User Metrics -requests_per_user = Counter( - "requests_per_user", - "Number of requests per user", - ["user_id", "method", "endpoint"] +# User Metrics - grouped by user type and endpoint +requests_by_user_type = Counter( + "requests_by_user_type", + "Number of requests grouped by user type", + ["user_type", "endpoint", "method"] ) -# Auth Metrics +# Auth Metrics - simplified to track success/failure auth_requests_total = Counter( "auth_requests_total", "Total number of authentication requests", - ["endpoint", "identifier"] + ["endpoint", "status"] # status will be "success" or "failure" ) class PrometheusMiddleware(BaseHTTPMiddleware): @@ -58,43 +45,24 @@ async def dispatch(self, request: Request, call_next): if request.url.path in ["/metrics", "/health", "/docs", "/openapi.json"]: return await call_next(request) - start_time = time.time() - response = None - duration = 0 - is_error = False + # Track auth requests for specific endpoints + is_auth_endpoint = request.url.path in [ + "/auth/login", + "/auth/register", + "/auth/validate-email", + "/auth/sign-in" + ] - try: - # Process the request - response = await call_next(request) - duration = time.time() - start_time - except Exception as e: - logger.warning(f"Request failed: {e}") - # Record the actual duration of the failed request - duration = time.time() - start_time - # Capture the error response to be raised later - is_error = True - if not isinstance(e, HTTPException): - error_response = HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=str(e) - ) - else: - error_response = e - - # Record RED metrics - http_requests_total.labels( - method=request.method, - endpoint=request.url.path, - status_code=response.status_code if response else 500 - ).inc() + response = await call_next(request) - http_request_duration_seconds.labels( - method=request.method, - endpoint=request.url.path - ).observe(duration) + if is_auth_endpoint: + auth_requests_total.labels( + endpoint=request.url.path, + status="success" + ).inc() - # Get user ID from request if available - user_id = None + # Get user type from request if available + user_type = "anonymous" try: # Get access token from cookie or authorization header cookies = request.cookies @@ -115,24 +83,22 @@ async def dispatch(self, request: Request, call_next): authorization=auth_header if auth_header else None, db=db ) - user_id = str(user.id) if user else "anonymous" + if user: + # Group users by their role or type + user_type = user.role if hasattr(user, 'role') else "authenticated" except Exception as e: logger.debug(f"Could not get user for metrics: {str(e)}") - user_id = "anonymous" finally: db.close() except Exception as e: logger.debug(f"Could not get user for metrics: {str(e)}") - user_id = "anonymous" - # Record requests per user - requests_per_user.labels( - user_id=user_id, - method=request.method, - endpoint=request.url.path + # Record requests by user type with normalized path + normalized_path = normalize_path(request.url.path) + requests_by_user_type.labels( + user_type=user_type, + endpoint=normalized_path, + method=request.method ).inc() - if is_error: - raise error_response - else: - return response + return response diff --git a/grafana/provisioning/dashboards/fastapi.json b/grafana/provisioning/dashboards/fastapi.json index 41f4c47..43919ba 100644 --- a/grafana/provisioning/dashboards/fastapi.json +++ b/grafana/provisioning/dashboards/fastapi.json @@ -1,12 +1,25 @@ { "annotations": { - "list": [] + "list": [ + { + "builtIn": 1, + "datasource": { + "type": "grafana", + "uid": "-- Grafana --" + }, + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "type": "dashboard" + } + ] }, "editable": true, "fiscalYearStartMonth": 0, "graphTooltip": 0, + "id": 4, "links": [], - "liveNow": false, "panels": [ { "datasource": { @@ -19,11 +32,13 @@ "mode": "palette-classic" }, "custom": { + "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", "axisPlacement": "auto", "barAlignment": 0, + "barWidthFactor": 0.6, "drawStyle": "line", "fillOpacity": 10, "gradientMode": "none", @@ -32,6 +47,7 @@ "tooltip": false, "viz": false }, + "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, @@ -53,8 +69,7 @@ "mode": "absolute", "steps": [ { - "color": "green", - "value": null + "color": "green" } ] }, @@ -71,7 +86,9 @@ "id": 1, "options": { "legend": { - "calcs": ["max"], + "calcs": [ + "max" + ], "displayMode": "table", "placement": "right", "showLegend": true, @@ -79,18 +96,22 @@ "sortDesc": true }, "tooltip": { + "hideZeros": false, "mode": "single", "sort": "none" } }, + "pluginVersion": "12.0.0", "targets": [ { "datasource": { "type": "prometheus", "uid": "prometheus" }, + "editorMode": "builder", "expr": "rate(http_requests_total[5m])", - "legendFormat": "{{method}} {{endpoint}}", + "legendFormat": "{{method}} - {{handler}}", + "range": true, "refId": "A" } ], @@ -108,11 +129,13 @@ "mode": "palette-classic" }, "custom": { + "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", "axisPlacement": "auto", "barAlignment": 0, + "barWidthFactor": 0.6, "drawStyle": "line", "fillOpacity": 10, "gradientMode": "none", @@ -121,6 +144,7 @@ "tooltip": false, "viz": false }, + "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, @@ -142,8 +166,7 @@ "mode": "absolute", "steps": [ { - "color": "green", - "value": null + "color": "green" } ] }, @@ -160,7 +183,9 @@ "id": 2, "options": { "legend": { - "calcs": ["max"], + "calcs": [ + "max" + ], "displayMode": "table", "placement": "right", "showLegend": true, @@ -168,18 +193,22 @@ "sortDesc": true }, "tooltip": { + "hideZeros": false, "mode": "single", "sort": "none" } }, + "pluginVersion": "12.0.0", "targets": [ { "datasource": { "type": "prometheus", "uid": "prometheus" }, + "editorMode": "code", "expr": "rate(http_request_duration_seconds_sum[5m]) / rate(http_request_duration_seconds_count[5m])", - "legendFormat": "{{method}} {{endpoint}}", + "legendFormat": "{{method}} - {{handler}}", + "range": true, "refId": "A" } ], @@ -197,11 +226,13 @@ "mode": "palette-classic" }, "custom": { + "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", "axisPlacement": "auto", "barAlignment": 0, + "barWidthFactor": 0.6, "drawStyle": "line", "fillOpacity": 10, "gradientMode": "none", @@ -210,6 +241,7 @@ "tooltip": false, "viz": false }, + "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, @@ -231,8 +263,7 @@ "mode": "absolute", "steps": [ { - "color": "green", - "value": null + "color": "green" } ] } @@ -248,7 +279,9 @@ "id": 3, "options": { "legend": { - "calcs": ["max"], + "calcs": [ + "max" + ], "displayMode": "table", "placement": "right", "showLegend": true, @@ -256,22 +289,24 @@ "sortDesc": true }, "tooltip": { + "hideZeros": false, "mode": "single", "sort": "none" } }, + "pluginVersion": "12.0.0", "targets": [ { "datasource": { "type": "prometheus", "uid": "prometheus" }, - "expr": "rate(audit_events_total[5m])", - "legendFormat": "{{event_type}} {{resource_type}}", + "expr": "rate(requests_by_user_type_total[5m])", + "legendFormat": "{{user_type}} - {{endpoint}}", "refId": "A" } ], - "title": "Audit Events Rate", + "title": "Requests by User Type", "type": "timeseries" }, { @@ -285,11 +320,13 @@ "mode": "palette-classic" }, "custom": { + "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", "axisPlacement": "auto", "barAlignment": 0, + "barWidthFactor": 0.6, "drawStyle": "line", "fillOpacity": 10, "gradientMode": "none", @@ -298,6 +335,7 @@ "tooltip": false, "viz": false }, + "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, @@ -319,8 +357,7 @@ "mode": "absolute", "steps": [ { - "color": "green", - "value": null + "color": "green" } ] } @@ -336,7 +373,9 @@ "id": 4, "options": { "legend": { - "calcs": ["max"], + "calcs": [ + "max" + ], "displayMode": "table", "placement": "right", "showLegend": true, @@ -344,22 +383,24 @@ "sortDesc": true }, "tooltip": { + "hideZeros": false, "mode": "single", "sort": "none" } }, + "pluginVersion": "12.0.0", "targets": [ { "datasource": { "type": "prometheus", "uid": "prometheus" }, - "expr": "rate(requests_per_user_total[5m])", - "legendFormat": "User {{user_id}} - {{method}} {{endpoint}}", + "expr": "rate(auth_requests_total[5m])", + "legendFormat": "{{endpoint}} - {{status}}", "refId": "A" } ], - "title": "Requests per User", + "title": "Auth Request Rate", "type": "timeseries" }, { @@ -373,11 +414,13 @@ "mode": "palette-classic" }, "custom": { + "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", "axisPlacement": "auto", "barAlignment": 0, + "barWidthFactor": 0.6, "drawStyle": "line", "fillOpacity": 10, "gradientMode": "none", @@ -386,6 +429,7 @@ "tooltip": false, "viz": false }, + "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, @@ -407,8 +451,7 @@ "mode": "absolute", "steps": [ { - "color": "green", - "value": null + "color": "green" } ] }, @@ -425,7 +468,9 @@ "id": 5, "options": { "legend": { - "calcs": ["max"], + "calcs": [ + "max" + ], "displayMode": "table", "placement": "right", "showLegend": true, @@ -433,19 +478,27 @@ "sortDesc": true }, "tooltip": { + "hideZeros": false, "mode": "single", "sort": "none" } }, + "pluginVersion": "12.0.0", "targets": [ { "datasource": { "type": "prometheus", "uid": "prometheus" }, - "expr": "rate(http_requests_total{status_code=~\"4..\"}[5m])", - "legendFormat": "{{method}} {{endpoint}} - {{status_code}}", - "refId": "A" + "disableTextWrap": false, + "editorMode": "builder", + "expr": "rate(http_requests_total{status=~\"4..\"}[5m])", + "fullMetaSearch": false, + "includeNullMetadata": true, + "legendFormat": "{{method}} {{handler}} - {{status}}", + "range": true, + "refId": "A", + "useBackend": false } ], "title": "4xx Client Error Rate", @@ -462,11 +515,13 @@ "mode": "palette-classic" }, "custom": { + "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", "axisPlacement": "auto", "barAlignment": 0, + "barWidthFactor": 0.6, "drawStyle": "line", "fillOpacity": 10, "gradientMode": "none", @@ -475,6 +530,7 @@ "tooltip": false, "viz": false }, + "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, @@ -496,8 +552,7 @@ "mode": "absolute", "steps": [ { - "color": "green", - "value": null + "color": "green" } ] }, @@ -514,7 +569,9 @@ "id": 6, "options": { "legend": { - "calcs": ["max"], + "calcs": [ + "max" + ], "displayMode": "table", "placement": "right", "showLegend": true, @@ -522,117 +579,36 @@ "sortDesc": true }, "tooltip": { + "hideZeros": false, "mode": "single", "sort": "none" } }, + "pluginVersion": "12.0.0", "targets": [ { "datasource": { "type": "prometheus", "uid": "prometheus" }, - "expr": "rate(http_requests_total{status_code=~\"5..\"}[5m])", - "legendFormat": "{{method}} {{endpoint}} - {{status_code}}", - "refId": "A" + "disableTextWrap": false, + "editorMode": "builder", + "expr": "rate(http_requests_total{status=~\"5..\"}[5m])", + "fullMetaSearch": false, + "includeNullMetadata": true, + "legendFormat": "{{method}} {{handler}} - {{status}}", + "range": true, + "refId": "A", + "useBackend": false } ], "title": "5xx Server Error Rate", "type": "timeseries" - }, - { - "datasource": { - "type": "prometheus", - "uid": "prometheus" - }, - "fieldConfig": { - "defaults": { - "color": { - "mode": "palette-classic" - }, - "custom": { - "axisCenteredZero": false, - "axisColorMode": "text", - "axisLabel": "", - "axisPlacement": "auto", - "barAlignment": 0, - "drawStyle": "line", - "fillOpacity": 10, - "gradientMode": "none", - "hideFrom": { - "legend": false, - "tooltip": false, - "viz": false - }, - "lineInterpolation": "linear", - "lineWidth": 1, - "pointSize": 5, - "scaleDistribution": { - "type": "linear" - }, - "showPoints": "never", - "spanNulls": false, - "stacking": { - "group": "A", - "mode": "none" - }, - "thresholdsStyle": { - "mode": "off" - } - }, - "mappings": [], - "thresholds": { - "mode": "absolute", - "steps": [ - { - "color": "green", - "value": null - } - ] - }, - "unit": "short" - }, - "overrides": [] - }, - "gridPos": { - "h": 8, - "w": 24, - "x": 0, - "y": 24 - }, - "id": 7, - "options": { - "legend": { - "calcs": ["max"], - "displayMode": "table", - "placement": "right", - "showLegend": true, - "sortBy": "Max", - "sortDesc": true - }, - "tooltip": { - "mode": "single", - "sort": "none" - } - }, - "targets": [ - { - "datasource": { - "type": "prometheus", - "uid": "prometheus" - }, - "expr": "rate(auth_requests_total[5m])", - "legendFormat": "{{endpoint}} - {{identifier}} ({{status_code}})", - "refId": "A" - } - ], - "title": "Authentication Attempts", - "type": "timeseries" } ], - "refresh": "5s", - "schemaVersion": 38, - "style": "dark", + "preload": false, + "refresh": "30s", + "schemaVersion": 41, "tags": [], "templating": { "list": [] @@ -645,6 +621,5 @@ "timezone": "", "title": "amazee.ai Backend", "uid": "amazeeai-backend", - "version": 1, - "weekStart": "" + "version": 6 } \ No newline at end of file diff --git a/grafana/provisioning/dashboards/fastapi.yml b/grafana/provisioning/dashboards/fastapi.yml index 9ac0b64..1db346a 100644 --- a/grafana/provisioning/dashboards/fastapi.yml +++ b/grafana/provisioning/dashboards/fastapi.yml @@ -1,7 +1,7 @@ apiVersion: 1 providers: - - name: 'FastAPI' + - name: 'amazee.ai Backend' orgId: 1 folder: '' type: file From 404f6d3974b19e96cde7ea026b546e17be74539f Mon Sep 17 00:00:00 2001 From: Pippa H Date: Wed, 14 May 2025 12:24:50 +0200 Subject: [PATCH 7/9] Add authentication middleware and refactor user retrieval logic - Introduced AuthMiddleware to handle user authentication and store user data in request state. - Updated existing middleware (AuditLogMiddleware and PrometheusMiddleware) to utilize user data from request state instead of directly querying for user information. - Enhanced get_current_user_from_auth function to check for user in request state, improving efficiency. - Cleaned up unused authentication code in middleware for better maintainability. --- app/core/security.py | 15 ++++++++-- app/main.py | 4 +++ app/middleware/audit.py | 28 ++---------------- app/middleware/auth.py | 55 ++++++++++++++++++++++++++++++++++++ app/middleware/prometheus.py | 34 +++------------------- 5 files changed, 79 insertions(+), 57 deletions(-) create mode 100644 app/middleware/auth.py diff --git a/app/core/security.py b/app/core/security.py index 3db38e0..952d40d 100644 --- a/app/core/security.py +++ b/app/core/security.py @@ -2,7 +2,7 @@ from typing import Optional, Literal, Dict from jose import JWTError, jwt from passlib.context import CryptContext -from fastapi import Depends, HTTPException, status, Cookie, Header +from fastapi import Depends, HTTPException, status, Cookie, Header, Request from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials import logging from app.core.config import settings @@ -83,9 +83,20 @@ async def get_current_user( async def get_current_user_from_auth( access_token: Optional[str] = Cookie(None, alias="access_token"), authorization: Optional[str] = Header(None), - db: Session = Depends(get_db) + db: Session = Depends(get_db), + request: Request = None ) -> DBUser: """Get current user from either JWT token (in cookie or Authorization header) or API token.""" + # First check if user is already in request state (set by AuthMiddleware) + if request and hasattr(request.state, 'user') and request.state.user is not None: + # If we have a dict from middleware, load the full user object + if isinstance(request.state.user, dict): + user = db.query(DBUser).filter(DBUser.id == request.state.user["id"]).first() + if user: + return user + else: + return request.state.user + if not access_token and not authorization: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, diff --git a/app/main.py b/app/main.py index fb882d9..c514898 100644 --- a/app/main.py +++ b/app/main.py @@ -27,6 +27,7 @@ async def dispatch(self, request, call_next): from app.db.database import get_db from app.middleware.audit import AuditLogMiddleware from app.middleware.prometheus import PrometheusMiddleware +from app.middleware.auth import AuthMiddleware app = FastAPI( title="Private AI Keys as a Service", @@ -79,6 +80,9 @@ async def dispatch(self, request, call_next): # Add HTTPS redirect middleware first app.add_middleware(HTTPSRedirectMiddleware) +# Add Auth middleware (must be before Prometheus and Audit middleware) +app.add_middleware(AuthMiddleware) + # Add Prometheus middleware app.add_middleware(PrometheusMiddleware) diff --git a/app/middleware/audit.py b/app/middleware/audit.py index 7ea4a32..ba7308d 100644 --- a/app/middleware/audit.py +++ b/app/middleware/audit.py @@ -32,30 +32,8 @@ async def dispatch(self, request: Request, call_next): # Get a fresh database session for each request db = next(get_db()) - # Try to get the current user from cookies or authorization header - user_id = None - try: - # Get access token from cookie or authorization header - cookies = request.cookies - headers = request.headers - access_token = cookies.get("access_token") - auth_header = headers.get("authorization") - - if auth_header: - parts = auth_header.split() - if len(parts) == 2 and parts[0].lower() == "bearer": - access_token = parts[1] - - if access_token: - user = await get_current_user_from_auth( - access_token=access_token if access_token else None, - authorization=auth_header if auth_header else None, - db=db - ) - user_id = user.id if user else None - except Exception as e: - logger.debug(f"Could not get user for audit log: {str(e)}") - user_id = None + # Get user_id from request state (set by AuthMiddleware) + user_id = request.state.user.id if hasattr(request.state, 'user') and request.state.user else None # Extract path parameters path_params = request.path_params @@ -71,7 +49,7 @@ async def dispatch(self, request: Request, call_next): request_source = "frontend" else: # If no origin/referer and has auth header, likely direct API call - request_source = "api" if auth_header else None + request_source = "api" if request.headers.get("authorization") else None # Get resource type from path resource_type = request.url.path.split("/")[1] # First path segment diff --git a/app/middleware/auth.py b/app/middleware/auth.py new file mode 100644 index 0000000..4272ec5 --- /dev/null +++ b/app/middleware/auth.py @@ -0,0 +1,55 @@ +from fastapi import Request +from starlette.middleware.base import BaseHTTPMiddleware +from app.core.security import get_current_user_from_auth +from app.db.database import get_db +import logging +from typing import Optional, Dict, Any + +logger = logging.getLogger(__name__) + +class AuthMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + # Skip auth for certain paths + if request.url.path in ["/health", "/docs", "/openapi.json", "/metrics"]: + return await call_next(request) + + # Initialize user as None + request.state.user = None + + try: + # Get access token from cookie or authorization header + cookies = request.cookies + headers = request.headers + access_token = cookies.get("access_token") + auth_header = headers.get("authorization") + + if auth_header: + parts = auth_header.split() + if len(parts) == 2 and parts[0].lower() == "bearer": + access_token = parts[1] + + if access_token: + # Get a fresh database session + db = next(get_db()) + try: + user = await get_current_user_from_auth( + access_token=access_token if access_token else None, + authorization=auth_header if auth_header else None, + db=db + ) + # Store essential user data instead of the full SQLAlchemy object + request.state.user = { + "id": user.id, + "email": user.email, + "is_admin": user.is_admin, + "role": user.role, + "team_id": user.team_id + } + except Exception as e: + logger.debug(f"Could not get user for request: {str(e)}") + finally: + db.close() + except Exception as e: + logger.debug(f"Error in auth middleware: {str(e)}") + + return await call_next(request) \ No newline at end of file diff --git a/app/middleware/prometheus.py b/app/middleware/prometheus.py index ebd3626..c454462 100644 --- a/app/middleware/prometheus.py +++ b/app/middleware/prometheus.py @@ -61,37 +61,11 @@ async def dispatch(self, request: Request, call_next): status="success" ).inc() - # Get user type from request if available + # Get user type from request state (set by AuthMiddleware) user_type = "anonymous" - try: - # Get access token from cookie or authorization header - cookies = request.cookies - headers = request.headers - access_token = cookies.get("access_token") - auth_header = headers.get("authorization") - - if auth_header: - parts = auth_header.split() - if len(parts) == 2 and parts[0].lower() == "bearer": - access_token = parts[1] - - if access_token: - try: - db = next(get_db()) - user = await get_current_user_from_auth( - access_token=access_token if access_token else None, - authorization=auth_header if auth_header else None, - db=db - ) - if user: - # Group users by their role or type - user_type = user.role if hasattr(user, 'role') else "authenticated" - except Exception as e: - logger.debug(f"Could not get user for metrics: {str(e)}") - finally: - db.close() - except Exception as e: - logger.debug(f"Could not get user for metrics: {str(e)}") + if hasattr(request.state, 'user') and request.state.user: + # Group users by their role or type + user_type = request.state.user.role if hasattr(request.state.user, 'role') else "authenticated" # Record requests by user type with normalized path normalized_path = normalize_path(request.url.path) From 6d9397dc9eaba4b0a7d2116a6a4b5d7876a92177 Mon Sep 17 00:00:00 2001 From: Pippa H Date: Wed, 14 May 2025 13:00:19 +0200 Subject: [PATCH 8/9] Bug-fix budget update logic - Updated the `update_budget_period` function to retrieve spend information from the `info` key in the response, improving data handling. - Added error logging to capture exceptions during budget period updates, enhancing traceability. - Refactored user ID retrieval in `AuditLogMiddleware` to handle different user data structures, improving robustness. - Introduced a new test for updating budget duration as a team admin, ensuring proper functionality and API interaction. --- app/api/private_ai_keys.py | 7 ++-- app/middleware/audit.py | 7 +++- tests/test_private_ai.py | 73 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 84 insertions(+), 3 deletions(-) diff --git a/app/api/private_ai_keys.py b/app/api/private_ai_keys.py index 1a19706..9264da0 100644 --- a/app/api/private_ai_keys.py +++ b/app/api/private_ai_keys.py @@ -556,14 +556,17 @@ async def update_budget_period( # Get updated spend information spend_data = await litellm_service.get_key_info(private_ai_key.litellm_token) + info = spend_data.get("info", {}) + # Only set default for spend field spend_info = { - "spend": spend_data.get("spend", 0.0), - **spend_data + "spend": info.get("spend", 0.0), + **info } return PrivateAIKeySpend.model_validate(spend_info) except Exception as e: + logger.error(f"Failed to update budget period: {str(e)}", exc_info=True) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to update budget period: {str(e)}" diff --git a/app/middleware/audit.py b/app/middleware/audit.py index ba7308d..f37cb90 100644 --- a/app/middleware/audit.py +++ b/app/middleware/audit.py @@ -33,7 +33,12 @@ async def dispatch(self, request: Request, call_next): db = next(get_db()) # Get user_id from request state (set by AuthMiddleware) - user_id = request.state.user.id if hasattr(request.state, 'user') and request.state.user else None + user_id = None + if hasattr(request.state, 'user') and request.state.user: + if isinstance(request.state.user, dict): + user_id = request.state.user.get('id') + else: + user_id = request.state.user.id # Extract path parameters path_params = request.path_params diff --git a/tests/test_private_ai.py b/tests/test_private_ai.py index 4b72c0a..6a7ba4d 100644 --- a/tests/test_private_ai.py +++ b/tests/test_private_ai.py @@ -979,6 +979,79 @@ def test_update_budget_period_as_key_creator(mock_post, client, team_key_creator db.delete(test_key) db.commit() +@patch("app.services.litellm.requests.post") +@patch("app.services.litellm.requests.get") +def test_update_budget_duration_as_team_admin(mock_get, mock_post, client, team_admin_token, test_region, mock_litellm_response, db, test_team): + """Test that a team admin can update the budget duration for a team-owned key""" + # Mock the LiteLLM API responses + mock_post.return_value.status_code = 200 + mock_post.return_value.json.return_value = mock_litellm_response + mock_post.return_value.raise_for_status.return_value = None + + # Mock the key info response + mock_get.return_value.status_code = 200 + mock_get.return_value.json.return_value = { + "info": { + "spend": 0.0, + "expires": "2024-12-31T23:59:59Z", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-02T00:00:00Z", + "max_budget": 100.0, + "budget_duration": "monthly", + "budget_reset_at": "2024-02-01T00:00:00Z" + } + } + mock_get.return_value.raise_for_status.return_value = None + + # Create a test key owned by the team + test_key = DBPrivateAIKey( + database_name="test-db-team", + name="Test Team Key", + database_host="test-host", + database_username="test-user", + database_password="test-pass", + litellm_token="test-token-team", + litellm_api_url="https://test-litellm.com", + team_id=test_team.id, + region_id=test_region.id + ) + db.add(test_key) + db.commit() + db.refresh(test_key) + + # Update the budget duration as team admin + response = client.put( + f"/private-ai-keys/{test_key.id}/budget-period", + headers={"Authorization": f"Bearer {team_admin_token}"}, + json={"budget_duration": "monthly"} + ) + + # Verify the response + assert response.status_code == 200 + data = response.json() + assert data["budget_duration"] == "monthly" + + # Verify that the LiteLLM API was called with the correct parameters + mock_post.assert_called_with( + f"{test_region.litellm_api_url}/key/update", + headers={"Authorization": f"Bearer {test_region.litellm_api_key}"}, + json={ + "key": test_key.litellm_token, + "budget_duration": "monthly" + } + ) + + # Verify that the key info was checked + mock_get.assert_called_with( + f"{test_region.litellm_api_url}/key/info", + headers={"Authorization": f"Bearer {test_region.litellm_api_key}"}, + params={"key": test_key.litellm_token} + ) + + # Clean up the test key + db.delete(test_key) + db.commit() + @patch("app.services.litellm.requests.post") def test_create_llm_token_as_system_admin(mock_post, client, admin_token, test_region, mock_litellm_response): """Test that a system admin can create an LLM token for themselves""" From fbeeb67e7cc74f97bbded9995ccf74d17a54c71d Mon Sep 17 00:00:00 2001 From: Pippa H Date: Wed, 14 May 2025 13:08:24 +0200 Subject: [PATCH 9/9] Enhance middleware path handling with configurable public paths - Added a new PUBLIC_PATHS setting in config.py to centralize the definition of public endpoints. - Updated AuditLogMiddleware, AuthMiddleware, and PrometheusMiddleware to utilize the PUBLIC_PATHS setting for path checks, improving maintainability and consistency across middleware. --- app/core/config.py | 1 + app/middleware/audit.py | 7 ++----- app/middleware/auth.py | 4 ++-- app/middleware/prometheus.py | 5 ++--- 4 files changed, 7 insertions(+), 10 deletions(-) diff --git a/app/core/config.py b/app/core/config.py index b8cc50f..260d849 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,6 +19,7 @@ class Settings(BaseSettings): "http://localhost:8800" ] ALLOWED_HOSTS: list[str] = ["*"] # In production, restrict this + PUBLIC_PATHS: list[str] = ["/health", "/docs", "/openapi.json", "/metrics"] model_config = ConfigDict(env_file=".env") diff --git a/app/middleware/audit.py b/app/middleware/audit.py index f37cb90..8d4d7ac 100644 --- a/app/middleware/audit.py +++ b/app/middleware/audit.py @@ -2,14 +2,11 @@ from starlette.middleware.base import BaseHTTPMiddleware from sqlalchemy.orm import Session from app.db.models import DBAuditLog -from app.api.auth import get_current_user_from_auth from app.db.database import get_db from app.middleware.prometheus import audit_events_total, audit_event_duration_seconds -import json import logging import time -from fastapi import Cookie, Header -from typing import Optional +from app.core.config import settings logger = logging.getLogger(__name__) @@ -20,7 +17,7 @@ def __init__(self, app, db: Session): async def dispatch(self, request: Request, call_next): # Skip audit logging for certain paths - if request.url.path in ["/health", "/docs", "/openapi.json", "/audit/logs", "/auth/me", "/metrics"]: + if request.url.path in {*settings.PUBLIC_PATHS, "/audit/logs", "/auth/me"}: return await call_next(request) start_time = time.time() diff --git a/app/middleware/auth.py b/app/middleware/auth.py index 4272ec5..21f3046 100644 --- a/app/middleware/auth.py +++ b/app/middleware/auth.py @@ -3,14 +3,14 @@ from app.core.security import get_current_user_from_auth from app.db.database import get_db import logging -from typing import Optional, Dict, Any +from app.core.config import settings logger = logging.getLogger(__name__) class AuthMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): # Skip auth for certain paths - if request.url.path in ["/health", "/docs", "/openapi.json", "/metrics"]: + if request.url.path in settings.PUBLIC_PATHS: return await call_next(request) # Initialize user as None diff --git a/app/middleware/prometheus.py b/app/middleware/prometheus.py index c454462..4d555f6 100644 --- a/app/middleware/prometheus.py +++ b/app/middleware/prometheus.py @@ -1,9 +1,8 @@ from prometheus_client import Counter, Histogram from fastapi import Request from starlette.middleware.base import BaseHTTPMiddleware -from app.core.security import get_current_user_from_auth -from app.db.database import get_db import logging +from app.core.config import settings logger = logging.getLogger(__name__) @@ -42,7 +41,7 @@ def normalize_path(path: str) -> str: class PrometheusMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): # Skip metrics for certain paths - if request.url.path in ["/metrics", "/health", "/docs", "/openapi.json"]: + if request.url.path in settings.PUBLIC_PATHS: return await call_next(request) # Track auth requests for specific endpoints