-
Notifications
You must be signed in to change notification settings - Fork 2k
Simple Agent V2 #5501
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Simple Agent V2 #5501
Changes from all commits
e042454
b401bb0
dbbbb19
ad1bcfa
5aab51e
09381a4
c178cb9
37bcf63
9b600fc
1d29302
3d14fe9
68020a7
6f73659
e9baf77
41accfd
9d5a1b6
d59c85a
74ed3c1
7cb01e5
ae4fafe
a370cc2
89e770e
84726d3
a260368
c1d19f9
5a7c391
77e8342
9553ce0
0940c8d
accbd7a
9634870
2b309b0
afecb4b
3b346af
2538738
28b85cb
2c547c7
6103a56
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,9 @@ | |
import traceback | ||
from collections.abc import Callable | ||
from collections.abc import Iterator | ||
from typing import Any | ||
from typing import cast | ||
from typing import Dict | ||
from typing import Protocol | ||
from uuid import UUID | ||
|
||
|
@@ -26,12 +28,13 @@ | |
from onyx.chat.models import QADocsResponse | ||
from onyx.chat.models import StreamingError | ||
from onyx.chat.models import UserKnowledgeFilePacket | ||
from onyx.chat.packet_proccessing.process_streamed_packets import ( | ||
process_streamed_packets, | ||
) | ||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder | ||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message | ||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message | ||
from onyx.chat.turn import fast_chat_turn | ||
from onyx.chat.turn.infra.chat_turn_event_stream import convert_to_packet_obj | ||
from onyx.chat.turn.models import DependenciesToMaybeRemove | ||
from onyx.chat.turn.models import RunDependencies | ||
from onyx.chat.user_files.parse_user_files import parse_user_files | ||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE | ||
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH | ||
|
@@ -44,7 +47,6 @@ | |
from onyx.context.search.enums import OptionalSearchSetting | ||
from onyx.context.search.models import InferenceSection | ||
from onyx.context.search.models import RetrievalDetails | ||
from onyx.context.search.models import SavedSearchDoc | ||
from onyx.context.search.retrieval.search_runner import ( | ||
inference_sections_from_ids, | ||
) | ||
|
@@ -83,11 +85,7 @@ | |
from onyx.llm.utils import litellm_exception_to_error_msg | ||
from onyx.natural_language_processing.utils import get_tokenizer | ||
from onyx.server.query_and_chat.models import CreateChatMessageRequest | ||
from onyx.server.query_and_chat.streaming_models import CitationDelta | ||
from onyx.server.query_and_chat.streaming_models import CitationInfo | ||
from onyx.server.query_and_chat.streaming_models import MessageDelta | ||
from onyx.server.query_and_chat.streaming_models import MessageStart | ||
from onyx.server.query_and_chat.streaming_models import Packet | ||
from onyx.server.utils import get_json_line | ||
from onyx.tools.force import ForceUseTool | ||
from onyx.tools.models import SearchToolOverrideKwargs | ||
|
@@ -778,10 +776,29 @@ def stream_chat_message_objects( | |
skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation, | ||
project_instructions=project_instructions, | ||
) | ||
|
||
# Process streamed packets using the new packet processing module | ||
yield from process_streamed_packets( | ||
answer_processed_output=answer.processed_streamed_output, | ||
type_to_role = { | ||
"human": "user", | ||
"assistant": "assistant", | ||
"system": "system", | ||
"function": "function", | ||
} | ||
other_messages = [ | ||
{"role": type_to_role[message.type], "content": message.content} | ||
for message in answer.graph_inputs.prompt_builder.build() | ||
if message.type != "system" | ||
] | ||
yield from fast_chat_turn.fast_chat_turn( | ||
messages=other_messages, | ||
dependencies=RunDependencies( | ||
llm=answer.graph_tooling.primary_llm, | ||
search_tool=answer.graph_tooling.search_tool, | ||
db_session=db_session, | ||
dependencies_to_maybe_remove=DependenciesToMaybeRemove( | ||
chat_session_id=chat_session_id, | ||
message_id=reserved_message_id, | ||
research_type=answer.graph_config.behavior.research_type, | ||
), | ||
), | ||
) | ||
|
||
except ValueError as e: | ||
|
@@ -843,7 +860,15 @@ def stream_chat_message( | |
document_retrieval_latency = time.time() - start_time | ||
logger.debug(f"First doc time: {document_retrieval_latency}") | ||
|
||
yield get_json_line(obj.model_dump()) | ||
# Convert Pydantic models to dictionaries for JSON serialization | ||
if hasattr(obj, "model_dump"): | ||
obj_dict = obj.model_dump() | ||
elif hasattr(obj, "dict"): | ||
obj_dict = obj.dict() | ||
else: | ||
obj_dict = obj | ||
|
||
yield get_json_line(obj_dict) | ||
|
||
|
||
def remove_answer_citations(answer: str) -> str: | ||
|
@@ -854,46 +879,28 @@ def remove_answer_citations(answer: str) -> str: | |
|
||
@log_function_time() | ||
def gather_stream( | ||
packets: AnswerStream, | ||
packets: Iterator[Dict[str, Any]], | ||
) -> ChatBasicResponse: | ||
answer = "" | ||
citations: list[CitationInfo] = [] | ||
error_msg: str | None = None | ||
message_id: int | None = None | ||
top_documents: list[SavedSearchDoc] = [] | ||
|
||
for packet in packets: | ||
if isinstance(packet, Packet): | ||
# Handle the different packet object types | ||
if isinstance(packet.obj, MessageStart): | ||
# MessageStart contains the initial content and final documents | ||
if packet.obj.content: | ||
answer += packet.obj.content | ||
if packet.obj.final_documents: | ||
top_documents = packet.obj.final_documents | ||
elif isinstance(packet.obj, MessageDelta): | ||
# MessageDelta contains incremental content updates | ||
if packet.obj.content: | ||
answer += packet.obj.content | ||
elif isinstance(packet.obj, CitationDelta): | ||
# CitationDelta contains citation information | ||
if packet.obj.citations: | ||
citations.extend(packet.obj.citations) | ||
elif isinstance(packet, StreamingError): | ||
error_msg = packet.error | ||
elif isinstance(packet, MessageResponseIDInfo): | ||
message_id = packet.reserved_assistant_message_id | ||
|
||
if message_id is None: | ||
raise ValueError("Message ID is required") | ||
if packet != {"type": "event"}: | ||
print(packet) | ||
|
||
# Convert packet to PacketObj when possible | ||
packet_obj = convert_to_packet_obj(packet) | ||
if packet_obj: | ||
# Handle PacketObj types that contain text content | ||
if hasattr(packet_obj, "content") and packet_obj.content: | ||
answer += packet_obj.content | ||
elif "text" in packet: | ||
# Fallback for legacy packet format | ||
answer += packet["text"] | ||
|
||
return ChatBasicResponse( | ||
answer=answer, | ||
answer_citationless=remove_answer_citations(answer), | ||
cited_documents={ | ||
citation.citation_num: citation.document_id for citation in citations | ||
}, | ||
message_id=message_id, | ||
error_msg=error_msg, | ||
top_documents=top_documents, | ||
cited_documents={}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
message_id=0, | ||
error_msg=None, | ||
top_documents=[], | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Turn module for chat functionality | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,219 @@ | ||
import re | ||
from typing import cast | ||
from uuid import UUID | ||
|
||
from agents import Agent | ||
from agents import ModelSettings | ||
from agents import RunItemStreamEvent | ||
from agents.extensions.models.litellm_model import LitellmModel | ||
from sqlalchemy.orm import Session | ||
|
||
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose | ||
from onyx.agents.agent_search.dr.enums import ResearchType | ||
from onyx.agents.agent_search.dr.models import AggregatedDRContext | ||
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import ( | ||
GeneratedImageFullResult, | ||
) | ||
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs | ||
from onyx.chat.turn.infra.chat_turn_event_stream import OnyxRunner | ||
from onyx.chat.turn.infra.chat_turn_orchestration import unified_event_stream | ||
from onyx.chat.turn.infra.packet_translation import default_packet_translation | ||
from onyx.chat.turn.models import MyContext | ||
from onyx.chat.turn.models import RunDependencies | ||
from onyx.context.search.models import InferenceSection | ||
from onyx.db.chat import create_search_doc_from_inference_section | ||
from onyx.db.chat import update_db_session_with_messages | ||
from onyx.db.models import ChatMessage__SearchDoc | ||
from onyx.db.models import ResearchAgentIteration | ||
from onyx.db.models import ResearchAgentIterationSubStep | ||
from onyx.server.query_and_chat.streaming_models import OverallStop | ||
from onyx.server.query_and_chat.streaming_models import Packet | ||
from onyx.tools.tool_implementations_v2.internal_search import internal_search | ||
from onyx.tools.tool_implementations_v2.web import web_fetch | ||
from onyx.tools.tool_implementations_v2.web import web_search | ||
|
||
|
||
# TODO: Dependency injection? | ||
@unified_event_stream | ||
def fast_chat_turn(messages: list[dict], dependencies: RunDependencies) -> None: | ||
ctx = MyContext( | ||
run_dependencies=dependencies, | ||
aggregated_context=AggregatedDRContext( | ||
context="context", | ||
cited_documents=[], | ||
is_internet_marker_dict={}, | ||
global_iteration_responses=[], # TODO: the only field that matters for now | ||
), | ||
iteration_instructions=[], | ||
) | ||
agent = Agent( | ||
name="Assistant", | ||
model=LitellmModel( | ||
model=dependencies.llm.config.model_name, | ||
api_key=dependencies.llm.config.api_key, | ||
), | ||
tools=[web_search, web_fetch, internal_search], | ||
model_settings=ModelSettings( | ||
temperature=0.0, | ||
include_usage=True, | ||
), | ||
) | ||
|
||
bridge = OnyxRunner().run_streamed(agent, messages, context=ctx, max_turns=100) | ||
final_answer = "filler final answer" | ||
for ev in bridge.events(): | ||
ctx.current_run_step | ||
obj = default_packet_translation(ev) | ||
print(ev) | ||
# TODO this obviously won't work for cancellation | ||
if isinstance(ev, RunItemStreamEvent): | ||
ev = cast(RunItemStreamEvent, ev) | ||
if ev.name == "message_output_created": | ||
final_answer = ev.item.raw_item.content[0].text | ||
if obj: | ||
dependencies.emitter.emit(Packet(ind=ctx.current_run_step, obj=obj)) | ||
save_iteration( | ||
db_session=dependencies.db_session, | ||
message_id=dependencies.dependencies_to_maybe_remove.message_id, | ||
chat_session_id=dependencies.dependencies_to_maybe_remove.chat_session_id, | ||
research_type=dependencies.dependencies_to_maybe_remove.research_type, | ||
ctx=ctx, | ||
final_answer=final_answer, | ||
all_cited_documents=[], | ||
) | ||
# TODO: Error handling | ||
# Should there be a timeout and some error on the queue? | ||
dependencies.emitter.emit( | ||
Packet(ind=ctx.current_run_step, obj=OverallStop(type="stop")) | ||
) | ||
|
||
|
||
# TODO: Figure out a way to persist information is robust to cancellation, | ||
# modular so easily testable in unit tests and evals [likely injecting some higher | ||
# level session manager and span sink], potentially has some robustness off the critical path, | ||
# and promotes clean separation of concerns. | ||
def save_iteration( | ||
db_session: Session, | ||
message_id: int, | ||
chat_session_id: UUID, | ||
research_type: ResearchType, | ||
ctx: MyContext, | ||
final_answer: str, | ||
all_cited_documents: list[InferenceSection], | ||
) -> None: | ||
# first, insert the search_docs | ||
is_internet_marker_dict = {} | ||
search_docs = [ | ||
create_search_doc_from_inference_section( | ||
inference_section=inference_section, | ||
is_internet=is_internet_marker_dict.get( | ||
inference_section.center_chunk.document_id, False | ||
), # TODO: revisit | ||
db_session=db_session, | ||
commit=False, | ||
) | ||
for inference_section in all_cited_documents | ||
] | ||
|
||
# then, map_search_docs to message | ||
_insert_chat_message_search_doc_pair( | ||
message_id, [search_doc.id for search_doc in search_docs], db_session | ||
) | ||
|
||
# lastly, insert the citations | ||
citation_dict: dict[int, int] = {} | ||
cited_doc_nrs = _extract_citation_numbers(final_answer) | ||
if search_docs: | ||
for cited_doc_nr in cited_doc_nrs: | ||
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id | ||
|
||
# Update the chat message and its parent message in database | ||
update_db_session_with_messages( | ||
db_session=db_session, | ||
chat_message_id=message_id, | ||
chat_session_id=chat_session_id, | ||
is_agentic=research_type == ResearchType.DEEP, | ||
message=final_answer, | ||
citations=citation_dict, | ||
research_type=research_type, | ||
research_plan={}, | ||
final_documents=search_docs, | ||
update_parent_message=True, | ||
research_answer_purpose=ResearchAnswerPurpose.ANSWER, | ||
token_count=0, | ||
) | ||
|
||
for iteration_preparation in ctx.iteration_instructions: | ||
research_agent_iteration_step = ResearchAgentIteration( | ||
primary_question_id=message_id, | ||
reasoning=iteration_preparation.reasoning, | ||
purpose=iteration_preparation.purpose, | ||
iteration_nr=iteration_preparation.iteration_nr, | ||
) | ||
db_session.add(research_agent_iteration_step) | ||
|
||
for iteration_answer in ctx.aggregated_context.global_iteration_responses: | ||
|
||
retrieved_search_docs = convert_inference_sections_to_search_docs( | ||
list(iteration_answer.cited_documents.values()) | ||
) | ||
|
||
# Convert SavedSearchDoc objects to JSON-serializable format | ||
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs] | ||
|
||
research_agent_iteration_sub_step = ResearchAgentIterationSubStep( | ||
primary_question_id=message_id, | ||
iteration_nr=iteration_answer.iteration_nr, | ||
iteration_sub_step_nr=iteration_answer.parallelization_nr, | ||
sub_step_instructions=iteration_answer.question, | ||
sub_step_tool_id=iteration_answer.tool_id, | ||
sub_answer=iteration_answer.answer, | ||
reasoning=iteration_answer.reasoning, | ||
claims=iteration_answer.claims, | ||
cited_doc_results=serialized_search_docs, | ||
generated_images=( | ||
GeneratedImageFullResult(images=iteration_answer.generated_images) | ||
if iteration_answer.generated_images | ||
else None | ||
), | ||
additional_data=iteration_answer.additional_data, | ||
) | ||
db_session.add(research_agent_iteration_sub_step) | ||
|
||
db_session.commit() | ||
|
||
|
||
def _insert_chat_message_search_doc_pair( | ||
message_id: int, search_doc_ids: list[int], db_session: Session | ||
) -> None: | ||
""" | ||
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table. | ||
|
||
Args: | ||
message_id: The ID of the chat message | ||
search_doc_id: The ID of the search document | ||
db_session: The database session | ||
""" | ||
for search_doc_id in search_doc_ids: | ||
chat_message_search_doc = ChatMessage__SearchDoc( | ||
chat_message_id=message_id, search_doc_id=search_doc_id | ||
) | ||
db_session.add(chat_message_search_doc) | ||
|
||
|
||
def _extract_citation_numbers(text: str) -> list[int]: | ||
""" | ||
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]]. | ||
Returns a list of all unique citation numbers found. | ||
""" | ||
# Pattern to match [[number]] or [[number1, number2, ...]] | ||
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]" | ||
matches = re.findall(pattern, text) | ||
|
||
cited_numbers = [] | ||
for match in matches: | ||
# Split by comma and extract all numbers | ||
numbers = [int(num.strip()) for num in match.split(",")] | ||
cited_numbers.extend(numbers) | ||
|
||
return list(set(cited_numbers)) # Return unique numbers |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Infrastructure module for chat turn orchestration | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lines: 886-887
This
print
statement is debug code and should not be present in production. It should be removed or replaced with a proper logging call.