Skip to content

Commit 4d9d0e5

Browse files
authored
Merge pull request #40 from sustech-cs304/backend
[feat] Authentication scheme
2 parents 763febc + 1921f39 commit 4d9d0e5

File tree

13 files changed

+497
-117
lines changed

13 files changed

+497
-117
lines changed

backend/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ COPY app/ ./app
77
RUN uv sync --frozen
88
ENV PATH="/app/.venv/bin:$PATH"
99

10-
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "5000"]
10+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "5000", "--workers", "4"]

backend/app/auth/__init__.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
from fastapi import APIRouter, Depends, HTTPException, status, Response
2+
from fastapi.security import HTTPBearer
3+
from sqlalchemy.orm import Session
4+
from typing import Annotated
5+
import uuid
6+
from datetime import timedelta, datetime
7+
8+
from app.auth.middleware import get_db, get_current_user
9+
from app.models.user import (
10+
TokenSchema,
11+
UserCreate,
12+
UserLogin,
13+
UserResponse,
14+
User,
15+
AuthToken,
16+
)
17+
from app.auth.utils import (
18+
verify_password,
19+
get_password_hash,
20+
create_access_token,
21+
create_auth_token,
22+
get_user_by_id,
23+
invalidate_session,
24+
ACCESS_TOKEN_EXPIRE_MINUTES,
25+
)
26+
27+
router = APIRouter()
28+
security = HTTPBearer()
29+
30+
31+
@router.post(
32+
"/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED
33+
)
34+
async def register(user_data: UserCreate, db: Session = Depends(get_db)):
35+
# Check if user exists
36+
db_user = get_user_by_id(db, user_data.user_id)
37+
if db_user:
38+
raise HTTPException(
39+
status_code=status.HTTP_400_BAD_REQUEST, detail="User ID already registered"
40+
)
41+
42+
# Create new user
43+
hashed_password = get_password_hash(user_data.password)
44+
db_user = User(
45+
user_id=user_data.user_id,
46+
name=user_data.name,
47+
password=hashed_password,
48+
is_teacher=user_data.is_teacher,
49+
courses=[],
50+
)
51+
52+
db.add(db_user)
53+
db.commit()
54+
db.refresh(db_user)
55+
56+
return db_user
57+
58+
59+
@router.post("/login", response_model=TokenSchema)
60+
async def login(
61+
response: Response, user_data: UserLogin, db: Session = Depends(get_db)
62+
):
63+
# Verify user
64+
user = get_user_by_id(db, user_data.user_id)
65+
if not user or not verify_password(user_data.password, user.password):
66+
raise HTTPException(
67+
status_code=status.HTTP_401_UNAUTHORIZED,
68+
detail="Incorrect username or password",
69+
headers={"WWW-Authenticate": "Bearer"},
70+
)
71+
72+
# Create access token
73+
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
74+
access_token = create_access_token(
75+
data={"sub": user.user_id}, expires_delta=access_token_expires
76+
)
77+
78+
# Create auth token with session
79+
session_id = create_auth_token(db, user.user_id, access_token)
80+
81+
# Set cookie
82+
response.set_cookie(
83+
key="session_id",
84+
value=session_id,
85+
httponly=True,
86+
max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
87+
samesite="lax",
88+
)
89+
90+
return {
91+
"access_token": access_token,
92+
"token_type": "bearer",
93+
"session_id": session_id,
94+
}
95+
96+
97+
@router.post("/logout")
98+
async def logout(
99+
response: Response,
100+
current_user: User = Depends(get_current_user),
101+
session_id: str = Depends(lambda request: request.cookies.get("session_id")),
102+
db: Session = Depends(get_db),
103+
):
104+
# Invalidate session and related tokens
105+
if session_id:
106+
invalidate_session(db, session_id)
107+
108+
# Revoke all tokens for this user
109+
db.query(AuthToken).filter(
110+
AuthToken.user_id == current_user.user_id,
111+
AuthToken.is_revoked == False,
112+
AuthToken.expires > datetime.now(),
113+
).update({"is_revoked": True})
114+
db.commit()
115+
116+
# Clear cookie
117+
response.delete_cookie(key="session_id")
118+
119+
return {"message": "Successfully logged out"}
120+
121+
122+
@router.get("/whoami", response_model=UserResponse)
123+
async def get_user_me(current_user: User = Depends(get_current_user)):
124+
return current_user

backend/app/auth/middleware.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from fastapi import Depends, HTTPException, status, Request, Cookie
2+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
3+
from jose import jwt, JWTError
4+
from sqlalchemy.orm import Session
5+
from typing import Optional
6+
from datetime import datetime
7+
8+
from app.db import SessionLocal
9+
from app.auth.utils import SECRET_KEY, ALGORITHM, get_user_by_session, verify_token
10+
from app.models.user import TokenData
11+
12+
security = HTTPBearer()
13+
14+
15+
def get_db():
16+
db = SessionLocal()
17+
try:
18+
yield db
19+
finally:
20+
db.close()
21+
22+
23+
async def get_current_user(
24+
credentials: HTTPAuthorizationCredentials = Depends(security),
25+
session_id: Optional[str] = Cookie(None, alias="session_id"),
26+
db: Session = Depends(get_db),
27+
):
28+
credentials_exception = HTTPException(
29+
status_code=status.HTTP_401_UNAUTHORIZED,
30+
detail="Could not validate credentials",
31+
headers={"WWW-Authenticate": "Bearer"},
32+
)
33+
34+
try:
35+
# First check the JWT token
36+
token = credentials.credentials
37+
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
38+
user_id: str = payload.get("sub")
39+
if user_id is None:
40+
raise credentials_exception
41+
token_data = TokenData(user_id=user_id)
42+
43+
# Verify token exists and is not revoked in database
44+
if not verify_token(db, token):
45+
raise credentials_exception
46+
except JWTError:
47+
raise credentials_exception
48+
49+
# Then verify the session
50+
if not session_id:
51+
raise credentials_exception
52+
53+
user = get_user_by_session(db, session_id)
54+
if user is None or user.user_id != token_data.user_id:
55+
raise credentials_exception
56+
57+
return user
58+
59+
60+
async def get_optional_user(request: Request, db: Session = Depends(get_db)):
61+
session_id = request.cookies.get("session_id")
62+
if not session_id:
63+
return None
64+
65+
return get_user_by_session(db, session_id)

backend/app/auth/utils.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from datetime import datetime, timedelta
2+
import os
3+
import uuid
4+
from typing import Optional
5+
from jose import jwt, JWTError
6+
from passlib.context import CryptContext
7+
from sqlalchemy.orm import Session
8+
from app.models.user import User, AuthToken
9+
from fastapi import HTTPException, status
10+
11+
# Password hashing
12+
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
13+
14+
# JWT settings
15+
SECRET_KEY = os.getenv("SECRET_KEY", "supersecretkey")
16+
ALGORITHM = "HS256"
17+
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 # 1 day
18+
19+
20+
def verify_password(plain_password, hashed_password):
21+
return pwd_context.verify(plain_password, hashed_password)
22+
23+
24+
def get_password_hash(password):
25+
return pwd_context.hash(password)
26+
27+
28+
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
29+
to_encode = data.copy()
30+
expire = datetime.utcnow() + (
31+
expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
32+
)
33+
to_encode.update({"exp": expire})
34+
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
35+
return encoded_jwt
36+
37+
38+
def create_auth_token(db: Session, user_id: str, token: str) -> str:
39+
# Create a session with native UUID
40+
session_id = uuid.uuid4()
41+
expires = datetime.now() + timedelta(days=1)
42+
43+
# Create auth token in database
44+
auth_token = AuthToken(
45+
session_id=session_id, token=token, user_id=user_id, expires=expires
46+
)
47+
48+
db.add(auth_token)
49+
db.commit()
50+
db.refresh(auth_token)
51+
52+
# Return string representation for HTTP use
53+
return str(session_id)
54+
55+
56+
def get_user_by_id(db: Session, user_id: str):
57+
return db.query(User).filter(User.user_id == user_id).first()
58+
59+
60+
def get_user_by_session(db: Session, session_id: str):
61+
try:
62+
# Convert string session_id to UUID for database query
63+
uuid_obj = uuid.UUID(session_id)
64+
auth_token = (
65+
db.query(AuthToken)
66+
.filter(
67+
AuthToken.session_id == uuid_obj,
68+
AuthToken.is_revoked == False,
69+
AuthToken.expires > datetime.now(),
70+
)
71+
.first()
72+
)
73+
74+
if not auth_token:
75+
return None
76+
77+
return get_user_by_id(db, auth_token.user_id)
78+
except ValueError:
79+
# Invalid UUID
80+
return None
81+
82+
83+
def invalidate_session(db: Session, session_id: str):
84+
try:
85+
uuid_obj = uuid.UUID(session_id)
86+
db.query(AuthToken).filter(AuthToken.session_id == uuid_obj).update(
87+
{"is_revoked": True}
88+
)
89+
db.commit()
90+
except ValueError:
91+
# Invalid UUID
92+
pass
93+
94+
95+
def verify_token(db: Session, token: str) -> Optional[AuthToken]:
96+
"""Verify if a token exists and is valid"""
97+
return (
98+
db.query(AuthToken)
99+
.filter(
100+
AuthToken.token == token,
101+
AuthToken.is_revoked == False,
102+
AuthToken.expires > datetime.now(),
103+
)
104+
.first()
105+
)

backend/app/main.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,60 @@
11
# app/main.py
22
from fastapi import FastAPI
3+
from contextlib import asynccontextmanager
4+
from fastapi.middleware.cors import CORSMiddleware
5+
from fastapi.security import HTTPBearer
36

47
from app.db import Base, engine
5-
from app.models import (
6-
assignment,
7-
bookmarklist,
8-
code_snippet,
9-
comment,
10-
course,
11-
material,
12-
note,
13-
section,
14-
user,
15-
)
8+
from app.models import *
169

1710
from app.slide.comment import router as comment_router
1811
from app.slide.material import router as material_router
1912
from app.slide.note import router as note_router
2013
from app.slide.code_snippet import router as code_snippet_router
2114
from app.slide.bookmarklist import router as bookmarklist_router
15+
from app.auth import router as auth_router
16+
17+
18+
@asynccontextmanager
19+
async def lifespan(app: FastAPI):
20+
# Startup: setup database
21+
Base.metadata.reflect(bind=engine)
22+
Base.metadata.drop_all(bind=engine)
23+
Base.metadata.create_all(bind=engine)
24+
yield
25+
# Shutdown: no cleanup needed
26+
27+
28+
app = FastAPI(
29+
title="PeachIDE API",
30+
description="API for the PeachIDE platform",
31+
version="1.0.0",
32+
lifespan=lifespan,
33+
)
2234

23-
app = FastAPI()
35+
# Configure CORS
36+
app.add_middleware(
37+
CORSMiddleware,
38+
allow_origins=["*"], # Allows all origins in development
39+
allow_credentials=True,
40+
allow_methods=["*"], # Allows all methods
41+
allow_headers=["*"], # Allows all headers
42+
)
2443

44+
# Register routers
2545
app.include_router(comment_router, tags=["comments"], prefix="/api")
2646
app.include_router(material_router, tags=["materials"], prefix="/api")
2747
app.include_router(note_router, tags=["notes"], prefix="/api")
2848
app.include_router(code_snippet_router, tags=["code_snippets"], prefix="/api")
2949
app.include_router(bookmarklist_router, tags=["bookmarklists"], prefix="/api")
50+
app.include_router(
51+
auth_router,
52+
tags=["authentication"],
53+
prefix="/api",
54+
responses={401: {"description": "Unauthorized"}},
55+
)
3056

3157

32-
@app.on_event("startup")
33-
def startup():
34-
Base.metadata.reflect(bind=engine)
35-
Base.metadata.drop_all(bind=engine)
36-
Base.metadata.create_all(bind=engine)
37-
38-
39-
@app.get("/api")
58+
@app.get("/api", tags=["health"])
4059
async def root():
4160
return {"message": "Hello World"}

backend/app/models/code_snippet.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ class CodeSnippet(Base):
77
__tablename__ = "code_snippets"
88

99
snippet_id = Column(String, primary_key=True, index=True)
10-
material_id = Column(String, ForeignKey("materials.material_id"), index=True, nullable=False)
10+
material_id = Column(
11+
String, ForeignKey("materials.material_id"), index=True, nullable=False
12+
)
1113
lang = Column(String, nullable=False)
1214
page = Column(Integer, nullable=False)
1315
content = Column(String, nullable=False)

0 commit comments

Comments
 (0)