Skip to content

WIP: Rewrite auth for generic OIDC logins #55

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@ KEYCLOAK_ADMIN=admin
KEYCLOAK_ADMIN_PASSWORD=admin123

# Fill this after you have created a realm and client in keycloak
OIDC_REALM=your_realm
OIDC_CLIENT_ID=your_client_id
OIDC_CLIENT_SECRET=your_client_secret
# DISCOVERY URL for your OIDC provider
# Example for Keycloak: http://<KEYCLOAK-ENDPOINT>/realms/<REALM-ID>/.well-known/openid-configuration
# Example for Authentik: http://<AUTHENTIK-ENDPOINT>/application/o/<PROVIDER-SLUG>/.well-known/openid-configuration
OIDC_DISCOVERY_URL=your_discovery_url

# Docker group id for coder, get it with: getent group docker | cut -d: -f3
DOCKER_GROUP_ID=your_docker_group_id
Expand Down
24 changes: 24 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [

{
"name": "Python Debugger: FastAPI",
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/src/backend",
"args": [
"main:app",
"--reload",
"--host", "0.0.0.0",
"--port", "8000"
],
"jinja": true,
"envFile": "${workspaceFolder}/.env"
}
]
}
9 changes: 2 additions & 7 deletions src/backend/config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import os
import json
import time
import httpx
import jwt
from jwt.jwks_client import PyJWKClient
from typing import Optional, Dict, Any, Tuple
from dotenv import load_dotenv
from cache import RedisClient

# Load environment variables once
load_dotenv()
Expand All @@ -28,10 +23,9 @@
POSTHOG_HOST = os.getenv("VITE_PUBLIC_POSTHOG_HOST")

# ===== OIDC Configuration =====
OIDC_DISCOVERY_URL = os.getenv("OIDC_DISCOVERY_URL")
OIDC_CLIENT_ID = os.getenv('OIDC_CLIENT_ID')
OIDC_CLIENT_SECRET = os.getenv('OIDC_CLIENT_SECRET')
OIDC_SERVER_URL = os.getenv('OIDC_SERVER_URL')
OIDC_REALM = os.getenv('OIDC_REALM')
OIDC_REDIRECT_URI = os.getenv('REDIRECT_URI')

default_pad = {}
Expand All @@ -48,6 +42,7 @@
# Cache for JWKS client
_jwks_client = None

# TODO deprecate this in favor of the newer implementation in dependencies.py
def get_jwks_client():
"""Get or create a PyJWKClient for token verification"""
global _jwks_client
Expand Down
22 changes: 17 additions & 5 deletions src/backend/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
import httpx
import jwt
from typing import Optional, Dict, Any, Tuple
from typing import Optional, Tuple
from uuid import UUID
import os
import asyncio
from sqlalchemy.ext.asyncio import AsyncSession

from fastapi import Request, HTTPException, Depends

from cache import RedisClient
from domain.session import Session
from domain.user import User
from domain.pad import Pad
from coder import CoderAPI
from database.database import get_session

# oidc_config for session creation and user sessions
oidc_config = {
'server_url': os.getenv('OIDC_SERVER_URL'),
'realm': os.getenv('OIDC_REALM'),
'discovery_url': os.getenv('OIDC_DISCOVERY_URL'),
'client_id': os.getenv('OIDC_CLIENT_ID'),
'client_secret': os.getenv('OIDC_CLIENT_SECRET'),
'redirect_uri': os.getenv('REDIRECT_URI')
Expand All @@ -26,6 +24,20 @@
async def get_session_domain() -> Session:
"""Get a Session domain instance for the current request."""
redis_client = await RedisClient.get_instance()

# TODO Optimize this to avoid fetching OIDC config on every request
async with httpx.AsyncClient() as client:
oidc_response = await client.get(oidc_config['discovery_url'])
if oidc_response.status_code != 200:
raise HTTPException(
status_code=500,
detail="Failed to fetch OIDC configuration"
)
oidc_config['authorization_endpoint'] = oidc_response.json().get('authorization_endpoint')
oidc_config['token_endpoint'] = oidc_response.json().get('token_endpoint')
oidc_config['end_session_endpoint'] = oidc_response.json().get('end_session_endpoint')
oidc_config['jwks_uri'] = oidc_response.json().get('jwks_uri')

return Session(redis_client, oidc_config)

class UserSession:
Expand Down
9 changes: 4 additions & 5 deletions src/backend/domain/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ def get_auth_url(self) -> str:
Returns:
The authentication URL
"""
auth_url = f"{self.oidc_config['server_url']}/realms/{self.oidc_config['realm']}/protocol/openid-connect/auth"
auth_url = self.oidc_config['authorization_endpoint']
params = {
'client_id': self.oidc_config['client_id'],
'response_type': 'code',
'redirect_uri': self.oidc_config['redirect_uri'],
'scope': 'openid profile email'
'scope': 'openid profile email offline_access'
}
return f"{auth_url}?{'&'.join(f'{k}={v}' for k,v in params.items())}"

Expand All @@ -104,7 +104,7 @@ def get_token_url(self) -> str:
Returns:
The token endpoint URL
"""
return f"{self.oidc_config['server_url']}/realms/{self.oidc_config['realm']}/protocol/openid-connect/token"
return self.oidc_config['token_endpoint']

def is_token_expired(self, token_data: Dict[str, Any], buffer_seconds: int = 30) -> bool:
"""
Expand Down Expand Up @@ -195,8 +195,7 @@ def _get_jwks_client(self) -> PyJWKClient:
The JWKs client
"""
if self._jwks_client is None:
jwks_url = f"{self.oidc_config['server_url']}/realms/{self.oidc_config['realm']}/protocol/openid-connect/certs"
self._jwks_client = PyJWKClient(jwks_url)
self._jwks_client = PyJWKClient(self.oidc_config['jwks_uri'])
return self._jwks_client

async def track_event(self, session_id: str, event_type: str, metadata: Dict[str, Any] = None) -> bool:
Expand Down
6 changes: 5 additions & 1 deletion src/backend/domain/user.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from random import Random
from uuid import UUID
from typing import Dict, Any, Optional, List
from datetime import datetime
Expand Down Expand Up @@ -167,7 +168,10 @@ async def get_open_pads(cls, session: AsyncSession, user_id: UUID) -> List[Dict[
@classmethod
async def ensure_exists(cls, session: AsyncSession, user_info: dict) -> 'User':
"""Ensure a user exists in the database, creating them if they don't"""
user_id = UUID(user_info['sub'])
# Certain OIDC don't provide 'sub' in user_info as UUID.
# So we have to generate a UUID based on the user 'sub' to ensure consistency
rng = Random(user_info['sub'])
user_id = UUID(int=rng.getrandbits(123))
user = await cls.get_by_id(session, user_id)

if not user:
Expand Down
164 changes: 164 additions & 0 deletions src/backend/routers/auth_router.old.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import secrets
import httpx
from fastapi import APIRouter, Depends, Request, HTTPException
from fastapi.responses import RedirectResponse, FileResponse, JSONResponse
import os

from config import (get_auth_url, get_token_url, set_session, delete_session, get_session,
FRONTEND_URL, OIDC_CLIENT_ID, OIDC_CLIENT_SECRET, OIDC_SERVER_URL, OIDC_REALM, OIDC_REDIRECT_URI, STATIC_DIR)
from dependencies import get_coder_api
from coder import CoderAPI
from dependencies import optional_auth, UserSession
from domain.session import Session
from database.database import async_session
from domain.user import User

auth_router = APIRouter()

@auth_router.get("/login")
async def login(request: Request, kc_idp_hint: str = None, popup: str = None):

session_id = secrets.token_urlsafe(32)

auth_url = get_auth_url()
state = "popup" if popup == "1" else "default"
if kc_idp_hint:
auth_url = f"{auth_url}&kc_idp_hint={kc_idp_hint}"
# Add state param to OIDC URL
auth_url = f"{auth_url}&state={state}"
response = RedirectResponse(auth_url)
response.set_cookie("session_id", session_id)


@auth_router.get("/login")
async def newLogin(request: Request, popup: str = None):
session_id = secrets.token_urlsafe(32)

authorization_url = get_auth_url()
if popup == "1":
authorization_url += "&state=popup"

response = RedirectResponse(authorization_url)
response.set_cookie("session_id", session_id)
return response


@auth_router.get("/callback")
async def callback(
request: Request,
code: str,
state: str = "default",
coder_api: CoderAPI = Depends(get_coder_api)
):
session_id = request.cookies.get("session_id")
if not session_id:
raise HTTPException(status_code=400, detail="No session")

# Exchange authorization code for access token
async with httpx.AsyncClient() as client:
token_response = await client.post(
get_token_url(),
data={
'grant_type': 'authorization_code',
'client_id': OIDC_CLIENT_ID,
'client_secret': OIDC_CLIENT_SECRET,
'code': code,
'redirect_uri': OIDC_REDIRECT_URI
}
)

if token_response.status_code != 200:
raise HTTPException(status_code=400, detail="Auth failed")

token_data = token_response.json()
expiry = token_data['expires_in']
set_session(session_id, token_data, expiry)
access_token = token_data['access_token']
user_info = jwt.decode(access_token, options={"verify_signature": False})

try:
user_data, _ = coder_api.ensure_user_exists(user_info)
coder_api.ensure_workspace_exists(user_data["username"])
except Exception as e:
print(f"Error in user/workspace setup: {str(e)}")
# Continue with login even if Coder API fails

if state == "popup":
return FileResponse(os.path.join(STATIC_DIR, "auth/popup-close.html"))
else:
return RedirectResponse("/")


@auth_router.get("/logout")
async def logout(request: Request):
session_id = request.cookies.get('session_id')

session_data = get_session(session_id)
if not session_data:
return RedirectResponse('/')

id_token = session_data.get('id_token', '')

# Delete the session from Redis
delete_session(session_id)

# Create the Keycloak logout URL with redirect back to our app
logout_url = f"{OIDC_SERVER_URL}/realms/{OIDC_REALM}/protocol/openid-connect/logout"
full_logout_url = f"{logout_url}?id_token_hint={id_token}&post_logout_redirect_uri={FRONTEND_URL}"

# Create a redirect response to Keycloak's logout endpoint
response = JSONResponse({"status": "success", "logout_url": full_logout_url})

return response

@auth_router.get("/status")
async def auth_status(
user_session: Optional[UserSession] = Depends(optional_auth)
):
"""Check if the user is authenticated and return session information"""
if not user_session:
return JSONResponse({
"authenticated": False,
"message": "Not authenticated"
})

try:
expires_in = user_session.token_data.get('exp') - time.time()

return JSONResponse({
"authenticated": True,
"user": {
"id": str(user_session.id),
"username": user_session.username,
"email": user_session.email,
"name": user_session.name
},
"expires_in": expires_in
})
except Exception as e:
return JSONResponse({
"authenticated": False,
"message": f"Error processing session: {str(e)}"
})

@auth_router.post("/refresh")
async def refresh_session(request: Request, session_domain: Session = Depends(get_session_domain)):
"""Refresh the current session's access token"""
session_id = request.cookies.get('session_id')
if not session_id:
raise HTTPException(status_code=401, detail="No session found")

session_data = await session_domain.get(session_id)
if not session_data:
raise HTTPException(status_code=401, detail="Invalid session")

# Try to refresh the token
success, new_token_data = await session_domain.refresh_token(session_id, session_data)
if not success:
raise HTTPException(status_code=401, detail="Failed to refresh session")

# Return the new expiry time
return JSONResponse({
"expires_in": new_token_data.get('expires_in'),
"authenticated": True
})
9 changes: 6 additions & 3 deletions src/backend/routers/auth_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ async def login(
auth_url = session_domain.get_auth_url()
state = "popup" if popup == "1" else "default"

# TODO: Handle kc_idp_hint properly for other identity providers
if kc_idp_hint:
auth_url = f"{auth_url}&kc_idp_hint={kc_idp_hint}"

# Add state param to OIDC URL
auth_url = f"{auth_url}&state={state}"

Expand Down Expand Up @@ -70,7 +70,10 @@ async def callback(
raise HTTPException(status_code=400, detail="Auth failed")

token_data = token_response.json()
expiry = token_data['refresh_expires_in']
# TODO OAuth2 spec doesn’t require providers to expose refresh token lifespan to clients
# expiry = token_data['refresh_expires_in']
# for now we default to expires_in if refresh_expires_in is not available
expiry = token_data.get('refresh_expires_in', token_data.get('expires_in'))

# Store the token data in Redis
success = await session_domain.set(session_id, token_data, expiry)
Expand Down Expand Up @@ -130,7 +133,7 @@ async def logout(request: Request, session_domain: Session = Depends(get_session
print(f"Warning: Failed to delete session {session_id}")

# Create the Keycloak logout URL with redirect back to our app
logout_url = f"{session_domain.oidc_config['server_url']}/realms/{session_domain.oidc_config['realm']}/protocol/openid-connect/logout"
logout_url = session_domain.oidc_config['end_session_endpoint']
full_logout_url = f"{logout_url}?id_token_hint={id_token}&post_logout_redirect_uri={FRONTEND_URL}"

# Create a response with the logout URL and clear the session cookie
Expand Down