7
7
*/
8
8
9
9
use std:: collections:: HashMap ;
10
+ use std:: fs;
11
+ use std:: path:: PathBuf ;
10
12
use std:: sync:: Arc ;
11
13
use std:: sync:: Mutex ;
12
14
@@ -20,11 +22,19 @@ use serde_json::Value as JValue;
20
22
use serde_rusqlite:: * ;
21
23
use tracing:: Event ;
22
24
use tracing:: Subscriber ;
23
- use tracing:: level_filters:: LevelFilter ;
24
25
use tracing_subscriber:: Layer ;
25
- use tracing_subscriber:: filter :: Targets ;
26
+ use tracing_subscriber:: Registry ;
26
27
use tracing_subscriber:: prelude:: * ;
28
+ use tracing_subscriber:: reload;
27
29
30
+ pub type SqliteReloadHandle = reload:: Handle < Option < SqliteLayer > , Registry > ;
31
+
32
+ lazy_static ! {
33
+ // Reload handle allows us to include a no-op layer during init, but load
34
+ // the layer dynamically during tests.
35
+ static ref RELOAD_HANDLE : Mutex <Option <SqliteReloadHandle >> =
36
+ Mutex :: new( None ) ;
37
+ }
28
38
pub trait TableDef {
29
39
fn name ( & self ) -> & ' static str ;
30
40
fn columns ( & self ) -> & ' static [ & ' static str ] ;
@@ -221,7 +231,15 @@ macro_rules! insert_event {
221
231
impl SqliteLayer {
222
232
pub fn new ( ) -> Result < Self > {
223
233
let conn = Connection :: open_in_memory ( ) ?;
234
+ Self :: setup_connection ( conn)
235
+ }
236
+
237
+ pub fn new_with_file ( db_path : & str ) -> Result < Self > {
238
+ let conn = Connection :: open ( db_path) ?;
239
+ Self :: setup_connection ( conn)
240
+ }
224
241
242
+ fn setup_connection ( conn : Connection ) -> Result < Self > {
225
243
for table in ALL_TABLES . iter ( ) {
226
244
conn. execute ( & table. create_table_stmt , [ ] ) ?;
227
245
}
@@ -323,21 +341,89 @@ fn print_table(conn: &Connection, table_name: TableName) -> Result<()> {
323
341
Ok ( ( ) )
324
342
}
325
343
326
- pub fn with_tracing_db ( ) -> Arc < Mutex < Connection > > {
327
- let layer = SqliteLayer :: new ( ) . unwrap ( ) ;
328
- let conn = layer. connection ( ) ;
329
-
330
- let layer = layer. with_filter (
331
- Targets :: new ( )
332
- . with_default ( LevelFilter :: TRACE )
333
- . with_targets ( vec ! [
334
- ( "tokio" , LevelFilter :: OFF ) ,
335
- ( "opentelemetry" , LevelFilter :: OFF ) ,
336
- ( "runtime" , LevelFilter :: OFF ) ,
337
- ] ) ,
338
- ) ;
339
- tracing_subscriber:: registry ( ) . with ( layer) . init ( ) ;
340
- conn
344
+ fn init_tracing_subscriber ( layer : SqliteLayer ) {
345
+ let handle = RELOAD_HANDLE . lock ( ) . unwrap ( ) ;
346
+ if let Some ( reload_handle) = handle. as_ref ( ) {
347
+ let _ = reload_handle. reload ( layer) ;
348
+ } else {
349
+ tracing_subscriber:: registry ( ) . with ( layer) . init ( ) ;
350
+ }
351
+ }
352
+
353
+ // === API ===
354
+
355
+ // Creates a new reload handler and no-op layer for initialization
356
+ pub fn get_reloadable_sqlite_layer ( ) -> Result < reload:: Layer < Option < SqliteLayer > , Registry > > {
357
+ let ( layer, reload_handle) = reload:: Layer :: new ( None ) ;
358
+ let mut handle = RELOAD_HANDLE . lock ( ) . unwrap ( ) ;
359
+ * handle = Some ( reload_handle) ;
360
+ Ok ( layer)
361
+ }
362
+
363
+ /// RAII guard for SQLite tracing database
364
+ pub struct SqliteTracing {
365
+ db_path : Option < PathBuf > ,
366
+ connection : Arc < Mutex < Connection > > ,
367
+ }
368
+
369
+ impl SqliteTracing {
370
+ /// Create a new SqliteTracing with a temporary file
371
+ pub fn new ( ) -> Result < Self > {
372
+ let temp_dir = std:: env:: temp_dir ( ) ;
373
+ let file_name = format ! ( "hyperactor_trace_{}.db" , std:: process:: id( ) ) ;
374
+ let db_path = temp_dir. join ( file_name) ;
375
+
376
+ let db_path_str = db_path. to_string_lossy ( ) ;
377
+ let layer = SqliteLayer :: new_with_file ( & db_path_str) ?;
378
+ let connection = layer. connection ( ) ;
379
+
380
+ init_tracing_subscriber ( layer) ;
381
+
382
+ Ok ( Self {
383
+ db_path : Some ( db_path) ,
384
+ connection,
385
+ } )
386
+ }
387
+
388
+ /// Create a new SqliteTracing with in-memory database
389
+ pub fn new_in_memory ( ) -> Result < Self > {
390
+ let layer = SqliteLayer :: new ( ) ?;
391
+ let connection = layer. connection ( ) ;
392
+
393
+ init_tracing_subscriber ( layer) ;
394
+
395
+ Ok ( Self {
396
+ db_path : None ,
397
+ connection,
398
+ } )
399
+ }
400
+
401
+ /// Get the path to the temporary database file (None for in-memory)
402
+ pub fn db_path ( & self ) -> Option < & PathBuf > {
403
+ self . db_path . as_ref ( )
404
+ }
405
+
406
+ /// Get a reference to the database connection
407
+ pub fn connection ( & self ) -> Arc < Mutex < Connection > > {
408
+ self . connection . clone ( )
409
+ }
410
+ }
411
+
412
+ impl Drop for SqliteTracing {
413
+ fn drop ( & mut self ) {
414
+ // Reset the layer to None
415
+ let handle = RELOAD_HANDLE . lock ( ) . unwrap ( ) ;
416
+ if let Some ( reload_handle) = handle. as_ref ( ) {
417
+ let _ = reload_handle. reload ( None ) ;
418
+ }
419
+
420
+ // Delete the temporary file if it exists
421
+ if let Some ( db_path) = & self . db_path {
422
+ if db_path. exists ( ) {
423
+ let _ = fs:: remove_file ( db_path) ;
424
+ }
425
+ }
426
+ }
341
427
}
342
428
343
429
#[ cfg( test) ]
@@ -347,8 +433,9 @@ mod tests {
347
433
use super :: * ;
348
434
349
435
#[ test]
350
- fn test_sqlite_layer ( ) -> Result < ( ) > {
351
- let conn = with_tracing_db ( ) ;
436
+ fn test_sqlite_tracing_with_file ( ) -> Result < ( ) > {
437
+ let tracing = SqliteTracing :: new ( ) ?;
438
+ let conn = tracing. connection ( ) ;
352
439
353
440
info ! ( target: "messages" , test_field = "test_value" , "Test msg" ) ;
354
441
info ! ( target: "events" , test_field = "test_value" , "Test event" ) ;
@@ -359,6 +446,87 @@ mod tests {
359
446
. query_row ( "SELECT COUNT(*) FROM messages" , [ ] , |row| row. get ( 0 ) ) ?;
360
447
print_table ( & conn. lock ( ) . unwrap ( ) , TableName :: Events ) ?;
361
448
assert ! ( count > 0 ) ;
449
+
450
+ // Verify we have a file path
451
+ assert ! ( tracing. db_path( ) . is_some( ) ) ;
452
+ let db_path = tracing. db_path ( ) . unwrap ( ) ;
453
+ assert ! ( db_path. exists( ) ) ;
454
+
455
+ Ok ( ( ) )
456
+ }
457
+
458
+ #[ test]
459
+ fn test_sqlite_tracing_in_memory ( ) -> Result < ( ) > {
460
+ let tracing = SqliteTracing :: new_in_memory ( ) ?;
461
+ let conn = tracing. connection ( ) ;
462
+
463
+ info ! ( target: "messages" , test_field = "test_value" , "Test event in memory" ) ;
464
+
465
+ let count: i64 =
466
+ conn. lock ( )
467
+ . unwrap ( )
468
+ . query_row ( "SELECT COUNT(*) FROM messages" , [ ] , |row| row. get ( 0 ) ) ?;
469
+ print_table ( & conn. lock ( ) . unwrap ( ) , TableName :: Messages ) ?;
470
+ assert ! ( count > 0 ) ;
471
+
472
+ // Verify we don't have a file path for in-memory
473
+ assert ! ( tracing. db_path( ) . is_none( ) ) ;
474
+
475
+ Ok ( ( ) )
476
+ }
477
+
478
+ #[ test]
479
+ fn test_sqlite_tracing_cleanup ( ) -> Result < ( ) > {
480
+ let db_path = {
481
+ let tracing = SqliteTracing :: new ( ) ?;
482
+ let conn = tracing. connection ( ) ;
483
+
484
+ info ! ( target: "events" , test_field = "cleanup_test" , "Test cleanup event" ) ;
485
+
486
+ let count: i64 =
487
+ conn. lock ( )
488
+ . unwrap ( )
489
+ . query_row ( "SELECT COUNT(*) FROM events" , [ ] , |row| row. get ( 0 ) ) ?;
490
+ assert ! ( count > 0 ) ;
491
+
492
+ tracing. db_path ( ) . unwrap ( ) . clone ( )
493
+ } ; // tracing goes out of scope here, triggering Drop
494
+
495
+ // File should be cleaned up after Drop
496
+ assert ! ( !db_path. exists( ) ) ;
497
+
498
+ Ok ( ( ) )
499
+ }
500
+
501
+ #[ test]
502
+ fn test_sqlite_tracing_different_targets ( ) -> Result < ( ) > {
503
+ let tracing = SqliteTracing :: new_in_memory ( ) ?;
504
+ let conn = tracing. connection ( ) ;
505
+
506
+ // Test different event targets
507
+ info ! ( target: "messages" , src = "actor1" , dest = "actor2" , payload = "test_message" , "Message event" ) ;
508
+ info ! ( target: "actor_lifecycle" , actor_id = "123" , actor = "TestActor" , name = "test" , "Lifecycle event" ) ;
509
+ info ! ( target: "events" , test_field = "general_event" , "General event" ) ;
510
+
511
+ // Check that events went to the right tables
512
+ let message_count: i64 =
513
+ conn. lock ( )
514
+ . unwrap ( )
515
+ . query_row ( "SELECT COUNT(*) FROM messages" , [ ] , |row| row. get ( 0 ) ) ?;
516
+ assert_eq ! ( message_count, 1 ) ;
517
+
518
+ let lifecycle_count: i64 =
519
+ conn. lock ( )
520
+ . unwrap ( )
521
+ . query_row ( "SELECT COUNT(*) FROM actor_lifecycle" , [ ] , |row| row. get ( 0 ) ) ?;
522
+ assert_eq ! ( lifecycle_count, 1 ) ;
523
+
524
+ let events_count: i64 =
525
+ conn. lock ( )
526
+ . unwrap ( )
527
+ . query_row ( "SELECT COUNT(*) FROM events" , [ ] , |row| row. get ( 0 ) ) ?;
528
+ assert_eq ! ( events_count, 1 ) ;
529
+
362
530
Ok ( ( ) )
363
531
}
364
532
}
0 commit comments