Skip to content

Commit accbd7a

Browse files
committed
.
1 parent 0940c8d commit accbd7a

File tree

5 files changed

+91
-31
lines changed

5 files changed

+91
-31
lines changed

backend/onyx/chat/turn/fast_chat_turn.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
from onyx.chat.turn.infra.chat_turn_orchestration import unified_event_stream
99
from onyx.chat.turn.models import MyContext
1010
from onyx.chat.turn.models import RunDependencies
11+
from onyx.server.query_and_chat.streaming_models import MessageDelta
12+
from onyx.server.query_and_chat.streaming_models import MessageStart
13+
from onyx.server.query_and_chat.streaming_models import OverallStop
14+
from onyx.server.query_and_chat.streaming_models import Packet
15+
from onyx.server.query_and_chat.streaming_models import SectionEnd
1116
from onyx.tools.tool_implementations_v2.internal_search import internal_search
1217
from onyx.tools.tool_implementations_v2.web import web_fetch
1318
from onyx.tools.tool_implementations_v2.web import web_search
@@ -33,14 +38,25 @@ def fast_chat_turn(messages: list[dict], dependencies: RunDependencies) -> None:
3338
)
3439

3540
bridge = OnyxRunner().run_streamed(agent, messages, context=ctx, max_turns=100)
36-
try:
37-
for ev in bridge.events():
38-
if isinstance(ev, RunItemStreamEvent):
39-
pass
40-
elif isinstance(ev, RawResponsesStreamEvent):
41-
# TODO: use very standardized schema for the emitter that is close to
42-
# front end schema
43-
dependencies.emitter.emit(kind="agent", data=ev.data.model_dump())
44-
finally:
45-
# TODO: Handle done signal more reliably?
46-
dependencies.emitter.emit(kind="done", data={"ok": True})
41+
for ev in bridge.events():
42+
if isinstance(ev, RunItemStreamEvent):
43+
pass
44+
elif isinstance(ev, RawResponsesStreamEvent):
45+
obj = None
46+
if ev.data.type == "response.created":
47+
obj = MessageStart(
48+
type="message_start", content="", final_documents=None
49+
)
50+
elif ev.data.type == "response.output_text.delta":
51+
obj = MessageDelta(type="message_delta", content=ev.data.delta)
52+
elif ev.data.type == "response.completed":
53+
obj = OverallStop(type="stop")
54+
elif ev.data.type == "response.output_item.done":
55+
obj = SectionEnd(type="section_end")
56+
if obj:
57+
dependencies.emitter.emit(Packet(ind=ctx.current_run_step, obj=obj))
58+
# TODO: Error handling
59+
# Should there be a timeout and some error on the queue?
60+
dependencies.emitter.emit(
61+
Packet(ind=ctx.current_run_step, obj=OverallStop(type="stop"))
62+
)

backend/onyx/chat/turn/infra/chat_turn_event_stream.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,13 @@
55
from collections.abc import Iterator
66
from queue import Queue
77
from typing import Any
8-
from typing import Dict
98
from typing import Optional
109

1110
from agents import Agent
1211
from agents import Runner
1312
from agents import TContext
14-
from pydantic import BaseModel
15-
from pydantic import Field
13+
14+
from onyx.server.query_and_chat.streaming_models import Packet
1615

1716

1817
logger = logging.getLogger(__name__)
@@ -85,12 +84,7 @@ def _do_cancel():
8584
self._loop.call_soon_threadsafe(_do_cancel)
8685

8786

88-
class StreamPacket(BaseModel):
89-
kind: str # "agent" | "tool-progress" | "done"
90-
payload: Dict[str, Any] = Field(default_factory=dict)
91-
92-
93-
def convert_to_packet_obj(packet: StreamPacket) -> Any | None:
87+
def convert_to_packet_obj(packet: dict[str, Any]) -> Any | None:
9488
"""Convert a packet dictionary to PacketObj when possible.
9589
9690
Args:
@@ -176,5 +170,5 @@ class Emitter:
176170
def __init__(self, bus: Queue):
177171
self.bus = bus
178172

179-
def emit(self, kind: str, data: Dict[str, Any]) -> None:
180-
self.bus.put(StreamPacket(kind=kind, payload=data))
173+
def emit(self, packet: Packet) -> None:
174+
self.bus.put(packet)

backend/onyx/chat/turn/infra/chat_turn_orchestration.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77
from typing import Dict
88
from typing import List
99

10-
from onyx.chat.turn.infra.chat_turn_event_stream import convert_to_packet_obj
1110
from onyx.chat.turn.infra.chat_turn_event_stream import Emitter
12-
from onyx.chat.turn.infra.chat_turn_event_stream import StreamPacket
1311
from onyx.chat.turn.models import RunDependencies
12+
from onyx.server.query_and_chat.streaming_models import OverallStop
1413
from onyx.server.query_and_chat.streaming_models import Packet
1514

1615

@@ -48,12 +47,10 @@ def wrapper(
4847
)
4948
t.start()
5049
while True:
51-
pkt: StreamPacket = emitter.bus.get()
52-
if pkt.kind == "done":
50+
pkt: Packet = emitter.bus.get()
51+
if pkt.obj == OverallStop(type="stop"):
5352
break
5453
else:
55-
packet_obj = convert_to_packet_obj(pkt.payload)
56-
if packet_obj:
57-
yield Packet(ind=0, obj=packet_obj)
54+
yield pkt
5855

5956
return wrapper

backend/onyx/chat/turn/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ class MyContext:
1818

1919
run_dependencies: RunDependencies | None = None
2020
needs_compaction: bool = False
21+
current_run_step: int = 0

backend/onyx/tools/tool_implementations_v2/web.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
get_default_provider,
99
)
1010
from onyx.chat.turn.models import MyContext
11+
from onyx.configs.constants import DocumentSource
12+
from onyx.server.query_and_chat.streaming_models import Packet
13+
from onyx.server.query_and_chat.streaming_models import SavedSearchDoc
14+
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
15+
from onyx.server.query_and_chat.streaming_models import SearchToolStart
1116

1217

1318
def short_tag(link: str, i: int) -> str:
@@ -28,7 +33,22 @@ def web_search(run_context: RunContextWrapper[MyContext], query: str) -> str:
2833
query: The natural-language search query.
2934
"""
3035
search_provider = get_default_provider()
31-
run_context.run_dependencies.emitter.emit(kind="web-search", data={"query": query})
36+
run_context.context.run_dependencies.emitter.emit(
37+
Packet(
38+
ind=run_context.context.current_run_step + 1,
39+
obj=SearchToolStart(
40+
type="internal_search_tool_start", is_internet_search=True
41+
),
42+
)
43+
)
44+
run_context.context.run_dependencies.emitter.emit(
45+
Packet(
46+
ind=run_context.context.current_run_step + 1,
47+
obj=SearchToolDelta(
48+
type="internal_search_tool_delta", queries=[query], documents=None
49+
),
50+
)
51+
)
3252
hits = search_provider.search(query)
3353
results = []
3454
for i, r in enumerate(hits):
@@ -44,6 +64,7 @@ def web_search(run_context: RunContextWrapper[MyContext], query: str) -> str:
4464
),
4565
}
4666
)
67+
run_context.context.current_run_step += 2
4768
return json.dumps({"results": results})
4869

4970

@@ -60,7 +81,38 @@ def web_fetch(run_context: RunContextWrapper[MyContext], urls: List[str]) -> str
6081
urls: The full URLs of the pages to retrieve.
6182
"""
6283
search_provider = get_default_provider()
63-
run_context.run_dependencies.emitter.emit(kind="web-fetch", data={"urls": urls})
84+
saved_search_docs = [
85+
SavedSearchDoc(
86+
document_id=url,
87+
chunk_ind=0,
88+
semantic_identifier=url,
89+
link=url,
90+
blurb=url,
91+
source_type=DocumentSource.WEB,
92+
boost=1,
93+
hidden=False,
94+
metadata={},
95+
score=0.0,
96+
is_relevant=None,
97+
relevance_explanation=None,
98+
match_highlights=[],
99+
updated_at=None,
100+
primary_owners=None,
101+
secondary_owners=None,
102+
is_internet=True,
103+
)
104+
for url in urls
105+
]
106+
run_context.context.run_dependencies.emitter.emit(
107+
Packet(
108+
ind=run_context.context.current_run_step + 1,
109+
obj=SearchToolDelta(
110+
type="internal_search_tool_delta",
111+
is_internet_search=True,
112+
documents=saved_search_docs,
113+
),
114+
)
115+
)
64116
docs = search_provider.contents(urls)
65117
out = []
66118
for i, d in enumerate(docs):

0 commit comments

Comments
 (0)