Skip to content

Commit 4a76fa8

Browse files
authored
Merge pull request #74 from amazeeio/jwt-re-apply
Revert "Revert "JWT validation and trial expiry""
2 parents ced9ba9 + 9455fd3 commit 4a76fa8

File tree

16 files changed

+2697
-1409
lines changed

16 files changed

+2697
-1409
lines changed

app/api/auth.py

Lines changed: 191 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1-
from typing import Optional, List, Union
2-
import email_validator
3-
from fastapi import APIRouter, Depends, HTTPException, status, Response, Request, Form
4-
from email_validator import validate_email, EmailNotValidError
5-
from sqlalchemy.orm import Session
61
import logging
7-
from logging.handlers import TimedRotatingFileHandler
82
import secrets
93
import os
4+
import email_validator
5+
6+
from typing import Optional, List, Union
7+
from fastapi import APIRouter, Depends, HTTPException, status, Response, Request, Form
8+
from sqlalchemy.orm import Session
109
from urllib.parse import urlparse
11-
from pathlib import Path
10+
from fastapi import HTTPException, status, Response, Request, Form
11+
from fastapi.security import HTTPAuthorizationCredentials
12+
from jose import JWTError, jwt
13+
from app.core.config import settings
1214

1315
from app.db.database import get_db
1416
from app.schemas.models import (
@@ -32,38 +34,12 @@
3234
get_password_hash,
3335
create_access_token,
3436
get_current_user_from_auth,
37+
get_current_user,
3538
)
3639
from app.services.dynamodb import DynamoDBService
3740
from app.services.ses import SESService
3841
from app.api.teams import register_team
3942

40-
# # Configure auth logger - disabled until we make it work in lagoon
41-
# auth_logger = logging.getLogger("auth")
42-
# auth_logger.setLevel(logging.INFO)
43-
#
44-
# # Create logs directory if it doesn't exist
45-
# log_dir = Path("logs")
46-
# log_dir.mkdir(exist_ok=True)
47-
#
48-
# # Configure file handler with daily rotation
49-
# file_handler = TimedRotatingFileHandler(
50-
# filename=log_dir / "auth.log",
51-
# when="midnight",
52-
# interval=1,
53-
# backupCount=30, # Keep logs for 30 days
54-
# encoding="utf-8"
55-
# )
56-
#
57-
# # Configure formatter
58-
# formatter = logging.Formatter(
59-
# "%(asctime)s - %(levelname)s - %(message)s",
60-
# datefmt="%Y-%m-%d %H:%M:%S"
61-
# )
62-
# file_handler.setFormatter(formatter)
63-
#
64-
# # Add handler to logger
65-
# auth_logger.addHandler(file_handler)
66-
6743
auth_logger = logging.getLogger(__name__)
6844

6945
router = APIRouter(
@@ -160,7 +136,7 @@ def create_and_set_access_token(response: Response, user_email: str) -> Token:
160136
# Set cookie with appropriate settings
161137
response.set_cookie(**cookie_settings)
162138

163-
return {"access_token": access_token, "token_type": "bearer"}
139+
return Token(access_token=access_token, token_type="bearer")
164140

165141
@router.post("/login", response_model=Token)
166142
async def login(
@@ -317,6 +293,64 @@ async def register(
317293
auth_logger.info(f"Successfully registered new user: {user.email}")
318294
return db_user
319295

296+
def generate_validation_token(email: str) -> str:
297+
"""
298+
Generate a validation token for the given email and store it in DynamoDB.
299+
300+
Args:
301+
email (str): The email address to generate a token for
302+
303+
Returns:
304+
str: The generated validation token (8 characters, alphanumeric, uppercase)
305+
"""
306+
# Generate an 8-character alphanumeric code in uppercase
307+
code = ''.join(secrets.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789') for _ in range(8))
308+
309+
# Store the code in DynamoDB
310+
dynamodb_service = DynamoDBService()
311+
dynamodb_service.write_validation_code(email, code)
312+
313+
return code
314+
315+
def send_validation_code(email: str, db: Session) -> None:
316+
"""
317+
Generate and send a validation code to the specified email address.
318+
319+
Args:
320+
email (str): The email address to send the code to
321+
db (Session): Database session to check if user exists
322+
323+
Raises:
324+
HTTPException: If email sending fails
325+
"""
326+
# Generate and store validation code
327+
code = generate_validation_token(email)
328+
329+
# Determine if user exists to choose appropriate template
330+
user = db.query(DBUser).filter(DBUser.email == email).first()
331+
email_template = 'returning-user-code' if user else 'new-user-code'
332+
333+
auth_logger.info(f"Sending validation code to {'existing' if user else 'new'} user: {email}")
334+
335+
# Send the validation code via email
336+
ses_service = SESService()
337+
email_sent = ses_service.send_email(
338+
to_addresses=[email],
339+
template_name=email_template,
340+
template_data={
341+
'code': code
342+
}
343+
)
344+
345+
if not email_sent:
346+
auth_logger.error(f"Failed to send validation code email to {email}")
347+
raise HTTPException(
348+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
349+
detail="Failed to send validation code email"
350+
)
351+
352+
auth_logger.info(f"Successfully sent validation code to: {email}")
353+
320354
@router.post("/validate-email")
321355
async def validate_email(
322356
request: Request,
@@ -358,109 +392,18 @@ async def validate_email(
358392
auth_logger.info(f"Email validation attempt for: {email}")
359393
try:
360394
email_validator.validate_email(email, check_deliverability=False)
361-
except EmailNotValidError as e:
395+
except email_validator.EmailNotValidError as e:
362396
auth_logger.warning(f"Invalid email format for {email}: {e}")
363397
raise HTTPException(
364398
status_code=status.HTTP_400_BAD_REQUEST,
365399
detail=f"Invalid email format: {e}"
366400
)
367401

368-
# Generate and store validation code
369-
code = generate_validation_token(email)
370-
user = db.query(DBUser).filter(DBUser.email == email).first()
371-
if user:
372-
email_template = 'returning-user-code'
373-
auth_logger.info(f"Sending validation code to existing user: {email}")
374-
else:
375-
email_template = 'new-user-code'
376-
auth_logger.info(f"Sending validation code to new user: {email}")
377-
378-
# Send the validation code via email
379-
ses_service = SESService()
380-
email_sent = ses_service.send_email(
381-
to_addresses=[email],
382-
template_name=email_template,
383-
template_data={
384-
'code': code
385-
}
386-
)
387-
388-
if not email_sent:
389-
auth_logger.error(f"Failed to send validation code email to {email}")
390-
raise HTTPException(
391-
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
392-
detail="Failed to send validation code email"
393-
)
394-
395-
auth_logger.info(f"Successfully sent validation code to: {email}")
402+
send_validation_code(email, db)
396403
return {
397404
"message": "Validation code has been generated and sent"
398405
}
399406

400-
def generate_token() -> str:
401-
return secrets.token_urlsafe(32)
402-
403-
def generate_validation_token(email: str) -> str:
404-
"""
405-
Generate a validation token for the given email and store it in DynamoDB.
406-
407-
Args:
408-
email (str): The email address to generate a token for
409-
410-
Returns:
411-
str: The generated validation token (8 characters, alphanumeric, uppercase)
412-
"""
413-
# Generate an 8-character alphanumeric code in uppercase
414-
code = ''.join(secrets.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789') for _ in range(8))
415-
416-
# Store the code in DynamoDB
417-
dynamodb_service = DynamoDBService()
418-
dynamodb_service.write_validation_code(email, code)
419-
420-
return code
421-
422-
# API Token routes (as apposed to AI Token routes)
423-
@router.post("/token", response_model=APIToken)
424-
async def create_token(
425-
token_create: APITokenCreate,
426-
current_user = Depends(get_current_user_from_auth),
427-
db: Session = Depends(get_db)
428-
):
429-
db_token = DBAPIToken(
430-
name=token_create.name,
431-
token=generate_token(),
432-
user_id=current_user.id
433-
)
434-
db.add(db_token)
435-
db.commit()
436-
db.refresh(db_token)
437-
return db_token
438-
439-
@router.get("/token", response_model=List[APITokenResponse])
440-
async def list_tokens(
441-
current_user = Depends(get_current_user_from_auth),
442-
db: Session = Depends(get_db)
443-
):
444-
"""List all API tokens for the current user"""
445-
return current_user.api_tokens
446-
447-
@router.delete("/token/{token_id}")
448-
async def delete_token(
449-
token_id: int,
450-
current_user = Depends(get_current_user_from_auth),
451-
db: Session = Depends(get_db)
452-
):
453-
"""Delete an API token"""
454-
token = db.query(DBAPIToken).filter(
455-
DBAPIToken.id == token_id,
456-
DBAPIToken.user_id == current_user.id
457-
).first()
458-
if not token:
459-
raise HTTPException(status_code=404, detail="Token not found")
460-
db.delete(token)
461-
db.commit()
462-
return {"message": "Token deleted successfully"}
463-
464407
@router.post("/sign-in", response_model=Token)
465408
async def sign_in(
466409
request: Request,
@@ -534,4 +477,123 @@ async def sign_in(
534477
auth_logger.info(f"Successfully created new user and team for: {sign_in_data.username}")
535478

536479
auth_logger.info(f"Successful sign-in for user: {sign_in_data.username}")
537-
return create_and_set_access_token(response, user.email)
480+
return create_and_set_access_token(response, user.email)
481+
482+
# API Token routes (as apposed to AI Token routes)
483+
def generate_api_token() -> str:
484+
return secrets.token_urlsafe(32)
485+
486+
@router.post("/token", response_model=APIToken)
487+
async def create_token(
488+
token_create: APITokenCreate,
489+
current_user = Depends(get_current_user_from_auth),
490+
db: Session = Depends(get_db)
491+
):
492+
db_token = DBAPIToken(
493+
name=token_create.name,
494+
token=generate_api_token(),
495+
user_id=current_user.id
496+
)
497+
db.add(db_token)
498+
db.commit()
499+
db.refresh(db_token)
500+
return db_token
501+
502+
@router.get("/token", response_model=List[APITokenResponse])
503+
async def list_tokens(
504+
current_user = Depends(get_current_user_from_auth),
505+
db: Session = Depends(get_db)
506+
):
507+
"""List all API tokens for the current user"""
508+
return current_user.api_tokens
509+
510+
@router.delete("/token/{token_id}")
511+
async def delete_token(
512+
token_id: int,
513+
current_user = Depends(get_current_user_from_auth),
514+
db: Session = Depends(get_db)
515+
):
516+
"""Delete an API token"""
517+
token = db.query(DBAPIToken).filter(
518+
DBAPIToken.id == token_id,
519+
DBAPIToken.user_id == current_user.id
520+
).first()
521+
if not token:
522+
raise HTTPException(status_code=404, detail="Token not found")
523+
db.delete(token)
524+
db.commit()
525+
return {"message": "Token deleted successfully"}
526+
527+
@router.get("/validate-jwt", response_model=Token)
528+
async def validate_jwt(
529+
request: Request,
530+
response: Response,
531+
token: Optional[str] = None,
532+
db: Session = Depends(get_db)
533+
):
534+
"""
535+
Validate a JWT token and either refresh it or send a new validation code.
536+
537+
The token can be provided either:
538+
- As a query parameter: ?token=your_token
539+
- In the Authorization header: Bearer your_token
540+
541+
Returns:
542+
- If token is valid: A new access token with cookies set
543+
- If token is expired: 401 with message about validation code being sent
544+
"""
545+
credentials_exception = HTTPException(
546+
status_code=status.HTTP_401_UNAUTHORIZED,
547+
detail="Could not validate credentials",
548+
headers={"WWW-Authenticate": "Bearer"},
549+
)
550+
551+
# Get token from Authorization header if not provided as parameter
552+
if not token:
553+
auth_header = request.headers.get("Authorization")
554+
if not auth_header or not auth_header.startswith("Bearer "):
555+
raise credentials_exception
556+
token = auth_header.split(" ")[1]
557+
558+
try:
559+
# Try to validate the token
560+
payload = jwt.decode(
561+
token,
562+
settings.SECRET_KEY,
563+
algorithms=[settings.ALGORITHM]
564+
)
565+
email: str = payload.get("sub")
566+
user = db.query(DBUser).filter(DBUser.email == email).first()
567+
if not user:
568+
raise credentials_exception
569+
570+
# Token is valid, create new access token
571+
auth_logger.info(f"Successfully validated JWT for user: {user.email}")
572+
return create_and_set_access_token(response, user.email)
573+
574+
except JWTError as e:
575+
if isinstance(e, jwt.ExpiredSignatureError):
576+
# Token is expired, try to get email from expired token
577+
try:
578+
# Decode without verifying expiration
579+
payload = jwt.decode(
580+
token,
581+
settings.SECRET_KEY,
582+
algorithms=[settings.ALGORITHM],
583+
options={"verify_exp": False}
584+
)
585+
email = payload.get("sub")
586+
587+
if not email:
588+
raise credentials_exception
589+
590+
send_validation_code(email, db)
591+
raise HTTPException(
592+
status_code=status.HTTP_401_UNAUTHORIZED,
593+
detail="Token expired. A new validation code has been sent to your email."
594+
)
595+
596+
except JWTError:
597+
raise credentials_exception
598+
else:
599+
raise credentials_exception

0 commit comments

Comments
 (0)