Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions backend/alembic/versions/1c3f8a7b5d4e_add_analysis_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""add_analysis_tool

Revision ID: 1c3f8a7b5d4e
Revises: 505c488f6662
Create Date: 2025-02-14 00:00:00

"""

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "1c3f8a7b5d4e"
down_revision = "505c488f6662"
branch_labels = None
depends_on = None


ANALYSIS_TOOL = {
"name": "AnalysisTool",
"display_name": "Code Interpreter",
"description": (
"The Code Interpreter Action lets assistants execute Python in an isolated runtime. "
"It can process staged files, read and write artifacts, stream stdout and stderr, "
"and return generated outputs for the chat session."
),
"in_code_tool_id": "AnalysisTool",
}


def upgrade() -> None:
conn = op.get_bind()
conn.execute(sa.text("BEGIN"))
try:
existing = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
ANALYSIS_TOOL,
).fetchone()

if existing:
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :name,
display_name = :display_name,
description = :description
WHERE in_code_tool_id = :in_code_tool_id
"""
),
ANALYSIS_TOOL,
)
else:
conn.execute(
sa.text(
"""
INSERT INTO tool (name, display_name, description, in_code_tool_id)
VALUES (:name, :display_name, :description, :in_code_tool_id)
"""
),
ANALYSIS_TOOL,
)

conn.execute(sa.text("COMMIT"))
except Exception:
conn.execute(sa.text("ROLLBACK"))
raise


def downgrade() -> None:
# Do not delete the tool entry on downgrade; leaving it is safe and keeps migrations idempotent.
pass
1 change: 1 addition & 0 deletions backend/onyx/agents/agent_search/dr/conditional_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def decision_router(state: MainState) -> list[Send | Hashable] | DRPath | str:
DRPath.WEB_SEARCH,
DRPath.KNOWLEDGE_GRAPH,
DRPath.IMAGE_GENERATION,
DRPath.ANALYSIS_TOOL,
)
and len(state.query_list) == 0
):
Expand Down
1 change: 1 addition & 0 deletions backend/onyx/agents/agent_search/dr/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
DRPath.WEB_SEARCH: 1.5,
DRPath.IMAGE_GENERATION: 3.0,
DRPath.GENERIC_TOOL: 1.5, # TODO: see todo in OrchestratorTool
DRPath.ANALYSIS_TOOL: 2.0,
DRPath.CLOSER: 0.0,
}

Expand Down
1 change: 1 addition & 0 deletions backend/onyx/agents/agent_search/dr/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class DRPath(str, Enum):
WEB_SEARCH = "Web Search"
IMAGE_GENERATION = "Image Generation"
GENERIC_INTERNAL_TOOL = "Generic Internal Tool"
ANALYSIS_TOOL = "Analysis Tool"
CLOSER = "Closer"
LOGGER = "Logger"
END = "End"
13 changes: 10 additions & 3 deletions backend/onyx/agents/agent_search/dr/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from onyx.agents.agent_search.dr.nodes.dr_a3_logger import logging
from onyx.agents.agent_search.dr.states import MainInput
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.sub_agents.analysis_tool.dr_analysis_tool_graph_builder import (
dr_analysis_tool_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_graph_builder import (
dr_basic_search_graph_builder,
)
Expand Down Expand Up @@ -58,12 +61,15 @@ def dr_graph_builder() -> StateGraph:
image_generation_graph = dr_image_generation_graph_builder().compile()
graph.add_node(DRPath.IMAGE_GENERATION, image_generation_graph)

custom_tool_graph = dr_custom_tool_graph_builder().compile()
graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)

generic_internal_tool_graph = dr_generic_internal_tool_graph_builder().compile()
graph.add_node(DRPath.GENERIC_INTERNAL_TOOL, generic_internal_tool_graph)

analysis_tool_graph = dr_analysis_tool_graph_builder().compile()
graph.add_node(DRPath.ANALYSIS_TOOL, analysis_tool_graph)

custom_tool_graph = dr_custom_tool_graph_builder().compile()
graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)

graph.add_node(DRPath.CLOSER, closer)
graph.add_node(DRPath.LOGGER, logging)

Expand All @@ -81,6 +87,7 @@ def dr_graph_builder() -> StateGraph:
graph.add_edge(start_key=DRPath.IMAGE_GENERATION, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.GENERIC_TOOL, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.GENERIC_INTERNAL_TOOL, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.ANALYSIS_TOOL, end_key=DRPath.ORCHESTRATOR)

graph.add_conditional_edges(DRPath.CLOSER, completeness_router)
graph.add_edge(start_key=DRPath.LOGGER, end_key=END)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.server.query_and_chat.streaming_models import StreamingType
from onyx.tools.tool_implementations.analysis.analysis_tool import AnalysisTool
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
Expand Down Expand Up @@ -134,6 +135,9 @@ def _get_available_tools(
continue
llm_path = DRPath.KNOWLEDGE_GRAPH.value
path = DRPath.KNOWLEDGE_GRAPH
elif isinstance(tool, AnalysisTool):
llm_path = DRPath.ANALYSIS_TOOL.value
path = DRPath.ANALYSIS_TOOL
elif isinstance(tool, ImageGenerationTool):
llm_path = DRPath.IMAGE_GENERATION.value
path = DRPath.IMAGE_GENERATION
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Analysis Tool sub-agent for deep research."""
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from datetime import datetime

from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter

from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger

logger = setup_logger()


def analysis_tool_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""Log the beginning of an Analysis Tool branch."""

node_start_time = datetime.now()
iteration_nr = state.iteration_nr

logger.debug(
f"Analysis Tool branch start for iteration {iteration_nr} at {datetime.now()}"
)

return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="analysis_tool",
node_name="branching",
node_start_time=node_start_time,
)
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import json
from datetime import datetime
from typing import Any
from typing import cast

from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter

from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
from onyx.tools.tool_implementations.analysis.analysis_tool import AnalysisTool
from onyx.tools.tool_implementations.analysis.analysis_tool import AnalysisToolResult
from onyx.utils.logger import setup_logger

logger = setup_logger()


def analysis_tool_act(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""Execute the Analysis Tool with any files supplied by the user."""

node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr

if not state.available_tools:
raise ValueError("available_tools is not set")

tool_key = state.tools_used[-1]
analysis_tool_info = state.available_tools[tool_key]
analysis_tool = cast(AnalysisTool | None, analysis_tool_info.tool_object)

if analysis_tool is None:
raise ValueError("analysis_tool is not set")

branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")

graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query

logger.debug(
"Tool call start for %s %s.%s at %s",
analysis_tool.llm_name,
iteration_nr,
parallelization_nr,
datetime.now(),
)

tool_args: dict[str, Any] | None = None
if graph_config.tooling.using_tool_calling_llm:
tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
query=branch_query,
base_question=base_question,
tool_description=analysis_tool_info.description,
)
tool_calling_msg = graph_config.tooling.primary_llm.invoke(
tool_use_prompt,
tools=[analysis_tool.tool_definition()],
tool_choice="required",
timeout_override=TF_DR_TIMEOUT_SHORT,
)

if isinstance(tool_calling_msg, AIMessage) and tool_calling_msg.tool_calls:
tool_args = tool_calling_msg.tool_calls[0].get("args")
else:
logger.warning(
"Tool-calling LLM did not emit a tool call for Analysis Tool"
)

if tool_args is None:
tool_args = analysis_tool.get_args_for_non_tool_calling_llm(
query=branch_query,
history=[],
llm=graph_config.tooling.primary_llm,
force_run=True,
)

if tool_args is None:
raise ValueError("Failed to obtain tool arguments from LLM")

if "files" in tool_args:
tool_args = {key: value for key, value in tool_args.items() if key != "files"}

override_kwargs = {"files": graph_config.inputs.files or []}

tool_responses = list(
analysis_tool.run(override_kwargs=override_kwargs, **tool_args)
)

analysis_result_obj: AnalysisToolResult | None = None
for response in tool_responses:
if isinstance(response.response, AnalysisToolResult):
analysis_result_obj = response.response
break

if analysis_result_obj is None:
raise ValueError("Analysis tool did not return a valid result")

final_result = analysis_tool.final_result(*tool_responses)
tool_result_str = json.dumps(final_result, ensure_ascii=False)

tool_str = (
f"Tool used: {analysis_tool.display_name}\n"
f"Description: {analysis_tool_info.description}\n"
f"Result: {tool_result_str}"
)

tool_summary_prompt = CUSTOM_TOOL_USE_PROMPT.build(
query=branch_query,
base_question=base_question,
tool_response=tool_str,
)
answer_string = str(
graph_config.tooling.primary_llm.invoke(
tool_summary_prompt,
timeout_override=TF_DR_TIMEOUT_SHORT,
).content
).strip()

artifact_file_ids = [
artifact.file_id
for artifact in analysis_result_obj.artifacts
if artifact.file_id
]

logger.debug(
"Tool call end for %s %s.%s at %s",
analysis_tool.llm_name,
iteration_nr,
parallelization_nr,
datetime.now(),
)

return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=analysis_tool.llm_name,
tool_id=analysis_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=[],
cited_documents={},
reasoning="",
additional_data=None,
response_type="json",
data=final_result,
file_ids=artifact_file_ids or None,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="analysis_tool",
node_name="tool_calling",
node_start_time=node_start_time,
)
],
)
Loading
Loading