Skip to content

simplify aggregator #812

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 78 additions & 95 deletions hyperactor_mesh/src/logging.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,11 @@ use std::path::Path;
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::RwLock;
use std::task::Context as TaskContext;
use std::task::Poll;
use std::time::Duration;
use std::time::SystemTime;

use anyhow::Error;
use anyhow::Result;
use async_trait::async_trait;
use chrono::DateTime;
Expand Down Expand Up @@ -50,10 +48,7 @@ use hyperactor_telemetry::log_file_path;
use serde::Deserialize;
use serde::Serialize;
use tokio::io;
use tokio::sync::mpsc;
use tokio::sync::watch;
use tokio::sync::watch::Receiver;
use tokio::task::JoinHandle;

use crate::bootstrap::BOOTSTRAP_LOG_CHANNEL;

Expand Down Expand Up @@ -264,6 +259,9 @@ pub enum LogMessage {
/// The log payload as bytes
payload: Serialized,
},

/// Flush the log
Flush {},
}

/// Messages that can be sent to the LogClient locally.
Expand Down Expand Up @@ -663,20 +661,15 @@ fn deserialize_message_lines(
handlers = [LogMessage, LogClientMessage],
)]
pub struct LogClientActor {
log_tx: mpsc::Sender<(OutputTarget, String)>,
#[allow(unused)]
aggregator_handle: JoinHandle<Result<(), Error>>,
/// The watch sender for the aggregation window in seconds
aggregate_window_tx: watch::Sender<u64>,
should_aggregate: bool,
// Store aggregators directly in the actor for access in Drop
aggregators: Arc<RwLock<HashMap<OutputTarget, Aggregator>>>,
aggregate_window_sec: Option<u64>,
aggregators: HashMap<OutputTarget, Aggregator>,
last_flush_time: SystemTime,
next_flush_deadline: Option<SystemTime>,
}

impl LogClientActor {
fn print_aggregators(aggregators: &RwLock<HashMap<OutputTarget, Aggregator>>) {
let mut aggregators_guard = aggregators.write().unwrap();
for (output_target, aggregator) in aggregators_guard.iter_mut() {
fn print_aggregators(&mut self) {
for (output_target, aggregator) in self.aggregators.iter_mut() {
if aggregator.is_empty() {
continue;
}
Expand All @@ -693,6 +686,14 @@ impl LogClientActor {
aggregator.reset();
}
}

fn print_log_line(hostname: &str, pid: u32, output_target: OutputTarget, line: String) {
let message = format!("[{} {}] {}", hostname, pid, line);
match output_target {
OutputTarget::Stdout => println!("{}", message),
OutputTarget::Stderr => eprintln!("{}", message),
}
}
}

#[async_trait]
Expand All @@ -701,114 +702,96 @@ impl Actor for LogClientActor {
type Params = ();

async fn new(_: ()) -> Result<Self, anyhow::Error> {
// Create mpsc channel for log messages
let (log_tx, log_rx) = mpsc::channel::<(OutputTarget, String)>(1000);

// Create a watch channel for the aggregation window
let (aggregate_window_tx, aggregate_window_rx) =
watch::channel(DEFAULT_AGGREGATE_WINDOW_SEC);

// Initialize aggregators
let mut aggregators = HashMap::new();
aggregators.insert(OutputTarget::Stderr, Aggregator::new());
aggregators.insert(OutputTarget::Stdout, Aggregator::new());
let aggregators = Arc::new(RwLock::new(aggregators));

// Clone aggregators for the aggregator task
let aggregators_for_task = Arc::clone(&aggregators);

// Start the loggregator
let aggregator_handle = tokio::spawn(async move {
start_aggregator(log_rx, aggregate_window_rx, aggregators_for_task).await
});

Ok(Self {
log_tx,
aggregator_handle,
aggregate_window_tx,
should_aggregate: true,
aggregate_window_sec: Some(DEFAULT_AGGREGATE_WINDOW_SEC),
aggregators,
last_flush_time: RealClock.system_time_now(),
next_flush_deadline: None,
})
}
}

impl Drop for LogClientActor {
fn drop(&mut self) {
// Flush the remaining logs before shutting down
Self::print_aggregators(&self.aggregators);
self.print_aggregators();
}
}

async fn start_aggregator(
mut log_rx: mpsc::Receiver<(OutputTarget, String)>,
mut interval_sec_rx: watch::Receiver<u64>,
aggregators: Arc<RwLock<HashMap<OutputTarget, Aggregator>>>,
) -> anyhow::Result<()> {
let mut interval =
tokio::time::interval(tokio::time::Duration::from_secs(*interval_sec_rx.borrow()));

// Start the event loop
loop {
tokio::select! {
// Process incoming log messages
Some((output_target, log_line)) = log_rx.recv() => {
let mut aggregators_guard = aggregators.write().unwrap();
if let Some(aggregator) = aggregators_guard.get_mut(&output_target) {
if let Err(e) = aggregator.add_line(&log_line) {
tracing::error!("error adding log line: {}", e);
}
} else {
tracing::error!("unknown output target: {:?}", output_target);
}
}
// Watch for changes in the interval
Ok(_) = interval_sec_rx.changed() => {
interval = tokio::time::interval(tokio::time::Duration::from_secs(*interval_sec_rx.borrow()));
}

// Every interval tick, print and reset the aggregator
_ = interval.tick() => {
LogClientActor::print_aggregators(&aggregators);
}

// Exit if the channel is closed
else => {
tracing::error!("log channel closed, exiting aggregator");
// Print final aggregated logs before shutting down
LogClientActor::print_aggregators(&aggregators);
break;
}
}
}

Ok(())
}

#[async_trait]
#[hyperactor::forward(LogMessage)]
impl LogMessageHandler for LogClientActor {
async fn log(
&mut self,
_cx: &Context<Self>,
cx: &Context<Self>,
hostname: String,
pid: u32,
output_target: OutputTarget,
payload: Serialized,
) -> Result<(), anyhow::Error> {
// Deserialize the message and process line by line with UTF-8
let message_lines = deserialize_message_lines(&payload)?;
let hostname = hostname.as_str();

for line in message_lines {
if self.should_aggregate {
self.log_tx.send((output_target, line)).await?;
} else {
let message = format!("[{} {}] {}", hostname, pid, line);
match output_target {
OutputTarget::Stdout => println!("{}", message),
OutputTarget::Stderr => eprintln!("{}", message),
match self.aggregate_window_sec {
None => {
for line in message_lines {
Self::print_log_line(hostname, pid, output_target, line);
}
self.last_flush_time = RealClock.system_time_now();
}
Some(window) => {
for line in message_lines {
if let Some(aggregator) = self.aggregators.get_mut(&output_target) {
if let Err(e) = aggregator.add_line(&line) {
tracing::error!("error adding log line: {}", e);
// For the sake of completeness, flush the log lines.
Self::print_log_line(hostname, pid, output_target, line);
}
} else {
tracing::error!("unknown output target: {:?}", output_target);
// For the sake of completeness, flush the log lines.
Self::print_log_line(hostname, pid, output_target, line);
}
}

let new_deadline = self.last_flush_time + Duration::from_secs(window);
let now = RealClock.system_time_now();
if new_deadline <= now {
self.flush(cx).await?;
} else {
let delay = new_deadline.duration_since(now)?;
match self.next_flush_deadline {
None => {
self.next_flush_deadline = Some(new_deadline);
cx.self_message_with_delay(LogMessage::Flush {}, delay)?;
}
Some(deadline) => {
// Some early log lines have alrady triggered the flush.
if new_deadline < deadline {
// This can happen if the user has adjusted the aggregation window.
self.next_flush_deadline = Some(new_deadline);
cx.self_message_with_delay(LogMessage::Flush {}, delay)?;
}
}
}
}
}
}

Ok(())
}

async fn flush(&mut self, _cx: &Context<Self>) -> Result<(), anyhow::Error> {
self.print_aggregators();
self.last_flush_time = RealClock.system_time_now();
self.next_flush_deadline = None;

Ok(())
}
}
Expand All @@ -821,11 +804,11 @@ impl LogClientMessageHandler for LogClientActor {
_cx: &Context<Self>,
aggregate_window_sec: Option<u64>,
) -> Result<(), anyhow::Error> {
if let Some(window) = aggregate_window_sec {
// Send the new value through the watch channel
self.aggregate_window_tx.send(window)?;
if self.aggregate_window_sec.is_some() && aggregate_window_sec.is_none() {
// Make sure we flush whatever in the aggregators before disabling aggregation.
self.print_aggregators();
}
self.should_aggregate = aggregate_window_sec.is_some();
self.aggregate_window_sec = aggregate_window_sec;
Ok(())
}
}
Expand Down
7 changes: 3 additions & 4 deletions python/tests/python_actor_test_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import asyncio
import logging
import sys

import click

Expand All @@ -27,12 +26,12 @@ def __init__(self) -> None:
@endpoint
async def print(self, content: str) -> None:
print(f"{content}", flush=True)
sys.stdout.flush()
sys.stderr.flush()


async def _flush_logs() -> None:
pm = await proc_mesh(gpus=2)
# Create a lot of processes to stress test the logging
pm = await proc_mesh(gpus=32)

# never flush
await pm.logging_option(aggregate_window_sec=1000)
am = await pm.spawn("printer", Printer)
Expand Down
Loading