Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 23 additions & 34 deletions backend/ee/onyx/background/celery/apps/heavy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from ee.onyx.background.task_name_builders import query_history_task_name
from ee.onyx.server.query_history.api import fetch_and_process_chat_session_history
from ee.onyx.server.query_history.api import ONYX_ANONYMIZED_EMAIL
from ee.onyx.server.query_history.models import ChatSessionSnapshot
from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
from onyx.background.celery.apps.heavy import celery_app
from onyx.background.task_utils import construct_query_history_report_name
Expand Down Expand Up @@ -45,6 +44,13 @@ def export_query_history_task(self: Task, *, start: datetime, end: datetime) ->
task_id = self.request.id
start_time = datetime.now(tz=timezone.utc)

stream = io.StringIO()
writer = csv.DictWriter(
stream,
fieldnames=list(QuestionAnswerPairSnapshot.model_fields.keys()),
)
writer.writeheader()

with get_session_with_current_tenant() as db_session:
try:
register_task(
Expand All @@ -55,15 +61,23 @@ def export_query_history_task(self: Task, *, start: datetime, end: datetime) ->
start_time=start_time,
)

complete_chat_session_history: list[ChatSessionSnapshot] = (
fetch_and_process_chat_session_history(
db_session=db_session,
start=start,
end=end,
feedback_type=None,
limit=None,
)
snapshot_generator = fetch_and_process_chat_session_history(
db_session=db_session,
start=start,
end=end,
)

for snapshot in snapshot_generator:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
snapshot.user_email = ONYX_ANONYMIZED_EMAIL

writer.writerows(
qa_pair.to_json()
for qa_pair in QuestionAnswerPairSnapshot.from_chat_session_snapshot(
snapshot
)
)

except Exception:
logger.exception(f"Failed to export query history with {task_id=}")
mark_task_as_finished_with_id(
Expand All @@ -73,31 +87,6 @@ def export_query_history_task(self: Task, *, start: datetime, end: datetime) ->
)
raise

if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
complete_chat_session_history = [
ChatSessionSnapshot(
**chat_session_snapshot.model_dump(), user_email=ONYX_ANONYMIZED_EMAIL
)
for chat_session_snapshot in complete_chat_session_history
]

qa_pairs: list[QuestionAnswerPairSnapshot] = [
qa_pair
for chat_session_snapshot in complete_chat_session_history
for qa_pair in QuestionAnswerPairSnapshot.from_chat_session_snapshot(
chat_session_snapshot
)
]

stream = io.StringIO()
writer = csv.DictWriter(
stream,
fieldnames=list(QuestionAnswerPairSnapshot.model_fields.keys()),
)
writer.writeheader()
for row in qa_pairs:
writer.writerow(row.to_json())

report_name = construct_query_history_report_name(task_id)
with get_session_with_current_tenant() as db_session:
try:
Expand Down
68 changes: 40 additions & 28 deletions backend/ee/onyx/server/query_history/api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
Expand All @@ -10,7 +11,6 @@
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session

from ee.onyx.db.query_history import fetch_chat_sessions_eagerly_by_time
from ee.onyx.db.query_history import get_all_query_history_export_tasks
from ee.onyx.db.query_history import get_page_of_chat_sessions
from ee.onyx.db.query_history import get_total_filtered_chat_sessions_count
Expand Down Expand Up @@ -45,6 +45,7 @@
from onyx.server.documents.models import PaginatedReturn
from onyx.server.query_and_chat.models import ChatSessionDetails
from onyx.server.query_and_chat.models import ChatSessionsResponse
from onyx.utils.threadpool_concurrency import parallel_yield

router = APIRouter()

Expand All @@ -61,41 +62,52 @@ def ensure_query_history_is_enabled(
)


def yield_snapshot_from_chat_session(
chat_session: ChatSession,
db_session: Session,
) -> Generator[ChatSessionSnapshot | None]:
yield snapshot_from_chat_session(chat_session=chat_session, db_session=db_session)


def fetch_and_process_chat_session_history(
db_session: Session,
start: datetime,
end: datetime,
feedback_type: QAFeedbackType | None,
limit: int | None = 500,
) -> list[ChatSessionSnapshot]:
# observed to be slow a scale of 8192 sessions and 4 messages per session
) -> Generator[ChatSessionSnapshot]:
PAGE_SIZE = 100

page = 0
while True:
paged_chat_sessions = get_page_of_chat_sessions(
start_time=start,
end_time=end,
db_session=db_session,
page_num=page,
page_size=PAGE_SIZE,
)

# this is a little slow (5 seconds)
chat_sessions = fetch_chat_sessions_eagerly_by_time(
start=start, end=end, db_session=db_session, limit=limit
)
paged_snapshots = parallel_yield(
[
yield_snapshot_from_chat_session(
db_session=db_session,
chat_session=chat_session,
)
for chat_session in paged_chat_sessions
]
)

# this is VERY slow (80 seconds) due to create_chat_chain being called
# for each session. Needs optimizing.
chat_session_snapshots = [
snapshot_from_chat_session(chat_session=chat_session, db_session=db_session)
for chat_session in chat_sessions
]

valid_snapshots = [
snapshot for snapshot in chat_session_snapshots if snapshot is not None
]

if feedback_type:
valid_snapshots = [
snapshot
for snapshot in valid_snapshots
if any(
message.feedback_type == feedback_type for message in snapshot.messages
)
]
for snapshot in paged_snapshots:
if snapshot:
yield snapshot

# If we've fetched *less* than a `PAGE_SIZE` worth
# of data, we have reached the end of the
# pagination sequence; break.
if len(paged_chat_sessions) < PAGE_SIZE:
break

return valid_snapshots
page += 1


def snapshot_from_chat_session(
Expand Down
Loading