Skip to content

Commit 10f1ac5

Browse files
authored
use persona info when creating tool args (#4397)
* use persona info when creating tool args * fixed unit test * include system message * fix unit test * nit
1 parent 1f80ed1 commit 10f1ac5

File tree

5 files changed

+54
-10
lines changed

5 files changed

+54
-10
lines changed

backend/onyx/chat/process_message.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -794,10 +794,10 @@ def stream_chat_message_objects(
794794
final_msg.prompt,
795795
prompt_override=prompt_override,
796796
)
797-
elif final_msg.prompt:
798-
prompt_config = PromptConfig.from_model(final_msg.prompt)
799797
else:
800-
prompt_config = PromptConfig.from_model(persona.prompts[0])
798+
prompt_config = PromptConfig.from_model(
799+
final_msg.prompt or persona.prompts[0]
800+
)
801801

802802
answer_style_config = AnswerStyleConfig(
803803
citation_config=CitationConfig(

backend/onyx/chat/prompt_builder/answer_prompt_builder.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,19 @@ def get_user_message_content(self) -> str:
155155
query, _ = message_to_prompt_and_imgs(self.user_message_and_token_cnt[0])
156156
return query
157157

158+
def get_message_history(self) -> list[PreviousMessage]:
159+
"""
160+
Get the message history as a list of PreviousMessage objects.
161+
"""
162+
message_history = []
163+
if self.system_message_and_token_cnt:
164+
tmp = PreviousMessage.from_langchain_msg(*self.system_message_and_token_cnt)
165+
message_history.append(tmp)
166+
for i, msg in enumerate(self.message_history):
167+
tmp = PreviousMessage.from_langchain_msg(msg, self.history_token_cnts[i])
168+
message_history.append(tmp)
169+
return message_history
170+
158171
def build(self) -> list[BaseMessage]:
159172
if not self.user_message_and_token_cnt:
160173
raise ValueError("User message must be set before building prompt")

backend/onyx/chat/tool_handling/tool_response_handler.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,12 @@ def get_tool_call_for_non_tool_calling_llm_impl(
162162
prompt_builder: AnswerPromptBuilder | PromptSnapshot,
163163
llm: LLM,
164164
) -> tuple[Tool, dict] | None:
165+
user_query = prompt_builder.raw_user_query
166+
history = prompt_builder.raw_message_history
167+
if isinstance(prompt_builder, AnswerPromptBuilder):
168+
user_query = prompt_builder.get_user_message_content()
169+
history = prompt_builder.get_message_history()
170+
165171
if force_use_tool.force_use:
166172
# if we are forcing a tool, we don't need to check which tools to run
167173
tool = get_tool_by_name(tools, force_use_tool.tool_name)
@@ -170,8 +176,8 @@ def get_tool_call_for_non_tool_calling_llm_impl(
170176
force_use_tool.args
171177
if force_use_tool.args is not None
172178
else tool.get_args_for_non_tool_calling_llm(
173-
query=prompt_builder.raw_user_query,
174-
history=prompt_builder.raw_message_history,
179+
query=user_query,
180+
history=history,
175181
llm=llm,
176182
force_run=True,
177183
)
@@ -188,8 +194,8 @@ def get_tool_call_for_non_tool_calling_llm_impl(
188194
else:
189195
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
190196
tools=tools,
191-
query=prompt_builder.raw_user_query,
192-
history=prompt_builder.raw_message_history,
197+
query=user_query,
198+
history=history,
193199
llm=llm,
194200
)
195201

@@ -206,8 +212,8 @@ def get_tool_call_for_non_tool_calling_llm_impl(
206212
chosen_tool_and_args = (
207213
select_single_tool_for_non_tool_calling_llm(
208214
tools_and_args=available_tools_and_args,
209-
history=prompt_builder.raw_message_history,
210-
query=prompt_builder.raw_user_query,
215+
history=history,
216+
query=user_query,
211217
llm=llm,
212218
)
213219
if available_tools_and_args

backend/onyx/llm/models.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from onyx.configs.constants import MessageType
1010
from onyx.file_store.models import InMemoryChatFile
1111
from onyx.llm.utils import build_content_with_imgs
12+
from onyx.llm.utils import message_to_string
1213
from onyx.tools.models import ToolCallFinalResult
1314

1415
if TYPE_CHECKING:
@@ -59,3 +60,22 @@ def to_langchain_msg(self) -> BaseMessage:
5960
return AIMessage(content=content)
6061
else:
6162
return SystemMessage(content=content)
63+
64+
@classmethod
65+
def from_langchain_msg(
66+
cls, msg: BaseMessage, token_count: int
67+
) -> "PreviousMessage":
68+
message_type = MessageType.SYSTEM
69+
if isinstance(msg, HumanMessage):
70+
message_type = MessageType.USER
71+
elif isinstance(msg, AIMessage):
72+
message_type = MessageType.ASSISTANT
73+
message = message_to_string(msg)
74+
return cls(
75+
message=message,
76+
token_count=token_count,
77+
message_type=message_type,
78+
files=[],
79+
tool_call=None,
80+
refined_answer_improvement=None,
81+
)

backend/tests/unit/onyx/chat/test_answer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,9 +323,14 @@ def test_answer_with_search_no_tool_calling(
323323
== mock_search_tool.build_next_prompt.return_value.build.return_value
324324
)
325325

326+
user_message = (
327+
answer_instance.graph_inputs.prompt_builder.get_user_message_content()
328+
)
329+
330+
prev_messages = answer_instance.graph_inputs.prompt_builder.get_message_history()
326331
# Verify that get_args_for_non_tool_calling_llm was called on the mock_search_tool
327332
mock_search_tool.get_args_for_non_tool_calling_llm.assert_called_once_with(
328-
QUERY, [], answer_instance.graph_config.tooling.primary_llm
333+
user_message, prev_messages, answer_instance.graph_config.tooling.primary_llm
329334
)
330335

331336
# Verify that the search tool's run method was called

0 commit comments

Comments
 (0)