From 4c2eace3b689ee77b620dbc337b15eb9ca5dd617 Mon Sep 17 00:00:00 2001 From: Charles Dixon Date: Tue, 30 Jul 2024 15:19:40 +0100 Subject: [PATCH] Update memdx channel usage to use callbacks where possible Motivation ----------- Using channels for things like orphan handler requires us to spin up new tasks, the flow is also a bit easier to follow with callbacks. Changes ------- Update orphan handling and connection close handling to use callbacks. Rewrite kvclientpool to make this possible. --- sdk/couchbase-core/Cargo.toml | 1 + sdk/couchbase-core/src/configwatcher.rs | 21 +- sdk/couchbase-core/src/crudcomponent.rs | 19 +- sdk/couchbase-core/src/kvclient.rs | 78 ++-- sdk/couchbase-core/src/kvclientmanager.rs | 7 +- sdk/couchbase-core/src/kvclientpool.rs | 427 +++++++++++---------- sdk/couchbase-core/src/lib.rs | 3 + sdk/couchbase-core/src/memdx/client.rs | 68 +--- sdk/couchbase-core/src/memdx/dispatcher.rs | 11 +- 9 files changed, 285 insertions(+), 350 deletions(-) diff --git a/sdk/couchbase-core/Cargo.toml b/sdk/couchbase-core/Cargo.toml index 2db04363..acd0e975 100644 --- a/sdk/couchbase-core/Cargo.toml +++ b/sdk/couchbase-core/Cargo.toml @@ -25,6 +25,7 @@ async-trait = "0.1.80" tokio-io = { version = "0.2.0-alpha.6", features = ["util"] } crc32fast = "1.4.2" serde_json = "1.0.120" +arc-swap = "1.7" [dev-dependencies] env_logger = "0.11" diff --git a/sdk/couchbase-core/src/configwatcher.rs b/sdk/couchbase-core/src/configwatcher.rs index 75c1bb11..6c213b6d 100644 --- a/sdk/couchbase-core/src/configwatcher.rs +++ b/sdk/couchbase-core/src/configwatcher.rs @@ -200,7 +200,6 @@ mod tests { use std::time::Duration; use tokio::sync::broadcast; - use tokio::sync::mpsc::unbounded_channel; use tokio::time::sleep; use crate::authenticator::PasswordAuthenticator; @@ -213,9 +212,8 @@ mod tests { }; use crate::kvclientpool::NaiveKvClientPool; use crate::memdx::client::Client; - use crate::memdx::packet::ResponsePacket; - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + #[tokio::test] async fn fetches_configs() { let client_config = KvClientConfig { address: "192.168.107.128:11210" @@ -245,28 +243,13 @@ mod tests { clients: client_configs, }; - let (orphan_tx, mut orphan_rx) = unbounded_channel::(); - - tokio::spawn(async move { - loop { - match orphan_rx.recv().await { - Some(resp) => { - dbg!("unexpected orphan", resp); - } - None => { - return; - } - } - } - }); - let manager: StdKvClientManager>> = StdKvClientManager::new( manger_config, KvClientManagerOptions { connect_timeout: Default::default(), connect_throttle_period: Default::default(), - orphan_handler: Arc::new(orphan_tx), + orphan_handler: Arc::new(|_| {}), }, ) .await diff --git a/sdk/couchbase-core/src/crudcomponent.rs b/sdk/couchbase-core/src/crudcomponent.rs index 537adcc3..aac25b9a 100644 --- a/sdk/couchbase-core/src/crudcomponent.rs +++ b/sdk/couchbase-core/src/crudcomponent.rs @@ -138,7 +138,6 @@ mod tests { use std::sync::Arc; use std::time::Duration; - use tokio::sync::mpsc::unbounded_channel; use tokio::time::Instant; use crate::authenticator::PasswordAuthenticator; @@ -151,7 +150,6 @@ mod tests { }; use crate::kvclientpool::NaiveKvClientPool; use crate::memdx::client::Client; - use crate::memdx::packet::ResponsePacket; use crate::vbucketmap::VbucketMap; use crate::vbucketrouter::{ NotMyVbucketConfigHandler, StdVbucketRouter, VbucketRouterOptions, VbucketRoutingInfo, @@ -171,21 +169,6 @@ mod tests { let instant = Instant::now().add(Duration::new(7, 0)); - let (orphan_tx, mut orphan_rx) = unbounded_channel::(); - - tokio::spawn(async move { - loop { - match orphan_rx.recv().await { - Some(resp) => { - dbg!("unexpected orphan", resp); - } - None => { - return; - } - } - } - }); - let client_config = KvClientConfig { address: "192.168.107.128:11210" .parse() @@ -220,7 +203,7 @@ mod tests { KvClientManagerOptions { connect_timeout: Default::default(), connect_throttle_period: Default::default(), - orphan_handler: Arc::new(orphan_tx), + orphan_handler: Arc::new(|_| {}), }, ) .await diff --git a/sdk/couchbase-core/src/kvclient.rs b/sdk/couchbase-core/src/kvclient.rs index 74fe5665..1f391359 100644 --- a/sdk/couchbase-core/src/kvclient.rs +++ b/sdk/couchbase-core/src/kvclient.rs @@ -1,12 +1,12 @@ use std::future::Future; use std::net::SocketAddr; -use std::ops::{Add, Deref}; +use std::ops::{Add, AsyncFn, Deref}; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; -use tokio::sync::{Mutex, oneshot}; -use tokio::sync::mpsc::UnboundedSender; +use futures::future::BoxFuture; +use tokio::sync::Mutex; use tokio::time::Instant; use tokio_rustls::rustls::RootCertStore; use uuid::Uuid; @@ -16,11 +16,10 @@ use crate::error::Error; use crate::error::Result; use crate::memdx::auth_mechanism::AuthMechanism; use crate::memdx::connection::{Connection, ConnectOptions}; -use crate::memdx::dispatcher::{Dispatcher, DispatcherOptions}; +use crate::memdx::dispatcher::{Dispatcher, DispatcherOptions, OrphanResponseHandler}; use crate::memdx::hello_feature::HelloFeature; use crate::memdx::op_auth_saslauto::SASLAuthAutoOptions; use crate::memdx::op_bootstrap::BootstrapOptions; -use crate::memdx::packet::ResponsePacket; use crate::memdx::request::{GetErrorMapRequest, HelloRequest, SelectBucketRequest}; use crate::service_type::ServiceType; @@ -53,9 +52,12 @@ impl PartialEq for KvClientConfig { } } +pub(crate) type OnKvClientCloseHandler = + Arc BoxFuture<'static, ()> + Send + Sync>; + pub(crate) struct KvClientOptions { - pub orphan_handler: Arc>, - pub on_close_tx: Option>, + pub orphan_handler: OrphanResponseHandler, + pub on_close: OnKvClientCloseHandler, } pub(crate) trait KvClient: Sized + PartialEq + Send + Sync { @@ -178,16 +180,26 @@ where )); } - let (connection_close_tx, mut connection_close_rx) = - oneshot::channel::>(); + let closed = Arc::new(AtomicBool::new(false)); + let closed_clone = closed.clone(); + let id = Uuid::new_v4().to_string(); + let read_id = id.clone(); + + let on_close = opts.on_close.clone(); let memdx_client_opts = DispatcherOptions { - on_connection_close_handler: Some(connection_close_tx), + on_connection_close_handler: Arc::new(move || { + // There's not much to do when the connection closes so just mark us as closed. + closed_clone.store(true, Ordering::SeqCst); + let on_close = on_close.clone(); + let read_id = read_id.clone(); + + Box::pin(async move { + on_close(read_id).await; + }) + }), orphan_handler: opts.orphan_handler, }; - let closed = Arc::new(AtomicBool::new(false)); - let closed_clone = closed.clone(); - let conn = Connection::connect( config.address, ConnectOptions { @@ -205,7 +217,6 @@ where let local_addr = *conn.local_addr(); let mut cli = D::new(conn, memdx_client_opts); - let id = Uuid::new_v4().to_string(); let mut kv_cli = StdKvClient { remote_addr, @@ -218,18 +229,6 @@ where id: id.clone(), }; - tokio::spawn(async move { - // There's not much to do when the connection closes so just mark us as closed. - if connection_close_rx.await.is_ok() { - closed_clone.store(true, Ordering::SeqCst); - }; - - if let Some(mut tx) = opts.on_close_tx { - // TODO: Probably log on failure. - tx.send(id).unwrap_or_default(); - } - }); - if should_bootstrap { if let Some(b) = &bootstrap_select_bucket { let mut guard = kv_cli.selected_bucket.lock().await; @@ -309,7 +308,7 @@ where .await { Ok(_) => {} - Err(e) => { + Err(_e) => { let mut current_bucket = self.selected_bucket.lock().await; *current_bucket = None; drop(current_bucket); @@ -362,37 +361,20 @@ mod tests { use std::sync::Arc; use std::time::Duration; - use tokio::sync::mpsc::unbounded_channel; use tokio::time::Instant; use crate::authenticator::PasswordAuthenticator; use crate::kvclient::{KvClient, KvClientConfig, KvClientOptions, StdKvClient}; use crate::kvclient_ops::KvClientOps; use crate::memdx::client::Client; - use crate::memdx::packet::ResponsePacket; use crate::memdx::request::{GetRequest, SetRequest}; - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + #[tokio::test] async fn roundtrip_a_request() { let _ = env_logger::try_init(); let instant = Instant::now().add(Duration::new(7, 0)); - let (orphan_tx, mut orphan_rx) = unbounded_channel::(); - - tokio::spawn(async move { - loop { - match orphan_rx.recv().await { - Some(resp) => { - dbg!("unexpected orphan", resp); - } - None => { - return; - } - } - } - }); - let client_config = KvClientConfig { address: "192.168.107.128:11210" .parse() @@ -416,8 +398,10 @@ mod tests { let mut client = StdKvClient::::new( client_config, KvClientOptions { - orphan_handler: Arc::new(orphan_tx), - on_close_tx: None, + orphan_handler: Arc::new(|packet| { + dbg!("unexpected orphan", packet); + }), + on_close: Arc::new(|id| Box::pin(async {})), }, ) .await diff --git a/sdk/couchbase-core/src/kvclientmanager.rs b/sdk/couchbase-core/src/kvclientmanager.rs index 8cb00047..b6ad52f3 100644 --- a/sdk/couchbase-core/src/kvclientmanager.rs +++ b/sdk/couchbase-core/src/kvclientmanager.rs @@ -3,7 +3,6 @@ use std::future::Future; use std::sync::Arc; use std::time::Duration; -use tokio::sync::mpsc::UnboundedSender; use tokio::sync::Mutex; use crate::error::ErrorKind; @@ -11,7 +10,7 @@ use crate::error::Result; use crate::kvclient::{KvClient, KvClientConfig}; use crate::kvclient_ops::KvClientOps; use crate::kvclientpool::{KvClientPool, KvClientPoolConfig, KvClientPoolOptions}; -use crate::memdx::packet::ResponsePacket; +use crate::memdx::dispatcher::OrphanResponseHandler; pub(crate) type KvClientManagerClientType = <::Pool as KvClientPool>::Client; @@ -46,11 +45,11 @@ pub(crate) struct KvClientManagerConfig { pub clients: HashMap, } -#[derive(Debug, Clone)] +#[derive(Clone)] pub(crate) struct KvClientManagerOptions { pub connect_timeout: Duration, pub connect_throttle_period: Duration, - pub orphan_handler: Arc>, + pub orphan_handler: OrphanResponseHandler, } #[derive(Debug)] diff --git a/sdk/couchbase-core/src/kvclientpool.rs b/sdk/couchbase-core/src/kvclientpool.rs index f5eed630..2a859bb3 100644 --- a/sdk/couchbase-core/src/kvclientpool.rs +++ b/sdk/couchbase-core/src/kvclientpool.rs @@ -1,19 +1,21 @@ use std::future::Future; use std::ops::{Deref, Sub}; use std::sync::Arc; -use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::time::Duration; -use tokio::sync::{mpsc, Mutex}; -use tokio::sync::mpsc::UnboundedSender; +use arc_swap::ArcSwap; +use tokio::sync::{Mutex, Notify}; use tokio::time::{Instant, sleep}; use crate::error::{Error, ErrorKind}; use crate::error::Result; -use crate::kvclient::{KvClient, KvClientConfig, KvClientOptions}; +use crate::kvclient::{KvClient, KvClientConfig, KvClientOptions, OnKvClientCloseHandler}; use crate::kvclient_ops::KvClientOps; -use crate::memdx::dispatcher::Dispatcher; -use crate::memdx::packet::ResponsePacket; +use crate::memdx::dispatcher::{Dispatcher, OrphanResponseHandler}; + +// TODO: This needs some work, some more thought should go into the locking strategy as it's possible +// there are still races in this. Additionally it's extremely easy to write in deadlocks. pub(crate) trait KvClientPool: Sized + Send + Sync { type Client: KvClient + KvClientOps + Send + Sync; @@ -37,75 +39,112 @@ pub(crate) struct KvClientPoolConfig { pub(crate) struct KvClientPoolOptions { pub connect_timeout: Duration, pub connect_throttle_period: Duration, - pub orphan_handler: Arc>, + pub orphan_handler: OrphanResponseHandler, } -struct NaiveKvClientPoolInner -where - K: KvClient, -{ +#[derive(Debug, Clone)] +struct ConnectionError { + pub connect_error: Error, + pub connect_error_time: Instant, +} + +struct KvClientPoolClientSpawner { connect_timeout: Duration, connect_throttle_period: Duration, - config: KvClientPoolConfig, + config: Arc>, - clients: Vec>, + connection_error: Mutex>, - client_idx: usize, + orphan_handler: OrphanResponseHandler, + on_client_close: OnKvClientCloseHandler, +} - connect_error: Option, - connect_error_time: Option, +struct KvClientPoolClientHandler { + num_connections: AtomicUsize, + clients: Arc>>>, + fast_map: ArcSwap>>, - orphan_handler: Arc>, + spawner: Mutex, + client_idx: AtomicUsize, - on_client_close_tx: UnboundedSender, + new_client_watcher_notif: Notify, closed: AtomicBool, } pub(crate) struct NaiveKvClientPool { - inner: Arc>>, + clients: Arc>, } -impl NaiveKvClientPoolInner +impl KvClientPoolClientHandler where - K: KvClient + PartialEq + Sync + Send + 'static, + K: KvClient + KvClientOps + PartialEq + Sync + Send + 'static, { - pub async fn new( - config: KvClientPoolConfig, - opts: KvClientPoolOptions, - on_client_close_tx: UnboundedSender, - ) -> Self { - // TODO: is unbounded the right option? - let mut inner = NaiveKvClientPoolInner:: { - connect_timeout: opts.connect_timeout, - connect_throttle_period: opts.connect_throttle_period, - config, - closed: AtomicBool::new(false), - on_client_close_tx, - orphan_handler: opts.orphan_handler, + pub async fn get_client(&self) -> Result> { + let fm = self.fast_map.load(); - clients: vec![], - connect_error: None, - connect_error_time: None, - client_idx: 0, - }; + if !fm.is_empty() { + let idx = self.client_idx.fetch_add(1, Ordering::SeqCst); + // TODO: is this unwrap ok? It should be... + let client = fm.get(idx % fm.len()).unwrap(); + return Ok(client.clone()); + } + + self.get_client_slow().await + } + + pub async fn close(&self) -> Result<()> { + if self.closed.swap(true, Ordering::SeqCst) { + return Err(ErrorKind::Shutdown.into()); + } - inner.check_connections().await; + let clients = self.clients.lock().await; + for mut client in clients.iter() { + // TODO: probably log + client.close().await.unwrap_or_default(); + } - inner + Ok(()) } - async fn check_connections(&mut self) { - let num_wanted_clients = self.config.num_connections; - let num_active_clients = self.clients.len(); + pub async fn reconfigure(&self, config: KvClientPoolConfig) -> Result<()> { + let mut old_clients = self.clients.lock().await; + let mut new_clients = vec![]; + for client in old_clients.iter() { + if let Err(e) = client.reconfigure(config.client_config.clone()).await { + // TODO: log here. + dbg!(e); + client.close().await.unwrap_or_default(); + continue; + }; + + new_clients.push(client.clone()); + } + self.spawner + .lock() + .await + .reconfigure(config.client_config) + .await; + + drop(old_clients); + self.check_connections().await; + + Ok(()) + } + + async fn check_connections(&self) { + let num_wanted_clients = self.num_connections.load(Ordering::SeqCst); + + let mut clients = self.clients.lock().await; + let num_active_clients = clients.len(); if num_active_clients > num_wanted_clients { let mut num_excess_clients = num_active_clients - num_wanted_clients; let mut num_closed_clients = 0; while num_excess_clients > 0 { - let client_to_close = self.clients.remove(0); + let client_to_close = clients.remove(0); self.shutdown_client(client_to_close).await; num_excess_clients -= 1; @@ -116,146 +155,141 @@ where if num_wanted_clients > num_active_clients { let mut num_needed_clients = num_wanted_clients - num_active_clients; while num_needed_clients > 0 { - match self.start_new_client().await { - Ok(client) => { - self.connect_error = None; - self.connect_error_time = None; - self.clients.push(Arc::new(client)); - num_needed_clients -= 1; - } - Err(e) => { - self.connect_error_time = Some(Instant::now()); - self.connect_error = Some(e); + if let Some(client) = self.spawner.lock().await.start_new_client::().await { + if self.closed.load(Ordering::SeqCst) { + client.close().await.unwrap_or_default(); } - } - } - } - } - - async fn start_new_client(&mut self) -> Result { - loop { - if let Some(error_time) = self.connect_error_time { - let connect_wait_period = - self.connect_throttle_period - Instant::now().sub(error_time); - if !connect_wait_period.is_zero() { - sleep(connect_wait_period).await; - continue; + clients.push(Arc::new(client)); + num_needed_clients -= 1; } } - break; } - let mut client_result = K::new( - self.config.client_config.clone(), - KvClientOptions { - orphan_handler: self.orphan_handler.clone(), - on_close_tx: Some(self.on_client_close_tx.clone()), - }, - ) - .await; - - if self.closed.load(Ordering::SeqCst) { - if let Ok(mut client) = client_result { - // TODO: Log something? - client.close().await.unwrap_or_default(); - } + drop(clients); - return Err(ErrorKind::Shutdown.into()); - } - - client_result + self.rebuild_fast_map().await; } - async fn get_client_slow(&mut self) -> Result> { + async fn get_client_slow(&self) -> Result> { if self.closed.load(Ordering::SeqCst) { return Err(ErrorKind::Shutdown.into()); } - if !self.clients.is_empty() { - let idx = self.client_idx; - self.client_idx += 1; + let clients = self.clients.lock().await; + if !clients.is_empty() { + let idx = self.client_idx.fetch_add(1, Ordering::SeqCst); // TODO: is this unwrap ok? It should be... - let client = self.clients.get(idx % self.clients.len()).unwrap(); + let client = clients.get(idx % clients.len()).unwrap(); return Ok(client.clone()); } - if let Some(e) = &self.connect_error { - return Err(e.clone()); + let spawner = self.spawner.lock().await; + if let Some(e) = spawner.error().await { + return Err(e.connect_error); } - self.check_connections().await; - Box::pin(self.get_client_slow()).await - } + drop(clients); - pub async fn get_client(&mut self) -> Result> { - if !self.clients.is_empty() { - let idx = self.client_idx; - self.client_idx += 1; - // TODO: is this unwrap ok? It should be... - let client = self.clients.get(idx % self.clients.len()).unwrap(); - return Ok(client.clone()); - } - - self.get_client_slow().await - } - - pub async fn shutdown_client(&mut self, client: Arc) { - let idx = self.clients.iter().position(|x| *x == client); - if let Some(idx) = idx { - self.clients.remove(idx); - } - - // TODO: Should log - client.close().await.unwrap_or_default(); + self.new_client_watcher_notif.notified(); + Box::pin(self.get_client_slow()).await } - pub async fn handle_client_close(&mut self, client_id: String) { + pub async fn handle_client_close(&self, client_id: String) { // TODO: not sure the ordering of close leading to here is great. if self.closed.load(Ordering::SeqCst) { return; } - let idx = self.clients.iter().position(|x| x.id() == client_id); + let mut clients = self.clients.lock().await; + let idx = clients.iter().position(|x| x.id() == client_id); if let Some(idx) = idx { - self.clients.remove(idx); + clients.remove(idx); } + drop(clients); self.check_connections().await; } - pub async fn close(&mut self) -> Result<()> { - if self.closed.swap(true, Ordering::SeqCst) { - return Err(ErrorKind::Shutdown.into()); - } + async fn rebuild_fast_map(&self) { + let clients = self.clients.lock().await; + let mut new_map = Vec::new(); + new_map.clone_from(clients.deref()); + self.fast_map.store(Arc::from(new_map)); - for mut client in &self.clients { - // TODO: probably log - client.close().await.unwrap_or_default(); + self.new_client_watcher_notif.notify_waiters(); + } + + pub async fn shutdown_client(&self, client: Arc) { + let mut clients = self.clients.lock().await; + let idx = clients.iter().position(|x| *x == client); + if let Some(idx) = idx { + clients.remove(idx); } - Ok(()) + drop(clients); + self.rebuild_fast_map().await; + + // TODO: Should log + client.close().await.unwrap_or_default(); } +} - pub async fn reconfigure(&mut self, config: KvClientPoolConfig) -> Result<()> { - let mut old_clients = self.clients.clone(); - let mut new_clients = vec![]; - for client in old_clients { - if let Err(e) = client.reconfigure(config.client_config.clone()).await { - // TODO: log here. - dbg!(e); - client.close().await.unwrap_or_default(); - continue; - }; +impl KvClientPoolClientSpawner { + async fn reconfigure(&self, config: KvClientConfig) { + let mut guard = self.config.lock().await; + *guard = config; + } - new_clients.push(client.clone()); - } - self.clients = new_clients; - self.config = config; + async fn error(&self) -> Option { + let err = self.connection_error.lock().await; + err.clone() + } - self.check_connections().await; + async fn start_new_client(&self) -> Option + where + K: KvClient + KvClientOps + PartialEq + Sync + Send + 'static, + { + loop { + let err = self.connection_error.lock().await; + if let Some(error) = err.deref() { + let connect_wait_period = + self.connect_throttle_period - Instant::now().sub(error.connect_error_time); - Ok(()) + if !connect_wait_period.is_zero() { + drop(err); + sleep(connect_wait_period).await; + continue; + } + } + break; + } + + let config = self.config.lock().await; + match K::new( + config.clone(), + KvClientOptions { + orphan_handler: self.orphan_handler.clone(), + on_close: self.on_client_close.clone(), + }, + ) + .await + { + Ok(r) => { + let mut e = self.connection_error.lock().await; + *e = None; + Some(r) + } + Err(e) => { + let mut err = self.connection_error.lock().await; + *err = Some(ConnectionError { + connect_error: e, + connect_error_time: Instant::now(), + }); + + None + } + } } } @@ -266,47 +300,52 @@ where type Client = K; async fn new(config: KvClientPoolConfig, opts: KvClientPoolOptions) -> Self { - // TODO: is unbounded the right option? - let (on_client_close_tx, mut on_client_close_rx) = mpsc::unbounded_channel(); - - let clients = Arc::new(Mutex::new( - NaiveKvClientPoolInner::::new(config, opts, on_client_close_tx).await, - )); - - let reader_clients = clients.clone(); - tokio::spawn(async move { - loop { - if let Some(id) = on_client_close_rx.recv().await { - reader_clients.lock().await.handle_client_close(id).await; - } else { - return; - } - } + let mut clients = Arc::new(KvClientPoolClientHandler { + num_connections: AtomicUsize::new(config.num_connections), + clients: Arc::new(Default::default()), + client_idx: AtomicUsize::new(0), + fast_map: ArcSwap::from_pointee(vec![]), + + spawner: Mutex::new(KvClientPoolClientSpawner { + connect_timeout: opts.connect_timeout, + connect_throttle_period: opts.connect_throttle_period, + orphan_handler: opts.orphan_handler.clone(), + connection_error: Mutex::new(None), + on_client_close: Arc::new(|id| Box::pin(async {})), + config: Arc::new(Mutex::new(config.client_config)), + }), + + new_client_watcher_notif: Notify::new(), + closed: AtomicBool::new(false), }); - NaiveKvClientPool { inner: clients } - } + let clients_clone = clients.clone(); + let mut spawner = clients.spawner.lock().await; + spawner.on_client_close = Arc::new(move |id| { + let clients_clone = clients_clone.clone(); + Box::pin(async move { clients_clone.handle_client_close(id).await }) + }); + drop(spawner); - async fn get_client(&self) -> Result> { - let mut clients = self.inner.lock().await; + clients.check_connections().await; - clients.get_client().await + NaiveKvClientPool { clients } } - async fn shutdown_client(&self, client: Arc) { - let mut clients = self.inner.lock().await; + async fn get_client(&self) -> Result> { + self.clients.get_client().await + } - clients.shutdown_client(client).await; + async fn shutdown_client(&self, client: Arc) { + self.clients.shutdown_client(client).await; } async fn close(&self) -> Result<()> { - let mut inner = self.inner.lock().await; - inner.close().await + self.clients.close().await } async fn reconfigure(&self, config: KvClientPoolConfig) -> Result<()> { - let mut inner = self.inner.lock().await; - inner.reconfigure(config).await + self.clients.reconfigure(config).await } } @@ -316,7 +355,6 @@ mod tests { use std::sync::Arc; use std::time::Duration; - use tokio::sync::mpsc::unbounded_channel; use tokio::time::Instant; use crate::authenticator::PasswordAuthenticator; @@ -326,30 +364,14 @@ mod tests { KvClientPool, KvClientPoolConfig, KvClientPoolOptions, NaiveKvClientPool, }; use crate::memdx::client::Client; - use crate::memdx::packet::ResponsePacket; use crate::memdx::request::{GetRequest, SetRequest}; - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + #[tokio::test] async fn roundtrip_a_request() { let _ = env_logger::try_init(); let instant = Instant::now().add(Duration::new(7, 0)); - let (orphan_tx, mut orphan_rx) = unbounded_channel::(); - - tokio::spawn(async move { - loop { - match orphan_rx.recv().await { - Some(resp) => { - dbg!("unexpected orphan", resp); - } - None => { - return; - } - } - } - }); - let client_config = KvClientConfig { address: "192.168.107.128:11210" .parse() @@ -380,7 +402,9 @@ mod tests { KvClientPoolOptions { connect_timeout: Default::default(), connect_throttle_period: Default::default(), - orphan_handler: Arc::new(orphan_tx), + orphan_handler: Arc::new(|packet| { + dbg!("unexpected orphan", packet); + }), }, ) .await; @@ -424,27 +448,12 @@ mod tests { pool.close().await.unwrap(); } - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + #[tokio::test] async fn reconfigure() { let _ = env_logger::try_init(); let instant = Instant::now().add(Duration::new(7, 0)); - let (orphan_tx, mut orphan_rx) = unbounded_channel::(); - - tokio::spawn(async move { - loop { - match orphan_rx.recv().await { - Some(resp) => { - dbg!("unexpected orphan", resp); - } - None => { - return; - } - } - } - }); - let client_config = KvClientConfig { address: "192.168.107.128:11210" .parse() @@ -475,7 +484,9 @@ mod tests { KvClientPoolOptions { connect_timeout: Default::default(), connect_throttle_period: Default::default(), - orphan_handler: Arc::new(orphan_tx), + orphan_handler: Arc::new(|packet| { + dbg!("unexpected orphan", packet); + }), }, ) .await; diff --git a/sdk/couchbase-core/src/lib.rs b/sdk/couchbase-core/src/lib.rs index 38706c6b..2ae0042d 100644 --- a/sdk/couchbase-core/src/lib.rs +++ b/sdk/couchbase-core/src/lib.rs @@ -1,4 +1,7 @@ #![feature(async_closure)] +// #![feature(unboxed_closures)] +#![feature(async_fn_traits)] +#![feature(unboxed_closures)] pub mod authenticator; pub mod cbconfig; diff --git a/sdk/couchbase-core/src/memdx/client.rs b/sdk/couchbase-core/src/memdx/client.rs index e3a45178..02997e44 100644 --- a/sdk/couchbase-core/src/memdx/client.rs +++ b/sdk/couchbase-core/src/memdx/client.rs @@ -1,8 +1,9 @@ +use std::{env, mem}; use std::cell::RefCell; use std::collections::HashMap; -use std::env; use std::io::empty; use std::net::SocketAddr; +use std::pin::pin; use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; use std::thread::spawn; @@ -25,7 +26,9 @@ use uuid::Uuid; use crate::memdx::client_response::ClientResponse; use crate::memdx::codec::KeyValueCodec; use crate::memdx::connection::{Connection, ConnectionType}; -use crate::memdx::dispatcher::{Dispatcher, DispatcherOptions}; +use crate::memdx::dispatcher::{ + Dispatcher, DispatcherOptions, OnConnectionCloseHandler, OrphanResponseHandler, +}; use crate::memdx::error; use crate::memdx::error::{CancellationErrorKind, Error, ErrorKind}; use crate::memdx::packet::{RequestPacket, ResponsePacket}; @@ -34,13 +37,12 @@ use crate::memdx::pendingop::ClientPendingOp; pub(crate) type ResponseSender = Sender>; pub(crate) type OpaqueMap = HashMap>; -#[derive(Debug)] struct ReadLoopOptions { pub client_id: String, pub local_addr: Option, pub peer_addr: Option, - pub orphan_handler: Arc>, - pub on_connection_close_tx: Option>>, + pub orphan_handler: OrphanResponseHandler, + pub on_connection_close_tx: OnConnectionCloseHandler, pub on_client_close_rx: Receiver<()>, } @@ -99,15 +101,13 @@ impl Client { async fn on_read_loop_close( stream: FramedRead, KeyValueCodec>, opaque_map: MutexGuard<'_, OpaqueMap>, - on_connection_close_tx: Option>>, + on_connection_close: OnConnectionCloseHandler, ) { drop(stream); Self::drain_opaque_map(opaque_map).await; - if let Some(handler) = on_connection_close_tx { - handler.send(Ok(())).unwrap(); - } + on_connection_close().await; } async fn read_loop( @@ -178,15 +178,7 @@ impl Client { } else { drop(map); let opaque = packet.opaque; - match opts.orphan_handler.send(packet) { - Ok(_) => {} - Err(_) => { - warn!( - "{} failed to send packet to orphan handler {}", - opts.client_id, opaque - ); - } - }; + (opts.orphan_handler)(packet); } drop(requests); } @@ -334,8 +326,6 @@ mod tests { use std::sync::Arc; use std::time::Duration; - use tokio::sync::mpsc::unbounded_channel; - use tokio::sync::oneshot; use tokio::time::Instant; use crate::memdx::auth_mechanism::AuthMechanism::{ScramSha1, ScramSha256, ScramSha512}; @@ -347,7 +337,6 @@ mod tests { use crate::memdx::op_bootstrap::{BootstrapOptions, OpBootstrap}; use crate::memdx::ops_core::OpsCore; use crate::memdx::ops_crud::OpsCrud; - use crate::memdx::packet::ResponsePacket; use crate::memdx::request::{ GetClusterConfigRequest, GetErrorMapRequest, GetRequest, HelloRequest, SelectBucketRequest, SetRequest, @@ -355,7 +344,7 @@ mod tests { use crate::memdx::response::{GetResponse, SetResponse}; use crate::memdx::sync_helpers::sync_unary_call; - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + #[tokio::test] async fn roundtrip_a_request() { let _ = env_logger::try_init(); @@ -371,36 +360,17 @@ mod tests { .await .expect("Could not connect"); - let (orphan_tx, mut orphan_rx) = unbounded_channel::(); - let (close_tx, mut close_rx) = oneshot::channel::>(); - - tokio::spawn(async move { - loop { - match orphan_rx.recv().await { - Some(resp) => { - dbg!("unexpected orphan", resp); - } - None => { - return; - } - } - } - }); - - tokio::spawn(async move { - loop { - if let Ok(resp) = close_rx.try_recv() { - dbg!("closed"); - return; - } - } - }); - let mut client = Client::new( conn, DispatcherOptions { - on_connection_close_handler: Some(close_tx), - orphan_handler: Arc::new(orphan_tx), + on_connection_close_handler: Arc::new(|| { + Box::pin(async { + dbg!("closed"); + }) + }), + orphan_handler: Arc::new(|packet| { + dbg!("unexpected orphan", packet); + }), }, ); diff --git a/sdk/couchbase-core/src/memdx/dispatcher.rs b/sdk/couchbase-core/src/memdx/dispatcher.rs index 98d0e9e7..0a36fea4 100644 --- a/sdk/couchbase-core/src/memdx/dispatcher.rs +++ b/sdk/couchbase-core/src/memdx/dispatcher.rs @@ -1,18 +1,19 @@ use std::sync::Arc; use async_trait::async_trait; -use tokio::sync::mpsc::UnboundedSender; -use tokio::sync::oneshot; +use futures::future::BoxFuture; use crate::memdx::connection::Connection; use crate::memdx::error::Result; use crate::memdx::packet::{RequestPacket, ResponsePacket}; use crate::memdx::pendingop::ClientPendingOp; -#[derive(Debug)] +pub type OrphanResponseHandler = Arc; +pub type OnConnectionCloseHandler = Arc BoxFuture<'static, ()> + Send + Sync>; + pub struct DispatcherOptions { - pub orphan_handler: Arc>, - pub on_connection_close_handler: Option>>, + pub orphan_handler: OrphanResponseHandler, + pub on_connection_close_handler: OnConnectionCloseHandler, } #[async_trait]