From 81843c83036a82dbfa52225d3d0a14ef827a8192 Mon Sep 17 00:00:00 2001 From: Charles Dixon Date: Wed, 17 Jul 2024 08:24:40 +0100 Subject: [PATCH 1/2] Add initial kvclientmanager --- sdk/couchbase-core/src/error.rs | 31 +- sdk/couchbase-core/src/kvclient.rs | 100 +++-- sdk/couchbase-core/src/kvclient_ops.rs | 23 +- sdk/couchbase-core/src/kvclientmanager.rs | 399 ++++++++++++++++++ sdk/couchbase-core/src/kvclientpool.rs | 191 ++++++++- sdk/couchbase-core/src/lib.rs | 1 + .../src/memdx/auth_mechanism.rs | 9 +- sdk/couchbase-core/src/memdx/client.rs | 28 +- sdk/couchbase-core/src/memdx/codec.rs | 6 +- sdk/couchbase-core/src/memdx/connection.rs | 18 +- sdk/couchbase-core/src/memdx/dispatcher.rs | 8 +- sdk/couchbase-core/src/memdx/error.rs | 10 +- sdk/couchbase-core/src/memdx/magic.rs | 6 +- .../src/memdx/op_auth_saslauto.rs | 18 +- .../src/memdx/op_auth_saslbyname.rs | 6 +- .../src/memdx/op_auth_saslplain.rs | 14 +- .../src/memdx/op_auth_saslscram.rs | 26 +- sdk/couchbase-core/src/memdx/op_bootstrap.rs | 22 +- sdk/couchbase-core/src/memdx/opcode.rs | 6 +- sdk/couchbase-core/src/memdx/ops_core.rs | 40 +- sdk/couchbase-core/src/memdx/ops_crud.rs | 65 +-- sdk/couchbase-core/src/memdx/pendingop.rs | 21 +- sdk/couchbase-core/src/memdx/response.rs | 56 +-- sdk/couchbase-core/src/memdx/sync_helpers.rs | 6 +- 24 files changed, 872 insertions(+), 238 deletions(-) create mode 100644 sdk/couchbase-core/src/kvclientmanager.rs diff --git a/sdk/couchbase-core/src/error.rs b/sdk/couchbase-core/src/error.rs index a1dac554..c6c85323 100644 --- a/sdk/couchbase-core/src/error.rs +++ b/sdk/couchbase-core/src/error.rs @@ -1,22 +1,23 @@ -use std::fmt::{Display, Formatter}; +use std::fmt::Display; -use crate::memdx::error::Error; +use crate::error::CoreError::{Dispatch, PlaceholderMemdxWrapper}; +use crate::memdx::error::MemdxError; -#[derive(Debug)] -pub struct CoreError { - pub msg: String, +#[derive(thiserror::Error, Debug, Eq, PartialEq)] +pub enum CoreError { + #[error("Dispatch error {0}")] + Dispatch(MemdxError), + #[error("Placeholder error {0}")] + Placeholder(String), + #[error("Placeholder memdx wrapper error {0}")] + PlaceholderMemdxWrapper(MemdxError), } -impl From for CoreError { - fn from(value: Error) -> Self { - Self { - msg: value.to_string(), +impl From for CoreError { + fn from(value: MemdxError) -> Self { + match value { + MemdxError::Dispatch(_) => Dispatch(value), + _ => PlaceholderMemdxWrapper(value), } } } - -impl Display for CoreError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.msg) - } -} diff --git a/sdk/couchbase-core/src/kvclient.rs b/sdk/couchbase-core/src/kvclient.rs index aaa124be..8cb82fc1 100644 --- a/sdk/couchbase-core/src/kvclient.rs +++ b/sdk/couchbase-core/src/kvclient.rs @@ -1,12 +1,12 @@ use std::future::Future; use std::net::SocketAddr; use std::ops::{Add, Deref}; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; -use tokio::sync::mpsc::{Sender, UnboundedSender}; -use tokio::sync::oneshot; +use tokio::sync::{Mutex, oneshot}; +use tokio::sync::mpsc::UnboundedSender; use tokio::time::Instant; use tokio_rustls::rustls::RootCertStore; use uuid::Uuid; @@ -60,10 +60,10 @@ pub(crate) struct KvClientOptions { pub(crate) trait KvClient: Sized + PartialEq + Send + Sync { fn new( - config: Arc, + config: KvClientConfig, opts: KvClientOptions, ) -> impl Future> + Send; - fn reconfigure(&self, config: Arc, on_complete: Sender>); + fn reconfigure(&self, config: KvClientConfig) -> impl Future> + Send; fn has_feature(&self, feature: HelloFeature) -> bool; fn load_factor(&self) -> f64; fn remote_addr(&self) -> SocketAddr; @@ -79,7 +79,7 @@ pub(crate) struct StdKvClient { pending_operations: u64, cli: D, - current_config: Arc, + current_config: Mutex, supported_features: Vec, @@ -87,7 +87,7 @@ pub(crate) struct StdKvClient { // so that we can use it in our errors. Note that it is set before // we send the operation to select the bucket, since things happen // asynchronously and we do not support changing selected buckets. - selected_bucket: Arc>>, + selected_bucket: Mutex>, closed: Arc, @@ -101,17 +101,13 @@ where pub fn client(&self) -> &D { &self.cli } - - pub fn client_mut(&mut self) -> &mut D { - &mut self.cli - } } impl KvClient for StdKvClient where D: Dispatcher, { - async fn new(config: Arc, opts: KvClientOptions) -> CoreResult> { + async fn new(config: KvClientConfig, opts: KvClientOptions) -> CoreResult> { let requested_features = if config.disable_default_features { vec![] } else { @@ -174,13 +170,13 @@ where if should_bootstrap && config.disable_bootstrap { // TODO: error model needs thought. - return Err(CoreError { - msg: "oopsies".to_string(), - }); + return Err(CoreError::Placeholder( + "Bootstrap was disabled but options requiring bootstrap were specified".to_string(), + )); } let (connection_close_tx, mut connection_close_rx) = - oneshot::channel::>(); + oneshot::channel::>(); let memdx_client_opts = DispatcherOptions { on_connection_close_handler: Some(connection_close_tx), orphan_handler: opts.orphan_handler, @@ -213,9 +209,9 @@ where local_addr, pending_operations: 0, cli, - current_config: config, + current_config: Mutex::new(config), supported_features: vec![], - selected_bucket: Arc::new(Mutex::new(None)), + selected_bucket: Mutex::new(None), closed, id: id.clone(), }; @@ -234,7 +230,7 @@ where if should_bootstrap { if let Some(b) = &bootstrap_select_bucket { - let mut guard = kv_cli.selected_bucket.lock().unwrap(); + let mut guard = kv_cli.selected_bucket.lock().await; *guard = Some(b.bucket_name.clone()); }; @@ -264,8 +260,64 @@ where Ok(kv_cli) } - fn reconfigure(&self, config: Arc, on_complete: Sender>) { - todo!() + async fn reconfigure(&self, config: KvClientConfig) -> CoreResult<()> { + let mut current_config = self.current_config.lock().await; + + // TODO: compare root certs or something somehow. + if !(current_config.address == config.address + && current_config.accept_all_certs == config.accept_all_certs + && current_config.client_name == config.client_name + && current_config.disable_default_features == config.disable_default_features + && current_config.disable_error_map == config.disable_error_map + && current_config.disable_bootstrap == config.disable_bootstrap) + { + return Err(CoreError::Placeholder( + "Cannot reconfigure due to conflicting options".to_string(), + )); + } + + let selected_bucket_name = if current_config.selected_bucket != config.selected_bucket { + if current_config.selected_bucket.is_some() { + return Err(CoreError::Placeholder( + "Cannot reconfigure from one selected bucket to another".to_string(), + )); + } + + current_config + .selected_bucket + .clone_from(&config.selected_bucket); + config.selected_bucket.clone() + } else { + None + }; + + if *current_config.deref() != config { + return Err(CoreError::Placeholder( + "Client config after reconfigure did not match new configuration".to_string(), + )); + } + + if let Some(bucket_name) = selected_bucket_name { + let mut current_bucket = self.selected_bucket.lock().await; + *current_bucket = Some(bucket_name.clone()); + drop(current_bucket); + + match self + .select_bucket(SelectBucketRequest { bucket_name }) + .await + { + Ok(_) => {} + Err(e) => { + let mut current_bucket = self.selected_bucket.lock().await; + *current_bucket = None; + drop(current_bucket); + + current_config.selected_bucket = None; + } + } + } + + Ok(()) } fn has_feature(&self, feature: HelloFeature) -> bool { @@ -286,9 +338,7 @@ where async fn close(&self) -> CoreResult<()> { if self.closed.swap(true, Ordering::Relaxed) { - return Err(CoreError { - msg: "closed".to_string(), - }); + return Err(CoreError::Placeholder("Client closed".to_string())); } Ok(self.cli.close().await?) @@ -362,7 +412,7 @@ mod tests { }; let mut client = StdKvClient::::new( - Arc::new(client_config), + client_config, KvClientOptions { orphan_handler: Arc::new(orphan_tx), on_close_tx: None, diff --git a/sdk/couchbase-core/src/kvclient_ops.rs b/sdk/couchbase-core/src/kvclient_ops.rs index 64a6f008..e1fb8fe7 100644 --- a/sdk/couchbase-core/src/kvclient_ops.rs +++ b/sdk/couchbase-core/src/kvclient_ops.rs @@ -2,24 +2,37 @@ use crate::error::CoreError; use crate::kvclient::{KvClient, StdKvClient}; use crate::memdx::dispatcher::Dispatcher; use crate::memdx::hello_feature::HelloFeature; -use crate::memdx::op_bootstrap::{BootstrapOptions, OpBootstrap}; +use crate::memdx::op_bootstrap::{BootstrapOptions, OpBootstrap, OpBootstrapEncoder}; use crate::memdx::ops_core::OpsCore; use crate::memdx::ops_crud::OpsCrud; use crate::memdx::pendingop::PendingOp; -use crate::memdx::request::{GetRequest, SetRequest}; -use crate::memdx::response::{BootstrapResult, GetResponse, SetResponse}; +use crate::memdx::request::{GetRequest, SelectBucketRequest, SetRequest}; +use crate::memdx::response::{BootstrapResult, GetResponse, SelectBucketResponse, SetResponse}; use crate::result::CoreResult; impl StdKvClient where D: Dispatcher, { - pub async fn bootstrap(&mut self, opts: BootstrapOptions) -> CoreResult { - OpBootstrap::bootstrap(OpsCore {}, self.client_mut(), opts) + pub async fn bootstrap(&self, opts: BootstrapOptions) -> CoreResult { + OpBootstrap::bootstrap(OpsCore {}, self.client(), opts) .await .map_err(CoreError::from) } + pub async fn select_bucket( + &self, + req: SelectBucketRequest, + ) -> CoreResult { + let mut op = OpsCore {} + .select_bucket(self.client(), req) + .await + .map_err(CoreError::from)?; + + let res = op.recv().await?; + Ok(res) + } + pub async fn get(&self, req: GetRequest) -> CoreResult { let mut op = self.ops_crud().get(self.client(), req).await?; diff --git a/sdk/couchbase-core/src/kvclientmanager.rs b/sdk/couchbase-core/src/kvclientmanager.rs new file mode 100644 index 00000000..e8857791 --- /dev/null +++ b/sdk/couchbase-core/src/kvclientmanager.rs @@ -0,0 +1,399 @@ +use std::collections::HashMap; +use std::future::Future; +use std::marker::PhantomData; +use std::sync::Arc; +use std::time::Duration; + +use tokio::sync::mpsc::UnboundedSender; +use tokio::sync::Mutex; + +use crate::error::CoreError; +use crate::kvclient::{KvClient, KvClientConfig}; +use crate::kvclientpool::{KvClientPool, KvClientPoolConfig, KvClientPoolOptions}; +use crate::memdx::packet::ResponsePacket; +use crate::result::CoreResult; + +pub(crate) trait KvClientManager: Sized + Send + Sync { + fn new( + config: KvClientManagerConfig, + opts: KvClientManagerOptions, + ) -> impl Future> + Send; + fn reconfigure( + &self, + config: KvClientManagerConfig, + ) -> impl Future> + Send; + fn get_client(&self, endpoint: String) -> impl Future>> + Send; + fn get_random_client(&self) -> impl Future>> + Send; + fn shutdown_client( + &self, + endpoint: String, + client: Arc, + ) -> impl Future> + Send; + fn close(&self) -> impl Future> + Send; + fn orchestrate_operation( + &self, + endpoint: String, + operation: impl Fn(Arc) -> Fut, + ) -> impl Future> + where + Fut: Future> + Send; +} + +#[derive(Debug)] +pub(crate) struct KvClientManagerConfig { + pub num_pool_connections: usize, + pub clients: HashMap, +} + +#[derive(Debug, Clone)] +pub(crate) struct KvClientManagerOptions { + pub connect_timeout: Duration, + pub connect_throttle_period: Duration, + pub orphan_handler: Arc>, +} + +#[derive(Debug)] +struct KvClientManagerPool +where + P: KvClientPool, +{ + config: KvClientPoolConfig, + pool: Arc

, + _phantom_client_type: PhantomData, +} + +#[derive(Debug, Default)] +struct KvClientManagerState +where + P: KvClientPool, +{ + pub client_pools: HashMap>, +} + +pub(crate) struct StdKvClientManager +where + P: KvClientPool, +{ + state: Mutex>, + opts: KvClientManagerOptions, +} + +impl StdKvClientManager +where + K: KvClient, + P: KvClientPool, +{ + async fn get_pool(&self, endpoint: String) -> CoreResult> { + let state = self.state.lock().await; + + let pool = match state.client_pools.get(&endpoint) { + Some(p) => p, + None => { + return Err(CoreError::Placeholder("Endpoint not known".to_string())); + } + }; + + Ok(pool.pool.clone()) + } + + async fn get_random_pool(&self) -> CoreResult> { + let state = self.state.lock().await; + + // Just pick one at random for now + if let Some((_, pool)) = state.client_pools.iter().next() { + return Ok(pool.pool.clone()); + } + + Err(CoreError::Placeholder("Endpoint not known".to_string())) + } + + async fn create_pool(&self, pool_config: KvClientPoolConfig) -> KvClientManagerPool { + let pool = P::new( + pool_config.clone(), + KvClientPoolOptions { + connect_timeout: self.opts.connect_timeout, + connect_throttle_period: self.opts.connect_throttle_period, + orphan_handler: self.opts.orphan_handler.clone(), + }, + ) + .await; + + KvClientManagerPool { + config: pool_config, + pool: Arc::new(pool), + _phantom_client_type: Default::default(), + } + } +} + +impl KvClientManager for StdKvClientManager +where + K: KvClient, + P: KvClientPool, +{ + async fn new(config: KvClientManagerConfig, opts: KvClientManagerOptions) -> CoreResult { + let manager = Self { + state: Mutex::new(KvClientManagerState { + client_pools: Default::default(), + }), + opts, + }; + + manager.reconfigure(config).await?; + Ok(manager) + } + + async fn reconfigure(&self, config: KvClientManagerConfig) -> CoreResult<()> { + let mut guard = self.state.lock().await; + + let mut old_pools = std::mem::take(&mut guard.client_pools); + + let mut new_state = KvClientManagerState:: { + client_pools: Default::default(), + }; + + for (endpoint, endpoint_config) in config.clients { + let pool_config = KvClientPoolConfig { + num_connections: config.num_pool_connections, + client_config: endpoint_config, + }; + + let old_pool = old_pools.remove(&endpoint); + let new_pool = if let Some(pool) = old_pool { + // TODO: log on error. + if pool.pool.reconfigure(pool_config.clone()).await.is_ok() { + pool + } else { + self.create_pool(pool_config).await + } + } else { + self.create_pool(pool_config).await + }; + + new_state.client_pools.insert(endpoint, new_pool); + } + + for (_, pool) in old_pools { + // TODO: log? + pool.pool.close().await.unwrap_or_default(); + } + + *guard = new_state; + + Ok(()) + } + + async fn get_client(&self, endpoint: String) -> CoreResult> { + let pool = self.get_pool(endpoint).await?; + + pool.get_client().await + } + + async fn get_random_client(&self) -> CoreResult> { + let pool = self.get_random_pool().await?; + + pool.get_client().await + } + + async fn shutdown_client(&self, endpoint: String, client: Arc) -> CoreResult<()> { + let pool = self.get_pool(endpoint).await?; + + pool.shutdown_client(client).await; + + Ok(()) + } + + async fn close(&self) -> CoreResult<()> { + let mut guard = self.state.lock().await; + + let mut old_pools = std::mem::take(&mut guard.client_pools); + + for (_, pool) in old_pools { + // TODO: log error. + pool.pool.close().await.unwrap_or_default(); + } + + Ok(()) + } + + async fn orchestrate_operation( + &self, + endpoint: String, + operation: impl Fn(Arc) -> Fut, + ) -> CoreResult + where + Fut: Future> + Send, + { + loop { + let client = self.get_client(endpoint.clone()).await?; + + let res = operation(client.clone()).await; + match res { + Ok(r) => { + return Ok(r); + } + Err(e) => match e { + CoreError::Dispatch(_) => { + // This was a dispatch error, so we can just try with + // a different client instead... + // TODO: Log something + self.shutdown_client(endpoint.clone(), client) + .await + .unwrap_or_default(); + } + _ => { + return Err(e); + } + }, + } + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::ops::Add; + use std::sync::Arc; + use std::time::Duration; + + use tokio::sync::mpsc::unbounded_channel; + use tokio::time::Instant; + + use crate::authenticator::PasswordAuthenticator; + use crate::kvclient::{KvClient, KvClientConfig, StdKvClient}; + use crate::kvclientmanager::{ + KvClientManager, KvClientManagerConfig, KvClientManagerOptions, StdKvClientManager, + }; + use crate::kvclientpool::{KvClientPool, NaiveKvClientPool}; + use crate::memdx::client::Client; + use crate::memdx::packet::ResponsePacket; + use crate::memdx::request::{GetRequest, SetRequest}; + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn roundtrip_a_request() { + let _ = env_logger::try_init(); + + let instant = Instant::now().add(Duration::new(7, 0)); + + let (orphan_tx, mut orphan_rx) = unbounded_channel::(); + + tokio::spawn(async move { + loop { + match orphan_rx.recv().await { + Some(resp) => { + dbg!("unexpected orphan", resp); + } + None => { + return; + } + } + } + }); + + let client_config = KvClientConfig { + address: "192.168.107.128:11210" + .parse() + .expect("Failed to parse address"), + root_certs: None, + accept_all_certs: None, + client_name: "myclient".to_string(), + authenticator: Some(Arc::new(PasswordAuthenticator { + username: "Administrator".to_string(), + password: "password".to_string(), + })), + selected_bucket: Some("default".to_string()), + disable_default_features: false, + disable_error_map: false, + disable_bootstrap: false, + }; + + let mut client_configs = HashMap::new(); + client_configs.insert("192.168.107.128:11210".to_string(), client_config); + + let manger_config = KvClientManagerConfig { + num_pool_connections: 1, + clients: client_configs, + }; + + let manager: StdKvClientManager< + NaiveKvClientPool>, + StdKvClient, + > = StdKvClientManager::new( + manger_config, + KvClientManagerOptions { + connect_timeout: Default::default(), + connect_throttle_period: Default::default(), + orphan_handler: Arc::new(orphan_tx), + }, + ) + .await + .unwrap(); + + let result = manager + .orchestrate_operation( + "192.168.107.128:11210".to_string(), + |client: Arc>| async move { + client + .set(SetRequest { + collection_id: 0, + key: "test".as_bytes().into(), + vbucket_id: 1, + flags: 0, + value: "test".as_bytes().into(), + datatype: 0, + expiry: None, + preserve_expiry: None, + cas: None, + on_behalf_of: None, + durability_level: None, + durability_level_timeout: None, + }) + .await + }, + ) + .await + .unwrap(); + + dbg!(result); + + let client = manager + .get_client("192.168.107.128:11210".to_string()) + .await + .unwrap(); + + let result = client + .set(SetRequest { + collection_id: 0, + key: "test".as_bytes().into(), + vbucket_id: 1, + flags: 0, + value: "test".as_bytes().into(), + datatype: 0, + expiry: None, + preserve_expiry: None, + cas: None, + on_behalf_of: None, + durability_level: None, + durability_level_timeout: None, + }) + .await + .unwrap(); + + dbg!(result); + + let get_result = client + .get(GetRequest { + collection_id: 0, + key: "test".as_bytes().into(), + vbucket_id: 1, + on_behalf_of: None, + }) + .await + .unwrap(); + + dbg!(get_result); + + manager.close().await.unwrap(); + } +} diff --git a/sdk/couchbase-core/src/kvclientpool.rs b/sdk/couchbase-core/src/kvclientpool.rs index c0794996..dfce2b37 100644 --- a/sdk/couchbase-core/src/kvclientpool.rs +++ b/sdk/couchbase-core/src/kvclientpool.rs @@ -1,5 +1,5 @@ use std::future::Future; -use std::ops::Sub; +use std::ops::{Deref, Sub}; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; @@ -22,12 +22,16 @@ pub(crate) trait KvClientPool: Sized + Send + Sync { fn get_client(&self) -> impl Future>> + Send; fn shutdown_client(&self, client: Arc) -> impl Future + Send; fn close(&self) -> impl Future> + Send; + fn reconfigure( + &self, + config: KvClientPoolConfig, + ) -> impl Future> + Send; } -#[derive(Debug)] +#[derive(Debug, Clone)] pub(crate) struct KvClientPoolConfig { pub num_connections: usize, - pub client_config: Arc, + pub client_config: KvClientConfig, } pub(crate) struct KvClientPoolOptions { @@ -36,8 +40,6 @@ pub(crate) struct KvClientPoolOptions { pub orphan_handler: Arc>, } -type KvClientsList = Vec>; - struct NaiveKvClientPoolInner where K: KvClient, @@ -47,7 +49,7 @@ where config: KvClientPoolConfig, - clients: KvClientsList, + clients: Vec>, client_idx: usize, @@ -159,9 +161,7 @@ where client.close().await.unwrap_or_default(); } - return Err(CoreError { - msg: "closed".to_string(), - }); + return Err(CoreError::Placeholder("Closed".to_string())); } client_result @@ -169,9 +169,7 @@ where async fn get_client_slow(&mut self) -> CoreResult> { if self.closed.load(Ordering::SeqCst) { - return Err(CoreError { - msg: "closed".to_string(), - }); + return Err(CoreError::Placeholder("Closed".to_string())); } if !self.clients.is_empty() { @@ -183,7 +181,7 @@ where } if let Some(e) = &self.connect_error { - return Err(CoreError { msg: e.to_string() }); + return Err(CoreError::Placeholder(e.to_string())); } self.check_connections().await; @@ -228,9 +226,7 @@ where pub async fn close(&mut self) -> CoreResult<()> { if self.closed.swap(true, Ordering::SeqCst) { - return Err(CoreError { - msg: "closed".to_string(), - }); + return Err(CoreError::Placeholder("Closed".to_string())); } for mut client in &self.clients { @@ -240,6 +236,27 @@ where Ok(()) } + + pub async fn reconfigure(&mut self, config: KvClientPoolConfig) -> CoreResult<()> { + let mut old_clients = self.clients.clone(); + let mut new_clients = vec![]; + for client in old_clients { + if let Err(e) = client.reconfigure(config.client_config.clone()).await { + // TODO: log here. + dbg!(e); + client.close().await.unwrap_or_default(); + continue; + }; + + new_clients.push(client.clone()); + } + self.clients = new_clients; + self.config = config; + + self.check_connections().await; + + Ok(()) + } } impl KvClientPool for NaiveKvClientPool @@ -284,6 +301,11 @@ where let mut inner = self.inner.lock().await; inner.close().await } + + async fn reconfigure(&self, config: KvClientPoolConfig) -> CoreResult<()> { + let mut inner = self.inner.lock().await; + inner.reconfigure(config).await + } } #[cfg(test)] @@ -344,7 +366,7 @@ mod tests { let pool_config = KvClientPoolConfig { num_connections: 1, - client_config: Arc::new(client_config), + client_config, }; let pool: NaiveKvClientPool> = NaiveKvClientPool::new( @@ -395,4 +417,139 @@ mod tests { pool.close().await.unwrap(); } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn reconfigure() { + let _ = env_logger::try_init(); + + let instant = Instant::now().add(Duration::new(7, 0)); + + let (orphan_tx, mut orphan_rx) = unbounded_channel::(); + + tokio::spawn(async move { + loop { + match orphan_rx.recv().await { + Some(resp) => { + dbg!("unexpected orphan", resp); + } + None => { + return; + } + } + } + }); + + let client_config = KvClientConfig { + address: "192.168.107.128:11210" + .parse() + .expect("Failed to parse address"), + root_certs: None, + accept_all_certs: None, + client_name: "myclient".to_string(), + authenticator: Some(Arc::new(PasswordAuthenticator { + username: "Administrator".to_string(), + password: "password".to_string(), + })), + selected_bucket: None, + disable_default_features: false, + disable_error_map: false, + disable_bootstrap: false, + }; + + let pool_config = KvClientPoolConfig { + num_connections: 1, + client_config, + }; + + let pool: NaiveKvClientPool> = NaiveKvClientPool::new( + pool_config, + KvClientPoolOptions { + connect_timeout: Default::default(), + connect_throttle_period: Default::default(), + orphan_handler: Arc::new(orphan_tx), + }, + ) + .await; + + let client_config = KvClientConfig { + address: "192.168.107.128:11210" + .parse() + .expect("Failed to parse address"), + root_certs: None, + accept_all_certs: None, + client_name: "myclient".to_string(), + authenticator: Some(Arc::new(PasswordAuthenticator { + username: "Administrator".to_string(), + password: "password".to_string(), + })), + selected_bucket: Some("default".to_string()), + disable_default_features: false, + disable_error_map: false, + disable_bootstrap: false, + }; + + let client = pool.get_client().await.unwrap(); + let result = client + .set(SetRequest { + collection_id: 0, + key: "test".as_bytes().into(), + vbucket_id: 1, + flags: 0, + value: "test".as_bytes().into(), + datatype: 0, + expiry: None, + preserve_expiry: None, + cas: None, + on_behalf_of: None, + durability_level: None, + durability_level_timeout: None, + }) + .await; + if result.is_ok() { + panic!("result did not contain an error"); + } + + pool.reconfigure(KvClientPoolConfig { + num_connections: 1, + client_config, + }) + .await + .unwrap(); + + let client = pool.get_client().await.unwrap(); + + let result = client + .set(SetRequest { + collection_id: 0, + key: "test".as_bytes().into(), + vbucket_id: 1, + flags: 0, + value: "test".as_bytes().into(), + datatype: 0, + expiry: None, + preserve_expiry: None, + cas: None, + on_behalf_of: None, + durability_level: None, + durability_level_timeout: None, + }) + .await + .unwrap(); + + dbg!(result); + + let get_result = client + .get(GetRequest { + collection_id: 0, + key: "test".as_bytes().into(), + vbucket_id: 1, + on_behalf_of: None, + }) + .await + .unwrap(); + + dbg!(get_result); + + pool.close().await.unwrap(); + } } diff --git a/sdk/couchbase-core/src/lib.rs b/sdk/couchbase-core/src/lib.rs index 3c487a4d..17ca4bc4 100644 --- a/sdk/couchbase-core/src/lib.rs +++ b/sdk/couchbase-core/src/lib.rs @@ -3,6 +3,7 @@ pub mod cbconfig; mod error; mod kvclient; mod kvclient_ops; +mod kvclientmanager; mod kvclientpool; pub mod memdx; pub mod result; diff --git a/sdk/couchbase-core/src/memdx/auth_mechanism.rs b/sdk/couchbase-core/src/memdx/auth_mechanism.rs index c0a81cbb..9a9da2b5 100644 --- a/sdk/couchbase-core/src/memdx/auth_mechanism.rs +++ b/sdk/couchbase-core/src/memdx/auth_mechanism.rs @@ -1,4 +1,4 @@ -use crate::memdx::error::Error; +use crate::memdx::error::MemdxError; #[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] pub enum AuthMechanism { @@ -22,7 +22,7 @@ impl From for Vec { } impl TryFrom<&str> for AuthMechanism { - type Error = Error; + type Error = MemdxError; fn try_from(value: &str) -> Result { let mech = match value { @@ -31,7 +31,10 @@ impl TryFrom<&str> for AuthMechanism { "SCRAM-SHA256" => AuthMechanism::ScramSha256, "SCRAM-SHA512" => AuthMechanism::ScramSha512, _ => { - return Err(Error::Protocol(format!("Unknown auth mechanism {}", value))); + return Err(MemdxError::Protocol(format!( + "Unknown auth mechanism {}", + value + ))); } }; diff --git a/sdk/couchbase-core/src/memdx/client.rs b/sdk/couchbase-core/src/memdx/client.rs index 4d87d737..1472e34b 100644 --- a/sdk/couchbase-core/src/memdx/client.rs +++ b/sdk/couchbase-core/src/memdx/client.rs @@ -26,12 +26,12 @@ use crate::memdx::client_response::ClientResponse; use crate::memdx::codec::KeyValueCodec; use crate::memdx::connection::{Connection, ConnectionType}; use crate::memdx::dispatcher::{Dispatcher, DispatcherOptions}; -use crate::memdx::error::{CancellationErrorKind, Error}; +use crate::memdx::error::{CancellationErrorKind, MemdxError}; use crate::memdx::packet::{RequestPacket, ResponsePacket}; use crate::memdx::pendingop::ClientPendingOp; -pub type Result = std::result::Result; -type ResponseSender = Sender>; +pub type MemdxResult = std::result::Result; +type ResponseSender = Sender>; type OpaqueMap = HashMap>; pub(crate) type CancellationSender = UnboundedSender<(u32, CancellationErrorKind)>; @@ -41,7 +41,7 @@ struct ReadLoopOptions { pub local_addr: Option, pub peer_addr: Option, pub orphan_handler: Arc>, - pub on_connection_close_tx: Option>>, + pub on_connection_close_tx: Option>>, pub on_client_close_rx: Receiver<()>, } @@ -90,7 +90,7 @@ impl Client { for entry in opaque_map.iter() { entry .1 - .send(Err(Error::ClosedInFlight)) + .send(Err(MemdxError::ClosedInFlight)) .await .unwrap_or_default(); } @@ -100,7 +100,7 @@ impl Client { stream: FramedRead, KeyValueCodec>, op_cancel_rx: UnboundedReceiver<(u32, CancellationErrorKind)>, opaque_map: MutexGuard<'_, OpaqueMap>, - on_connection_close_tx: Option>>, + on_connection_close_tx: Option>>, ) { drop(stream); drop(op_cancel_rx); @@ -138,7 +138,7 @@ impl Client { drop(map); sender - .send(Err(Error::Cancelled(cancel_info.1))) + .send(Err(MemdxError::Cancelled(cancel_info.1))) .await .unwrap(); } else { @@ -300,7 +300,7 @@ impl Dispatcher for Client { } } - async fn dispatch(&self, mut packet: RequestPacket) -> Result { + async fn dispatch(&self, mut packet: RequestPacket) -> MemdxResult { let (response_tx, response_rx) = mpsc::channel(1); let opaque = self.register_handler(Arc::new(response_tx)).await; packet.opaque = Some(opaque); @@ -323,14 +323,14 @@ impl Dispatcher for Client { let mut map = requests.lock().await; map.remove(&opaque); - Err(Error::Dispatch(e.kind())) + Err(MemdxError::Dispatch(e.kind())) } } } - async fn close(&self) -> Result<()> { + async fn close(&self) -> MemdxResult<()> { if self.closed.swap(true, Ordering::SeqCst) { - return Err(Error::Closed); + return Err(MemdxError::Closed); } let mut close_err = None; @@ -354,7 +354,7 @@ impl Dispatcher for Client { Self::drain_opaque_map(map).await; if let Some(e) = close_err { - return Err(Error::from(e)); + return Err(MemdxError::from(e)); } Ok(()) @@ -405,7 +405,7 @@ mod tests { .expect("Could not connect"); let (orphan_tx, mut orphan_rx) = unbounded_channel::(); - let (close_tx, mut close_rx) = oneshot::channel::>(); + let (close_tx, mut close_rx) = oneshot::channel::>(); tokio::spawn(async move { loop { @@ -442,7 +442,7 @@ mod tests { let bootstrap_result = OpBootstrap::bootstrap( OpsCore {}, - &mut client, + &client, BootstrapOptions { hello: Some(HelloRequest { client_name: "test-client".into(), diff --git a/sdk/couchbase-core/src/memdx/codec.rs b/sdk/couchbase-core/src/memdx/codec.rs index 53523609..23b36fa2 100644 --- a/sdk/couchbase-core/src/memdx/codec.rs +++ b/sdk/couchbase-core/src/memdx/codec.rs @@ -3,8 +3,8 @@ use std::io; use tokio_util::bytes::{Buf, BufMut, BytesMut}; use tokio_util::codec::{Decoder, Encoder}; -use crate::memdx::error::Error; -use crate::memdx::error::Error::Protocol; +use crate::memdx::error::MemdxError; +use crate::memdx::error::MemdxError::Protocol; use crate::memdx::magic::Magic; use crate::memdx::opcode::OpCode; use crate::memdx::packet::{RequestPacket, ResponsePacket}; @@ -17,7 +17,7 @@ pub struct KeyValueCodec(()); impl Decoder for KeyValueCodec { type Item = ResponsePacket; - type Error = Error; + type Error = MemdxError; fn decode(&mut self, buf: &mut BytesMut) -> Result, Self::Error> { let buf_len = buf.len(); diff --git a/sdk/couchbase-core/src/memdx/connection.rs b/sdk/couchbase-core/src/memdx/connection.rs index 67776220..826998fa 100644 --- a/sdk/couchbase-core/src/memdx/connection.rs +++ b/sdk/couchbase-core/src/memdx/connection.rs @@ -12,8 +12,8 @@ use tokio_rustls::rustls::client::danger::{ use tokio_rustls::rustls::pki_types::{CertificateDer, IpAddr, ServerName, UnixTime}; use tokio_rustls::TlsConnector; -use crate::memdx::client::Result; -use crate::memdx::error::Error; +use crate::memdx::client::MemdxResult; +use crate::memdx::error::MemdxError; #[derive(Debug, Default)] pub struct TlsConfig { @@ -42,7 +42,7 @@ pub struct Connection { } impl Connection { - pub async fn connect(addr: SocketAddr, opts: ConnectOptions) -> Result { + pub async fn connect(addr: SocketAddr, opts: ConnectOptions) -> MemdxResult { let remote_addr = addr.to_string(); if let Some(tls_config) = opts.tls_config { @@ -55,7 +55,7 @@ impl Connection { } else if let Some(roots) = tls_config.root_certs { builder.with_root_certificates(roots).with_no_client_auth() } else { - return Err(Error::Generic( + return Err(MemdxError::Generic( "If tls config is specified then roots or accept_all_certs must be specified" .to_string(), )); @@ -63,11 +63,11 @@ impl Connection { let tcp_socket = timeout_at(opts.deadline, TcpStream::connect(remote_addr)) .await? - .map_err(|e| Error::Connect(e.kind()))?; + .map_err(|e| MemdxError::Connect(e.kind()))?; tcp_socket .set_nodelay(false) - .map_err(|e| Error::Connect(e.kind()))?; + .map_err(|e| MemdxError::Connect(e.kind()))?; let local_addr = match tcp_socket.local_addr() { Ok(addr) => Some(addr), @@ -84,7 +84,7 @@ impl Connection { connector.connect(ServerName::IpAddress(IpAddr::from(addr.ip())), tcp_socket), ) .await? - .map_err(|e| Error::Connect(e.kind()))?; + .map_err(|e| MemdxError::Connect(e.kind()))?; Ok(Connection { inner: ConnectionType::Tls(socket), @@ -94,10 +94,10 @@ impl Connection { } else { let socket = timeout_at(opts.deadline, TcpStream::connect(remote_addr)) .await? - .map_err(|e| Error::Connect(e.kind()))?; + .map_err(|e| MemdxError::Connect(e.kind()))?; socket .set_nodelay(false) - .map_err(|e| Error::Connect(e.kind()))?; + .map_err(|e| MemdxError::Connect(e.kind()))?; let local_addr = match socket.local_addr() { Ok(addr) => Some(addr), diff --git a/sdk/couchbase-core/src/memdx/dispatcher.rs b/sdk/couchbase-core/src/memdx/dispatcher.rs index 334411fd..a108f673 100644 --- a/sdk/couchbase-core/src/memdx/dispatcher.rs +++ b/sdk/couchbase-core/src/memdx/dispatcher.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; use tokio::sync::mpsc::UnboundedSender; use tokio::sync::oneshot; -use crate::memdx::client::Result; +use crate::memdx::client::MemdxResult; use crate::memdx::connection::Connection; use crate::memdx::packet::{RequestPacket, ResponsePacket}; use crate::memdx::pendingop::ClientPendingOp; @@ -12,12 +12,12 @@ use crate::memdx::pendingop::ClientPendingOp; #[derive(Debug)] pub struct DispatcherOptions { pub orphan_handler: Arc>, - pub on_connection_close_handler: Option>>, + pub on_connection_close_handler: Option>>, } #[async_trait] pub trait Dispatcher: Send + Sync { fn new(conn: Connection, opts: DispatcherOptions) -> Self; - async fn dispatch(&self, packet: RequestPacket) -> Result; - async fn close(&self) -> Result<()>; + async fn dispatch(&self, packet: RequestPacket) -> MemdxResult; + async fn close(&self) -> MemdxResult<()>; } diff --git a/sdk/couchbase-core/src/memdx/error.rs b/sdk/couchbase-core/src/memdx/error.rs index a8941c73..56589ba0 100644 --- a/sdk/couchbase-core/src/memdx/error.rs +++ b/sdk/couchbase-core/src/memdx/error.rs @@ -6,7 +6,7 @@ use tokio::time::error::Elapsed; use crate::scram::ScramError; #[derive(thiserror::Error, Debug, Eq, PartialEq)] -pub enum Error { +pub enum MemdxError { #[error("Connect failed {0}")] Connect(io::ErrorKind), #[error("Dispatch failed {0}")] @@ -67,19 +67,19 @@ impl Display for CancellationErrorKind { } // TODO: improve this. -impl From for Error { +impl From for MemdxError { fn from(value: io::Error) -> Self { - Error::Unknown(value.to_string()) + MemdxError::Unknown(value.to_string()) } } -impl From for Error { +impl From for MemdxError { fn from(value: ScramError) -> Self { Self::Auth(value.to_string()) } } -impl From for Error { +impl From for MemdxError { fn from(_value: Elapsed) -> Self { Self::Cancelled(CancellationErrorKind::Timeout) } diff --git a/sdk/couchbase-core/src/memdx/magic.rs b/sdk/couchbase-core/src/memdx/magic.rs index 71b8964b..e9858e4d 100644 --- a/sdk/couchbase-core/src/memdx/magic.rs +++ b/sdk/couchbase-core/src/memdx/magic.rs @@ -1,6 +1,6 @@ use std::fmt::{Debug, Display}; -use crate::memdx::error::Error; +use crate::memdx::error::MemdxError; #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum Magic { @@ -36,7 +36,7 @@ impl From for u8 { } impl TryFrom for Magic { - type Error = Error; + type Error = MemdxError; fn try_from(value: u8) -> Result { let magic = match value { @@ -45,7 +45,7 @@ impl TryFrom for Magic { 0x08 => Magic::ReqExt, 0x18 => Magic::ResExt, _ => { - return Err(Error::Protocol(format!("unknown magic {}", value))); + return Err(MemdxError::Protocol(format!("unknown magic {}", value))); } }; diff --git a/sdk/couchbase-core/src/memdx/op_auth_saslauto.rs b/sdk/couchbase-core/src/memdx/op_auth_saslauto.rs index 0d01158b..d95c47f3 100644 --- a/sdk/couchbase-core/src/memdx/op_auth_saslauto.rs +++ b/sdk/couchbase-core/src/memdx/op_auth_saslauto.rs @@ -3,11 +3,11 @@ use std::cmp::PartialEq; use tokio::time::Instant; use crate::memdx::auth_mechanism::AuthMechanism; -use crate::memdx::client::Result; +use crate::memdx::client::MemdxResult; use crate::memdx::dispatcher::Dispatcher; use crate::memdx::error::CancellationErrorKind::{RequestCancelled, Timeout}; -use crate::memdx::error::Error; -use crate::memdx::error::Error::Generic; +use crate::memdx::error::MemdxError; +use crate::memdx::error::MemdxError::Generic; use crate::memdx::op_auth_saslbyname::{ OpSASLAuthByNameEncoder, OpsSASLAuthByName, SASLAuthByNameOptions, }; @@ -29,9 +29,9 @@ pub struct SASLListMechsOptions {} pub trait OpSASLAutoEncoder: OpSASLAuthByNameEncoder { fn sasl_list_mechs( &self, - dispatcher: &mut D, + dispatcher: &D, request: SASLListMechsRequest, - ) -> impl std::future::Future>> + ) -> impl std::future::Future>> where D: Dispatcher; } @@ -43,10 +43,10 @@ impl OpsSASLAuthAuto { pub async fn sasl_auth_auto( &self, encoder: &E, - dispatcher: &mut D, + dispatcher: &D, deadline: Instant, opts: SASLAuthAutoOptions, - ) -> Result<()> + ) -> MemdxResult<()> where E: OpSASLAutoEncoder, D: Dispatcher, @@ -81,7 +81,9 @@ impl OpsSASLAuthAuto { { Ok(()) => Ok(()), Err(e) => { - if e == Error::Cancelled(Timeout) || e == Error::Cancelled(RequestCancelled) { + if e == MemdxError::Cancelled(Timeout) + || e == MemdxError::Cancelled(RequestCancelled) + { return Err(e); } diff --git a/sdk/couchbase-core/src/memdx/op_auth_saslbyname.rs b/sdk/couchbase-core/src/memdx/op_auth_saslbyname.rs index e3df6c81..445c44e1 100644 --- a/sdk/couchbase-core/src/memdx/op_auth_saslbyname.rs +++ b/sdk/couchbase-core/src/memdx/op_auth_saslbyname.rs @@ -1,7 +1,7 @@ use tokio::time::Instant; use crate::memdx::auth_mechanism::AuthMechanism; -use crate::memdx::client::Result; +use crate::memdx::client::MemdxResult; use crate::memdx::dispatcher::Dispatcher; use crate::memdx::op_auth_saslplain::{OpSASLPlainEncoder, OpsSASLAuthPlain, SASLAuthPlainOptions}; use crate::memdx::op_auth_saslscram::{OpSASLScramEncoder, OpsSASLAuthScram, SASLAuthScramOptions}; @@ -25,9 +25,9 @@ impl OpsSASLAuthByName { pub async fn sasl_auth_by_name( &self, encoder: &E, - dispatcher: &mut D, + dispatcher: &D, opts: SASLAuthByNameOptions, - ) -> Result<()> + ) -> MemdxResult<()> where E: OpSASLAuthByNameEncoder, D: Dispatcher, diff --git a/sdk/couchbase-core/src/memdx/op_auth_saslplain.rs b/sdk/couchbase-core/src/memdx/op_auth_saslplain.rs index 26b474bf..d9c7af24 100644 --- a/sdk/couchbase-core/src/memdx/op_auth_saslplain.rs +++ b/sdk/couchbase-core/src/memdx/op_auth_saslplain.rs @@ -1,9 +1,9 @@ use tokio::time::Instant; use crate::memdx::auth_mechanism::AuthMechanism; -use crate::memdx::client::Result; +use crate::memdx::client::MemdxResult; use crate::memdx::dispatcher::Dispatcher; -use crate::memdx::error::Error; +use crate::memdx::error::MemdxError; use crate::memdx::pendingop::{run_op_future_with_deadline, StandardPendingOp}; use crate::memdx::request::SASLAuthRequest; use crate::memdx::response::SASLAuthResponse; @@ -11,9 +11,9 @@ use crate::memdx::response::SASLAuthResponse; pub trait OpSASLPlainEncoder { fn sasl_auth( &self, - dispatcher: &mut D, + dispatcher: &D, req: SASLAuthRequest, - ) -> impl std::future::Future>> + ) -> impl std::future::Future>> where D: Dispatcher; } @@ -42,9 +42,9 @@ impl OpsSASLAuthPlain { pub async fn sasl_auth_plain( &self, encoder: &E, - dispatcher: &mut D, + dispatcher: &D, opts: SASLAuthPlainOptions, - ) -> Result<()> + ) -> MemdxResult<()> where E: OpSASLPlainEncoder, D: Dispatcher, @@ -64,7 +64,7 @@ impl OpsSASLAuthPlain { run_op_future_with_deadline(opts.deadline, encoder.sasl_auth(dispatcher, req)).await?; if resp.needs_more_steps { - return Err(Error::Protocol( + return Err(MemdxError::Protocol( "Server did not accept auth when the client expected".to_string(), )); } diff --git a/sdk/couchbase-core/src/memdx/op_auth_saslscram.rs b/sdk/couchbase-core/src/memdx/op_auth_saslscram.rs index 698dc370..70fe828f 100644 --- a/sdk/couchbase-core/src/memdx/op_auth_saslscram.rs +++ b/sdk/couchbase-core/src/memdx/op_auth_saslscram.rs @@ -4,9 +4,9 @@ use sha2::{Sha256, Sha512}; use tokio::time::Instant; use crate::memdx::auth_mechanism::AuthMechanism; -use crate::memdx::client::Result; +use crate::memdx::client::MemdxResult; use crate::memdx::dispatcher::Dispatcher; -use crate::memdx::error::Error; +use crate::memdx::error::MemdxError; use crate::memdx::op_auth_saslplain::OpSASLPlainEncoder; use crate::memdx::pendingop::{run_op_future_with_deadline, StandardPendingOp}; use crate::memdx::request::{SASLAuthRequest, SASLStepRequest}; @@ -16,9 +16,9 @@ use crate::scram; pub trait OpSASLScramEncoder: OpSASLPlainEncoder { fn sasl_step( &self, - dispatcher: &mut D, + dispatcher: &D, request: SASLStepRequest, - ) -> impl std::future::Future>> + ) -> impl std::future::Future>> where D: Dispatcher; } @@ -50,9 +50,9 @@ impl OpsSASLAuthScram { pub async fn sasl_auth_scram_512( &self, encoder: &E, - dispatcher: &mut D, + dispatcher: &D, opts: SASLAuthScramOptions, - ) -> Result<()> + ) -> MemdxResult<()> where E: OpSASLScramEncoder, D: Dispatcher, @@ -86,7 +86,7 @@ impl OpsSASLAuthScram { run_op_future_with_deadline(opts.deadline, encoder.sasl_step(dispatcher, req)).await?; if resp.needs_more_steps { - return Err(Error::Protocol( + return Err(MemdxError::Protocol( "Server did not accept auth when the client expected".to_string(), )); } @@ -97,9 +97,9 @@ impl OpsSASLAuthScram { pub async fn sasl_auth_scram_256( &self, encoder: &E, - dispatcher: &mut D, + dispatcher: &D, opts: SASLAuthScramOptions, - ) -> Result<()> + ) -> MemdxResult<()> where E: OpSASLScramEncoder, D: Dispatcher, @@ -133,7 +133,7 @@ impl OpsSASLAuthScram { run_op_future_with_deadline(opts.deadline, encoder.sasl_step(dispatcher, req)).await?; if resp.needs_more_steps { - return Err(Error::Protocol( + return Err(MemdxError::Protocol( "Server did not accept auth when the client expected".to_string(), )); } @@ -144,9 +144,9 @@ impl OpsSASLAuthScram { pub async fn sasl_auth_scram_1( &self, encoder: &E, - dispatcher: &mut D, + dispatcher: &D, opts: SASLAuthScramOptions, - ) -> Result<()> + ) -> MemdxResult<()> where E: OpSASLScramEncoder, D: Dispatcher, @@ -179,7 +179,7 @@ impl OpsSASLAuthScram { run_op_future_with_deadline(opts.deadline, encoder.sasl_step(dispatcher, req)).await?; if resp.needs_more_steps { - return Err(Error::Protocol( + return Err(MemdxError::Protocol( "Server did not accept auth when the client expected".to_string(), )); } diff --git a/sdk/couchbase-core/src/memdx/op_bootstrap.rs b/sdk/couchbase-core/src/memdx/op_bootstrap.rs index b873ccc1..962a645b 100644 --- a/sdk/couchbase-core/src/memdx/op_bootstrap.rs +++ b/sdk/couchbase-core/src/memdx/op_bootstrap.rs @@ -2,7 +2,7 @@ use log::warn; use tokio::select; use tokio::time::{Instant, sleep}; -use crate::memdx::client::Result; +use crate::memdx::client::MemdxResult; use crate::memdx::dispatcher::Dispatcher; use crate::memdx::error::CancellationErrorKind; use crate::memdx::op_auth_saslauto::{OpSASLAutoEncoder, OpsSASLAuthAuto, SASLAuthAutoOptions}; @@ -20,33 +20,33 @@ use crate::memdx::response::{ pub trait OpBootstrapEncoder { fn hello( &self, - dispatcher: &mut D, + dispatcher: &D, request: HelloRequest, - ) -> impl std::future::Future>> + ) -> impl std::future::Future>> where D: Dispatcher; fn get_error_map( &self, - dispatcher: &mut D, + dispatcher: &D, request: GetErrorMapRequest, - ) -> impl std::future::Future>> + ) -> impl std::future::Future>> where D: Dispatcher; fn select_bucket( &self, - dispatcher: &mut D, + dispatcher: &D, request: SelectBucketRequest, - ) -> impl std::future::Future>> + ) -> impl std::future::Future>> where D: Dispatcher; fn get_cluster_config( &self, - dispatcher: &mut D, + dispatcher: &D, request: GetClusterConfigRequest, - ) -> impl std::future::Future>> + ) -> impl std::future::Future>> where D: Dispatcher; } @@ -68,9 +68,9 @@ impl OpBootstrap { // make pipelining complex. It's a bit of a niche optimization so we can improve later. pub async fn bootstrap( encoder: E, - dispatcher: &mut D, + dispatcher: &D, opts: BootstrapOptions, - ) -> Result + ) -> MemdxResult where E: OpBootstrapEncoder + OpSASLAutoEncoder, D: Dispatcher, diff --git a/sdk/couchbase-core/src/memdx/opcode.rs b/sdk/couchbase-core/src/memdx/opcode.rs index a5d052f3..f74771e7 100644 --- a/sdk/couchbase-core/src/memdx/opcode.rs +++ b/sdk/couchbase-core/src/memdx/opcode.rs @@ -1,6 +1,6 @@ use std::fmt::{Display, Formatter}; -use crate::memdx::error::Error; +use crate::memdx::error::MemdxError; #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum OpCode { @@ -34,7 +34,7 @@ impl From for u8 { } impl TryFrom for OpCode { - type Error = Error; + type Error = MemdxError; fn try_from(value: u8) -> Result { let code = match value { @@ -49,7 +49,7 @@ impl TryFrom for OpCode { 0xb5 => OpCode::GetClusterConfig, 0xfe => OpCode::GetErrorMap, _ => { - return Err(Error::Protocol(format!("unknown opcode {}", value))); + return Err(MemdxError::Protocol(format!("unknown opcode {}", value))); } }; diff --git a/sdk/couchbase-core/src/memdx/ops_core.rs b/sdk/couchbase-core/src/memdx/ops_core.rs index d789020e..df076f12 100644 --- a/sdk/couchbase-core/src/memdx/ops_core.rs +++ b/sdk/couchbase-core/src/memdx/ops_core.rs @@ -2,9 +2,9 @@ use std::io::Write; use byteorder::{BigEndian, WriteBytesExt}; -use crate::memdx::client::Result; +use crate::memdx::client::MemdxResult; use crate::memdx::dispatcher::Dispatcher; -use crate::memdx::error::Error; +use crate::memdx::error::MemdxError; use crate::memdx::magic::Magic; use crate::memdx::op_auth_saslauto::OpSASLAutoEncoder; use crate::memdx::op_auth_saslbyname::OpSASLAuthByNameEncoder; @@ -27,14 +27,14 @@ use crate::memdx::status::Status; pub struct OpsCore {} impl OpsCore { - pub(crate) fn decode_error(resp: &ResponsePacket) -> Error { + pub(crate) fn decode_error(resp: &ResponsePacket) -> MemdxError { let status = resp.status; if status == Status::NotMyVbucket { - Error::NotMyVbucket + MemdxError::NotMyVbucket } else if status == Status::TmpFail { - Error::TmpFail + MemdxError::TmpFail } else { - Error::Unknown(format!("{}", status)) + MemdxError::Unknown(format!("{}", status)) } // TODO: decode error context @@ -44,9 +44,9 @@ impl OpsCore { impl OpBootstrapEncoder for OpsCore { async fn hello( &self, - dispatcher: &mut D, + dispatcher: &D, request: HelloRequest, - ) -> Result> + ) -> MemdxResult> where D: Dispatcher, { @@ -75,9 +75,9 @@ impl OpBootstrapEncoder for OpsCore { async fn get_error_map( &self, - dispatcher: &mut D, + dispatcher: &D, request: GetErrorMapRequest, - ) -> Result> + ) -> MemdxResult> where D: Dispatcher, { @@ -104,9 +104,9 @@ impl OpBootstrapEncoder for OpsCore { async fn select_bucket( &self, - dispatcher: &mut D, + dispatcher: &D, request: SelectBucketRequest, - ) -> Result> + ) -> MemdxResult> where D: Dispatcher, { @@ -133,9 +133,9 @@ impl OpBootstrapEncoder for OpsCore { async fn get_cluster_config( &self, - dispatcher: &mut D, + dispatcher: &D, _request: GetClusterConfigRequest, - ) -> Result> + ) -> MemdxResult> where D: Dispatcher, { @@ -161,9 +161,9 @@ impl OpBootstrapEncoder for OpsCore { impl OpSASLPlainEncoder for OpsCore { async fn sasl_auth( &self, - dispatcher: &mut D, + dispatcher: &D, request: SASLAuthRequest, - ) -> Result> + ) -> MemdxResult> where D: Dispatcher, { @@ -194,9 +194,9 @@ impl OpSASLAuthByNameEncoder for OpsCore {} impl OpSASLAutoEncoder for OpsCore { async fn sasl_list_mechs( &self, - dispatcher: &mut D, + dispatcher: &D, _request: SASLListMechsRequest, - ) -> Result> + ) -> MemdxResult> where D: Dispatcher, { @@ -222,9 +222,9 @@ impl OpSASLAutoEncoder for OpsCore { impl OpSASLScramEncoder for OpsCore { async fn sasl_step( &self, - dispatcher: &mut D, + dispatcher: &D, request: SASLStepRequest, - ) -> Result> + ) -> MemdxResult> where D: Dispatcher, { diff --git a/sdk/couchbase-core/src/memdx/ops_crud.rs b/sdk/couchbase-core/src/memdx/ops_crud.rs index 92de1e54..88cd5afd 100644 --- a/sdk/couchbase-core/src/memdx/ops_crud.rs +++ b/sdk/couchbase-core/src/memdx/ops_crud.rs @@ -3,10 +3,10 @@ use std::time::Duration; use byteorder::{BigEndian, WriteBytesExt}; use bytes::{BufMut, BytesMut}; -use crate::memdx::client::Result; +use crate::memdx::client::MemdxResult; use crate::memdx::dispatcher::Dispatcher; use crate::memdx::durability_level::{DurabilityLevel, DurabilityLevelSettings}; -use crate::memdx::error::Error; +use crate::memdx::error::MemdxError; use crate::memdx::ext_frame_code::{ExtReqFrameCode, ExtResFrameCode}; use crate::memdx::magic::Magic; use crate::memdx::opcode::OpCode; @@ -30,7 +30,7 @@ impl OpsCrud { &self, dispatcher: &D, request: SetRequest, - ) -> Result> + ) -> MemdxResult> where D: Dispatcher, { @@ -76,7 +76,7 @@ impl OpsCrud { &self, dispatcher: &D, request: GetRequest, - ) -> Result> + ) -> MemdxResult> where D: Dispatcher, { @@ -110,10 +110,10 @@ impl OpsCrud { Ok(StandardPendingOp::new(pending_op)) } - fn encode_collection_and_key(&self, collection_id: u32, key: Vec) -> Result> { + fn encode_collection_and_key(&self, collection_id: u32, key: Vec) -> MemdxResult> { if !self.collections_enabled { if collection_id != 0 { - return Err(Error::CollectionsNotEnabled); + return Err(MemdxError::CollectionsNotEnabled); } return Ok(key); @@ -129,14 +129,14 @@ impl OpsCrud { preserve_expiry: Option, on_behalf_of: Option, buf: &mut Vec, - ) -> Result { + ) -> MemdxResult { if let Some(obo) = on_behalf_of { append_ext_frame(ExtReqFrameCode::OnBehalfOf, obo.into_bytes(), buf)?; } if let Some(dura) = durability_level { if !self.durability_enabled { - return Err(Error::Protocol( + return Err(MemdxError::Protocol( "Cannot use synchronous durability when its not enabled".to_string(), )); } @@ -145,14 +145,14 @@ impl OpsCrud { append_ext_frame(ExtReqFrameCode::Durability, dura_buf, buf)?; } else if durability_timeout.is_some() { - return Err(Error::Protocol( + return Err(MemdxError::Protocol( "Cannot encode durability timeout without durability level".to_string(), )); } if preserve_expiry.is_some() { if !self.preserve_expiry_enabled { - return Err(Error::Protocol( + return Err(MemdxError::Protocol( "Cannot use preserve expiry when its not enabled".to_string(), )); } @@ -162,7 +162,7 @@ impl OpsCrud { let magic = if !buf.is_empty() { if !self.ext_frames_enabled { - return Err(Error::Protocol( + return Err(MemdxError::Protocol( "Cannot use framing extras when its not enabled".to_string(), )); } @@ -175,10 +175,10 @@ impl OpsCrud { Ok(magic) } - pub(crate) fn decode_common_status(status: Status) -> Result<()> { + pub(crate) fn decode_common_status(status: Status) -> MemdxResult<()> { let err = match status { - Status::CollectionUnknown => Error::UnknownCollectionID, - Status::AccessError => Error::Access, + Status::CollectionUnknown => MemdxError::UnknownCollectionID, + Status::AccessError => MemdxError::Access, _ => { return Ok(()); } @@ -187,7 +187,7 @@ impl OpsCrud { Err(err) } - pub(crate) fn decode_common_error(resp: &ResponsePacket) -> Error { + pub(crate) fn decode_common_error(resp: &ResponsePacket) -> MemdxError { if let Err(e) = Self::decode_common_status(resp.status) { return e; }; @@ -196,7 +196,7 @@ impl OpsCrud { } } -pub(crate) fn decode_res_ext_frames(buf: &[u8]) -> Result> { +pub(crate) fn decode_res_ext_frames(buf: &[u8]) -> MemdxResult> { let mut server_duration_data = None; iter_ext_frames(buf, |code, data| { @@ -212,9 +212,11 @@ pub(crate) fn decode_res_ext_frames(buf: &[u8]) -> Result> { Ok(None) } -pub fn decode_ext_frame(buf: &[u8]) -> Result<(ExtResFrameCode, Vec, usize)> { +pub fn decode_ext_frame(buf: &[u8]) -> MemdxResult<(ExtResFrameCode, Vec, usize)> { if buf.is_empty() { - return Err(Error::Protocol("Framing extras protocol error".to_string())); + return Err(MemdxError::Protocol( + "Framing extras protocol error".to_string(), + )); } let mut buf_pos = 0; @@ -227,7 +229,7 @@ pub fn decode_ext_frame(buf: &[u8]) -> Result<(ExtResFrameCode, Vec, usize)> if u_frame_code == 15 { if buf.len() < buf_pos + 1 { - return Err(Error::Protocol("Unexpected eof".to_string())); + return Err(MemdxError::Protocol("Unexpected eof".to_string())); } let frame_code_ext = buf[buf_pos]; @@ -238,7 +240,7 @@ pub fn decode_ext_frame(buf: &[u8]) -> Result<(ExtResFrameCode, Vec, usize)> if frame_len == 15 { if buf.len() < buf_pos + 1 { - return Err(Error::Protocol("Unexpected eof".to_string())); + return Err(MemdxError::Protocol("Unexpected eof".to_string())); } let frame_len_ext = buf[buf_pos]; @@ -248,7 +250,7 @@ pub fn decode_ext_frame(buf: &[u8]) -> Result<(ExtResFrameCode, Vec, usize)> let u_frame_len = frame_len as usize; if buf.len() < buf_pos + u_frame_len { - return Err(Error::Protocol("unexpected eof".to_string())); + return Err(MemdxError::Protocol("unexpected eof".to_string())); } let frame_body = &buf[buf_pos..buf_pos + u_frame_len]; @@ -257,7 +259,10 @@ pub fn decode_ext_frame(buf: &[u8]) -> Result<(ExtResFrameCode, Vec, usize)> Ok((frame_code, frame_body.to_vec(), buf_pos)) } -fn iter_ext_frames(buf: &[u8], mut cb: impl FnMut(ExtResFrameCode, Vec)) -> Result> { +fn iter_ext_frames( + buf: &[u8], + mut cb: impl FnMut(ExtResFrameCode, Vec), +) -> MemdxResult> { if !buf.is_empty() { let (frame_code, frame_body, buf_pos) = decode_ext_frame(buf)?; @@ -273,7 +278,7 @@ pub fn append_ext_frame( frame_code: ExtReqFrameCode, frame_body: Vec, buf: &mut Vec, -) -> Result<()> { +) -> MemdxResult<()> { let frame_len = frame_body.len(); let buf_len = buf.len(); @@ -285,7 +290,7 @@ pub fn append_ext_frame( *hdr_byte_ptr = (*hdr_byte_ptr as u16 | ((u_frame_code & 0x0f) << 4)) as u8; } else { if u_frame_code - 15 >= 15 { - return Err(Error::Protocol( + return Err(MemdxError::Protocol( "Extframe code too large to encode".to_string(), )); } @@ -326,7 +331,7 @@ pub fn make_uleb128_32(key: Vec, collection_id: u32) -> Vec { fn encode_durability_ext_frame( level: DurabilityLevel, timeout: Option, -) -> Result> { +) -> MemdxResult> { if timeout.is_none() { return Ok(vec![level.into()]); } @@ -335,7 +340,7 @@ fn encode_durability_ext_frame( let mut timeout_millis = timeout.as_millis(); if timeout_millis > 65535 { - return Err(Error::Protocol( + return Err(MemdxError::Protocol( "Cannot encode durability timeout greater than 65535 milliseconds".to_string(), )); } @@ -351,9 +356,9 @@ fn encode_durability_ext_frame( Ok(buf) } -pub(crate) fn decode_server_duration_ext_frame(mut data: Vec) -> Result { +pub(crate) fn decode_server_duration_ext_frame(mut data: Vec) -> MemdxResult { if data.len() != 2 { - return Err(Error::Protocol( + return Err(MemdxError::Protocol( "Invalid server duration extframe length".to_string(), )); } @@ -366,7 +371,7 @@ pub(crate) fn decode_server_duration_ext_frame(mut data: Vec) -> Result, -) -> Result { +) -> MemdxResult { if data.len() == 1 { let durability = DurabilityLevel::from(data.remove(0)); @@ -381,7 +386,7 @@ pub(crate) fn decode_durability_level_ext_frame( )); } - Err(Error::Protocol( + Err(MemdxError::Protocol( "Invalid durability extframe length".to_string(), )) } diff --git a/sdk/couchbase-core/src/memdx/pendingop.rs b/sdk/couchbase-core/src/memdx/pendingop.rs index c35d7bb7..58437d70 100644 --- a/sdk/couchbase-core/src/memdx/pendingop.rs +++ b/sdk/couchbase-core/src/memdx/pendingop.rs @@ -8,14 +8,14 @@ use tokio::sync::mpsc::Receiver; use tokio::time::{Instant, timeout_at}; use crate::memdx::client::CancellationSender; -use crate::memdx::client::Result; +use crate::memdx::client::MemdxResult; use crate::memdx::client_response::ClientResponse; use crate::memdx::error::CancellationErrorKind; -use crate::memdx::error::Error::{Cancelled, Closed}; +use crate::memdx::error::MemdxError::{Cancelled, Closed}; use crate::memdx::response::TryFromClientResponse; pub trait PendingOp { - fn recv(&mut self) -> impl std::future::Future> + fn recv(&mut self) -> impl std::future::Future> where T: TryFromClientResponse; fn cancel(&mut self, e: CancellationErrorKind); @@ -28,14 +28,14 @@ pub(crate) trait OpCanceller { pub struct ClientPendingOp { opaque: u32, cancel_chan: CancellationSender, - response_receiver: Receiver>, + response_receiver: Receiver>, } impl ClientPendingOp { pub(crate) fn new( opaque: u32, cancel_chan: CancellationSender, - response_receiver: Receiver>, + response_receiver: Receiver>, ) -> Self { ClientPendingOp { opaque, @@ -44,7 +44,7 @@ impl ClientPendingOp { } } - pub async fn recv(&mut self) -> Result { + pub async fn recv(&mut self) -> MemdxResult { match self.response_receiver.recv().await { Some(r) => r, None => Err(Closed), @@ -76,7 +76,7 @@ impl StandardPendingOp { } impl PendingOp for StandardPendingOp { - async fn recv(&mut self) -> Result { + async fn recv(&mut self) -> MemdxResult { let packet = self.wrapped.recv().await?; T::try_from(packet) @@ -87,10 +87,13 @@ impl PendingOp for StandardPendingOp { } } -pub(super) async fn run_op_future_with_deadline(deadline: Instant, fut: F) -> Result +pub(super) async fn run_op_future_with_deadline( + deadline: Instant, + fut: F, +) -> MemdxResult where O: PendingOp, - F: Future>, + F: Future>, T: TryFromClientResponse, { let mut op = match timeout_at(deadline, fut).await { diff --git a/sdk/couchbase-core/src/memdx/response.rs b/sdk/couchbase-core/src/memdx/response.rs index a8d1f92e..8c9b05ea 100644 --- a/sdk/couchbase-core/src/memdx/response.rs +++ b/sdk/couchbase-core/src/memdx/response.rs @@ -5,14 +5,14 @@ use byteorder::{BigEndian, ReadBytesExt}; use crate::memdx::auth_mechanism::AuthMechanism; use crate::memdx::client_response::ClientResponse; -use crate::memdx::error::Error; +use crate::memdx::error::MemdxError; use crate::memdx::hello_feature::HelloFeature; use crate::memdx::ops_core::OpsCore; use crate::memdx::ops_crud::{decode_res_ext_frames, OpsCrud}; use crate::memdx::status::Status; pub trait TryFromClientResponse: Sized { - fn try_from(resp: ClientResponse) -> Result; + fn try_from(resp: ClientResponse) -> Result; } #[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] @@ -21,7 +21,7 @@ pub struct HelloResponse { } impl TryFromClientResponse for HelloResponse { - fn try_from(resp: ClientResponse) -> Result { + fn try_from(resp: ClientResponse) -> Result { let packet = resp.packet(); let status = packet.status; if status != Status::Success { @@ -31,7 +31,7 @@ impl TryFromClientResponse for HelloResponse { let mut features: Vec = Vec::new(); if let Some(value) = &packet.value { if value.len() % 2 != 0 { - return Err(Error::Protocol("invalid hello features length".into())); + return Err(MemdxError::Protocol("invalid hello features length".into())); } let mut cursor = Cursor::new(value); @@ -53,7 +53,7 @@ pub struct GetErrorMapResponse { } impl TryFromClientResponse for GetErrorMapResponse { - fn try_from(resp: ClientResponse) -> Result { + fn try_from(resp: ClientResponse) -> Result { let packet = resp.packet(); let status = packet.status; if status != Status::Success { @@ -72,12 +72,12 @@ impl TryFromClientResponse for GetErrorMapResponse { pub struct SelectBucketResponse {} impl TryFromClientResponse for SelectBucketResponse { - fn try_from(resp: ClientResponse) -> Result { + fn try_from(resp: ClientResponse) -> Result { let packet = resp.packet(); let status = packet.status; if status != Status::Success { if status == Status::AccessError || status == Status::KeyNotFound { - return Err(Error::UnknownBucketName); + return Err(MemdxError::UnknownBucketName); } return Err(OpsCore::decode_error(packet)); } @@ -93,7 +93,7 @@ pub struct SASLAuthResponse { } impl TryFromClientResponse for SASLAuthResponse { - fn try_from(resp: ClientResponse) -> Result { + fn try_from(resp: ClientResponse) -> Result { let packet = resp.packet(); let status = packet.status; if status == Status::SASLAuthContinue { @@ -124,7 +124,7 @@ pub struct SASLStepResponse { } impl TryFromClientResponse for SASLStepResponse { - fn try_from(resp: ClientResponse) -> Result { + fn try_from(resp: ClientResponse) -> Result { let packet = resp.packet(); let status = packet.status; if status != Status::Success { @@ -145,7 +145,7 @@ pub struct SASLListMechsResponse { } impl TryFromClientResponse for SASLListMechsResponse { - fn try_from(resp: ClientResponse) -> Result { + fn try_from(resp: ClientResponse) -> Result { let packet = resp.packet(); let status = packet.status; if status != Status::Success { @@ -154,7 +154,7 @@ impl TryFromClientResponse for SASLListMechsResponse { // ns_server has not posted a configuration for the bucket to kv_engine yet. We // transform this into a ErrTmpFail as we make the assumption that the // SelectBucket will have failed if this was anything but a transient issue. - return Err(Error::ConfigNotSet); + return Err(MemdxError::ConfigNotSet); } return Err(OpsCore::decode_error(packet)); } @@ -164,7 +164,7 @@ impl TryFromClientResponse for SASLListMechsResponse { let mechs_list_string = match String::from_utf8(value) { Ok(v) => v, Err(e) => { - return Err(Error::Protocol(e.to_string())); + return Err(MemdxError::Protocol(e.to_string())); } }; let mechs_list_split = mechs_list_string.split(' '); @@ -201,7 +201,7 @@ impl GetClusterConfigResponse { } impl TryFromClientResponse for GetClusterConfigResponse { - fn try_from(resp: ClientResponse) -> Result { + fn try_from(resp: ClientResponse) -> Result { let packet = resp.packet(); let status = packet.status; if status != Status::Success { @@ -211,7 +211,7 @@ impl TryFromClientResponse for GetClusterConfigResponse { let host = match resp.local_addr() { Some(addr) => addr.ip().to_string(), None => { - return Err(Error::Generic( + return Err(MemdxError::Generic( "Failed to identify memd hostname for $HOST replacement".to_string(), )) } @@ -252,35 +252,35 @@ pub struct SetResponse { } impl TryFromClientResponse for SetResponse { - fn try_from(resp: ClientResponse) -> Result { + fn try_from(resp: ClientResponse) -> Result { let packet = resp.packet(); let status = packet.status; if status == Status::TooBig { - return Err(Error::TooBig); + return Err(MemdxError::TooBig); } else if status == Status::Locked { - return Err(Error::Locked); + return Err(MemdxError::Locked); } else if status == Status::KeyExists { - return Err(Error::KeyExists); + return Err(MemdxError::KeyExists); } else if status != Status::Success { - return Err(Error::Unknown( + return Err(MemdxError::Unknown( OpsCrud::decode_common_error(resp.packet()).to_string(), )); } let mutation_token = if let Some(extras) = &packet.extras { if extras.len() != 16 { - return Err(Error::Protocol("Bad extras length".to_string())); + return Err(MemdxError::Protocol("Bad extras length".to_string())); } let mut extras = Cursor::new(extras); Some(MutationToken { vbuuid: extras .read_u64::() - .map_err(|e| Error::Unknown(e.to_string()))?, + .map_err(|e| MemdxError::Unknown(e.to_string()))?, seqno: extras .read_u64::() - .map_err(|e| Error::Unknown(e.to_string()))?, + .map_err(|e| MemdxError::Unknown(e.to_string()))?, }) } else { None @@ -310,29 +310,29 @@ pub struct GetResponse { } impl TryFromClientResponse for GetResponse { - fn try_from(resp: ClientResponse) -> Result { + fn try_from(resp: ClientResponse) -> Result { let packet = resp.packet(); let status = packet.status; if status == Status::KeyNotFound { - return Err(Error::KeyNotFound); + return Err(MemdxError::KeyNotFound); } else if status != Status::Success { - return Err(Error::Unknown( + return Err(MemdxError::Unknown( OpsCrud::decode_common_error(resp.packet()).to_string(), )); } let flags = if let Some(extras) = &packet.extras { if extras.len() != 4 { - return Err(Error::Protocol("Bad extras length".to_string())); + return Err(MemdxError::Protocol("Bad extras length".to_string())); } let mut extras = Cursor::new(extras); extras .read_u32::() - .map_err(|e| Error::Unknown(e.to_string()))? + .map_err(|e| MemdxError::Unknown(e.to_string()))? } else { - return Err(Error::Protocol("Bad extras length".to_string())); + return Err(MemdxError::Protocol("Bad extras length".to_string())); }; let server_duration = if let Some(f) = &packet.framing_extras { diff --git a/sdk/couchbase-core/src/memdx/sync_helpers.rs b/sdk/couchbase-core/src/memdx/sync_helpers.rs index e5462ed5..38b481b8 100644 --- a/sdk/couchbase-core/src/memdx/sync_helpers.rs +++ b/sdk/couchbase-core/src/memdx/sync_helpers.rs @@ -1,13 +1,13 @@ use std::future::Future; -use crate::memdx::client::Result; +use crate::memdx::client::MemdxResult; use crate::memdx::pendingop::{PendingOp, StandardPendingOp}; use crate::memdx::response::TryFromClientResponse; -pub async fn sync_unary_call(fut: Fut) -> Result +pub async fn sync_unary_call(fut: Fut) -> MemdxResult where RespT: TryFromClientResponse, - Fut: Future>>, + Fut: Future>>, { let mut op = fut.await?; From 7c29068f96dc3c91faf040425313dbc6f455a874 Mon Sep 17 00:00:00 2001 From: Charles Dixon Date: Wed, 17 Jul 2024 16:20:12 +0100 Subject: [PATCH 2/2] Minor refactor to use associated types Motivation ---------- Moving to using associated types in our traits allows us to keep the definition of the traits much tidier and easier to read. Changes ------- Update kvclientpool and kvclientmanager to use associated types. --- sdk/couchbase-core/src/kvclientmanager.rs | 97 +++++++++++++---------- sdk/couchbase-core/src/kvclientpool.rs | 16 ++-- 2 files changed, 66 insertions(+), 47 deletions(-) diff --git a/sdk/couchbase-core/src/kvclientmanager.rs b/sdk/couchbase-core/src/kvclientmanager.rs index e8857791..264f7968 100644 --- a/sdk/couchbase-core/src/kvclientmanager.rs +++ b/sdk/couchbase-core/src/kvclientmanager.rs @@ -1,6 +1,5 @@ use std::collections::HashMap; use std::future::Future; -use std::marker::PhantomData; use std::sync::Arc; use std::time::Duration; @@ -13,7 +12,9 @@ use crate::kvclientpool::{KvClientPool, KvClientPoolConfig, KvClientPoolOptions} use crate::memdx::packet::ResponsePacket; use crate::result::CoreResult; -pub(crate) trait KvClientManager: Sized + Send + Sync { +pub(crate) trait KvClientManager: Sized + Send + Sync { + type Pool: KvClientPool + Send + Sync; + fn new( config: KvClientManagerConfig, opts: KvClientManagerOptions, @@ -22,18 +23,27 @@ pub(crate) trait KvClientManager: Sized + Send + Sync { &self, config: KvClientManagerConfig, ) -> impl Future> + Send; - fn get_client(&self, endpoint: String) -> impl Future>> + Send; - fn get_random_client(&self) -> impl Future>> + Send; + fn get_client( + &self, + endpoint: String, + ) -> impl Future< + Output = CoreResult::Pool as KvClientPool>::Client>>, + > + Send; + fn get_random_client( + &self, + ) -> impl Future< + Output = CoreResult::Pool as KvClientPool>::Client>>, + > + Send; fn shutdown_client( &self, endpoint: String, - client: Arc, + client: Arc<<::Pool as KvClientPool>::Client>, ) -> impl Future> + Send; fn close(&self) -> impl Future> + Send; fn orchestrate_operation( &self, endpoint: String, - operation: impl Fn(Arc) -> Fut, + operation: impl Fn(Arc<<::Pool as KvClientPool>::Client>) -> Fut, ) -> impl Future> where Fut: Future> + Send; @@ -53,35 +63,33 @@ pub(crate) struct KvClientManagerOptions { } #[derive(Debug)] -struct KvClientManagerPool +struct KvClientManagerPool

where - P: KvClientPool, + P: KvClientPool, { config: KvClientPoolConfig, pool: Arc

, - _phantom_client_type: PhantomData, } #[derive(Debug, Default)] -struct KvClientManagerState +struct KvClientManagerState

where - P: KvClientPool, + P: KvClientPool, { - pub client_pools: HashMap>, + pub client_pools: HashMap>, } -pub(crate) struct StdKvClientManager +pub(crate) struct StdKvClientManager

where - P: KvClientPool, + P: KvClientPool, { - state: Mutex>, + state: Mutex>, opts: KvClientManagerOptions, } -impl StdKvClientManager +impl

StdKvClientManager

where - K: KvClient, - P: KvClientPool, + P: KvClientPool, { async fn get_pool(&self, endpoint: String) -> CoreResult> { let state = self.state.lock().await; @@ -107,7 +115,7 @@ where Err(CoreError::Placeholder("Endpoint not known".to_string())) } - async fn create_pool(&self, pool_config: KvClientPoolConfig) -> KvClientManagerPool { + async fn create_pool(&self, pool_config: KvClientPoolConfig) -> KvClientManagerPool

{ let pool = P::new( pool_config.clone(), KvClientPoolOptions { @@ -121,16 +129,16 @@ where KvClientManagerPool { config: pool_config, pool: Arc::new(pool), - _phantom_client_type: Default::default(), } } } -impl KvClientManager for StdKvClientManager +impl

KvClientManager for StdKvClientManager

where - K: KvClient, - P: KvClientPool, + P: KvClientPool, { + type Pool = P; + async fn new(config: KvClientManagerConfig, opts: KvClientManagerOptions) -> CoreResult { let manager = Self { state: Mutex::new(KvClientManagerState { @@ -148,7 +156,7 @@ where let mut old_pools = std::mem::take(&mut guard.client_pools); - let mut new_state = KvClientManagerState:: { + let mut new_state = KvClientManagerState::

{ client_pools: Default::default(), }; @@ -183,19 +191,28 @@ where Ok(()) } - async fn get_client(&self, endpoint: String) -> CoreResult> { + async fn get_client( + &self, + endpoint: String, + ) -> CoreResult::Pool as KvClientPool>::Client>> { let pool = self.get_pool(endpoint).await?; pool.get_client().await } - async fn get_random_client(&self) -> CoreResult> { + async fn get_random_client( + &self, + ) -> CoreResult::Pool as KvClientPool>::Client>> { let pool = self.get_random_pool().await?; pool.get_client().await } - async fn shutdown_client(&self, endpoint: String, client: Arc) -> CoreResult<()> { + async fn shutdown_client( + &self, + endpoint: String, + client: Arc<<::Pool as KvClientPool>::Client>, + ) -> CoreResult<()> { let pool = self.get_pool(endpoint).await?; pool.shutdown_client(client).await; @@ -219,7 +236,7 @@ where async fn orchestrate_operation( &self, endpoint: String, - operation: impl Fn(Arc) -> Fut, + operation: impl Fn(Arc) -> Fut, ) -> CoreResult where Fut: Future> + Send, @@ -316,19 +333,17 @@ mod tests { clients: client_configs, }; - let manager: StdKvClientManager< - NaiveKvClientPool>, - StdKvClient, - > = StdKvClientManager::new( - manger_config, - KvClientManagerOptions { - connect_timeout: Default::default(), - connect_throttle_period: Default::default(), - orphan_handler: Arc::new(orphan_tx), - }, - ) - .await - .unwrap(); + let manager: StdKvClientManager>> = + StdKvClientManager::new( + manger_config, + KvClientManagerOptions { + connect_timeout: Default::default(), + connect_throttle_period: Default::default(), + orphan_handler: Arc::new(orphan_tx), + }, + ) + .await + .unwrap(); let result = manager .orchestrate_operation( diff --git a/sdk/couchbase-core/src/kvclientpool.rs b/sdk/couchbase-core/src/kvclientpool.rs index dfce2b37..a9d2a7b6 100644 --- a/sdk/couchbase-core/src/kvclientpool.rs +++ b/sdk/couchbase-core/src/kvclientpool.rs @@ -14,13 +14,15 @@ use crate::memdx::dispatcher::Dispatcher; use crate::memdx::packet::ResponsePacket; use crate::result::CoreResult; -pub(crate) trait KvClientPool: Sized + Send + Sync { +pub(crate) trait KvClientPool: Sized + Send + Sync { + type Client: KvClient + Send + Sync; + fn new( config: KvClientPoolConfig, opts: KvClientPoolOptions, ) -> impl Future + Send; - fn get_client(&self) -> impl Future>> + Send; - fn shutdown_client(&self, client: Arc) -> impl Future + Send; + fn get_client(&self) -> impl Future>> + Send; + fn shutdown_client(&self, client: Arc) -> impl Future + Send; fn close(&self) -> impl Future> + Send; fn reconfigure( &self, @@ -259,10 +261,12 @@ where } } -impl KvClientPool for NaiveKvClientPool +impl KvClientPool for NaiveKvClientPool where K: KvClient + PartialEq + Sync + Send + 'static, { + type Client = K; + async fn new(config: KvClientPoolConfig, opts: KvClientPoolOptions) -> Self { // TODO: is unbounded the right option? let (on_client_close_tx, mut on_client_close_rx) = mpsc::unbounded_channel(); @@ -285,13 +289,13 @@ where NaiveKvClientPool { inner: clients } } - async fn get_client(&self) -> CoreResult> { + async fn get_client(&self) -> CoreResult> { let mut clients = self.inner.lock().await; clients.get_client().await } - async fn shutdown_client(&self, client: Arc) { + async fn shutdown_client(&self, client: Arc) { let mut clients = self.inner.lock().await; clients.shutdown_client(client).await;