Skip to content

Commit 94ab39f

Browse files
committed
merging code with dev
1 parent 3bed88b commit 94ab39f

File tree

20 files changed

+1486
-70
lines changed

20 files changed

+1486
-70
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""add_chat_feedback_timestamp
2+
3+
Revision ID: 7d819bf06aab
4+
Revises: 0ba3b4464c70
5+
Create Date: 2025-08-14 03:36:05.961388
6+
7+
"""
8+
9+
from alembic import op
10+
import sqlalchemy as sa
11+
from sqlalchemy.dialects import postgresql
12+
13+
# revision identifiers, used by Alembic.
14+
revision = "7d819bf06aab"
15+
down_revision = "0ba3b4464c70"
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
def upgrade() -> None:
21+
# ### commands auto generated by Alembic - please adjust! ###
22+
op.add_column(
23+
"chat_feedback",
24+
sa.Column(
25+
"feedback_timestamp",
26+
sa.DateTime(timezone=True),
27+
server_default=sa.text("now()"),
28+
nullable=True,
29+
),
30+
)
31+
32+
def downgrade() -> None:
33+
op.drop_column("chat_feedback", "feedback_timestamp")

backend/onyx/auth/unified_oauth_callback.py

Lines changed: 148 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@
322322

323323
# return router
324324

325-
325+
from onyx.db.engine.async_sql_engine import get_async_session
326326
import httpx
327327
from fastapi import APIRouter, Depends, Request, HTTPException, status
328328
from fastapi.responses import RedirectResponse
@@ -353,12 +353,17 @@
353353
is_domain_allowed_for_oauth,
354354
get_allowed_oauth_domains
355355
)
356+
from onyx.db.models import User
357+
from onyx.auth.sso_data_db import get_sso_configurations_from_db
356358
import jwt
357359
import json
360+
from sqlalchemy import select
358361

359362
STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state"
360363
logger = setup_logger()
361364

365+
GRAPH_ROOT = "https://graph.microsoft.com/v1.0"
366+
362367
async def get_google_user_info(access_token: str) -> dict:
363368
async with httpx.AsyncClient() as client:
364369
response = await client.get(
@@ -367,17 +372,83 @@ async def get_google_user_info(access_token: str) -> dict:
367372
)
368373
response.raise_for_status()
369374
return response.json()
370-
371-
async def get_microsoft_user_info(access_token: str) -> dict:
372-
logger.info(f"accesssssssssssssssssssss{access_token}")
375+
async def _make_graph_request(access_token: str, endpoint: str) -> dict:
376+
"""Helper function to make Microsoft Graph API requests"""
377+
logger.info(f"Making request to {endpoint} with token: {access_token}")
378+
373379
async with httpx.AsyncClient() as client:
374380
response = await client.get(
375-
"https://graph.microsoft.com/v1.0/me",
381+
f"{GRAPH_ROOT}/{endpoint}",
376382
headers={"Authorization": f"Bearer {access_token}"}
377383
)
378384
response.raise_for_status()
379-
logger.info(f"response:{response.json()}")
380-
return response.json()
385+
result = response.json()
386+
logger.info(f"Response: {result}")
387+
return result
388+
389+
390+
async def get_microsoft_user_info(access_token: str) -> dict:
391+
return await _make_graph_request(access_token, "me")
392+
393+
394+
async def check_microsoft_account(access_token: str) -> dict:
395+
return await _make_graph_request(access_token, "me/appRoleAssignments")
396+
397+
398+
async def check_entra_account_details(access_token: str) -> dict:
399+
from onyx.db.engine.async_sql_engine import get_async_session
400+
db_generator = get_async_session()
401+
db = await anext(db_generator)
402+
client_id = None
403+
404+
sso_configs = await get_sso_configurations_from_db(db)
405+
for provider, creds in sso_configs.items():
406+
if creds.get("client_id") and creds.get("client_secret"):
407+
client_id = creds["client_id"]
408+
break # stop after finding the first match
409+
user_account_details = await _make_graph_request(access_token, "me/appRoleAssignments")
410+
411+
account_details = await _make_graph_request(
412+
access_token,
413+
f"servicePrincipals?$filter=appId eq '{client_id}'"
414+
)
415+
user_roles = user_account_details.get("value", [])
416+
accounts = account_details.get("value", [])
417+
logger.info(f"User account details 3 : f{user_roles} f{accounts}")
418+
if not accounts:
419+
return {"account_exists": False, "account_role_check": False}
420+
421+
first_account = accounts[0]
422+
423+
# Check if any user role matches the account resourceId
424+
account_exists = [
425+
role for role in user_roles if role.get("resourceId") == first_account.get("id")
426+
]
427+
428+
account_role_check = account_exists # Same filter as above
429+
logger.info(f"User account details 3 : f{account_role_check} ")
430+
if account_exists and account_role_check:
431+
# Fetch role type (appRole value)
432+
app_roles = first_account.get("appRoles", [])
433+
role_type = "User"
434+
if app_roles and account_role_check:
435+
for role in app_roles:
436+
if role.get("id") == account_role_check[0].get("appRoleId"):
437+
role_type = role.get("value")
438+
break
439+
logger.info(f"User account details6: {account_exists}")
440+
return {
441+
"account_exists": True,
442+
"account_role_check": True,
443+
"account_details": first_account,
444+
"role_type": role_type
445+
}
446+
447+
return {
448+
"account_exists": False,
449+
"account_role_check": False
450+
}
451+
381452

382453
def get_unified_oauth_callback_router(
383454
auth_backend: AuthenticationBackend,
@@ -594,11 +665,21 @@ async def callback(
594665
account_email = user_info.get("email")
595666
account_id = user_info.get("sub")
596667
else: # microsoft
597-
user_info = await get_microsoft_user_info(token["access_token"])
668+
user_info = await get_microsoft_user_info(token["access_token"])
669+
logger.info(f"Fetching Microsoft user info... {token['access_token']}")
670+
671+
entra_account_details = await check_entra_account_details(token["access_token"])
672+
logger.info(f"User account details: {entra_account_details}")
673+
# Validate account existence
674+
if not entra_account_details or not entra_account_details.get("account_exists", False):
675+
raise Exception("Account does not exist in Microsoft Entra")
676+
677+
# Extract account details
598678
account_email = user_info.get("userPrincipalName") or user_info.get("mail")
599679
account_id = user_info.get("id")
680+
600681

601-
logger.info(f"Retrieved user info - email: {account_email}, id: {account_id}")
682+
logger.info(f"Retrieved user info - email: {account_email}, id: {account_id}")
602683

603684
except Exception as e:
604685
logger.error(f"Failed to fetch user info: {str(e)}")
@@ -618,7 +699,16 @@ async def callback(
618699
# Check if this OAuth registration should be allowed
619700
oauth_allowed = True if provider == "microsoft" else should_allow_oauth_registration(account_email, provider)
620701

621-
702+
role_Type_check1 = (
703+
"BASIC"
704+
if provider == "google"
705+
else "ADMIN" if entra_account_details.get("role_type") == "Admin" else "BASIC"
706+
)
707+
role_Type_check = (
708+
"basic"
709+
if provider == "google"
710+
else "admin" if entra_account_details.get("role_type") == "Admin" else "basic"
711+
)
622712
if not oauth_allowed:
623713
logger.error(f"OAuth registration not allowed for {account_email} with provider {provider}")
624714

@@ -653,7 +743,9 @@ async def callback(
653743
request.state.referral_source = referral_source
654744
request.state.oauth_provider = provider
655745
request.state.is_oauth_flow = True
656-
request.state.user_role = "basic" # Set default role for OAuth users
746+
# request.state.user_role = "admin" if entra_account_details.get("role_type") == "Admin" else "basic"
747+
request.state.user_role = role_Type_check
748+
# Set default role for OAuth users
657749

658750
# Perform OAuth callback to create/login user
659751
try:
@@ -670,9 +762,17 @@ async def callback(
670762
request=request,
671763
associate_by_email=True,
672764
is_verified_by_default=True,
765+
# role_type=role_Type_check
673766
)
674-
logger.info(f"OAuth callback successful for user: {user.email}")
767+
logger.info(f"OAuth callback successful for user: {user.email} ")
768+
logger.info(f"OAuth user: {user} ")
769+
role_update_success = await update_user_role_in_db(role_Type_check1, account_email)
675770

771+
if role_update_success:
772+
logger.info(f"Successfully updated user {account_email} role to {role_Type_check}")
773+
else:
774+
logger.warning(f"Failed to update role for user {account_email}, but continuing with login")
775+
676776
except UserAlreadyExists:
677777
logger.error("User already exists")
678778
raise HTTPException(status_code=400, detail="User already exists")
@@ -756,3 +856,39 @@ async def callback(
756856
raise HTTPException(status_code=500, detail="Internal server error during OAuth callback")
757857

758858
return router
859+
860+
async def update_user_role_in_db(role: str, email: str):
861+
862+
try:
863+
db_generator = get_async_session()
864+
db = await anext(db_generator)
865+
866+
try:
867+
result = await db.execute(
868+
select(User).where(User.email == email)
869+
)
870+
logger.info(f"Role type {result}")
871+
user = result.unique().scalar_one_or_none()
872+
873+
if not user:
874+
logger.error(f"User with ID {User.email} not found for role update")
875+
return False
876+
877+
old_role = user.role
878+
user.role = role
879+
880+
await db.commit()
881+
882+
logger.info(f"Successfully updated user {email} role from {old_role} to {role}")
883+
return True
884+
885+
except Exception as e:
886+
logger.error(f"Failed to update user role for {email}: {str(e)}")
887+
await db.rollback()
888+
return False
889+
finally:
890+
await db.close()
891+
892+
except Exception as e:
893+
logger.error(f"Database session error when updating role for {email}: {str(e)}")
894+
return False

backend/onyx/auth/users.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,9 @@ async def oauth_callback(
401401
*,
402402
associate_by_email: bool = False,
403403
is_verified_by_default: bool = False,
404+
role_type: str = UserRole.BASIC,
404405
) -> User:
406+
logger.info(f"role_typeeeeeeee {role_type}")
405407
referral_source = (
406408
getattr(request.state, "referral_source", None) if request else None
407409
)
@@ -442,6 +444,7 @@ async def oauth_callback(
442444
"account_email": account_email,
443445
"expires_at": expires_at,
444446
"refresh_token": refresh_token,
447+
# "role": role_type,
445448
}
446449

447450
user: User | None = None
@@ -473,6 +476,7 @@ async def oauth_callback(
473476
"email": account_email,
474477
"hashed_password": self.password_helper.hash(password),
475478
"is_verified": is_verified_by_default,
479+
# "role": role_type,
476480
}
477481

478482
user = await self.user_db.create(user_dict)
@@ -518,7 +522,8 @@ async def oauth_callback(
518522
user,
519523
{
520524
"is_verified": is_verified_by_default,
521-
"role": UserRole.BASIC,
525+
# "role": UserRole.BASIC,
526+
# "role": role_type,
522527
},
523528
)
524529

backend/onyx/db/chat.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,12 +1030,34 @@ def get_retrieval_docs_from_search_docs(
10301030
if sort_by_score:
10311031
top_documents = sorted(top_documents, key=lambda doc: doc.score, reverse=True) # type: ignore
10321032
return RetrievalDocs(top_documents=top_documents)
1033+
from onyx.db.models import ChatMessageFeedback
10331034

10341035

10351036
def translate_db_message_to_chat_message_detail(
10361037
chat_message: ChatMessage,
10371038
remove_doc_content: bool = False,
1039+
db_session: Session = None, # ADD THIS PARAMETER
10381040
) -> ChatMessageDetail:
1041+
1042+
# ADD THIS: Get chat feedback for this message
1043+
chat_feedback = []
1044+
if db_session:
1045+
# Replace 'ChatFeedback' with your actual feedback model name
1046+
feedback_records = db_session.query(ChatMessageFeedback).filter(
1047+
ChatMessageFeedback.chat_message_id == chat_message.id
1048+
).all()
1049+
1050+
chat_feedback = [
1051+
{
1052+
"id": feedback.id,
1053+
"is_positive": feedback.is_positive,
1054+
"feedback_text": feedback.feedback_text,
1055+
"chat_message_id": feedback.chat_message_id,
1056+
"feedback_timestamp":feedback.feedback_timestamp
1057+
}
1058+
for feedback in feedback_records
1059+
]
1060+
10391061
chat_msg_detail = ChatMessageDetail(
10401062
chat_session_id=chat_message.chat_session_id,
10411063
message_id=chat_message.id,
@@ -1067,6 +1089,7 @@ def translate_db_message_to_chat_message_detail(
10671089
refined_answer_improvement=chat_message.refined_answer_improvement,
10681090
is_agentic=chat_message.is_agentic,
10691091
error=chat_message.error,
1092+
chat_feedback=chat_feedback, # ADD THIS LINE
10701093
)
10711094

10721095
return chat_msg_detail

backend/onyx/db/models.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import NotRequired
66
from typing import Optional
77
from uuid import uuid4
8-
98
from pydantic import BaseModel
109
from sqlalchemy.orm import validates
1110
from typing_extensions import TypedDict # noreorder
@@ -2336,6 +2335,11 @@ class ChatMessageFeedback(Base):
23362335
required_followup: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
23372336
feedback_text: Mapped[str | None] = mapped_column(Text, nullable=True)
23382337
predefined_feedback: Mapped[str | None] = mapped_column(String, nullable=True)
2338+
feedback_timestamp: Mapped[datetime.datetime | None] = mapped_column(
2339+
DateTime(timezone=True),
2340+
nullable=True,
2341+
server_default=func.now() # Set at database level
2342+
)
23392343

23402344
chat_message: Mapped[ChatMessage] = relationship(
23412345
"ChatMessage",

0 commit comments

Comments
 (0)