Skip to content

Commit 753293c

Browse files
authored
Basic multi tenant api key (#3004)
* basic multi tenant api key * organization * nit * clean
1 parent 6d543f3 commit 753293c

File tree

3 files changed

+84
-34
lines changed

3 files changed

+84
-34
lines changed

backend/ee/danswer/auth/api_key.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import secrets
22
import uuid
3+
from urllib.parse import quote
4+
from urllib.parse import unquote
35

46
from fastapi import Request
57
from passlib.hash import sha256_crypt
@@ -30,8 +32,35 @@ class ApiKeyDescriptor(BaseModel):
3032
user_id: uuid.UUID
3133

3234

33-
def generate_api_key() -> str:
34-
return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN)
35+
def generate_api_key(tenant_id: str | None = None) -> str:
36+
# For backwards compatibility, if no tenant_id, generate old style key
37+
if not tenant_id:
38+
return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN)
39+
40+
encoded_tenant = quote(tenant_id) # URL encode the tenant ID
41+
return f"{_API_KEY_PREFIX}{encoded_tenant}.{secrets.token_urlsafe(_API_KEY_LEN)}"
42+
43+
44+
def extract_tenant_from_api_key_header(request: Request) -> str | None:
45+
"""Extract tenant ID from request. Returns None if auth is disabled or invalid format."""
46+
raw_api_key_header = request.headers.get(
47+
_API_KEY_HEADER_ALTERNATIVE_NAME
48+
) or request.headers.get(_API_KEY_HEADER_NAME)
49+
50+
if not raw_api_key_header or not raw_api_key_header.startswith(_BEARER_PREFIX):
51+
return None
52+
53+
api_key = raw_api_key_header[len(_BEARER_PREFIX) :].strip()
54+
55+
if not api_key.startswith(_API_KEY_PREFIX):
56+
return None
57+
58+
parts = api_key[len(_API_KEY_PREFIX) :].split(".", 1)
59+
if len(parts) != 2:
60+
return None
61+
62+
tenant_id = parts[0]
63+
return unquote(tenant_id) if tenant_id else None
3564

3665

3766
def hash_api_key(api_key: str) -> str:

backend/ee/danswer/db/api_key.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from ee.danswer.auth.api_key import generate_api_key
1616
from ee.danswer.auth.api_key import hash_api_key
1717
from ee.danswer.server.api_key.models import APIKeyArgs
18+
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
19+
from shared_configs.configs import MULTI_TENANT
1820

1921

2022
def get_api_key_email_pattern() -> str:
@@ -64,7 +66,11 @@ def insert_api_key(
6466
db_session: Session, api_key_args: APIKeyArgs, user_id: uuid.UUID | None
6567
) -> ApiKeyDescriptor:
6668
std_password_helper = PasswordHelper()
67-
api_key = generate_api_key()
69+
70+
# Get tenant_id from context var (will be default schema for single tenant)
71+
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
72+
73+
api_key = generate_api_key(tenant_id if MULTI_TENANT else None)
6874
api_key_user_id = uuid.uuid4()
6975

7076
display_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER

backend/ee/danswer/server/middleware/tenant_tracking.py

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from danswer.configs.app_configs import USER_AUTH_SECRET
1212
from danswer.db.engine import is_valid_schema_name
13+
from ee.danswer.auth.api_key import extract_tenant_from_api_key_header
1314
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
1415
from shared_configs.configs import MULTI_TENANT
1516
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
@@ -21,40 +22,54 @@ async def set_tenant_id(
2122
request: Request, call_next: Callable[[Request], Awaitable[Response]]
2223
) -> Response:
2324
try:
24-
if not MULTI_TENANT:
25-
tenant_id = POSTGRES_DEFAULT_SCHEMA
26-
else:
27-
token = request.cookies.get("fastapiusersauth")
28-
29-
if token:
30-
try:
31-
payload = jwt.decode(
32-
token,
33-
USER_AUTH_SECRET,
34-
audience=["fastapi-users:auth"],
35-
algorithms=["HS256"],
36-
)
37-
tenant_id = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
38-
if not is_valid_schema_name(tenant_id):
39-
raise HTTPException(
40-
status_code=400, detail="Invalid tenant ID format"
41-
)
42-
except jwt.InvalidTokenError:
43-
tenant_id = POSTGRES_DEFAULT_SCHEMA
44-
except Exception as e:
45-
logger.error(
46-
f"Unexpected error in set_tenant_id_middleware: {str(e)}"
47-
)
48-
raise HTTPException(
49-
status_code=500, detail="Internal server error"
50-
)
51-
else:
52-
tenant_id = POSTGRES_DEFAULT_SCHEMA
25+
tenant_id = POSTGRES_DEFAULT_SCHEMA
26+
27+
if MULTI_TENANT:
28+
tenant_id = _get_tenant_id_from_request(request, logger)
5329

5430
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
55-
response = await call_next(request)
56-
return response
31+
return await call_next(request)
5732

5833
except Exception as e:
5934
logger.error(f"Error in tenant ID middleware: {str(e)}")
6035
raise
36+
37+
38+
def _get_tenant_id_from_request(request: Request, logger: logging.LoggerAdapter) -> str:
39+
# First check for API key
40+
tenant_id = extract_tenant_from_api_key_header(request)
41+
if tenant_id is not None:
42+
return tenant_id
43+
44+
# Check for cookie-based auth
45+
token = request.cookies.get("fastapiusersauth")
46+
if not token:
47+
return POSTGRES_DEFAULT_SCHEMA
48+
49+
try:
50+
payload = jwt.decode(
51+
token,
52+
USER_AUTH_SECRET,
53+
audience=["fastapi-users:auth"],
54+
algorithms=["HS256"],
55+
)
56+
tenant_id_from_payload = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
57+
58+
# Since payload.get() can return None, ensure we have a string
59+
tenant_id = (
60+
str(tenant_id_from_payload)
61+
if tenant_id_from_payload is not None
62+
else POSTGRES_DEFAULT_SCHEMA
63+
)
64+
65+
if not is_valid_schema_name(tenant_id):
66+
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
67+
68+
return tenant_id
69+
70+
except jwt.InvalidTokenError:
71+
return POSTGRES_DEFAULT_SCHEMA
72+
73+
except Exception as e:
74+
logger.error(f"Unexpected error in set_tenant_id_middleware: {str(e)}")
75+
raise HTTPException(status_code=500, detail="Internal server error")

0 commit comments

Comments
 (0)