Skip to content

Commit c508791

Browse files
: proxy mesh process test (meta-pytorch#809)
Summary: this self-bootstrapping program spawns a `ProxyActor` on a remote process, which then spawns a `TestActor` on another remote process, creating a 3-level process hierarchy. Differential Revision: D79930332
1 parent 36b26f1 commit c508791

File tree

2 files changed

+279
-1
lines changed

2 files changed

+279
-1
lines changed

hyperactor_mesh/src/actor_mesh.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ impl<'a, A: RemoteActor> RootActorMesh<'a, A> {
280280
}
281281

282282
/// Open a port on this ActorMesh.
283-
pub(crate) fn open_port<M: Message>(&self) -> (PortHandle<M>, PortReceiver<M>) {
283+
pub fn open_port<M: Message>(&self) -> (PortHandle<M>, PortReceiver<M>) {
284284
self.proc_mesh.client().open_port()
285285
}
286286

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
use std::env;
10+
use std::fmt;
11+
use std::path::PathBuf;
12+
use std::sync::Arc;
13+
use std::sync::OnceLock;
14+
15+
use anyhow::Result;
16+
use async_trait::async_trait;
17+
use clap::Parser;
18+
use hyperactor::Actor;
19+
use hyperactor::Context;
20+
use hyperactor::Handler;
21+
use hyperactor::Named;
22+
use hyperactor::PortRef;
23+
use hyperactor_mesh::Mesh;
24+
use hyperactor_mesh::ProcMesh;
25+
use hyperactor_mesh::RootActorMesh;
26+
use hyperactor_mesh::alloc::AllocSpec;
27+
use hyperactor_mesh::alloc::Allocator;
28+
use hyperactor_mesh::alloc::ProcessAllocator;
29+
use ndslice::extent;
30+
use serde::Deserialize;
31+
use serde::Serialize;
32+
use tokio::process::Command;
33+
34+
pub fn initialize() {
35+
let subscriber = tracing_subscriber::fmt()
36+
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
37+
.finish();
38+
tracing::subscriber::set_global_default(subscriber).expect("failed to set subscriber");
39+
40+
static INITIALIZED: OnceLock<()> = OnceLock::new();
41+
INITIALIZED.get_or_init(|| {
42+
#[cfg(target_os = "linux")]
43+
linux::initialize();
44+
});
45+
}
46+
47+
#[cfg(target_os = "linux")]
48+
mod linux {
49+
use std::backtrace::Backtrace;
50+
use std::process;
51+
52+
use nix::sys::signal::SigHandler;
53+
use nix::unistd::getpid;
54+
use tokio::signal::unix::SignalKind;
55+
use tokio::signal::unix::signal;
56+
57+
pub(crate) fn initialize() {
58+
// Safety: Because I want to
59+
unsafe {
60+
extern "C" fn handle_fatal_signal(signo: libc::c_int) {
61+
let bt = Backtrace::force_capture();
62+
let signame = nix::sys::signal::Signal::try_from(signo).expect("unknown signal");
63+
tracing::error!("stacktrace"= %bt, "fatal signal {signo}:{signame} received");
64+
std::process::exit(1);
65+
}
66+
nix::sys::signal::signal(
67+
nix::sys::signal::SIGABRT,
68+
SigHandler::Handler(handle_fatal_signal),
69+
)
70+
.expect("unable to register signal handler");
71+
nix::sys::signal::signal(
72+
nix::sys::signal::SIGSEGV,
73+
SigHandler::Handler(handle_fatal_signal),
74+
)
75+
.expect("unable to register signal handler");
76+
}
77+
78+
// Set up the async signal handler FIRST
79+
let rt = tokio::runtime::Handle::current();
80+
rt.spawn(async {
81+
// Set up signal handler before prctl
82+
let mut sigusr1 = match signal(SignalKind::user_defined1()) {
83+
Ok(s) => s,
84+
Err(err) => {
85+
eprintln!("failed to set up SIGUSR1 signal handler: {:?}", err);
86+
return;
87+
}
88+
};
89+
90+
// SAFETY: Now set PDEATHSIG after handler is ready. This
91+
// is unsafe.
92+
unsafe {
93+
if libc::prctl(
94+
libc::PR_SET_PDEATHSIG,
95+
nix::sys::signal::SIGUSR1 as libc::c_ulong,
96+
) != 0
97+
{
98+
eprintln!(
99+
"prctl(PR_SET_PDEATHSIG) failed: {}",
100+
std::io::Error::last_os_error()
101+
);
102+
return;
103+
}
104+
105+
// Close the race: if parent already died, we are now orphaned.
106+
if libc::getppid() == 1 {
107+
tracing::error!(
108+
"hyperactor[{}]: parent already dead on startup; exiting",
109+
getpid()
110+
);
111+
std::process::exit(1);
112+
}
113+
}
114+
115+
// Wait for the signal
116+
sigusr1.recv().await;
117+
tracing::error!(
118+
"hyperactor[{}]: parent process died (SIGUSR1 received); exiting",
119+
getpid()
120+
);
121+
process::exit(1);
122+
});
123+
}
124+
}
125+
126+
#[derive(Parser)]
127+
struct Args {
128+
/// Run bootstrap logic
129+
#[arg(long)]
130+
bootstrap: bool,
131+
}
132+
133+
// -- TestActor
134+
135+
#[derive(Debug)]
136+
#[hyperactor::export(
137+
spawn = true,
138+
handlers = [
139+
Echo,
140+
],
141+
)]
142+
pub struct TestActor;
143+
144+
#[async_trait]
145+
impl Actor for TestActor {
146+
type Params = ();
147+
148+
async fn new(_params: Self::Params) -> Result<Self, anyhow::Error> {
149+
Ok(Self)
150+
}
151+
}
152+
153+
#[derive(Debug, Serialize, Deserialize, Named, Clone)]
154+
pub struct Echo(pub String, pub PortRef<String>);
155+
156+
#[async_trait]
157+
impl Handler<Echo> for TestActor {
158+
async fn handle(&mut self, cx: &Context<Self>, message: Echo) -> Result<(), anyhow::Error> {
159+
let Echo(message, reply_port) = message;
160+
reply_port.send(cx, message)?;
161+
Ok(())
162+
}
163+
}
164+
165+
// -- ProxyActor
166+
167+
#[hyperactor::export(
168+
spawn = true,
169+
handlers = [
170+
Echo,
171+
],
172+
)]
173+
pub struct ProxyActor {
174+
#[allow(dead_code)]
175+
proc_mesh: Arc<ProcMesh>,
176+
actor_mesh: RootActorMesh<'static, TestActor>,
177+
}
178+
179+
impl fmt::Debug for ProxyActor {
180+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181+
f.debug_struct("ProxyActor")
182+
.field("proc_mesh", &"...")
183+
.field("actor_mesh", &"...")
184+
.finish()
185+
}
186+
}
187+
188+
#[async_trait]
189+
impl Actor for ProxyActor {
190+
type Params = String;
191+
192+
async fn new(exe_path: Self::Params) -> anyhow::Result<Self, anyhow::Error> {
193+
let mut cmd = Command::new(PathBuf::from(&exe_path));
194+
cmd.arg("--bootstrap");
195+
196+
let mut allocator = ProcessAllocator::new(cmd);
197+
198+
let alloc = allocator
199+
.allocate(AllocSpec {
200+
extent: extent! { replica = 1 },
201+
constraints: Default::default(),
202+
})
203+
.await
204+
.unwrap();
205+
let proc_mesh = Arc::new(ProcMesh::allocate(alloc).await.unwrap());
206+
let leaked: &'static Arc<ProcMesh> = Box::leak(Box::new(proc_mesh));
207+
let actor_mesh: RootActorMesh<'static, TestActor> =
208+
leaked.spawn("echo", &()).await.unwrap();
209+
Ok(Self {
210+
proc_mesh: Arc::clone(leaked),
211+
actor_mesh,
212+
})
213+
}
214+
}
215+
216+
#[async_trait]
217+
impl Handler<Echo> for ProxyActor {
218+
async fn handle(&mut self, cx: &Context<Self>, message: Echo) -> Result<(), anyhow::Error> {
219+
let actor = self.actor_mesh.get(0).unwrap();
220+
221+
let (tx, mut rx) = cx.open_port();
222+
actor.send(cx, Echo(message.0, tx.bind()))?;
223+
message.1.send(cx, rx.recv().await.unwrap())?;
224+
225+
Ok(())
226+
}
227+
}
228+
229+
async fn run_client(exe_path: PathBuf) -> Result<(), anyhow::Error> {
230+
let mut cmd = Command::new(PathBuf::from(&exe_path));
231+
cmd.arg("--bootstrap");
232+
233+
let mut allocator = ProcessAllocator::new(cmd);
234+
let alloc = allocator
235+
.allocate(AllocSpec {
236+
extent: extent! { replica = 1 },
237+
constraints: Default::default(),
238+
})
239+
.await
240+
.unwrap();
241+
242+
let mut proc_mesh = ProcMesh::allocate(alloc).await?;
243+
let actor_mesh: RootActorMesh<'_, ProxyActor> = proc_mesh
244+
.spawn("proxy", &exe_path.to_str().unwrap().to_string())
245+
.await?;
246+
let proxy_actor = actor_mesh.get(0).unwrap();
247+
let (tx, mut rx) = actor_mesh.open_port::<String>();
248+
proxy_actor.send(proc_mesh.client(), Echo("hello!".to_owned(), tx.bind()))?;
249+
250+
let msg = rx.recv().await?;
251+
println!("{}", msg);
252+
assert_eq!(msg, "hello!");
253+
254+
let mut alloc = proc_mesh.events().unwrap().into_alloc();
255+
alloc.stop_and_wait().await?;
256+
drop(alloc);
257+
258+
Ok(())
259+
}
260+
261+
#[tokio::main]
262+
async fn main() -> Result<(), anyhow::Error> {
263+
// Logs are written to /tmp/$USER/monarch_log*.
264+
initialize();
265+
266+
let args = Args::parse();
267+
if args.bootstrap {
268+
hyperactor_mesh::bootstrap_or_die().await;
269+
} else {
270+
let exe_path: PathBuf = env::current_exe().unwrap_or_else(|e| {
271+
eprintln!("Failed to get current executable path: {}", e);
272+
std::process::exit(1);
273+
});
274+
run_client(exe_path).await?;
275+
}
276+
277+
Ok(())
278+
}

0 commit comments

Comments
 (0)