-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Persona simplification r2 #5031
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| from collections.abc import Sequence | ||
| from datetime import datetime | ||
| from enum import Enum | ||
| from uuid import UUID | ||
|
|
||
| from fastapi import HTTPException | ||
|
|
@@ -10,7 +11,6 @@ | |
| from sqlalchemy import select | ||
| from sqlalchemy import update | ||
| from sqlalchemy.orm import aliased | ||
| from sqlalchemy.orm import joinedload | ||
| from sqlalchemy.orm import selectinload | ||
| from sqlalchemy.orm import Session | ||
|
|
||
|
|
@@ -38,17 +38,25 @@ | |
| from onyx.db.models import UserGroup | ||
| from onyx.db.notification import create_notification | ||
| from onyx.server.features.persona.models import FullPersonaSnapshot | ||
| from onyx.server.features.persona.models import MinimalPersonaSnapshot | ||
| from onyx.server.features.persona.models import PersonaSharedNotificationData | ||
| from onyx.server.features.persona.models import PersonaSnapshot | ||
| from onyx.server.features.persona.models import PersonaUpsertRequest | ||
| from onyx.utils.logger import setup_logger | ||
| from onyx.utils.variable_functionality import fetch_versioned_implementation | ||
|
|
||
| logger = setup_logger() | ||
|
|
||
|
|
||
| class PersonaLoadType(Enum): | ||
| NONE = "none" | ||
| MINIMAL = "minimal" | ||
| FULL = "full" | ||
|
|
||
|
|
||
| def _add_user_filters( | ||
| stmt: Select, user: User | None, get_editable: bool = True | ||
| ) -> Select: | ||
| stmt: Select[tuple[Persona]], user: User | None, get_editable: bool = True | ||
| ) -> Select[tuple[Persona]]: | ||
| # If user is None and auth is disabled, assume the user is an admin | ||
| if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN): | ||
| return stmt | ||
|
|
@@ -322,43 +330,83 @@ def update_persona_public_status( | |
| db_session.commit() | ||
|
|
||
|
|
||
| def get_personas_for_user( | ||
| def _build_persona_filters( | ||
| stmt: Select[tuple[Persona]], | ||
| include_default: bool, | ||
| include_slack_bot_personas: bool, | ||
| include_deleted: bool, | ||
| ) -> Select[tuple[Persona]]: | ||
| if not include_default: | ||
| stmt = stmt.where(Persona.builtin_persona.is_(False)) | ||
| if not include_slack_bot_personas: | ||
| stmt = stmt.where(not_(Persona.name.startswith(SLACK_BOT_PERSONA_PREFIX))) | ||
| if not include_deleted: | ||
| stmt = stmt.where(Persona.deleted.is_(False)) | ||
| return stmt | ||
|
|
||
|
|
||
| def get_minimal_persona_snapshots_for_user( | ||
| user: User | None, | ||
| db_session: Session, | ||
| get_editable: bool = True, | ||
| include_default: bool = True, | ||
| include_slack_bot_personas: bool = False, | ||
| include_deleted: bool = False, | ||
| ) -> list[MinimalPersonaSnapshot]: | ||
| stmt = select(Persona) | ||
| stmt = _add_user_filters(stmt, user, get_editable) | ||
| stmt = _build_persona_filters( | ||
| stmt, include_default, include_slack_bot_personas, include_deleted | ||
| ) | ||
| stmt = stmt.options( | ||
| selectinload(Persona.tools), | ||
| selectinload(Persona.labels), | ||
| selectinload(Persona.document_sets), | ||
| selectinload(Persona.user), | ||
| ) | ||
| results = db_session.scalars(stmt).all() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rule violated: Prefer pure stateless functions |
||
| return [MinimalPersonaSnapshot.from_model(persona) for persona in results] | ||
|
|
||
|
|
||
| def get_persona_snapshots_for_user( | ||
| # if user is `None` assume the user is an admin or auth is disabled | ||
| user: User | None, | ||
| db_session: Session, | ||
| get_editable: bool = True, | ||
| include_default: bool = True, | ||
| include_slack_bot_personas: bool = False, | ||
| include_deleted: bool = False, | ||
| joinedload_all: bool = False, | ||
| # a bit jank | ||
| include_prompt: bool = True, | ||
| ) -> Sequence[Persona]: | ||
| ) -> list[PersonaSnapshot]: | ||
| stmt = select(Persona) | ||
| stmt = _add_user_filters(stmt, user, get_editable) | ||
| stmt = _build_persona_filters( | ||
| stmt, include_default, include_slack_bot_personas, include_deleted | ||
| ) | ||
| stmt = stmt.options( | ||
| selectinload(Persona.tools), | ||
| selectinload(Persona.labels), | ||
| selectinload(Persona.document_sets), | ||
| selectinload(Persona.user), | ||
| ) | ||
|
|
||
| if not include_default: | ||
| stmt = stmt.where(Persona.builtin_persona.is_(False)) | ||
| if not include_slack_bot_personas: | ||
| stmt = stmt.where(not_(Persona.name.startswith(SLACK_BOT_PERSONA_PREFIX))) | ||
| if not include_deleted: | ||
| stmt = stmt.where(Persona.deleted.is_(False)) | ||
| results = db_session.scalars(stmt).all() | ||
| return [PersonaSnapshot.from_model(persona) for persona in results] | ||
|
|
||
| if joinedload_all: | ||
| stmt = stmt.options( | ||
| selectinload(Persona.tools), | ||
| selectinload(Persona.document_sets), | ||
| selectinload(Persona.groups), | ||
| selectinload(Persona.users), | ||
| selectinload(Persona.labels), | ||
| selectinload(Persona.user_files), | ||
| selectinload(Persona.user_folders), | ||
| ) | ||
| if include_prompt: | ||
| stmt = stmt.options(selectinload(Persona.prompts)) | ||
|
|
||
| results = db_session.execute(stmt).scalars().all() | ||
| return results | ||
| def get_raw_personas_for_user( | ||
| user: User | None, | ||
| db_session: Session, | ||
| get_editable: bool = True, | ||
| include_default: bool = True, | ||
| include_slack_bot_personas: bool = False, | ||
| include_deleted: bool = False, | ||
| ) -> Sequence[Persona]: | ||
| stmt = select(Persona) | ||
| stmt = _add_user_filters(stmt, user, get_editable) | ||
| stmt = _build_persona_filters( | ||
| stmt, include_default, include_slack_bot_personas, include_deleted | ||
| ) | ||
| return db_session.scalars(stmt).all() | ||
|
|
||
|
|
||
| def get_personas(db_session: Session) -> Sequence[Persona]: | ||
|
|
@@ -826,7 +874,7 @@ def delete_persona_label(label_id: int, db_session: Session) -> None: | |
| def persona_has_search_tool(persona_id: int, db_session: Session) -> bool: | ||
| persona = ( | ||
| db_session.query(Persona) | ||
| .options(joinedload(Persona.tools)) | ||
| .options(selectinload(Persona.tools)) | ||
| .filter(Persona.id == persona_id) | ||
| .one_or_none() | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -158,7 +158,7 @@ def list_document_sets_for_user( | |||||
| document_sets = fetch_all_document_sets_for_user( | ||||||
| db_session=db_session, user=user, get_editable=get_editable | ||||||
| ) | ||||||
| return [DocumentSetSummary.from_document_set(ds) for ds in document_sets] | ||||||
| return [DocumentSetSummary.from_model(ds) for ds in document_sets] | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rule violated: Prefer long, explicit variable names
Suggested change
|
||||||
|
|
||||||
|
|
||||||
| @router.get("/document-set-public") | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -25,8 +25,9 @@ | |||||
| from onyx.db.persona import create_update_persona | ||||||
| from onyx.db.persona import delete_persona_label | ||||||
| from onyx.db.persona import get_assistant_labels | ||||||
| from onyx.db.persona import get_minimal_persona_snapshots_for_user | ||||||
| from onyx.db.persona import get_persona_by_id | ||||||
| from onyx.db.persona import get_personas_for_user | ||||||
| from onyx.db.persona import get_persona_snapshots_for_user | ||||||
| from onyx.db.persona import mark_persona_as_deleted | ||||||
| from onyx.db.persona import mark_persona_as_not_deleted | ||||||
| from onyx.db.persona import update_all_personas_display_priority | ||||||
|
|
@@ -45,6 +46,7 @@ | |||||
| from onyx.server.features.persona.models import FullPersonaSnapshot | ||||||
| from onyx.server.features.persona.models import GenerateStarterMessageRequest | ||||||
| from onyx.server.features.persona.models import ImageGenerationToolStatus | ||||||
| from onyx.server.features.persona.models import MinimalPersonaSnapshot | ||||||
| from onyx.server.features.persona.models import PersonaLabelCreate | ||||||
| from onyx.server.features.persona.models import PersonaLabelResponse | ||||||
| from onyx.server.features.persona.models import PersonaSharedNotificationData | ||||||
|
|
@@ -53,6 +55,9 @@ | |||||
| from onyx.server.features.persona.models import PromptSnapshot | ||||||
| from onyx.server.models import DisplayPriorityRequest | ||||||
| from onyx.server.settings.store import load_settings | ||||||
| from onyx.tools.tool_implementations.images.image_generation_tool import ( | ||||||
| ImageGenerationTool, | ||||||
| ) | ||||||
| from onyx.tools.utils import is_image_generation_available | ||||||
| from onyx.utils.logger import setup_logger | ||||||
| from onyx.utils.telemetry import create_milestone_and_report | ||||||
|
|
@@ -165,16 +170,12 @@ def list_personas_admin( | |||||
| include_deleted: bool = False, | ||||||
| get_editable: bool = Query(False, description="If true, return editable personas"), | ||||||
| ) -> list[PersonaSnapshot]: | ||||||
| return [ | ||||||
| PersonaSnapshot.from_model(persona) | ||||||
| for persona in get_personas_for_user( | ||||||
| db_session=db_session, | ||||||
| user=user, | ||||||
| get_editable=get_editable, | ||||||
| include_deleted=include_deleted, | ||||||
| joinedload_all=True, | ||||||
| ) | ||||||
| ] | ||||||
| return get_persona_snapshots_for_user( | ||||||
| user=user, | ||||||
| db_session=db_session, | ||||||
| get_editable=get_editable, | ||||||
| include_deleted=include_deleted, | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| @admin_router.patch("/{persona_id}/undelete") | ||||||
|
|
@@ -414,14 +415,12 @@ def list_personas( | |||||
| db_session: Session = Depends(get_session), | ||||||
| include_deleted: bool = False, | ||||||
| persona_ids: list[int] = Query(None), | ||||||
| ) -> list[PersonaSnapshot]: | ||||||
| personas = get_personas_for_user( | ||||||
| ) -> list[MinimalPersonaSnapshot]: | ||||||
| personas = get_minimal_persona_snapshots_for_user( | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rule violated: Everything Python should be as strictly typed as possible
Suggested change
|
||||||
| user=user, | ||||||
| include_deleted=include_deleted, | ||||||
| db_session=db_session, | ||||||
| get_editable=False, | ||||||
| joinedload_all=True, | ||||||
| include_prompt=False, | ||||||
| ) | ||||||
|
|
||||||
| if persona_ids: | ||||||
|
|
@@ -432,12 +431,14 @@ def list_personas( | |||||
| p | ||||||
| for p in personas | ||||||
| if not ( | ||||||
| any(tool.in_code_tool_id == "ImageGenerationTool" for tool in p.tools) | ||||||
| any( | ||||||
| tool.in_code_tool_id == ImageGenerationTool.__name__ for tool in p.tools | ||||||
| ) | ||||||
| and not is_image_generation_available(db_session=db_session) | ||||||
| ) | ||||||
| ] | ||||||
|
|
||||||
| return [PersonaSnapshot.from_model(p) for p in personas] | ||||||
| return personas | ||||||
|
|
||||||
|
|
||||||
| @basic_router.get("/{persona_id}") | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rule violated: Everything Python should be as strictly typed as possible