Skip to content

Commit 37237d0

Browse files
committed
clean up user manager class
1 parent 9a2fefa commit 37237d0

File tree

1 file changed

+84
-93
lines changed

1 file changed

+84
-93
lines changed

backend/danswer/auth/users.py

Lines changed: 84 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from fastapi_users.openapi import OpenAPIResponseType
3434
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
3535
from sqlalchemy import select
36+
from sqlalchemy.orm import attributes
3637
from sqlalchemy.orm import Session
3738

3839
from danswer.auth.invited_users import get_invited_users
@@ -298,114 +299,102 @@ async def oauth_callback(
298299
self.user_db = tenant_user_db
299300
self.database = tenant_user_db
300301

301-
# verify_email_in_whitelist(account_email)
302-
# verify_email_domain(account_email)
302+
logger.info(f"Starting OAuth callback process for email: {account_email}")
303+
oauth_account_dict = {
304+
"oauth_name": oauth_name,
305+
"access_token": access_token,
306+
"account_id": account_id,
307+
"account_email": account_email,
308+
"expires_at": expires_at,
309+
"refresh_token": refresh_token,
310+
}
311+
logger.debug(f"OAuth account dict created: {oauth_account_dict}")
303312

304313
try:
305314
logger.info(
306-
f"Starting OAuth callback process for email: {account_email}"
315+
f"Attempting to get user by OAuth account: {oauth_name}, {account_id}"
316+
)
317+
user = await self.get_by_oauth_account(oauth_name, account_id)
318+
logger.info(f"User found by OAuth account: {user.id}")
319+
except exceptions.UserNotExists:
320+
logger.info(
321+
f"User not found by OAuth account, attempting to get by email: {account_email}"
307322
)
308-
oauth_account_dict = {
309-
"oauth_name": oauth_name,
310-
"access_token": access_token,
311-
"account_id": account_id,
312-
"account_email": account_email,
313-
"expires_at": expires_at,
314-
"refresh_token": refresh_token,
315-
}
316-
logger.debug(f"OAuth account dict created: {oauth_account_dict}")
317-
318323
try:
319-
logger.info(
320-
f"Attempting to get user by OAuth account: {oauth_name}, {account_id}"
324+
user = await self.get_by_email(account_email)
325+
logger.info(f"User found by email: {user.id}")
326+
if not associate_by_email:
327+
logger.warning(
328+
f"User already exists but associate_by_email is False: {account_email}"
329+
)
330+
raise exceptions.UserAlreadyExists()
331+
logger.info(f"Adding OAuth account to existing user: {user.id}")
332+
user = await self.user_db.add_oauth_account(
333+
user, oauth_account_dict
321334
)
322-
user = await self.get_by_oauth_account(oauth_name, account_id)
323-
logger.info(f"User found by OAuth account: {user.id}")
335+
logger.info(f"OAuth account added to user: {user.id}")
324336
except exceptions.UserNotExists:
325337
logger.info(
326-
f"User not found by OAuth account, attempting to get by email: {account_email}"
338+
f"User not found, creating new account for: {account_email}"
327339
)
328-
try:
329-
# Associate account
330-
user = await self.get_by_email(account_email)
331-
logger.info(f"User found by email: {user.id}")
332-
if not associate_by_email:
333-
logger.warning(
334-
f"User already exists but associate_by_email is False: {account_email}"
335-
)
336-
raise exceptions.UserAlreadyExists()
337-
logger.info(f"Adding OAuth account to existing user: {user.id}")
338-
user = await self.user_db.add_oauth_account(
339-
user, oauth_account_dict
340-
)
341-
logger.info(f"OAuth account added to user: {user.id}")
342-
except exceptions.UserNotExists:
340+
password = self.password_helper.generate()
341+
user_dict = {
342+
"email": account_email,
343+
"hashed_password": self.password_helper.hash(password),
344+
"is_verified": is_verified_by_default,
345+
}
346+
logger.debug(f"Creating new user with dict: {user_dict}")
347+
user = await self.user_db.create(user_dict)
348+
logger.info(f"New user created: {user.id}")
349+
user = await self.user_db.add_oauth_account(
350+
user, oauth_account_dict
351+
)
352+
logger.info(f"OAuth account added to new user: {user.id}")
353+
await self.on_after_register(user, request)
354+
else:
355+
logger.info(f"Updating OAuth account for existing user: {user.id}")
356+
for existing_oauth_account in user.oauth_accounts:
357+
if (
358+
existing_oauth_account.account_id == account_id
359+
and existing_oauth_account.oauth_name == oauth_name
360+
):
343361
logger.info(
344-
f"User not found, creating new account for: {account_email}"
362+
f"Updating OAuth account: {oauth_name}, {account_id}"
345363
)
346-
# Create account
347-
password = self.password_helper.generate()
348-
user_dict = {
349-
"email": account_email,
350-
"hashed_password": self.password_helper.hash(password),
351-
"is_verified": is_verified_by_default,
352-
}
353-
logger.debug(f"Creating new user with dict: {user_dict}")
354-
user = await self.user_db.create(user_dict)
355-
logger.info(f"New user created: {user.id}")
356-
logger.info(f"Adding OAuth account to new user: {user.id}")
357-
user = await self.user_db.add_oauth_account(
358-
user, oauth_account_dict
364+
user = await self.user_db.update_oauth_account(
365+
user, existing_oauth_account, oauth_account_dict
359366
)
360-
logger.info(f"OAuth account added to new user: {user.id}")
361-
logger.info(
362-
f"Calling on_after_register for new user: {user.id}"
363-
)
364-
await self.on_after_register(user, request)
365-
else:
366-
# Update oauth
367-
logger.info(f"Updating OAuth account for existing user: {user.id}")
368-
for existing_oauth_account in user.oauth_accounts:
369-
if (
370-
existing_oauth_account.account_id == account_id
371-
and existing_oauth_account.oauth_name == oauth_name
372-
):
373-
logger.info(
374-
f"Updating OAuth account: {oauth_name}, {account_id}"
375-
)
376-
user = await self.user_db.update_oauth_account(
377-
user, existing_oauth_account, oauth_account_dict
378-
)
379-
logger.info(f"OAuth account updated for user: {user.id}")
367+
logger.info(f"OAuth account updated for user: {user.id}")
380368

381-
except Exception as e:
382-
logger.exception(f"Error in oauth_callback: {str(e)}")
369+
logger.info("OAuth callback completed")
383370

384-
print("OAUTH CALLBACK COMPLETED")
385-
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
386-
# re-authenticate that frequently, so by default this is disabled
387-
if expires_at and TRACK_EXTERNAL_IDP_EXPIRY:
388-
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
389-
await self.user_db.update(
390-
user, update_dict={"oidc_expiry": oidc_expiry}
391-
)
392-
393-
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
394-
# otherwise, the oidc expiry will always be old, and the user will never be able to login
395-
if user.oidc_expiry and not TRACK_EXTERNAL_IDP_EXPIRY:
396-
await self.user_db.update(user, update_dict={"oidc_expiry": None})
397-
398-
# Handle case where user has used product outside of web and is now creating an account through web
399-
if not user.has_web_login:
400-
await self.user_db.update(
401-
user,
402-
update_dict={
371+
try:
372+
if not user.has_web_login:
373+
update_dict = {
403374
"is_verified": is_verified_by_default,
404375
"has_web_login": True,
405-
},
406-
)
407-
user.is_verified = is_verified_by_default
408-
user.has_web_login = True
376+
}
377+
await self.user_db.update(user, update_dict)
378+
user.is_verified = is_verified_by_default
379+
user.has_web_login = True
380+
381+
if expires_at and TRACK_EXTERNAL_IDP_EXPIRY:
382+
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
383+
await self.user_db.update(
384+
user, update_dict={"oidc_expiry": oidc_expiry}
385+
)
386+
387+
if (
388+
hasattr(user, "oidc_expiry")
389+
and user.oidc_expiry is not None
390+
and not TRACK_EXTERNAL_IDP_EXPIRY
391+
):
392+
update_dict = {"oidc_expiry": None}
393+
await self.user_db.update(user, update_dict)
394+
user.oidc_expiry = None
395+
396+
except Exception as e:
397+
logger.exception(f"Error in oauth_callback: {str(e)}")
409398

410399
return user
411400

@@ -462,7 +451,9 @@ async def authenticate(
462451
self.password_helper.hash(credentials.password)
463452
return None
464453

465-
if not user.has_web_login:
454+
has_web_login = attributes.get_attribute(user, "has_web_login")
455+
456+
if not has_web_login:
466457
raise HTTPException(
467458
status_code=status.HTTP_403_FORBIDDEN,
468459
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",

0 commit comments

Comments
 (0)