Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions crates/cli/runner/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@ tokio = { workspace = true, features = ["macros", "rt-multi-thread", "signal"] }

# misc
tracing.workspace = true

[dev-dependencies]
tokio = { workspace = true, features = ["time"] }
223 changes: 204 additions & 19 deletions crates/cli/runner/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,151 @@

use reth_tasks::{TaskExecutor, TaskManager};
use std::{future::Future, pin::pin, sync::mpsc, time::Duration};
use tokio::runtime::{Handle, Runtime};
use tracing::{debug, error, trace};

/// A tokio runtime or handle.
#[derive(Debug)]
enum RuntimeOrHandle {
/// Owned runtime that can be used for blocking operations
Runtime(Runtime),
/// Handle to an existing runtime
Handle(Handle),
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move this end of file so that the first type remains the pub one?


impl RuntimeOrHandle {
/// Returns a reference to the inner tokio runtime handle.
fn handle(&self) -> &Handle {
match self {
Self::Runtime(rt) => rt.handle(),
Self::Handle(handle) => handle,
}
}

/// Attempts to extract the runtime, returning an error if only a handle is available.
fn into_runtime(self, operation: &str) -> Result<Runtime, std::io::Error> {
let (rt, _handle) = self.into_runtime_or_handle();
rt.ok_or_else(|| {
std::io::Error::other(
format!("A tokio runtime is required to run {}. Please create a CliRunner with an owned runtime.", operation)
)
})
}

/// Chooses to return the owned runtime if it exists, otherwise returns `None` and the handle.
fn into_runtime_or_handle(self) -> (Option<Runtime>, Handle) {
match self {
Self::Runtime(runtime) => {
let handle = runtime.handle().clone();
(Some(runtime), handle)
}
Self::Handle(handle) => (None, handle),
}
}

/// Block on a future, handling both `Runtime` and `Handle` cases.
///
/// # Example
/// ```ignore
/// // Safe: Called from outside async context
/// std::thread::spawn(move || {
/// let result = handle.block_on(async { "ok" });
/// });
///
/// // Unsafe: Would panic if called directly in async context
/// // async fn bad() {
/// // handle.block_on(async { "panic!" }); // Don't do this!
/// // }
/// ```
fn block_on<F>(&self, fut: F) -> Result<F::Output, std::io::Error>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
match self {
Self::Runtime(rt) => Ok(rt.block_on(fut)),
Self::Handle(handle) => {
// Check if we're in an async context to spawn a thread to avoid panic
if Handle::try_current().is_ok() {
let handle = handle.clone();
std::thread::spawn(move || handle.block_on(fut))
.join()
.map_err(|_| std::io::Error::other("Failed to join blocking thread"))
} else {
Ok(handle.block_on(fut))
}
}
}
}

/// Spawn a blocking task that runs a future
fn spawn_blocking_task<F>(&self, fut: F) -> tokio::task::JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let handle = self.handle().clone();
self.handle().spawn_blocking(move || handle.block_on(fut))
}
}

/// Executes CLI commands.
///
/// Provides utilities for running a cli command to completion.
#[derive(Debug)]
#[non_exhaustive]
pub struct CliRunner {
tokio_runtime: tokio::runtime::Runtime,
executor: RuntimeOrHandle,
}

impl CliRunner {
/// Attempts to create a new [`CliRunner`] using the default tokio
/// [`Runtime`](tokio::runtime::Runtime).
/// [`Runtime`].
///
/// The default tokio runtime is multi-threaded, with both I/O and time drivers enabled.
pub fn try_default_runtime() -> Result<Self, std::io::Error> {
Ok(Self { tokio_runtime: tokio_runtime()? })
Ok(Self { executor: RuntimeOrHandle::Runtime(tokio_runtime()?) })
}

/// Create a new [`CliRunner`] from a provided tokio [`Runtime`].
pub const fn from_runtime(tokio_runtime: Runtime) -> Self {
Self { executor: RuntimeOrHandle::Runtime(tokio_runtime) }
}

/// Create a new [`CliRunner`] from a tokio [`Handle`].
///
/// # Warning
///
/// When using a [`Handle`], some operations may panic if called from within
/// the same runtime context.
///
/// Prefer using [`Self::from_runtime`] when possible.
///
/// # Example
/// ```ignore
/// // Use from a separate thread to avoid async context issues
/// std::thread::spawn(move || {
/// let runner = CliRunner::from_handle(handle);
/// runner.run_until_ctrl_c_with_handle(fut);
/// });
/// ```
pub const fn from_handle(handle: Handle) -> Self {
Self { executor: RuntimeOrHandle::Handle(handle) }
}

/// Returns the handle reference, regardless of whether this contains a runtime or handle
pub fn handle(&self) -> &Handle {
self.executor.handle()
}

/// Create a new [`CliRunner`] from a provided tokio [`Runtime`](tokio::runtime::Runtime).
pub const fn from_runtime(tokio_runtime: tokio::runtime::Runtime) -> Self {
Self { tokio_runtime }
/// Executes a regular future until completion or until external signal received
pub fn run_until_ctrl_c_with_handle<F, E>(self, fut: F) -> Result<(), E>
where
F: Future<Output = Result<(), E>> + Send + 'static,
E: Send + Sync + From<std::io::Error> + 'static,
{
self.executor.block_on(run_until_ctrl_c(fut)).map_err(E::from)??;
Ok(())
}
}

Expand All @@ -54,8 +176,9 @@ impl CliRunner {
F: Future<Output = Result<(), E>>,
E: Send + Sync + From<std::io::Error> + From<reth_tasks::PanickedTaskError> + 'static,
{
let tokio_runtime = self.executor.into_runtime("async commands")?;
let AsyncCliRunner { context, mut task_manager, tokio_runtime } =
AsyncCliRunner::new(self.tokio_runtime);
AsyncCliRunner::new(tokio_runtime);

// Executes the command until it finished or ctrl-c was fired
let command_res = tokio_runtime.block_on(run_to_completion_or_panic(
Expand Down Expand Up @@ -99,7 +222,8 @@ impl CliRunner {
F: Future<Output = Result<(), E>>,
E: Send + Sync + From<std::io::Error> + 'static,
{
self.tokio_runtime.block_on(run_until_ctrl_c(fut))?;
let tokio_runtime = self.executor.into_runtime("async commands")?;
tokio_runtime.block_on(run_until_ctrl_c(fut))?;
Ok(())
}

Expand All @@ -112,19 +236,21 @@ impl CliRunner {
F: Future<Output = Result<(), E>> + Send + 'static,
E: Send + Sync + From<std::io::Error> + 'static,
{
let tokio_runtime = self.tokio_runtime;
let handle = tokio_runtime.handle().clone();
let fut = tokio_runtime.handle().spawn_blocking(move || handle.block_on(fut));
tokio_runtime
.block_on(run_until_ctrl_c(async move { fut.await.expect("Failed to join task") }))?;
let fut = self.executor.spawn_blocking_task(fut);

self.executor
.block_on(run_until_ctrl_c(async move { fut.await.expect("Failed to join task") }))
.map_err(E::from)??;

// drop the tokio runtime on a separate thread because drop blocks until its pools
// (including blocking pool) are shutdown. In other words `drop(tokio_runtime)` would block
// the current thread but we want to exit right away.
std::thread::Builder::new()
.name("tokio-runtime-shutdown".to_string())
.spawn(move || drop(tokio_runtime))
.unwrap();
if let RuntimeOrHandle::Runtime(tokio_runtime) = self.executor {
std::thread::Builder::new()
.name("tokio-runtime-shutdown".to_string())
.spawn(move || drop(tokio_runtime))
.unwrap();
}

Ok(())
}
Expand All @@ -140,7 +266,7 @@ struct AsyncCliRunner {
// === impl AsyncCliRunner ===

impl AsyncCliRunner {
/// Given a tokio [`Runtime`](tokio::runtime::Runtime), creates additional context required to
/// Given a tokio [`Runtime`], creates additional context required to
/// execute commands asynchronously.
fn new(tokio_runtime: tokio::runtime::Runtime) -> Self {
let task_manager = TaskManager::new(tokio_runtime.handle().clone());
Expand All @@ -156,7 +282,7 @@ pub struct CliContext {
pub task_executor: TaskExecutor,
}

/// Creates a new default tokio multi-thread [Runtime](tokio::runtime::Runtime) with all features
/// Creates a new default tokio multi-thread [Runtime] with all features
/// enabled
pub fn tokio_runtime() -> Result<tokio::runtime::Runtime, std::io::Error> {
tokio::runtime::Builder::new_multi_thread().enable_all().build()
Expand Down Expand Up @@ -228,3 +354,62 @@ where

Ok(())
}

#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use tokio::time::sleep;

#[test]
fn test_runtime_with_run_until_ctrl_c() {
let runner = CliRunner::try_default_runtime().unwrap();
let result = runner.run_until_ctrl_c(async {
sleep(Duration::from_millis(5)).await;
Ok::<(), std::io::Error>(())
});
assert!(result.is_ok());
}

#[test]
fn test_handle_with_run_until_ctrl_c_with_handle() {
let rt = tokio_runtime().unwrap();
let handle = rt.handle().clone();

// Separate thread is used to avoid async context
let result = std::thread::spawn(move || {
let runner = CliRunner::from_handle(handle);
runner.run_until_ctrl_c_with_handle(async { Ok::<(), std::io::Error>(()) })
})
.join()
.unwrap();

assert!(result.is_ok());
}

#[test]
fn test_handle_with_run_blocking_until_ctrl_c() {
let rt = tokio_runtime().unwrap();
let handle = rt.handle().clone();

let result = std::thread::spawn(move || {
let runner = CliRunner::from_handle(handle);
runner.run_blocking_until_ctrl_c(async { Ok::<(), std::io::Error>(()) })
})
.join()
.unwrap();

assert!(result.is_ok());
}

#[test]
fn test_handle_fails_without_handle_support() {
let rt = tokio_runtime().unwrap();
let runner = CliRunner::from_handle(rt.handle().clone());

let result = runner.run_until_ctrl_c(async { Ok::<(), std::io::Error>(()) });

assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("tokio runtime is required"));
}
}