Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 55 additions & 48 deletions backend/onyx/chat/process_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import traceback
from collections.abc import Callable
from collections.abc import Iterator
from typing import Any
from typing import cast
from typing import Dict
from typing import Protocol
from uuid import UUID

Expand All @@ -26,12 +28,13 @@
from onyx.chat.models import QADocsResponse
from onyx.chat.models import StreamingError
from onyx.chat.models import UserKnowledgeFilePacket
from onyx.chat.packet_proccessing.process_streamed_packets import (
process_streamed_packets,
)
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
from onyx.chat.turn import fast_chat_turn
from onyx.chat.turn.infra.chat_turn_event_stream import convert_to_packet_obj
from onyx.chat.turn.models import DependenciesToMaybeRemove
from onyx.chat.turn.models import RunDependencies
from onyx.chat.user_files.parse_user_files import parse_user_files
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
Expand All @@ -44,7 +47,6 @@
from onyx.context.search.enums import OptionalSearchSetting
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import SavedSearchDoc
from onyx.context.search.retrieval.search_runner import (
inference_sections_from_ids,
)
Expand Down Expand Up @@ -83,11 +85,7 @@
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.server.query_and_chat.streaming_models import CitationDelta
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import MessageDelta
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.utils import get_json_line
from onyx.tools.force import ForceUseTool
from onyx.tools.models import SearchToolOverrideKwargs
Expand Down Expand Up @@ -778,10 +776,29 @@ def stream_chat_message_objects(
skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation,
project_instructions=project_instructions,
)

# Process streamed packets using the new packet processing module
yield from process_streamed_packets(
answer_processed_output=answer.processed_streamed_output,
type_to_role = {
"human": "user",
"assistant": "assistant",
"system": "system",
"function": "function",
}
other_messages = [
{"role": type_to_role[message.type], "content": message.content}
for message in answer.graph_inputs.prompt_builder.build()
if message.type != "system"
]
yield from fast_chat_turn.fast_chat_turn(
messages=other_messages,
dependencies=RunDependencies(
llm=answer.graph_tooling.primary_llm,
search_tool=answer.graph_tooling.search_tool,
db_session=db_session,
dependencies_to_maybe_remove=DependenciesToMaybeRemove(
chat_session_id=chat_session_id,
message_id=reserved_message_id,
research_type=answer.graph_config.behavior.research_type,
),
),
)

except ValueError as e:
Expand Down Expand Up @@ -843,7 +860,15 @@ def stream_chat_message(
document_retrieval_latency = time.time() - start_time
logger.debug(f"First doc time: {document_retrieval_latency}")

yield get_json_line(obj.model_dump())
# Convert Pydantic models to dictionaries for JSON serialization
if hasattr(obj, "model_dump"):
obj_dict = obj.model_dump()
elif hasattr(obj, "dict"):
obj_dict = obj.dict()
else:
obj_dict = obj

yield get_json_line(obj_dict)


def remove_answer_citations(answer: str) -> str:
Expand All @@ -854,46 +879,28 @@ def remove_answer_citations(answer: str) -> str:

@log_function_time()
def gather_stream(
packets: AnswerStream,
packets: Iterator[Dict[str, Any]],
) -> ChatBasicResponse:
answer = ""
citations: list[CitationInfo] = []
error_msg: str | None = None
message_id: int | None = None
top_documents: list[SavedSearchDoc] = []

for packet in packets:
if isinstance(packet, Packet):
# Handle the different packet object types
if isinstance(packet.obj, MessageStart):
# MessageStart contains the initial content and final documents
if packet.obj.content:
answer += packet.obj.content
if packet.obj.final_documents:
top_documents = packet.obj.final_documents
elif isinstance(packet.obj, MessageDelta):
# MessageDelta contains incremental content updates
if packet.obj.content:
answer += packet.obj.content
elif isinstance(packet.obj, CitationDelta):
# CitationDelta contains citation information
if packet.obj.citations:
citations.extend(packet.obj.citations)
elif isinstance(packet, StreamingError):
error_msg = packet.error
elif isinstance(packet, MessageResponseIDInfo):
message_id = packet.reserved_assistant_message_id

if message_id is None:
raise ValueError("Message ID is required")
if packet != {"type": "event"}:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Code Issue

Lines: 886-887

        if packet != {"type": "event"}:
            print(packet)

This print statement is debug code and should not be present in production. It should be removed or replaced with a proper logging call.

print(packet)

# Convert packet to PacketObj when possible
packet_obj = convert_to_packet_obj(packet)
if packet_obj:
# Handle PacketObj types that contain text content
if hasattr(packet_obj, "content") and packet_obj.content:
answer += packet_obj.content
elif "text" in packet:
# Fallback for legacy packet format
answer += packet["text"]

return ChatBasicResponse(
answer=answer,
answer_citationless=remove_answer_citations(answer),
cited_documents={
citation.citation_num: citation.document_id for citation in citations
},
message_id=message_id,
error_msg=error_msg,
top_documents=top_documents,
cited_documents={},

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Code Issue

Lines: 902-905

        cited_documents={},
        message_id=0,
        error_msg=None,
        top_documents=[],

These hardcoded default values indicate a significant loss of functionality. The gather_stream function no longer extracts citations, message ID, error messages, or top documents from the stream, which were previously handled. This is a critical functional regression.

message_id=0,
error_msg=None,
top_documents=[],
)
1 change: 1 addition & 0 deletions backend/onyx/chat/turn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Turn module for chat functionality

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Code Issue

Lines: 1-1

# Turn module for chat functionality

This comment could be replaced by a package docstring for better documentation practices and accessibility. While not strictly problematic, it's a missed opportunity for more robust documentation.

219 changes: 219 additions & 0 deletions backend/onyx/chat/turn/fast_chat_turn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
import re
from typing import cast
from uuid import UUID

from agents import Agent
from agents import ModelSettings
from agents import RunItemStreamEvent
from agents.extensions.models.litellm_model import LitellmModel
from sqlalchemy.orm import Session

from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import AggregatedDRContext
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
GeneratedImageFullResult,
)
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.chat.turn.infra.chat_turn_event_stream import OnyxRunner
from onyx.chat.turn.infra.chat_turn_orchestration import unified_event_stream
from onyx.chat.turn.infra.packet_translation import default_packet_translation
from onyx.chat.turn.models import MyContext
from onyx.chat.turn.models import RunDependencies
from onyx.context.search.models import InferenceSection
from onyx.db.chat import create_search_doc_from_inference_section
from onyx.db.chat import update_db_session_with_messages
from onyx.db.models import ChatMessage__SearchDoc
from onyx.db.models import ResearchAgentIteration
from onyx.db.models import ResearchAgentIterationSubStep
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.tools.tool_implementations_v2.internal_search import internal_search
from onyx.tools.tool_implementations_v2.web import web_fetch
from onyx.tools.tool_implementations_v2.web import web_search


# TODO: Dependency injection?
@unified_event_stream
def fast_chat_turn(messages: list[dict], dependencies: RunDependencies) -> None:
ctx = MyContext(
run_dependencies=dependencies,
aggregated_context=AggregatedDRContext(
context="context",
cited_documents=[],
is_internet_marker_dict={},
global_iteration_responses=[], # TODO: the only field that matters for now
),
iteration_instructions=[],
)
agent = Agent(
name="Assistant",
model=LitellmModel(
model=dependencies.llm.config.model_name,
api_key=dependencies.llm.config.api_key,
),
tools=[web_search, web_fetch, internal_search],
model_settings=ModelSettings(
temperature=0.0,
include_usage=True,
),
)

bridge = OnyxRunner().run_streamed(agent, messages, context=ctx, max_turns=100)
final_answer = "filler final answer"
for ev in bridge.events():
ctx.current_run_step
obj = default_packet_translation(ev)
print(ev)
# TODO this obviously won't work for cancellation
if isinstance(ev, RunItemStreamEvent):
ev = cast(RunItemStreamEvent, ev)
if ev.name == "message_output_created":
final_answer = ev.item.raw_item.content[0].text
if obj:
dependencies.emitter.emit(Packet(ind=ctx.current_run_step, obj=obj))
save_iteration(
db_session=dependencies.db_session,
message_id=dependencies.dependencies_to_maybe_remove.message_id,
chat_session_id=dependencies.dependencies_to_maybe_remove.chat_session_id,
research_type=dependencies.dependencies_to_maybe_remove.research_type,
ctx=ctx,
final_answer=final_answer,
all_cited_documents=[],
)
# TODO: Error handling
# Should there be a timeout and some error on the queue?
dependencies.emitter.emit(
Packet(ind=ctx.current_run_step, obj=OverallStop(type="stop"))
)


# TODO: Figure out a way to persist information is robust to cancellation,
# modular so easily testable in unit tests and evals [likely injecting some higher
# level session manager and span sink], potentially has some robustness off the critical path,
# and promotes clean separation of concerns.
def save_iteration(
db_session: Session,
message_id: int,
chat_session_id: UUID,
research_type: ResearchType,
ctx: MyContext,
final_answer: str,
all_cited_documents: list[InferenceSection],
) -> None:
# first, insert the search_docs
is_internet_marker_dict = {}
search_docs = [
create_search_doc_from_inference_section(
inference_section=inference_section,
is_internet=is_internet_marker_dict.get(
inference_section.center_chunk.document_id, False
), # TODO: revisit
db_session=db_session,
commit=False,
)
for inference_section in all_cited_documents
]

# then, map_search_docs to message
_insert_chat_message_search_doc_pair(
message_id, [search_doc.id for search_doc in search_docs], db_session
)

# lastly, insert the citations
citation_dict: dict[int, int] = {}
cited_doc_nrs = _extract_citation_numbers(final_answer)
if search_docs:
for cited_doc_nr in cited_doc_nrs:
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id

# Update the chat message and its parent message in database
update_db_session_with_messages(
db_session=db_session,
chat_message_id=message_id,
chat_session_id=chat_session_id,
is_agentic=research_type == ResearchType.DEEP,
message=final_answer,
citations=citation_dict,
research_type=research_type,
research_plan={},
final_documents=search_docs,
update_parent_message=True,
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
token_count=0,
)

for iteration_preparation in ctx.iteration_instructions:
research_agent_iteration_step = ResearchAgentIteration(
primary_question_id=message_id,
reasoning=iteration_preparation.reasoning,
purpose=iteration_preparation.purpose,
iteration_nr=iteration_preparation.iteration_nr,
)
db_session.add(research_agent_iteration_step)

for iteration_answer in ctx.aggregated_context.global_iteration_responses:

retrieved_search_docs = convert_inference_sections_to_search_docs(
list(iteration_answer.cited_documents.values())
)

# Convert SavedSearchDoc objects to JSON-serializable format
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]

research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
primary_question_id=message_id,
iteration_nr=iteration_answer.iteration_nr,
iteration_sub_step_nr=iteration_answer.parallelization_nr,
sub_step_instructions=iteration_answer.question,
sub_step_tool_id=iteration_answer.tool_id,
sub_answer=iteration_answer.answer,
reasoning=iteration_answer.reasoning,
claims=iteration_answer.claims,
cited_doc_results=serialized_search_docs,
generated_images=(
GeneratedImageFullResult(images=iteration_answer.generated_images)
if iteration_answer.generated_images
else None
),
additional_data=iteration_answer.additional_data,
)
db_session.add(research_agent_iteration_sub_step)

db_session.commit()


def _insert_chat_message_search_doc_pair(
message_id: int, search_doc_ids: list[int], db_session: Session
) -> None:
"""
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.

Args:
message_id: The ID of the chat message
search_doc_id: The ID of the search document
db_session: The database session
"""
for search_doc_id in search_doc_ids:
chat_message_search_doc = ChatMessage__SearchDoc(
chat_message_id=message_id, search_doc_id=search_doc_id
)
db_session.add(chat_message_search_doc)


def _extract_citation_numbers(text: str) -> list[int]:
"""
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
Returns a list of all unique citation numbers found.
"""
# Pattern to match [[number]] or [[number1, number2, ...]]
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
matches = re.findall(pattern, text)

cited_numbers = []
for match in matches:
# Split by comma and extract all numbers
numbers = [int(num.strip()) for num in match.split(",")]
cited_numbers.extend(numbers)

return list(set(cited_numbers)) # Return unique numbers
1 change: 1 addition & 0 deletions backend/onyx/chat/turn/infra/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Infrastructure module for chat turn orchestration

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Code Issue

Lines: 1-1

# Infrastructure module for chat turn orchestration

Using a # comment for module-level documentation is less standard than using a module docstring. Docstrings are accessible via help() and are used by documentation generation tools, improving maintainability and discoverability.

Loading
Loading