1
1
import asyncio
2
2
import json
3
- from typing import Any , AsyncGenerator , Callable , Coroutine , Optional , Sequence , Set , Union , cast
3
+ from datetime import datetime , timezone
4
+ from typing import (
5
+ Any ,
6
+ AsyncGenerator ,
7
+ Callable ,
8
+ Coroutine ,
9
+ Optional ,
10
+ Sequence ,
11
+ Set ,
12
+ Union ,
13
+ cast ,
14
+ )
4
15
5
- from fastapi import APIRouter , Depends , HTTPException , status
16
+ from fastapi import APIRouter , Depends , HTTPException , Response , status
6
17
from fastapi .responses import StreamingResponse
7
18
from flux0_core .agents import AgentStore
19
+ from flux0_core .ids import gen_id
20
+ from flux0_core .recordings import (
21
+ RecordedChunkPayload ,
22
+ RecordedEmittedPayload ,
23
+ RecordedEvent ,
24
+ RecordingStore ,
25
+ )
8
26
from flux0_core .sessions import (
9
27
ContentPart ,
10
28
Event ,
24
42
from flux0_api .dependency_injection import (
25
43
get_agent_store ,
26
44
get_event_emitter ,
45
+ get_recording_store ,
27
46
get_session_service ,
28
47
get_session_store ,
29
48
get_user_store ,
60
79
def mount_create_session_route (
61
80
router : APIRouter ,
62
81
) -> Callable [
63
- [AuthedUser , SessionCreationParamsDTO , AgentStore , SessionService , AllowGreetingQuery ],
82
+ [
83
+ AuthedUser ,
84
+ Response ,
85
+ SessionCreationParamsDTO ,
86
+ AgentStore ,
87
+ SessionService ,
88
+ RecordingStore ,
89
+ AllowGreetingQuery ,
90
+ ],
64
91
Coroutine [Any , Any , SessionDTO ],
65
92
]:
66
93
@router .post (
@@ -87,9 +114,11 @@ def mount_create_session_route(
87
114
)
88
115
async def create_session_route (
89
116
authedUser : AuthedUser ,
117
+ response : Response ,
90
118
params : SessionCreationParamsDTO ,
91
119
agent_store : AgentStore = Depends (get_agent_store ),
92
120
session_service : SessionService = Depends (get_session_service ),
121
+ recording_store : RecordingStore = Depends (get_recording_store ),
93
122
allow_greeting : AllowGreetingQuery = False ,
94
123
) -> SessionDTO :
95
124
"""
@@ -111,19 +140,28 @@ async def create_session_route(
111
140
detail = f"Agent type { agent .type } is not supported by the server" ,
112
141
)
113
142
143
+ session_id = params .id if params .id else SessionId (gen_id ())
144
+ if params .mode == "record" :
145
+ recorded_event_header = await recording_store .create_recording (
146
+ source_session_id = session_id ,
147
+ )
148
+ response .headers ["X-Flux0-Recording-Id" ] = str (recorded_event_header .recording_id )
149
+
114
150
session = await session_service .create_user_session (
115
- id = params . id ,
151
+ id = session_id ,
116
152
user_id = authedUser .id ,
117
153
agent = agent ,
118
154
title = params .title ,
119
155
allow_greeting = allow_greeting ,
156
+ mode = params .mode ,
120
157
metadata = params .metadata ,
121
158
)
122
159
123
160
return SessionDTO (
124
161
id = session .id ,
125
162
agent_id = session .agent_id ,
126
163
user_id = session .user_id ,
164
+ mode = session .mode ,
127
165
title = session .title ,
128
166
consumption_offsets = ConsumptionOffsetsDTO (client = session .consumption_offsets ["client" ]),
129
167
created_at = session .created_at ,
@@ -169,12 +207,12 @@ async def retrieve_session(
169
207
id = session .id ,
170
208
agent_id = session .agent_id ,
171
209
user_id = session .user_id ,
210
+ mode = session .mode ,
172
211
title = session .title ,
173
212
consumption_offsets = ConsumptionOffsetsDTO (
174
213
client = session .consumption_offsets ["client" ],
175
214
),
176
215
created_at = session .created_at ,
177
- metadata = session .metadata ,
178
216
)
179
217
180
218
return retrieve_session
@@ -218,11 +256,11 @@ async def list_sessions(
218
256
agent_id = s .agent_id ,
219
257
title = s .title ,
220
258
user_id = s .user_id ,
259
+ mode = s .mode ,
221
260
consumption_offsets = ConsumptionOffsetsDTO (
222
261
client = s .consumption_offsets ["client" ],
223
262
),
224
263
created_at = s .created_at ,
225
- metadata = s .metadata ,
226
264
)
227
265
for s in sessions
228
266
]
@@ -239,6 +277,7 @@ async def event_stream(
239
277
session_service : SessionService ,
240
278
event_emitter : EventEmitter ,
241
279
subscription_ready : asyncio .Event ,
280
+ recording_store : RecordingStore ,
242
281
) -> AsyncGenerator [str , None ]:
243
282
queue : asyncio .Queue [Union [ChunkEvent , EmittedEvent ]] = asyncio .Queue ()
244
283
@@ -251,11 +290,33 @@ async def subscriber_final(emitted_event: EmittedEvent) -> None:
251
290
await queue .put (emitted_event )
252
291
253
292
# Subscribe to processed event updates
254
- print ("subscribed to correlation_id" , correlation_id )
293
+ # print("subscribed to correlation_id", correlation_id)
255
294
event_emitter .subscribe_processed (correlation_id , subscriber )
256
295
event_emitter .subscribe_final (correlation_id , subscriber_final )
257
296
subscription_ready .set ()
258
297
298
+ recording = await recording_store .read_header_by_source_session_id (session_id )
299
+ recording_id = recording .recording_id if recording else None
300
+ recording_aborted = False
301
+
302
+ async def maybe_record_emitted (payload : RecordedEmittedPayload , created_at : datetime ) -> None :
303
+ nonlocal recording_aborted
304
+ if not recording_id or recording_aborted :
305
+ return
306
+ try :
307
+ await recording_store .append_emitted (recording_id , payload , created_at = created_at )
308
+ except Exception :
309
+ recording_aborted = True # stop further attempts
310
+
311
+ async def maybe_record_chunk (payload : RecordedChunkPayload , created_at : datetime ) -> None :
312
+ nonlocal recording_aborted
313
+ if not recording_id or recording_aborted :
314
+ return
315
+ try :
316
+ await recording_store .append_chunk (recording_id , payload , created_at = created_at )
317
+ except Exception :
318
+ recording_aborted = True
319
+
259
320
try :
260
321
while True :
261
322
event = await queue .get ()
@@ -267,6 +328,17 @@ async def subscriber_final(emitted_event: EmittedEvent) -> None:
267
328
if event .type == "status" :
268
329
ed = cast (StatusEventData , event .data )
269
330
if ed ["status" ] == "completed" :
331
+ await maybe_record_emitted (
332
+ RecordedEmittedPayload (
333
+ id = event .id ,
334
+ source = event .source ,
335
+ type = "status" ,
336
+ correlation_id = event .correlation_id ,
337
+ data = ed ,
338
+ metadata = event .metadata or {},
339
+ ),
340
+ created_at = datetime .fromtimestamp (event .timestamp , tz = timezone .utc ),
341
+ )
270
342
break
271
343
await session_store .create_event (
272
344
correlation_id = event .correlation_id ,
@@ -276,6 +348,18 @@ async def subscriber_final(emitted_event: EmittedEvent) -> None:
276
348
data = ed ,
277
349
metadata = event .metadata ,
278
350
)
351
+ await maybe_record_emitted (
352
+ RecordedEmittedPayload (
353
+ id = event .id ,
354
+ source = event .source ,
355
+ type = "status" ,
356
+ correlation_id = event .correlation_id ,
357
+ data = ed ,
358
+ metadata = event .metadata or {},
359
+ ),
360
+ created_at = datetime .fromtimestamp (event .timestamp , tz = timezone .utc ),
361
+ )
362
+
279
363
yield f"event: { event_type } \n id: { event .id } \n data: { json .dumps (event .__dict__ )} \n \n "
280
364
elif event .type == "message" :
281
365
event_type = event .type
@@ -306,13 +390,23 @@ async def subscriber_final(emitted_event: EmittedEvent) -> None:
306
390
else :
307
391
# this is a chunk event
308
392
event_type = "chunk"
393
+ await maybe_record_chunk (
394
+ RecordedChunkPayload (
395
+ correlation_id = event .correlation_id ,
396
+ event_id = event .event_id ,
397
+ seq = event .seq ,
398
+ patches = event .patches ,
399
+ metadata = event .metadata or {},
400
+ ),
401
+ created_at = datetime .fromtimestamp (event .timestamp , tz = timezone .utc ),
402
+ )
309
403
yield f"event: { event_type } \n data: { json .dumps (event .__dict__ )} \n \n "
310
404
except asyncio .CancelledError :
311
405
await session_service .cancel_processing_session_task (session_id )
312
406
return
313
407
finally :
314
408
# Unsubscribe when client disconnects
315
- print ("unsubscribed from correlation_id" , correlation_id )
409
+ # print("unsubscribed from correlation_id", correlation_id)
316
410
event_emitter .unsubscribe_processed (correlation_id , subscriber )
317
411
# Explicitly send a termination event before closing
318
412
# yield "event: close\ndata: {}\n\n"
@@ -339,6 +433,8 @@ async def _add_user_message(
339
433
agent_store : AgentStore ,
340
434
session_store : SessionStore ,
341
435
session_service : SessionService ,
436
+ rec_header : Optional [RecordedEvent ],
437
+ recording_store : RecordingStore ,
342
438
# moderation: Moderation = Moderation.NONE,
343
439
subscription_ready : asyncio .Event ,
344
440
) -> EventDTO :
@@ -399,6 +495,19 @@ async def _add_user_message(
399
495
wait_event = subscription_ready ,
400
496
)
401
497
498
+ if rec_header :
499
+ await recording_store .append_emitted (
500
+ recording_id = rec_header .recording_id ,
501
+ payload = RecordedEmittedPayload (
502
+ id = event .id ,
503
+ source = event .source ,
504
+ type = event .type ,
505
+ correlation_id = event .correlation_id ,
506
+ data = event .data ,
507
+ metadata = event .metadata or {},
508
+ ),
509
+ )
510
+
402
511
return event_to_dto (event )
403
512
404
513
@@ -413,6 +522,7 @@ def mount_create_event_and_stream_route(
413
522
SessionStore ,
414
523
UserStore ,
415
524
AgentStore ,
525
+ RecordingStore ,
416
526
EventEmitter ,
417
527
],
418
528
Coroutine [Any , Any , StreamingResponse ],
@@ -474,6 +584,7 @@ async def create_event_and_stream(
474
584
session_store : SessionStore = Depends (get_session_store ),
475
585
user_store : UserStore = Depends (get_user_store ),
476
586
agent_store : AgentStore = Depends (get_agent_store ),
587
+ recording_store : RecordingStore = Depends (get_recording_store ),
477
588
event_emitter : EventEmitter = Depends (get_event_emitter ),
478
589
# moderation: ModerationQuery = Moderation.NONE,
479
590
) -> StreamingResponse :
@@ -487,6 +598,7 @@ async def create_event_and_stream(
487
598
detail = "Only message events can currently be added manually" ,
488
599
)
489
600
601
+ rec_header = await recording_store .read_header_by_source_session_id (session_id )
490
602
subscription_ready = asyncio .Event ()
491
603
if params .source == EventSourceDTO .USER :
492
604
event = await _add_user_message (
@@ -496,9 +608,12 @@ async def create_event_and_stream(
496
608
agent_store ,
497
609
session_store ,
498
610
session_service ,
611
+ rec_header ,
612
+ recording_store ,
499
613
# moderation,
500
614
subscription_ready ,
501
615
)
616
+
502
617
return StreamingResponse (
503
618
event_stream (
504
619
session_id ,
@@ -507,6 +622,7 @@ async def create_event_and_stream(
507
622
session_service ,
508
623
event_emitter ,
509
624
subscription_ready ,
625
+ recording_store ,
510
626
),
511
627
media_type = "text/event-stream" ,
512
628
)
0 commit comments