Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions crates/cli/runner/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@ tokio = { workspace = true, features = ["macros", "rt-multi-thread", "signal"] }

# misc
tracing.workspace = true

[dev-dependencies]
tokio = { workspace = true, features = ["time"] }
236 changes: 217 additions & 19 deletions crates/cli/runner/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

use reth_tasks::{TaskExecutor, TaskManager};
use std::{future::Future, pin::pin, sync::mpsc, time::Duration};
use tokio::runtime::{Handle, Runtime};
use tracing::{debug, error, trace};

/// Executes CLI commands.
Expand All @@ -20,21 +21,69 @@ use tracing::{debug, error, trace};
#[derive(Debug)]
#[non_exhaustive]
pub struct CliRunner {
tokio_runtime: tokio::runtime::Runtime,
executor: RuntimeOrHandle,
}

impl CliRunner {
/// Attempts to create a new [`CliRunner`] using the default tokio
/// [`Runtime`](tokio::runtime::Runtime).
/// [`Runtime`].
///
/// The default tokio runtime is multi-threaded, with both I/O and time drivers enabled.
pub fn try_default_runtime() -> Result<Self, std::io::Error> {
Ok(Self { tokio_runtime: tokio_runtime()? })
Ok(Self { executor: RuntimeOrHandle::Runtime(tokio_runtime()?) })
}

/// Create a new [`CliRunner`] from a provided tokio [`Runtime`].
pub const fn from_runtime(tokio_runtime: Runtime) -> Self {
Self { executor: RuntimeOrHandle::Runtime(tokio_runtime) }
}

/// Create a new [`CliRunner`] from the current tokio runtime handle.
/// Panics if not called from a tokio runtime context.
pub fn current() -> Self {
Self { executor: RuntimeOrHandle::Handle(Handle::current()) }
}

/// Create a new [`CliRunner`] from a provided tokio [`Runtime`](tokio::runtime::Runtime).
pub const fn from_runtime(tokio_runtime: tokio::runtime::Runtime) -> Self {
Self { tokio_runtime }
/// Try to create a new [`CliRunner`] from the current tokio runtime handle.
/// It does not panic if not called from a tokio runtime context.
pub fn try_current() -> Option<Self> {
Handle::try_current().ok().map(|handle| Self { executor: RuntimeOrHandle::Handle(handle) })
}

/// Create a new [`CliRunner`] from a tokio [`Handle`].
///
/// # Warning
///
/// When using a [`Handle`], some operations may panic if called from within
/// the same runtime context.
///
/// Prefer using [`Self::from_runtime`] when possible.
///
/// # Example
/// ```ignore
/// // Use from a separate thread to avoid async context issues
/// std::thread::spawn(move || {
/// let runner = CliRunner::from_handle(handle);
/// runner.run_until_ctrl_c_with_handle(fut);
/// });
/// ```
pub const fn from_handle(handle: Handle) -> Self {
Self { executor: RuntimeOrHandle::Handle(handle) }
}

/// Returns the handle reference, regardless of whether this contains a runtime or handle
pub fn handle(&self) -> &Handle {
self.executor.handle()
}

/// Executes a regular future until completion or until external signal received
pub fn run_until_ctrl_c_with_handle<F, E>(self, fut: F) -> Result<(), E>
where
F: Future<Output = Result<(), E>> + Send + 'static,
E: Send + Sync + From<std::io::Error> + 'static,
{
self.executor.block_on(run_until_ctrl_c(fut)).map_err(E::from)??;
Ok(())
}
}

Expand All @@ -54,8 +103,9 @@ impl CliRunner {
F: Future<Output = Result<(), E>>,
E: Send + Sync + From<std::io::Error> + From<reth_tasks::PanickedTaskError> + 'static,
{
let tokio_runtime = self.executor.into_runtime("async commands")?;
let AsyncCliRunner { context, mut task_manager, tokio_runtime } =
AsyncCliRunner::new(self.tokio_runtime);
AsyncCliRunner::new(tokio_runtime);

// Executes the command until it finished or ctrl-c was fired
let command_res = tokio_runtime.block_on(run_to_completion_or_panic(
Expand Down Expand Up @@ -99,7 +149,8 @@ impl CliRunner {
F: Future<Output = Result<(), E>>,
E: Send + Sync + From<std::io::Error> + 'static,
{
self.tokio_runtime.block_on(run_until_ctrl_c(fut))?;
let tokio_runtime = self.executor.into_runtime("async commands")?;
tokio_runtime.block_on(run_until_ctrl_c(fut))?;
Ok(())
}

Expand All @@ -112,19 +163,21 @@ impl CliRunner {
F: Future<Output = Result<(), E>> + Send + 'static,
E: Send + Sync + From<std::io::Error> + 'static,
{
let tokio_runtime = self.tokio_runtime;
let handle = tokio_runtime.handle().clone();
let fut = tokio_runtime.handle().spawn_blocking(move || handle.block_on(fut));
tokio_runtime
.block_on(run_until_ctrl_c(async move { fut.await.expect("Failed to join task") }))?;
let fut = self.executor.spawn_blocking_task(fut);

self.executor
.block_on(run_until_ctrl_c(async move { fut.await.expect("Failed to join task") }))
.map_err(E::from)??;

// drop the tokio runtime on a separate thread because drop blocks until its pools
// (including blocking pool) are shutdown. In other words `drop(tokio_runtime)` would block
// the current thread but we want to exit right away.
std::thread::Builder::new()
.name("tokio-runtime-shutdown".to_string())
.spawn(move || drop(tokio_runtime))
.unwrap();
if let RuntimeOrHandle::Runtime(tokio_runtime) = self.executor {
std::thread::Builder::new()
.name("tokio-runtime-shutdown".to_string())
.spawn(move || drop(tokio_runtime))
.unwrap();
}

Ok(())
}
Expand All @@ -140,7 +193,7 @@ struct AsyncCliRunner {
// === impl AsyncCliRunner ===

impl AsyncCliRunner {
/// Given a tokio [`Runtime`](tokio::runtime::Runtime), creates additional context required to
/// Given a tokio [`Runtime`], creates additional context required to
/// execute commands asynchronously.
fn new(tokio_runtime: tokio::runtime::Runtime) -> Self {
let task_manager = TaskManager::new(tokio_runtime.handle().clone());
Expand All @@ -156,7 +209,7 @@ pub struct CliContext {
pub task_executor: TaskExecutor,
}

/// Creates a new default tokio multi-thread [Runtime](tokio::runtime::Runtime) with all features
/// Creates a new default tokio multi-thread [Runtime] with all features
/// enabled
pub fn tokio_runtime() -> Result<tokio::runtime::Runtime, std::io::Error> {
tokio::runtime::Builder::new_multi_thread().enable_all().build()
Expand Down Expand Up @@ -228,3 +281,148 @@ where

Ok(())
}

/// A tokio runtime or handle.
#[derive(Debug)]
enum RuntimeOrHandle {
/// Owned runtime that can be used for blocking operations
Runtime(Runtime),
/// Handle to an existing runtime
Handle(Handle),
}

impl RuntimeOrHandle {
/// Returns a reference to the inner tokio runtime handle.
fn handle(&self) -> &Handle {
match self {
Self::Runtime(rt) => rt.handle(),
Self::Handle(handle) => handle,
}
}

/// Attempts to extract the runtime, returning an error if only a handle is available.
fn into_runtime(self, operation: &str) -> Result<Runtime, std::io::Error> {
let (rt, _handle) = self.into_runtime_or_handle();
rt.ok_or_else(|| {
std::io::Error::other(
format!("A tokio runtime is required to run {}. Please create a CliRunner with an owned runtime.", operation)
)
})
}

/// Chooses to return the owned runtime if it exists, otherwise returns `None` and the handle.
fn into_runtime_or_handle(self) -> (Option<Runtime>, Handle) {
match self {
Self::Runtime(runtime) => {
let handle = runtime.handle().clone();
(Some(runtime), handle)
}
Self::Handle(handle) => (None, handle),
}
}

/// Block on a future, handling both `Runtime` and `Handle` cases.
///
/// # Example
/// ```ignore
/// // Safe: Called from outside async context
/// std::thread::spawn(move || {
/// let result = handle.block_on(async { "ok" });
/// });
///
/// // Unsafe: Would panic if called directly in async context
/// // async fn bad() {
/// // handle.block_on(async { "panic!" }); // Don't do this!
/// // }
/// ```
fn block_on<F>(&self, fut: F) -> Result<F::Output, std::io::Error>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
match self {
Self::Runtime(rt) => Ok(rt.block_on(fut)),
Self::Handle(handle) => {
// Check if we're in an async context to spawn a thread to avoid panic
if Handle::try_current().is_ok() {
let handle = handle.clone();
std::thread::spawn(move || handle.block_on(fut))
.join()
.map_err(|_| std::io::Error::other("Failed to join blocking thread"))
} else {
Ok(handle.block_on(fut))
}
}
}
}

/// Spawn a blocking task that runs a future
fn spawn_blocking_task<F>(&self, fut: F) -> tokio::task::JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let handle = self.handle().clone();
self.handle().spawn_blocking(move || handle.block_on(fut))
}
}

#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use tokio::time::sleep;

#[test]
fn test_runtime_with_run_until_ctrl_c() {
let runner = CliRunner::try_default_runtime().unwrap();
let result = runner.run_until_ctrl_c(async {
sleep(Duration::from_millis(5)).await;
Ok::<(), std::io::Error>(())
});
assert!(result.is_ok());
}

#[test]
fn test_handle_with_run_until_ctrl_c_with_handle() {
let rt = tokio_runtime().unwrap();
let handle = rt.handle().clone();

// Separate thread is used to avoid async context
let result = std::thread::spawn(move || {
let runner = CliRunner::from_handle(handle);
runner.run_until_ctrl_c_with_handle(async { Ok::<(), std::io::Error>(()) })
})
.join()
.unwrap();

assert!(result.is_ok());
}

#[test]
fn test_handle_with_run_blocking_until_ctrl_c() {
let rt = tokio_runtime().unwrap();
let handle = rt.handle().clone();

let result = std::thread::spawn(move || {
let runner = CliRunner::from_handle(handle);
runner.run_blocking_until_ctrl_c(async { Ok::<(), std::io::Error>(()) })
})
.join()
.unwrap();

assert!(result.is_ok());
}

#[test]
fn test_handle_fails_with_run_until_ctrl_c() {
let rt = tokio_runtime().unwrap();
let runner = CliRunner::from_handle(rt.handle().clone());

// This should fail because `run_until_ctrl_c` needs an owned runtime
let result = runner.run_until_ctrl_c(async { Ok::<(), std::io::Error>(()) });

assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("tokio runtime is required"));
}
}
Loading