diff --git a/.github/workflows/vss-integration.yml b/.github/workflows/vss-integration.yml index 2a6c63704..d265d682b 100644 --- a/.github/workflows/vss-integration.yml +++ b/.github/workflows/vss-integration.yml @@ -75,7 +75,7 @@ jobs: cd ldk-node export TEST_VSS_BASE_URL="http://localhost:8080/vss" RUSTFLAGS="--cfg vss_test" cargo build --verbose --color always - RUSTFLAGS="--cfg vss_test" cargo test --test integration_tests_vss + RUSTFLAGS="--cfg vss_test --cfg tokio_unstable" cargo test --test integration_tests_vss -- --nocapture - name: Cleanup run: | diff --git a/Cargo.toml b/Cargo.toml index 5ce10d6ad..6f812920c 100755 --- a/Cargo.toml +++ b/Cargo.toml @@ -84,7 +84,8 @@ serde = { version = "1.0.210", default-features = false, features = ["std", "der serde_json = { version = "1.0.128", default-features = false, features = ["std"] } log = { version = "0.4.22", default-features = false, features = ["std"]} -vss-client = "0.3" +#vss-client = "0.3" +vss-client = { git = "https://github.com/tnull/vss-rust-client", branch = "2025-05-try-client-timeout" } prost = { version = "0.11.6", default-features = false} [target.'cfg(windows)'.dependencies] diff --git a/bindings/ldk_node.udl b/bindings/ldk_node.udl index c2f0166c8..9ee66d2f1 100644 --- a/bindings/ldk_node.udl +++ b/bindings/ldk_node.udl @@ -64,7 +64,7 @@ dictionary LogRecord { [Trait, WithForeign] interface LogWriter { - void log(LogRecord record); + void log(LogRecord record); }; interface Builder { @@ -160,8 +160,8 @@ interface Node { [Enum] interface Bolt11InvoiceDescription { - Hash(string hash); - Direct(string description); + Hash(string hash); + Direct(string description); }; interface Bolt11Payment { @@ -330,6 +330,7 @@ enum BuildError { "InvalidListeningAddresses", "InvalidAnnouncementAddresses", "InvalidNodeAlias", + "RuntimeSetupFailed", "ReadFailed", "WriteFailed", "StoragePathAccessFailed", diff --git a/src/builder.rs b/src/builder.rs index 31a0fee45..63c41f227 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -27,6 +27,7 @@ use crate::liquidity::{ use crate::logger::{log_error, log_info, LdkLogger, LogLevel, LogWriter, Logger}; use crate::message_handler::NodeCustomMessageHandler; use crate::peer_store::PeerStore; +use crate::runtime::Runtime; use crate::tx_broadcaster::TransactionBroadcaster; use crate::types::{ ChainMonitor, ChannelManager, DynStore, GossipSync, Graph, KeysManager, MessageRouter, @@ -154,6 +155,8 @@ pub enum BuildError { InvalidAnnouncementAddresses, /// The provided alias is invalid. InvalidNodeAlias, + /// An attempt to setup a runtime has failed. + RuntimeSetupFailed, /// We failed to read data from the [`KVStore`]. /// /// [`KVStore`]: lightning::util::persist::KVStore @@ -191,6 +194,7 @@ impl fmt::Display for BuildError { Self::InvalidAnnouncementAddresses => { write!(f, "Given announcement addresses are invalid.") }, + Self::RuntimeSetupFailed => write!(f, "Failed to setup a runtime."), Self::ReadFailed => write!(f, "Failed to read from store."), Self::WriteFailed => write!(f, "Failed to write to store."), Self::StoragePathAccessFailed => write!(f, "Failed to access the given storage path."), @@ -222,6 +226,7 @@ pub struct NodeBuilder { gossip_source_config: Option, liquidity_source_config: Option, log_writer_config: Option, + runtime_handle: Option, } impl NodeBuilder { @@ -238,6 +243,7 @@ impl NodeBuilder { let gossip_source_config = None; let liquidity_source_config = None; let log_writer_config = None; + let runtime_handle = None; Self { config, entropy_source_config, @@ -245,9 +251,20 @@ impl NodeBuilder { gossip_source_config, liquidity_source_config, log_writer_config, + runtime_handle, } } + /// Configures the [`Node`] instance to (re-)use a specific `tokio` runtime. + /// + /// If not provided, the node will spawn its own runtime or reuse any outer runtime context it + /// can detect. + #[cfg_attr(feature = "uniffi", allow(dead_code))] + pub fn set_runtime(&mut self, runtime_handle: tokio::runtime::Handle) -> &mut Self { + self.runtime_handle = Some(runtime_handle); + self + } + /// Configures the [`Node`] instance to source its wallet entropy from a seed file on disk. /// /// If the given file does not exist a new random seed file will be generated and @@ -582,6 +599,15 @@ impl NodeBuilder { ) -> Result { let logger = setup_logger(&self.log_writer_config, &self.config)?; + let runtime = if let Some(handle) = self.runtime_handle.as_ref() { + Arc::new(Runtime::with_handle(handle.clone())) + } else { + Arc::new(Runtime::new().map_err(|e| { + log_error!(logger, "Failed to setup tokio runtime: {}", e); + BuildError::RuntimeSetupFailed + })?) + }; + let seed_bytes = seed_bytes_from_config( &self.config, self.entropy_source_config.as_ref(), @@ -600,16 +626,14 @@ impl NodeBuilder { let vss_seed_bytes: [u8; 32] = vss_xprv.private_key.secret_bytes(); let vss_store = - VssStore::new(vss_url, store_id, vss_seed_bytes, header_provider).map_err(|e| { - log_error!(logger, "Failed to setup VssStore: {}", e); - BuildError::KVStoreSetupFailed - })?; + VssStore::new(vss_url, store_id, vss_seed_bytes, header_provider, Arc::clone(&runtime)); build_with_store_internal( config, self.chain_data_source_config.as_ref(), self.gossip_source_config.as_ref(), self.liquidity_source_config.as_ref(), seed_bytes, + runtime, logger, Arc::new(vss_store), ) @@ -619,6 +643,15 @@ impl NodeBuilder { pub fn build_with_store(&self, kv_store: Arc) -> Result { let logger = setup_logger(&self.log_writer_config, &self.config)?; + let runtime = if let Some(handle) = self.runtime_handle.as_ref() { + Arc::new(Runtime::with_handle(handle.clone())) + } else { + Arc::new(Runtime::new().map_err(|e| { + log_error!(logger, "Failed to setup tokio runtime: {}", e); + BuildError::RuntimeSetupFailed + })?) + }; + let seed_bytes = seed_bytes_from_config( &self.config, self.entropy_source_config.as_ref(), @@ -632,6 +665,7 @@ impl NodeBuilder { self.gossip_source_config.as_ref(), self.liquidity_source_config.as_ref(), seed_bytes, + runtime, logger, kv_store, ) @@ -934,7 +968,7 @@ fn build_with_store_internal( config: Arc, chain_data_source_config: Option<&ChainDataSourceConfig>, gossip_source_config: Option<&GossipSourceConfig>, liquidity_source_config: Option<&LiquiditySourceConfig>, seed_bytes: [u8; 64], - logger: Arc, kv_store: Arc, + runtime: Arc, logger: Arc, kv_store: Arc, ) -> Result { if let Err(err) = may_announce_channel(&config) { if config.announcement_addresses.is_some() { @@ -1101,8 +1135,6 @@ fn build_with_store_internal( }, }; - let runtime = Arc::new(RwLock::new(None)); - // Initialize the ChainMonitor let chain_monitor: Arc = Arc::new(chainmonitor::ChainMonitor::new( Some(Arc::clone(&chain_source)), @@ -1495,6 +1527,8 @@ fn build_with_store_internal( let (stop_sender, _) = tokio::sync::watch::channel(()); let (event_handling_stopped_sender, _) = tokio::sync::watch::channel(()); + let is_running = Arc::new(RwLock::new(false)); + Ok(Node { runtime, stop_sender, @@ -1520,6 +1554,7 @@ fn build_with_store_internal( scorer, peer_store, payment_store, + is_running, is_listening, node_metrics, }) diff --git a/src/chain/electrum.rs b/src/chain/electrum.rs index 6e62d9c08..dd6a49296 100644 --- a/src/chain/electrum.rs +++ b/src/chain/electrum.rs @@ -15,6 +15,7 @@ use crate::fee_estimator::{ ConfirmationTarget, }; use crate::logger::{log_bytes, log_error, log_info, log_trace, LdkLogger, Logger}; +use crate::runtime::Runtime; use lightning::chain::{Confirm, Filter, WatchedOutput}; use lightning::util::ser::Writeable; @@ -46,15 +47,14 @@ pub(crate) struct ElectrumRuntimeClient { electrum_client: Arc, bdk_electrum_client: Arc>, tx_sync: Arc>>, - runtime: Arc, + runtime: Arc, config: Arc, logger: Arc, } impl ElectrumRuntimeClient { pub(crate) fn new( - server_url: String, runtime: Arc, config: Arc, - logger: Arc, + server_url: String, runtime: Arc, config: Arc, logger: Arc, ) -> Result { let electrum_config = ElectrumConfigBuilder::new() .retry(ELECTRUM_CLIENT_NUM_RETRIES) @@ -187,7 +187,6 @@ impl ElectrumRuntimeClient { let spawn_fut = self.runtime.spawn_blocking(move || electrum_client.transaction_broadcast(&tx)); - let timeout_fut = tokio::time::timeout(Duration::from_secs(TX_BROADCAST_TIMEOUT_SECS), spawn_fut); diff --git a/src/chain/mod.rs b/src/chain/mod.rs index 62627797e..31eb5f53d 100644 --- a/src/chain/mod.rs +++ b/src/chain/mod.rs @@ -24,6 +24,7 @@ use crate::fee_estimator::{ }; use crate::io::utils::write_node_metrics; use crate::logger::{log_bytes, log_error, log_info, log_trace, LdkLogger, Logger}; +use crate::runtime::Runtime; use crate::types::{Broadcaster, ChainMonitor, ChannelManager, DynStore, Sweeper, Wallet}; use crate::{Error, NodeMetrics}; @@ -126,7 +127,7 @@ impl ElectrumRuntimeStatus { } pub(crate) fn start( - &mut self, server_url: String, runtime: Arc, config: Arc, + &mut self, server_url: String, runtime: Arc, config: Arc, logger: Arc, ) -> Result<(), Error> { match self { @@ -311,7 +312,7 @@ impl ChainSource { } } - pub(crate) fn start(&self, runtime: Arc) -> Result<(), Error> { + pub(crate) fn start(&self, runtime: Arc) -> Result<(), Error> { match self { Self::Electrum { server_url, electrum_runtime_status, config, logger, .. } => { electrum_runtime_status.write().unwrap().start( diff --git a/src/event.rs b/src/event.rs index e95983710..21770d91f 100644 --- a/src/event.rs +++ b/src/event.rs @@ -29,6 +29,8 @@ use crate::io::{ }; use crate::logger::{log_debug, log_error, log_info, LdkLogger}; +use crate::runtime::Runtime; + use lightning::events::bump_transaction::BumpTransactionEvent; use lightning::events::{ClosureReason, PaymentPurpose, ReplayEvent}; use lightning::events::{Event as LdkEvent, PaymentFailureReason}; @@ -53,7 +55,7 @@ use core::future::Future; use core::task::{Poll, Waker}; use std::collections::VecDeque; use std::ops::Deref; -use std::sync::{Arc, Condvar, Mutex, RwLock}; +use std::sync::{Arc, Condvar, Mutex}; use std::time::Duration; /// An event emitted by [`Node`], which should be handled by the user. @@ -451,7 +453,7 @@ where liquidity_source: Option>>>, payment_store: Arc, peer_store: Arc>, - runtime: Arc>>>, + runtime: Arc, logger: L, config: Arc, } @@ -466,8 +468,8 @@ where channel_manager: Arc, connection_manager: Arc>, output_sweeper: Arc, network_graph: Arc, liquidity_source: Option>>>, - payment_store: Arc, peer_store: Arc>, - runtime: Arc>>>, logger: L, config: Arc, + payment_store: Arc, peer_store: Arc>, runtime: Arc, + logger: L, config: Arc, ) -> Self { Self { event_queue, @@ -1049,17 +1051,14 @@ where let forwarding_channel_manager = self.channel_manager.clone(); let min = time_forwardable.as_millis() as u64; - let runtime_lock = self.runtime.read().unwrap(); - debug_assert!(runtime_lock.is_some()); + let future = async move { + let millis_to_sleep = thread_rng().gen_range(min..min * 5) as u64; + tokio::time::sleep(Duration::from_millis(millis_to_sleep)).await; - if let Some(runtime) = runtime_lock.as_ref() { - runtime.spawn(async move { - let millis_to_sleep = thread_rng().gen_range(min..min * 5) as u64; - tokio::time::sleep(Duration::from_millis(millis_to_sleep)).await; + forwarding_channel_manager.process_pending_htlc_forwards(); + }; - forwarding_channel_manager.process_pending_htlc_forwards(); - }); - } + self.runtime.spawn(future); }, LdkEvent::SpendableOutputs { outputs, channel_id } => { match self.output_sweeper.track_spendable_outputs(outputs, channel_id, true, None) { @@ -1419,31 +1418,27 @@ where debug_assert!(false, "We currently don't handle BOLT12 invoices manually, so this event should never be emitted."); }, LdkEvent::ConnectionNeeded { node_id, addresses } => { - let runtime_lock = self.runtime.read().unwrap(); - debug_assert!(runtime_lock.is_some()); - - if let Some(runtime) = runtime_lock.as_ref() { - let spawn_logger = self.logger.clone(); - let spawn_cm = Arc::clone(&self.connection_manager); - runtime.spawn(async move { - for addr in &addresses { - match spawn_cm.connect_peer_if_necessary(node_id, addr.clone()).await { - Ok(()) => { - return; - }, - Err(e) => { - log_error!( - spawn_logger, - "Failed to establish connection to peer {}@{}: {}", - node_id, - addr, - e - ); - }, - } + let spawn_logger = self.logger.clone(); + let spawn_cm = Arc::clone(&self.connection_manager); + let future = async move { + for addr in &addresses { + match spawn_cm.connect_peer_if_necessary(node_id, addr.clone()).await { + Ok(()) => { + return; + }, + Err(e) => { + log_error!( + spawn_logger, + "Failed to establish connection to peer {}@{}: {}", + node_id, + addr, + e + ); + }, } - }); - } + } + }; + self.runtime.spawn(future); }, LdkEvent::BumpTransaction(bte) => { match bte { diff --git a/src/gossip.rs b/src/gossip.rs index a8a6e3831..1185f0718 100644 --- a/src/gossip.rs +++ b/src/gossip.rs @@ -7,7 +7,8 @@ use crate::chain::ChainSource; use crate::config::RGS_SYNC_TIMEOUT_SECS; -use crate::logger::{log_error, log_trace, LdkLogger, Logger}; +use crate::logger::{log_trace, LdkLogger, Logger}; +use crate::runtime::Runtime; use crate::types::{GossipSync, Graph, P2PGossipSync, PeerManager, RapidGossipSync, UtxoLookup}; use crate::Error; @@ -15,13 +16,12 @@ use lightning_block_sync::gossip::{FutureSpawner, GossipVerifier}; use std::future::Future; use std::sync::atomic::{AtomicU32, Ordering}; -use std::sync::{Arc, RwLock}; +use std::sync::Arc; use std::time::Duration; pub(crate) enum GossipSource { P2PNetwork { gossip_sync: Arc, - logger: Arc, }, RapidGossipSync { gossip_sync: Arc, @@ -38,7 +38,7 @@ impl GossipSource { None::>, Arc::clone(&logger), )); - Self::P2PNetwork { gossip_sync, logger } + Self::P2PNetwork { gossip_sync } } pub fn new_rgs( @@ -63,12 +63,12 @@ impl GossipSource { pub(crate) fn set_gossip_verifier( &self, chain_source: Arc, peer_manager: Arc, - runtime: Arc>>>, + runtime: Arc, ) { match self { - Self::P2PNetwork { gossip_sync, logger } => { + Self::P2PNetwork { gossip_sync } => { if let Some(utxo_source) = chain_source.as_utxo_source() { - let spawner = RuntimeSpawner::new(Arc::clone(&runtime), Arc::clone(&logger)); + let spawner = RuntimeSpawner::new(Arc::clone(&runtime)); let gossip_verifier = Arc::new(GossipVerifier::new( utxo_source, spawner, @@ -133,28 +133,17 @@ impl GossipSource { } pub(crate) struct RuntimeSpawner { - runtime: Arc>>>, - logger: Arc, + runtime: Arc, } impl RuntimeSpawner { - pub(crate) fn new( - runtime: Arc>>>, logger: Arc, - ) -> Self { - Self { runtime, logger } + pub(crate) fn new(runtime: Arc) -> Self { + Self { runtime } } } impl FutureSpawner for RuntimeSpawner { fn spawn + Send + 'static>(&self, future: T) { - let rt_lock = self.runtime.read().unwrap(); - if rt_lock.is_none() { - log_error!(self.logger, "Tried spawing a future while the runtime wasn't available. This should never happen."); - debug_assert!(false, "Tried spawing a future while the runtime wasn't available. This should never happen."); - return; - } - - let runtime = rt_lock.as_ref().unwrap(); - runtime.spawn(future); + self.runtime.spawn(future); } } diff --git a/src/io/vss_store.rs b/src/io/vss_store.rs index 296eaabe3..0b0c02c40 100644 --- a/src/io/vss_store.rs +++ b/src/io/vss_store.rs @@ -6,6 +6,8 @@ // accordance with one or both of these licenses. use crate::io::utils::check_namespace_key_validity; +use crate::runtime::Runtime; + use bitcoin::hashes::{sha256, Hash, HashEngine, Hmac, HmacEngine}; use lightning::io::{self, Error, ErrorKind}; use lightning::util::persist::KVStore; @@ -15,7 +17,6 @@ use rand::RngCore; use std::panic::RefUnwindSafe; use std::sync::Arc; use std::time::Duration; -use tokio::runtime::Runtime; use vss_client::client::VssClient; use vss_client::error::VssError; use vss_client::headers::VssHeaderProvider; @@ -41,7 +42,7 @@ type CustomRetryPolicy = FilteredRetryPolicy< pub struct VssStore { client: VssClient, store_id: String, - runtime: Runtime, + runtime: Arc, storable_builder: StorableBuilder, key_obfuscator: KeyObfuscator, } @@ -49,9 +50,8 @@ pub struct VssStore { impl VssStore { pub(crate) fn new( base_url: String, store_id: String, vss_seed: [u8; 32], - header_provider: Arc, - ) -> io::Result { - let runtime = tokio::runtime::Builder::new_multi_thread().enable_all().build()?; + header_provider: Arc, runtime: Arc, + ) -> Self { let (data_encryption_key, obfuscation_master_key) = derive_data_encryption_and_obfuscation_keys(&vss_seed); let key_obfuscator = KeyObfuscator::new(obfuscation_master_key); @@ -70,7 +70,7 @@ impl VssStore { }) as _); let client = VssClient::new_with_headers(base_url, retry_policy, header_provider); - Ok(Self { client, store_id, runtime, storable_builder, key_obfuscator }) + Self { client, store_id, runtime, storable_builder, key_obfuscator } } fn build_key( @@ -137,18 +137,18 @@ impl KVStore for VssStore { key: self.build_key(primary_namespace, secondary_namespace, key)?, }; - let resp = - tokio::task::block_in_place(|| self.runtime.block_on(self.client.get_object(&request))) - .map_err(|e| { - let msg = format!( - "Failed to read from key {}/{}/{}: {}", - primary_namespace, secondary_namespace, key, e - ); - match e { - VssError::NoSuchKeyError(..) => Error::new(ErrorKind::NotFound, msg), - _ => Error::new(ErrorKind::Other, msg), - } - })?; + println!("READ: {}/{}/{}", primary_namespace, secondary_namespace, key); + let resp = self.runtime.block_on(self.client.get_object(&request)).map_err(|e| { + let msg = format!( + "Failed to read from key {}/{}/{}: {}", + primary_namespace, secondary_namespace, key, e + ); + match e { + VssError::NoSuchKeyError(..) => Error::new(ErrorKind::NotFound, msg), + _ => Error::new(ErrorKind::Other, msg), + } + })?; + println!("READ DONE: {}/{}/{}", primary_namespace, secondary_namespace, key); // unwrap safety: resp.value must be always present for a non-erroneous VSS response, otherwise // it is an API-violation which is converted to [`VssError::InternalServerError`] in [`VssClient`] let storable = Storable::decode(&resp.value.unwrap().value[..]).map_err(|e| { @@ -179,15 +179,28 @@ impl KVStore for VssStore { delete_items: vec![], }; - tokio::task::block_in_place(|| self.runtime.block_on(self.client.put_object(&request))) - .map_err(|e| { - let msg = format!( - "Failed to write to key {}/{}/{}: {}", - primary_namespace, secondary_namespace, key, e - ); - Error::new(ErrorKind::Other, msg) - })?; + println!("WRITE: {}/{}/{}", primary_namespace, secondary_namespace, key); + let res = self.runtime.block_on(async move { + tokio::time::timeout(Duration::from_secs(100), self.client.put_object(&request)) + .await + .map_err(|e| { + let msg = format!( + "Failed to write to key {}/{}/{}: {}", + primary_namespace, secondary_namespace, key, e + ); + Error::new(ErrorKind::Other, msg) + }) + }); + + println!("WRITE DONE: {}/{}/{}: {:?}", primary_namespace, secondary_namespace, key, res); + res?.map_err(|e| { + let msg = format!( + "Failed to write to key {}/{}/{}: {}", + primary_namespace, secondary_namespace, key, e + ); + Error::new(ErrorKind::Other, msg) + })?; Ok(()) } @@ -204,30 +217,33 @@ impl KVStore for VssStore { }), }; - tokio::task::block_in_place(|| self.runtime.block_on(self.client.delete_object(&request))) - .map_err(|e| { - let msg = format!( - "Failed to delete key {}/{}/{}: {}", - primary_namespace, secondary_namespace, key, e - ); - Error::new(ErrorKind::Other, msg) - })?; + println!("REMOVE: {}/{}/{}", primary_namespace, secondary_namespace, key); + self.runtime.block_on(self.client.delete_object(&request)).map_err(|e| { + let msg = format!( + "Failed to delete key {}/{}/{}: {}", + primary_namespace, secondary_namespace, key, e + ); + Error::new(ErrorKind::Other, msg) + })?; + println!("REMOVE DONE: {}/{}/{}", primary_namespace, secondary_namespace, key); Ok(()) } fn list(&self, primary_namespace: &str, secondary_namespace: &str) -> io::Result> { check_namespace_key_validity(primary_namespace, secondary_namespace, None, "list")?; - let keys = tokio::task::block_in_place(|| { - self.runtime.block_on(self.list_all_keys(primary_namespace, secondary_namespace)) - }) - .map_err(|e| { - let msg = format!( - "Failed to retrieve keys in namespace: {}/{} : {}", - primary_namespace, secondary_namespace, e - ); - Error::new(ErrorKind::Other, msg) - })?; + println!("LIST: {}/{}", primary_namespace, secondary_namespace); + let keys = self + .runtime + .block_on(self.list_all_keys(primary_namespace, secondary_namespace)) + .map_err(|e| { + let msg = format!( + "Failed to retrieve keys in namespace: {}/{} : {}", + primary_namespace, secondary_namespace, e + ); + Error::new(ErrorKind::Other, msg) + })?; + println!("LIST DONE: {}/{}", primary_namespace, secondary_namespace); Ok(keys) } @@ -266,10 +282,27 @@ mod tests { use rand::distributions::Alphanumeric; use rand::{thread_rng, Rng, RngCore}; use std::collections::HashMap; + use tokio::runtime; use vss_client::headers::FixedHeaders; #[test] - fn read_write_remove_list_persist() { + fn vss_read_write_remove_list_persist() { + let runtime = Arc::new(Runtime::new().unwrap()); + let vss_base_url = std::env::var("TEST_VSS_BASE_URL").unwrap(); + let mut rng = thread_rng(); + let rand_store_id: String = (0..7).map(|_| rng.sample(Alphanumeric) as char).collect(); + let mut vss_seed = [0u8; 32]; + rng.fill_bytes(&mut vss_seed); + let header_provider = Arc::new(FixedHeaders::new(HashMap::new())); + let vss_store = + VssStore::new(vss_base_url, rand_store_id, vss_seed, header_provider, runtime).unwrap(); + + do_read_write_remove_list_persist(&vss_store); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn vss_read_write_remove_list_persist_in_runtime_context() { + let runtime = Arc::new(Runtime::new().unwrap()); let vss_base_url = std::env::var("TEST_VSS_BASE_URL").unwrap(); let mut rng = thread_rng(); let rand_store_id: String = (0..7).map(|_| rng.sample(Alphanumeric) as char).collect(); @@ -277,8 +310,9 @@ mod tests { rng.fill_bytes(&mut vss_seed); let header_provider = Arc::new(FixedHeaders::new(HashMap::new())); let vss_store = - VssStore::new(vss_base_url, rand_store_id, vss_seed, header_provider).unwrap(); + VssStore::new(vss_base_url, rand_store_id, vss_seed, header_provider, runtime).unwrap(); do_read_write_remove_list_persist(&vss_store); + drop(vss_store) } } diff --git a/src/lib.rs b/src/lib.rs index c3bfe16d8..778587d71 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -93,6 +93,7 @@ pub mod logger; mod message_handler; pub mod payment; mod peer_store; +mod runtime; mod sweep; mod tx_broadcaster; mod types; @@ -106,6 +107,7 @@ pub use lightning; pub use lightning_invoice; pub use lightning_liquidity; pub use lightning_types; +pub use tokio; pub use vss_client; pub use balance::{BalanceDetails, LightningBalance, PendingSweepBalance}; @@ -141,6 +143,7 @@ use payment::{ UnifiedQrPayment, }; use peer_store::{PeerInfo, PeerStore}; +use runtime::Runtime; use types::{ Broadcaster, BumpTransactionEventHandler, ChainMonitor, ChannelManager, DynStore, Graph, KeysManager, OnionMessenger, PaymentStore, PeerManager, Router, Scorer, Sweeper, Wallet, @@ -176,7 +179,7 @@ uniffi::include_scaffolding!("ldk_node"); /// /// Needs to be initialized and instantiated through [`Builder::build`]. pub struct Node { - runtime: Arc>>>, + runtime: Arc, stop_sender: tokio::sync::watch::Sender<()>, event_handling_stopped_sender: tokio::sync::watch::Sender<()>, config: Arc, @@ -200,6 +203,7 @@ pub struct Node { scorer: Arc>, peer_store: Arc>>, payment_store: Arc, + is_running: Arc>, is_listening: Arc, node_metrics: Arc>, } @@ -208,27 +212,15 @@ impl Node { /// Starts the necessary background tasks, such as handling events coming from user input, /// LDK/BDK, and the peer-to-peer network. /// - /// After this returns, the [`Node`] instance can be controlled via the provided API methods in - /// a thread-safe manner. - pub fn start(&self) -> Result<(), Error> { - let runtime = - Arc::new(tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap()); - self.start_with_runtime(runtime) - } - - /// Starts the necessary background tasks (such as handling events coming from user input, - /// LDK/BDK, and the peer-to-peer network) on the the given `runtime`. - /// - /// This allows to have LDK Node reuse an outer pre-existing runtime, e.g., to avoid stacking Tokio - /// runtime contexts. + /// This will try to auto-detect an outer pre-existing runtime, e.g., to avoid stacking Tokio + /// runtime contexts. Note we require the outer runtime to be of the `multithreaded` flavor. /// /// After this returns, the [`Node`] instance can be controlled via the provided API methods in /// a thread-safe manner. - pub fn start_with_runtime(&self, runtime: Arc) -> Result<(), Error> { + pub fn start(&self) -> Result<(), Error> { // Acquire a run lock and hold it until we're setup. - let mut runtime_lock = self.runtime.write().unwrap(); - if runtime_lock.is_some() { - // We're already running. + let mut is_running_lock = self.is_running.write().unwrap(); + if *is_running_lock { return Err(Error::AlreadyRunning); } @@ -240,17 +232,14 @@ impl Node { ); // Start up any runtime-dependant chain sources (e.g. Electrum) - self.chain_source.start(Arc::clone(&runtime)).map_err(|e| { + self.chain_source.start(Arc::clone(&self.runtime)).map_err(|e| { log_error!(self.logger, "Failed to start chain syncing: {}", e); e })?; // Block to ensure we update our fee rate cache once on startup let chain_source = Arc::clone(&self.chain_source); - let runtime_ref = &runtime; - tokio::task::block_in_place(move || { - runtime_ref.block_on(async move { chain_source.update_fee_rate_estimates().await }) - })?; + self.runtime.block_on(async move { chain_source.update_fee_rate_estimates().await })?; // Spawn background task continuously syncing onchain, lightning, and fee rate cache. let stop_sync_receiver = self.stop_sender.subscribe(); @@ -258,7 +247,7 @@ impl Node { let sync_cman = Arc::clone(&self.channel_manager); let sync_cmon = Arc::clone(&self.chain_monitor); let sync_sweeper = Arc::clone(&self.output_sweeper); - runtime.spawn(async move { + self.runtime.spawn(async move { chain_source .continuously_sync_wallets(stop_sync_receiver, sync_cman, sync_cmon, sync_sweeper) .await; @@ -270,7 +259,7 @@ impl Node { let gossip_sync_logger = Arc::clone(&self.logger); let gossip_node_metrics = Arc::clone(&self.node_metrics); let mut stop_gossip_sync = self.stop_sender.subscribe(); - runtime.spawn(async move { + self.runtime.spawn(async move { let mut interval = tokio::time::interval(RGS_SYNC_INTERVAL); loop { tokio::select! { @@ -337,7 +326,7 @@ impl Node { bind_addrs.extend(resolved_address); } - runtime.spawn(async move { + self.runtime.spawn(async move { { let listener = tokio::net::TcpListener::bind(&*bind_addrs).await @@ -384,7 +373,7 @@ impl Node { let connect_logger = Arc::clone(&self.logger); let connect_peer_store = Arc::clone(&self.peer_store); let mut stop_connect = self.stop_sender.subscribe(); - runtime.spawn(async move { + self.runtime.spawn(async move { let mut interval = tokio::time::interval(PEER_RECONNECTION_INTERVAL); interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); loop { @@ -424,7 +413,7 @@ impl Node { let mut stop_bcast = self.stop_sender.subscribe(); let node_alias = self.config.node_alias.clone(); if may_announce_channel(&self.config).is_ok() { - runtime.spawn(async move { + self.runtime.spawn(async move { // We check every 30 secs whether our last broadcast is NODE_ANN_BCAST_INTERVAL away. #[cfg(not(test))] let mut interval = tokio::time::interval(Duration::from_secs(30)); @@ -501,7 +490,7 @@ impl Node { let mut stop_tx_bcast = self.stop_sender.subscribe(); let chain_source = Arc::clone(&self.chain_source); let tx_bcast_logger = Arc::clone(&self.logger); - runtime.spawn(async move { + self.runtime.spawn(async move { // Every second we try to clear our broadcasting queue. let mut interval = tokio::time::interval(Duration::from_secs(1)); interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); @@ -578,7 +567,7 @@ impl Node { let background_stop_logger = Arc::clone(&self.logger); let event_handling_stopped_sender = self.event_handling_stopped_sender.clone(); - runtime.spawn(async move { + self.runtime.spawn(async move { process_events_async( background_persister, |e| background_event_handler.handle_event(e), @@ -617,7 +606,7 @@ impl Node { let mut stop_liquidity_handler = self.stop_sender.subscribe(); let liquidity_handler = Arc::clone(&liquidity_source); let liquidity_logger = Arc::clone(&self.logger); - runtime.spawn(async move { + self.runtime.spawn(async move { loop { tokio::select! { _ = stop_liquidity_handler.changed() => { @@ -633,9 +622,8 @@ impl Node { }); } - *runtime_lock = Some(runtime); - log_info!(self.logger, "Startup complete."); + *is_running_lock = true; Ok(()) } @@ -643,9 +631,10 @@ impl Node { /// /// After this returns most API methods will return [`Error::NotRunning`]. pub fn stop(&self) -> Result<(), Error> { - let runtime = self.runtime.write().unwrap().take().ok_or(Error::NotRunning)?; - #[cfg(tokio_unstable)] - let metrics_runtime = Arc::clone(&runtime); + let mut is_running_lock = self.is_running.write().unwrap(); + if !*is_running_lock { + return Err(Error::NotRunning); + } log_info!(self.logger, "Shutting down LDK Node with node ID {}...", self.node_id()); @@ -675,14 +664,12 @@ impl Node { // FIXME: For now, we wait up to 100 secs (BDK_WALLET_SYNC_TIMEOUT_SECS + 10) to allow // event handling to exit gracefully even if it was blocked on the BDK wallet syncing. We // should drop this considerably post upgrading to BDK 1.0. - let timeout_res = tokio::task::block_in_place(move || { - runtime.block_on(async { - tokio::time::timeout( - Duration::from_secs(100), - event_handling_stopped_receiver.changed(), - ) - .await - }) + let timeout_res = self.runtime.block_on(async { + tokio::time::timeout( + Duration::from_secs(100), + event_handling_stopped_receiver.changed(), + ) + .await }); match timeout_res { @@ -708,20 +695,22 @@ impl Node { #[cfg(tokio_unstable)] { + let runtime_handle = self.runtime.handle(); log_trace!( self.logger, "Active runtime tasks left prior to shutdown: {}", - metrics_runtime.metrics().active_tasks_count() + runtime_handle.metrics().active_tasks_count() ); } log_info!(self.logger, "Shutdown complete."); + *is_running_lock = false; Ok(()) } /// Returns the status of the [`Node`]. pub fn status(&self) -> NodeStatus { - let is_running = self.runtime.read().unwrap().is_some(); + let is_running = *self.is_running.read().unwrap(); let is_listening = self.is_listening.load(Ordering::Acquire); let current_best_block = self.channel_manager.current_best_block().into(); let locked_node_metrics = self.node_metrics.read().unwrap(); @@ -842,6 +831,7 @@ impl Node { Arc::clone(&self.payment_store), Arc::clone(&self.peer_store), Arc::clone(&self.config), + Arc::clone(&self.is_running), Arc::clone(&self.logger), ) } @@ -859,6 +849,7 @@ impl Node { Arc::clone(&self.payment_store), Arc::clone(&self.peer_store), Arc::clone(&self.config), + Arc::clone(&self.is_running), Arc::clone(&self.logger), )) } @@ -869,9 +860,9 @@ impl Node { #[cfg(not(feature = "uniffi"))] pub fn bolt12_payment(&self) -> Bolt12Payment { Bolt12Payment::new( - Arc::clone(&self.runtime), Arc::clone(&self.channel_manager), Arc::clone(&self.payment_store), + Arc::clone(&self.is_running), Arc::clone(&self.logger), ) } @@ -882,9 +873,9 @@ impl Node { #[cfg(feature = "uniffi")] pub fn bolt12_payment(&self) -> Arc { Arc::new(Bolt12Payment::new( - Arc::clone(&self.runtime), Arc::clone(&self.channel_manager), Arc::clone(&self.payment_store), + Arc::clone(&self.is_running), Arc::clone(&self.logger), )) } @@ -893,11 +884,11 @@ impl Node { #[cfg(not(feature = "uniffi"))] pub fn spontaneous_payment(&self) -> SpontaneousPayment { SpontaneousPayment::new( - Arc::clone(&self.runtime), Arc::clone(&self.channel_manager), Arc::clone(&self.keys_manager), Arc::clone(&self.payment_store), Arc::clone(&self.config), + Arc::clone(&self.is_running), Arc::clone(&self.logger), ) } @@ -906,11 +897,11 @@ impl Node { #[cfg(feature = "uniffi")] pub fn spontaneous_payment(&self) -> Arc { Arc::new(SpontaneousPayment::new( - Arc::clone(&self.runtime), Arc::clone(&self.channel_manager), Arc::clone(&self.keys_manager), Arc::clone(&self.payment_store), Arc::clone(&self.config), + Arc::clone(&self.is_running), Arc::clone(&self.logger), )) } @@ -919,10 +910,10 @@ impl Node { #[cfg(not(feature = "uniffi"))] pub fn onchain_payment(&self) -> OnchainPayment { OnchainPayment::new( - Arc::clone(&self.runtime), Arc::clone(&self.wallet), Arc::clone(&self.channel_manager), Arc::clone(&self.config), + Arc::clone(&self.is_running), Arc::clone(&self.logger), ) } @@ -931,10 +922,10 @@ impl Node { #[cfg(feature = "uniffi")] pub fn onchain_payment(&self) -> Arc { Arc::new(OnchainPayment::new( - Arc::clone(&self.runtime), Arc::clone(&self.wallet), Arc::clone(&self.channel_manager), Arc::clone(&self.config), + Arc::clone(&self.is_running), Arc::clone(&self.logger), )) } @@ -1012,11 +1003,9 @@ impl Node { pub fn connect( &self, node_id: PublicKey, address: SocketAddress, persist: bool, ) -> Result<(), Error> { - let rt_lock = self.runtime.read().unwrap(); - if rt_lock.is_none() { + if !*self.is_running.read().unwrap() { return Err(Error::NotRunning); } - let runtime = rt_lock.as_ref().unwrap(); let peer_info = PeerInfo { node_id, address }; @@ -1026,10 +1015,8 @@ impl Node { // We need to use our main runtime here as a local runtime might not be around to poll // connection futures going forward. - tokio::task::block_in_place(move || { - runtime.block_on(async move { - con_cm.connect_peer_if_necessary(con_node_id, con_addr).await - }) + self.runtime.block_on(async move { + con_cm.connect_peer_if_necessary(con_node_id, con_addr).await })?; log_info!(self.logger, "Connected to peer {}@{}. ", peer_info.node_id, peer_info.address); @@ -1046,8 +1033,7 @@ impl Node { /// Will also remove the peer from the peer store, i.e., after this has been called we won't /// try to reconnect on restart. pub fn disconnect(&self, counterparty_node_id: PublicKey) -> Result<(), Error> { - let rt_lock = self.runtime.read().unwrap(); - if rt_lock.is_none() { + if !*self.is_running.read().unwrap() { return Err(Error::NotRunning); } @@ -1069,11 +1055,9 @@ impl Node { push_to_counterparty_msat: Option, channel_config: Option, announce_for_forwarding: bool, ) -> Result { - let rt_lock = self.runtime.read().unwrap(); - if rt_lock.is_none() { + if !*self.is_running.read().unwrap() { return Err(Error::NotRunning); } - let runtime = rt_lock.as_ref().unwrap(); let peer_info = PeerInfo { node_id, address }; @@ -1097,10 +1081,8 @@ impl Node { // We need to use our main runtime here as a local runtime might not be around to poll // connection futures going forward. - tokio::task::block_in_place(move || { - runtime.block_on(async move { - con_cm.connect_peer_if_necessary(con_node_id, con_addr).await - }) + self.runtime.block_on(async move { + con_cm.connect_peer_if_necessary(con_node_id, con_addr).await })?; // Fail if we have less than the channel value + anchor reserve available (if applicable). @@ -1249,8 +1231,7 @@ impl Node { /// /// [`EsploraSyncConfig::background_sync_config`]: crate::config::EsploraSyncConfig::background_sync_config pub fn sync_wallets(&self) -> Result<(), Error> { - let rt_lock = self.runtime.read().unwrap(); - if rt_lock.is_none() { + if !*self.is_running.read().unwrap() { return Err(Error::NotRunning); } @@ -1284,9 +1265,29 @@ impl Node { }, } Ok(()) - }, - ) + }) }) + //self.runtime.block_on(async move { + // match chain_source.as_ref() { + // ChainSource::Esplora { .. } => { + // chain_source.update_fee_rate_estimates().await?; + // chain_source.sync_lightning_wallet(sync_cman, sync_cmon, sync_sweeper).await?; + // chain_source.sync_onchain_wallet().await?; + // }, + // ChainSource::Electrum { .. } => { + // chain_source.update_fee_rate_estimates().await?; + // chain_source.sync_lightning_wallet(sync_cman, sync_cmon, sync_sweeper).await?; + // chain_source.sync_onchain_wallet().await?; + // }, + // ChainSource::BitcoindRpc { .. } => { + // chain_source.update_fee_rate_estimates().await?; + // chain_source + // .poll_and_update_listeners(sync_cman, sync_cmon, sync_sweeper) + // .await?; + // }, + // } + // Ok(()) + //}) } /// Close a previously opened channel. diff --git a/src/liquidity.rs b/src/liquidity.rs index 47f3dcce4..e93f33abf 100644 --- a/src/liquidity.rs +++ b/src/liquidity.rs @@ -10,6 +10,7 @@ use crate::chain::ChainSource; use crate::connection::ConnectionManager; use crate::logger::{log_debug, log_error, log_info, LdkLogger, Logger}; +use crate::runtime::Runtime; use crate::types::{ChannelManager, KeysManager, LiquidityManager, PeerManager, Wallet}; use crate::{total_anchor_channels_reserve_sats, Config, Error}; @@ -1388,7 +1389,7 @@ pub(crate) struct LSPS2BuyResponse { /// [`Bolt11Payment::receive_via_jit_channel`]: crate::payment::Bolt11Payment::receive_via_jit_channel #[derive(Clone)] pub struct LSPS1Liquidity { - runtime: Arc>>>, + runtime: Arc, wallet: Arc, connection_manager: Arc>>, liquidity_source: Option>>>, @@ -1397,7 +1398,7 @@ pub struct LSPS1Liquidity { impl LSPS1Liquidity { pub(crate) fn new( - runtime: Arc>>>, wallet: Arc, + runtime: Arc, wallet: Arc, connection_manager: Arc>>, liquidity_source: Option>>>, logger: Arc, ) -> Self { @@ -1418,19 +1419,14 @@ impl LSPS1Liquidity { let (lsp_node_id, lsp_address) = liquidity_source.get_lsps1_lsp_details().ok_or(Error::LiquiditySourceUnavailable)?; - let rt_lock = self.runtime.read().unwrap(); - let runtime = rt_lock.as_ref().unwrap(); - let con_node_id = lsp_node_id; let con_addr = lsp_address.clone(); let con_cm = Arc::clone(&self.connection_manager); // We need to use our main runtime here as a local runtime might not be around to poll // connection futures going forward. - tokio::task::block_in_place(move || { - runtime.block_on(async move { - con_cm.connect_peer_if_necessary(con_node_id, con_addr).await - }) + self.runtime.block_on(async move { + con_cm.connect_peer_if_necessary(con_node_id, con_addr).await })?; log_info!(self.logger, "Connected to LSP {}@{}. ", lsp_node_id, lsp_address); @@ -1438,18 +1434,16 @@ impl LSPS1Liquidity { let refund_address = self.wallet.get_new_address()?; let liquidity_source = Arc::clone(&liquidity_source); - let response = tokio::task::block_in_place(move || { - runtime.block_on(async move { - liquidity_source - .lsps1_request_channel( - lsp_balance_sat, - client_balance_sat, - channel_expiry_blocks, - announce_channel, - refund_address, - ) - .await - }) + let response = self.runtime.block_on(async move { + liquidity_source + .lsps1_request_channel( + lsp_balance_sat, + client_balance_sat, + channel_expiry_blocks, + announce_channel, + refund_address, + ) + .await })?; Ok(response) @@ -1463,27 +1457,20 @@ impl LSPS1Liquidity { let (lsp_node_id, lsp_address) = liquidity_source.get_lsps1_lsp_details().ok_or(Error::LiquiditySourceUnavailable)?; - let rt_lock = self.runtime.read().unwrap(); - let runtime = rt_lock.as_ref().unwrap(); - let con_node_id = lsp_node_id; let con_addr = lsp_address.clone(); let con_cm = Arc::clone(&self.connection_manager); // We need to use our main runtime here as a local runtime might not be around to poll // connection futures going forward. - tokio::task::block_in_place(move || { - runtime.block_on(async move { - con_cm.connect_peer_if_necessary(con_node_id, con_addr).await - }) + self.runtime.block_on(async move { + con_cm.connect_peer_if_necessary(con_node_id, con_addr).await })?; let liquidity_source = Arc::clone(&liquidity_source); - let response = tokio::task::block_in_place(move || { - runtime - .block_on(async move { liquidity_source.lsps1_check_order_status(order_id).await }) - })?; - + let response = self + .runtime + .block_on(async move { liquidity_source.lsps1_check_order_status(order_id).await })?; Ok(response) } } diff --git a/src/logger.rs b/src/logger.rs index 073aa92bc..d357f018d 100644 --- a/src/logger.rs +++ b/src/logger.rs @@ -13,7 +13,8 @@ pub(crate) use lightning::{log_bytes, log_debug, log_error, log_info, log_trace} pub use lightning::util::logger::Level as LogLevel; use chrono::Utc; -use log::{debug, error, info, trace, warn}; +use log::Level as LogFacadeLevel; +use log::Record as LogFacadeRecord; #[cfg(not(feature = "uniffi"))] use core::fmt; @@ -139,20 +140,32 @@ impl LogWriter for Writer { .expect("Failed to write to log file") }, Writer::LogFacadeWriter => { - macro_rules! log_with_level { - ($log_level:expr, $target: expr, $($args:tt)*) => { - match $log_level { - LogLevel::Gossip | LogLevel::Trace => trace!(target: $target, $($args)*), - LogLevel::Debug => debug!(target: $target, $($args)*), - LogLevel::Info => info!(target: $target, $($args)*), - LogLevel::Warn => warn!(target: $target, $($args)*), - LogLevel::Error => error!(target: $target, $($args)*), - } - }; - } - - let target = format!("[{}:{}]", record.module_path, record.line); - log_with_level!(record.level, &target, " {}", record.args) + let mut builder = LogFacadeRecord::builder(); + + match record.level { + LogLevel::Gossip | LogLevel::Trace => builder.level(LogFacadeLevel::Trace), + LogLevel::Debug => builder.level(LogFacadeLevel::Debug), + LogLevel::Info => builder.level(LogFacadeLevel::Info), + LogLevel::Warn => builder.level(LogFacadeLevel::Warn), + LogLevel::Error => builder.level(LogFacadeLevel::Error), + }; + + #[cfg(not(feature = "uniffi"))] + log::logger().log( + &builder + .module_path(Some(record.module_path)) + .line(Some(record.line)) + .args(format_args!("{}", record.args)) + .build(), + ); + #[cfg(feature = "uniffi")] + log::logger().log( + &builder + .module_path(Some(&record.module_path)) + .line(Some(record.line)) + .args(format_args!("{}", record.args)) + .build(), + ); }, Writer::CustomWriter(custom_logger) => custom_logger.log(record), } diff --git a/src/payment/bolt11.rs b/src/payment/bolt11.rs index 052571818..e526fc055 100644 --- a/src/payment/bolt11.rs +++ b/src/payment/bolt11.rs @@ -21,6 +21,7 @@ use crate::payment::store::{ }; use crate::payment::SendingParameters; use crate::peer_store::{PeerInfo, PeerStore}; +use crate::runtime::Runtime; use crate::types::{ChannelManager, PaymentStore}; use lightning::ln::bolt11_payment; @@ -87,24 +88,24 @@ macro_rules! maybe_convert_description { /// [BOLT 11]: https://github.com/lightning/bolts/blob/master/11-payment-encoding.md /// [`Node::bolt11_payment`]: crate::Node::bolt11_payment pub struct Bolt11Payment { - runtime: Arc>>>, + runtime: Arc, channel_manager: Arc, connection_manager: Arc>>, liquidity_source: Option>>>, payment_store: Arc, peer_store: Arc>>, config: Arc, + is_running: Arc>, logger: Arc, } impl Bolt11Payment { pub(crate) fn new( - runtime: Arc>>>, - channel_manager: Arc, + runtime: Arc, channel_manager: Arc, connection_manager: Arc>>, liquidity_source: Option>>>, payment_store: Arc, peer_store: Arc>>, - config: Arc, logger: Arc, + config: Arc, is_running: Arc>, logger: Arc, ) -> Self { Self { runtime, @@ -114,6 +115,7 @@ impl Bolt11Payment { payment_store, peer_store, config, + is_running, logger, } } @@ -126,8 +128,7 @@ impl Bolt11Payment { &self, invoice: &Bolt11Invoice, sending_parameters: Option, ) -> Result { let invoice = maybe_convert_invoice(invoice); - let rt_lock = self.runtime.read().unwrap(); - if rt_lock.is_none() { + if !*self.is_running.read().unwrap() { return Err(Error::NotRunning); } @@ -235,8 +236,7 @@ impl Bolt11Payment { sending_parameters: Option, ) -> Result { let invoice = maybe_convert_invoice(invoice); - let rt_lock = self.runtime.read().unwrap(); - if rt_lock.is_none() { + if !*self.is_running.read().unwrap() { return Err(Error::NotRunning); } @@ -649,9 +649,6 @@ impl Bolt11Payment { let (node_id, address) = liquidity_source.get_lsps2_lsp_details().ok_or(Error::LiquiditySourceUnavailable)?; - let rt_lock = self.runtime.read().unwrap(); - let runtime = rt_lock.as_ref().unwrap(); - let peer_info = PeerInfo { node_id, address }; let con_node_id = peer_info.node_id; @@ -660,39 +657,35 @@ impl Bolt11Payment { // We need to use our main runtime here as a local runtime might not be around to poll // connection futures going forward. - tokio::task::block_in_place(move || { - runtime.block_on(async move { - con_cm.connect_peer_if_necessary(con_node_id, con_addr).await - }) + self.runtime.block_on(async move { + con_cm.connect_peer_if_necessary(con_node_id, con_addr).await })?; log_info!(self.logger, "Connected to LSP {}@{}. ", peer_info.node_id, peer_info.address); let liquidity_source = Arc::clone(&liquidity_source); let (invoice, lsp_total_opening_fee, lsp_prop_opening_fee) = - tokio::task::block_in_place(move || { - runtime.block_on(async move { - if let Some(amount_msat) = amount_msat { - liquidity_source - .lsps2_receive_to_jit_channel( - amount_msat, - description, - expiry_secs, - max_total_lsp_fee_limit_msat, - ) - .await - .map(|(invoice, total_fee)| (invoice, Some(total_fee), None)) - } else { - liquidity_source - .lsps2_receive_variable_amount_to_jit_channel( - description, - expiry_secs, - max_proportional_lsp_fee_limit_ppm_msat, - ) - .await - .map(|(invoice, prop_fee)| (invoice, None, Some(prop_fee))) - } - }) + self.runtime.block_on(async move { + if let Some(amount_msat) = amount_msat { + liquidity_source + .lsps2_receive_to_jit_channel( + amount_msat, + description, + expiry_secs, + max_total_lsp_fee_limit_msat, + ) + .await + .map(|(invoice, total_fee)| (invoice, Some(total_fee), None)) + } else { + liquidity_source + .lsps2_receive_variable_amount_to_jit_channel( + description, + expiry_secs, + max_proportional_lsp_fee_limit_ppm_msat, + ) + .await + .map(|(invoice, prop_fee)| (invoice, None, Some(prop_fee))) + } })?; // Register payment in payment store. @@ -742,12 +735,12 @@ impl Bolt11Payment { /// amount times [`Config::probing_liquidity_limit_multiplier`] won't be used to send /// pre-flight probes. pub fn send_probes(&self, invoice: &Bolt11Invoice) -> Result<(), Error> { - let invoice = maybe_convert_invoice(invoice); - let rt_lock = self.runtime.read().unwrap(); - if rt_lock.is_none() { + if !*self.is_running.read().unwrap() { return Err(Error::NotRunning); } + let invoice = maybe_convert_invoice(invoice); + let (_payment_hash, _recipient_onion, route_params) = bolt11_payment::payment_parameters_from_invoice(&invoice).map_err(|_| { log_error!(self.logger, "Failed to send probes due to the given invoice being \"zero-amount\". Please use send_probes_using_amount instead."); Error::InvalidInvoice @@ -775,12 +768,12 @@ impl Bolt11Payment { pub fn send_probes_using_amount( &self, invoice: &Bolt11Invoice, amount_msat: u64, ) -> Result<(), Error> { - let invoice = maybe_convert_invoice(invoice); - let rt_lock = self.runtime.read().unwrap(); - if rt_lock.is_none() { + if !*self.is_running.read().unwrap() { return Err(Error::NotRunning); } + let invoice = maybe_convert_invoice(invoice); + let (_payment_hash, _recipient_onion, route_params) = if let Some(invoice_amount_msat) = invoice.amount_milli_satoshis() { diff --git a/src/payment/bolt12.rs b/src/payment/bolt12.rs index 8006f4bb9..8024d3f7c 100644 --- a/src/payment/bolt12.rs +++ b/src/payment/bolt12.rs @@ -35,19 +35,18 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH}; /// [BOLT 12]: https://github.com/lightning/bolts/blob/master/12-offer-encoding.md /// [`Node::bolt12_payment`]: crate::Node::bolt12_payment pub struct Bolt12Payment { - runtime: Arc>>>, channel_manager: Arc, payment_store: Arc, + is_running: Arc>, logger: Arc, } impl Bolt12Payment { pub(crate) fn new( - runtime: Arc>>>, channel_manager: Arc, payment_store: Arc, - logger: Arc, + is_running: Arc>, logger: Arc, ) -> Self { - Self { runtime, channel_manager, payment_store, logger } + Self { channel_manager, payment_store, is_running, logger } } /// Send a payment given an offer. @@ -59,10 +58,10 @@ impl Bolt12Payment { pub fn send( &self, offer: &Offer, quantity: Option, payer_note: Option, ) -> Result { - let rt_lock = self.runtime.read().unwrap(); - if rt_lock.is_none() { + if !*self.is_running.read().unwrap() { return Err(Error::NotRunning); } + let mut random_bytes = [0u8; 32]; rand::thread_rng().fill_bytes(&mut random_bytes); let payment_id = PaymentId(random_bytes); @@ -160,8 +159,7 @@ impl Bolt12Payment { pub fn send_using_amount( &self, offer: &Offer, amount_msat: u64, quantity: Option, payer_note: Option, ) -> Result { - let rt_lock = self.runtime.read().unwrap(); - if rt_lock.is_none() { + if !*self.is_running.read().unwrap() { return Err(Error::NotRunning); } diff --git a/src/payment/onchain.rs b/src/payment/onchain.rs index 046d66c69..2614e55ce 100644 --- a/src/payment/onchain.rs +++ b/src/payment/onchain.rs @@ -41,19 +41,19 @@ macro_rules! maybe_map_fee_rate_opt { /// /// [`Node::onchain_payment`]: crate::Node::onchain_payment pub struct OnchainPayment { - runtime: Arc>>>, wallet: Arc, channel_manager: Arc, config: Arc, + is_running: Arc>, logger: Arc, } impl OnchainPayment { pub(crate) fn new( - runtime: Arc>>>, wallet: Arc, - channel_manager: Arc, config: Arc, logger: Arc, + wallet: Arc, channel_manager: Arc, config: Arc, + is_running: Arc>, logger: Arc, ) -> Self { - Self { runtime, wallet, channel_manager, config, logger } + Self { wallet, channel_manager, config, is_running, logger } } /// Retrieve a new on-chain/funding address. @@ -75,8 +75,7 @@ impl OnchainPayment { pub fn send_to_address( &self, address: &bitcoin::Address, amount_sats: u64, fee_rate: Option, ) -> Result { - let rt_lock = self.runtime.read().unwrap(); - if rt_lock.is_none() { + if !*self.is_running.read().unwrap() { return Err(Error::NotRunning); } @@ -106,8 +105,7 @@ impl OnchainPayment { pub fn send_all_to_address( &self, address: &bitcoin::Address, retain_reserves: bool, fee_rate: Option, ) -> Result { - let rt_lock = self.runtime.read().unwrap(); - if rt_lock.is_none() { + if !*self.is_running.read().unwrap() { return Err(Error::NotRunning); } diff --git a/src/payment/spontaneous.rs b/src/payment/spontaneous.rs index 1508b6cd8..cdac80ff7 100644 --- a/src/payment/spontaneous.rs +++ b/src/payment/spontaneous.rs @@ -33,21 +33,21 @@ const LDK_DEFAULT_FINAL_CLTV_EXPIRY_DELTA: u32 = 144; /// /// [`Node::spontaneous_payment`]: crate::Node::spontaneous_payment pub struct SpontaneousPayment { - runtime: Arc>>>, channel_manager: Arc, keys_manager: Arc, payment_store: Arc, config: Arc, + is_running: Arc>, logger: Arc, } impl SpontaneousPayment { pub(crate) fn new( - runtime: Arc>>>, channel_manager: Arc, keys_manager: Arc, - payment_store: Arc, config: Arc, logger: Arc, + payment_store: Arc, config: Arc, is_running: Arc>, + logger: Arc, ) -> Self { - Self { runtime, channel_manager, keys_manager, payment_store, config, logger } + Self { channel_manager, keys_manager, payment_store, config, is_running, logger } } /// Send a spontaneous aka. "keysend", payment. @@ -72,8 +72,7 @@ impl SpontaneousPayment { &self, amount_msat: u64, node_id: PublicKey, sending_parameters: Option, custom_tlvs: Option>, ) -> Result { - let rt_lock = self.runtime.read().unwrap(); - if rt_lock.is_none() { + if !*self.is_running.read().unwrap() { return Err(Error::NotRunning); } @@ -180,8 +179,7 @@ impl SpontaneousPayment { /// /// [`Bolt11Payment::send_probes`]: crate::payment::Bolt11Payment pub fn send_probes(&self, amount_msat: u64, node_id: PublicKey) -> Result<(), Error> { - let rt_lock = self.runtime.read().unwrap(); - if rt_lock.is_none() { + if !*self.is_running.read().unwrap() { return Err(Error::NotRunning); } diff --git a/src/runtime.rs b/src/runtime.rs new file mode 100644 index 000000000..dcda74e63 --- /dev/null +++ b/src/runtime.rs @@ -0,0 +1,81 @@ +// This file is Copyright its original authors, visible in version control history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license , at your option. You may not use this file except in +// accordance with one or both of these licenses. + +use tokio::task::JoinHandle; + +use std::future::Future; + +pub(crate) struct Runtime { + mode: RuntimeMode, +} + +impl Runtime { + pub fn new() -> Result { + let mode = match tokio::runtime::Handle::try_current() { + Ok(handle) => RuntimeMode::Handle(handle), + Err(_) => { + let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build()?; + RuntimeMode::Owned(rt) + }, + }; + Ok(Self { mode }) + } + + pub fn with_handle(handle: tokio::runtime::Handle) -> Self { + let mode = RuntimeMode::Handle(handle); + Self { mode } + } + + pub fn spawn(&self, future: F) -> JoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + let handle = self.handle(); + handle.spawn(future) + } + + pub fn spawn_blocking(&self, func: F) -> JoinHandle + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + let handle = self.handle(); + handle.spawn_blocking(func) + } + + pub fn block_on(&self, future: F) -> F::Output { + // While we generally decided not to overthink via which call graph users would enter our + // runtime context, we'd still try to reuse whatever current context would be present + // during `block_on`, as this is the context `block_in_place` would operate on. So we try + // to detect the outer context here, and otherwise use whatever was set during + // initialization. + let handle = tokio::runtime::Handle::try_current().unwrap_or(self.handle()); + #[cfg(tokio_unstable)] + { + println!("Tokio blocking queue depth: {}", handle.metrics().blocking_queue_depth()); + println!( + "Tokio num_blocking_threads {} / idle_blocking_threads {}", + handle.metrics().num_blocking_threads(), + handle.metrics().num_idle_blocking_threads() + ); + } + tokio::task::block_in_place(move || handle.block_on(future)) + } + + pub fn handle(&self) -> tokio::runtime::Handle { + match &self.mode { + RuntimeMode::Owned(rt) => rt.handle().clone(), + RuntimeMode::Handle(handle) => handle.clone(), + } + } +} + +enum RuntimeMode { + Owned(tokio::runtime::Runtime), + Handle(tokio::runtime::Handle), +} diff --git a/tests/common/logging.rs b/tests/common/logging.rs index 6bceac29a..6db6b3082 100644 --- a/tests/common/logging.rs +++ b/tests/common/logging.rs @@ -22,11 +22,12 @@ impl Default for TestLogWriter { pub(crate) struct MockLogFacadeLogger { logs: Arc>>, + prefix: String, } impl MockLogFacadeLogger { - pub fn new() -> Self { - Self { logs: Arc::new(Mutex::new(Vec::new())) } + pub fn new(prefix: String) -> Self { + Self { prefix, logs: Arc::new(Mutex::new(Vec::new())) } } pub fn retrieve_logs(&self) -> Vec { @@ -48,6 +49,7 @@ impl LogFacadeLog for MockLogFacadeLogger { record.line().unwrap(), record.args() ); + println!("{}: {}", self.prefix, message); self.logs.lock().unwrap().push(message); } @@ -95,8 +97,10 @@ impl<'a> From> for LogFacadeRecord<'a> { } } -pub(crate) fn init_log_logger(level: LogFacadeLevelFilter) -> Arc { - let logger = Arc::new(MockLogFacadeLogger::new()); +pub(crate) fn init_log_logger( + prefix: String, level: LogFacadeLevelFilter, +) -> Arc { + let logger = Arc::new(MockLogFacadeLogger::new(prefix)); log::set_boxed_logger(Box::new(logger.clone())).unwrap(); log::set_max_level(level); diff --git a/tests/integration_tests_rust.rs b/tests/integration_tests_rust.rs index ded88d35c..a2b138a93 100644 --- a/tests/integration_tests_rust.rs +++ b/tests/integration_tests_rust.rs @@ -1275,7 +1275,7 @@ fn facade_logging() { let (_bitcoind, electrsd) = setup_bitcoind_and_electrsd(); let chain_source = TestChainSource::Esplora(&electrsd); - let logger = init_log_logger(LevelFilter::Trace); + let logger = init_log_logger("".to_owned(), LevelFilter::Trace); let mut config = random_config(false); config.log_writer = TestLogWriter::LogFacade; @@ -1287,3 +1287,15 @@ fn facade_logging() { validate_log_entry(entry); } } + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn test_start_stop_drop_in_runtime_context() { + let (_bitcoind, electrsd) = setup_bitcoind_and_electrsd(); + let chain_source = TestChainSource::Esplora(&electrsd); + + { + let config = random_config(true); + let node = setup_node(&chain_source, config, None); + node.stop().unwrap(); + } +} diff --git a/tests/integration_tests_vss.rs b/tests/integration_tests_vss.rs index 9d6ec158c..615bf8519 100644 --- a/tests/integration_tests_vss.rs +++ b/tests/integration_tests_vss.rs @@ -9,7 +9,9 @@ mod common; +use common::logging::{init_log_logger, TestLogWriter}; use ldk_node::Builder; +use log::LevelFilter; use std::collections::HashMap; #[test] @@ -17,9 +19,13 @@ fn channel_full_cycle_with_vss_store() { let (bitcoind, electrsd) = common::setup_bitcoind_and_electrsd(); println!("== Node A =="); let esplora_url = format!("http://{}", electrsd.esplora_url.as_ref().unwrap()); - let config_a = common::random_config(true); + let mut config_a = common::random_config(true); + let prefix_a = "A".to_string(); + let logger_a = init_log_logger(prefix_a, LevelFilter::Trace); + config_a.log_writer = TestLogWriter::LogFacade; let mut builder_a = Builder::from_config(config_a.node_config); builder_a.set_chain_source_esplora(esplora_url.clone(), None); + builder_a.set_log_facade_logger(); let vss_base_url = std::env::var("TEST_VSS_BASE_URL").unwrap(); let node_a = builder_a .build_with_vss_store_and_fixed_headers( @@ -31,9 +37,13 @@ fn channel_full_cycle_with_vss_store() { node_a.start().unwrap(); println!("\n== Node B =="); - let config_b = common::random_config(true); + let mut config_b = common::random_config(true); + //let prefix_b = "B".to_string(); + //let logger_b = init_log_logger(prefix_b, LevelFilter::Trace); + //config_a.log_writer = TestLogWriter::LogFacade; let mut builder_b = Builder::from_config(config_b.node_config); builder_b.set_chain_source_esplora(esplora_url.clone(), None); + //builder_b.set_log_facade_logger(); let node_b = builder_b .build_with_vss_store_and_fixed_headers( vss_base_url,