Skip to content

Commit ed913d6

Browse files
benjipelletierfacebook-github-bot
authored andcommitted
Add tracing sqlite db file support so we can connect from python (#807)
Summary: POC of Rust-create sqlite db accessable from native python sqlite connection. This allows us to write python tests that test Monarch and assert user (or BE) events using sql queries. * Uses reloadable layer to inject SqliteLayer into tracing registry on demand. * Exposes `with_tracing_db_file` to python to create and get the DB file name so we can connect to Differential Revision: D79761474
1 parent dc4d14b commit ed913d6

File tree

4 files changed

+310
-25
lines changed

4 files changed

+310
-25
lines changed

hyperactor_telemetry/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ use tracing_subscriber::fmt::format::Writer;
6464
use tracing_subscriber::registry::LookupSpan;
6565

6666
use crate::recorder::Recorder;
67+
use crate::sqlite::get_reloadable_sqlite_layer;
6768

6869
pub trait TelemetryClock {
6970
fn now(&self) -> tokio::time::Instant;
@@ -563,6 +564,8 @@ pub fn initialize_logging_with_log_prefix(
563564
.with_target("opentelemetry", LevelFilter::OFF), // otel has some log span under debug that we don't care about
564565
);
565566

567+
let sqlite_layer = get_reloadable_sqlite_layer().unwrap();
568+
566569
use tracing_subscriber::Registry;
567570
use tracing_subscriber::layer::SubscriberExt;
568571
use tracing_subscriber::util::SubscriberInitExt;
@@ -574,6 +577,7 @@ pub fn initialize_logging_with_log_prefix(
574577
std::env::var(env_var).unwrap_or_default() != "1"
575578
}
576579
if let Err(err) = Registry::default()
580+
.with(sqlite_layer)
577581
.with(if is_layer_enabled(DISABLE_OTEL_TRACING) {
578582
Some(otel::tracing_layer())
579583
} else {

hyperactor_telemetry/src/sqlite.rs

Lines changed: 187 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
*/
88

99
use std::collections::HashMap;
10+
use std::fs;
11+
use std::path::PathBuf;
1012
use std::sync::Arc;
1113
use std::sync::Mutex;
1214

@@ -20,11 +22,19 @@ use serde_json::Value as JValue;
2022
use serde_rusqlite::*;
2123
use tracing::Event;
2224
use tracing::Subscriber;
23-
use tracing::level_filters::LevelFilter;
2425
use tracing_subscriber::Layer;
25-
use tracing_subscriber::filter::Targets;
26+
use tracing_subscriber::Registry;
2627
use tracing_subscriber::prelude::*;
28+
use tracing_subscriber::reload;
2729

30+
pub type SqliteReloadHandle = reload::Handle<Option<SqliteLayer>, Registry>;
31+
32+
lazy_static! {
33+
// Reload handle allows us to include a no-op layer during init, but load
34+
// the layer dynamically during tests.
35+
static ref RELOAD_HANDLE: Mutex<Option<SqliteReloadHandle>> =
36+
Mutex::new(None);
37+
}
2838
pub trait TableDef {
2939
fn name(&self) -> &'static str;
3040
fn columns(&self) -> &'static [&'static str];
@@ -221,7 +231,15 @@ macro_rules! insert_event {
221231
impl SqliteLayer {
222232
pub fn new() -> Result<Self> {
223233
let conn = Connection::open_in_memory()?;
234+
Self::setup_connection(conn)
235+
}
236+
237+
pub fn new_with_file(db_path: &str) -> Result<Self> {
238+
let conn = Connection::open(db_path)?;
239+
Self::setup_connection(conn)
240+
}
224241

242+
fn setup_connection(conn: Connection) -> Result<Self> {
225243
for table in ALL_TABLES.iter() {
226244
conn.execute(&table.create_table_stmt, [])?;
227245
}
@@ -323,21 +341,89 @@ fn print_table(conn: &Connection, table_name: TableName) -> Result<()> {
323341
Ok(())
324342
}
325343

326-
pub fn with_tracing_db() -> Arc<Mutex<Connection>> {
327-
let layer = SqliteLayer::new().unwrap();
328-
let conn = layer.connection();
329-
330-
let layer = layer.with_filter(
331-
Targets::new()
332-
.with_default(LevelFilter::TRACE)
333-
.with_targets(vec![
334-
("tokio", LevelFilter::OFF),
335-
("opentelemetry", LevelFilter::OFF),
336-
("runtime", LevelFilter::OFF),
337-
]),
338-
);
339-
tracing_subscriber::registry().with(layer).init();
340-
conn
344+
fn init_tracing_subscriber(layer: SqliteLayer) {
345+
let handle = RELOAD_HANDLE.lock().unwrap();
346+
if let Some(reload_handle) = handle.as_ref() {
347+
let _ = reload_handle.reload(layer);
348+
} else {
349+
tracing_subscriber::registry().with(layer).init();
350+
}
351+
}
352+
353+
// === API ===
354+
355+
// Creates a new reload handler and no-op layer for initialization
356+
pub fn get_reloadable_sqlite_layer() -> Result<reload::Layer<Option<SqliteLayer>, Registry>> {
357+
let (layer, reload_handle) = reload::Layer::new(None);
358+
let mut handle = RELOAD_HANDLE.lock().unwrap();
359+
*handle = Some(reload_handle);
360+
Ok(layer)
361+
}
362+
363+
/// RAII guard for SQLite tracing database
364+
pub struct SqliteTracing {
365+
db_path: Option<PathBuf>,
366+
connection: Arc<Mutex<Connection>>,
367+
}
368+
369+
impl SqliteTracing {
370+
/// Create a new SqliteTracing with a temporary file
371+
pub fn new() -> Result<Self> {
372+
let temp_dir = std::env::temp_dir();
373+
let file_name = format!("hyperactor_trace_{}.db", std::process::id());
374+
let db_path = temp_dir.join(file_name);
375+
376+
let db_path_str = db_path.to_string_lossy();
377+
let layer = SqliteLayer::new_with_file(&db_path_str)?;
378+
let connection = layer.connection();
379+
380+
init_tracing_subscriber(layer);
381+
382+
Ok(Self {
383+
db_path: Some(db_path),
384+
connection,
385+
})
386+
}
387+
388+
/// Create a new SqliteTracing with in-memory database
389+
pub fn new_in_memory() -> Result<Self> {
390+
let layer = SqliteLayer::new()?;
391+
let connection = layer.connection();
392+
393+
init_tracing_subscriber(layer);
394+
395+
Ok(Self {
396+
db_path: None,
397+
connection,
398+
})
399+
}
400+
401+
/// Get the path to the temporary database file (None for in-memory)
402+
pub fn db_path(&self) -> Option<&PathBuf> {
403+
self.db_path.as_ref()
404+
}
405+
406+
/// Get a reference to the database connection
407+
pub fn connection(&self) -> Arc<Mutex<Connection>> {
408+
self.connection.clone()
409+
}
410+
}
411+
412+
impl Drop for SqliteTracing {
413+
fn drop(&mut self) {
414+
// Reset the layer to None
415+
let handle = RELOAD_HANDLE.lock().unwrap();
416+
if let Some(reload_handle) = handle.as_ref() {
417+
let _ = reload_handle.reload(None);
418+
}
419+
420+
// Delete the temporary file if it exists
421+
if let Some(db_path) = &self.db_path {
422+
if db_path.exists() {
423+
let _ = fs::remove_file(db_path);
424+
}
425+
}
426+
}
341427
}
342428

343429
#[cfg(test)]
@@ -347,8 +433,9 @@ mod tests {
347433
use super::*;
348434

349435
#[test]
350-
fn test_sqlite_layer() -> Result<()> {
351-
let conn = with_tracing_db();
436+
fn test_sqlite_tracing_with_file() -> Result<()> {
437+
let tracing = SqliteTracing::new()?;
438+
let conn = tracing.connection();
352439

353440
info!(target:"messages", test_field = "test_value", "Test msg");
354441
info!(target:"events", test_field = "test_value", "Test event");
@@ -359,6 +446,87 @@ mod tests {
359446
.query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
360447
print_table(&conn.lock().unwrap(), TableName::Events)?;
361448
assert!(count > 0);
449+
450+
// Verify we have a file path
451+
assert!(tracing.db_path().is_some());
452+
let db_path = tracing.db_path().unwrap();
453+
assert!(db_path.exists());
454+
455+
Ok(())
456+
}
457+
458+
#[test]
459+
fn test_sqlite_tracing_in_memory() -> Result<()> {
460+
let tracing = SqliteTracing::new_in_memory()?;
461+
let conn = tracing.connection();
462+
463+
info!(target:"messages", test_field = "test_value", "Test event in memory");
464+
465+
let count: i64 =
466+
conn.lock()
467+
.unwrap()
468+
.query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
469+
print_table(&conn.lock().unwrap(), TableName::Messages)?;
470+
assert!(count > 0);
471+
472+
// Verify we don't have a file path for in-memory
473+
assert!(tracing.db_path().is_none());
474+
475+
Ok(())
476+
}
477+
478+
#[test]
479+
fn test_sqlite_tracing_cleanup() -> Result<()> {
480+
let db_path = {
481+
let tracing = SqliteTracing::new()?;
482+
let conn = tracing.connection();
483+
484+
info!(target:"events", test_field = "cleanup_test", "Test cleanup event");
485+
486+
let count: i64 =
487+
conn.lock()
488+
.unwrap()
489+
.query_row("SELECT COUNT(*) FROM events", [], |row| row.get(0))?;
490+
assert!(count > 0);
491+
492+
tracing.db_path().unwrap().clone()
493+
}; // tracing goes out of scope here, triggering Drop
494+
495+
// File should be cleaned up after Drop
496+
assert!(!db_path.exists());
497+
498+
Ok(())
499+
}
500+
501+
#[test]
502+
fn test_sqlite_tracing_different_targets() -> Result<()> {
503+
let tracing = SqliteTracing::new_in_memory()?;
504+
let conn = tracing.connection();
505+
506+
// Test different event targets
507+
info!(target:"messages", src = "actor1", dest = "actor2", payload = "test_message", "Message event");
508+
info!(target:"actor_lifecycle", actor_id = "123", actor = "TestActor", name = "test", "Lifecycle event");
509+
info!(target:"events", test_field = "general_event", "General event");
510+
511+
// Check that events went to the right tables
512+
let message_count: i64 =
513+
conn.lock()
514+
.unwrap()
515+
.query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
516+
assert_eq!(message_count, 1);
517+
518+
let lifecycle_count: i64 =
519+
conn.lock()
520+
.unwrap()
521+
.query_row("SELECT COUNT(*) FROM actor_lifecycle", [], |row| row.get(0))?;
522+
assert_eq!(lifecycle_count, 1);
523+
524+
let events_count: i64 =
525+
conn.lock()
526+
.unwrap()
527+
.query_row("SELECT COUNT(*) FROM events", [], |row| row.get(0))?;
528+
assert_eq!(events_count, 1);
529+
362530
Ok(())
363531
}
364532
}

monarch_hyperactor/src/telemetry.rs

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use std::cell::Cell;
1313
use hyperactor::clock::ClockKind;
1414
use hyperactor::clock::RealClock;
1515
use hyperactor::clock::SimClock;
16+
use hyperactor_telemetry::sqlite::SqliteTracing;
1617
use hyperactor_telemetry::swap_telemetry_clock;
1718
use opentelemetry::global;
1819
use opentelemetry::metrics;
@@ -65,7 +66,6 @@ pub fn forward_to_tracing(py: Python, record: PyObject) -> PyResult<()> {
6566
let file = record.getattr(py, "filename")?;
6667
let file: &str = file.extract(py)?;
6768
let level: i32 = record.getattr(py, "levelno")?.extract(py)?;
68-
6969
// Map level number to level name
7070
match level {
7171
40 | 50 => {
@@ -82,6 +82,7 @@ pub fn forward_to_tracing(py: Python, record: PyObject) -> PyResult<()> {
8282
match traceback {
8383
Some(traceback) => {
8484
tracing::error!(
85+
target:"events",
8586
file = file,
8687
lineno = lineno,
8788
stacktrace = traceback,
@@ -93,10 +94,10 @@ pub fn forward_to_tracing(py: Python, record: PyObject) -> PyResult<()> {
9394
}
9495
}
9596
}
96-
30 => tracing::warn!(file = file, lineno = lineno, message),
97-
20 => tracing::info!(file = file, lineno = lineno, message),
98-
10 => tracing::debug!(file = file, lineno = lineno, message),
99-
_ => tracing::info!(file = file, lineno = lineno, message),
97+
30 => tracing::warn!(target:"events", file = file, lineno = lineno, message),
98+
20 => tracing::info!(target:"events", file = file, lineno = lineno, message),
99+
10 => tracing::debug!(target:"events", file = file, lineno = lineno, message),
100+
_ => tracing::info!(target:"events", file = file, lineno = lineno, message),
100101
}
101102
Ok(())
102103
}
@@ -215,6 +216,62 @@ impl PySpan {
215216
}
216217
}
217218

219+
#[pyclass(
220+
subclass,
221+
module = "monarch._rust_bindings.monarch_hyperactor.telemetry"
222+
)]
223+
struct PySqliteTracing {
224+
guard: Option<SqliteTracing>,
225+
}
226+
227+
#[pymethods]
228+
impl PySqliteTracing {
229+
#[new]
230+
#[pyo3(signature = (in_memory = false))]
231+
fn new(in_memory: bool) -> PyResult<Self> {
232+
let guard = if in_memory {
233+
SqliteTracing::new_in_memory()
234+
} else {
235+
SqliteTracing::new()
236+
};
237+
238+
match guard {
239+
Ok(guard) => Ok(Self { guard: Some(guard) }),
240+
Err(e) => Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
241+
"Failed to create SQLite tracing guard: {}",
242+
e
243+
))),
244+
}
245+
}
246+
247+
fn db_path(&self) -> PyResult<Option<String>> {
248+
match &self.guard {
249+
Some(guard) => Ok(guard.db_path().map(|p| p.to_string_lossy().to_string())),
250+
None => Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
251+
"Guard has been closed",
252+
)),
253+
}
254+
}
255+
256+
fn __enter__(slf: PyRefMut<'_, Self>) -> PyResult<PyRefMut<'_, Self>> {
257+
Ok(slf)
258+
}
259+
260+
fn __exit__(
261+
&mut self,
262+
_exc_type: Option<PyObject>,
263+
_exc_value: Option<PyObject>,
264+
_traceback: Option<PyObject>,
265+
) -> PyResult<bool> {
266+
self.guard = None;
267+
Ok(false) // Don't suppress exceptions
268+
}
269+
270+
fn close(&mut self) {
271+
self.guard = None;
272+
}
273+
}
274+
218275
use pyo3::Bound;
219276
use pyo3::types::PyModule;
220277

@@ -267,5 +324,6 @@ pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> {
267324
module.add_class::<PyCounter>()?;
268325
module.add_class::<PyHistogram>()?;
269326
module.add_class::<PyUpDownCounter>()?;
327+
module.add_class::<PySqliteTracing>()?;
270328
Ok(())
271329
}

0 commit comments

Comments
 (0)