Skip to content

Commit 3bb58a3

Browse files
authored
Persona simplification r2 (#5031)
* Revert "Revert "Reduce amount of stuff we fetch on `/persona` (#4988)" (#5024)" This reverts commit f7ed7cd. * Enhancements / fix re-render * re-arrange * greptile
1 parent 4b02fee commit 3bb58a3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+523
-341
lines changed

backend/onyx/db/persona.py

Lines changed: 77 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections.abc import Sequence
22
from datetime import datetime
3+
from enum import Enum
34
from uuid import UUID
45

56
from fastapi import HTTPException
@@ -10,7 +11,6 @@
1011
from sqlalchemy import select
1112
from sqlalchemy import update
1213
from sqlalchemy.orm import aliased
13-
from sqlalchemy.orm import joinedload
1414
from sqlalchemy.orm import selectinload
1515
from sqlalchemy.orm import Session
1616

@@ -38,17 +38,25 @@
3838
from onyx.db.models import UserGroup
3939
from onyx.db.notification import create_notification
4040
from onyx.server.features.persona.models import FullPersonaSnapshot
41+
from onyx.server.features.persona.models import MinimalPersonaSnapshot
4142
from onyx.server.features.persona.models import PersonaSharedNotificationData
43+
from onyx.server.features.persona.models import PersonaSnapshot
4244
from onyx.server.features.persona.models import PersonaUpsertRequest
4345
from onyx.utils.logger import setup_logger
4446
from onyx.utils.variable_functionality import fetch_versioned_implementation
4547

4648
logger = setup_logger()
4749

4850

51+
class PersonaLoadType(Enum):
52+
NONE = "none"
53+
MINIMAL = "minimal"
54+
FULL = "full"
55+
56+
4957
def _add_user_filters(
50-
stmt: Select, user: User | None, get_editable: bool = True
51-
) -> Select:
58+
stmt: Select[tuple[Persona]], user: User | None, get_editable: bool = True
59+
) -> Select[tuple[Persona]]:
5260
# If user is None and auth is disabled, assume the user is an admin
5361
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
5462
return stmt
@@ -322,43 +330,83 @@ def update_persona_public_status(
322330
db_session.commit()
323331

324332

325-
def get_personas_for_user(
333+
def _build_persona_filters(
334+
stmt: Select[tuple[Persona]],
335+
include_default: bool,
336+
include_slack_bot_personas: bool,
337+
include_deleted: bool,
338+
) -> Select[tuple[Persona]]:
339+
if not include_default:
340+
stmt = stmt.where(Persona.builtin_persona.is_(False))
341+
if not include_slack_bot_personas:
342+
stmt = stmt.where(not_(Persona.name.startswith(SLACK_BOT_PERSONA_PREFIX)))
343+
if not include_deleted:
344+
stmt = stmt.where(Persona.deleted.is_(False))
345+
return stmt
346+
347+
348+
def get_minimal_persona_snapshots_for_user(
349+
user: User | None,
350+
db_session: Session,
351+
get_editable: bool = True,
352+
include_default: bool = True,
353+
include_slack_bot_personas: bool = False,
354+
include_deleted: bool = False,
355+
) -> list[MinimalPersonaSnapshot]:
356+
stmt = select(Persona)
357+
stmt = _add_user_filters(stmt, user, get_editable)
358+
stmt = _build_persona_filters(
359+
stmt, include_default, include_slack_bot_personas, include_deleted
360+
)
361+
stmt = stmt.options(
362+
selectinload(Persona.tools),
363+
selectinload(Persona.labels),
364+
selectinload(Persona.document_sets),
365+
selectinload(Persona.user),
366+
)
367+
results = db_session.scalars(stmt).all()
368+
return [MinimalPersonaSnapshot.from_model(persona) for persona in results]
369+
370+
371+
def get_persona_snapshots_for_user(
326372
# if user is `None` assume the user is an admin or auth is disabled
327373
user: User | None,
328374
db_session: Session,
329375
get_editable: bool = True,
330376
include_default: bool = True,
331377
include_slack_bot_personas: bool = False,
332378
include_deleted: bool = False,
333-
joinedload_all: bool = False,
334-
# a bit jank
335-
include_prompt: bool = True,
336-
) -> Sequence[Persona]:
379+
) -> list[PersonaSnapshot]:
337380
stmt = select(Persona)
338381
stmt = _add_user_filters(stmt, user, get_editable)
382+
stmt = _build_persona_filters(
383+
stmt, include_default, include_slack_bot_personas, include_deleted
384+
)
385+
stmt = stmt.options(
386+
selectinload(Persona.tools),
387+
selectinload(Persona.labels),
388+
selectinload(Persona.document_sets),
389+
selectinload(Persona.user),
390+
)
339391

340-
if not include_default:
341-
stmt = stmt.where(Persona.builtin_persona.is_(False))
342-
if not include_slack_bot_personas:
343-
stmt = stmt.where(not_(Persona.name.startswith(SLACK_BOT_PERSONA_PREFIX)))
344-
if not include_deleted:
345-
stmt = stmt.where(Persona.deleted.is_(False))
392+
results = db_session.scalars(stmt).all()
393+
return [PersonaSnapshot.from_model(persona) for persona in results]
346394

347-
if joinedload_all:
348-
stmt = stmt.options(
349-
selectinload(Persona.tools),
350-
selectinload(Persona.document_sets),
351-
selectinload(Persona.groups),
352-
selectinload(Persona.users),
353-
selectinload(Persona.labels),
354-
selectinload(Persona.user_files),
355-
selectinload(Persona.user_folders),
356-
)
357-
if include_prompt:
358-
stmt = stmt.options(selectinload(Persona.prompts))
359395

360-
results = db_session.execute(stmt).scalars().all()
361-
return results
396+
def get_raw_personas_for_user(
397+
user: User | None,
398+
db_session: Session,
399+
get_editable: bool = True,
400+
include_default: bool = True,
401+
include_slack_bot_personas: bool = False,
402+
include_deleted: bool = False,
403+
) -> Sequence[Persona]:
404+
stmt = select(Persona)
405+
stmt = _add_user_filters(stmt, user, get_editable)
406+
stmt = _build_persona_filters(
407+
stmt, include_default, include_slack_bot_personas, include_deleted
408+
)
409+
return db_session.scalars(stmt).all()
362410

363411

364412
def get_personas(db_session: Session) -> Sequence[Persona]:
@@ -826,7 +874,7 @@ def delete_persona_label(label_id: int, db_session: Session) -> None:
826874
def persona_has_search_tool(persona_id: int, db_session: Session) -> bool:
827875
persona = (
828876
db_session.query(Persona)
829-
.options(joinedload(Persona.tools))
877+
.options(selectinload(Persona.tools))
830878
.filter(Persona.id == persona_id)
831879
.one_or_none()
832880
)

backend/onyx/seeding/load_yamls.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from onyx.db.prompts import get_prompt_by_name
1818
from onyx.db.prompts import upsert_prompt
1919
from onyx.db.user_documents import upsert_user_folder
20+
from onyx.tools.tool_implementations.images.image_generation_tool import (
21+
ImageGenerationTool,
22+
)
2023

2124

2225
def load_user_folders_from_yaml(
@@ -136,7 +139,7 @@ def load_personas_from_yaml(
136139
if persona.get("image_generation"):
137140
image_gen_tool = (
138141
db_session.query(ToolDBModel)
139-
.filter(ToolDBModel.name == "ImageGenerationTool")
142+
.filter(ToolDBModel.name == ImageGenerationTool.__name__)
140143
.first()
141144
)
142145
if image_gen_tool:

backend/onyx/server/features/document_set/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def list_document_sets_for_user(
158158
document_sets = fetch_all_document_sets_for_user(
159159
db_session=db_session, user=user, get_editable=get_editable
160160
)
161-
return [DocumentSetSummary.from_document_set(ds) for ds in document_sets]
161+
return [DocumentSetSummary.from_model(ds) for ds in document_sets]
162162

163163

164164
@router.get("/document-set-public")

backend/onyx/server/features/document_set/models.py

Lines changed: 48 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -49,56 +49,6 @@ def from_federated_connector_mapping(
4949
)
5050

5151

52-
class DocumentSetSummary(BaseModel):
53-
"""Simplified document set model with minimal data for list views"""
54-
55-
id: int
56-
name: str
57-
description: str | None
58-
cc_pair_summaries: list[CCPairSummary]
59-
is_up_to_date: bool
60-
is_public: bool
61-
users: list[UUID]
62-
groups: list[int]
63-
federated_connector_summaries: list[FederatedConnectorSummary] = Field(
64-
default_factory=list
65-
)
66-
67-
@classmethod
68-
def from_document_set(
69-
cls, document_set: DocumentSetDBModel
70-
) -> "DocumentSetSummary":
71-
"""Create a summary from a DocumentSet database model"""
72-
return cls(
73-
id=document_set.id,
74-
name=document_set.name,
75-
description=document_set.description,
76-
cc_pair_summaries=[
77-
CCPairSummary(
78-
id=cc_pair.id,
79-
name=cc_pair.name,
80-
source=cc_pair.connector.source,
81-
access_type=cc_pair.access_type,
82-
)
83-
for cc_pair in document_set.connector_credential_pairs
84-
],
85-
is_up_to_date=document_set.is_up_to_date,
86-
is_public=document_set.is_public,
87-
users=[user.id for user in document_set.users],
88-
groups=[group.id for group in document_set.groups],
89-
federated_connector_summaries=[
90-
FederatedConnectorSummary(
91-
id=fc_mapping.federated_connector_id,
92-
name=f"{fc_mapping.federated_connector.source.replace('_', ' ').title()}",
93-
source=fc_mapping.federated_connector.source,
94-
entities=fc_mapping.entities,
95-
)
96-
for fc_mapping in document_set.federated_connectors
97-
if fc_mapping.federated_connector is not None
98-
],
99-
)
100-
101-
10252
class DocumentSetCreationRequest(BaseModel):
10353
name: str
10454
description: str
@@ -181,3 +131,51 @@ def from_model(cls, document_set_model: DocumentSetDBModel) -> "DocumentSet":
181131
for fc_mapping in document_set_model.federated_connectors
182132
],
183133
)
134+
135+
136+
class DocumentSetSummary(BaseModel):
137+
"""Simplified document set model with minimal data for list views"""
138+
139+
id: int
140+
name: str
141+
description: str | None
142+
cc_pair_summaries: list[CCPairSummary]
143+
is_up_to_date: bool
144+
is_public: bool
145+
users: list[UUID]
146+
groups: list[int]
147+
federated_connector_summaries: list[FederatedConnectorSummary] = Field(
148+
default_factory=list
149+
)
150+
151+
@classmethod
152+
def from_model(cls, document_set: DocumentSetDBModel) -> "DocumentSetSummary":
153+
"""Create a summary from a DocumentSet database model"""
154+
return cls(
155+
id=document_set.id,
156+
name=document_set.name,
157+
description=document_set.description,
158+
cc_pair_summaries=[
159+
CCPairSummary(
160+
id=cc_pair.id,
161+
name=cc_pair.name,
162+
source=cc_pair.connector.source,
163+
access_type=cc_pair.access_type,
164+
)
165+
for cc_pair in document_set.connector_credential_pairs
166+
],
167+
is_up_to_date=document_set.is_up_to_date,
168+
is_public=document_set.is_public,
169+
users=[user.id for user in document_set.users],
170+
groups=[group.id for group in document_set.groups],
171+
federated_connector_summaries=[
172+
FederatedConnectorSummary(
173+
id=fc_mapping.federated_connector_id,
174+
name=f"{fc_mapping.federated_connector.source.replace('_', ' ').title()}",
175+
source=fc_mapping.federated_connector.source,
176+
entities=fc_mapping.entities,
177+
)
178+
for fc_mapping in document_set.federated_connectors
179+
if fc_mapping.federated_connector is not None
180+
],
181+
)

backend/onyx/server/features/persona/api.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@
2525
from onyx.db.persona import create_update_persona
2626
from onyx.db.persona import delete_persona_label
2727
from onyx.db.persona import get_assistant_labels
28+
from onyx.db.persona import get_minimal_persona_snapshots_for_user
2829
from onyx.db.persona import get_persona_by_id
29-
from onyx.db.persona import get_personas_for_user
30+
from onyx.db.persona import get_persona_snapshots_for_user
3031
from onyx.db.persona import mark_persona_as_deleted
3132
from onyx.db.persona import mark_persona_as_not_deleted
3233
from onyx.db.persona import update_all_personas_display_priority
@@ -45,6 +46,7 @@
4546
from onyx.server.features.persona.models import FullPersonaSnapshot
4647
from onyx.server.features.persona.models import GenerateStarterMessageRequest
4748
from onyx.server.features.persona.models import ImageGenerationToolStatus
49+
from onyx.server.features.persona.models import MinimalPersonaSnapshot
4850
from onyx.server.features.persona.models import PersonaLabelCreate
4951
from onyx.server.features.persona.models import PersonaLabelResponse
5052
from onyx.server.features.persona.models import PersonaSharedNotificationData
@@ -53,6 +55,9 @@
5355
from onyx.server.features.persona.models import PromptSnapshot
5456
from onyx.server.models import DisplayPriorityRequest
5557
from onyx.server.settings.store import load_settings
58+
from onyx.tools.tool_implementations.images.image_generation_tool import (
59+
ImageGenerationTool,
60+
)
5661
from onyx.tools.utils import is_image_generation_available
5762
from onyx.utils.logger import setup_logger
5863
from onyx.utils.telemetry import create_milestone_and_report
@@ -165,16 +170,12 @@ def list_personas_admin(
165170
include_deleted: bool = False,
166171
get_editable: bool = Query(False, description="If true, return editable personas"),
167172
) -> list[PersonaSnapshot]:
168-
return [
169-
PersonaSnapshot.from_model(persona)
170-
for persona in get_personas_for_user(
171-
db_session=db_session,
172-
user=user,
173-
get_editable=get_editable,
174-
include_deleted=include_deleted,
175-
joinedload_all=True,
176-
)
177-
]
173+
return get_persona_snapshots_for_user(
174+
user=user,
175+
db_session=db_session,
176+
get_editable=get_editable,
177+
include_deleted=include_deleted,
178+
)
178179

179180

180181
@admin_router.patch("/{persona_id}/undelete")
@@ -414,14 +415,12 @@ def list_personas(
414415
db_session: Session = Depends(get_session),
415416
include_deleted: bool = False,
416417
persona_ids: list[int] = Query(None),
417-
) -> list[PersonaSnapshot]:
418-
personas = get_personas_for_user(
418+
) -> list[MinimalPersonaSnapshot]:
419+
personas = get_minimal_persona_snapshots_for_user(
419420
user=user,
420421
include_deleted=include_deleted,
421422
db_session=db_session,
422423
get_editable=False,
423-
joinedload_all=True,
424-
include_prompt=False,
425424
)
426425

427426
if persona_ids:
@@ -432,12 +431,14 @@ def list_personas(
432431
p
433432
for p in personas
434433
if not (
435-
any(tool.in_code_tool_id == "ImageGenerationTool" for tool in p.tools)
434+
any(
435+
tool.in_code_tool_id == ImageGenerationTool.__name__ for tool in p.tools
436+
)
436437
and not is_image_generation_available(db_session=db_session)
437438
)
438439
]
439440

440-
return [PersonaSnapshot.from_model(p) for p in personas]
441+
return personas
441442

442443

443444
@basic_router.get("/{persona_id}")

0 commit comments

Comments
 (0)