Skip to content

Commit 3ffe044

Browse files
committed
feat: session record & replay
resolves #101
1 parent d3dea98 commit 3ffe044

File tree

16 files changed

+1326
-43
lines changed

16 files changed

+1326
-43
lines changed

packages/api/src/flux0_api/dependency_injection.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import inflection
44
from fastapi import HTTPException, Request
55
from flux0_core.agents import AgentStore
6+
from flux0_core.recordings import RecordingStore
67
from flux0_core.sessions import SessionStore
78
from flux0_core.users import UserStore
89
from flux0_stream.emitter.api import EventEmitter
@@ -77,5 +78,9 @@ def get_user_store(request: Request) -> UserStore:
7778
return resolve_dependency(request, UserStore)
7879

7980

81+
def get_recording_store(request: Request) -> RecordingStore:
82+
return resolve_dependency(request, RecordingStore)
83+
84+
8085
def get_event_emitter(request: Request) -> EventEmitter:
8186
return resolve_dependency(request, EventEmitter)

packages/api/src/flux0_api/session_service.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99
from flux0_core.contextual_correlator import ContextualCorrelator
1010
from flux0_core.ids import gen_id
1111
from flux0_core.logging import Logger
12+
from flux0_core.recordings import RecordingStore
1213
from flux0_core.sessions import (
1314
Event,
1415
EventSource,
1516
EventType,
1617
MessageEventData,
1718
Session,
1819
SessionId,
20+
SessionMode,
1921
SessionStore,
2022
StatusEventData,
2123
ToolEventData,
@@ -32,6 +34,7 @@ def __init__(
3234
logger: Logger,
3335
agent_store: AgentStore,
3436
session_store: SessionStore,
37+
recording_store: RecordingStore,
3538
background_task_service: BackgroundTaskService,
3639
agent_runner_factory: AgentRunnerFactory,
3740
event_emitter: EventEmitter,
@@ -40,6 +43,7 @@ def __init__(
4043
self._logger = logger
4144
self._agent_store = agent_store
4245
self._session_store = session_store
46+
self._recording_store = recording_store
4347
self._background_task_service = background_task_service
4448
self._agent_runner_factory = agent_runner_factory
4549
self._event_emitter = event_emitter
@@ -51,6 +55,7 @@ async def create_user_session(
5155
id: Optional[SessionId] = None,
5256
title: Optional[str] = None,
5357
allow_greeting: bool = False,
58+
mode: Optional[SessionMode] = None,
5459
metadata: Optional[Mapping[str, JSONSerializable]] = None,
5560
) -> Session:
5661
session = await self._session_store.create_session(
@@ -59,6 +64,7 @@ async def create_user_session(
5964
id=id,
6065
title=title,
6166
created_at=datetime.now(timezone.utc),
67+
mode=mode,
6268
metadata=metadata,
6369
)
6470

@@ -145,6 +151,7 @@ async def _process_session(self, session: Session, agent: Agent) -> None:
145151
event_emitter=self._event_emitter,
146152
agent_store=self._agent_store,
147153
session_store=self._session_store,
154+
recording_store=self._recording_store,
148155
),
149156
)
150157

packages/api/src/flux0_api/sessions.py

Lines changed: 127 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,28 @@
11
import asyncio
22
import json
3-
from typing import Any, AsyncGenerator, Callable, Coroutine, Optional, Sequence, Set, Union, cast
3+
from datetime import datetime, timezone
4+
from typing import (
5+
Any,
6+
AsyncGenerator,
7+
Callable,
8+
Coroutine,
9+
Optional,
10+
Sequence,
11+
Set,
12+
Union,
13+
cast,
14+
)
415

5-
from fastapi import APIRouter, Depends, HTTPException, status
16+
from fastapi import APIRouter, Depends, HTTPException, Response, status
617
from fastapi.responses import StreamingResponse
718
from flux0_core.agents import AgentStore
19+
from flux0_core.ids import gen_id
20+
from flux0_core.recordings import (
21+
RecordedChunkPayload,
22+
RecordedEmittedPayload,
23+
RecordedEvent,
24+
RecordingStore,
25+
)
826
from flux0_core.sessions import (
927
ContentPart,
1028
Event,
@@ -24,6 +42,7 @@
2442
from flux0_api.dependency_injection import (
2543
get_agent_store,
2644
get_event_emitter,
45+
get_recording_store,
2746
get_session_service,
2847
get_session_store,
2948
get_user_store,
@@ -60,7 +79,15 @@
6079
def mount_create_session_route(
6180
router: APIRouter,
6281
) -> Callable[
63-
[AuthedUser, SessionCreationParamsDTO, AgentStore, SessionService, AllowGreetingQuery],
82+
[
83+
AuthedUser,
84+
Response,
85+
SessionCreationParamsDTO,
86+
AgentStore,
87+
SessionService,
88+
RecordingStore,
89+
AllowGreetingQuery,
90+
],
6491
Coroutine[Any, Any, SessionDTO],
6592
]:
6693
@router.post(
@@ -87,9 +114,11 @@ def mount_create_session_route(
87114
)
88115
async def create_session_route(
89116
authedUser: AuthedUser,
117+
response: Response,
90118
params: SessionCreationParamsDTO,
91119
agent_store: AgentStore = Depends(get_agent_store),
92120
session_service: SessionService = Depends(get_session_service),
121+
recording_store: RecordingStore = Depends(get_recording_store),
93122
allow_greeting: AllowGreetingQuery = False,
94123
) -> SessionDTO:
95124
"""
@@ -111,19 +140,30 @@ async def create_session_route(
111140
detail=f"Agent type {agent.type} is not supported by the server",
112141
)
113142

143+
session_id = params.id if params.id else SessionId(gen_id())
144+
if params.mode == "record":
145+
recorded_event_header = await recording_store.create_recording(
146+
source_session_id=session_id,
147+
)
148+
if params.metadata is None or not isinstance(params.metadata, dict):
149+
params.metadata = {}
150+
params.metadata["recording"] = {"recording_id": str(recorded_event_header.recording_id)}
151+
114152
session = await session_service.create_user_session(
115-
id=params.id,
153+
id=session_id,
116154
user_id=authedUser.id,
117155
agent=agent,
118156
title=params.title,
119157
allow_greeting=allow_greeting,
158+
mode=params.mode,
120159
metadata=params.metadata,
121160
)
122161

123162
return SessionDTO(
124163
id=session.id,
125164
agent_id=session.agent_id,
126165
user_id=session.user_id,
166+
mode=session.mode,
127167
title=session.title,
128168
consumption_offsets=ConsumptionOffsetsDTO(client=session.consumption_offsets["client"]),
129169
created_at=session.created_at,
@@ -169,6 +209,7 @@ async def retrieve_session(
169209
id=session.id,
170210
agent_id=session.agent_id,
171211
user_id=session.user_id,
212+
mode=session.mode,
172213
title=session.title,
173214
consumption_offsets=ConsumptionOffsetsDTO(
174215
client=session.consumption_offsets["client"],
@@ -211,13 +252,15 @@ async def list_sessions(
211252
user_id=authedUser.id,
212253
agent_id=agent_id,
213254
)
255+
214256
return SessionsDTO(
215257
data=[
216258
SessionDTO(
217259
id=s.id,
218260
agent_id=s.agent_id,
219261
title=s.title,
220262
user_id=s.user_id,
263+
mode=s.mode,
221264
consumption_offsets=ConsumptionOffsetsDTO(
222265
client=s.consumption_offsets["client"],
223266
),
@@ -239,6 +282,7 @@ async def event_stream(
239282
session_service: SessionService,
240283
event_emitter: EventEmitter,
241284
subscription_ready: asyncio.Event,
285+
recording_store: RecordingStore,
242286
) -> AsyncGenerator[str, None]:
243287
queue: asyncio.Queue[Union[ChunkEvent, EmittedEvent]] = asyncio.Queue()
244288

@@ -251,11 +295,33 @@ async def subscriber_final(emitted_event: EmittedEvent) -> None:
251295
await queue.put(emitted_event)
252296

253297
# Subscribe to processed event updates
254-
print("subscribed to correlation_id", correlation_id)
298+
# print("subscribed to correlation_id", correlation_id)
255299
event_emitter.subscribe_processed(correlation_id, subscriber)
256300
event_emitter.subscribe_final(correlation_id, subscriber_final)
257301
subscription_ready.set()
258302

303+
recording = await recording_store.read_header_by_source_session_id(session_id)
304+
recording_id = recording.recording_id if recording else None
305+
recording_aborted = False
306+
307+
async def maybe_record_emitted(payload: RecordedEmittedPayload, created_at: datetime) -> None:
308+
nonlocal recording_aborted
309+
if not recording_id or recording_aborted:
310+
return
311+
try:
312+
await recording_store.append_emitted(recording_id, payload, created_at=created_at)
313+
except Exception:
314+
recording_aborted = True # stop further attempts
315+
316+
async def maybe_record_chunk(payload: RecordedChunkPayload, created_at: datetime) -> None:
317+
nonlocal recording_aborted
318+
if not recording_id or recording_aborted:
319+
return
320+
try:
321+
await recording_store.append_chunk(recording_id, payload, created_at=created_at)
322+
except Exception:
323+
recording_aborted = True
324+
259325
try:
260326
while True:
261327
event = await queue.get()
@@ -267,6 +333,17 @@ async def subscriber_final(emitted_event: EmittedEvent) -> None:
267333
if event.type == "status":
268334
ed = cast(StatusEventData, event.data)
269335
if ed["status"] == "completed":
336+
await maybe_record_emitted(
337+
RecordedEmittedPayload(
338+
id=event.id,
339+
source=event.source,
340+
type="status",
341+
correlation_id=event.correlation_id,
342+
data=ed,
343+
metadata=event.metadata or {},
344+
),
345+
created_at=datetime.fromtimestamp(event.timestamp, tz=timezone.utc),
346+
)
270347
break
271348
await session_store.create_event(
272349
correlation_id=event.correlation_id,
@@ -276,6 +353,18 @@ async def subscriber_final(emitted_event: EmittedEvent) -> None:
276353
data=ed,
277354
metadata=event.metadata,
278355
)
356+
await maybe_record_emitted(
357+
RecordedEmittedPayload(
358+
id=event.id,
359+
source=event.source,
360+
type="status",
361+
correlation_id=event.correlation_id,
362+
data=ed,
363+
metadata=event.metadata or {},
364+
),
365+
created_at=datetime.fromtimestamp(event.timestamp, tz=timezone.utc),
366+
)
367+
279368
yield f"event: {event_type}\nid: {event.id}\ndata: {json.dumps(event.__dict__)}\n\n"
280369
elif event.type == "message":
281370
event_type = event.type
@@ -306,13 +395,23 @@ async def subscriber_final(emitted_event: EmittedEvent) -> None:
306395
else:
307396
# this is a chunk event
308397
event_type = "chunk"
398+
await maybe_record_chunk(
399+
RecordedChunkPayload(
400+
correlation_id=event.correlation_id,
401+
event_id=event.event_id,
402+
seq=event.seq,
403+
patches=event.patches,
404+
metadata=event.metadata or {},
405+
),
406+
created_at=datetime.fromtimestamp(event.timestamp, tz=timezone.utc),
407+
)
309408
yield f"event: {event_type}\ndata: {json.dumps(event.__dict__)}\n\n"
310409
except asyncio.CancelledError:
311410
await session_service.cancel_processing_session_task(session_id)
312411
return
313412
finally:
314413
# Unsubscribe when client disconnects
315-
print("unsubscribed from correlation_id", correlation_id)
414+
# print("unsubscribed from correlation_id", correlation_id)
316415
event_emitter.unsubscribe_processed(correlation_id, subscriber)
317416
event_emitter.unsubscribe_final(correlation_id, subscriber_final)
318417
# Explicitly send a termination event before closing
@@ -340,6 +439,8 @@ async def _add_user_message(
340439
agent_store: AgentStore,
341440
session_store: SessionStore,
342441
session_service: SessionService,
442+
rec_header: Optional[RecordedEvent],
443+
recording_store: RecordingStore,
343444
# moderation: Moderation = Moderation.NONE,
344445
subscription_ready: asyncio.Event,
345446
) -> EventDTO:
@@ -400,6 +501,19 @@ async def _add_user_message(
400501
wait_event=subscription_ready,
401502
)
402503

504+
if rec_header:
505+
await recording_store.append_emitted(
506+
recording_id=rec_header.recording_id,
507+
payload=RecordedEmittedPayload(
508+
id=event.id,
509+
source=event.source,
510+
type=event.type,
511+
correlation_id=event.correlation_id,
512+
data=event.data,
513+
metadata=event.metadata or {},
514+
),
515+
)
516+
403517
return event_to_dto(event)
404518

405519

@@ -414,6 +528,7 @@ def mount_create_event_and_stream_route(
414528
SessionStore,
415529
UserStore,
416530
AgentStore,
531+
RecordingStore,
417532
EventEmitter,
418533
],
419534
Coroutine[Any, Any, StreamingResponse],
@@ -475,6 +590,7 @@ async def create_event_and_stream(
475590
session_store: SessionStore = Depends(get_session_store),
476591
user_store: UserStore = Depends(get_user_store),
477592
agent_store: AgentStore = Depends(get_agent_store),
593+
recording_store: RecordingStore = Depends(get_recording_store),
478594
event_emitter: EventEmitter = Depends(get_event_emitter),
479595
# moderation: ModerationQuery = Moderation.NONE,
480596
) -> StreamingResponse:
@@ -488,6 +604,7 @@ async def create_event_and_stream(
488604
detail="Only message events can currently be added manually",
489605
)
490606

607+
rec_header = await recording_store.read_header_by_source_session_id(session_id)
491608
subscription_ready = asyncio.Event()
492609
if params.source == EventSourceDTO.USER:
493610
event = await _add_user_message(
@@ -497,9 +614,12 @@ async def create_event_and_stream(
497614
agent_store,
498615
session_store,
499616
session_service,
617+
rec_header,
618+
recording_store,
500619
# moderation,
501620
subscription_ready,
502621
)
622+
503623
return StreamingResponse(
504624
event_stream(
505625
session_id,
@@ -508,6 +628,7 @@ async def create_event_and_stream(
508628
session_service,
509629
event_emitter,
510630
subscription_ready,
631+
recording_store,
511632
),
512633
media_type="text/event-stream",
513634
)

0 commit comments

Comments
 (0)