Skip to content

Commit 507248f

Browse files
committed
feat: session record & replay
resolves #101
1 parent ce601a3 commit 507248f

File tree

16 files changed

+1323
-45
lines changed

16 files changed

+1323
-45
lines changed

packages/api/src/flux0_api/dependency_injection.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import inflection
44
from fastapi import HTTPException, Request
55
from flux0_core.agents import AgentStore
6+
from flux0_core.recordings import RecordingStore
67
from flux0_core.sessions import SessionStore
78
from flux0_core.users import UserStore
89
from flux0_stream.emitter.api import EventEmitter
@@ -77,5 +78,9 @@ def get_user_store(request: Request) -> UserStore:
7778
return resolve_dependency(request, UserStore)
7879

7980

81+
def get_recording_store(request: Request) -> RecordingStore:
82+
return resolve_dependency(request, RecordingStore)
83+
84+
8085
def get_event_emitter(request: Request) -> EventEmitter:
8186
return resolve_dependency(request, EventEmitter)

packages/api/src/flux0_api/session_service.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99
from flux0_core.contextual_correlator import ContextualCorrelator
1010
from flux0_core.ids import gen_id
1111
from flux0_core.logging import Logger
12+
from flux0_core.recordings import RecordingStore
1213
from flux0_core.sessions import (
1314
Event,
1415
EventSource,
1516
EventType,
1617
MessageEventData,
1718
Session,
1819
SessionId,
20+
SessionMode,
1921
SessionStore,
2022
StatusEventData,
2123
ToolEventData,
@@ -32,6 +34,7 @@ def __init__(
3234
logger: Logger,
3335
agent_store: AgentStore,
3436
session_store: SessionStore,
37+
recording_store: RecordingStore,
3538
background_task_service: BackgroundTaskService,
3639
agent_runner_factory: AgentRunnerFactory,
3740
event_emitter: EventEmitter,
@@ -40,6 +43,7 @@ def __init__(
4043
self._logger = logger
4144
self._agent_store = agent_store
4245
self._session_store = session_store
46+
self._recording_store = recording_store
4347
self._background_task_service = background_task_service
4448
self._agent_runner_factory = agent_runner_factory
4549
self._event_emitter = event_emitter
@@ -51,6 +55,7 @@ async def create_user_session(
5155
id: Optional[SessionId] = None,
5256
title: Optional[str] = None,
5357
allow_greeting: bool = False,
58+
mode: Optional[SessionMode] = None,
5459
metadata: Optional[Mapping[str, JSONSerializable]] = None,
5560
) -> Session:
5661
session = await self._session_store.create_session(
@@ -59,6 +64,7 @@ async def create_user_session(
5964
id=id,
6065
title=title,
6166
created_at=datetime.now(timezone.utc),
67+
mode=mode,
6268
metadata=metadata,
6369
)
6470

@@ -145,6 +151,7 @@ async def _process_session(self, session: Session, agent: Agent) -> None:
145151
event_emitter=self._event_emitter,
146152
agent_store=self._agent_store,
147153
session_store=self._session_store,
154+
recording_store=self._recording_store,
148155
),
149156
)
150157

packages/api/src/flux0_api/sessions.py

Lines changed: 124 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,28 @@
11
import asyncio
22
import json
3-
from typing import Any, AsyncGenerator, Callable, Coroutine, Optional, Sequence, Set, Union, cast
3+
from datetime import datetime, timezone
4+
from typing import (
5+
Any,
6+
AsyncGenerator,
7+
Callable,
8+
Coroutine,
9+
Optional,
10+
Sequence,
11+
Set,
12+
Union,
13+
cast,
14+
)
415

5-
from fastapi import APIRouter, Depends, HTTPException, status
16+
from fastapi import APIRouter, Depends, HTTPException, Response, status
617
from fastapi.responses import StreamingResponse
718
from flux0_core.agents import AgentStore
19+
from flux0_core.ids import gen_id
20+
from flux0_core.recordings import (
21+
RecordedChunkPayload,
22+
RecordedEmittedPayload,
23+
RecordedEvent,
24+
RecordingStore,
25+
)
826
from flux0_core.sessions import (
927
ContentPart,
1028
Event,
@@ -24,6 +42,7 @@
2442
from flux0_api.dependency_injection import (
2543
get_agent_store,
2644
get_event_emitter,
45+
get_recording_store,
2746
get_session_service,
2847
get_session_store,
2948
get_user_store,
@@ -60,7 +79,15 @@
6079
def mount_create_session_route(
6180
router: APIRouter,
6281
) -> Callable[
63-
[AuthedUser, SessionCreationParamsDTO, AgentStore, SessionService, AllowGreetingQuery],
82+
[
83+
AuthedUser,
84+
Response,
85+
SessionCreationParamsDTO,
86+
AgentStore,
87+
SessionService,
88+
RecordingStore,
89+
AllowGreetingQuery,
90+
],
6491
Coroutine[Any, Any, SessionDTO],
6592
]:
6693
@router.post(
@@ -87,9 +114,11 @@ def mount_create_session_route(
87114
)
88115
async def create_session_route(
89116
authedUser: AuthedUser,
117+
response: Response,
90118
params: SessionCreationParamsDTO,
91119
agent_store: AgentStore = Depends(get_agent_store),
92120
session_service: SessionService = Depends(get_session_service),
121+
recording_store: RecordingStore = Depends(get_recording_store),
93122
allow_greeting: AllowGreetingQuery = False,
94123
) -> SessionDTO:
95124
"""
@@ -111,19 +140,28 @@ async def create_session_route(
111140
detail=f"Agent type {agent.type} is not supported by the server",
112141
)
113142

143+
session_id = params.id if params.id else SessionId(gen_id())
144+
if params.mode == "record":
145+
recorded_event_header = await recording_store.create_recording(
146+
source_session_id=session_id,
147+
)
148+
response.headers["X-Flux0-Recording-Id"] = str(recorded_event_header.recording_id)
149+
114150
session = await session_service.create_user_session(
115-
id=params.id,
151+
id=session_id,
116152
user_id=authedUser.id,
117153
agent=agent,
118154
title=params.title,
119155
allow_greeting=allow_greeting,
156+
mode=params.mode,
120157
metadata=params.metadata,
121158
)
122159

123160
return SessionDTO(
124161
id=session.id,
125162
agent_id=session.agent_id,
126163
user_id=session.user_id,
164+
mode=session.mode,
127165
title=session.title,
128166
consumption_offsets=ConsumptionOffsetsDTO(client=session.consumption_offsets["client"]),
129167
created_at=session.created_at,
@@ -169,12 +207,12 @@ async def retrieve_session(
169207
id=session.id,
170208
agent_id=session.agent_id,
171209
user_id=session.user_id,
210+
mode=session.mode,
172211
title=session.title,
173212
consumption_offsets=ConsumptionOffsetsDTO(
174213
client=session.consumption_offsets["client"],
175214
),
176215
created_at=session.created_at,
177-
metadata=session.metadata,
178216
)
179217

180218
return retrieve_session
@@ -218,11 +256,11 @@ async def list_sessions(
218256
agent_id=s.agent_id,
219257
title=s.title,
220258
user_id=s.user_id,
259+
mode=s.mode,
221260
consumption_offsets=ConsumptionOffsetsDTO(
222261
client=s.consumption_offsets["client"],
223262
),
224263
created_at=s.created_at,
225-
metadata=s.metadata,
226264
)
227265
for s in sessions
228266
]
@@ -239,6 +277,7 @@ async def event_stream(
239277
session_service: SessionService,
240278
event_emitter: EventEmitter,
241279
subscription_ready: asyncio.Event,
280+
recording_store: RecordingStore,
242281
) -> AsyncGenerator[str, None]:
243282
queue: asyncio.Queue[Union[ChunkEvent, EmittedEvent]] = asyncio.Queue()
244283

@@ -251,11 +290,33 @@ async def subscriber_final(emitted_event: EmittedEvent) -> None:
251290
await queue.put(emitted_event)
252291

253292
# Subscribe to processed event updates
254-
print("subscribed to correlation_id", correlation_id)
293+
# print("subscribed to correlation_id", correlation_id)
255294
event_emitter.subscribe_processed(correlation_id, subscriber)
256295
event_emitter.subscribe_final(correlation_id, subscriber_final)
257296
subscription_ready.set()
258297

298+
recording = await recording_store.read_header_by_source_session_id(session_id)
299+
recording_id = recording.recording_id if recording else None
300+
recording_aborted = False
301+
302+
async def maybe_record_emitted(payload: RecordedEmittedPayload, created_at: datetime) -> None:
303+
nonlocal recording_aborted
304+
if not recording_id or recording_aborted:
305+
return
306+
try:
307+
await recording_store.append_emitted(recording_id, payload, created_at=created_at)
308+
except Exception:
309+
recording_aborted = True # stop further attempts
310+
311+
async def maybe_record_chunk(payload: RecordedChunkPayload, created_at: datetime) -> None:
312+
nonlocal recording_aborted
313+
if not recording_id or recording_aborted:
314+
return
315+
try:
316+
await recording_store.append_chunk(recording_id, payload, created_at=created_at)
317+
except Exception:
318+
recording_aborted = True
319+
259320
try:
260321
while True:
261322
event = await queue.get()
@@ -267,6 +328,17 @@ async def subscriber_final(emitted_event: EmittedEvent) -> None:
267328
if event.type == "status":
268329
ed = cast(StatusEventData, event.data)
269330
if ed["status"] == "completed":
331+
await maybe_record_emitted(
332+
RecordedEmittedPayload(
333+
id=event.id,
334+
source=event.source,
335+
type="status",
336+
correlation_id=event.correlation_id,
337+
data=ed,
338+
metadata=event.metadata or {},
339+
),
340+
created_at=datetime.fromtimestamp(event.timestamp, tz=timezone.utc),
341+
)
270342
break
271343
await session_store.create_event(
272344
correlation_id=event.correlation_id,
@@ -276,6 +348,18 @@ async def subscriber_final(emitted_event: EmittedEvent) -> None:
276348
data=ed,
277349
metadata=event.metadata,
278350
)
351+
await maybe_record_emitted(
352+
RecordedEmittedPayload(
353+
id=event.id,
354+
source=event.source,
355+
type="status",
356+
correlation_id=event.correlation_id,
357+
data=ed,
358+
metadata=event.metadata or {},
359+
),
360+
created_at=datetime.fromtimestamp(event.timestamp, tz=timezone.utc),
361+
)
362+
279363
yield f"event: {event_type}\nid: {event.id}\ndata: {json.dumps(event.__dict__)}\n\n"
280364
elif event.type == "message":
281365
event_type = event.type
@@ -306,13 +390,23 @@ async def subscriber_final(emitted_event: EmittedEvent) -> None:
306390
else:
307391
# this is a chunk event
308392
event_type = "chunk"
393+
await maybe_record_chunk(
394+
RecordedChunkPayload(
395+
correlation_id=event.correlation_id,
396+
event_id=event.event_id,
397+
seq=event.seq,
398+
patches=event.patches,
399+
metadata=event.metadata or {},
400+
),
401+
created_at=datetime.fromtimestamp(event.timestamp, tz=timezone.utc),
402+
)
309403
yield f"event: {event_type}\ndata: {json.dumps(event.__dict__)}\n\n"
310404
except asyncio.CancelledError:
311405
await session_service.cancel_processing_session_task(session_id)
312406
return
313407
finally:
314408
# Unsubscribe when client disconnects
315-
print("unsubscribed from correlation_id", correlation_id)
409+
# print("unsubscribed from correlation_id", correlation_id)
316410
event_emitter.unsubscribe_processed(correlation_id, subscriber)
317411
# Explicitly send a termination event before closing
318412
# yield "event: close\ndata: {}\n\n"
@@ -339,6 +433,8 @@ async def _add_user_message(
339433
agent_store: AgentStore,
340434
session_store: SessionStore,
341435
session_service: SessionService,
436+
rec_header: Optional[RecordedEvent],
437+
recording_store: RecordingStore,
342438
# moderation: Moderation = Moderation.NONE,
343439
subscription_ready: asyncio.Event,
344440
) -> EventDTO:
@@ -399,6 +495,19 @@ async def _add_user_message(
399495
wait_event=subscription_ready,
400496
)
401497

498+
if rec_header:
499+
await recording_store.append_emitted(
500+
recording_id=rec_header.recording_id,
501+
payload=RecordedEmittedPayload(
502+
id=event.id,
503+
source=event.source,
504+
type=event.type,
505+
correlation_id=event.correlation_id,
506+
data=event.data,
507+
metadata=event.metadata or {},
508+
),
509+
)
510+
402511
return event_to_dto(event)
403512

404513

@@ -413,6 +522,7 @@ def mount_create_event_and_stream_route(
413522
SessionStore,
414523
UserStore,
415524
AgentStore,
525+
RecordingStore,
416526
EventEmitter,
417527
],
418528
Coroutine[Any, Any, StreamingResponse],
@@ -474,6 +584,7 @@ async def create_event_and_stream(
474584
session_store: SessionStore = Depends(get_session_store),
475585
user_store: UserStore = Depends(get_user_store),
476586
agent_store: AgentStore = Depends(get_agent_store),
587+
recording_store: RecordingStore = Depends(get_recording_store),
477588
event_emitter: EventEmitter = Depends(get_event_emitter),
478589
# moderation: ModerationQuery = Moderation.NONE,
479590
) -> StreamingResponse:
@@ -487,6 +598,7 @@ async def create_event_and_stream(
487598
detail="Only message events can currently be added manually",
488599
)
489600

601+
rec_header = await recording_store.read_header_by_source_session_id(session_id)
490602
subscription_ready = asyncio.Event()
491603
if params.source == EventSourceDTO.USER:
492604
event = await _add_user_message(
@@ -496,9 +608,12 @@ async def create_event_and_stream(
496608
agent_store,
497609
session_store,
498610
session_service,
611+
rec_header,
612+
recording_store,
499613
# moderation,
500614
subscription_ready,
501615
)
616+
502617
return StreamingResponse(
503618
event_stream(
504619
session_id,
@@ -507,6 +622,7 @@ async def create_event_and_stream(
507622
session_service,
508623
event_emitter,
509624
subscription_ready,
625+
recording_store,
510626
),
511627
media_type="text/event-stream",
512628
)

0 commit comments

Comments
 (0)