Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions backend/ee/danswer/auth/api_key.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import secrets
import uuid
from urllib.parse import quote
from urllib.parse import unquote

from fastapi import Request
from passlib.hash import sha256_crypt
Expand Down Expand Up @@ -30,8 +32,35 @@ class ApiKeyDescriptor(BaseModel):
user_id: uuid.UUID


def generate_api_key() -> str:
return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN)
def generate_api_key(tenant_id: str | None = None) -> str:
# For backwards compatibility, if no tenant_id, generate old style key
if not tenant_id:
return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN)

encoded_tenant = quote(tenant_id) # URL encode the tenant ID
return f"{_API_KEY_PREFIX}{encoded_tenant}.{secrets.token_urlsafe(_API_KEY_LEN)}"


def extract_tenant_from_api_key_header(request: Request) -> str | None:
"""Extract tenant ID from request. Returns None if auth is disabled or invalid format."""
raw_api_key_header = request.headers.get(
_API_KEY_HEADER_ALTERNATIVE_NAME
) or request.headers.get(_API_KEY_HEADER_NAME)

if not raw_api_key_header or not raw_api_key_header.startswith(_BEARER_PREFIX):
return None

api_key = raw_api_key_header[len(_BEARER_PREFIX) :].strip()

if not api_key.startswith(_API_KEY_PREFIX):
return None

parts = api_key[len(_API_KEY_PREFIX) :].split(".", 1)
if len(parts) != 2:
return None

tenant_id = parts[0]
return unquote(tenant_id) if tenant_id else None


def hash_api_key(api_key: str) -> str:
Expand Down
8 changes: 7 additions & 1 deletion backend/ee/danswer/db/api_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from ee.danswer.auth.api_key import generate_api_key
from ee.danswer.auth.api_key import hash_api_key
from ee.danswer.server.api_key.models import APIKeyArgs
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT


def is_api_key_email_address(email: str) -> bool:
Expand Down Expand Up @@ -60,7 +62,11 @@ def insert_api_key(
db_session: Session, api_key_args: APIKeyArgs, user_id: uuid.UUID | None
) -> ApiKeyDescriptor:
std_password_helper = PasswordHelper()
api_key = generate_api_key()

# Get tenant_id from context var (will be default schema for single tenant)
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()

api_key = generate_api_key(tenant_id if MULTI_TENANT else None)
api_key_user_id = uuid.uuid4()

display_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER
Expand Down
77 changes: 46 additions & 31 deletions backend/ee/danswer/server/middleware/tenant_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from danswer.configs.app_configs import USER_AUTH_SECRET
from danswer.db.engine import is_valid_schema_name
from ee.danswer.auth.api_key import extract_tenant_from_api_key_header
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
Expand All @@ -21,40 +22,54 @@ async def set_tenant_id(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
try:
if not MULTI_TENANT:
tenant_id = POSTGRES_DEFAULT_SCHEMA
else:
token = request.cookies.get("fastapiusersauth")

if token:
try:
payload = jwt.decode(
token,
USER_AUTH_SECRET,
audience=["fastapi-users:auth"],
algorithms=["HS256"],
)
tenant_id = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
if not is_valid_schema_name(tenant_id):
raise HTTPException(
status_code=400, detail="Invalid tenant ID format"
)
except jwt.InvalidTokenError:
tenant_id = POSTGRES_DEFAULT_SCHEMA
except Exception as e:
logger.error(
f"Unexpected error in set_tenant_id_middleware: {str(e)}"
)
raise HTTPException(
status_code=500, detail="Internal server error"
)
else:
tenant_id = POSTGRES_DEFAULT_SCHEMA
tenant_id = POSTGRES_DEFAULT_SCHEMA

if MULTI_TENANT:
tenant_id = _get_tenant_id_from_request(request, logger)

CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
response = await call_next(request)
return response
return await call_next(request)

except Exception as e:
logger.error(f"Error in tenant ID middleware: {str(e)}")
raise


def _get_tenant_id_from_request(request: Request, logger: logging.LoggerAdapter) -> str:
# First check for API key
tenant_id = extract_tenant_from_api_key_header(request)
if tenant_id is not None:
return tenant_id

# Check for cookie-based auth
token = request.cookies.get("fastapiusersauth")
if not token:
return POSTGRES_DEFAULT_SCHEMA

try:
payload = jwt.decode(
token,
USER_AUTH_SECRET,
audience=["fastapi-users:auth"],
algorithms=["HS256"],
)
tenant_id_from_payload = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)

# Since payload.get() can return None, ensure we have a string
tenant_id = (
str(tenant_id_from_payload)
if tenant_id_from_payload is not None
else POSTGRES_DEFAULT_SCHEMA
)

if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID format")

return tenant_id

except jwt.InvalidTokenError:
return POSTGRES_DEFAULT_SCHEMA

except Exception as e:
logger.error(f"Unexpected error in set_tenant_id_middleware: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
Loading