Skip to content
4 changes: 4 additions & 0 deletions backend/onyx/agents/agent_search/dr/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchType

Expand All @@ -12,6 +14,8 @@
0 # how many times the closer can send back to the orchestrator
)

DR_BASIC_SEARCH_MAX_DOCS = int(os.environ.get("DR_BASIC_SEARCH_MAX_DOCS", 15))

CLARIFICATION_REQUEST_PREFIX = "PLEASE CLARIFY:"
HIGH_LEVEL_PLAN_PREFIX = "The Plan:"

Expand Down
187 changes: 135 additions & 52 deletions backend/onyx/agents/agent_search/dr/nodes/dr_a0_clarification.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import re
from datetime import datetime
from typing import Any
from typing import cast

from langchain_core.messages import AIMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.messages import SystemMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from sqlalchemy.orm import Session
Expand All @@ -25,6 +26,7 @@
from onyx.agents.agent_search.dr.process_llm_stream import process_llm_stream
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.states import OrchestrationSetup
from onyx.agents.agent_search.dr.utils import get_chat_history_messages
from onyx.agents.agent_search.dr.utils import get_chat_history_string
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
Expand Down Expand Up @@ -53,9 +55,11 @@
from onyx.llm.utils import get_max_input_tokens
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.prompts.dr_prompts import ANSWER_PROMPT_WO_TOOL_CALLING
from onyx.prompts.dr_prompts import BASE_SYSTEM_MESSAGE_TEMPLATE
from onyx.prompts.dr_prompts import DECISION_PROMPT_W_TOOL_CALLING
from onyx.prompts.dr_prompts import DECISION_PROMPT_WO_TOOL_CALLING
from onyx.prompts.dr_prompts import DEFAULT_DR_SYSTEM_PROMPT
from onyx.prompts.dr_prompts import QUESTION_CONFIRMATION
from onyx.prompts.dr_prompts import REPEAT_PROMPT
from onyx.prompts.dr_prompts import TOOL_DESCRIPTION
from onyx.prompts.prompt_template import PromptTemplate
Expand All @@ -79,13 +83,14 @@

logger = setup_logger()

_ANSWER_COMMENT_PROMPT = "I will now answer your question directly."

def _format_tool_name(tool_name: str) -> str:
"""Convert tool name to LLM-friendly format."""
name = tool_name.replace(" ", "_")
# take care of camel case like GetAPIKey -> GET_API_KEY for LLM readability
name = re.sub(r"(?<=[a-z0-9])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])", "_", name)
return name.upper()
_CONSIDER_TOOLS_PROMPT = "I will now concier the tools and sub-agents that are available to answer your question."


def _is_kg_tool_available(available_tools: dict[str, OrchestratorTool]) -> bool:
"""Check if the Knowledge Graph tool is available in the provided tools."""
return DRPath.KNOWLEDGE_GRAPH.value in available_tools


def _get_available_tools(
Expand Down Expand Up @@ -193,18 +198,42 @@ def _get_available_tools(
return available_tools


def _construct_uploaded_text_context(files: list[InMemoryChatFile]) -> str:
"""Construct the uploaded context from the files."""
file_contents = []
for file in files:
def _construct_uploaded_text_context(
files: list[InMemoryChatFile], max_chars_per_file: int = 8000
) -> str:
"""Construct the uploaded context from the files with better formatting."""
if not files:
return ""

file_sections = []
for i, file in enumerate(files, 1):
if file.file_type in (
ChatFileType.DOC,
ChatFileType.PLAIN_TEXT,
ChatFileType.CSV,
):
file_contents.append(file.content.decode("utf-8"))
if len(file_contents) > 0:
return "Uploaded context:\n\n\n" + "\n\n".join(file_contents)
file_type_name = {
ChatFileType.DOC: "Document",
ChatFileType.PLAIN_TEXT: "Text File",
ChatFileType.CSV: "CSV File",
}.get(file.file_type, "File")

file_name = getattr(file, "file_name", f"file_{i}")
content = file.content.decode("utf-8").strip()

# Truncate if too long
if len(content) > max_chars_per_file:
content = (
content[:max_chars_per_file]
+ f"\n\n[Content truncated - showing first {max_chars_per_file} characters of {len(content)} total]"
)

# Add file header with metadata
file_section = f"=== {file_type_name}: {file_name} ===\n\n{content}"
file_sections.append(file_section)

if file_sections:
return "Uploaded Files:\n\n" + "\n\n---\n\n".join(file_sections)
return ""


Expand Down Expand Up @@ -384,7 +413,8 @@ def clarifier(
)

kg_config = get_kg_config_settings()
if kg_config.KG_ENABLED and kg_config.KG_EXPOSED:
kg_tool_used = _is_kg_tool_available(available_tools)
if kg_config.KG_ENABLED and kg_config.KG_EXPOSED and kg_tool_used:
all_entity_types = get_entity_types_str(active=True)
all_relationship_types = get_relationship_types_str(active=True)
else:
Expand Down Expand Up @@ -421,12 +451,20 @@ def clarifier(
assistant_system_prompt = PromptTemplate(DEFAULT_DR_SYSTEM_PROMPT).build()
assistant_task_prompt = ""

chat_history_string = (
get_chat_history_string(
graph_config.inputs.prompt_builder.message_history,
MAX_CHAT_HISTORY_MESSAGES,
)
or "(No chat history yet available)"
# chat_history_string = (
# get_chat_history_string(
# graph_config.inputs.prompt_builder.message_history,
# MAX_CHAT_HISTORY_MESSAGES,
# )
# or "(No chat history yet available)"
# )

chat_history_messages = get_chat_history_messages(
graph_config.inputs.prompt_builder.raw_message_history,
MAX_CHAT_HISTORY_MESSAGES,
max_tokens=int(
0.7 * max_input_tokens
), # limit chat history to 70% of max input tokens
)

uploaded_text_context = (
Expand All @@ -435,6 +473,8 @@ def clarifier(
else ""
)

# File content will be integrated into the user message instead of separate messages

uploaded_context_tokens = check_number_of_tokens(
uploaded_text_context, llm_tokenizer.encode
)
Expand All @@ -449,25 +489,68 @@ def clarifier(
graph_config.inputs.files
)

message_history_for_continuation: list[SystemMessage | HumanMessage | AIMessage] = (
[]
)

base_system_message = BASE_SYSTEM_MESSAGE_TEMPLATE.build(
assistant_system_prompt=assistant_system_prompt,
active_source_type_descriptions_str=active_source_type_descriptions_str,
entity_types_string=all_entity_types,
relationship_types_string=all_relationship_types,
available_tool_descriptions_str=available_tool_descriptions_str,
)

message_history_for_continuation.append(SystemMessage(content=base_system_message))
message_history_for_continuation.extend(chat_history_messages)

# Create message content that includes text, files, and any available images
user_message_text = original_question
if uploaded_text_context:
# Count the number of files for better messaging
files: list[InMemoryChatFile] = graph_config.inputs.files or []
file_count = len(
[
f
for f in files
if f.file_type
in (ChatFileType.DOC, ChatFileType.PLAIN_TEXT, ChatFileType.CSV)
]
)
file_word = "file" if file_count == 1 else "files"
user_message_text += f"\n\n[I have uploaded {file_count} {file_word} for reference]\n\n{uploaded_text_context}"

message_content: list[dict[str, Any]] = [
{"type": "text", "text": user_message_text}
]
if uploaded_image_context:
message_content.extend(uploaded_image_context)

# If we only have text, use string content for backwards compatibility
if len(message_content) == 1 and not uploaded_text_context:
message_history_for_continuation.append(HumanMessage(content=original_question))
else:
message_history_for_continuation.append(
HumanMessage(content=cast(list[str | dict[Any, Any]], message_content))
)
message_history_for_continuation.append(AIMessage(content=QUESTION_CONFIRMATION))

if not (force_use_tool and force_use_tool.force_use):

if assistant_task_prompt:
reminder = """REMINDER:\n\n""" + assistant_task_prompt
else:
reminder = ""

if not use_tool_calling_llm or len(available_tools) == 1:
if len(available_tools) > 1:
decision_prompt = DECISION_PROMPT_WO_TOOL_CALLING.build(
question=original_question,
chat_history_string=chat_history_string,
uploaded_context=uploaded_text_context or "",
active_source_type_descriptions_str=active_source_type_descriptions_str,
available_tool_descriptions_str=available_tool_descriptions_str,
message_history_for_continuation.append(
HumanMessage(content=DECISION_PROMPT_WO_TOOL_CALLING)
)

llm_decision = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt,
decision_prompt,
uploaded_image_context=uploaded_image_context,
),
prompt=message_history_for_continuation,
schema=DecisionResponse,
)
else:
Expand All @@ -486,22 +569,22 @@ def clarifier(
)

answer_prompt = ANSWER_PROMPT_WO_TOOL_CALLING.build(
question=original_question,
chat_history_string=chat_history_string,
uploaded_context=uploaded_text_context or "",
active_source_type_descriptions_str=active_source_type_descriptions_str,
available_tool_descriptions_str=available_tool_descriptions_str,
reminder=reminder,
)

message_history_for_continuation.append(
AIMessage(content=_ANSWER_COMMENT_PROMPT)
)

message_history_for_continuation.append(
HumanMessage(content=answer_prompt)
)

answer_tokens, _, _ = run_with_timeout(
TF_DR_TIMEOUT_LONG,
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt,
answer_prompt + assistant_task_prompt,
uploaded_image_context=uploaded_image_context,
),
prompt=message_history_for_continuation,
event_name="basic_response",
writer=writer,
answer_piece=StreamingType.MESSAGE_DELTA.value,
Expand Down Expand Up @@ -556,19 +639,14 @@ def clarifier(

else:

decision_prompt = DECISION_PROMPT_W_TOOL_CALLING.build(
question=original_question,
chat_history_string=chat_history_string,
uploaded_context=uploaded_text_context or "",
active_source_type_descriptions_str=active_source_type_descriptions_str,
decision_prompt = DECISION_PROMPT_W_TOOL_CALLING.build(reminder=reminder)

message_history_for_continuation.append(
HumanMessage(content=decision_prompt)
)

stream = graph_config.tooling.primary_llm.stream(
prompt=create_question_prompt(
assistant_system_prompt,
decision_prompt + assistant_task_prompt,
uploaded_image_context=uploaded_image_context,
),
prompt=message_history_for_continuation,
tools=([_ARTIFICIAL_ALL_ENCOMPASSING_TOOL]),
tool_choice=(None),
structured_response_format=graph_config.inputs.structured_response_format,
Expand Down Expand Up @@ -758,6 +836,8 @@ def clarifier(
else:
next_tool = DRPath.ORCHESTRATOR.value

message_history_for_continuation.append(AIMessage(content=_CONSIDER_TOOLS_PROMPT))

return OrchestrationSetup(
original_question=original_question,
chat_history_string=chat_history_string,
Expand All @@ -780,4 +860,7 @@ def clarifier(
assistant_task_prompt=assistant_task_prompt,
uploaded_test_context=uploaded_text_context,
uploaded_image_context=uploaded_image_context,
all_entity_types=all_entity_types,
all_relationship_types=all_relationship_types,
orchestration_llm_messages=message_history_for_continuation,
)
Loading
Loading