Skip to content

Commit 062ba43

Browse files
authored
Refresh routing table using a separated thread (#216)
* Refresh routing table using a separated thread * Fix compilation issue * Sto background task when the refresh command fails * Handle None in channel
1 parent 392e977 commit 062ba43

File tree

5 files changed

+241
-110
lines changed

5 files changed

+241
-110
lines changed

lib/src/graph.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
use {
33
crate::connection::{ConnectionInfo, Routing},
44
crate::graph::ConnectionPoolManager::Routed,
5-
crate::routing::RoutedConnectionManager,
5+
crate::routing::{ClusterRoutingTableProvider, RoutedConnectionManager},
6+
log::debug,
67
};
78

89
use crate::graph::ConnectionPoolManager::Direct;
@@ -73,7 +74,11 @@ impl Graph {
7374
&config.tls_config,
7475
)?;
7576
if matches!(info.routing, Routing::Yes(_)) {
76-
let pool = Routed(RoutedConnectionManager::new(&config).await?);
77+
debug!("Routing enabled, creating a routed connection manager");
78+
let pool = Routed(
79+
RoutedConnectionManager::new(&config, Box::new(ClusterRoutingTableProvider))
80+
.await?,
81+
);
7782
Ok(Graph {
7883
config: config.into_live_config(),
7984
pool,

lib/src/routing/connection_registry.rs

Lines changed: 150 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
use crate::connection::NeoUrl;
22
use crate::pool::{create_pool, ConnectionPool};
3-
use crate::routing::{RoutingTable, Server};
3+
use crate::routing::routing_table_provider::RoutingTableProvider;
4+
use crate::routing::Server;
45
use crate::{Config, Error};
56
use dashmap::DashMap;
6-
use futures::lock::Mutex;
77
use log::debug;
8-
use std::sync::atomic::{AtomicU64, Ordering};
98
use std::sync::Arc;
10-
use std::time::Instant;
9+
use std::time::Duration;
10+
use tokio::sync::mpsc;
11+
use tokio::sync::mpsc::Sender;
1112

13+
/// Represents a Bolt server, with its address, port and role.
1214
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1315
pub(crate) struct BoltServer {
1416
pub(crate) address: String,
@@ -36,83 +38,129 @@ impl BoltServer {
3638
}
3739
}
3840

41+
/// A registry of connection pools, indexed by the Bolt server they connect to.
3942
pub type Registry = DashMap<BoltServer, ConnectionPool>;
4043

4144
#[derive(Clone)]
4245
pub(crate) struct ConnectionRegistry {
43-
config: Config,
44-
creation_time: Arc<Mutex<Instant>>,
45-
ttl: Arc<AtomicU64>,
4646
pub(crate) connections: Registry,
4747
}
4848

49-
impl ConnectionRegistry {
50-
pub(crate) fn new(config: &Config) -> Self {
49+
#[allow(dead_code)]
50+
pub(crate) enum RegistryCommand {
51+
Refresh,
52+
Stop,
53+
}
54+
55+
impl Default for ConnectionRegistry {
56+
fn default() -> Self {
5157
ConnectionRegistry {
52-
config: config.clone(),
53-
creation_time: Arc::new(Mutex::new(Instant::now())),
54-
ttl: Arc::new(AtomicU64::new(0)),
55-
connections: DashMap::new(),
58+
connections: Registry::new(),
5659
}
5760
}
61+
}
62+
63+
async fn refresh_routing_table(
64+
config: Config,
65+
registry: Arc<ConnectionRegistry>,
66+
provider: Arc<Box<dyn RoutingTableProvider>>,
67+
) -> Result<u64, Error> {
68+
debug!("Routing table expired or empty, refreshing...");
69+
let routing_table = provider.fetch_routing_table(&config).await?;
70+
debug!("Routing table refreshed: {:?}", routing_table);
71+
let servers = routing_table.resolve();
72+
let url = NeoUrl::parse(config.uri.as_str())?;
73+
// Convert neo4j scheme to bolt scheme to create connection pools.
74+
// We need to use the bolt scheme since we don't want new connections to be routed
75+
let scheme = match url.scheme() {
76+
"neo4j" => "bolt",
77+
"neo4j+s" => "bolt+s",
78+
"neo4j+ssc" => "bolt+ssc",
79+
_ => panic!("Unsupported scheme: {}", url.scheme()),
80+
};
5881

59-
pub(crate) async fn update_if_expired<F, R>(&self, f: F) -> Result<(), Error>
60-
where
61-
F: FnOnce() -> R,
62-
R: std::future::Future<Output = Result<RoutingTable, Error>>,
63-
{
64-
let now = Instant::now();
65-
debug!("Checking if routing table is expired...");
66-
let mut guard = self.creation_time.lock().await;
67-
if self.connections.is_empty()
68-
|| now.duration_since(*guard).as_secs() > self.ttl.load(Ordering::Relaxed)
69-
{
70-
debug!("Routing table expired or empty, refreshing...");
71-
let routing_table = f().await?;
72-
debug!("Routing table refreshed: {:?}", routing_table);
73-
let registry = &self.connections;
74-
let servers = routing_table.resolve();
75-
let url = NeoUrl::parse(self.config.uri.as_str())?;
76-
// Convert neo4j scheme to bolt scheme to create connection pools.
77-
// We need to use the bolt scheme since we don't want new connections to be routed
78-
let scheme = match url.scheme() {
79-
"neo4j" => "bolt",
80-
"neo4j+s" => "bolt+s",
81-
"neo4j+ssc" => "bolt+ssc",
82-
_ => return Err(Error::UnsupportedScheme(url.scheme().to_string())),
83-
};
84-
85-
for server in servers.iter() {
86-
if registry.contains_key(server) {
87-
continue;
82+
for server in servers.iter() {
83+
if registry.connections.contains_key(server) {
84+
continue;
85+
}
86+
let uri = format!("{}://{}:{}", scheme, server.address, server.port);
87+
debug!("Creating pool for server: {}", uri);
88+
registry.connections.insert(
89+
server.clone(),
90+
create_pool(&Config {
91+
uri,
92+
..config.clone()
93+
})
94+
.await?,
95+
);
96+
}
97+
registry.connections.retain(|k, _| servers.contains(k));
98+
debug!(
99+
"Registry updated. New size is {} with TTL {}s",
100+
registry.connections.len(),
101+
routing_table.ttl
102+
);
103+
Ok(routing_table.ttl)
104+
}
105+
106+
pub(crate) async fn start_background_updater(
107+
config: &Config,
108+
registry: Arc<ConnectionRegistry>,
109+
provider: Arc<Box<dyn RoutingTableProvider>>,
110+
) -> Sender<RegistryCommand> {
111+
let config_clone = config.clone();
112+
let (tx, mut rx) = mpsc::channel(1);
113+
114+
// This thread is in charge of refreshing the routing table periodically
115+
tokio::spawn(async move {
116+
let mut ttl =
117+
refresh_routing_table(config_clone.clone(), registry.clone(), provider.clone())
118+
.await
119+
.expect("Failed to get routing table. Exiting...");
120+
debug!("Starting background updater with TTL: {}", ttl);
121+
let mut interval = tokio::time::interval(Duration::from_secs(ttl));
122+
interval.tick().await; // first tick is immediate
123+
loop {
124+
tokio::select! {
125+
// Trigger periodic updates
126+
_ = interval.tick() => {
127+
ttl = match refresh_routing_table(config_clone.clone(), registry.clone(), provider.clone()).await {
128+
Ok(ttl) => ttl,
129+
Err(e) => {
130+
debug!("Failed to refresh routing table: {}", e);
131+
ttl
132+
}
133+
};
134+
interval = tokio::time::interval(Duration::from_secs(ttl)); // recreate interval with the new TTL
135+
}
136+
// Handle forced updates
137+
cmd = rx.recv() => {
138+
match cmd {
139+
Some(RegistryCommand::Refresh) => {
140+
ttl = match refresh_routing_table(config_clone.clone(), registry.clone(), provider.clone()).await {
141+
Ok(ttl) => ttl,
142+
Err(e) => {
143+
debug!("Failed to refresh routing table: {}", e);
144+
ttl
145+
}
146+
};
147+
interval = tokio::time::interval(Duration::from_secs(ttl)); // recreate interval with the new TTL
148+
}
149+
Some(RegistryCommand::Stop) | None => {
150+
debug!("Stopping background updater");
151+
break;
152+
}
153+
}
88154
}
89-
let uri = format!("{}://{}:{}", scheme, server.address, server.port);
90-
debug!("Creating pool for server: {}", uri);
91-
registry.insert(
92-
server.clone(),
93-
create_pool(&Config {
94-
uri,
95-
..self.config.clone()
96-
})
97-
.await?,
98-
);
99155
}
100-
registry.retain(|k, _| servers.contains(k));
101-
let _ = self
102-
.ttl
103-
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |_ttl| {
104-
Some(routing_table.ttl)
105-
})
106-
.unwrap();
107-
debug!(
108-
"Registry updated. New size is {} with TTL {}s",
109-
registry.len(),
110-
routing_table.ttl
111-
);
112-
*guard = now;
156+
157+
interval.tick().await;
113158
}
114-
Ok(())
115-
}
159+
});
160+
tx
161+
}
162+
163+
impl ConnectionRegistry {
116164
/// Retrieve the pool for a specific server.
117165
pub fn get_pool(&self, server: &BoltServer) -> Option<ConnectionPool> {
118166
self.connections.get(server).map(|entry| entry.clone())
@@ -135,8 +183,30 @@ mod tests {
135183
use super::*;
136184
use crate::auth::ConnectionTLSConfig;
137185
use crate::routing::load_balancing::LoadBalancingStrategy;
138-
use crate::routing::RoundRobinStrategy;
139186
use crate::routing::Server;
187+
use crate::routing::{RoundRobinStrategy, RoutingTable};
188+
use std::future::Future;
189+
use std::pin::Pin;
190+
191+
struct TestRoutingTableProvider {
192+
routing_table: RoutingTable,
193+
}
194+
195+
impl TestRoutingTableProvider {
196+
fn new(routing_table: RoutingTable) -> Self {
197+
TestRoutingTableProvider { routing_table }
198+
}
199+
}
200+
201+
impl RoutingTableProvider for TestRoutingTableProvider {
202+
fn fetch_routing_table(
203+
&self,
204+
_: &Config,
205+
) -> Pin<Box<dyn Future<Output = Result<RoutingTable, Error>> + Send>> {
206+
let routing_table = self.routing_table.clone();
207+
Box::pin(async move { Ok(routing_table) })
208+
}
209+
}
140210

141211
#[tokio::test]
142212
async fn test_available_servers() {
@@ -165,7 +235,7 @@ mod tests {
165235
role: "ROUTE".to_string(),
166236
}];
167237
let cluster_routing_table = RoutingTable {
168-
ttl: 0,
238+
ttl: 300,
169239
db: None,
170240
servers: readers
171241
.clone()
@@ -183,11 +253,17 @@ mod tests {
183253
fetch_size: 0,
184254
tls_config: ConnectionTLSConfig::None,
185255
};
186-
let registry = ConnectionRegistry::new(&config);
187-
registry
188-
.update_if_expired(|| async { Ok(cluster_routing_table) })
189-
.await
190-
.unwrap();
256+
let registry = Arc::new(ConnectionRegistry::default());
257+
let ttl = refresh_routing_table(
258+
config.clone(),
259+
registry.clone(),
260+
Arc::new(Box::new(TestRoutingTableProvider::new(
261+
cluster_routing_table,
262+
))),
263+
)
264+
.await
265+
.unwrap();
266+
assert_eq!(ttl, 300);
191267
assert_eq!(registry.connections.len(), 5);
192268
let strategy = RoundRobinStrategy::default();
193269
registry.mark_unavailable(BoltServer::resolve(&writers[0]).first().unwrap());

lib/src/routing/mod.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
mod connection_registry;
22
mod load_balancing;
33
mod routed_connection_manager;
4+
mod routing_table_provider;
5+
46
use std::fmt::{Display, Formatter};
57
#[cfg(feature = "unstable-bolt-protocol-impl-v2")]
68
use {crate::connection::Routing, serde::Deserialize};
@@ -29,7 +31,7 @@ pub struct Extra {
2931
pub(crate) imp_user: Option<String>,
3032
}
3133

32-
#[derive(Debug, Clone, PartialEq, Eq)]
34+
#[derive(Debug, Default, Clone, PartialEq, Eq)]
3335
#[cfg_attr(feature = "unstable-bolt-protocol-impl-v2", derive(Deserialize))]
3436
pub struct RoutingTable {
3537
pub(crate) ttl: u64,
@@ -163,3 +165,4 @@ use crate::routing::connection_registry::BoltServer;
163165
use crate::{Database, Version};
164166
pub use load_balancing::round_robin_strategy::RoundRobinStrategy;
165167
pub use routed_connection_manager::RoutedConnectionManager;
168+
pub use routing_table_provider::ClusterRoutingTableProvider;

0 commit comments

Comments
 (0)