diff --git a/backend/onyx/chat/answer_cli.py b/backend/onyx/chat/answer_cli.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/onyx/chat/answer_scratchpad.py b/backend/onyx/chat/answer_scratchpad.py new file mode 100644 index 00000000000..9cc64cfb227 --- /dev/null +++ b/backend/onyx/chat/answer_scratchpad.py @@ -0,0 +1,720 @@ +from __future__ import annotations + +import asyncio +import contextvars +import json +import os +import queue +import threading +from collections.abc import Generator +from collections.abc import Iterator +from dataclasses import dataclass +from queue import Queue +from typing import Any +from typing import cast +from typing import Dict +from typing import List +from typing import Optional + +import litellm +from agents import Agent +from agents import AgentHooks +from agents import function_tool +from agents import ModelSettings +from agents import RunContextWrapper +from agents import Runner +from agents.extensions.handoff_prompt import prompt_with_handoff_instructions +from agents.extensions.handoff_prompt import RECOMMENDED_PROMPT_PREFIX +from agents.extensions.models.litellm_model import LitellmModel +from agents.handoffs import HandoffInputData +from agents.stream_events import RawResponsesStreamEvent +from agents.stream_events import RunItemStreamEvent +from braintrust import traced +from openai.types import Reasoning +from pydantic import BaseModel + +from onyx.agents.agent_search.dr.constants import MAX_CHAT_HISTORY_MESSAGES +from onyx.agents.agent_search.dr.dr_prompt_builder import ( + get_dr_prompt_orchestration_templates, +) +from onyx.agents.agent_search.dr.enums import ResearchType +from onyx.agents.agent_search.dr.models import DRPromptPurpose +from onyx.agents.agent_search.dr.sub_agents.web_search.clients.exa_client import ( + ExaClient, +) +from onyx.agents.agent_search.dr.utils import get_chat_history_string +from onyx.agents.agent_search.models import GraphConfig +from onyx.db.engine.sql_engine import get_session_with_current_tenant +from onyx.llm.interfaces import ( + LLM, +) +from onyx.tools.models import SearchToolOverrideKwargs +from onyx.tools.tool_implementations.search.search_tool import ( + SEARCH_RESPONSE_SUMMARY_ID, +) +from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary +from onyx.tools.tool_implementations.search.search_tool import SearchTool +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +@dataclass +class RunDependencies: + emitter: Emitter + llm: LLM + search_tool: SearchTool | None = None + + +@dataclass +class MyContext: + """Context class to hold search tool and other dependencies""" + + run_dependencies: RunDependencies | None = None + needs_compaction: bool = False + + +def short_tag(link: str, i: int) -> str: + # Stable, readable; index keeps it deterministic across a batch + return f"S{i+1}" + + +@function_tool +def web_search(query: str) -> str: + """Search the web for information. This tool provides urls and short snippets, + but does not fetch the full content of the urls.""" + exa_client = ExaClient() + hits = exa_client.search(query) + results = [] + for i, r in enumerate(hits): + results.append( + { + "tag": short_tag(r.link, i), # <-- add a tag + "title": r.title, + "link": r.link, + "snippet": r.snippet, + "author": r.author, + "published_date": ( + r.published_date.isoformat() if r.published_date else None + ), + } + ) + return json.dumps({"results": results}) + + +@function_tool +def web_fetch(urls: List[str]) -> str: + """Fetch the full contents of a list of URLs.""" + exa_client = ExaClient() + docs = exa_client.contents(urls) + out = [] + for i, d in enumerate(docs): + out.append( + { + "tag": short_tag(d.link, i), # <-- add a tag + "title": d.title, + "link": d.link, + "full_content": d.full_content, + "published_date": ( + d.published_date.isoformat() if d.published_date else None + ), + } + ) + return json.dumps({"results": out}) + + +@traced(name="llm_completion", type="llm") +def llm_completion( + model_name: str, + temperature: float, + messages: List[Dict[str, Any]], + stream: bool = False, +) -> litellm.ModelResponse: + return litellm.responses( + model=model_name, + input=messages, + tools=[], + stream=stream, + reasoning=litellm.Reasoning(effort="medium", summary="detailed"), + ) + + +@function_tool +def internal_search(context_wrapper: RunContextWrapper[MyContext], query: str) -> str: + """Search internal company vector database for information. Sources + include: + - Fireflies (internal company call transcripts) + - Google Drive (internal company documents) + - Gmail (internal company emails) + - Linear (internal company issues) + - Slack (internal company messages) + """ + context_wrapper.context.run_dependencies.emitter.emit( + kind="tool-progress", data={"progress": "Searching internal database"} + ) + search_tool = context_wrapper.context.run_dependencies.search_tool + if search_tool is None: + raise RuntimeError("Search tool not available in context") + + with get_session_with_current_tenant() as search_db_session: + for tool_response in search_tool.run( + query=query, + override_kwargs=SearchToolOverrideKwargs( + force_no_rerank=True, + alternate_db_session=search_db_session, + skip_query_analysis=True, + original_query=query, + ), + ): + # get retrieved docs to send to the rest of the graph + if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID: + response = cast(SearchResponseSummary, tool_response.response) + retrieved_docs = response.top_sections + + break + return retrieved_docs + + +def _convert_to_packet_obj(packet: Dict[str, Any]) -> Any | None: + """Convert a packet dictionary to PacketObj when possible. + + Args: + packet: Dictionary containing packet data + + Returns: + PacketObj instance if conversion is possible, None otherwise + """ + if not isinstance(packet, dict) or "type" not in packet: + return None + + packet_type = packet.get("type") + if not packet_type: + return None + + try: + # Import here to avoid circular imports + from onyx.server.query_and_chat.streaming_models import ( + MessageStart, + MessageDelta, + OverallStop, + ) + + if packet_type == "response.output_item.added": + return MessageStart( + type="message_start", + content="", + final_documents=None, + ) + elif packet_type == "response.output_text.delta": + return MessageDelta(type="message_delta", content=packet["delta"]) + elif packet_type == "response.completed": + return OverallStop(type="stop") + + except Exception as e: + # Log the error but don't fail the entire process + logger.debug(f"Failed to convert packet to PacketObj: {e}") + + return None + + +# stream_bus.py +@dataclass +class StreamPacket: + kind: str # "agent" | "tool-progress" | "done" + payload: Dict[str, Any] = None + + +class Emitter: + """Use this inside tools to emit arbitrary UI progress.""" + + def __init__(self, bus: Queue): + self.bus = bus + + def emit(self, kind: str, data: Dict[str, Any]) -> None: + self.bus.put(StreamPacket(kind=kind, payload=data)) + + +# If we want durable execution in the future, we can replace this with a temporal call +def start_run_in_thread( + agent: Agent, + messages: List[Dict[str, Any]], + cfg: GraphConfig, + llm: LLM, + emitter: Emitter, + search_tool: SearchTool | None = None, +) -> threading.Thread: + def worker(): + async def amain(): + ctx = MyContext( + run_dependencies=RunDependencies( + search_tool=search_tool, + emitter=emitter, + llm=llm, + ) + ) + # 1) start the streamed run (async) + streamed = Runner.run_streamed(agent, messages, context=ctx) + + # 2) forward the agent’s async event stream + async for ev in streamed.stream_events(): + if isinstance(ev, RunItemStreamEvent): + pass + elif isinstance(ev, RawResponsesStreamEvent): + emitter.emit(kind="agent", data=ev.data.model_dump()) + + emitter.emit(kind="done", data={"ok": True}) + + # run the async main inside this thread + asyncio.run(amain()) + + t = threading.Thread(target=worker, daemon=True) + t.start() + return t + + +class ResearchScratchpad(BaseModel): + notes: List[dict] = [] + + +scratchpad = ResearchScratchpad() + + +@function_tool +def add_note(note: str, source_url: str | None = None): + """Store a factual note you want to cite later.""" + scratchpad.notes.append({"note": note, "source_url": source_url}) + return {"ok": True, "count": len(scratchpad.notes)} + + +@function_tool +def finalize_report(): + """Signal you're done researching. Return a structured, citation-rich report.""" + # The model should *compose* the report as the tool *result*, using notes in scratchpad. + # Some teams have the model return the full report as this tool's return value + # so the UI can detect completion cleanly. + return { + "status": "ready_to_render", + "notes_index": scratchpad.notes, # the model can read these to assemble citations + } + + +class CompactionHooks(AgentHooks[Any]): + async def on_llm_start( + self, + context: RunContextWrapper[MyContext], + agent: Agent[Any], + system_prompt: Optional[str], + input_items: List[dict], + ) -> None: + print(f"[{agent.name}] LLM start") + print("system_prompt:", system_prompt) + print("usage so far:", context.usage.total_tokens) + usage = context.usage.total_tokens + if usage > 10000: + context.context.needs_compaction = True + + +def compaction_input_filter(input_data: HandoffInputData): + filtered_messages = [] + for msg in input_data.input_history[:-1]: + if isinstance(msg, dict) and msg.get("content") is not None: + # Convert tool messages to user messages to avoid API errors + if msg.get("role") == "tool": + filtered_msg = { + "role": "user", + "content": f"Tool response: {msg.get('content', '')}", + } + filtered_messages.append(filtered_msg) + else: + filtered_messages.append(msg) + + # Only proceed with compaction if we have valid messages + if filtered_messages: + return [filtered_messages[-1]] + + +def construct_deep_research_agent(llm: LLM) -> Agent: + litellm_model = LitellmModel( + # If you have access, prefer OpenAI’s deep research-capable models: + # "o3-deep-research" or "o4-mini-deep-research" + # otherwise keep your current model and lean on the prompt + tools + model=llm.config.model_name, + api_key=llm.config.api_key, + ) + + DR_INSTRUCTIONS = f""" + {RECOMMENDED_PROMPT_PREFIX} +You are a deep-research agent. Work in explicit iterations: +1) PLAN: Decompose the user’s query into sub-questions and a step-by-step plan. +2) SEARCH: Use web_search to explore multiple angles, fanning out and searching in parallel. +3) FETCH: Use web_fetch for any promising URLs to extract specifics and quotes. +4) NOTE: After each useful find, call add_note(note, source_url) to save key facts. +5) REVISE: If evidence contradicts earlier assumptions, update your plan and continue. +6) FINALIZE: When confident, call finalize_report(). Your final answer must include: + - Clear, structured conclusions + - A short “How I searched” summary + - Inline citations to sources (with URLs) + - A bullet list of limitations/open questions +Guidelines: +- Prefer breadth-first exploration before deep dives. +- Compare sources and dates; prioritize recency for time-sensitive topics. +- Minimize redundancy by skimming before fetching. +- Think out loud in a compact way, but keep reasoning crisp. +- If context exceeds 10000 tokens, handoff to the compactor agent. +""" + return Agent( + name="Researcher", + instructions=DR_INSTRUCTIONS, + model=litellm_model, + tools=[web_search, web_fetch, add_note, finalize_report, internal_search], + model_settings=ModelSettings( + temperature=llm.config.temperature, + include_usage=True, + parallel_tool_calls=True, + # optional: let model choose tools freely + # tool_choice="auto", # if supported by your LitellmModel wrapper + ), + hooks=CompactionHooks(), + ) + + +def unified_event_stream( + messages: List[Dict[str, Any]], + cfg: GraphConfig, + llm: LLM, + emitter: Emitter, + search_tool: SearchTool | None = None, +) -> Generator[Dict[str, Any], None, None]: + bus: Queue = Queue() + emitter = Emitter(bus) + current_context = contextvars.copy_context() + t = threading.Thread( + target=current_context.run, + args=( + # thread_worker_dr_turn, + thread_worker_simple_turn, + messages, + cfg, + llm, + emitter, + search_tool, + ), # eval_context=None for now + daemon=True, + ) + t.start() + done = False + while not done: + pkt: StreamPacket = emitter.bus.get() + if pkt.kind == "done": + done = True + else: + # Convert packet to PacketObj when possible + packet_obj = _convert_to_packet_obj(pkt.payload) + if packet_obj: + # Convert PacketObj back to dict for compatibility + yield packet_obj.model_dump() + else: + # Fallback to original payload + yield pkt.payload + + +# This should be close to the API +def stream_chat_sync( + messages: List[Dict[str, Any]], + cfg: GraphConfig, + llm: LLM, + search_tool: SearchTool | None = None, +) -> Generator[Dict[str, Any], None, None]: + bus: Queue = Queue() + emitter = Emitter(bus) + return unified_event_stream( + messages=messages, + cfg=cfg, + llm=llm, + emitter=emitter, + search_tool=search_tool, + ) + + +def construct_simple_agent( + llm: LLM, +) -> Agent: + litellm_model = LitellmModel( + model="o3-mini", + api_key=llm.config.api_key, + ) + return Agent( + name="Assistant", + instructions=""" + You are a helpful assistant that can search the web, fetch content from URLs, + and search internal databases. Please do some reasoning and then return your answer. + """, + model=litellm_model, + tools=[web_search, web_fetch, internal_search], + model_settings=ModelSettings( + temperature=0.0, + include_usage=True, # Track usage metrics + reasoning=Reasoning( + effort="medium", summary="detailed", generate_summary="detailed" + ), + verbose=True, + ), + ) + + +def thread_worker_dr_turn(messages, cfg, llm, emitter, search_tool): + """ + Worker function for deep research turn that runs in a separate thread. + + Args: + messages: List of messages for the conversation + cfg: Graph configuration + llm: Language model instance + emitter: Event emitter for streaming responses + search_tool: Search tool instance (optional) + eval_context: Evaluation context to be propagated to the worker thread + """ + try: + dr_turn(messages, cfg, llm, emitter, search_tool) + except Exception as e: + logger.error(f"Error in dr_turn: {e}", exc_info=e, stack_info=True) + emitter.emit(kind="done", data={"ok": False}) + + +def thread_worker_simple_turn(messages, cfg, llm, emitter, search_tool): + try: + simple_turn( + messages=messages, + cfg=cfg, + llm=llm, + turn_event_stream_emitter=emitter, + search_tool=search_tool, + ) + except Exception as e: + logger.error(f"Error in simple_turn: {e}", exc_info=e, stack_info=True) + emitter.emit(kind="done", data={"ok": False}) + + +SENTINEL = object() + + +class StreamBridge: + """ + Spins up an asyncio loop in a background thread, starts Runner.run_streamed there, + consumes its async event stream, and exposes a blocking .events() iterator. + """ + + def __init__(self, agent, messages, ctx, max_turns: int = 100): + self.agent = agent + self.messages = messages + self.ctx = ctx + self.max_turns = max_turns + + self._q: "queue.Queue[object]" = queue.Queue() + self._loop: Optional[asyncio.AbstractEventLoop] = None + self._thread: Optional[threading.Thread] = None + self._streamed = None + + def start(self): + def worker(): + async def run_and_consume(): + # Create the streamed run *inside* the loop thread + self._streamed = Runner.run_streamed( + self.agent, + self.messages, + context=self.ctx, + max_turns=self.max_turns, + ) + try: + async for ev in self._streamed.stream_events(): + self._q.put(ev) + finally: + self._q.put(SENTINEL) + + # Each thread needs its own loop + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + try: + self._loop.run_until_complete(run_and_consume()) + finally: + self._loop.close() + + self._thread = threading.Thread(target=worker, daemon=True) + self._thread.start() + return self + + def events(self) -> Iterator[object]: + while True: + ev = self._q.get() + if ev is SENTINEL: + break + yield ev + + def cancel(self): + # Post a cancellation to the loop thread safely + if self._loop and self._streamed: + + def _do_cancel(): + try: + self._streamed.cancel() + except Exception: + pass + + self._loop.call_soon_threadsafe(_do_cancel) + + +def simple_turn( + messages: List[Dict[str, Any]], + cfg: GraphConfig, + llm: LLM, + turn_event_stream_emitter: Emitter, + search_tool: SearchTool | None = None, +) -> None: + llm_response = llm_completion( + model_name="gpt-5-mini", + temperature=0.0, + messages=messages, + stream=True, + ) + llm_response.json() + simple_agent = construct_simple_agent(llm) + ctx = MyContext( + run_dependencies=RunDependencies( + search_tool=search_tool, emitter=turn_event_stream_emitter, llm=llm + ) + ) + bridge = StreamBridge(simple_agent, messages, ctx, max_turns=100).start() + for ev in bridge.events(): + if isinstance(ev, RunItemStreamEvent): + print("RUN ITEM STREAM EVENT!") + if ev.name == "reasoning_item_created": + print("REASONING!") + turn_event_stream_emitter.emit( + kind="reasoning", data=ev.item.raw_item.model_dump() + ) + elif isinstance(ev, RawResponsesStreamEvent): + print("RAW RESPONSES STREAM EVENT!") + print(ev.type) + turn_event_stream_emitter.emit(kind="agent", data=ev.data.model_dump()) + turn_event_stream_emitter.emit(kind="done", data={"ok": True}) + + +def dr_turn( + messages: List[Dict[str, Any]], + cfg: GraphConfig, + llm: LLM, + turn_event_stream_emitter: Emitter, # TurnEventStream is the primary output of the turn + search_tool: SearchTool | None = None, +) -> None: + """ + Execute a deep research turn with evaluation context support. + + Args: + messages: List of messages for the conversation + cfg: Graph configuration + llm: Language model instance + turn_event_stream_emitter: Event emitter for streaming responses + search_tool: Search tool instance (optional) + eval_context: Evaluation context for the turn (optional) + """ + clarification = get_clarification( + messages, cfg, llm, turn_event_stream_emitter, search_tool + ) + output = json.loads(clarification.choices[0].message.content) + clarification_output = ClarificationOutput(**output) + if clarification_output.clarification_needed: + turn_event_stream_emitter.emit( + kind="agent", data=clarification_output.clarification_question + ) + turn_event_stream_emitter.emit(kind="done", data={"ok": True}) + return + dr_agent = construct_deep_research_agent(llm) + ctx = MyContext( + run_dependencies=RunDependencies( + search_tool=search_tool, + emitter=turn_event_stream_emitter, + llm=llm, + ) + ) + bridge = StreamBridge(dr_agent, messages, ctx, max_turns=100).start() + for ev in bridge.events(): + if isinstance(ev, RunItemStreamEvent): + pass + elif isinstance(ev, RawResponsesStreamEvent): + turn_event_stream_emitter.emit(kind="agent", data=ev.data.model_dump()) + + turn_event_stream_emitter.emit(kind="done", data={"ok": True}) + + +class ClarificationOutput(BaseModel): + clarification_question: str + clarification_needed: bool + + +def get_clarification( + messages: List[Dict[str, Any]], + cfg: GraphConfig, + llm: LLM, + emitter: Emitter, + search_tool: SearchTool | None = None, +) -> litellm.ModelResponse: + chat_history_string = ( + get_chat_history_string( + cfg.inputs.prompt_builder.message_history, + MAX_CHAT_HISTORY_MESSAGES, + ) + or "(No chat history yet available)" + ) + base_clarification_prompt = get_dr_prompt_orchestration_templates( + DRPromptPurpose.CLARIFICATION, + research_type=ResearchType.DEEP, + entity_types_string=None, + relationship_types_string=None, + available_tools={}, + ) + clarification_prompt = base_clarification_prompt.build( + question=messages[-1]["content"], + chat_history_string=chat_history_string, + ) + clarifier_prompt = prompt_with_handoff_instructions(clarification_prompt) + llm_response = llm_completion( + model_name=llm.config.model_name, + temperature=llm.config.temperature, + messages=[{"role": "user", "content": clarifier_prompt}], + stream=False, + ) + return llm_response + + +if __name__ == "__main__": + messages = [ + { + "role": "user", + "content": """ + Let $N$ denote the number of ordered triples of positive integers $(a, b, c)$ such that $a, b, c + \\leq 3^6$ and $a^3 + b^3 + c^3$ is a multiple of $3^7$. Find the remainder when $N$ is divided by $1000$. + """, + } + ] + # OpenAI reasoning is not supported yet due to: https://github.com/BerriAI/litellm/pull/14117 + reasoning_agent = Agent( + name="Reasoning", + instructions="You are a reasoning agent. You are given a question and you need to reason about it.", + model=LitellmModel( + model="gpt-5-mini", + api_key=os.getenv("OPENAI_API_KEY"), + ), + tools=[], + model_settings=ModelSettings( + temperature=0.0, + reasoning=Reasoning(effort="medium", summary="detailed"), + ), + ) + llm_response = llm_completion( + model_name="gpt-5-mini", + temperature=0.0, + messages=messages, + stream=False, + ) + x = llm_response.json() + print(x) diff --git a/backend/onyx/chat/process_message.py b/backend/onyx/chat/process_message.py index 69f85a9b1eb..32e95e7d2b0 100644 --- a/backend/onyx/chat/process_message.py +++ b/backend/onyx/chat/process_message.py @@ -3,13 +3,16 @@ import traceback from collections.abc import Callable from collections.abc import Iterator +from typing import Any from typing import cast +from typing import Dict from typing import Protocol from sqlalchemy.orm import Session from onyx.agents.agent_search.orchestration.nodes.call_tool import ToolCallException from onyx.chat.answer import Answer +from onyx.chat.answer_scratchpad import stream_chat_sync from onyx.chat.chat_utils import create_chat_chain from onyx.chat.chat_utils import create_temporary_persona from onyx.chat.chat_utils import process_kg_commands @@ -24,9 +27,6 @@ from onyx.chat.models import QADocsResponse from onyx.chat.models import StreamingError from onyx.chat.models import UserKnowledgeFilePacket -from onyx.chat.packet_proccessing.process_streamed_packets import ( - process_streamed_packets, -) from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message @@ -41,7 +41,6 @@ from onyx.context.search.enums import OptionalSearchSetting from onyx.context.search.models import InferenceSection from onyx.context.search.models import RetrievalDetails -from onyx.context.search.models import SavedSearchDoc from onyx.context.search.retrieval.search_runner import ( inference_sections_from_ids, ) @@ -76,11 +75,7 @@ from onyx.llm.utils import litellm_exception_to_error_msg from onyx.natural_language_processing.utils import get_tokenizer from onyx.server.query_and_chat.models import CreateChatMessageRequest -from onyx.server.query_and_chat.streaming_models import CitationDelta from onyx.server.query_and_chat.streaming_models import CitationInfo -from onyx.server.query_and_chat.streaming_models import MessageDelta -from onyx.server.query_and_chat.streaming_models import MessageStart -from onyx.server.query_and_chat.streaming_models import Packet from onyx.server.utils import get_json_line from onyx.tools.force import ForceUseTool from onyx.tools.models import SearchToolOverrideKwargs @@ -671,10 +666,49 @@ def stream_chat_message_objects( use_agentic_search=new_msg_req.use_agentic_search, skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation, ) - - # Process streamed packets using the new packet processing module - yield from process_streamed_packets( - answer_processed_output=answer.processed_streamed_output, + type_to_role = { + "human": "user", + "assistant": "assistant", + "system": "system", + "function": "function", + } + SYSTEM_PROMPT = """ + You are a highly capable, thoughtful, and precise assistant. Your goal is to deeply understand the \ + user's intent, ask clarifying questions when needed, think step-by-step through complex problems, \ + provide clear and accurate answers, and proactively anticipate helpful follow-up information. Always \ + prioritize being truthful, nuanced, insightful, and efficient. + The current date is September 18, 2025. + + You use different text styles, bolding, emojis (sparingly), block quotes, and other formatting to make \ + your responses more readable and engaging. + You use proper Markdown and LaTeX to format your responses for math, scientific, and chemical formulas, \ + symbols, etc.: '$$\\n[expression]\\n$$' for standalone cases and '\\( [expression] \\)' when inline. + For code you prefer to use Markdown and specify the language. + You can use Markdown horizontal rules (---) to separate sections of your responses. + You can use Markdown tables to format your responses for data, lists, and other structured information. + + You must cite inline using tags from tool results. + + Rules: + - Only cite sources provided by the tools (use each item’s "tag" field). + - Place the citation immediately after the claim it supports, like this: "... result [S1](https://linkforS1)" or + "... results [S1](https://linkforS1)[S3](https://linkforS3)". + - If multiple sentences in a row are supported by the same source, cite the first sentence; + then omit repeats until the source changes. + - Never invent tags. If no source supports a claim, say so. + - Do not add a separate “Sources” section unless asked. + """ + system_message = [{"role": "system", "content": SYSTEM_PROMPT}] + other_messages = [ + {"role": type_to_role[message.type], "content": message.content} + for message in answer.graph_inputs.prompt_builder.build() + if message.type != "system" + ] + yield from stream_chat_sync( + messages=system_message + other_messages, + cfg=answer.graph_config, + llm=answer.graph_tooling.primary_llm, + search_tool=answer.graph_tooling.search_tool, ) except ValueError as e: @@ -736,7 +770,15 @@ def stream_chat_message( document_retrieval_latency = time.time() - start_time logger.debug(f"First doc time: {document_retrieval_latency}") - yield get_json_line(obj.model_dump()) + # Convert Pydantic models to dictionaries for JSON serialization + if hasattr(obj, "model_dump"): + obj_dict = obj.model_dump() + elif hasattr(obj, "dict"): + obj_dict = obj.dict() + else: + obj_dict = obj + + yield get_json_line(obj_dict) def remove_answer_citations(answer: str) -> str: @@ -745,48 +787,98 @@ def remove_answer_citations(answer: str) -> str: return re.sub(pattern, "", answer) +def _convert_to_packet_obj(packet: Dict[str, Any]) -> Any | None: + """Convert a packet dictionary to PacketObj when possible. + + Args: + packet: Dictionary containing packet data + + Returns: + PacketObj instance if conversion is possible, None otherwise + """ + if not isinstance(packet, dict) or "type" not in packet: + return None + + packet_type = packet.get("type") + if not packet_type: + return None + + try: + # Import here to avoid circular imports + from onyx.server.query_and_chat.streaming_models import ( + MessageStart, + MessageDelta, + OverallStop, + SectionEnd, + SearchToolStart, + SearchToolDelta, + ImageGenerationToolStart, + ImageGenerationToolDelta, + ImageGenerationToolHeartbeat, + CustomToolStart, + CustomToolDelta, + ReasoningStart, + ReasoningDelta, + CitationStart, + CitationDelta, + ) + + # Map packet types to their corresponding classes + type_mapping = { + "message_start": MessageStart, + "message_delta": MessageDelta, + "stop": OverallStop, + "section_end": SectionEnd, + "internal_search_tool_start": SearchToolStart, + "internal_search_tool_delta": SearchToolDelta, + "image_generation_tool_start": ImageGenerationToolStart, + "image_generation_tool_delta": ImageGenerationToolDelta, + "image_generation_tool_heartbeat": ImageGenerationToolHeartbeat, + "custom_tool_start": CustomToolStart, + "custom_tool_delta": CustomToolDelta, + "reasoning_start": ReasoningStart, + "reasoning_delta": ReasoningDelta, + "citation_start": CitationStart, + "citation_delta": CitationDelta, + } + + packet_class = type_mapping.get(packet_type) + if packet_class: + # Create instance using the packet data, filtering out None values + filtered_data = {k: v for k, v in packet.items() if v is not None} + return packet_class(**filtered_data) + + except Exception as e: + # Log the error but don't fail the entire process + logger.debug(f"Failed to convert packet to PacketObj: {e}") + + return None + + @log_function_time() def gather_stream( - packets: AnswerStream, + packets: Iterator[Dict[str, Any]], ) -> ChatBasicResponse: answer = "" - citations: list[CitationInfo] = [] - error_msg: str | None = None - message_id: int | None = None - top_documents: list[SavedSearchDoc] = [] - for packet in packets: - if isinstance(packet, Packet): - # Handle the different packet object types - if isinstance(packet.obj, MessageStart): - # MessageStart contains the initial content and final documents - if packet.obj.content: - answer += packet.obj.content - if packet.obj.final_documents: - top_documents = packet.obj.final_documents - elif isinstance(packet.obj, MessageDelta): - # MessageDelta contains incremental content updates - if packet.obj.content: - answer += packet.obj.content - elif isinstance(packet.obj, CitationDelta): - # CitationDelta contains citation information - if packet.obj.citations: - citations.extend(packet.obj.citations) - elif isinstance(packet, StreamingError): - error_msg = packet.error - elif isinstance(packet, MessageResponseIDInfo): - message_id = packet.reserved_assistant_message_id - - if message_id is None: - raise ValueError("Message ID is required") + if packet != {"type": "event"}: + print(packet) + + # Convert packet to PacketObj when possible + packet_obj = _convert_to_packet_obj(packet) + if packet_obj: + # Handle PacketObj types that contain text content + if hasattr(packet_obj, "content") and packet_obj.content: + answer += packet_obj.content + elif "text" in packet: + # Fallback for legacy packet format + answer += packet["text"] return ChatBasicResponse( answer=answer, answer_citationless=remove_answer_citations(answer), - cited_documents={ - citation.citation_num: citation.document_id for citation in citations - }, - message_id=message_id, - error_msg=error_msg, - top_documents=top_documents, + cited_documents={}, + message_id=0, + error_msg=None, + top_documents=[], ) diff --git a/backend/onyx/evals/demo_agent.py b/backend/onyx/evals/demo_agent.py new file mode 100644 index 00000000000..ba8220cd6bf --- /dev/null +++ b/backend/onyx/evals/demo_agent.py @@ -0,0 +1,79 @@ +import asyncio +import os + +from agents import ModelSettings +from agents import run_demo_loop +from agents.agent import Agent +from agents.extensions.handoff_prompt import prompt_with_handoff_instructions +from agents.extensions.models.litellm_model import LitellmModel +from pydantic import BaseModel + +from onyx.agents.agent_search.dr.dr_prompt_builder import ( + get_dr_prompt_orchestration_templates, +) +from onyx.agents.agent_search.dr.enums import ResearchType +from onyx.agents.agent_search.dr.models import DRPromptPurpose + + +def construct_simple_agent() -> Agent: + litellm_model = LitellmModel( + model="gpt-4.1", + api_key=os.getenv("OPENAI_API_KEY"), + ) + return Agent( + name="Assistant", + instructions=""" + You are a helpful assistant that can search the web, fetch content from URLs, + and search internal databases. + """, + model=litellm_model, + tools=[], + model_settings=ModelSettings( + temperature=0.0, + include_usage=True, # Track usage metrics + ), + ) + + +class ClarificationOutput(BaseModel): + clarification_question: str + clarification_needed: bool + + +def construct_dr_agent() -> Agent: + simple_agent = construct_simple_agent() + litellm_model = LitellmModel( + model="gpt-4.1", + api_key=os.getenv("OPENAI_API_KEY"), + ) + base_clarification_prompt = get_dr_prompt_orchestration_templates( + DRPromptPurpose.CLARIFICATION, + research_type=ResearchType.DEEP, + entity_types_string=None, + relationship_types_string=None, + available_tools={}, + ) + clarification_prompt = base_clarification_prompt.build( + question="", + chat_history_string="", + ) + clarifier_prompt = prompt_with_handoff_instructions(clarification_prompt) + clarifier_agent = Agent( + name="Clarifier", + instructions=clarifier_prompt, + model=litellm_model, + tools=[], + output_type=ClarificationOutput, + handoffs=[simple_agent], + model_settings=ModelSettings(tool_choice="required"), + ) + return clarifier_agent + + +async def main() -> None: + agent = construct_dr_agent() + await run_demo_loop(agent) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/backend/onyx/evals/eval.py b/backend/onyx/evals/eval.py index 310af7b0ea1..a032c9e6827 100644 --- a/backend/onyx/evals/eval.py +++ b/backend/onyx/evals/eval.py @@ -100,6 +100,7 @@ def run_eval( data: list[dict[str, dict[str, str]]] | None = None, remote_dataset_name: str | None = None, provider: EvalProvider = get_default_provider(), + no_send_logs: bool = False, ) -> EvalationAck: if data is not None and remote_dataset_name is not None: raise ValueError("Cannot specify both data and remote_dataset_name") @@ -112,4 +113,5 @@ def run_eval( configuration=configuration, data=data, remote_dataset_name=remote_dataset_name, + no_send_logs=no_send_logs, ) diff --git a/backend/onyx/evals/eval_cli.py b/backend/onyx/evals/eval_cli.py index de4b87dd1b5..26eeb4fd06f 100644 --- a/backend/onyx/evals/eval_cli.py +++ b/backend/onyx/evals/eval_cli.py @@ -42,6 +42,7 @@ def run_local( local_data_path: str | None, remote_dataset_name: str | None, search_permissions_email: str | None = None, + no_send_logs: bool = False, ) -> EvalationAck: """ Run evaluation with local configurations. @@ -67,7 +68,9 @@ def run_local( if remote_dataset_name: score = run_eval( - configuration=configuration, remote_dataset_name=remote_dataset_name + configuration=configuration, + remote_dataset_name=remote_dataset_name, + no_send_logs=no_send_logs, ) else: if local_data_path is None: @@ -75,7 +78,9 @@ def run_local( "local_data_path or remote_dataset_name is required for local evaluation" ) data = load_data_local(local_data_path) - score = run_eval(configuration=configuration, data=data) + score = run_eval( + configuration=configuration, data=data, no_send_logs=no_send_logs + ) return score @@ -172,6 +177,13 @@ def main() -> None: help="Email address to impersonate for the evaluation", ) + parser.add_argument( + "--no-send-logs", + action="store_true", + help="Do not send logs to the remote server", + default=False, + ) + args = parser.parse_args() if args.local_data_path: @@ -215,6 +227,7 @@ def main() -> None: local_data_path=args.local_data_path, remote_dataset_name=args.remote_dataset_name, search_permissions_email=args.search_permissions_email, + no_send_logs=args.no_send_logs, ) diff --git a/backend/onyx/evals/models.py b/backend/onyx/evals/models.py index 91800ce68d6..b4b2f648ac2 100644 --- a/backend/onyx/evals/models.py +++ b/backend/onyx/evals/models.py @@ -78,5 +78,6 @@ def eval( configuration: EvalConfigurationOptions, data: list[dict[str, dict[str, str]]] | None = None, remote_dataset_name: str | None = None, + no_send_logs: bool = False, ) -> EvalationAck: pass diff --git a/backend/onyx/evals/one_off/create_braintrust_dataset.py b/backend/onyx/evals/one_off/create_braintrust_dataset.py index 9739ee67c21..9da5f2647b4 100644 --- a/backend/onyx/evals/one_off/create_braintrust_dataset.py +++ b/backend/onyx/evals/one_off/create_braintrust_dataset.py @@ -109,8 +109,7 @@ def parse_csv_file(csv_path: str) -> List[Dict[str, Any]]: records.extend( [ { - "question": question - + ". All info is contained in the quesiton. DO NOT ask any clarifying questions.", + "question": question, "research_type": "DEEP", "categories": categories, "expected_depth": expected_depth, diff --git a/backend/onyx/evals/providers/braintrust.py b/backend/onyx/evals/providers/braintrust.py index 18a325e9521..8ae64834031 100644 --- a/backend/onyx/evals/providers/braintrust.py +++ b/backend/onyx/evals/providers/braintrust.py @@ -1,5 +1,6 @@ from collections.abc import Callable +from autoevals.llm import LLMClassifier from braintrust import Eval from braintrust import EvalCase from braintrust import init_dataset @@ -11,6 +12,33 @@ from onyx.evals.models import EvalProvider +quality_classifier = LLMClassifier( + name="quality", + prompt_template=""" + You are a customer doing a trial of the product Onyx. Onyx provides a UI for users to chat with an LLM + and search for information, similar to ChatGPT. You think ChatGPT's answer quality is great, and + you want to rate Onyx's response relativeto ChatGPT's response.\n + [Question]: {{input}}\n + [ChatGPT Answer]: {{expected}}\n + [Onyx Answer]: {{output}}\n + + Please rate the quality of the Onyx answer relative to the ChatGPT answer on a scale of A to E: + A: The Onyx answer is great and is as good or better than the ChatGPT answer. + B: The Onyx answer is good and and comparable to the ChatGPT answer. + C: The Onyx answer is fair. + D: The Onyx answer is poor and is worse than the ChatGPT answer. + E: The Onyx answer is terrible and is much worse than the ChatGPT answer. + """, + choice_scores={ + "A": 1, + "B": 0.75, + "C": 0.5, + "D": 0.25, + "E": 0, + }, +) + + class BraintrustEvalProvider(EvalProvider): def eval( self, @@ -18,6 +46,7 @@ def eval( configuration: EvalConfigurationOptions, data: list[dict[str, dict[str, str]]] | None = None, remote_dataset_name: str | None = None, + no_send_logs: bool = False, ) -> EvalationAck: if data is not None and remote_dataset_name is not None: raise ValueError("Cannot specify both data and remote_dataset_name") @@ -35,6 +64,7 @@ def eval( scores=[], metadata={**configuration.model_dump()}, max_concurrency=BRAINTRUST_MAX_CONCURRENCY, + no_send_logs=no_send_logs, ) else: if data is None: @@ -51,5 +81,6 @@ def eval( scores=[], metadata={**configuration.model_dump()}, max_concurrency=BRAINTRUST_MAX_CONCURRENCY, + no_send_logs=no_send_logs, ) return EvalationAck(success=True) diff --git a/backend/onyx/evals/tracing.py b/backend/onyx/evals/tracing.py index 1df631e6a2f..2383ae3f9f3 100644 --- a/backend/onyx/evals/tracing.py +++ b/backend/onyx/evals/tracing.py @@ -1,13 +1,17 @@ +import os from typing import Any import braintrust +from agents import set_trace_processors +from braintrust import init_logger +from braintrust.wrappers.openai import BraintrustTracingProcessor from braintrust_langchain import set_global_handler from braintrust_langchain.callbacks import BraintrustCallbackHandler from onyx.configs.app_configs import BRAINTRUST_API_KEY from onyx.configs.app_configs import BRAINTRUST_PROJECT -MASKING_LENGTH = 20000 +MASKING_LENGTH = int(os.environ.get("BRAINTRUST_MASKING_LENGTH", "20000")) def _truncate_str(s: str) -> str: @@ -33,3 +37,4 @@ def setup_braintrust() -> None: braintrust.set_masking_function(_mask) handler = BraintrustCallbackHandler() set_global_handler(handler) + set_trace_processors([BraintrustTracingProcessor(init_logger(BRAINTRUST_PROJECT))]) diff --git a/backend/onyx/llm/interfaces.py b/backend/onyx/llm/interfaces.py index 787c0315ea6..03bfafdeabe 100644 --- a/backend/onyx/llm/interfaces.py +++ b/backend/onyx/llm/interfaces.py @@ -4,7 +4,6 @@ from braintrust import traced from langchain.schema.language_model import LanguageModelInput -from langchain_core.messages import AIMessageChunk from langchain_core.messages import BaseMessage from pydantic import BaseModel @@ -34,29 +33,30 @@ class LLMConfig(BaseModel): def log_prompt(prompt: LanguageModelInput) -> None: - if isinstance(prompt, list): - for ind, msg in enumerate(prompt): - if isinstance(msg, AIMessageChunk): - if msg.content: - log_msg = msg.content - elif msg.tool_call_chunks: - log_msg = "Tool Calls: " + str( - [ - { - key: value - for key, value in tool_call.items() - if key != "index" - } - for tool_call in msg.tool_call_chunks - ] - ) - else: - log_msg = "" - logger.debug(f"Message {ind}:\n{log_msg}") - else: - logger.debug(f"Message {ind}:\n{msg.content}") - if isinstance(prompt, str): - logger.debug(f"Prompt:\n{prompt}") + # if isinstance(prompt, list): + # for ind, msg in enumerate(prompt): + # if isinstance(msg, AIMessageChunk): + # if msg.content: + # log_msg = msg.content + # elif msg.tool_call_chunks: + # log_msg = "Tool Calls: " + str( + # [ + # { + # key: value + # for key, value in tool_call.items() + # if key != "index" + # } + # for tool_call in msg.tool_call_chunks + # ] + # ) + # else: + # log_msg = "" + # logger.debug(f"Message {ind}:\n{log_msg}") + # else: + # logger.debug(f"Message {ind}:\n{msg.content}") + # if isinstance(prompt, str): + # logger.debug(f"Prompt:\n{prompt}") + pass class LLM(abc.ABC): diff --git a/backend/onyx/prompts/dr_prompts.py b/backend/onyx/prompts/dr_prompts.py index 49d547e5776..1070090cafb 100644 --- a/backend/onyx/prompts/dr_prompts.py +++ b/backend/onyx/prompts/dr_prompts.py @@ -1160,7 +1160,7 @@ GET_CLARIFICATION_PROMPT = PromptTemplate( - f"""\ + """\ You are great at asking clarifying questions in case \ a base question is not as clear enough. Your task is to ask necessary clarification \ questions to the user, before the question is sent to the deep research agent. @@ -1183,17 +1183,6 @@ The tools and the entity and relationship types in the knowledge graph are simply provided \ as context for determining whether the question requires clarification. -Here is the question the user asked: -{SEPARATOR_LINE} ----question--- -{SEPARATOR_LINE} - -Here is the previous chat history (if any), which may contain relevant information \ -to answer the question: -{SEPARATOR_LINE} ----chat_history_string--- -{SEPARATOR_LINE} - NOTES: - you have to reason over this purely based on your intrinsic knowledge. - if clarifications are required, fill in 'true' for the "feedback_needed" field and \ diff --git a/backend/onyx/server/query_and_chat/chat_backend.py b/backend/onyx/server/query_and_chat/chat_backend.py index d2c430ddd1d..4c37afaca49 100644 --- a/backend/onyx/server/query_and_chat/chat_backend.py +++ b/backend/onyx/server/query_and_chat/chat_backend.py @@ -21,7 +21,6 @@ from onyx.auth.users import current_chat_accessible_user from onyx.auth.users import current_user -from onyx.chat.chat_utils import create_chat_chain from onyx.chat.chat_utils import extract_headers from onyx.chat.process_message import stream_chat_message from onyx.chat.prompt_builder.citations_prompt import ( @@ -63,13 +62,8 @@ from onyx.file_processing.extract_file_text import docx_to_txt_filename from onyx.file_store.file_store import get_default_file_store from onyx.file_store.models import FileDescriptor -from onyx.llm.exceptions import GenAIDisabledException -from onyx.llm.factory import get_default_llms from onyx.llm.factory import get_llms_for_persona from onyx.natural_language_processing.utils import get_tokenizer -from onyx.secondary_llm_flows.chat_session_naming import ( - get_renamed_conversation_name, -) from onyx.server.documents.models import ConnectorBase from onyx.server.documents.models import CredentialBase from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type @@ -305,45 +299,44 @@ def rename_chat_session( user: User | None = Depends(current_user), db_session: Session = Depends(get_session), ) -> RenameChatSessionResponse: - name = rename_req.name - chat_session_id = rename_req.chat_session_id - user_id = user.id if user is not None else None - - if name: - update_chat_session( - db_session=db_session, - user_id=user_id, - chat_session_id=chat_session_id, - description=name, - ) - return RenameChatSessionResponse(new_name=name) - - final_msg, history_msgs = create_chat_chain( - chat_session_id=chat_session_id, db_session=db_session - ) - full_history = history_msgs + [final_msg] - - try: - llm, _ = get_default_llms( - additional_headers=extract_headers( - request.headers, LITELLM_PASS_THROUGH_HEADERS - ) - ) - except GenAIDisabledException: - # This may be longer than what the LLM tends to produce but is the most - # clear thing we can do - return RenameChatSessionResponse(new_name=full_history[0].message) - - new_name = get_renamed_conversation_name(full_history=full_history, llm=llm) - - update_chat_session( - db_session=db_session, - user_id=user_id, - chat_session_id=chat_session_id, - description=new_name, - ) - - return RenameChatSessionResponse(new_name=new_name) + # name = rename_req.name + # chat_session_id = rename_req.chat_session_id + # user_id = user.id if user is not None else None + + # if name: + # update_chat_session( + # db_session=db_session, + # user_id=user_id, + # chat_session_id=chat_session_id, + # description=name, + # ) + # return RenameChatSessionResponse(new_name=name) + + # final_msg, history_msgs = create_chat_chain( + # chat_session_id=chat_session_id, db_session=db_session + # ) + # full_history = history_msgs + [final_msg] + + # try: + # llm, _ = get_default_llms( + # additional_headers=extract_headers( + # request.headers, LITELLM_PASS_THROUGH_HEADERS + # ) + # ) + # except GenAIDisabledException: + # # This may be longer than what the LLM tends to produce but is the most + # # clear thing we can do + # return RenameChatSessionResponse(new_name=full_history[0].message) + + # new_name = get_renamed_conversation_name(full_history=full_history, llm=llm) + + # update_chat_session( + # db_session=db_session, + # user_id=user_id, + # chat_session_id=chat_session_id, + # description=new_name, + # ) + return RenameChatSessionResponse(new_name="hi") @router.patch("/chat-session/{session_id}")