Skip to content

Commit 342c638

Browse files
committed
Fix load balancing strategy for multi database support
1 parent 76ddd16 commit 342c638

File tree

4 files changed

+155
-76
lines changed

4 files changed

+155
-76
lines changed

lib/src/pool.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ pub fn create_pool(config: &Config) -> Result<ConnectionPool> {
6363
&config.tls_config,
6464
)?;
6565
info!(
66-
"creating connection pool with max size {}",
66+
"creating connection pool for node {} with max size {}",
67+
config.uri,
6768
config.max_connections
6869
);
6970
Ok(ConnectionPool::builder(mgr)

lib/src/routing/connection_registry.rs

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@ use crate::routing::{RoutingTable, Server};
55
use crate::{Config, Database, Error};
66
use dashmap::DashMap;
77
use log::debug;
8+
use std::fmt::Debug;
89
use std::sync::{Arc, RwLock};
910
use std::time::Duration;
1011
use tokio::sync::mpsc;
1112
use tokio::sync::mpsc::Sender;
1213

1314
/// Represents a Bolt server, with its address, port and role.
14-
#[derive(Debug, Clone)]
15+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1516
pub(crate) struct BoltServer {
1617
pub(crate) address: String,
1718
pub(crate) port: u16,
@@ -35,23 +36,12 @@ impl BoltServer {
3536
})
3637
.collect()
3738
}
38-
}
39-
40-
impl Eq for BoltServer {}
41-
42-
impl PartialEq for BoltServer {
43-
fn eq(&self, other: &Self) -> bool {
39+
40+
pub fn has_same_address(&self, other: &Self) -> bool {
4441
self.address == other.address && self.port == other.port
4542
}
4643
}
4744

48-
impl std::hash::Hash for BoltServer {
49-
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
50-
self.address.hash(state);
51-
self.port.hash(state);
52-
}
53-
}
54-
5545
/// A registry of connection pools, indexed by the Bolt server they connect to.
5646
pub type PoolRegistry = DashMap<BoltServer, ConnectionPool>;
5747
/// A map of registries, indexed by the database name.
@@ -138,6 +128,17 @@ async fn refresh_routing_tables(
138128
"No servers available in the routing table".to_string(),
139129
));
140130
}
131+
132+
// purge the pool registry of servers that are no longer in the routing tables
133+
let all_servers: Vec<BoltServer> = connection_registry
134+
.databases
135+
.iter()
136+
.flat_map(|kv| kv.value().clone())
137+
.collect();
138+
connection_registry
139+
.pool_registry
140+
.retain(|server, _| all_servers.contains(server));
141+
141142
Ok(*ttls.iter().min().unwrap())
142143
}
143144

@@ -179,9 +180,9 @@ async fn refresh_routing_table(
179180
})?,
180181
);
181182
}
182-
registry.retain(|k, _| servers.contains(k));
183183
debug!(
184-
"Registry updated. New size is {} with TTL {}s",
184+
"Registry updated for database {}. New size is {} with TTL {}s",
185+
db.as_ref().map_or("default".to_string(), |d| d.to_string()),
185186
registry.len(),
186187
routing_table.ttl
187188
);
@@ -271,7 +272,7 @@ impl ConnectionRegistry {
271272
if let Some(index) = self
272273
.databases
273274
.get(db_name.as_str())
274-
.and_then(|vec| vec.iter().position(|s| s == server))
275+
.and_then(|vec| vec.iter().position(|s| server.has_same_address(s)))
275276
{
276277
debug!("Marking server as available: {:?}", server);
277278
self.databases
@@ -300,6 +301,13 @@ impl ConnectionRegistry {
300301
}
301302
}
302303

304+
pub fn all_servers(&self) -> Vec<BoltServer> {
305+
self.pool_registry
306+
.iter()
307+
.map(|kv| kv.key().clone())
308+
.collect::<Vec<BoltServer>>()
309+
}
310+
303311
fn get_db_name(&self, db: Option<Database>) -> String {
304312
db.as_ref()
305313
.map(|d| d.to_string())
@@ -419,7 +427,7 @@ mod tests {
419427
let servers = registry.servers(None);
420428
assert_eq!(servers.len(), 5);
421429

422-
let strategy = RoundRobinStrategy::default();
430+
let strategy = RoundRobinStrategy::new(registry.clone());
423431
registry.mark_unavailable(BoltServer::resolve(&writers[0]).first().unwrap(), None);
424432
let servers = registry.servers(None);
425433
assert_eq!(servers.len(), 4);
@@ -540,7 +548,7 @@ mod tests {
540548
assert_eq!(servers2.len(), 5); // 2 readers, 2 writers, 1 router
541549
assert_eq!(servers2.first().unwrap().address, "host5");
542550

543-
let strategy = RoundRobinStrategy::default();
551+
let strategy = RoundRobinStrategy::new(registry.clone());
544552
let writer = strategy
545553
.select_writer(&registry.servers(Some("db1".into())))
546554
.unwrap();
Lines changed: 114 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,46 @@
1-
use crate::routing::connection_registry::BoltServer;
1+
use crate::routing::connection_registry::{BoltServer, ConnectionRegistry};
22
use crate::routing::load_balancing::LoadBalancingStrategy;
3-
use std::sync::atomic::AtomicUsize;
3+
use std::sync::atomic::{AtomicUsize, Ordering};
4+
use std::sync::Arc;
45

5-
#[derive(Default)]
66
pub struct RoundRobinStrategy {
7+
connection_registry: Arc<ConnectionRegistry>,
78
reader_index: AtomicUsize,
89
writer_index: AtomicUsize,
910
}
1011

1112
impl RoundRobinStrategy {
12-
fn select(servers: &[BoltServer], index: &AtomicUsize) -> Option<BoltServer> {
13+
pub fn new(connection_registry: Arc<ConnectionRegistry>) -> Self {
14+
RoundRobinStrategy {
15+
connection_registry,
16+
reader_index: AtomicUsize::new(0),
17+
writer_index: AtomicUsize::new(0),
18+
}
19+
}
20+
21+
fn select(all_servers: &[BoltServer], servers: &[BoltServer], index: &AtomicUsize) -> Option<BoltServer> {
1322
if servers.is_empty() {
1423
return None;
1524
}
1625

17-
let _ = index.compare_exchange(
18-
0,
19-
servers.len(),
20-
std::sync::atomic::Ordering::Relaxed,
21-
std::sync::atomic::Ordering::Relaxed,
22-
);
23-
let i = index.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
24-
if let Some(server) = servers.get(i - 1) {
25-
Some(server.clone())
26-
} else {
27-
//reset index
28-
index.store(servers.len(), std::sync::atomic::Ordering::Relaxed);
29-
servers.last().cloned()
26+
let mut used = vec![];
27+
loop {
28+
if used.len() >= all_servers.len() {
29+
return None; // All servers have been used
30+
}
31+
let _ = index.compare_exchange(
32+
0,
33+
all_servers.len(),
34+
Ordering::Relaxed,
35+
Ordering::Relaxed,
36+
);
37+
let i = index.fetch_sub(1, Ordering::Relaxed);
38+
if let Some(server) = all_servers.get(i - 1) {
39+
if servers.contains(server) {
40+
return Some(server.clone());
41+
}
42+
used.push(server.clone());
43+
}
3044
}
3145
}
3246
}
@@ -38,7 +52,13 @@ impl LoadBalancingStrategy for RoundRobinStrategy {
3852
.filter(|s| s.role == "READ")
3953
.cloned()
4054
.collect::<Vec<BoltServer>>();
41-
Self::select(&readers, &self.reader_index)
55+
let all_readers = self.connection_registry.all_servers()
56+
.iter()
57+
.filter(|s| s.role == "READ")
58+
.cloned()
59+
.collect::<Vec<BoltServer>>();
60+
61+
Self::select(&all_readers, &readers, &self.reader_index)
4262
}
4363

4464
fn select_writer(&self, servers: &[BoltServer]) -> Option<BoltServer> {
@@ -47,7 +67,13 @@ impl LoadBalancingStrategy for RoundRobinStrategy {
4767
.filter(|s| s.role == "WRITE")
4868
.cloned()
4969
.collect::<Vec<BoltServer>>();
50-
Self::select(&writers, &self.writer_index)
70+
let all_writers = self.connection_registry.all_servers()
71+
.iter()
72+
.filter(|s| s.role == "WRITE")
73+
.cloned()
74+
.collect::<Vec<BoltServer>>();
75+
76+
Self::select(&all_writers, &writers, &self.writer_index)
5177
}
5278
}
5379

@@ -59,54 +85,94 @@ mod tests {
5985
#[test]
6086
fn should_get_next_server() {
6187
let routers = vec![Server {
62-
addresses: vec!["192.168.0.1:7688".to_string()],
88+
addresses: vec!["server1:7687".to_string()],
89+
role: "ROUTE".to_string(),
90+
}];
91+
let readers1 = vec![Server {
92+
addresses: vec!["server1:7687".to_string()],
93+
role: "READ".to_string(),
94+
}, Server {
95+
addresses: vec!["server2:7687".to_string()],
96+
role: "READ".to_string(),
97+
}];
98+
let writers1 = vec![Server {
99+
addresses: vec!["server4:7687".to_string()],
63100
role: "WRITE".to_string(),
64101
}];
65-
let readers = vec![Server {
66-
addresses: vec![
67-
"192.168.0.2:7687".to_string(),
68-
"192.168.0.3:7687".to_string(),
69-
],
102+
let readers2 = vec![Server {
103+
addresses: vec!["server1:7687".to_string()],
104+
role: "READ".to_string(),
105+
}, Server {
106+
addresses: vec!["server3:7687".to_string()],
70107
role: "READ".to_string(),
71108
}];
72-
let writers = vec![Server {
73-
addresses: vec!["192.168.0.4:7688".to_string()],
109+
110+
let writers2 = vec![Server {
111+
addresses: vec!["server4:7687".to_string()],
74112
role: "WRITE".to_string(),
75113
}];
76114

77-
let cluster_routing_table = RoutingTable {
115+
let routing_table_1 = RoutingTable {
116+
ttl: 300,
117+
db: Some("db-1".into()),
118+
servers: routers
119+
.clone()
120+
.into_iter()
121+
.chain(readers1.clone())
122+
.chain(writers1.clone())
123+
.collect(),
124+
};
125+
let routing_table_2 = RoutingTable {
78126
ttl: 300,
79-
db: Some("neo4j".into()),
127+
db: Some("db-2".into()),
80128
servers: routers
81129
.clone()
82130
.into_iter()
83-
.chain(readers.clone())
84-
.chain(writers.clone())
131+
.chain(readers2.clone())
132+
.chain(writers2.clone())
85133
.collect(),
86134
};
87-
let all_servers = cluster_routing_table.resolve();
88-
assert_eq!(all_servers.len(), 4);
89-
let strategy = RoundRobinStrategy::default();
90135

91-
let reader = strategy.select_reader(&all_servers).unwrap();
92-
assert_eq!(
93-
format!("{}:{}", reader.address, reader.port),
94-
readers[0].addresses[1]
95-
);
96-
let reader = strategy.select_reader(&all_servers).unwrap();
136+
let registry = Arc::new(ConnectionRegistry::default());
137+
138+
let mut servers1 = routing_table_1.resolve();
139+
servers1.retain(|s| s.role == "READ");
140+
let mut servers2 = routing_table_2.resolve();
141+
servers2.retain(|s| s.role == "READ");
142+
143+
let mut all_readers: Vec<BoltServer> = Vec::new();
144+
for s in servers1.iter() {
145+
if !all_readers.iter().any(|x| x == s) {
146+
all_readers.push(s.clone());
147+
}
148+
}
149+
for s in servers2.iter() {
150+
if !all_readers.iter().any(|x| x == s) {
151+
all_readers.push(s.clone());
152+
}
153+
}
154+
all_readers.retain(|s| s.role == "READ");
155+
156+
assert_eq!(all_readers.len(), 3);
157+
let strategy = RoundRobinStrategy::new(registry.clone());
158+
159+
// select a reader for db-1
160+
let reader = RoundRobinStrategy::select(&all_readers, &servers1, &strategy.reader_index).unwrap();
97161
assert_eq!(
98-
format!("{}:{}", reader.address, reader.port),
99-
readers[0].addresses[0]
162+
reader.address,
163+
"server2"
100164
);
101-
let reader = strategy.select_reader(&all_servers).unwrap();
165+
// select a reader for db-2
166+
let reader = RoundRobinStrategy::select(&all_readers, &servers2, &strategy.reader_index).unwrap();
102167
assert_eq!(
103-
format!("{}:{}", reader.address, reader.port),
104-
readers[0].addresses[1]
168+
reader.address,
169+
"server1"
105170
);
106-
let writer = strategy.select_writer(&all_servers).unwrap();
171+
// select another reader for db-1
172+
let reader = RoundRobinStrategy::select(&all_readers, &servers1, &strategy.reader_index).unwrap();
107173
assert_eq!(
108-
format!("{}:{}", writer.address, writer.port),
109-
writers[0].addresses[0]
174+
reader.address,
175+
"server2"
110176
);
111177
}
112178
}

lib/src/routing/routed_connection_manager.rs

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ pub struct RoutedConnectionManager {
2323
channel: Sender<RegistryCommand>,
2424
}
2525

26+
const ROUTING_TABLE_MAX_WAIT_TIME_MS: i32 = 5000;
27+
2628
impl RoutedConnectionManager {
2729
pub fn new(config: &Config, provider: Arc<dyn RoutingTableProvider>) -> Result<Self, Error> {
2830
let backoff = Arc::new(
@@ -37,7 +39,7 @@ impl RoutedConnectionManager {
3739
let connection_registry = Arc::new(ConnectionRegistry::default());
3840
let channel = start_background_updater(config, connection_registry.clone(), provider);
3941
Ok(RoutedConnectionManager {
40-
load_balancing_strategy: Arc::new(RoundRobinStrategy::default()),
42+
load_balancing_strategy: Arc::new(RoundRobinStrategy::new(connection_registry.clone())),
4143
bookmarks: Arc::new(Mutex::new(vec![])),
4244
connection_registry,
4345
backoff,
@@ -76,15 +78,17 @@ impl RoutedConnectionManager {
7678
if servers.is_empty() {
7779
// the first time we need to wait until we get the routing table
7880
tokio::time::sleep(Duration::from_millis(10)).await;
79-
attempts += 1;
80-
if attempts > 500 {
81-
// 5 seconds max wait time
81+
attempts += 10;
82+
if attempts > ROUTING_TABLE_MAX_WAIT_TIME_MS {
83+
// 5 seconds max wait time by default
8284
error!(
83-
"Failed to get a connection after 5 seconds, routing table is still empty"
85+
"Failed to get a connection after {} seconds, routing table is still empty",
86+
ROUTING_TABLE_MAX_WAIT_TIME_MS / 1000
8487
);
85-
return Err(Error::ServerUnavailableError(
86-
"Routing table is still empty after 5 seconds".to_string(),
87-
));
88+
return Err(Error::ServerUnavailableError(format!(
89+
"Routing table is still empty after {} seconds",
90+
ROUTING_TABLE_MAX_WAIT_TIME_MS / 1000
91+
)));
8892
}
8993
continue;
9094
}

0 commit comments

Comments
 (0)