Skip to content

Commit 5bf4dab

Browse files
perf: Change query-exporting to use generators instead of expanding fully into memory (onyx-dot-app#4729)
* Change query-exporting to use generators instead of expanding fully into memory * Fix pagination logic Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * Add type annotation * Add early break if list of chat_sessions is empty --------- Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
1 parent 84d267a commit 5bf4dab

File tree

2 files changed

+66
-62
lines changed
  • backend/ee/onyx

2 files changed

+66
-62
lines changed

backend/ee/onyx/background/celery/apps/heavy.py

Lines changed: 23 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from ee.onyx.background.task_name_builders import query_history_task_name
1010
from ee.onyx.server.query_history.api import fetch_and_process_chat_session_history
1111
from ee.onyx.server.query_history.api import ONYX_ANONYMIZED_EMAIL
12-
from ee.onyx.server.query_history.models import ChatSessionSnapshot
1312
from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
1413
from onyx.background.celery.apps.heavy import celery_app
1514
from onyx.background.task_utils import construct_query_history_report_name
@@ -45,6 +44,13 @@ def export_query_history_task(self: Task, *, start: datetime, end: datetime) ->
4544
task_id = self.request.id
4645
start_time = datetime.now(tz=timezone.utc)
4746

47+
stream = io.StringIO()
48+
writer = csv.DictWriter(
49+
stream,
50+
fieldnames=list(QuestionAnswerPairSnapshot.model_fields.keys()),
51+
)
52+
writer.writeheader()
53+
4854
with get_session_with_current_tenant() as db_session:
4955
try:
5056
register_task(
@@ -55,15 +61,23 @@ def export_query_history_task(self: Task, *, start: datetime, end: datetime) ->
5561
start_time=start_time,
5662
)
5763

58-
complete_chat_session_history: list[ChatSessionSnapshot] = (
59-
fetch_and_process_chat_session_history(
60-
db_session=db_session,
61-
start=start,
62-
end=end,
63-
feedback_type=None,
64-
limit=None,
65-
)
64+
snapshot_generator = fetch_and_process_chat_session_history(
65+
db_session=db_session,
66+
start=start,
67+
end=end,
6668
)
69+
70+
for snapshot in snapshot_generator:
71+
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
72+
snapshot.user_email = ONYX_ANONYMIZED_EMAIL
73+
74+
writer.writerows(
75+
qa_pair.to_json()
76+
for qa_pair in QuestionAnswerPairSnapshot.from_chat_session_snapshot(
77+
snapshot
78+
)
79+
)
80+
6781
except Exception:
6882
logger.exception(f"Failed to export query history with {task_id=}")
6983
mark_task_as_finished_with_id(
@@ -73,31 +87,6 @@ def export_query_history_task(self: Task, *, start: datetime, end: datetime) ->
7387
)
7488
raise
7589

76-
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
77-
complete_chat_session_history = [
78-
ChatSessionSnapshot(
79-
**chat_session_snapshot.model_dump(), user_email=ONYX_ANONYMIZED_EMAIL
80-
)
81-
for chat_session_snapshot in complete_chat_session_history
82-
]
83-
84-
qa_pairs: list[QuestionAnswerPairSnapshot] = [
85-
qa_pair
86-
for chat_session_snapshot in complete_chat_session_history
87-
for qa_pair in QuestionAnswerPairSnapshot.from_chat_session_snapshot(
88-
chat_session_snapshot
89-
)
90-
]
91-
92-
stream = io.StringIO()
93-
writer = csv.DictWriter(
94-
stream,
95-
fieldnames=list(QuestionAnswerPairSnapshot.model_fields.keys()),
96-
)
97-
writer.writeheader()
98-
for row in qa_pairs:
99-
writer.writerow(row.to_json())
100-
10190
report_name = construct_query_history_report_name(task_id)
10291
with get_session_with_current_tenant() as db_session:
10392
try:

backend/ee/onyx/server/query_history/api.py

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Generator
12
from datetime import datetime
23
from datetime import timezone
34
from http import HTTPStatus
@@ -10,7 +11,6 @@
1011
from fastapi.responses import StreamingResponse
1112
from sqlalchemy.orm import Session
1213

13-
from ee.onyx.db.query_history import fetch_chat_sessions_eagerly_by_time
1414
from ee.onyx.db.query_history import get_all_query_history_export_tasks
1515
from ee.onyx.db.query_history import get_page_of_chat_sessions
1616
from ee.onyx.db.query_history import get_total_filtered_chat_sessions_count
@@ -45,6 +45,7 @@
4545
from onyx.server.documents.models import PaginatedReturn
4646
from onyx.server.query_and_chat.models import ChatSessionDetails
4747
from onyx.server.query_and_chat.models import ChatSessionsResponse
48+
from onyx.utils.threadpool_concurrency import parallel_yield
4849

4950
router = APIRouter()
5051

@@ -61,41 +62,55 @@ def ensure_query_history_is_enabled(
6162
)
6263

6364

65+
def yield_snapshot_from_chat_session(
66+
chat_session: ChatSession,
67+
db_session: Session,
68+
) -> Generator[ChatSessionSnapshot | None]:
69+
yield snapshot_from_chat_session(chat_session=chat_session, db_session=db_session)
70+
71+
6472
def fetch_and_process_chat_session_history(
6573
db_session: Session,
6674
start: datetime,
6775
end: datetime,
68-
feedback_type: QAFeedbackType | None,
6976
limit: int | None = 500,
70-
) -> list[ChatSessionSnapshot]:
71-
# observed to be slow a scale of 8192 sessions and 4 messages per session
77+
) -> Generator[ChatSessionSnapshot]:
78+
PAGE_SIZE = 100
79+
80+
page = 0
81+
while True:
82+
paged_chat_sessions = get_page_of_chat_sessions(
83+
start_time=start,
84+
end_time=end,
85+
db_session=db_session,
86+
page_num=page,
87+
page_size=PAGE_SIZE,
88+
)
7289

73-
# this is a little slow (5 seconds)
74-
chat_sessions = fetch_chat_sessions_eagerly_by_time(
75-
start=start, end=end, db_session=db_session, limit=limit
76-
)
90+
if not paged_chat_sessions:
91+
break
92+
93+
paged_snapshots = parallel_yield(
94+
[
95+
yield_snapshot_from_chat_session(
96+
db_session=db_session,
97+
chat_session=chat_session,
98+
)
99+
for chat_session in paged_chat_sessions
100+
]
101+
)
77102

78-
# this is VERY slow (80 seconds) due to create_chat_chain being called
79-
# for each session. Needs optimizing.
80-
chat_session_snapshots = [
81-
snapshot_from_chat_session(chat_session=chat_session, db_session=db_session)
82-
for chat_session in chat_sessions
83-
]
84-
85-
valid_snapshots = [
86-
snapshot for snapshot in chat_session_snapshots if snapshot is not None
87-
]
88-
89-
if feedback_type:
90-
valid_snapshots = [
91-
snapshot
92-
for snapshot in valid_snapshots
93-
if any(
94-
message.feedback_type == feedback_type for message in snapshot.messages
95-
)
96-
]
103+
for snapshot in paged_snapshots:
104+
if snapshot:
105+
yield snapshot
106+
107+
# If we've fetched *less* than a `PAGE_SIZE` worth
108+
# of data, we have reached the end of the
109+
# pagination sequence; break.
110+
if len(paged_chat_sessions) < PAGE_SIZE:
111+
break
97112

98-
return valid_snapshots
113+
page += 1
99114

100115

101116
def snapshot_from_chat_session(

0 commit comments

Comments
 (0)