Skip to content

Commit fc9dcd3

Browse files
generic_internal tools
1 parent 1dd19fa commit fc9dcd3

15 files changed

+474
-38
lines changed

backend/onyx/agents/agent_search/dr/enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,6 @@ class DRPath(str, Enum):
2525
KNOWLEDGE_GRAPH = "Knowledge Graph"
2626
INTERNET_SEARCH = "Internet Search"
2727
IMAGE_GENERATION = "Image Generation"
28+
GENERIC_INTERNAL_TOOL = "Generic Internal Tool"
2829
CLOSER = "Closer"
2930
END = "End"

backend/onyx/agents/agent_search/dr/graph_builder.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_graph_builder import (
1616
dr_custom_tool_graph_builder,
1717
)
18+
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_graph_builder import (
19+
dr_generic_internal_tool_graph_builder,
20+
)
1821
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_graph_builder import (
1922
dr_image_generation_graph_builder,
2023
)
@@ -59,6 +62,9 @@ def dr_graph_builder() -> StateGraph:
5962
custom_tool_graph = dr_custom_tool_graph_builder().compile()
6063
graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)
6164

65+
generic_internal_tool_graph = dr_generic_internal_tool_graph_builder().compile()
66+
graph.add_node(DRPath.GENERIC_INTERNAL_TOOL, generic_internal_tool_graph)
67+
6268
graph.add_node(DRPath.CLOSER, closer)
6369

6470
### Add edges ###
@@ -74,6 +80,7 @@ def dr_graph_builder() -> StateGraph:
7480
graph.add_edge(start_key=DRPath.INTERNET_SEARCH, end_key=DRPath.ORCHESTRATOR)
7581
graph.add_edge(start_key=DRPath.IMAGE_GENERATION, end_key=DRPath.ORCHESTRATOR)
7682
graph.add_edge(start_key=DRPath.GENERIC_TOOL, end_key=DRPath.ORCHESTRATOR)
83+
graph.add_edge(start_key=DRPath.GENERIC_INTERNAL_TOOL, end_key=DRPath.ORCHESTRATOR)
7784

7885
graph.add_conditional_edges(DRPath.CLOSER, completeness_router)
7986

backend/onyx/agents/agent_search/dr/models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ class IterationAnswer(BaseModel):
9090
background_info: str | None = None
9191
claims: list[str] | None = None
9292
additional_data: dict[str, str] | None = None
93+
response_type: str | None = None
94+
data: dict | list | str | int | float | bool | None = None
95+
file_ids: list[str] | None = None
9396

9497

9598
class AggregatedDRContext(BaseModel):

backend/onyx/agents/agent_search/dr/nodes/dr_a0_clarification.py

Lines changed: 49 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from langchain_core.messages import merge_content
77
from langchain_core.runnables import RunnableConfig
88
from langgraph.types import StreamWriter
9+
from sqlalchemy.orm import Session
910

1011
from onyx.agents.agent_search.basic.utils import process_llm_stream
1112
from onyx.agents.agent_search.dr.constants import AVERAGE_TOOL_COSTS
@@ -38,6 +39,8 @@
3839
from onyx.configs.constants import DocumentSourceDescription
3940
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
4041
from onyx.db.connector import fetch_unique_document_sources
42+
from onyx.db.models import Tool
43+
from onyx.db.tools import get_tools
4144
from onyx.file_store.models import ChatFileType
4245
from onyx.file_store.models import InMemoryChatFile
4346
from onyx.kg.utils.extraction_utils import get_entity_types_str
@@ -57,7 +60,6 @@
5760
from onyx.server.query_and_chat.streaming_models import MessageStart
5861
from onyx.server.query_and_chat.streaming_models import OverallStop
5962
from onyx.server.query_and_chat.streaming_models import SectionEnd
60-
from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
6163
from onyx.tools.tool_implementations.images.image_generation_tool import (
6264
ImageGenerationTool,
6365
)
@@ -82,6 +84,7 @@ def _format_tool_name(tool_name: str) -> str:
8284

8385

8486
def _get_available_tools(
87+
db_session: Session,
8588
graph_config: GraphConfig,
8689
kg_enabled: bool,
8790
active_source_types: list[DocumentSource],
@@ -97,49 +100,59 @@ def _get_available_tools(
97100
else:
98101
include_kg = False
99102

103+
tool_dict: dict[int, Tool] = {tool.id: tool for tool in get_tools(db_session)}
104+
100105
for tool in graph_config.tooling.tools:
101-
tool_info = OrchestratorTool(
102-
tool_id=tool.id,
103-
name=tool.name,
104-
llm_path=_format_tool_name(tool.name),
105-
path=DRPath.GENERIC_TOOL,
106-
description=tool.description,
107-
metadata={},
108-
cost=1.0,
109-
tool_object=tool,
110-
)
111106

112-
if isinstance(tool, CustomTool):
113-
# tool_info.metadata["summary_signature"] = CUSTOM_TOOL_RESPONSE_ID
114-
pass
115-
elif isinstance(tool, InternetSearchTool):
116-
# tool_info.metadata["summary_signature"] = (
117-
# INTERNET_SEARCH_RESPONSE_SUMMARY_ID
118-
# )
119-
tool_info.llm_path = DRPath.INTERNET_SEARCH.value
120-
tool_info.path = DRPath.INTERNET_SEARCH
107+
tool_db_info = tool_dict.get(tool.id)
108+
if tool_db_info:
109+
incode_tool_id = tool_db_info.in_code_tool_id
110+
else:
111+
raise ValueError(f"Tool {tool.name} is not found in the database")
112+
113+
if isinstance(tool, InternetSearchTool):
114+
llm_path = DRPath.INTERNET_SEARCH.value
115+
path = DRPath.INTERNET_SEARCH
121116
elif isinstance(tool, SearchTool) and len(active_source_types) > 0:
122117
# tool_info.metadata["summary_signature"] = SEARCH_RESPONSE_SUMMARY_ID
123-
tool_info.llm_path = DRPath.INTERNAL_SEARCH.value
124-
tool_info.path = DRPath.INTERNAL_SEARCH
118+
llm_path = DRPath.INTERNAL_SEARCH.value
119+
path = DRPath.INTERNAL_SEARCH
125120
elif (
126121
isinstance(tool, KnowledgeGraphTool)
127122
and include_kg
128123
and len(active_source_types) > 0
129124
):
130-
tool_info.llm_path = DRPath.KNOWLEDGE_GRAPH.value
131-
tool_info.path = DRPath.KNOWLEDGE_GRAPH
125+
llm_path = DRPath.KNOWLEDGE_GRAPH.value
126+
path = DRPath.KNOWLEDGE_GRAPH
132127
elif isinstance(tool, ImageGenerationTool):
133-
tool_info.llm_path = DRPath.IMAGE_GENERATION.value
134-
tool_info.path = DRPath.IMAGE_GENERATION
128+
llm_path = DRPath.IMAGE_GENERATION.value
129+
path = DRPath.IMAGE_GENERATION
130+
elif incode_tool_id:
131+
# if incode tool id is found, it is a generic internal tool
132+
llm_path = DRPath.GENERIC_INTERNAL_TOOL.value
133+
path = DRPath.GENERIC_INTERNAL_TOOL
135134
else:
136-
logger.warning(
137-
f"Tool {tool.name} ({type(tool)}) is not supported/available"
138-
)
139-
continue
135+
# otherwise it is a custom tool
136+
llm_path = DRPath.GENERIC_TOOL.value
137+
path = DRPath.GENERIC_TOOL
138+
139+
if path not in {DRPath.GENERIC_INTERNAL_TOOL, DRPath.GENERIC_TOOL}:
140+
description = TOOL_DESCRIPTION.get(path, tool.description)
141+
cost = AVERAGE_TOOL_COSTS[path]
142+
else:
143+
description = tool.description
144+
cost = 1.0
140145

141-
tool_info.description = TOOL_DESCRIPTION.get(tool_info.path, tool.description)
142-
tool_info.cost = AVERAGE_TOOL_COSTS[tool_info.path]
146+
tool_info = OrchestratorTool(
147+
tool_id=tool.id,
148+
name=tool.name,
149+
llm_path=llm_path,
150+
path=path,
151+
description=description,
152+
metadata={},
153+
cost=cost,
154+
tool_object=tool,
155+
)
143156

144157
# TODO: handle custom tools with same name as other tools (e.g., CLOSER)
145158
available_tools[tool_info.llm_path] = tool_info
@@ -236,8 +249,9 @@ def _get_existing_clarification_request(
236249
"type": "function",
237250
"function": {
238251
"name": "run_any_knowledge_retrieval_and_any_action_tool",
239-
"description": "Use this tool to get any external information \
240-
that is relevant to the question, or for any action to be taken, including image generation.",
252+
"description": "Use this tool to get ANY external information \
253+
that is relevant to the question, or for any action to be taken, including image generation. In fact, \
254+
ANY tool mentioned can be accessed through this generic tool.",
241255
"parameters": {
242256
"type": "object",
243257
"properties": {
@@ -292,7 +306,7 @@ def clarifier(
292306
active_source_types = fetch_unique_document_sources(db_session)
293307

294308
available_tools = _get_available_tools(
295-
graph_config, kg_enabled, active_source_types
309+
db_session, graph_config, kg_enabled, active_source_types
296310
)
297311

298312
available_tool_descriptions_str = "\n -" + "\n -".join(

backend/onyx/agents/agent_search/dr/nodes/dr_a1_orchestrator.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from onyx.agents.agent_search.utils import create_question_prompt
3333
from onyx.kg.utils.extraction_utils import get_entity_types_str
3434
from onyx.kg.utils.extraction_utils import get_relationship_types_str
35+
from onyx.prompts.dr_prompts import DEFAULLT_DECISION_PROMPT
3536
from onyx.prompts.dr_prompts import SUFFICIENT_INFORMATION_STRING
3637
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
3738
from onyx.server.query_and_chat.streaming_models import ReasoningStart
@@ -62,9 +63,15 @@ def orchestrator(
6263
clarification = state.clarification
6364
assistant_system_prompt = state.assistant_system_prompt
6465

65-
decision_system_prompt = _DECISION_SYSTEM_PROMPT_PREFIX + assistant_system_prompt
66+
if assistant_system_prompt:
67+
decision_system_prompt: str = (
68+
DEFAULLT_DECISION_PROMPT
69+
+ _DECISION_SYSTEM_PROMPT_PREFIX
70+
+ assistant_system_prompt
71+
)
72+
else:
73+
decision_system_prompt = DEFAULLT_DECISION_PROMPT
6674

67-
state.assistant_task_prompt
6875
iteration_nr = state.iteration_nr + 1
6976
current_step_nr = state.current_step_nr
7077

backend/onyx/agents/agent_search/dr/sub_agents/custom_tool/dr_custom_tool_2_act.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ def custom_tool_act(
102102
raise ValueError("Custom tool did not return a valid response summary")
103103

104104
# summarise tool result
105+
if not response_summary.response_type:
106+
raise ValueError("Response type is not returned.")
107+
105108
if response_summary.response_type == "json":
106109
tool_result_str = json.dumps(response_summary.tool_result, ensure_ascii=False)
107110
elif response_summary.response_type in {"image", "csv"}:
@@ -124,6 +127,13 @@ def custom_tool_act(
124127
).content
125128
).strip()
126129

130+
# get file_ids:
131+
file_ids = None
132+
if response_summary.response_type in {"image", "csv"} and hasattr(
133+
response_summary.tool_result, "file_ids"
134+
):
135+
file_ids = response_summary.tool_result.file_ids
136+
127137
logger.debug(
128138
f"Tool call end for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
129139
)
@@ -141,6 +151,9 @@ def custom_tool_act(
141151
cited_documents={},
142152
reasoning="",
143153
additional_data=None,
154+
response_type=response_summary.response_type,
155+
data=response_summary.tool_result,
156+
file_ids=file_ids,
144157
)
145158
],
146159
log_messages=[

backend/onyx/agents/agent_search/dr/sub_agents/custom_tool/dr_custom_tool_3_reduce.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
from onyx.agents.agent_search.shared_graph_utils.utils import (
99
get_langgraph_node_log_string,
1010
)
11+
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
12+
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
13+
from onyx.server.query_and_chat.streaming_models import CustomToolStart
14+
from onyx.server.query_and_chat.streaming_models import SectionEnd
1115
from onyx.utils.logger import setup_logger
1216

1317

@@ -25,13 +29,47 @@ def custom_tool_reducer(
2529

2630
node_start_time = datetime.now()
2731

32+
current_step_nr = state.current_step_nr
33+
2834
branch_updates = state.branch_iteration_responses
2935
current_iteration = state.iteration_nr
3036

3137
new_updates = [
3238
update for update in branch_updates if update.iteration_nr == current_iteration
3339
]
3440

41+
for new_update in new_updates:
42+
43+
if not new_update.response_type:
44+
raise ValueError("Response type is not returned.")
45+
46+
write_custom_event(
47+
current_step_nr,
48+
CustomToolStart(
49+
tool_name=new_update.tool,
50+
),
51+
writer,
52+
)
53+
54+
write_custom_event(
55+
current_step_nr,
56+
CustomToolDelta(
57+
tool_name=new_update.tool,
58+
response_type=new_update.response_type,
59+
data=new_update.data,
60+
file_ids=new_update.file_ids,
61+
),
62+
writer,
63+
)
64+
65+
write_custom_event(
66+
current_step_nr,
67+
SectionEnd(),
68+
writer,
69+
)
70+
71+
current_step_nr += 1
72+
3573
return SubAgentUpdate(
3674
iteration_responses=new_updates,
3775
log_messages=[
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from datetime import datetime
2+
3+
from langchain_core.runnables import RunnableConfig
4+
from langgraph.types import StreamWriter
5+
6+
from onyx.agents.agent_search.dr.states import LoggerUpdate
7+
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
8+
from onyx.agents.agent_search.shared_graph_utils.utils import (
9+
get_langgraph_node_log_string,
10+
)
11+
from onyx.utils.logger import setup_logger
12+
13+
logger = setup_logger()
14+
15+
16+
def generic_internal_tool_branch(
17+
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
18+
) -> LoggerUpdate:
19+
"""
20+
LangGraph node to perform a generic tool call as part of the DR process.
21+
"""
22+
23+
node_start_time = datetime.now()
24+
iteration_nr = state.iteration_nr
25+
26+
logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
27+
28+
return LoggerUpdate(
29+
log_messages=[
30+
get_langgraph_node_log_string(
31+
graph_component="generic_internal_tool",
32+
node_name="branching",
33+
node_start_time=node_start_time,
34+
)
35+
],
36+
)

0 commit comments

Comments
 (0)