Skip to content

Commit 7b37e72

Browse files
committed
almost working custom tools
1 parent 09d672f commit 7b37e72

21 files changed

+188
-320
lines changed

backend/onyx/agents/agent_search/dr/dr_prompt_builder.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,23 @@
1919
def get_dr_prompt_orchestration_templates(
2020
purpose: DRPromptPurpose,
2121
research_type: ResearchType,
22-
available_tools: list[OrchestratorTool],
22+
available_tools: dict[str, OrchestratorTool],
2323
entity_types_string: str | None = None,
2424
relationship_types_string: str | None = None,
2525
reasoning_result: str | None = None,
2626
tool_calls_string: str | None = None,
2727
) -> PromptTemplate:
28-
available_tools = available_tools or []
29-
tool_names = [tool.llm_path for tool in available_tools]
28+
available_tools = available_tools or {}
29+
tool_names = list(available_tools.keys())
3030
tool_description_str = "\n\n".join(
31-
f"- {tool.llm_path}: {tool.description}" for tool in available_tools
31+
f"- {tool_name}: {tool.description}"
32+
for tool_name, tool in available_tools.items()
3233
)
3334
tool_cost_str = "\n".join(
34-
f"{tool.llm_path}: {tool.cost}" for tool in available_tools
35+
f"{tool_name}: {tool.cost}" for tool_name, tool in available_tools.items()
3536
)
3637

37-
available_tool_paths = [tool.path for tool in available_tools]
38+
available_tool_paths = [tool.path for tool in available_tools.values()]
3839

3940
tool_differentiations: list[str] = []
4041
for tool_1 in available_tool_paths:

backend/onyx/agents/agent_search/dr/models.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from onyx.agents.agent_search.dr.enums import DRPath
66
from onyx.context.search.models import InferenceSection
7+
from onyx.tools.tool import Tool
78

89

910
class OrchestratorStep(BaseModel):
@@ -59,6 +60,10 @@ class OrchestratorTool(BaseModel):
5960
description: str
6061
metadata: dict[str, str]
6162
cost: float
63+
tool_object: Tool | None = None # None for CLOSER
64+
65+
class Config:
66+
arbitrary_types_allowed = True
6267

6368

6469
class IterationInstructions(BaseModel):
@@ -71,11 +76,10 @@ class IterationInstructions(BaseModel):
7176
class GenericToolAnswer(BaseModel):
7277
reasoning: str
7378
answer: str
74-
background_info: str
7579

7680

7781
class IterationAnswer(BaseModel):
78-
tool: DRPath
82+
tool: str
7983
tool_id: int
8084
iteration_nr: int
8185
parallelization_nr: int

backend/onyx/agents/agent_search/dr/nodes/dr_a0_clarification.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@ def _format_tool_name(tool_name: str) -> str:
6767

6868
def _get_available_tools(
6969
graph_config: GraphConfig, kg_enabled: bool
70-
) -> list[OrchestratorTool]:
70+
) -> dict[str, OrchestratorTool]:
7171

72-
available_tools: list[OrchestratorTool] = []
72+
available_tools: dict[str, OrchestratorTool] = {}
7373
for tool in graph_config.tooling.tools:
7474
tool_info = OrchestratorTool(
7575
tool_id=tool.id,
@@ -79,6 +79,7 @@ def _get_available_tools(
7979
description=tool.description,
8080
metadata={},
8181
cost=1.0,
82+
tool_object=tool,
8283
)
8384

8485
if isinstance(tool, CustomTool):
@@ -107,29 +108,28 @@ def _get_available_tools(
107108
tool_info.description = TOOL_DESCRIPTION.get(tool_info.path, tool.description)
108109
tool_info.cost = AVERAGE_TOOL_COSTS[tool_info.path]
109110

110-
available_tools.append(tool_info)
111+
# TODO: handle custom tools with same name as other tools (e.g., CLOSER)
112+
available_tools[tool_info.llm_path] = tool_info
111113

112114
# make sure KG isn't enabled without internal search
113-
available_paths = [tool.path for tool in available_tools]
114115
if (
115-
DRPath.KNOWLEDGE_GRAPH in available_paths
116-
and DRPath.INTERNAL_SEARCH not in available_paths
116+
DRPath.KNOWLEDGE_GRAPH.value in available_tools
117+
and DRPath.INTERNAL_SEARCH.value not in available_tools
117118
):
118119
raise ValueError(
119120
"The Knowledge Graph is not supported without internal search tool"
120121
)
121122

122123
# add CLOSER tool, which is always available
123-
available_tools.append(
124-
OrchestratorTool(
125-
tool_id=-1,
126-
name="closer",
127-
llm_path=DRPath.CLOSER.value,
128-
path=DRPath.CLOSER,
129-
description=TOOL_DESCRIPTION[DRPath.CLOSER],
130-
metadata={},
131-
cost=0.0,
132-
)
124+
available_tools[DRPath.CLOSER.value] = OrchestratorTool(
125+
tool_id=-1,
126+
name="closer",
127+
llm_path=DRPath.CLOSER.value,
128+
path=DRPath.CLOSER,
129+
description=TOOL_DESCRIPTION[DRPath.CLOSER],
130+
metadata={},
131+
cost=0.0,
132+
tool_object=None,
133133
)
134134

135135
return available_tools

backend/onyx/agents/agent_search/dr/nodes/dr_a1_orchestrator.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,7 @@ def orchestrator(
7474
).context
7575
or "(No answer history yet available)"
7676
)
77-
available_tools = state.available_tools or []
78-
available_tool_map = {tool.llm_path: tool for tool in available_tools}
77+
available_tools = state.available_tools or {}
7978

8079
questions = [
8180
f"{iteration_response.tool}: {iteration_response.question}"
@@ -217,7 +216,7 @@ def orchestrator(
217216
logger.error(f"Error in approach extraction: {e}")
218217
raise e
219218

220-
remaining_time_budget -= available_tool_map[next_tool].cost
219+
remaining_time_budget -= available_tools[next_tool].cost
221220
else:
222221
if iteration_nr == 1 and not plan_of_record:
223222
# by default, we start a new iteration, but if there is a feedback request,
@@ -304,7 +303,7 @@ def orchestrator(
304303
logger.error(f"Error in approach extraction: {e}")
305304
raise e
306305

307-
remaining_time_budget -= available_tool_map[next_tool].cost
306+
remaining_time_budget -= available_tools[next_tool].cost
308307
else:
309308
reasoning_result = "Time to wrap up."
310309

backend/onyx/agents/agent_search/dr/states.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,11 @@ class OrchestrationUpdate(LoggerUpdate):
2626
tools_used: Annotated[list[str], add] = []
2727
query_list: list[str] = []
2828
iteration_nr: int = 0
29-
plan_of_record: OrchestrationPlan | None = None # None for FAST TimeBudget
29+
plan_of_record: OrchestrationPlan | None = None # None for Thoughtful
3030
remaining_time_budget: float = 2.0 # set by default to about 2 searches
3131
clarification: OrchestrationClarificationInfo | None = None
32-
available_tools: list[OrchestratorTool] | None = None
33-
num_closer_suggestions: int = (
34-
0 # how many times the closer was suggested. (Closer can send back now.)
35-
)
32+
available_tools: dict[str, OrchestratorTool] | None = None
33+
num_closer_suggestions: int = 0 # how many times the closer was suggested
3634
gaps: list[str] = (
3735
[]
3836
) # gaps that may be identified by the closer before being able to answer the question.

backend/onyx/agents/agent_search/dr/sub_agents/basic_search/dr_basic_search_2_act.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from langchain_core.runnables import RunnableConfig
66
from langgraph.types import StreamWriter
77

8-
from onyx.agents.agent_search.dr.enums import DRPath
98
from onyx.agents.agent_search.dr.enums import ResearchType
109
from onyx.agents.agent_search.dr.models import BaseSearchProcessingResponse
1110
from onyx.agents.agent_search.dr.models import IterationAnswer
@@ -30,6 +29,7 @@
3029
SEARCH_RESPONSE_SUMMARY_ID,
3130
)
3231
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
32+
from onyx.tools.tool_implementations.search.search_tool import SearchTool
3333
from onyx.utils.logger import setup_logger
3434

3535
logger = setup_logger()
@@ -56,13 +56,17 @@ def basic_search(
5656
base_question = graph_config.inputs.prompt_builder.raw_user_query
5757
research_type = graph_config.behavior.research_type
5858

59-
search_tool = graph_config.tooling.search_tool
59+
if not state.available_tools:
60+
raise ValueError("available_tools is not set")
6061

61-
if search_tool is None:
62-
raise ValueError("search_tool must be provided for agentic search")
62+
search_tool_info = state.available_tools[state.tools_used[-1]]
63+
search_tool = cast(SearchTool, search_tool_info.tool_object)
6364

64-
# rewrite query and identify source types
65+
# sanity check
66+
if search_tool != graph_config.tooling.search_tool:
67+
raise ValueError("search_tool does not match the configured search tool")
6568

69+
# rewrite query and identify source types
6670
active_source_types_str = ", ".join(
6771
[source.value for source in state.active_source_types or []]
6872
)
@@ -228,8 +232,8 @@ def basic_search(
228232
return AnswerUpdate(
229233
iteration_responses=[
230234
IterationAnswer(
231-
tool=DRPath.INTERNAL_SEARCH,
232-
tool_id=search_tool.id,
235+
tool=search_tool_info.llm_path,
236+
tool_id=search_tool_info.tool_id,
233237
iteration_nr=iteration_nr,
234238
parallelization_nr=parallelization_nr,
235239
question=branch_query,
@@ -242,8 +246,8 @@ def basic_search(
242246
],
243247
log_messages=[
244248
get_langgraph_node_log_string(
245-
graph_component="main",
246-
node_name="search",
249+
graph_component="basic_search",
250+
node_name="searching",
247251
node_start_time=node_start_time,
248252
)
249253
],

backend/onyx/agents/agent_search/dr/sub_agents/basic_search/dr_basic_search_3_reduce.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def is_reducer(
3636
iteration_responses=new_updates,
3737
log_messages=[
3838
get_langgraph_node_log_string(
39-
graph_component="internet_search",
39+
graph_component="basic_search",
4040
node_name="consolidation",
4141
node_start_time=node_start_time,
4242
)

backend/onyx/agents/agent_search/dr/sub_agents/basic_search/dr_basic_search_conditional_edges.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
1717
branch_question=query,
1818
context="",
1919
active_source_types=state.active_source_types,
20+
tools_used=state.tools_used,
21+
available_tools=state.available_tools,
2022
),
2123
)
2224
for parallelization_nr, query in enumerate(

backend/onyx/agents/agent_search/dr/sub_agents/custom_tool/dr_custom_tool_1_branch.py

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,8 @@
33
from langchain_core.runnables import RunnableConfig
44
from langgraph.types import StreamWriter
55

6-
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_states import (
7-
CustomToolSubAgentInput,
8-
)
9-
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_states import (
10-
CustomToolSubAgentPrepareUpdate,
11-
)
6+
from onyx.agents.agent_search.dr.states import LoggerUpdate
7+
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
128
from onyx.agents.agent_search.shared_graph_utils.utils import (
139
get_langgraph_node_log_string,
1410
)
@@ -18,32 +14,21 @@
1814

1915

2016
def custom_tool_branch(
21-
state: CustomToolSubAgentInput,
22-
config: RunnableConfig,
23-
writer: StreamWriter = lambda _: None,
24-
) -> CustomToolSubAgentPrepareUpdate:
17+
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
18+
) -> LoggerUpdate:
2519
"""
2620
LangGraph node to perform a generic tool call as part of the DR process.
2721
"""
2822

2923
node_start_time = datetime.now()
30-
tool_name = state.query_path[-1]
31-
32-
if not state.available_tools:
33-
raise ValueError("available_tools is not set")
24+
iteration_nr = state.iteration_nr
3425

35-
tool_dict: dict[str, str] = {}
36-
for available_tool_dict in state.available_tools:
37-
if available_tool_dict["name"] == tool_name:
38-
tool_dict = available_tool_dict
39-
break
26+
logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
4027

41-
return CustomToolSubAgentPrepareUpdate(
42-
tool_name=tool_name,
43-
tool_dict=tool_dict,
28+
return LoggerUpdate(
4429
log_messages=[
4530
get_langgraph_node_log_string(
46-
graph_component="custom_tool_sub_agent",
31+
graph_component="custom_tool",
4732
node_name="branching",
4833
node_start_time=node_start_time,
4934
)

0 commit comments

Comments
 (0)