Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
npm-debug.log*
yarn-debug.log*
yarn-error.log*
**/logs/**
*.log
*.log.*

# Environment variables
.env
Expand Down
55 changes: 52 additions & 3 deletions app/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,38 @@
from email_validator import validate_email, EmailNotValidError
from sqlalchemy.orm import Session
import logging
from logging.handlers import TimedRotatingFileHandler
import secrets
import os
from urllib.parse import urlparse
from pathlib import Path

# Configure auth logger
auth_logger = logging.getLogger("auth")
auth_logger.setLevel(logging.INFO)

# Create logs directory if it doesn't exist
log_dir = Path("logs")
log_dir.mkdir(exist_ok=True)

# Configure file handler with daily rotation
file_handler = TimedRotatingFileHandler(
filename=log_dir / "auth.log",
when="midnight",
interval=1,
backupCount=30, # Keep logs for 30 days
encoding="utf-8"
)

# Configure formatter
formatter = logging.Formatter(
"%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S"
)
file_handler.setFormatter(formatter)

logger = logging.getLogger(__name__)
# Add handler to logger
auth_logger.addHandler(file_handler)

from app.db.database import get_db
from app.schemas.models import (
Expand Down Expand Up @@ -156,13 +183,16 @@ async def login(
detail="Invalid login data. Please provide username and password in either form data or JSON format."
)

auth_logger.info(f"Login attempt for user: {login_data.username}")
user = db.query(DBUser).filter(DBUser.email == login_data.username).first()
if not user or not verify_password(login_data.password, user.hashed_password):
auth_logger.warning(f"Failed login attempt for user: {login_data.username}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password"
)

auth_logger.info(f"Successful login for user: {login_data.username}")
return create_and_set_access_token(response, user.email)

@router.post("/logout")
Expand Down Expand Up @@ -243,7 +273,11 @@ async def update_user_me(
return current_user

@router.post("/register", response_model=User)
async def register(user: UserCreate, db: Session = Depends(get_db)):
async def register(
request: Request,
user: UserCreate,
db: Session = Depends(get_db)
):
"""
Register a new user account.

Expand All @@ -252,9 +286,11 @@ async def register(user: UserCreate, db: Session = Depends(get_db)):

After registration, you'll need to login to get an access token.
"""
auth_logger.info(f"Registration attempt for user: {user.email}")
# Check if user with this email exists
db_user = db.query(DBUser).filter(DBUser.email == user.email).first()
if db_user:
auth_logger.warning(f"Registration failed - email already exists: {user.email}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered"
Expand All @@ -270,6 +306,7 @@ async def register(user: UserCreate, db: Session = Depends(get_db)):
db.add(db_user)
db.commit()
db.refresh(db_user)
auth_logger.info(f"Successfully registered new user: {user.email}")
return db_user

@router.post("/validate-email")
Expand Down Expand Up @@ -304,14 +341,17 @@ async def validate_email(
pass

if not email:
auth_logger.warning("Email validation attempt with no email provided")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email is required"
)

auth_logger.info(f"Email validation attempt for: {email}")
try:
email_validator.validate_email(email, check_deliverability=False)
except EmailNotValidError as e:
auth_logger.warning(f"Invalid email format for {email}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid email format: {e}"
Expand All @@ -322,8 +362,10 @@ async def validate_email(
user = db.query(DBUser).filter(DBUser.email == email).first()
if user:
email_template = 'returning-user-code'
auth_logger.info(f"Sending validation code to existing user: {email}")
else:
email_template = 'new-user-code'
auth_logger.info(f"Sending validation code to new user: {email}")

# Send the validation code via email
ses_service = SESService()
Expand All @@ -336,12 +378,13 @@ async def validate_email(
)

if not email_sent:
logger.error(f"Failed to send validation code email to {email}")
auth_logger.error(f"Failed to send validation code email to {email}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to send validation code email"
)

auth_logger.info(f"Successfully sent validation code to: {email}")
return {
"message": "Validation code has been generated and sent"
}
Expand Down Expand Up @@ -437,16 +480,19 @@ async def sign_in(
with them as the admin.
"""
if not sign_in_data:
auth_logger.warning("Sign-in attempt with invalid data format")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid sign in data. Please provide username and verification code in either form data or JSON format."
)

auth_logger.info(f"Sign-in attempt for user: {sign_in_data.username}")
# Verify the code using DynamoDB first
dynamodb_service = DynamoDBService()
stored_code = dynamodb_service.read_validation_code(sign_in_data.username)

if not stored_code or stored_code.get('code').upper() != sign_in_data.verification_code.upper():
auth_logger.warning(f"Invalid verification code for user: {sign_in_data.username}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or verification code"
Expand All @@ -457,6 +503,7 @@ async def sign_in(

# If user doesn't exist, create a new user and team
if not user:
auth_logger.info(f"Creating new user and team for: {sign_in_data.username}")
# First create the team
team_data = TeamCreate(
name=f"Team {sign_in_data.username}",
Expand All @@ -476,5 +523,7 @@ async def sign_in(
db.add(user)
db.commit()
db.refresh(user)
auth_logger.info(f"Successfully created new user and team for: {sign_in_data.username}")

auth_logger.info(f"Successful sign-in for user: {sign_in_data.username}")
return create_and_set_access_token(response, user.email)
15 changes: 13 additions & 2 deletions app/core/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional, Literal, Dict
from jose import JWTError, jwt
from passlib.context import CryptContext
from fastapi import Depends, HTTPException, status, Cookie, Header
from fastapi import Depends, HTTPException, status, Cookie, Header, Request
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
import logging
from app.core.config import settings
Expand Down Expand Up @@ -83,9 +83,20 @@ async def get_current_user(
async def get_current_user_from_auth(
access_token: Optional[str] = Cookie(None, alias="access_token"),
authorization: Optional[str] = Header(None),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
request: Request = None
) -> DBUser:
"""Get current user from either JWT token (in cookie or Authorization header) or API token."""
# First check if user is already in request state (set by AuthMiddleware)
if request and hasattr(request.state, 'user') and request.state.user is not None:
# If we have a dict from middleware, load the full user object
if isinstance(request.state.user, dict):
user = db.query(DBUser).filter(DBUser.id == request.state.user["id"]).first()
if user:
return user
else:
return request.state.user

if not access_token and not authorization:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
Expand Down
31 changes: 28 additions & 3 deletions app/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from fastapi import FastAPI, Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.orm import Session
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from prometheus_fastapi_instrumentator import Instrumentator, metrics
import os
import logging

Expand All @@ -27,6 +26,8 @@ async def dispatch(self, request, call_next):
from app.core.config import settings
from app.db.database import get_db
from app.middleware.audit import AuditLogMiddleware
from app.middleware.prometheus import PrometheusMiddleware
from app.middleware.auth import AuthMiddleware

app = FastAPI(
title="Private AI Keys as a Service",
Expand Down Expand Up @@ -79,6 +80,12 @@ async def dispatch(self, request, call_next):
# Add HTTPS redirect middleware first
app.add_middleware(HTTPSRedirectMiddleware)

# Add Auth middleware (must be before Prometheus and Audit middleware)
app.add_middleware(AuthMiddleware)

# Add Prometheus middleware
app.add_middleware(PrometheusMiddleware)

# Configure CORS
app.add_middleware(
CORSMiddleware,
Expand All @@ -96,6 +103,24 @@ async def dispatch(self, request, call_next):

app.add_middleware(AuditLogMiddleware, db=next(get_db()))

# Setup Prometheus instrumentation
instrumentator = Instrumentator(
should_group_status_codes=False,
should_ignore_untemplated=True,
should_respect_env_var=True,
should_instrument_requests_inprogress=True,
excluded_handlers=["/metrics"],
env_var_name="ENABLE_METRICS",
inprogress_name="fastapi_inprogress",
inprogress_labels=True,
)

# Add default metrics
instrumentator.add(metrics.default())

# Instrument the app
instrumentator.instrument(app).expose(app)

@app.get("/health")
async def health_check():
return {"status": "healthy"}
Expand Down
54 changes: 27 additions & 27 deletions app/middleware/audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from app.db.models import DBAuditLog
from app.api.auth import get_current_user_from_auth
from app.db.database import get_db
from app.middleware.prometheus import audit_events_total, audit_event_duration_seconds
import json
import logging
import time
from fastapi import Cookie, Header
from typing import Optional

Expand All @@ -18,40 +20,20 @@ def __init__(self, app, db: Session):

async def dispatch(self, request: Request, call_next):
# Skip audit logging for certain paths
if request.url.path in ["/health", "/docs", "/openapi.json", "/audit/logs", "/auth/me"]:
if request.url.path in ["/health", "/docs", "/openapi.json", "/audit/logs", "/auth/me", "/metrics"]:
return await call_next(request)

start_time = time.time()

# Get the response
response = await call_next(request)

try:
# Get a fresh database session for each request
db = next(get_db())

# Try to get the current user from cookies or authorization header
user_id = None
try:
# Get access token from cookie or authorization header
cookies = request.cookies
headers = request.headers
access_token = cookies.get("access_token")
auth_header = headers.get("authorization")

if auth_header:
parts = auth_header.split()
if len(parts) == 2 and parts[0].lower() == "bearer":
access_token = parts[1]

if access_token:
user = await get_current_user_from_auth(
access_token=access_token if access_token else None,
authorization=auth_header if auth_header else None,
db=db
)
user_id = user.id if user else None
except Exception as e:
logger.debug(f"Could not get user for audit log: {str(e)}")
user_id = None
# Get user_id from request state (set by AuthMiddleware)
user_id = request.state.user.id if hasattr(request.state, 'user') and request.state.user else None

# Extract path parameters
path_params = request.path_params
Expand All @@ -67,13 +49,16 @@ async def dispatch(self, request: Request, call_next):
request_source = "frontend"
else:
# If no origin/referer and has auth header, likely direct API call
request_source = "api" if auth_header else None
request_source = "api" if request.headers.get("authorization") else None

# Get resource type from path
resource_type = request.url.path.split("/")[1] # First path segment

# Create audit log entry
audit_log = DBAuditLog(
user_id=user_id,
event_type=request.method,
resource_type=request.url.path.split("/")[1], # First path segment
resource_type=resource_type,
resource_id=str(resource_id) if resource_id else None,
action=f"{request.method} {request.url.path}",
details={
Expand All @@ -89,6 +74,21 @@ async def dispatch(self, request: Request, call_next):
db.add(audit_log)
db.commit()

# Record audit metrics
audit_events_total.labels(
event_type=request.method,
resource_type=resource_type,
request_source=request_source or "unknown",
status_code=response.status_code
).inc()

# Record audit event duration
duration = time.time() - start_time
audit_event_duration_seconds.labels(
event_type=request.method,
resource_type=resource_type
).observe(duration)

except Exception as e:
logger.error(f"Failed to create audit log: {str(e)}", exc_info=True)
# Don't re-raise the exception - we don't want to break the request if audit logging fails
Expand Down
Loading