322
322
323
323
# return router
324
324
325
-
325
+ from onyx . db . engine . async_sql_engine import get_async_session
326
326
import httpx
327
327
from fastapi import APIRouter , Depends , Request , HTTPException , status
328
328
from fastapi .responses import RedirectResponse
353
353
is_domain_allowed_for_oauth ,
354
354
get_allowed_oauth_domains
355
355
)
356
+ from onyx .db .models import User
357
+ from onyx .auth .sso_data_db import get_sso_configurations_from_db
356
358
import jwt
357
359
import json
360
+ from sqlalchemy import select
358
361
359
362
STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state"
360
363
logger = setup_logger ()
361
364
365
+ GRAPH_ROOT = "https://graph.microsoft.com/v1.0"
366
+
362
367
async def get_google_user_info (access_token : str ) -> dict :
363
368
async with httpx .AsyncClient () as client :
364
369
response = await client .get (
@@ -367,17 +372,83 @@ async def get_google_user_info(access_token: str) -> dict:
367
372
)
368
373
response .raise_for_status ()
369
374
return response .json ()
370
-
371
- async def get_microsoft_user_info (access_token : str ) -> dict :
372
- logger .info (f"accesssssssssssssssssssss{ access_token } " )
375
+ async def _make_graph_request (access_token : str , endpoint : str ) -> dict :
376
+ """Helper function to make Microsoft Graph API requests"""
377
+ logger .info (f"Making request to { endpoint } with token: { access_token } " )
378
+
373
379
async with httpx .AsyncClient () as client :
374
380
response = await client .get (
375
- "https://graph.microsoft.com/v1.0/me " ,
381
+ f" { GRAPH_ROOT } / { endpoint } " ,
376
382
headers = {"Authorization" : f"Bearer { access_token } " }
377
383
)
378
384
response .raise_for_status ()
379
- logger .info (f"response:{ response .json ()} " )
380
- return response .json ()
385
+ result = response .json ()
386
+ logger .info (f"Response: { result } " )
387
+ return result
388
+
389
+
390
+ async def get_microsoft_user_info (access_token : str ) -> dict :
391
+ return await _make_graph_request (access_token , "me" )
392
+
393
+
394
+ async def check_microsoft_account (access_token : str ) -> dict :
395
+ return await _make_graph_request (access_token , "me/appRoleAssignments" )
396
+
397
+
398
+ async def check_entra_account_details (access_token : str ) -> dict :
399
+ from onyx .db .engine .async_sql_engine import get_async_session
400
+ db_generator = get_async_session ()
401
+ db = await anext (db_generator )
402
+ client_id = None
403
+
404
+ sso_configs = await get_sso_configurations_from_db (db )
405
+ for provider , creds in sso_configs .items ():
406
+ if creds .get ("client_id" ) and creds .get ("client_secret" ):
407
+ client_id = creds ["client_id" ]
408
+ break # stop after finding the first match
409
+ user_account_details = await _make_graph_request (access_token , "me/appRoleAssignments" )
410
+
411
+ account_details = await _make_graph_request (
412
+ access_token ,
413
+ f"servicePrincipals?$filter=appId eq '{ client_id } '"
414
+ )
415
+ user_roles = user_account_details .get ("value" , [])
416
+ accounts = account_details .get ("value" , [])
417
+ logger .info (f"User account details 3 : f{ user_roles } f{ accounts } " )
418
+ if not accounts :
419
+ return {"account_exists" : False , "account_role_check" : False }
420
+
421
+ first_account = accounts [0 ]
422
+
423
+ # Check if any user role matches the account resourceId
424
+ account_exists = [
425
+ role for role in user_roles if role .get ("resourceId" ) == first_account .get ("id" )
426
+ ]
427
+
428
+ account_role_check = account_exists # Same filter as above
429
+ logger .info (f"User account details 3 : f{ account_role_check } " )
430
+ if account_exists and account_role_check :
431
+ # Fetch role type (appRole value)
432
+ app_roles = first_account .get ("appRoles" , [])
433
+ role_type = "User"
434
+ if app_roles and account_role_check :
435
+ for role in app_roles :
436
+ if role .get ("id" ) == account_role_check [0 ].get ("appRoleId" ):
437
+ role_type = role .get ("value" )
438
+ break
439
+ logger .info (f"User account details6: { account_exists } " )
440
+ return {
441
+ "account_exists" : True ,
442
+ "account_role_check" : True ,
443
+ "account_details" : first_account ,
444
+ "role_type" : role_type
445
+ }
446
+
447
+ return {
448
+ "account_exists" : False ,
449
+ "account_role_check" : False
450
+ }
451
+
381
452
382
453
def get_unified_oauth_callback_router (
383
454
auth_backend : AuthenticationBackend ,
@@ -594,11 +665,21 @@ async def callback(
594
665
account_email = user_info .get ("email" )
595
666
account_id = user_info .get ("sub" )
596
667
else : # microsoft
597
- user_info = await get_microsoft_user_info (token ["access_token" ])
668
+ user_info = await get_microsoft_user_info (token ["access_token" ])
669
+ logger .info (f"Fetching Microsoft user info... { token ['access_token' ]} " )
670
+
671
+ entra_account_details = await check_entra_account_details (token ["access_token" ])
672
+ logger .info (f"User account details: { entra_account_details } " )
673
+ # Validate account existence
674
+ if not entra_account_details or not entra_account_details .get ("account_exists" , False ):
675
+ raise Exception ("Account does not exist in Microsoft Entra" )
676
+
677
+ # Extract account details
598
678
account_email = user_info .get ("userPrincipalName" ) or user_info .get ("mail" )
599
679
account_id = user_info .get ("id" )
680
+
600
681
601
- logger .info (f"Retrieved user info - email: { account_email } , id: { account_id } " )
682
+ logger .info (f"Retrieved user info - email: { account_email } , id: { account_id } " )
602
683
603
684
except Exception as e :
604
685
logger .error (f"Failed to fetch user info: { str (e )} " )
@@ -618,7 +699,16 @@ async def callback(
618
699
# Check if this OAuth registration should be allowed
619
700
oauth_allowed = True if provider == "microsoft" else should_allow_oauth_registration (account_email , provider )
620
701
621
-
702
+ role_Type_check1 = (
703
+ "BASIC"
704
+ if provider == "google"
705
+ else "ADMIN" if entra_account_details .get ("role_type" ) == "Admin" else "BASIC"
706
+ )
707
+ role_Type_check = (
708
+ "basic"
709
+ if provider == "google"
710
+ else "admin" if entra_account_details .get ("role_type" ) == "Admin" else "basic"
711
+ )
622
712
if not oauth_allowed :
623
713
logger .error (f"OAuth registration not allowed for { account_email } with provider { provider } " )
624
714
@@ -653,7 +743,9 @@ async def callback(
653
743
request .state .referral_source = referral_source
654
744
request .state .oauth_provider = provider
655
745
request .state .is_oauth_flow = True
656
- request .state .user_role = "basic" # Set default role for OAuth users
746
+ # request.state.user_role = "admin" if entra_account_details.get("role_type") == "Admin" else "basic"
747
+ request .state .user_role = role_Type_check
748
+ # Set default role for OAuth users
657
749
658
750
# Perform OAuth callback to create/login user
659
751
try :
@@ -670,9 +762,17 @@ async def callback(
670
762
request = request ,
671
763
associate_by_email = True ,
672
764
is_verified_by_default = True ,
765
+ # role_type=role_Type_check
673
766
)
674
- logger .info (f"OAuth callback successful for user: { user .email } " )
767
+ logger .info (f"OAuth callback successful for user: { user .email } " )
768
+ logger .info (f"OAuth user: { user } " )
769
+ role_update_success = await update_user_role_in_db (role_Type_check1 , account_email )
675
770
771
+ if role_update_success :
772
+ logger .info (f"Successfully updated user { account_email } role to { role_Type_check } " )
773
+ else :
774
+ logger .warning (f"Failed to update role for user { account_email } , but continuing with login" )
775
+
676
776
except UserAlreadyExists :
677
777
logger .error ("User already exists" )
678
778
raise HTTPException (status_code = 400 , detail = "User already exists" )
@@ -756,3 +856,39 @@ async def callback(
756
856
raise HTTPException (status_code = 500 , detail = "Internal server error during OAuth callback" )
757
857
758
858
return router
859
+
860
+ async def update_user_role_in_db (role : str , email : str ):
861
+
862
+ try :
863
+ db_generator = get_async_session ()
864
+ db = await anext (db_generator )
865
+
866
+ try :
867
+ result = await db .execute (
868
+ select (User ).where (User .email == email )
869
+ )
870
+ logger .info (f"Role type { result } " )
871
+ user = result .unique ().scalar_one_or_none ()
872
+
873
+ if not user :
874
+ logger .error (f"User with ID { User .email } not found for role update" )
875
+ return False
876
+
877
+ old_role = user .role
878
+ user .role = role
879
+
880
+ await db .commit ()
881
+
882
+ logger .info (f"Successfully updated user { email } role from { old_role } to { role } " )
883
+ return True
884
+
885
+ except Exception as e :
886
+ logger .error (f"Failed to update user role for { email } : { str (e )} " )
887
+ await db .rollback ()
888
+ return False
889
+ finally :
890
+ await db .close ()
891
+
892
+ except Exception as e :
893
+ logger .error (f"Database session error when updating role for { email } : { str (e )} " )
894
+ return False
0 commit comments