Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,14 @@ def search_objects(
"""

graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
question = graph_config.inputs.prompt_builder.raw_user_query
search_tool = graph_config.tooling.search_tool

if search_tool is None or graph_config.inputs.search_request.persona is None:
if search_tool is None or graph_config.inputs.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")

try:
instructions = graph_config.inputs.search_request.persona.prompts[
0
].system_prompt
instructions = graph_config.inputs.persona.prompts[0].system_prompt

agent_1_instructions = extract_section(
instructions, "Agent Step 1:", "Agent Step 2:"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,15 @@ def research_object_source(
datetime.now()

graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config.inputs.search_request.query
search_tool = graph_config.tooling.search_tool
question = graph_config.inputs.search_request.query
question = graph_config.inputs.prompt_builder.raw_user_query
object, document_source = state.object_source_combination

if search_tool is None or graph_config.inputs.search_request.persona is None:
if search_tool is None or graph_config.inputs.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")

try:
instructions = graph_config.inputs.search_request.persona.prompts[
0
].system_prompt
instructions = graph_config.inputs.persona.prompts[0].system_prompt

agent_2_instructions = extract_section(
instructions, "Agent Step 2:", "Agent Step 3:"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from collections import defaultdict
from datetime import datetime
from typing import cast
from typing import Dict
from typing import List

Expand All @@ -11,7 +9,6 @@
from onyx.agents.agent_search.dc_search_analysis.states import (
ObjectResearchInformationUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.utils.logger import setup_logger
Expand All @@ -25,10 +22,6 @@ def structure_research_by_object(
"""
LangGraph node to start the agentic search process.
"""
datetime.now()

graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config.inputs.search_request.query

write_custom_event(
"initial_agent_answer",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,13 @@ def consolidate_object_research(
LangGraph node to start the agentic search process.
"""
graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config.inputs.search_request.query
search_tool = graph_config.tooling.search_tool
question = graph_config.inputs.search_request.query
question = graph_config.inputs.prompt_builder.raw_user_query

if search_tool is None or graph_config.inputs.search_request.persona is None:
if search_tool is None or graph_config.inputs.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")

instructions = graph_config.inputs.search_request.persona.prompts[0].system_prompt
instructions = graph_config.inputs.persona.prompts[0].system_prompt

agent_4_instructions = extract_section(
instructions, "Agent Step 4:", "Agent Step 5:"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@ def consolidate_research(
"""
LangGraph node to start the agentic search process.
"""
datetime.now()

graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config.inputs.search_request.query

search_tool = graph_config.tooling.search_tool

Expand All @@ -46,11 +44,11 @@ def consolidate_research(
writer,
)

if search_tool is None or graph_config.inputs.search_request.persona is None:
if search_tool is None or graph_config.inputs.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")

# Populate prompt
instructions = graph_config.inputs.search_request.persona.prompts[0].system_prompt
instructions = graph_config.inputs.persona.prompts[0].system_prompt

try:
agent_5_instructions = extract_section(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def generate_sub_answer(
context_docs = dedup_sort_inference_section_list(context_docs)

persona_contextualized_prompt = get_persona_agent_prompt_expressions(
graph_config.inputs.search_request.persona
graph_config.inputs.persona
).contextualized_prompt

if len(context_docs) == 0:
Expand All @@ -106,7 +106,7 @@ def generate_sub_answer(
fast_llm = graph_config.tooling.fast_llm
msg = build_sub_question_answer_prompt(
question=question,
original_question=graph_config.inputs.search_request.query,
original_question=graph_config.inputs.prompt_builder.raw_user_query,
docs=context_docs,
persona_specification=persona_contextualized_prompt,
config=fast_llm.config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def generate_initial_answer(
node_start_time = datetime.now()

graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
question = graph_config.inputs.prompt_builder.raw_user_query
prompt_enrichment_components = get_prompt_enrichment_components(graph_config)

# get all documents cited in sub-questions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def decompose_orig_question(
node_start_time = datetime.now()

graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
question = graph_config.inputs.prompt_builder.raw_user_query
perform_initial_search_decomposition = (
graph_config.behavior.perform_initial_search_decomposition
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def format_orig_question_search_input(
logger.debug("generate_raw_search_data")
graph_config = cast(GraphConfig, config["metadata"]["config"])
return ExpandedRetrievalInput(
question=graph_config.inputs.search_request.query,
question=graph_config.inputs.prompt_builder.raw_user_query,
base_search=True,
sub_question_id=None, # This graph is always and only used for the original question
log_messages=[],
Expand Down
20 changes: 10 additions & 10 deletions backend/onyx/agents/agent_search/deep_search/main/edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,18 @@ def route_initial_tool_choice(
LangGraph edge to route to agent search.
"""
agent_config = cast(GraphConfig, config["metadata"]["config"])
if state.tool_choice is not None:
if (
agent_config.behavior.use_agentic_search
and agent_config.tooling.search_tool is not None
and state.tool_choice.tool.name == agent_config.tooling.search_tool.name
):
return "start_agent_search"
else:
return "call_tool"
else:
if state.tool_choice is None:
return "logging_node"

if (
agent_config.behavior.use_agentic_search
and agent_config.tooling.search_tool is not None
and state.tool_choice.tool.name == agent_config.tooling.search_tool.name
):
return "start_agent_search"
else:
return "call_tool"


def parallelize_initial_sub_question_answering(
state: MainState,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
test_mode = False


def main_graph_builder(test_mode: bool = False) -> StateGraph:
def agent_search_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the main agent search process.
"""
Expand All @@ -76,7 +76,7 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:

# Choose the initial tool
graph.add_node(
node="initial_tool_choice",
node="choose_tool",
action=choose_tool,
)

Expand Down Expand Up @@ -162,11 +162,11 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:

graph.add_edge(
start_key="prepare_tool_input",
end_key="initial_tool_choice",
end_key="choose_tool",
)

graph.add_conditional_edges(
"initial_tool_choice",
"choose_tool",
route_initial_tool_choice,
["call_tool", "start_agent_search", "logging_node"],
)
Expand Down Expand Up @@ -242,7 +242,7 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest

graph = main_graph_builder()
graph = agent_search_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def compare_answers(
node_start_time = datetime.now()

graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
question = graph_config.inputs.prompt_builder.raw_user_query
initial_answer = state.initial_answer
refined_answer = state.refined_answer

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def create_refined_sub_questions(
ToolCallKickoff(
tool_name="agent_search_1",
tool_args={
"query": graph_config.inputs.search_request.query,
"query": graph_config.inputs.prompt_builder.raw_user_query,
"answer": state.initial_answer,
},
),
Expand All @@ -96,7 +96,7 @@ def create_refined_sub_questions(

agent_refined_start_time = datetime.now()

question = graph_config.inputs.search_request.query
question = graph_config.inputs.prompt_builder.raw_user_query
base_answer = state.initial_answer
history = build_history_prompt(graph_config, question)
# get the entity term extraction dict and properly format it
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def extract_entities_terms(
)

# first four lines duplicates from generate_initial_answer
question = graph_config.inputs.search_request.query
question = graph_config.inputs.prompt_builder.raw_user_query
initial_search_docs = state.exploratory_search_results[:NUM_EXPLORATORY_DOCS]

# start with the entity/term/extraction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def generate_validate_refined_answer(
node_start_time = datetime.now()

graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
question = graph_config.inputs.prompt_builder.raw_user_query
prompt_enrichment_components = get_prompt_enrichment_components(graph_config)

persona_contextualized_prompt = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def persist_agent_results(state: MainState, config: RunnableConfig) -> MainOutpu

persona_id = None
graph_config = cast(GraphConfig, config["metadata"]["config"])
if graph_config.inputs.search_request.persona:
persona_id = graph_config.inputs.search_request.persona.id
if graph_config.inputs.persona:
persona_id = graph_config.inputs.persona.id

user_id = None
assert (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def start_agent_search(
node_start_time = datetime.now()

graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
question = graph_config.inputs.prompt_builder.raw_user_query

history = build_history_prompt(graph_config, question)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def parallel_retrieval_edge(
"""
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = (
state.question if state.question else graph_config.inputs.search_request.query
state.question
if state.question
else graph_config.inputs.prompt_builder.raw_user_query
)

query_expansions = state.expanded_queries + [question]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,16 @@ def rerank_documents(

graph_config = cast(GraphConfig, config["metadata"]["config"])
question = (
state.question if state.question else graph_config.inputs.search_request.query
state.question
if state.question
else graph_config.inputs.prompt_builder.raw_user_query
)
assert (
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"

# Note that these are passed in values from the API and are overrides which are typically None
rerank_settings = graph_config.inputs.search_request.rerank_settings
rerank_settings = graph_config.inputs.rerank_settings
allow_agent_reranking = graph_config.behavior.allow_agent_reranking

if rerank_settings is None:
Expand Down Expand Up @@ -95,7 +97,7 @@ def rerank_documents(

return DocRerankingUpdate(
reranked_documents=[
doc for doc in reranked_documents if type(doc) == InferenceSection
doc for doc in reranked_documents if isinstance(doc, InferenceSection)
][:AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS],
sub_question_retrieval_stats=fit_scores,
log_messages=[
Expand Down
6 changes: 4 additions & 2 deletions backend/onyx/agents/agent_search/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from sqlalchemy.orm import Session

from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.context.search.models import SearchRequest
from onyx.context.search.models import RerankingDetails
from onyx.db.models import Persona
from onyx.file_store.utils import InMemoryChatFile
from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool
Expand All @@ -16,7 +17,8 @@
class GraphInputs(BaseModel):
"""Input data required for the graph execution"""

search_request: SearchRequest
persona: Persona | None = None
rerank_settings: RerankingDetails | None = None
prompt_builder: AnswerPromptBuilder
files: list[InMemoryChatFile] | None = None
structured_response_format: dict | None = None
Expand Down
Loading