10
10
11
11
from danswer .configs .app_configs import USER_AUTH_SECRET
12
12
from danswer .db .engine import is_valid_schema_name
13
+ from ee .danswer .auth .api_key import extract_tenant_from_api_key_header
13
14
from shared_configs .configs import CURRENT_TENANT_ID_CONTEXTVAR
14
15
from shared_configs .configs import MULTI_TENANT
15
16
from shared_configs .configs import POSTGRES_DEFAULT_SCHEMA
@@ -21,40 +22,54 @@ async def set_tenant_id(
21
22
request : Request , call_next : Callable [[Request ], Awaitable [Response ]]
22
23
) -> Response :
23
24
try :
24
- if not MULTI_TENANT :
25
- tenant_id = POSTGRES_DEFAULT_SCHEMA
26
- else :
27
- token = request .cookies .get ("fastapiusersauth" )
28
-
29
- if token :
30
- try :
31
- payload = jwt .decode (
32
- token ,
33
- USER_AUTH_SECRET ,
34
- audience = ["fastapi-users:auth" ],
35
- algorithms = ["HS256" ],
36
- )
37
- tenant_id = payload .get ("tenant_id" , POSTGRES_DEFAULT_SCHEMA )
38
- if not is_valid_schema_name (tenant_id ):
39
- raise HTTPException (
40
- status_code = 400 , detail = "Invalid tenant ID format"
41
- )
42
- except jwt .InvalidTokenError :
43
- tenant_id = POSTGRES_DEFAULT_SCHEMA
44
- except Exception as e :
45
- logger .error (
46
- f"Unexpected error in set_tenant_id_middleware: { str (e )} "
47
- )
48
- raise HTTPException (
49
- status_code = 500 , detail = "Internal server error"
50
- )
51
- else :
52
- tenant_id = POSTGRES_DEFAULT_SCHEMA
25
+ tenant_id = POSTGRES_DEFAULT_SCHEMA
26
+
27
+ if MULTI_TENANT :
28
+ tenant_id = _get_tenant_id_from_request (request , logger )
53
29
54
30
CURRENT_TENANT_ID_CONTEXTVAR .set (tenant_id )
55
- response = await call_next (request )
56
- return response
31
+ return await call_next (request )
57
32
58
33
except Exception as e :
59
34
logger .error (f"Error in tenant ID middleware: { str (e )} " )
60
35
raise
36
+
37
+
38
+ def _get_tenant_id_from_request (request : Request , logger : logging .LoggerAdapter ) -> str :
39
+ # First check for API key
40
+ tenant_id = extract_tenant_from_api_key_header (request )
41
+ if tenant_id is not None :
42
+ return tenant_id
43
+
44
+ # Check for cookie-based auth
45
+ token = request .cookies .get ("fastapiusersauth" )
46
+ if not token :
47
+ return POSTGRES_DEFAULT_SCHEMA
48
+
49
+ try :
50
+ payload = jwt .decode (
51
+ token ,
52
+ USER_AUTH_SECRET ,
53
+ audience = ["fastapi-users:auth" ],
54
+ algorithms = ["HS256" ],
55
+ )
56
+ tenant_id_from_payload = payload .get ("tenant_id" , POSTGRES_DEFAULT_SCHEMA )
57
+
58
+ # Since payload.get() can return None, ensure we have a string
59
+ tenant_id = (
60
+ str (tenant_id_from_payload )
61
+ if tenant_id_from_payload is not None
62
+ else POSTGRES_DEFAULT_SCHEMA
63
+ )
64
+
65
+ if not is_valid_schema_name (tenant_id ):
66
+ raise HTTPException (status_code = 400 , detail = "Invalid tenant ID format" )
67
+
68
+ return tenant_id
69
+
70
+ except jwt .InvalidTokenError :
71
+ return POSTGRES_DEFAULT_SCHEMA
72
+
73
+ except Exception as e :
74
+ logger .error (f"Unexpected error in set_tenant_id_middleware: { str (e )} " )
75
+ raise HTTPException (status_code = 500 , detail = "Internal server error" )
0 commit comments