Skip to content

Commit 75ef603

Browse files
James Sunfacebook-github-bot
authored andcommitted
simplify aggregator (meta-pytorch#812)
Summary: no async, no buffer; whenver a log line received, just aggregate it directly. Reviewed By: pablorfb-meta Differential Revision: D79933065
1 parent ca2e7b9 commit 75ef603

File tree

3 files changed

+267
-113
lines changed

3 files changed

+267
-113
lines changed

hyperactor_mesh/src/logging.rs

Lines changed: 78 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,11 @@ use std::path::Path;
1212
use std::path::PathBuf;
1313
use std::pin::Pin;
1414
use std::sync::Arc;
15-
use std::sync::RwLock;
1615
use std::task::Context as TaskContext;
1716
use std::task::Poll;
1817
use std::time::Duration;
1918
use std::time::SystemTime;
2019

21-
use anyhow::Error;
2220
use anyhow::Result;
2321
use async_trait::async_trait;
2422
use chrono::DateTime;
@@ -50,10 +48,7 @@ use hyperactor_telemetry::log_file_path;
5048
use serde::Deserialize;
5149
use serde::Serialize;
5250
use tokio::io;
53-
use tokio::sync::mpsc;
54-
use tokio::sync::watch;
5551
use tokio::sync::watch::Receiver;
56-
use tokio::task::JoinHandle;
5752

5853
use crate::bootstrap::BOOTSTRAP_LOG_CHANNEL;
5954

@@ -264,6 +259,9 @@ pub enum LogMessage {
264259
/// The log payload as bytes
265260
payload: Serialized,
266261
},
262+
263+
/// Flush the log
264+
Flush {},
267265
}
268266

269267
/// Messages that can be sent to the LogClient locally.
@@ -663,20 +661,15 @@ fn deserialize_message_lines(
663661
handlers = [LogMessage, LogClientMessage],
664662
)]
665663
pub struct LogClientActor {
666-
log_tx: mpsc::Sender<(OutputTarget, String)>,
667-
#[allow(unused)]
668-
aggregator_handle: JoinHandle<Result<(), Error>>,
669-
/// The watch sender for the aggregation window in seconds
670-
aggregate_window_tx: watch::Sender<u64>,
671-
should_aggregate: bool,
672-
// Store aggregators directly in the actor for access in Drop
673-
aggregators: Arc<RwLock<HashMap<OutputTarget, Aggregator>>>,
664+
aggregate_window_sec: Option<u64>,
665+
aggregators: HashMap<OutputTarget, Aggregator>,
666+
last_flush_time: SystemTime,
667+
next_flush_deadline: Option<SystemTime>,
674668
}
675669

676670
impl LogClientActor {
677-
fn print_aggregators(aggregators: &RwLock<HashMap<OutputTarget, Aggregator>>) {
678-
let mut aggregators_guard = aggregators.write().unwrap();
679-
for (output_target, aggregator) in aggregators_guard.iter_mut() {
671+
fn print_aggregators(&mut self) {
672+
for (output_target, aggregator) in self.aggregators.iter_mut() {
680673
if aggregator.is_empty() {
681674
continue;
682675
}
@@ -693,6 +686,14 @@ impl LogClientActor {
693686
aggregator.reset();
694687
}
695688
}
689+
690+
fn print_log_line(hostname: &str, pid: u32, output_target: OutputTarget, line: String) {
691+
let message = format!("[{} {}] {}", hostname, pid, line);
692+
match output_target {
693+
OutputTarget::Stdout => println!("{}", message),
694+
OutputTarget::Stderr => eprintln!("{}", message),
695+
}
696+
}
696697
}
697698

698699
#[async_trait]
@@ -701,114 +702,96 @@ impl Actor for LogClientActor {
701702
type Params = ();
702703

703704
async fn new(_: ()) -> Result<Self, anyhow::Error> {
704-
// Create mpsc channel for log messages
705-
let (log_tx, log_rx) = mpsc::channel::<(OutputTarget, String)>(1000);
706-
707-
// Create a watch channel for the aggregation window
708-
let (aggregate_window_tx, aggregate_window_rx) =
709-
watch::channel(DEFAULT_AGGREGATE_WINDOW_SEC);
710-
711705
// Initialize aggregators
712706
let mut aggregators = HashMap::new();
713707
aggregators.insert(OutputTarget::Stderr, Aggregator::new());
714708
aggregators.insert(OutputTarget::Stdout, Aggregator::new());
715-
let aggregators = Arc::new(RwLock::new(aggregators));
716-
717-
// Clone aggregators for the aggregator task
718-
let aggregators_for_task = Arc::clone(&aggregators);
719-
720-
// Start the loggregator
721-
let aggregator_handle = tokio::spawn(async move {
722-
start_aggregator(log_rx, aggregate_window_rx, aggregators_for_task).await
723-
});
724709

725710
Ok(Self {
726-
log_tx,
727-
aggregator_handle,
728-
aggregate_window_tx,
729-
should_aggregate: true,
711+
aggregate_window_sec: Some(DEFAULT_AGGREGATE_WINDOW_SEC),
730712
aggregators,
713+
last_flush_time: RealClock.system_time_now(),
714+
next_flush_deadline: None,
731715
})
732716
}
733717
}
734718

735719
impl Drop for LogClientActor {
736720
fn drop(&mut self) {
737721
// Flush the remaining logs before shutting down
738-
Self::print_aggregators(&self.aggregators);
722+
self.print_aggregators();
739723
}
740724
}
741725

742-
async fn start_aggregator(
743-
mut log_rx: mpsc::Receiver<(OutputTarget, String)>,
744-
mut interval_sec_rx: watch::Receiver<u64>,
745-
aggregators: Arc<RwLock<HashMap<OutputTarget, Aggregator>>>,
746-
) -> anyhow::Result<()> {
747-
let mut interval =
748-
tokio::time::interval(tokio::time::Duration::from_secs(*interval_sec_rx.borrow()));
749-
750-
// Start the event loop
751-
loop {
752-
tokio::select! {
753-
// Process incoming log messages
754-
Some((output_target, log_line)) = log_rx.recv() => {
755-
let mut aggregators_guard = aggregators.write().unwrap();
756-
if let Some(aggregator) = aggregators_guard.get_mut(&output_target) {
757-
if let Err(e) = aggregator.add_line(&log_line) {
758-
tracing::error!("error adding log line: {}", e);
759-
}
760-
} else {
761-
tracing::error!("unknown output target: {:?}", output_target);
762-
}
763-
}
764-
// Watch for changes in the interval
765-
Ok(_) = interval_sec_rx.changed() => {
766-
interval = tokio::time::interval(tokio::time::Duration::from_secs(*interval_sec_rx.borrow()));
767-
}
768-
769-
// Every interval tick, print and reset the aggregator
770-
_ = interval.tick() => {
771-
LogClientActor::print_aggregators(&aggregators);
772-
}
773-
774-
// Exit if the channel is closed
775-
else => {
776-
tracing::error!("log channel closed, exiting aggregator");
777-
// Print final aggregated logs before shutting down
778-
LogClientActor::print_aggregators(&aggregators);
779-
break;
780-
}
781-
}
782-
}
783-
784-
Ok(())
785-
}
786-
787726
#[async_trait]
788727
#[hyperactor::forward(LogMessage)]
789728
impl LogMessageHandler for LogClientActor {
790729
async fn log(
791730
&mut self,
792-
_cx: &Context<Self>,
731+
cx: &Context<Self>,
793732
hostname: String,
794733
pid: u32,
795734
output_target: OutputTarget,
796735
payload: Serialized,
797736
) -> Result<(), anyhow::Error> {
798737
// Deserialize the message and process line by line with UTF-8
799738
let message_lines = deserialize_message_lines(&payload)?;
739+
let hostname = hostname.as_str();
800740

801-
for line in message_lines {
802-
if self.should_aggregate {
803-
self.log_tx.send((output_target, line)).await?;
804-
} else {
805-
let message = format!("[{} {}] {}", hostname, pid, line);
806-
match output_target {
807-
OutputTarget::Stdout => println!("{}", message),
808-
OutputTarget::Stderr => eprintln!("{}", message),
741+
match self.aggregate_window_sec {
742+
None => {
743+
for line in message_lines {
744+
Self::print_log_line(hostname, pid, output_target, line);
745+
}
746+
self.last_flush_time = RealClock.system_time_now();
747+
}
748+
Some(window) => {
749+
for line in message_lines {
750+
if let Some(aggregator) = self.aggregators.get_mut(&output_target) {
751+
if let Err(e) = aggregator.add_line(&line) {
752+
tracing::error!("error adding log line: {}", e);
753+
// For the sake of completeness, flush the log lines.
754+
Self::print_log_line(hostname, pid, output_target, line);
755+
}
756+
} else {
757+
tracing::error!("unknown output target: {:?}", output_target);
758+
// For the sake of completeness, flush the log lines.
759+
Self::print_log_line(hostname, pid, output_target, line);
760+
}
761+
}
762+
763+
let new_deadline = self.last_flush_time + Duration::from_secs(window);
764+
let now = RealClock.system_time_now();
765+
if new_deadline <= now {
766+
self.flush(cx).await?;
767+
} else {
768+
let delay = new_deadline.duration_since(now)?;
769+
match self.next_flush_deadline {
770+
None => {
771+
self.next_flush_deadline = Some(new_deadline);
772+
cx.self_message_with_delay(LogMessage::Flush {}, delay)?;
773+
}
774+
Some(deadline) => {
775+
// Some early log lines have alrady triggered the flush.
776+
if new_deadline < deadline {
777+
// This can happen if the user has adjusted the aggregation window.
778+
self.next_flush_deadline = Some(new_deadline);
779+
cx.self_message_with_delay(LogMessage::Flush {}, delay)?;
780+
}
781+
}
782+
}
809783
}
810784
}
811785
}
786+
787+
Ok(())
788+
}
789+
790+
async fn flush(&mut self, _cx: &Context<Self>) -> Result<(), anyhow::Error> {
791+
self.print_aggregators();
792+
self.last_flush_time = RealClock.system_time_now();
793+
self.next_flush_deadline = None;
794+
812795
Ok(())
813796
}
814797
}
@@ -821,11 +804,11 @@ impl LogClientMessageHandler for LogClientActor {
821804
_cx: &Context<Self>,
822805
aggregate_window_sec: Option<u64>,
823806
) -> Result<(), anyhow::Error> {
824-
if let Some(window) = aggregate_window_sec {
825-
// Send the new value through the watch channel
826-
self.aggregate_window_tx.send(window)?;
807+
if self.aggregate_window_sec.is_some() && aggregate_window_sec.is_none() {
808+
// Make sure we flush whatever in the aggregators before disabling aggregation.
809+
self.print_aggregators();
827810
}
828-
self.should_aggregate = aggregate_window_sec.is_some();
811+
self.aggregate_window_sec = aggregate_window_sec;
829812
Ok(())
830813
}
831814
}

python/tests/python_actor_test_binary.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import asyncio
1010
import logging
11-
import sys
1211

1312
import click
1413

@@ -27,12 +26,12 @@ def __init__(self) -> None:
2726
@endpoint
2827
async def print(self, content: str) -> None:
2928
print(f"{content}", flush=True)
30-
sys.stdout.flush()
31-
sys.stderr.flush()
3229

3330

3431
async def _flush_logs() -> None:
35-
pm = await proc_mesh(gpus=2)
32+
# Create a lot of processes to stress test the logging
33+
pm = await proc_mesh(gpus=32)
34+
3635
# never flush
3736
await pm.logging_option(aggregate_window_sec=1000)
3837
am = await pm.spawn("printer", Printer)

0 commit comments

Comments
 (0)