Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions backend/onyx/chat/process_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,10 +794,10 @@ def stream_chat_message_objects(
final_msg.prompt,
prompt_override=prompt_override,
)
elif final_msg.prompt:
prompt_config = PromptConfig.from_model(final_msg.prompt)
else:
prompt_config = PromptConfig.from_model(persona.prompts[0])
prompt_config = PromptConfig.from_model(
final_msg.prompt or persona.prompts[0]
)
Comment on lines +798 to +800
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Ensure persona.prompts is non-empty to avoid index errors if final_msg.prompt is falsy.


answer_style_config = AnswerStyleConfig(
citation_config=CitationConfig(
Expand Down
13 changes: 13 additions & 0 deletions backend/onyx/chat/prompt_builder/answer_prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,19 @@ def get_user_message_content(self) -> str:
query, _ = message_to_prompt_and_imgs(self.user_message_and_token_cnt[0])
return query

def get_message_history(self) -> list[PreviousMessage]:
"""
Get the message history as a list of PreviousMessage objects.
"""
ret = []
if self.system_message_and_token_cnt:
tmp = PreviousMessage.from_langchain_msg(*self.system_message_and_token_cnt)
ret.append(tmp)
for i, msg in enumerate(self.message_history):
tmp = PreviousMessage.from_langchain_msg(msg, self.history_token_cnts[i])
ret.append(tmp)
return ret

def build(self) -> list[BaseMessage]:
if not self.user_message_and_token_cnt:
raise ValueError("User message must be set before building prompt")
Expand Down
18 changes: 12 additions & 6 deletions backend/onyx/chat/tool_handling/tool_response_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,12 @@ def get_tool_call_for_non_tool_calling_llm_impl(
prompt_builder: AnswerPromptBuilder | PromptSnapshot,
llm: LLM,
) -> tuple[Tool, dict] | None:
user_query = prompt_builder.raw_user_query
history = prompt_builder.raw_message_history
if isinstance(prompt_builder, AnswerPromptBuilder):
user_query = prompt_builder.get_user_message_content()
history = prompt_builder.get_message_history()

if force_use_tool.force_use:
# if we are forcing a tool, we don't need to check which tools to run
tool = get_tool_by_name(tools, force_use_tool.tool_name)
Expand All @@ -170,8 +176,8 @@ def get_tool_call_for_non_tool_calling_llm_impl(
force_use_tool.args
if force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
query=prompt_builder.raw_user_query,
history=prompt_builder.raw_message_history,
query=user_query,
history=history,
llm=llm,
force_run=True,
)
Expand All @@ -188,8 +194,8 @@ def get_tool_call_for_non_tool_calling_llm_impl(
else:
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
tools=tools,
query=prompt_builder.raw_user_query,
history=prompt_builder.raw_message_history,
query=user_query,
history=history,
llm=llm,
)

Expand All @@ -206,8 +212,8 @@ def get_tool_call_for_non_tool_calling_llm_impl(
chosen_tool_and_args = (
select_single_tool_for_non_tool_calling_llm(
tools_and_args=available_tools_and_args,
history=prompt_builder.raw_message_history,
query=prompt_builder.raw_user_query,
history=history,
query=user_query,
llm=llm,
)
if available_tools_and_args
Expand Down
20 changes: 20 additions & 0 deletions backend/onyx/llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from onyx.configs.constants import MessageType
from onyx.file_store.models import InMemoryChatFile
from onyx.llm.utils import build_content_with_imgs
from onyx.llm.utils import message_to_string
from onyx.tools.models import ToolCallFinalResult

if TYPE_CHECKING:
Expand Down Expand Up @@ -59,3 +60,22 @@ def to_langchain_msg(self) -> BaseMessage:
return AIMessage(content=content)
else:
return SystemMessage(content=content)

@classmethod
def from_langchain_msg(
cls, msg: BaseMessage, token_count: int
) -> "PreviousMessage":
message_type = MessageType.SYSTEM
if isinstance(msg, HumanMessage):
message_type = MessageType.USER
elif isinstance(msg, AIMessage):
message_type = MessageType.ASSISTANT
message = message_to_string(msg)
return cls(
message=message,
token_count=token_count,
message_type=message_type,
files=[],
tool_call=None,
refined_answer_improvement=None,
)
7 changes: 6 additions & 1 deletion backend/tests/unit/onyx/chat/test_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,14 @@ def test_answer_with_search_no_tool_calling(
== mock_search_tool.build_next_prompt.return_value.build.return_value
)

user_message = (
answer_instance.graph_inputs.prompt_builder.get_user_message_content()
)

prev_messages = answer_instance.graph_inputs.prompt_builder.get_message_history()
# Verify that get_args_for_non_tool_calling_llm was called on the mock_search_tool
mock_search_tool.get_args_for_non_tool_calling_llm.assert_called_once_with(
QUERY, [], answer_instance.graph_config.tooling.primary_llm
user_message, prev_messages, answer_instance.graph_config.tooling.primary_llm
)

# Verify that the search tool's run method was called
Expand Down