1
1
import asyncio
2
2
import json
3
- from typing import Any , AsyncGenerator , Callable , Coroutine , Optional , Sequence , Set , Union , cast
3
+ from datetime import datetime , timezone
4
+ from typing import (
5
+ Any ,
6
+ AsyncGenerator ,
7
+ Callable ,
8
+ Coroutine ,
9
+ Optional ,
10
+ Sequence ,
11
+ Set ,
12
+ Union ,
13
+ cast ,
14
+ )
4
15
5
- from fastapi import APIRouter , Depends , HTTPException , status
16
+ from fastapi import APIRouter , Depends , HTTPException , Response , status
6
17
from fastapi .responses import StreamingResponse
7
18
from flux0_core .agents import AgentStore
19
+ from flux0_core .ids import gen_id
20
+ from flux0_core .recordings import (
21
+ RecordedChunkPayload ,
22
+ RecordedEmittedPayload ,
23
+ RecordedEvent ,
24
+ RecordingStore ,
25
+ )
8
26
from flux0_core .sessions import (
9
27
ContentPart ,
10
28
Event ,
24
42
from flux0_api .dependency_injection import (
25
43
get_agent_store ,
26
44
get_event_emitter ,
45
+ get_recording_store ,
27
46
get_session_service ,
28
47
get_session_store ,
29
48
get_user_store ,
60
79
def mount_create_session_route (
61
80
router : APIRouter ,
62
81
) -> Callable [
63
- [AuthedUser , SessionCreationParamsDTO , AgentStore , SessionService , AllowGreetingQuery ],
82
+ [
83
+ AuthedUser ,
84
+ Response ,
85
+ SessionCreationParamsDTO ,
86
+ AgentStore ,
87
+ SessionService ,
88
+ RecordingStore ,
89
+ AllowGreetingQuery ,
90
+ ],
64
91
Coroutine [Any , Any , SessionDTO ],
65
92
]:
66
93
@router .post (
@@ -87,9 +114,11 @@ def mount_create_session_route(
87
114
)
88
115
async def create_session_route (
89
116
authedUser : AuthedUser ,
117
+ response : Response ,
90
118
params : SessionCreationParamsDTO ,
91
119
agent_store : AgentStore = Depends (get_agent_store ),
92
120
session_service : SessionService = Depends (get_session_service ),
121
+ recording_store : RecordingStore = Depends (get_recording_store ),
93
122
allow_greeting : AllowGreetingQuery = False ,
94
123
) -> SessionDTO :
95
124
"""
@@ -111,19 +140,30 @@ async def create_session_route(
111
140
detail = f"Agent type { agent .type } is not supported by the server" ,
112
141
)
113
142
143
+ session_id = params .id if params .id else SessionId (gen_id ())
144
+ if params .mode == "record" :
145
+ recorded_event_header = await recording_store .create_recording (
146
+ source_session_id = session_id ,
147
+ )
148
+ if params .metadata is None or not isinstance (params .metadata , dict ):
149
+ params .metadata = {}
150
+ params .metadata ["recording" ] = {"recording_id" : str (recorded_event_header .recording_id )}
151
+
114
152
session = await session_service .create_user_session (
115
- id = params . id ,
153
+ id = session_id ,
116
154
user_id = authedUser .id ,
117
155
agent = agent ,
118
156
title = params .title ,
119
157
allow_greeting = allow_greeting ,
158
+ mode = params .mode ,
120
159
metadata = params .metadata ,
121
160
)
122
161
123
162
return SessionDTO (
124
163
id = session .id ,
125
164
agent_id = session .agent_id ,
126
165
user_id = session .user_id ,
166
+ mode = session .mode ,
127
167
title = session .title ,
128
168
consumption_offsets = ConsumptionOffsetsDTO (client = session .consumption_offsets ["client" ]),
129
169
created_at = session .created_at ,
@@ -169,6 +209,7 @@ async def retrieve_session(
169
209
id = session .id ,
170
210
agent_id = session .agent_id ,
171
211
user_id = session .user_id ,
212
+ mode = session .mode ,
172
213
title = session .title ,
173
214
consumption_offsets = ConsumptionOffsetsDTO (
174
215
client = session .consumption_offsets ["client" ],
@@ -211,13 +252,15 @@ async def list_sessions(
211
252
user_id = authedUser .id ,
212
253
agent_id = agent_id ,
213
254
)
255
+
214
256
return SessionsDTO (
215
257
data = [
216
258
SessionDTO (
217
259
id = s .id ,
218
260
agent_id = s .agent_id ,
219
261
title = s .title ,
220
262
user_id = s .user_id ,
263
+ mode = s .mode ,
221
264
consumption_offsets = ConsumptionOffsetsDTO (
222
265
client = s .consumption_offsets ["client" ],
223
266
),
@@ -239,6 +282,7 @@ async def event_stream(
239
282
session_service : SessionService ,
240
283
event_emitter : EventEmitter ,
241
284
subscription_ready : asyncio .Event ,
285
+ recording_store : RecordingStore ,
242
286
) -> AsyncGenerator [str , None ]:
243
287
queue : asyncio .Queue [Union [ChunkEvent , EmittedEvent ]] = asyncio .Queue ()
244
288
@@ -251,11 +295,33 @@ async def subscriber_final(emitted_event: EmittedEvent) -> None:
251
295
await queue .put (emitted_event )
252
296
253
297
# Subscribe to processed event updates
254
- print ("subscribed to correlation_id" , correlation_id )
298
+ # print("subscribed to correlation_id", correlation_id)
255
299
event_emitter .subscribe_processed (correlation_id , subscriber )
256
300
event_emitter .subscribe_final (correlation_id , subscriber_final )
257
301
subscription_ready .set ()
258
302
303
+ recording = await recording_store .read_header_by_source_session_id (session_id )
304
+ recording_id = recording .recording_id if recording else None
305
+ recording_aborted = False
306
+
307
+ async def maybe_record_emitted (payload : RecordedEmittedPayload , created_at : datetime ) -> None :
308
+ nonlocal recording_aborted
309
+ if not recording_id or recording_aborted :
310
+ return
311
+ try :
312
+ await recording_store .append_emitted (recording_id , payload , created_at = created_at )
313
+ except Exception :
314
+ recording_aborted = True # stop further attempts
315
+
316
+ async def maybe_record_chunk (payload : RecordedChunkPayload , created_at : datetime ) -> None :
317
+ nonlocal recording_aborted
318
+ if not recording_id or recording_aborted :
319
+ return
320
+ try :
321
+ await recording_store .append_chunk (recording_id , payload , created_at = created_at )
322
+ except Exception :
323
+ recording_aborted = True
324
+
259
325
try :
260
326
while True :
261
327
event = await queue .get ()
@@ -267,6 +333,17 @@ async def subscriber_final(emitted_event: EmittedEvent) -> None:
267
333
if event .type == "status" :
268
334
ed = cast (StatusEventData , event .data )
269
335
if ed ["status" ] == "completed" :
336
+ await maybe_record_emitted (
337
+ RecordedEmittedPayload (
338
+ id = event .id ,
339
+ source = event .source ,
340
+ type = "status" ,
341
+ correlation_id = event .correlation_id ,
342
+ data = ed ,
343
+ metadata = event .metadata or {},
344
+ ),
345
+ created_at = datetime .fromtimestamp (event .timestamp , tz = timezone .utc ),
346
+ )
270
347
break
271
348
await session_store .create_event (
272
349
correlation_id = event .correlation_id ,
@@ -276,6 +353,18 @@ async def subscriber_final(emitted_event: EmittedEvent) -> None:
276
353
data = ed ,
277
354
metadata = event .metadata ,
278
355
)
356
+ await maybe_record_emitted (
357
+ RecordedEmittedPayload (
358
+ id = event .id ,
359
+ source = event .source ,
360
+ type = "status" ,
361
+ correlation_id = event .correlation_id ,
362
+ data = ed ,
363
+ metadata = event .metadata or {},
364
+ ),
365
+ created_at = datetime .fromtimestamp (event .timestamp , tz = timezone .utc ),
366
+ )
367
+
279
368
yield f"event: { event_type } \n id: { event .id } \n data: { json .dumps (event .__dict__ )} \n \n "
280
369
elif event .type == "message" :
281
370
event_type = event .type
@@ -306,13 +395,23 @@ async def subscriber_final(emitted_event: EmittedEvent) -> None:
306
395
else :
307
396
# this is a chunk event
308
397
event_type = "chunk"
398
+ await maybe_record_chunk (
399
+ RecordedChunkPayload (
400
+ correlation_id = event .correlation_id ,
401
+ event_id = event .event_id ,
402
+ seq = event .seq ,
403
+ patches = event .patches ,
404
+ metadata = event .metadata or {},
405
+ ),
406
+ created_at = datetime .fromtimestamp (event .timestamp , tz = timezone .utc ),
407
+ )
309
408
yield f"event: { event_type } \n data: { json .dumps (event .__dict__ )} \n \n "
310
409
except asyncio .CancelledError :
311
410
await session_service .cancel_processing_session_task (session_id )
312
411
return
313
412
finally :
314
413
# Unsubscribe when client disconnects
315
- print ("unsubscribed from correlation_id" , correlation_id )
414
+ # print("unsubscribed from correlation_id", correlation_id)
316
415
event_emitter .unsubscribe_processed (correlation_id , subscriber )
317
416
event_emitter .unsubscribe_final (correlation_id , subscriber_final )
318
417
# Explicitly send a termination event before closing
@@ -340,6 +439,8 @@ async def _add_user_message(
340
439
agent_store : AgentStore ,
341
440
session_store : SessionStore ,
342
441
session_service : SessionService ,
442
+ rec_header : Optional [RecordedEvent ],
443
+ recording_store : RecordingStore ,
343
444
# moderation: Moderation = Moderation.NONE,
344
445
subscription_ready : asyncio .Event ,
345
446
) -> EventDTO :
@@ -400,6 +501,19 @@ async def _add_user_message(
400
501
wait_event = subscription_ready ,
401
502
)
402
503
504
+ if rec_header :
505
+ await recording_store .append_emitted (
506
+ recording_id = rec_header .recording_id ,
507
+ payload = RecordedEmittedPayload (
508
+ id = event .id ,
509
+ source = event .source ,
510
+ type = event .type ,
511
+ correlation_id = event .correlation_id ,
512
+ data = event .data ,
513
+ metadata = event .metadata or {},
514
+ ),
515
+ )
516
+
403
517
return event_to_dto (event )
404
518
405
519
@@ -414,6 +528,7 @@ def mount_create_event_and_stream_route(
414
528
SessionStore ,
415
529
UserStore ,
416
530
AgentStore ,
531
+ RecordingStore ,
417
532
EventEmitter ,
418
533
],
419
534
Coroutine [Any , Any , StreamingResponse ],
@@ -475,6 +590,7 @@ async def create_event_and_stream(
475
590
session_store : SessionStore = Depends (get_session_store ),
476
591
user_store : UserStore = Depends (get_user_store ),
477
592
agent_store : AgentStore = Depends (get_agent_store ),
593
+ recording_store : RecordingStore = Depends (get_recording_store ),
478
594
event_emitter : EventEmitter = Depends (get_event_emitter ),
479
595
# moderation: ModerationQuery = Moderation.NONE,
480
596
) -> StreamingResponse :
@@ -488,6 +604,7 @@ async def create_event_and_stream(
488
604
detail = "Only message events can currently be added manually" ,
489
605
)
490
606
607
+ rec_header = await recording_store .read_header_by_source_session_id (session_id )
491
608
subscription_ready = asyncio .Event ()
492
609
if params .source == EventSourceDTO .USER :
493
610
event = await _add_user_message (
@@ -497,9 +614,12 @@ async def create_event_and_stream(
497
614
agent_store ,
498
615
session_store ,
499
616
session_service ,
617
+ rec_header ,
618
+ recording_store ,
500
619
# moderation,
501
620
subscription_ready ,
502
621
)
622
+
503
623
return StreamingResponse (
504
624
event_stream (
505
625
session_id ,
@@ -508,6 +628,7 @@ async def create_event_and_stream(
508
628
session_service ,
509
629
event_emitter ,
510
630
subscription_ready ,
631
+ recording_store ,
511
632
),
512
633
media_type = "text/event-stream" ,
513
634
)
0 commit comments