|
33 | 33 | from fastapi_users.openapi import OpenAPIResponseType
|
34 | 34 | from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
|
35 | 35 | from sqlalchemy import select
|
| 36 | +from sqlalchemy.orm import attributes |
36 | 37 | from sqlalchemy.orm import Session
|
37 | 38 |
|
38 | 39 | from danswer.auth.invited_users import get_invited_users
|
@@ -298,114 +299,102 @@ async def oauth_callback(
|
298 | 299 | self.user_db = tenant_user_db
|
299 | 300 | self.database = tenant_user_db
|
300 | 301 |
|
301 |
| - # verify_email_in_whitelist(account_email) |
302 |
| - # verify_email_domain(account_email) |
| 302 | + logger.info(f"Starting OAuth callback process for email: {account_email}") |
| 303 | + oauth_account_dict = { |
| 304 | + "oauth_name": oauth_name, |
| 305 | + "access_token": access_token, |
| 306 | + "account_id": account_id, |
| 307 | + "account_email": account_email, |
| 308 | + "expires_at": expires_at, |
| 309 | + "refresh_token": refresh_token, |
| 310 | + } |
| 311 | + logger.debug(f"OAuth account dict created: {oauth_account_dict}") |
303 | 312 |
|
304 | 313 | try:
|
305 | 314 | logger.info(
|
306 |
| - f"Starting OAuth callback process for email: {account_email}" |
| 315 | + f"Attempting to get user by OAuth account: {oauth_name}, {account_id}" |
| 316 | + ) |
| 317 | + user = await self.get_by_oauth_account(oauth_name, account_id) |
| 318 | + logger.info(f"User found by OAuth account: {user.id}") |
| 319 | + except exceptions.UserNotExists: |
| 320 | + logger.info( |
| 321 | + f"User not found by OAuth account, attempting to get by email: {account_email}" |
307 | 322 | )
|
308 |
| - oauth_account_dict = { |
309 |
| - "oauth_name": oauth_name, |
310 |
| - "access_token": access_token, |
311 |
| - "account_id": account_id, |
312 |
| - "account_email": account_email, |
313 |
| - "expires_at": expires_at, |
314 |
| - "refresh_token": refresh_token, |
315 |
| - } |
316 |
| - logger.debug(f"OAuth account dict created: {oauth_account_dict}") |
317 |
| - |
318 | 323 | try:
|
319 |
| - logger.info( |
320 |
| - f"Attempting to get user by OAuth account: {oauth_name}, {account_id}" |
| 324 | + user = await self.get_by_email(account_email) |
| 325 | + logger.info(f"User found by email: {user.id}") |
| 326 | + if not associate_by_email: |
| 327 | + logger.warning( |
| 328 | + f"User already exists but associate_by_email is False: {account_email}" |
| 329 | + ) |
| 330 | + raise exceptions.UserAlreadyExists() |
| 331 | + logger.info(f"Adding OAuth account to existing user: {user.id}") |
| 332 | + user = await self.user_db.add_oauth_account( |
| 333 | + user, oauth_account_dict |
321 | 334 | )
|
322 |
| - user = await self.get_by_oauth_account(oauth_name, account_id) |
323 |
| - logger.info(f"User found by OAuth account: {user.id}") |
| 335 | + logger.info(f"OAuth account added to user: {user.id}") |
324 | 336 | except exceptions.UserNotExists:
|
325 | 337 | logger.info(
|
326 |
| - f"User not found by OAuth account, attempting to get by email: {account_email}" |
| 338 | + f"User not found, creating new account for: {account_email}" |
327 | 339 | )
|
328 |
| - try: |
329 |
| - # Associate account |
330 |
| - user = await self.get_by_email(account_email) |
331 |
| - logger.info(f"User found by email: {user.id}") |
332 |
| - if not associate_by_email: |
333 |
| - logger.warning( |
334 |
| - f"User already exists but associate_by_email is False: {account_email}" |
335 |
| - ) |
336 |
| - raise exceptions.UserAlreadyExists() |
337 |
| - logger.info(f"Adding OAuth account to existing user: {user.id}") |
338 |
| - user = await self.user_db.add_oauth_account( |
339 |
| - user, oauth_account_dict |
340 |
| - ) |
341 |
| - logger.info(f"OAuth account added to user: {user.id}") |
342 |
| - except exceptions.UserNotExists: |
| 340 | + password = self.password_helper.generate() |
| 341 | + user_dict = { |
| 342 | + "email": account_email, |
| 343 | + "hashed_password": self.password_helper.hash(password), |
| 344 | + "is_verified": is_verified_by_default, |
| 345 | + } |
| 346 | + logger.debug(f"Creating new user with dict: {user_dict}") |
| 347 | + user = await self.user_db.create(user_dict) |
| 348 | + logger.info(f"New user created: {user.id}") |
| 349 | + user = await self.user_db.add_oauth_account( |
| 350 | + user, oauth_account_dict |
| 351 | + ) |
| 352 | + logger.info(f"OAuth account added to new user: {user.id}") |
| 353 | + await self.on_after_register(user, request) |
| 354 | + else: |
| 355 | + logger.info(f"Updating OAuth account for existing user: {user.id}") |
| 356 | + for existing_oauth_account in user.oauth_accounts: |
| 357 | + if ( |
| 358 | + existing_oauth_account.account_id == account_id |
| 359 | + and existing_oauth_account.oauth_name == oauth_name |
| 360 | + ): |
343 | 361 | logger.info(
|
344 |
| - f"User not found, creating new account for: {account_email}" |
| 362 | + f"Updating OAuth account: {oauth_name}, {account_id}" |
345 | 363 | )
|
346 |
| - # Create account |
347 |
| - password = self.password_helper.generate() |
348 |
| - user_dict = { |
349 |
| - "email": account_email, |
350 |
| - "hashed_password": self.password_helper.hash(password), |
351 |
| - "is_verified": is_verified_by_default, |
352 |
| - } |
353 |
| - logger.debug(f"Creating new user with dict: {user_dict}") |
354 |
| - user = await self.user_db.create(user_dict) |
355 |
| - logger.info(f"New user created: {user.id}") |
356 |
| - logger.info(f"Adding OAuth account to new user: {user.id}") |
357 |
| - user = await self.user_db.add_oauth_account( |
358 |
| - user, oauth_account_dict |
| 364 | + user = await self.user_db.update_oauth_account( |
| 365 | + user, existing_oauth_account, oauth_account_dict |
359 | 366 | )
|
360 |
| - logger.info(f"OAuth account added to new user: {user.id}") |
361 |
| - logger.info( |
362 |
| - f"Calling on_after_register for new user: {user.id}" |
363 |
| - ) |
364 |
| - await self.on_after_register(user, request) |
365 |
| - else: |
366 |
| - # Update oauth |
367 |
| - logger.info(f"Updating OAuth account for existing user: {user.id}") |
368 |
| - for existing_oauth_account in user.oauth_accounts: |
369 |
| - if ( |
370 |
| - existing_oauth_account.account_id == account_id |
371 |
| - and existing_oauth_account.oauth_name == oauth_name |
372 |
| - ): |
373 |
| - logger.info( |
374 |
| - f"Updating OAuth account: {oauth_name}, {account_id}" |
375 |
| - ) |
376 |
| - user = await self.user_db.update_oauth_account( |
377 |
| - user, existing_oauth_account, oauth_account_dict |
378 |
| - ) |
379 |
| - logger.info(f"OAuth account updated for user: {user.id}") |
| 367 | + logger.info(f"OAuth account updated for user: {user.id}") |
380 | 368 |
|
381 |
| - except Exception as e: |
382 |
| - logger.exception(f"Error in oauth_callback: {str(e)}") |
| 369 | + logger.info("OAuth callback completed") |
383 | 370 |
|
384 |
| - print("OAUTH CALLBACK COMPLETED") |
385 |
| - # NOTE: Most IdPs have very short expiry times, and we don't want to force the user to |
386 |
| - # re-authenticate that frequently, so by default this is disabled |
387 |
| - if expires_at and TRACK_EXTERNAL_IDP_EXPIRY: |
388 |
| - oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc) |
389 |
| - await self.user_db.update( |
390 |
| - user, update_dict={"oidc_expiry": oidc_expiry} |
391 |
| - ) |
392 |
| - |
393 |
| - # this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false` |
394 |
| - # otherwise, the oidc expiry will always be old, and the user will never be able to login |
395 |
| - if user.oidc_expiry and not TRACK_EXTERNAL_IDP_EXPIRY: |
396 |
| - await self.user_db.update(user, update_dict={"oidc_expiry": None}) |
397 |
| - |
398 |
| - # Handle case where user has used product outside of web and is now creating an account through web |
399 |
| - if not user.has_web_login: |
400 |
| - await self.user_db.update( |
401 |
| - user, |
402 |
| - update_dict={ |
| 371 | + try: |
| 372 | + if not user.has_web_login: |
| 373 | + update_dict = { |
403 | 374 | "is_verified": is_verified_by_default,
|
404 | 375 | "has_web_login": True,
|
405 |
| - }, |
406 |
| - ) |
407 |
| - user.is_verified = is_verified_by_default |
408 |
| - user.has_web_login = True |
| 376 | + } |
| 377 | + await self.user_db.update(user, update_dict) |
| 378 | + user.is_verified = is_verified_by_default |
| 379 | + user.has_web_login = True |
| 380 | + |
| 381 | + if expires_at and TRACK_EXTERNAL_IDP_EXPIRY: |
| 382 | + oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc) |
| 383 | + await self.user_db.update( |
| 384 | + user, update_dict={"oidc_expiry": oidc_expiry} |
| 385 | + ) |
| 386 | + |
| 387 | + if ( |
| 388 | + hasattr(user, "oidc_expiry") |
| 389 | + and user.oidc_expiry is not None |
| 390 | + and not TRACK_EXTERNAL_IDP_EXPIRY |
| 391 | + ): |
| 392 | + update_dict = {"oidc_expiry": None} |
| 393 | + await self.user_db.update(user, update_dict) |
| 394 | + user.oidc_expiry = None |
| 395 | + |
| 396 | + except Exception as e: |
| 397 | + logger.exception(f"Error in oauth_callback: {str(e)}") |
409 | 398 |
|
410 | 399 | return user
|
411 | 400 |
|
@@ -462,7 +451,9 @@ async def authenticate(
|
462 | 451 | self.password_helper.hash(credentials.password)
|
463 | 452 | return None
|
464 | 453 |
|
465 |
| - if not user.has_web_login: |
| 454 | + has_web_login = attributes.get_attribute(user, "has_web_login") |
| 455 | + |
| 456 | + if not has_web_login: |
466 | 457 | raise HTTPException(
|
467 | 458 | status_code=status.HTTP_403_FORBIDDEN,
|
468 | 459 | detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
|
|
0 commit comments