Skip to content

Commit 74ed3c1

Browse files
committed
.
1 parent d59c85a commit 74ed3c1

File tree

1 file changed

+87
-33
lines changed

1 file changed

+87
-33
lines changed

backend/onyx/chat/answer_scratchpad.py

Lines changed: 87 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22

33
import asyncio
44
import json
5+
import queue
56
import threading
67
from collections.abc import Generator
8+
from collections.abc import Iterator
79
from dataclasses import dataclass
810
from queue import Queue
911
from typing import Any
1012
from typing import cast
1113
from typing import Dict
1214
from typing import List
15+
from typing import Optional
1316

1417
import litellm
1518
from agents import Agent
@@ -69,7 +72,6 @@ def short_tag(link: str, i: int) -> str:
6972

7073

7174
@function_tool
72-
@traced(name="web_search")
7375
def web_search(query: str) -> str:
7476
"""Search the web for information. This tool provides urls and short snippets,
7577
but does not fetch the full content of the urls."""
@@ -93,7 +95,6 @@ def web_search(query: str) -> str:
9395

9496

9597
@function_tool
96-
@traced(name="web_fetch")
9798
def web_fetch(urls: List[str]) -> str:
9899
"""Fetch the full contents of a list of URLs."""
99100
exa_client = ExaClient()
@@ -114,18 +115,6 @@ def web_fetch(urls: List[str]) -> str:
114115
return json.dumps({"results": out})
115116

116117

117-
@function_tool
118-
@traced(name="reasoning")
119-
def reasoning() -> str:
120-
"""Use this tool for reasoning. Powerful for complex questions and
121-
tasks, or questions that require multiple steps to answer."""
122-
# Note: This is a simplified version. In the full implementation,
123-
# we would need to pass the context through the agent's context system
124-
return (
125-
"Reasoning tool - this would need to be implemented with proper context access"
126-
)
127-
128-
129118
@traced(name="llm_completion", type="llm")
130119
def llm_completion(
131120
model_name: str,
@@ -143,7 +132,6 @@ def llm_completion(
143132

144133

145134
@function_tool
146-
@traced(name="internal_search")
147135
def internal_search(context_wrapper: RunContextWrapper[MyContext], query: str) -> str:
148136
"""Search internal company vector database for information. Sources
149137
include:
@@ -283,15 +271,13 @@ class ResearchScratchpad(BaseModel):
283271

284272

285273
@function_tool
286-
@traced(name="add_note")
287274
def add_note(note: str, source_url: str | None = None):
288275
"""Store a factual note you want to cite later."""
289276
scratchpad.notes.append({"note": note, "source_url": source_url})
290277
return {"ok": True, "count": len(scratchpad.notes)}
291278

292279

293280
@function_tool
294-
@traced(name="finalize_report")
295281
def finalize_report():
296282
"""Signal you're done researching. Return a structured, citation-rich report."""
297283
# The model should *compose* the report as the tool *result*, using notes in scratchpad.
@@ -330,7 +316,6 @@ def construct_deep_research_agent(llm: LLM) -> Agent:
330316
- Minimize redundancy by skimming before fetching.
331317
- Think out loud in a compact way, but keep reasoning crisp.
332318
"""
333-
334319
return Agent(
335320
name="Researcher",
336321
instructions=DR_INSTRUCTIONS,
@@ -420,7 +405,7 @@ def construct_simple_agent(
420405
and search internal databases.
421406
""",
422407
model=litellm_model,
423-
tools=[web_search, web_fetch, reasoning, internal_search],
408+
tools=[web_search, web_fetch, internal_search],
424409
model_settings=ModelSettings(
425410
temperature=llm.config.temperature,
426411
include_usage=True, # Track usage metrics
@@ -430,45 +415,114 @@ def construct_simple_agent(
430415

431416
def thread_worker_dr_turn(messages, cfg, llm, emitter, search_tool):
432417
try:
433-
asyncio.run(dr_turn(messages, cfg, llm, emitter, search_tool))
418+
dr_turn(messages, cfg, llm, emitter, search_tool)
434419
except Exception as e:
435420
logger.error(f"Error in dr_turn: {e}", exc_info=e, stack_info=True)
436421
emitter.emit(kind="done", data={"ok": False})
437422

438423

439-
async def dr_turn(
424+
SENTINEL = object()
425+
426+
427+
class StreamBridge:
428+
"""
429+
Spins up an asyncio loop in a background thread, starts Runner.run_streamed there,
430+
consumes its async event stream, and exposes a blocking .events() iterator.
431+
"""
432+
433+
def __init__(self, agent, messages, ctx, max_turns: int = 100):
434+
self.agent = agent
435+
self.messages = messages
436+
self.ctx = ctx
437+
self.max_turns = max_turns
438+
439+
self._q: "queue.Queue[object]" = queue.Queue()
440+
self._loop: Optional[asyncio.AbstractEventLoop] = None
441+
self._thread: Optional[threading.Thread] = None
442+
self._streamed = None
443+
444+
def start(self):
445+
def worker():
446+
async def run_and_consume():
447+
# Create the streamed run *inside* the loop thread
448+
self._streamed = Runner.run_streamed(
449+
self.agent,
450+
self.messages,
451+
context=self.ctx,
452+
max_turns=self.max_turns,
453+
)
454+
try:
455+
async for ev in self._streamed.stream_events():
456+
self._q.put(ev)
457+
finally:
458+
self._q.put(SENTINEL)
459+
460+
# Each thread needs its own loop
461+
self._loop = asyncio.new_event_loop()
462+
asyncio.set_event_loop(self._loop)
463+
try:
464+
self._loop.run_until_complete(run_and_consume())
465+
finally:
466+
self._loop.close()
467+
468+
self._thread = threading.Thread(target=worker, daemon=True)
469+
self._thread.start()
470+
return self
471+
472+
def events(self) -> Iterator[object]:
473+
while True:
474+
ev = self._q.get()
475+
if ev is SENTINEL:
476+
break
477+
yield ev
478+
479+
def cancel(self):
480+
# Post a cancellation to the loop thread safely
481+
if self._loop and self._streamed:
482+
483+
def _do_cancel():
484+
try:
485+
self._streamed.cancel()
486+
except Exception:
487+
pass
488+
489+
self._loop.call_soon_threadsafe(_do_cancel)
490+
491+
492+
def dr_turn(
440493
messages: List[Dict[str, Any]],
441494
cfg: GraphConfig,
442495
llm: LLM,
443-
emitter: Emitter,
496+
turn_event_stream_emitter: Emitter, # TurnEventStream is the primary output of the turn
444497
search_tool: SearchTool | None = None,
445498
) -> None:
446-
clarification = get_clarification(messages, cfg, llm, emitter, search_tool)
499+
clarification = get_clarification(
500+
messages, cfg, llm, turn_event_stream_emitter, search_tool
501+
)
447502
output = json.loads(clarification.choices[0].message.content)
448503
clarification_output = ClarificationOutput(**output)
449504
if clarification_output.clarification_needed:
450-
emitter.emit(kind="agent", data=clarification_output.clarification_question)
451-
emitter.emit(kind="done", data={"ok": True})
505+
turn_event_stream_emitter.emit(
506+
kind="agent", data=clarification_output.clarification_question
507+
)
508+
turn_event_stream_emitter.emit(kind="done", data={"ok": True})
452509
return
453510

454511
agent = construct_deep_research_agent(llm)
455512
ctx = MyContext(
456513
run_dependencies=RunDependencies(
457514
search_tool=search_tool,
458-
emitter=emitter,
515+
emitter=turn_event_stream_emitter,
459516
)
460517
)
461-
# 1) start the streamed run (async)
462-
streamed = Runner.run_streamed(agent, messages, context=ctx, max_turns=100)
463-
464-
# 2) forward the agent’s async event stream
465-
async for ev in streamed.stream_events():
518+
bridge = StreamBridge(agent, messages, ctx, max_turns=100).start()
519+
for ev in bridge.events():
466520
if isinstance(ev, RunItemStreamEvent):
467521
pass
468522
elif isinstance(ev, RawResponsesStreamEvent):
469-
emitter.emit(kind="agent", data=ev.data.model_dump())
523+
turn_event_stream_emitter.emit(kind="agent", data=ev.data.model_dump())
470524

471-
emitter.emit(kind="done", data={"ok": True})
525+
turn_event_stream_emitter.emit(kind="done", data={"ok": True})
472526

473527

474528
class ClarificationOutput(BaseModel):

0 commit comments

Comments
 (0)