48
48
from httpx_oauth .oauth2 import BaseOAuth2
49
49
from httpx_oauth .oauth2 import OAuth2Token
50
50
from pydantic import BaseModel
51
- from sqlalchemy import select
52
51
from sqlalchemy import text
53
52
from sqlalchemy .orm import attributes
54
53
from sqlalchemy .orm import Session
83
82
from danswer .db .engine import get_async_session_with_tenant
84
83
from danswer .db .engine import get_session
85
84
from danswer .db .engine import get_session_with_tenant
86
- from danswer .db .engine import get_sqlalchemy_engine
87
85
from danswer .db .models import AccessToken
88
86
from danswer .db .models import OAuthAccount
89
87
from danswer .db .models import User
90
- from danswer .db .models import UserTenantMapping
91
88
from danswer .db .users import get_user_by_email
92
89
from danswer .utils .logger import setup_logger
93
90
from danswer .utils .telemetry import optional_telemetry
94
91
from danswer .utils .telemetry import RecordType
95
92
from danswer .utils .variable_functionality import fetch_versioned_implementation
93
+ from ee .danswer .server .tenants .provisioning import get_or_create_tenant_id
94
+ from ee .danswer .server .tenants .user_mapping import get_tenant_id_for_email
96
95
from shared_configs .configs import MULTI_TENANT
97
- from shared_configs .configs import POSTGRES_DEFAULT_SCHEMA
98
96
from shared_configs .contextvars import CURRENT_TENANT_ID_CONTEXTVAR
99
97
100
-
101
98
logger = setup_logger ()
102
99
103
100
@@ -190,20 +187,6 @@ def verify_email_domain(email: str) -> None:
190
187
)
191
188
192
189
193
- def get_tenant_id_for_email (email : str ) -> str :
194
- if not MULTI_TENANT :
195
- return POSTGRES_DEFAULT_SCHEMA
196
- # Implement logic to get tenant_id from the mapping table
197
- with Session (get_sqlalchemy_engine ()) as db_session :
198
- result = db_session .execute (
199
- select (UserTenantMapping .tenant_id ).where (UserTenantMapping .email == email )
200
- )
201
- tenant_id = result .scalar_one_or_none ()
202
- if tenant_id is None :
203
- raise exceptions .UserNotExists ()
204
- return tenant_id
205
-
206
-
207
190
def send_user_verification_email (
208
191
user_email : str ,
209
192
token : str ,
@@ -238,19 +221,7 @@ async def create(
238
221
safe : bool = False ,
239
222
request : Optional [Request ] = None ,
240
223
) -> User :
241
- try :
242
- tenant_id = (
243
- get_tenant_id_for_email (user_create .email )
244
- if MULTI_TENANT
245
- else POSTGRES_DEFAULT_SCHEMA
246
- )
247
- except exceptions .UserNotExists :
248
- raise HTTPException (status_code = 401 , detail = "User not found" )
249
-
250
- if not tenant_id :
251
- raise HTTPException (
252
- status_code = 401 , detail = "User does not belong to an organization"
253
- )
224
+ tenant_id = await get_or_create_tenant_id (user_create .email )
254
225
255
226
async with get_async_session_with_tenant (tenant_id ) as db_session :
256
227
token = CURRENT_TENANT_ID_CONTEXTVAR .set (tenant_id )
@@ -271,7 +242,7 @@ async def create(
271
242
user_create .role = UserRole .ADMIN
272
243
else :
273
244
user_create .role = UserRole .BASIC
274
- user = None
245
+
275
246
try :
276
247
user = await super ().create (user_create , safe = safe , request = request ) # type: ignore
277
248
except exceptions .UserAlreadyExists :
@@ -292,7 +263,9 @@ async def create(
292
263
else :
293
264
raise exceptions .UserAlreadyExists ()
294
265
295
- CURRENT_TENANT_ID_CONTEXTVAR .reset (token )
266
+ finally :
267
+ CURRENT_TENANT_ID_CONTEXTVAR .reset (token )
268
+
296
269
return user
297
270
298
271
async def oauth_callback (
@@ -308,19 +281,12 @@ async def oauth_callback(
308
281
associate_by_email : bool = False ,
309
282
is_verified_by_default : bool = False ,
310
283
) -> models .UOAP :
311
- # Get tenant_id from mapping table
312
- try :
313
- tenant_id = (
314
- get_tenant_id_for_email (account_email )
315
- if MULTI_TENANT
316
- else POSTGRES_DEFAULT_SCHEMA
317
- )
318
- except exceptions .UserNotExists :
319
- raise HTTPException (status_code = 401 , detail = "User not found" )
284
+ tenant_id = await get_or_create_tenant_id (account_email )
320
285
321
286
if not tenant_id :
322
287
raise HTTPException (status_code = 401 , detail = "User not found" )
323
288
289
+ # Proceed with the tenant context
324
290
token = None
325
291
async with get_async_session_with_tenant (tenant_id ) as db_session :
326
292
token = CURRENT_TENANT_ID_CONTEXTVAR .set (tenant_id )
@@ -371,9 +337,9 @@ async def oauth_callback(
371
337
# Explicitly set the Postgres schema for this session to ensure
372
338
# OAuth account creation happens in the correct tenant schema
373
339
await db_session .execute (text (f'SET search_path = "{ tenant_id } "' ))
374
- user = await self . user_db . add_oauth_account (
375
- user , oauth_account_dict
376
- )
340
+
341
+ # Add OAuth account
342
+ await self . user_db . add_oauth_account ( user , oauth_account_dict )
377
343
await self .on_after_register (user , request )
378
344
379
345
else :
0 commit comments