Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 77 additions & 29 deletions backend/onyx/db/persona.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Sequence
from datetime import datetime
from enum import Enum
from uuid import UUID

from fastapi import HTTPException
Expand All @@ -10,7 +11,6 @@
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import aliased
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session

Expand Down Expand Up @@ -38,17 +38,25 @@
from onyx.db.models import UserGroup
from onyx.db.notification import create_notification
from onyx.server.features.persona.models import FullPersonaSnapshot
from onyx.server.features.persona.models import MinimalPersonaSnapshot
from onyx.server.features.persona.models import PersonaSharedNotificationData
from onyx.server.features.persona.models import PersonaSnapshot
from onyx.server.features.persona.models import PersonaUpsertRequest
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_versioned_implementation

logger = setup_logger()


class PersonaLoadType(Enum):
NONE = "none"
MINIMAL = "minimal"
FULL = "full"


def _add_user_filters(
stmt: Select, user: User | None, get_editable: bool = True
) -> Select:
stmt: Select[tuple[Persona]], user: User | None, get_editable: bool = True
) -> Select[tuple[Persona]]:
# If user is None and auth is disabled, assume the user is an admin
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
return stmt
Expand Down Expand Up @@ -322,43 +330,83 @@ def update_persona_public_status(
db_session.commit()


def get_personas_for_user(
def _build_persona_filters(
stmt: Select[tuple[Persona]],
include_default: bool,
include_slack_bot_personas: bool,
include_deleted: bool,
) -> Select[tuple[Persona]]:
if not include_default:
stmt = stmt.where(Persona.builtin_persona.is_(False))
if not include_slack_bot_personas:
stmt = stmt.where(not_(Persona.name.startswith(SLACK_BOT_PERSONA_PREFIX)))
if not include_deleted:
stmt = stmt.where(Persona.deleted.is_(False))
return stmt


def get_minimal_persona_snapshots_for_user(
user: User | None,
db_session: Session,
get_editable: bool = True,
include_default: bool = True,
include_slack_bot_personas: bool = False,
include_deleted: bool = False,
) -> list[MinimalPersonaSnapshot]:
stmt = select(Persona)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rule violated: Everything Python should be as strictly typed as possible

  Variable assignment inside function lacks explicit type annotation, violating strict-typing rule that requires typed variables at function scope.
Suggested change
stmt = select(Persona)
stmt: Select[tuple[Persona]] = select(Persona)

stmt = _add_user_filters(stmt, user, get_editable)
stmt = _build_persona_filters(
stmt, include_default, include_slack_bot_personas, include_deleted
)
stmt = stmt.options(
selectinload(Persona.tools),
selectinload(Persona.labels),
selectinload(Persona.document_sets),
selectinload(Persona.user),
)
results = db_session.scalars(stmt).all()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rule violated: Prefer pure stateless functions

  Performs a database query inside the function, introducing I/O and external state dependency, in direct violation of the "Prefer pure stateless functions" rule.

return [MinimalPersonaSnapshot.from_model(persona) for persona in results]


def get_persona_snapshots_for_user(
# if user is `None` assume the user is an admin or auth is disabled
user: User | None,
db_session: Session,
get_editable: bool = True,
include_default: bool = True,
include_slack_bot_personas: bool = False,
include_deleted: bool = False,
joinedload_all: bool = False,
# a bit jank
include_prompt: bool = True,
) -> Sequence[Persona]:
) -> list[PersonaSnapshot]:
stmt = select(Persona)
stmt = _add_user_filters(stmt, user, get_editable)
stmt = _build_persona_filters(
stmt, include_default, include_slack_bot_personas, include_deleted
)
stmt = stmt.options(
selectinload(Persona.tools),
selectinload(Persona.labels),
selectinload(Persona.document_sets),
selectinload(Persona.user),
)

if not include_default:
stmt = stmt.where(Persona.builtin_persona.is_(False))
if not include_slack_bot_personas:
stmt = stmt.where(not_(Persona.name.startswith(SLACK_BOT_PERSONA_PREFIX)))
if not include_deleted:
stmt = stmt.where(Persona.deleted.is_(False))
results = db_session.scalars(stmt).all()
return [PersonaSnapshot.from_model(persona) for persona in results]

if joinedload_all:
stmt = stmt.options(
selectinload(Persona.tools),
selectinload(Persona.document_sets),
selectinload(Persona.groups),
selectinload(Persona.users),
selectinload(Persona.labels),
selectinload(Persona.user_files),
selectinload(Persona.user_folders),
)
if include_prompt:
stmt = stmt.options(selectinload(Persona.prompts))

results = db_session.execute(stmt).scalars().all()
return results
def get_raw_personas_for_user(
user: User | None,
db_session: Session,
get_editable: bool = True,
include_default: bool = True,
include_slack_bot_personas: bool = False,
include_deleted: bool = False,
) -> Sequence[Persona]:
stmt = select(Persona)
stmt = _add_user_filters(stmt, user, get_editable)
stmt = _build_persona_filters(
stmt, include_default, include_slack_bot_personas, include_deleted
)
return db_session.scalars(stmt).all()


def get_personas(db_session: Session) -> Sequence[Persona]:
Expand Down Expand Up @@ -826,7 +874,7 @@ def delete_persona_label(label_id: int, db_session: Session) -> None:
def persona_has_search_tool(persona_id: int, db_session: Session) -> bool:
persona = (
db_session.query(Persona)
.options(joinedload(Persona.tools))
.options(selectinload(Persona.tools))
.filter(Persona.id == persona_id)
.one_or_none()
)
Expand Down
5 changes: 4 additions & 1 deletion backend/onyx/seeding/load_yamls.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from onyx.db.prompts import get_prompt_by_name
from onyx.db.prompts import upsert_prompt
from onyx.db.user_documents import upsert_user_folder
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)


def load_user_folders_from_yaml(
Expand Down Expand Up @@ -136,7 +139,7 @@ def load_personas_from_yaml(
if persona.get("image_generation"):
image_gen_tool = (
db_session.query(ToolDBModel)
.filter(ToolDBModel.name == "ImageGenerationTool")
.filter(ToolDBModel.name == ImageGenerationTool.__name__)
.first()
)
if image_gen_tool:
Expand Down
2 changes: 1 addition & 1 deletion backend/onyx/server/features/document_set/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def list_document_sets_for_user(
document_sets = fetch_all_document_sets_for_user(
db_session=db_session, user=user, get_editable=get_editable
)
return [DocumentSetSummary.from_document_set(ds) for ds in document_sets]
return [DocumentSetSummary.from_model(ds) for ds in document_sets]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rule violated: Prefer long, explicit variable names

  Variable name "ds" is an overly abbreviated identifier and violates the explicit naming rule; use a descriptive name with at least 3 words and 12 characters.
Suggested change
return [DocumentSetSummary.from_model(ds) for ds in document_sets]
return [DocumentSetSummary.from_model(document_set) for document_set in document_sets]



@router.get("/document-set-public")
Expand Down
98 changes: 48 additions & 50 deletions backend/onyx/server/features/document_set/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,56 +49,6 @@ def from_federated_connector_mapping(
)


class DocumentSetSummary(BaseModel):
"""Simplified document set model with minimal data for list views"""

id: int
name: str
description: str | None
cc_pair_summaries: list[CCPairSummary]
is_up_to_date: bool
is_public: bool
users: list[UUID]
groups: list[int]
federated_connector_summaries: list[FederatedConnectorSummary] = Field(
default_factory=list
)

@classmethod
def from_document_set(
cls, document_set: DocumentSetDBModel
) -> "DocumentSetSummary":
"""Create a summary from a DocumentSet database model"""
return cls(
id=document_set.id,
name=document_set.name,
description=document_set.description,
cc_pair_summaries=[
CCPairSummary(
id=cc_pair.id,
name=cc_pair.name,
source=cc_pair.connector.source,
access_type=cc_pair.access_type,
)
for cc_pair in document_set.connector_credential_pairs
],
is_up_to_date=document_set.is_up_to_date,
is_public=document_set.is_public,
users=[user.id for user in document_set.users],
groups=[group.id for group in document_set.groups],
federated_connector_summaries=[
FederatedConnectorSummary(
id=fc_mapping.federated_connector_id,
name=f"{fc_mapping.federated_connector.source.replace('_', ' ').title()}",
source=fc_mapping.federated_connector.source,
entities=fc_mapping.entities,
)
for fc_mapping in document_set.federated_connectors
if fc_mapping.federated_connector is not None
],
)


class DocumentSetCreationRequest(BaseModel):
name: str
description: str
Expand Down Expand Up @@ -181,3 +131,51 @@ def from_model(cls, document_set_model: DocumentSetDBModel) -> "DocumentSet":
for fc_mapping in document_set_model.federated_connectors
],
)


class DocumentSetSummary(BaseModel):
"""Simplified document set model with minimal data for list views"""

id: int
name: str
description: str | None
cc_pair_summaries: list[CCPairSummary]
is_up_to_date: bool
is_public: bool
users: list[UUID]
groups: list[int]
federated_connector_summaries: list[FederatedConnectorSummary] = Field(
default_factory=list
)

@classmethod
def from_model(cls, document_set: DocumentSetDBModel) -> "DocumentSetSummary":
"""Create a summary from a DocumentSet database model"""
return cls(
id=document_set.id,
name=document_set.name,
description=document_set.description,
cc_pair_summaries=[
CCPairSummary(
id=cc_pair.id,
name=cc_pair.name,
source=cc_pair.connector.source,
access_type=cc_pair.access_type,
)
for cc_pair in document_set.connector_credential_pairs
],
is_up_to_date=document_set.is_up_to_date,
is_public=document_set.is_public,
users=[user.id for user in document_set.users],
groups=[group.id for group in document_set.groups],
federated_connector_summaries=[
FederatedConnectorSummary(
id=fc_mapping.federated_connector_id,
name=f"{fc_mapping.federated_connector.source.replace('_', ' ').title()}",
source=fc_mapping.federated_connector.source,
entities=fc_mapping.entities,
)
for fc_mapping in document_set.federated_connectors
if fc_mapping.federated_connector is not None
],
)
35 changes: 18 additions & 17 deletions backend/onyx/server/features/persona/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
from onyx.db.persona import create_update_persona
from onyx.db.persona import delete_persona_label
from onyx.db.persona import get_assistant_labels
from onyx.db.persona import get_minimal_persona_snapshots_for_user
from onyx.db.persona import get_persona_by_id
from onyx.db.persona import get_personas_for_user
from onyx.db.persona import get_persona_snapshots_for_user
from onyx.db.persona import mark_persona_as_deleted
from onyx.db.persona import mark_persona_as_not_deleted
from onyx.db.persona import update_all_personas_display_priority
Expand All @@ -45,6 +46,7 @@
from onyx.server.features.persona.models import FullPersonaSnapshot
from onyx.server.features.persona.models import GenerateStarterMessageRequest
from onyx.server.features.persona.models import ImageGenerationToolStatus
from onyx.server.features.persona.models import MinimalPersonaSnapshot
from onyx.server.features.persona.models import PersonaLabelCreate
from onyx.server.features.persona.models import PersonaLabelResponse
from onyx.server.features.persona.models import PersonaSharedNotificationData
Expand All @@ -53,6 +55,9 @@
from onyx.server.features.persona.models import PromptSnapshot
from onyx.server.models import DisplayPriorityRequest
from onyx.server.settings.store import load_settings
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.utils import is_image_generation_available
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
Expand Down Expand Up @@ -165,16 +170,12 @@ def list_personas_admin(
include_deleted: bool = False,
get_editable: bool = Query(False, description="If true, return editable personas"),
) -> list[PersonaSnapshot]:
return [
PersonaSnapshot.from_model(persona)
for persona in get_personas_for_user(
db_session=db_session,
user=user,
get_editable=get_editable,
include_deleted=include_deleted,
joinedload_all=True,
)
]
return get_persona_snapshots_for_user(
user=user,
db_session=db_session,
get_editable=get_editable,
include_deleted=include_deleted,
)


@admin_router.patch("/{persona_id}/undelete")
Expand Down Expand Up @@ -414,14 +415,12 @@ def list_personas(
db_session: Session = Depends(get_session),
include_deleted: bool = False,
persona_ids: list[int] = Query(None),
) -> list[PersonaSnapshot]:
personas = get_personas_for_user(
) -> list[MinimalPersonaSnapshot]:
personas = get_minimal_persona_snapshots_for_user(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rule violated: Everything Python should be as strictly typed as possible

  Variable 'personas' is assigned a non-literal value without an explicit type annotation, violating the strict typing rule that requires all variable declarations to be fully typed.
Suggested change
personas = get_minimal_persona_snapshots_for_user(
personas: list[MinimalPersonaSnapshot] = get_minimal_persona_snapshots_for_user(

user=user,
include_deleted=include_deleted,
db_session=db_session,
get_editable=False,
joinedload_all=True,
include_prompt=False,
)

if persona_ids:
Expand All @@ -432,12 +431,14 @@ def list_personas(
p
for p in personas
if not (
any(tool.in_code_tool_id == "ImageGenerationTool" for tool in p.tools)
any(
tool.in_code_tool_id == ImageGenerationTool.__name__ for tool in p.tools
)
and not is_image_generation_available(db_session=db_session)
)
]

return [PersonaSnapshot.from_model(p) for p in personas]
return personas


@basic_router.get("/{persona_id}")
Expand Down
Loading
Loading