From 808135038eb94cde31b013714f37ee56d5164f73 Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Sun, 30 Mar 2025 19:45:28 -0700 Subject: [PATCH 1/5] use persona info when creating tool args --- backend/onyx/chat/process_message.py | 6 +++--- .../onyx/chat/tool_handling/tool_response_handler.py | 10 +++++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/backend/onyx/chat/process_message.py b/backend/onyx/chat/process_message.py index eea541aca71..bbb7652e0fb 100644 --- a/backend/onyx/chat/process_message.py +++ b/backend/onyx/chat/process_message.py @@ -794,10 +794,10 @@ def stream_chat_message_objects( final_msg.prompt, prompt_override=prompt_override, ) - elif final_msg.prompt: - prompt_config = PromptConfig.from_model(final_msg.prompt) else: - prompt_config = PromptConfig.from_model(persona.prompts[0]) + prompt_config = PromptConfig.from_model( + final_msg.prompt or persona.prompts[0] + ) answer_style_config = AnswerStyleConfig( citation_config=CitationConfig( diff --git a/backend/onyx/chat/tool_handling/tool_response_handler.py b/backend/onyx/chat/tool_handling/tool_response_handler.py index cab5d9e08fa..74c169f2ad7 100644 --- a/backend/onyx/chat/tool_handling/tool_response_handler.py +++ b/backend/onyx/chat/tool_handling/tool_response_handler.py @@ -162,6 +162,10 @@ def get_tool_call_for_non_tool_calling_llm_impl( prompt_builder: AnswerPromptBuilder | PromptSnapshot, llm: LLM, ) -> tuple[Tool, dict] | None: + user_query = prompt_builder.raw_user_query + if isinstance(prompt_builder, AnswerPromptBuilder): + user_query = prompt_builder.get_user_message_content() + if force_use_tool.force_use: # if we are forcing a tool, we don't need to check which tools to run tool = get_tool_by_name(tools, force_use_tool.tool_name) @@ -170,7 +174,7 @@ def get_tool_call_for_non_tool_calling_llm_impl( force_use_tool.args if force_use_tool.args is not None else tool.get_args_for_non_tool_calling_llm( - query=prompt_builder.raw_user_query, + query=user_query, history=prompt_builder.raw_message_history, llm=llm, force_run=True, @@ -188,7 +192,7 @@ def get_tool_call_for_non_tool_calling_llm_impl( else: tool_options = check_which_tools_should_run_for_non_tool_calling_llm( tools=tools, - query=prompt_builder.raw_user_query, + query=user_query, history=prompt_builder.raw_message_history, llm=llm, ) @@ -207,7 +211,7 @@ def get_tool_call_for_non_tool_calling_llm_impl( select_single_tool_for_non_tool_calling_llm( tools_and_args=available_tools_and_args, history=prompt_builder.raw_message_history, - query=prompt_builder.raw_user_query, + query=user_query, llm=llm, ) if available_tools_and_args From 8e41ca9a9fe3e9a5c418b8d269eb9cbc72498cb4 Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Mon, 31 Mar 2025 09:31:29 -0700 Subject: [PATCH 2/5] fixed unit test --- backend/tests/unit/onyx/chat/test_answer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/backend/tests/unit/onyx/chat/test_answer.py b/backend/tests/unit/onyx/chat/test_answer.py index 1f94e9ef720..0279cdd085d 100644 --- a/backend/tests/unit/onyx/chat/test_answer.py +++ b/backend/tests/unit/onyx/chat/test_answer.py @@ -323,9 +323,12 @@ def test_answer_with_search_no_tool_calling( == mock_search_tool.build_next_prompt.return_value.build.return_value ) + user_message = ( + answer_instance.graph_inputs.prompt_builder.get_user_message_content() + ) # Verify that get_args_for_non_tool_calling_llm was called on the mock_search_tool mock_search_tool.get_args_for_non_tool_calling_llm.assert_called_once_with( - QUERY, [], answer_instance.graph_config.tooling.primary_llm + user_message, [], answer_instance.graph_config.tooling.primary_llm ) # Verify that the search tool's run method was called From f96f88a8d397b03c8ae3f30a873667c78ca3fda9 Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Fri, 4 Apr 2025 19:53:29 -0700 Subject: [PATCH 3/5] include system message --- .../prompt_builder/answer_prompt_builder.py | 13 ++++++++++++ .../tool_handling/tool_response_handler.py | 8 +++++--- backend/onyx/llm/models.py | 20 +++++++++++++++++++ 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/backend/onyx/chat/prompt_builder/answer_prompt_builder.py b/backend/onyx/chat/prompt_builder/answer_prompt_builder.py index 8e175f576ee..111421779e0 100644 --- a/backend/onyx/chat/prompt_builder/answer_prompt_builder.py +++ b/backend/onyx/chat/prompt_builder/answer_prompt_builder.py @@ -155,6 +155,19 @@ def get_user_message_content(self) -> str: query, _ = message_to_prompt_and_imgs(self.user_message_and_token_cnt[0]) return query + def get_message_history(self) -> list[PreviousMessage]: + """ + Get the message history as a list of PreviousMessage objects. + """ + ret = [] + if self.system_message_and_token_cnt: + tmp = PreviousMessage.from_langchain_msg(*self.system_message_and_token_cnt) + ret.append(tmp) + for i, msg in enumerate(self.message_history): + tmp = PreviousMessage.from_langchain_msg(msg, self.history_token_cnts[i]) + ret.append(tmp) + return ret + def build(self) -> list[BaseMessage]: if not self.user_message_and_token_cnt: raise ValueError("User message must be set before building prompt") diff --git a/backend/onyx/chat/tool_handling/tool_response_handler.py b/backend/onyx/chat/tool_handling/tool_response_handler.py index 74c169f2ad7..d01630a9748 100644 --- a/backend/onyx/chat/tool_handling/tool_response_handler.py +++ b/backend/onyx/chat/tool_handling/tool_response_handler.py @@ -163,8 +163,10 @@ def get_tool_call_for_non_tool_calling_llm_impl( llm: LLM, ) -> tuple[Tool, dict] | None: user_query = prompt_builder.raw_user_query + history = prompt_builder.raw_message_history if isinstance(prompt_builder, AnswerPromptBuilder): user_query = prompt_builder.get_user_message_content() + history = prompt_builder.get_message_history() if force_use_tool.force_use: # if we are forcing a tool, we don't need to check which tools to run @@ -175,7 +177,7 @@ def get_tool_call_for_non_tool_calling_llm_impl( if force_use_tool.args is not None else tool.get_args_for_non_tool_calling_llm( query=user_query, - history=prompt_builder.raw_message_history, + history=history, llm=llm, force_run=True, ) @@ -193,7 +195,7 @@ def get_tool_call_for_non_tool_calling_llm_impl( tool_options = check_which_tools_should_run_for_non_tool_calling_llm( tools=tools, query=user_query, - history=prompt_builder.raw_message_history, + history=history, llm=llm, ) @@ -210,7 +212,7 @@ def get_tool_call_for_non_tool_calling_llm_impl( chosen_tool_and_args = ( select_single_tool_for_non_tool_calling_llm( tools_and_args=available_tools_and_args, - history=prompt_builder.raw_message_history, + history=history, query=user_query, llm=llm, ) diff --git a/backend/onyx/llm/models.py b/backend/onyx/llm/models.py index 925c8bc3f9e..b7e85deaa1f 100644 --- a/backend/onyx/llm/models.py +++ b/backend/onyx/llm/models.py @@ -9,6 +9,7 @@ from onyx.configs.constants import MessageType from onyx.file_store.models import InMemoryChatFile from onyx.llm.utils import build_content_with_imgs +from onyx.llm.utils import message_to_string from onyx.tools.models import ToolCallFinalResult if TYPE_CHECKING: @@ -59,3 +60,22 @@ def to_langchain_msg(self) -> BaseMessage: return AIMessage(content=content) else: return SystemMessage(content=content) + + @classmethod + def from_langchain_msg( + cls, msg: BaseMessage, token_count: int + ) -> "PreviousMessage": + message_type = MessageType.SYSTEM + if isinstance(msg, HumanMessage): + message_type = MessageType.USER + elif isinstance(msg, AIMessage): + message_type = MessageType.ASSISTANT + message = message_to_string(msg) + return cls( + message=message, + token_count=token_count, + message_type=message_type, + files=[], + tool_call=None, + refined_answer_improvement=None, + ) From 59451f24f52be045eb78f8fdc762b1fa185c3311 Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Fri, 4 Apr 2025 19:56:59 -0700 Subject: [PATCH 4/5] fix unit test --- backend/tests/unit/onyx/chat/test_answer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/backend/tests/unit/onyx/chat/test_answer.py b/backend/tests/unit/onyx/chat/test_answer.py index 0279cdd085d..c714427c1c4 100644 --- a/backend/tests/unit/onyx/chat/test_answer.py +++ b/backend/tests/unit/onyx/chat/test_answer.py @@ -326,9 +326,11 @@ def test_answer_with_search_no_tool_calling( user_message = ( answer_instance.graph_inputs.prompt_builder.get_user_message_content() ) + + prev_messages = answer_instance.graph_inputs.prompt_builder.get_message_history() # Verify that get_args_for_non_tool_calling_llm was called on the mock_search_tool mock_search_tool.get_args_for_non_tool_calling_llm.assert_called_once_with( - user_message, [], answer_instance.graph_config.tooling.primary_llm + user_message, prev_messages, answer_instance.graph_config.tooling.primary_llm ) # Verify that the search tool's run method was called From 6108d264112f6201d752373c5f64dcc796305298 Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Mon, 7 Apr 2025 18:15:26 -0700 Subject: [PATCH 5/5] nit --- backend/onyx/chat/prompt_builder/answer_prompt_builder.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/backend/onyx/chat/prompt_builder/answer_prompt_builder.py b/backend/onyx/chat/prompt_builder/answer_prompt_builder.py index 111421779e0..7b35690a2e3 100644 --- a/backend/onyx/chat/prompt_builder/answer_prompt_builder.py +++ b/backend/onyx/chat/prompt_builder/answer_prompt_builder.py @@ -159,14 +159,14 @@ def get_message_history(self) -> list[PreviousMessage]: """ Get the message history as a list of PreviousMessage objects. """ - ret = [] + message_history = [] if self.system_message_and_token_cnt: tmp = PreviousMessage.from_langchain_msg(*self.system_message_and_token_cnt) - ret.append(tmp) + message_history.append(tmp) for i, msg in enumerate(self.message_history): tmp = PreviousMessage.from_langchain_msg(msg, self.history_token_cnts[i]) - ret.append(tmp) - return ret + message_history.append(tmp) + return message_history def build(self) -> list[BaseMessage]: if not self.user_message_and_token_cnt: