|
93 | 93 | from danswer.utils.telemetry import optional_telemetry
|
94 | 94 | from danswer.utils.telemetry import RecordType
|
95 | 95 | from danswer.utils.variable_functionality import fetch_versioned_implementation
|
96 |
| -from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR |
97 | 96 | from shared_configs.configs import MULTI_TENANT
|
98 | 97 | from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
| 98 | +from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR |
99 | 99 |
|
100 | 100 |
|
101 | 101 | logger = setup_logger()
|
@@ -510,19 +510,23 @@ async def get_user_manager(
|
510 | 510 |
|
511 | 511 | # This strategy is used to add tenant_id to the JWT token
|
512 | 512 | class TenantAwareJWTStrategy(JWTStrategy):
|
513 |
| - async def write_token(self, user: User) -> str: |
| 513 | + async def _create_token_data(self, user: User, impersonate: bool = False) -> dict: |
514 | 514 | tenant_id = get_tenant_id_for_email(user.email)
|
515 | 515 | data = {
|
516 | 516 | "sub": str(user.id),
|
517 | 517 | "aud": self.token_audience,
|
518 | 518 | "tenant_id": tenant_id,
|
519 | 519 | }
|
| 520 | + return data |
| 521 | + |
| 522 | + async def write_token(self, user: User) -> str: |
| 523 | + data = await self._create_token_data(user) |
520 | 524 | return generate_jwt(
|
521 | 525 | data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm
|
522 | 526 | )
|
523 | 527 |
|
524 | 528 |
|
525 |
| -def get_jwt_strategy() -> JWTStrategy: |
| 529 | +def get_jwt_strategy() -> TenantAwareJWTStrategy: |
526 | 530 | return TenantAwareJWTStrategy(
|
527 | 531 | secret=USER_AUTH_SECRET,
|
528 | 532 | lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,
|
|
0 commit comments