Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from datetime import datetime
from typing import cast

from langchain_core.messages import HumanMessage
Expand All @@ -12,6 +11,7 @@
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_NO_BASE_DATA_PROMPT
Expand Down Expand Up @@ -113,42 +113,20 @@ def consolidate_research(
)
]

dispatch_timings: list[float] = []

primary_model = graph_config.tooling.primary_llm

def stream_initial_answer() -> list[str]:
response: list[str] = []
for message in primary_model.stream(msg, timeout_override=30, max_tokens=None):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()

write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
response.append(content)
return response

try:
_ = run_with_timeout(
60,
stream_initial_answer,
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=msg,
event_name="initial_agent_answer",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=30,
max_tokens=None,
),
)

except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from onyx.agents.agent_search.shared_graph_utils.constants import (
LLM_ANSWER_ERROR_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import get_answer_citation_ids
Expand Down Expand Up @@ -112,44 +113,23 @@ def generate_sub_answer(
config=fast_llm.config,
)

dispatch_timings: list[float] = []
agent_error: AgentErrorLog | None = None
response: list[str] = []

def stream_sub_answer() -> list[str]:
for message in fast_llm.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBANSWER_GENERATION,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=content,
level=level,
level_question_num=question_num,
answer_type="agent_sub_answer",
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
response.append(content)
return response

try:
response = run_with_timeout(
response, _ = run_with_timeout(
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION,
stream_sub_answer,
lambda: stream_llm_answer(
llm=fast_llm,
prompt=msg,
event_name="sub_answers",
writer=writer,
agent_answer_level=level,
agent_answer_question_num=question_num,
agent_answer_type="agent_sub_answer",
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBANSWER_GENERATION,
),
)

except (LLMTimeoutError, TimeoutError):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
Expand Down Expand Up @@ -275,46 +276,24 @@ def generate_initial_answer(

agent_error: AgentErrorLog | None = None

def stream_initial_answer() -> list[str]:
response: list[str] = []
for message in model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
max_tokens=(
AGENT_MAX_TOKENS_ANSWER_GENERATION
if _should_restrict_tokens(model.config)
else None
),
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()

write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
response.append(content)
return response

try:
streamed_tokens = run_with_timeout(
streamed_tokens, dispatch_timings = run_with_timeout(
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
stream_initial_answer,
lambda: stream_llm_answer(
llm=model,
prompt=msg,
event_name="initial_agent_answer",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
max_tokens=(
AGENT_MAX_TOKENS_ANSWER_GENERATION
if _should_restrict_tokens(model.config)
else None
),
),
)

except (LLMTimeoutError, TimeoutError):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
Expand All @@ -63,7 +64,6 @@
remove_document_citations,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
Expand Down Expand Up @@ -301,45 +301,24 @@ def generate_validate_refined_answer(
dispatch_timings: list[float] = []
agent_error: AgentErrorLog | None = None

def stream_refined_answer() -> list[str]:
for message in model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
max_tokens=(
AGENT_MAX_TOKENS_ANSWER_GENERATION
if _should_restrict_tokens(model.config)
else None
),
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)

start_stream_token = datetime.now()
write_custom_event(
"refined_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=1,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
streamed_tokens.append(content)
return streamed_tokens

try:
streamed_tokens = run_with_timeout(
streamed_tokens, dispatch_timings = run_with_timeout(
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION,
stream_refined_answer,
lambda: stream_llm_answer(
llm=model,
prompt=msg,
event_name="refined_agent_answer",
writer=writer,
agent_answer_level=1,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
max_tokens=(
AGENT_MAX_TOKENS_ANSWER_GENERATION
if _should_restrict_tokens(model.config)
else None
),
),
)

except (LLMTimeoutError, TimeoutError):
Expand Down
68 changes: 68 additions & 0 deletions backend/onyx/agents/agent_search/shared_graph_utils/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from datetime import datetime
from typing import Literal

from langchain.schema.language_model import LanguageModelInput
from langgraph.types import StreamWriter

from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.llm.interfaces import LLM


def stream_llm_answer(
llm: LLM,
prompt: LanguageModelInput,
event_name: str,
writer: StreamWriter,
agent_answer_level: int,
agent_answer_question_num: int,
agent_answer_type: Literal["agent_level_answer", "agent_sub_answer"],
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> tuple[list[str], list[float]]:
"""Stream the initial answer from the LLM.

Args:
llm: The LLM to use.
prompt: The prompt to use.
event_name: The name of the event to write.
writer: The writer to write to.
agent_answer_level: The level of the agent answer.
agent_answer_question_num: The question number within the level.
agent_answer_type: The type of answer ("agent_level_answer" or "agent_sub_answer").
timeout_override: The LLM timeout to use.
max_tokens: The LLM max tokens to use.

Returns:
A tuple of the response and the dispatch timings.
"""
response: list[str] = []
dispatch_timings: list[float] = []

for message in llm.stream(
prompt, timeout_override=timeout_override, max_tokens=max_tokens
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)

start_stream_token = datetime.now()
write_custom_event(
event_name,
AgentAnswerPiece(
answer_piece=content,
level=agent_answer_level,
level_question_num=agent_answer_question_num,
answer_type=agent_answer_type,
),
writer,
)
end_stream_token = datetime.now()

dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
response.append(content)

return response, dispatch_timings