Skip to content

Commit dc4d14b

Browse files
benjipelletierfacebook-github-bot
authored andcommitted
Add sqlite tracing recorder and sql assertions (#779)
Summary: New tracer subscriber to be used for testing (e.g., script or simulator) 1. New logging layer for use in tests that writes all log messages to a series of sqlite tables 2. Add capability to do sql based assertions for script tests or simulation tests 3. New trace level logging events on actor lifecycle events Next diffs will: * Get this working for our PAFT simulator tests so we can easily assert * Support custom columns Differential Revision: D73512355
1 parent 707f705 commit dc4d14b

File tree

4 files changed

+368
-1
lines changed

4 files changed

+368
-1
lines changed

hyperactor/src/mailbox.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1043,7 +1043,7 @@ impl MailboxSender for MailboxClient {
10431043
return_handle: PortHandle<Undeliverable<MessageEnvelope>>,
10441044
) {
10451045
// tracing::trace!(name = "post", "posting message to {}", envelope.dest);
1046-
tracing::event!(target:"message", tracing::Level::DEBUG, "crc"=envelope.data.crc(), "size"=envelope.data.len(), "sender"= %envelope.sender, "dest" = %envelope.dest.0, "port"= envelope.dest.1, "message_type" = envelope.data.typename().unwrap_or("unknown"), "send_message");
1046+
tracing::event!(target:"messages", tracing::Level::DEBUG, "crc"=envelope.data.crc(), "size"=envelope.data.len(), "sender"= %envelope.sender, "dest" = %envelope.dest.0, "port"= envelope.dest.1, "message_type" = envelope.data.typename().unwrap_or("unknown"), "send_message");
10471047

10481048
if let Err(mpsc::error::SendError((envelope, return_handle))) =
10491049
self.buffer.send((envelope, return_handle))

hyperactor_telemetry/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@ lazy_static = "1.5"
2020
opentelemetry = "0.29"
2121
opentelemetry_sdk = { version = "0.29.0", features = ["rt-tokio"] }
2222
rand = { version = "0.8", features = ["small_rng"] }
23+
rusqlite = { version = "0.36.0", features = ["backup", "blob", "bundled", "column_decltype", "functions", "limits", "modern_sqlite", "serde_json"] }
2324
scuba = { version = "0.1.0", git = "https://github.com/facebookexperimental/rust-shed.git", branch = "main", optional = true }
2425
serde = { version = "1.0.219", features = ["derive", "rc"] }
2526
serde_json = { version = "1.0.140", features = ["alloc", "float_roundtrip", "unbounded_depth"] }
27+
serde_rusqlite = "0.39.3"
2628
tokio = { version = "1.46.1", features = ["full", "test-util", "tracing"] }
2729
tracing = { version = "0.1.41", features = ["attributes", "valuable"] }
2830
tracing-appender = "0.2.3"

hyperactor_telemetry/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ mod otel;
3333
mod pool;
3434
pub mod recorder;
3535
mod spool;
36+
pub mod sqlite;
3637
use std::io::IsTerminal;
3738
use std::io::Write;
3839
use std::str::FromStr;

hyperactor_telemetry/src/sqlite.rs

Lines changed: 364 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,364 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
use std::collections::HashMap;
10+
use std::sync::Arc;
11+
use std::sync::Mutex;
12+
13+
use anyhow::Result;
14+
use anyhow::anyhow;
15+
use lazy_static::lazy_static;
16+
use rusqlite::Connection;
17+
use rusqlite::functions::FunctionFlags;
18+
use serde::Serialize;
19+
use serde_json::Value as JValue;
20+
use serde_rusqlite::*;
21+
use tracing::Event;
22+
use tracing::Subscriber;
23+
use tracing::level_filters::LevelFilter;
24+
use tracing_subscriber::Layer;
25+
use tracing_subscriber::filter::Targets;
26+
use tracing_subscriber::prelude::*;
27+
28+
pub trait TableDef {
29+
fn name(&self) -> &'static str;
30+
fn columns(&self) -> &'static [&'static str];
31+
fn create_table_stmt(&self) -> String {
32+
let name = self.name();
33+
let columns = self
34+
.columns()
35+
.iter()
36+
.map(|col| format!("{col} TEXT "))
37+
.collect::<Vec<String>>()
38+
.join(",");
39+
format!("create table if not exists {name} (seq INTEGER primary key, {columns})")
40+
}
41+
fn insert_stmt(&self) -> String {
42+
let name = self.name();
43+
let columns = self.columns().join(", ");
44+
let params = self
45+
.columns()
46+
.iter()
47+
.map(|c| format!(":{c}"))
48+
.collect::<Vec<String>>()
49+
.join(", ");
50+
format!("insert into {name} ({columns}) values ({params})")
51+
}
52+
}
53+
54+
impl TableDef for (&'static str, &'static [&'static str]) {
55+
fn name(&self) -> &'static str {
56+
self.0
57+
}
58+
59+
fn columns(&self) -> &'static [&'static str] {
60+
self.1
61+
}
62+
}
63+
64+
#[derive(Clone, Debug)]
65+
pub struct Table {
66+
pub columns: &'static [&'static str],
67+
pub create_table_stmt: String,
68+
pub insert_stmt: String,
69+
}
70+
71+
impl From<(&'static str, &'static [&'static str])> for Table {
72+
fn from(value: (&'static str, &'static [&'static str])) -> Self {
73+
Self {
74+
columns: value.columns(),
75+
create_table_stmt: value.create_table_stmt(),
76+
insert_stmt: value.insert_stmt(),
77+
}
78+
}
79+
}
80+
81+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
82+
pub enum TableName {
83+
ActorLifecycle,
84+
Messages,
85+
Events,
86+
}
87+
88+
impl TableName {
89+
pub const ACTOR_LIFECYCLE_STR: &'static str = "actor_lifecycle";
90+
pub const MESSAGES_STR: &'static str = "messages";
91+
pub const EVENTS_STR: &'static str = "events";
92+
93+
pub fn as_str(&self) -> &'static str {
94+
match self {
95+
TableName::ActorLifecycle => Self::ACTOR_LIFECYCLE_STR,
96+
TableName::Messages => Self::MESSAGES_STR,
97+
TableName::Events => Self::EVENTS_STR,
98+
}
99+
}
100+
101+
pub fn get_table(&self) -> &'static Table {
102+
match self {
103+
TableName::ActorLifecycle => &ACTOR_LIFECYCLE,
104+
TableName::Messages => &MESSAGES,
105+
TableName::Events => &EVENTS,
106+
}
107+
}
108+
}
109+
110+
lazy_static! {
111+
static ref ACTOR_LIFECYCLE: Table = (
112+
TableName::ActorLifecycle.as_str(),
113+
[
114+
"actor_id",
115+
"actor",
116+
"name",
117+
"supervised_actor",
118+
"actor_status",
119+
"module_path",
120+
"line",
121+
"file",
122+
]
123+
.as_slice()
124+
)
125+
.into();
126+
static ref MESSAGES: Table = (
127+
TableName::Messages.as_str(),
128+
[
129+
"span_id",
130+
"time_us",
131+
"src",
132+
"dest",
133+
"payload",
134+
"module_path",
135+
"line",
136+
"file",
137+
]
138+
.as_slice()
139+
)
140+
.into();
141+
static ref EVENTS: Table = (
142+
TableName::Events.as_str(),
143+
[
144+
"span_id",
145+
"time_us",
146+
"name",
147+
"message",
148+
"actor_id",
149+
"level",
150+
"line",
151+
"file",
152+
"module_path",
153+
]
154+
.as_slice()
155+
)
156+
.into();
157+
static ref ALL_TABLES: Vec<Table> =
158+
vec![ACTOR_LIFECYCLE.clone(), MESSAGES.clone(), EVENTS.clone()];
159+
}
160+
161+
pub struct SqliteLayer {
162+
conn: Arc<Mutex<Connection>>,
163+
}
164+
use tracing::field::Visit;
165+
166+
#[derive(Debug, Clone, Default, Serialize)]
167+
struct SqlVisitor(HashMap<String, JValue>);
168+
169+
impl Visit for SqlVisitor {
170+
fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
171+
self.0.insert(
172+
field.name().to_string(),
173+
JValue::String(format!("{:?}", value)),
174+
);
175+
}
176+
177+
fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
178+
self.0
179+
.insert(field.name().to_string(), JValue::String(value.to_string()));
180+
}
181+
182+
fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
183+
self.0
184+
.insert(field.name().to_string(), JValue::Number(value.into()));
185+
}
186+
187+
fn record_f64(&mut self, field: &tracing::field::Field, value: f64) {
188+
let n = serde_json::Number::from_f64(value).unwrap();
189+
self.0.insert(field.name().to_string(), JValue::Number(n));
190+
}
191+
192+
fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
193+
self.0
194+
.insert(field.name().to_string(), JValue::Number(value.into()));
195+
}
196+
197+
fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
198+
self.0.insert(field.name().to_string(), JValue::Bool(value));
199+
}
200+
}
201+
202+
macro_rules! insert_event {
203+
($table:expr, $conn:ident, $event:ident) => {
204+
let mut v: SqlVisitor = Default::default();
205+
$event.record(&mut v);
206+
let meta = $event.metadata();
207+
v.0.insert(
208+
"module_path".to_string(),
209+
meta.module_path().map(String::from).into(),
210+
);
211+
v.0.insert("line".to_string(), meta.line().into());
212+
v.0.insert("file".to_string(), meta.file().map(String::from).into());
213+
$conn.prepare_cached(&$table.insert_stmt)?.execute(
214+
serde_rusqlite::to_params_named_with_fields(v, $table.columns)?
215+
.to_slice()
216+
.as_slice(),
217+
)?;
218+
};
219+
}
220+
221+
impl SqliteLayer {
222+
pub fn new() -> Result<Self> {
223+
let conn = Connection::open_in_memory()?;
224+
225+
for table in ALL_TABLES.iter() {
226+
conn.execute(&table.create_table_stmt, [])?;
227+
}
228+
conn.create_scalar_function(
229+
"assert",
230+
2,
231+
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
232+
move |ctx| {
233+
let condition: bool = ctx.get(0)?;
234+
let message: String = ctx.get(1)?;
235+
236+
if !condition {
237+
return Err(rusqlite::Error::UserFunctionError(
238+
anyhow!("assertion failed:{condition} {message}",).into(),
239+
));
240+
}
241+
242+
Ok(condition)
243+
},
244+
)?;
245+
246+
Ok(Self {
247+
conn: Arc::new(Mutex::new(conn)),
248+
})
249+
}
250+
251+
fn insert_event(&self, event: &Event<'_>) -> Result<()> {
252+
let conn = self.conn.lock().unwrap();
253+
match (event.metadata().target(), event.metadata().name()) {
254+
(TableName::MESSAGES_STR, _) => {
255+
insert_event!(TableName::Messages.get_table(), conn, event);
256+
}
257+
(TableName::ACTOR_LIFECYCLE_STR, _) => {
258+
insert_event!(TableName::ActorLifecycle.get_table(), conn, event);
259+
}
260+
_ => {
261+
insert_event!(TableName::Events.get_table(), conn, event);
262+
}
263+
}
264+
Ok(())
265+
}
266+
267+
pub fn connection(&self) -> Arc<Mutex<Connection>> {
268+
self.conn.clone()
269+
}
270+
}
271+
272+
impl<S: Subscriber> Layer<S> for SqliteLayer {
273+
fn on_event(&self, event: &Event<'_>, _ctx: tracing_subscriber::layer::Context<'_, S>) {
274+
self.insert_event(event).unwrap();
275+
}
276+
}
277+
278+
#[allow(dead_code)]
279+
fn print_table(conn: &Connection, table_name: TableName) -> Result<()> {
280+
let table_name_str = table_name.as_str();
281+
282+
// Get column names
283+
let mut stmt = conn.prepare(&format!("PRAGMA table_info({})", table_name_str))?;
284+
let column_info = stmt.query_map([], |row| {
285+
row.get::<_, String>(1) // Column name is at index 1
286+
})?;
287+
288+
let columns: Vec<String> = column_info.collect::<Result<Vec<_>, _>>()?;
289+
290+
// Print header
291+
println!("=== {} ===", table_name_str.to_uppercase());
292+
println!("{}", columns.join(" | "));
293+
println!("{}", "-".repeat(columns.len() * 10));
294+
295+
// Print rows
296+
let mut stmt = conn.prepare(&format!("SELECT * FROM {}", table_name_str))?;
297+
let rows = stmt.query_map([], |row| {
298+
let mut values = Vec::new();
299+
for (i, column) in columns.iter().enumerate() {
300+
// Handle different column types properly
301+
let value = if i == 0 && *column == "seq" {
302+
// First column is always the INTEGER seq column
303+
match row.get::<_, Option<i64>>(i)? {
304+
Some(v) => v.to_string(),
305+
None => "NULL".to_string(),
306+
}
307+
} else {
308+
// All other columns are TEXT
309+
match row.get::<_, Option<String>>(i)? {
310+
Some(v) => v,
311+
None => "NULL".to_string(),
312+
}
313+
};
314+
values.push(value);
315+
}
316+
Ok(values.join(" | "))
317+
})?;
318+
319+
for row in rows {
320+
println!("{}", row?);
321+
}
322+
println!();
323+
Ok(())
324+
}
325+
326+
pub fn with_tracing_db() -> Arc<Mutex<Connection>> {
327+
let layer = SqliteLayer::new().unwrap();
328+
let conn = layer.connection();
329+
330+
let layer = layer.with_filter(
331+
Targets::new()
332+
.with_default(LevelFilter::TRACE)
333+
.with_targets(vec![
334+
("tokio", LevelFilter::OFF),
335+
("opentelemetry", LevelFilter::OFF),
336+
("runtime", LevelFilter::OFF),
337+
]),
338+
);
339+
tracing_subscriber::registry().with(layer).init();
340+
conn
341+
}
342+
343+
#[cfg(test)]
344+
mod tests {
345+
use tracing::info;
346+
347+
use super::*;
348+
349+
#[test]
350+
fn test_sqlite_layer() -> Result<()> {
351+
let conn = with_tracing_db();
352+
353+
info!(target:"messages", test_field = "test_value", "Test msg");
354+
info!(target:"events", test_field = "test_value", "Test event");
355+
356+
let count: i64 =
357+
conn.lock()
358+
.unwrap()
359+
.query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
360+
print_table(&conn.lock().unwrap(), TableName::Events)?;
361+
assert!(count > 0);
362+
Ok(())
363+
}
364+
}

0 commit comments

Comments
 (0)