Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""add_multi_user_chat_collaboration

Revision ID: 26a9d522abca
Revises: ffc707a226b4
Create Date: 2025-06-22 16:05:10.592181

"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = '26a9d522abca'
down_revision = 'ffc707a226b4'
branch_labels = None
depends_on = None

# Define Enum types for roles and statuses
chat_participant_role = sa.Enum('OWNER', 'COLLABORATOR', name='chatparticipantrole')
chat_invitation_status = sa.Enum('PENDING', 'ACCEPTED', 'DECLINED', name='chatinvitationstatus')


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###

# Create new ENUM types in the database
chat_participant_role.create(op.get_bind(), checkfirst=True)
chat_invitation_status.create(op.get_bind(), checkfirst=True)

# Add collaboration_enabled to chat_session table
op.add_column(
'chat_session',
sa.Column('collaboration_enabled', sa.Boolean(), server_default=sa.text('false'), nullable=False)
)

# Add sender_id to chat_message table
op.add_column(
'chat_message',
sa.Column('sender_id', postgresql.UUID(as_uuid=True), nullable=True)
)
op.create_foreign_key(
'fk_chat_message_sender_id_user',
'chat_message', 'user',
['sender_id'], ['id'],
ondelete='SET NULL'
)

# Create chat_session_participant table
op.create_table('chat_session_participant',
sa.Column('chat_session_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('role', chat_participant_role, nullable=False),
sa.Column('joined_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('last_read_message_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['chat_session_id'], ['chat_session.id'], name=op.f('fk_chat_session_participant_chat_session_id_chat_session'), ondelete='CASCADE'),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], name=op.f('fk_chat_session_participant_user_id_user'), ondelete='CASCADE'),
sa.ForeignKeyConstraint(['last_read_message_id'], ['chat_message.id'], name=op.f('fk_chat_session_participant_last_read_message_id_chat_message'), ondelete='SET NULL'),
sa.PrimaryKeyConstraint('chat_session_id', 'user_id', name=op.f('pk_chat_session_participant'))
)
op.create_index(op.f('ix_chat_session_participant_chat_session_id'), 'chat_session_participant', ['chat_session_id'], unique=False)
op.create_index(op.f('ix_chat_session_participant_user_id'), 'chat_session_participant', ['user_id'], unique=False)

# Create chat_invitation table
op.create_table('chat_invitation',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('chat_session_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('inviter_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('invitee_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('status', chat_invitation_status, server_default='PENDING', nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['chat_session_id'], ['chat_session.id'], name=op.f('fk_chat_invitation_chat_session_id_chat_session'), ondelete='CASCADE'),
sa.ForeignKeyConstraint(['inviter_id'], ['user.id'], name=op.f('fk_chat_invitation_inviter_id_user'), ondelete='CASCADE'),
sa.ForeignKeyConstraint(['invitee_id'], ['user.id'], name=op.f('fk_chat_invitation_invitee_id_user'), ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id', name=op.f('pk_chat_invitation'))
)
op.create_index(op.f('ix_chat_invitation_chat_session_id'), 'chat_invitation', ['chat_session_id'], unique=False)
op.create_index(op.f('ix_chat_invitation_invitee_id'), 'chat_invitation', ['invitee_id'], unique=False)

# Data migration: Populate chat_session_participant with existing owners
bind = op.get_bind()
session = sa.orm.Session(bind=bind)

session.execute(sa.text("""
INSERT INTO chat_session_participant (chat_session_id, user_id, role, joined_at)
SELECT id, user_id, 'OWNER', time_created
FROM chat_session
WHERE user_id IS NOT NULL;
"""))

# Data migration: Populate sender_id for existing messages
# Assuming the message sender is the chat session owner for existing 'USER' type messages
session.execute(sa.text("""
UPDATE chat_message
SET sender_id = cs.user_id
FROM chat_session cs
WHERE chat_message.chat_session_id = cs.id
AND cs.user_id IS NOT NULL
AND chat_message.message_type = 'USER';
"""))

session.commit()
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('chat_invitation')
op.drop_table('chat_session_participant')

op.drop_constraint('fk_chat_message_sender_id_user', 'chat_message', type_='foreignkey')
op.drop_column('chat_message', 'sender_id')

op.drop_column('chat_session', 'collaboration_enabled')

# Drop ENUM types from the database
chat_invitation_status.drop(op.get_bind(), checkfirst=True)
chat_participant_role.drop(op.get_bind(), checkfirst=True)
# ### end Alembic commands ###
73 changes: 73 additions & 0 deletions backend/onyx/access/access.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
from onyx.db.document import get_access_info_for_documents
from onyx.db.models import User
from onyx.utils.variable_functionality import fetch_versioned_implementation
from onyx.db.models import (
ChatSession,
ChatSessionParticipant,
ChatSessionSharedStatus,
)


def _get_access_for_document(
Expand Down Expand Up @@ -107,3 +112,71 @@ def get_acl_for_user(user: User | None, db_session: Session | None = None) -> se
"onyx.access.access", "_get_acl_for_user"
)
return versioned_acl_for_user_fn(user, db_session) # type: ignore

# ---------------------------------------------------------------------------
# Chat-session access helpers (multi-user collaboration aware)
# ---------------------------------------------------------------------------

def _user_can_access_chat_session(
chat_session: ChatSession,
user: User | None,
db_session: Session,
) -> bool:
"""
Determines whether the given `user` can access the provided `chat_session`.

Access rules (in precedence order):
1. Session is PUBLIC -> anyone can access.
2. `user` is None -> cannot access (except #1).
3. `user` is the OWNER -> can access.
4. Collaboration is enabled
*AND* `user` is a participant
(role OWNER or COLLABORATOR) -> can access.
5. Otherwise -> access denied.
"""

# 1. Publicly shared chat sessions are accessible by anyone.
if chat_session.shared_status == ChatSessionSharedStatus.PUBLIC:
return True

# 2. No user present (unauthenticated) – cannot access private sessions.
if user is None:
return False

# 3. The owner of the chat session always has access.
if chat_session.user_id == user.id:
return True

# 4. If collaboration is enabled, check participant list.
if chat_session.collaboration_enabled:
participant_exists = (
db_session.query(ChatSessionParticipant)
.filter(
ChatSessionParticipant.chat_session_id == chat_session.id,
ChatSessionParticipant.user_id == user.id,
)
.first()
is not None
)
if participant_exists:
return True

# 5. All other cases – deny access.
return False


def user_can_access_chat_session(
chat_session: ChatSession,
user: User | None,
db_session: Session,
) -> bool:
"""
Public wrapper that supports EE overrides via `fetch_versioned_implementation`,
mirroring the pattern used elsewhere in this module.
"""

versioned_fn = fetch_versioned_implementation(
"onyx.access.access", "_user_can_access_chat_session"
)
# mypy: ignore dynamic dispatch based on EE overrides
return versioned_fn(chat_session, user, db_session) # type: ignore
34 changes: 32 additions & 2 deletions backend/onyx/db/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from onyx.db.models import ChatMessage__SearchDoc
from onyx.db.models import ChatSession
from onyx.db.models import ChatSessionSharedStatus
from onyx.db.models import ChatSessionParticipant
from onyx.db.models import Prompt
from onyx.db.models import SearchDoc
from onyx.db.models import SearchDoc as DBSearchDoc
Expand Down Expand Up @@ -69,6 +70,7 @@ def get_chat_session_by_id(
include_deleted: bool = False,
is_shared: bool = False,
) -> ChatSession:
# Base select for the chat session
stmt = select(ChatSession).where(ChatSession.id == chat_session_id)

if is_shared:
Expand All @@ -77,8 +79,21 @@ def get_chat_session_by_id(
# if user_id is None, assume this is an admin who should be able
# to view all chat sessions
if user_id is not None:
stmt = stmt.where(
or_(ChatSession.user_id == user_id, ChatSession.user_id.is_(None))
# Allow owners OR collaborators (participants) to access the session.
# Use an outer join to ChatSessionParticipant so the query still works
# when there are no participants yet.
stmt = (
stmt.outerjoin(
ChatSessionParticipant,
ChatSessionParticipant.chat_session_id == ChatSession.id,
)
.where(
or_(
ChatSession.user_id == user_id,
ChatSessionParticipant.user_id == user_id,
ChatSession.user_id.is_(None),
)
)
)

result = db_session.execute(stmt)
Expand Down Expand Up @@ -429,6 +444,13 @@ def get_chat_message(
logger.error(
f"User {user_id} tried to fetch a chat message that does not belong to them"
)
# If collaboration is enabled, allow participants to access.
if chat_message.chat_session.collaboration_enabled:
participant_ids = [
p.user_id for p in chat_message.chat_session.participants
]
if user_id in participant_ids:
return chat_message
raise ValueError("Chat message does not belong to user")

return chat_message
Expand Down Expand Up @@ -644,7 +666,13 @@ def create_new_chat_message(
overridden_model: str | None = None,
refined_answer_improvement: bool | None = None,
is_agentic: bool = False,
# NEW: sender of the message (owner or collaborator)
sender_id: UUID | None = None,
) -> ChatMessage:
# Fallback to chat session owner if sender not specified
if sender_id is None:
sender_id = parent_message.chat_session.user_id

if reserved_message_id is not None:
# Edit existing message
existing_message = db_session.query(ChatMessage).get(reserved_message_id)
Expand All @@ -661,6 +689,7 @@ def create_new_chat_message(
existing_message.citations = citations
existing_message.files = files
existing_message.tool_call = tool_call
existing_message.sender_id = sender_id
existing_message.error = error
existing_message.alternate_assistant_id = alternate_assistant_id
existing_message.overridden_model = overridden_model
Expand All @@ -683,6 +712,7 @@ def create_new_chat_message(
tool_call=tool_call,
error=error,
alternate_assistant_id=alternate_assistant_id,
sender_id=sender_id,
overridden_model=overridden_model,
refined_answer_improvement=refined_answer_improvement,
is_agentic=is_agentic,
Expand Down
Loading
Loading