Skip to content

Refresh routing table using a separated thread #216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions lib/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
use {
crate::connection::{ConnectionInfo, Routing},
crate::graph::ConnectionPoolManager::Routed,
crate::routing::RoutedConnectionManager,
crate::routing::{ClusterRoutingTableProvider, RoutedConnectionManager},
log::debug,
};

use crate::graph::ConnectionPoolManager::Direct;
Expand Down Expand Up @@ -73,7 +74,11 @@ impl Graph {
&config.tls_config,
)?;
if matches!(info.routing, Routing::Yes(_)) {
let pool = Routed(RoutedConnectionManager::new(&config).await?);
debug!("Routing enabled, creating a routed connection manager");
let pool = Routed(
RoutedConnectionManager::new(&config, Box::new(ClusterRoutingTableProvider))
.await?,
);
Ok(Graph {
config: config.into_live_config(),
pool,
Expand Down
224 changes: 150 additions & 74 deletions lib/src/routing/connection_registry.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use crate::connection::NeoUrl;
use crate::pool::{create_pool, ConnectionPool};
use crate::routing::{RoutingTable, Server};
use crate::routing::routing_table_provider::RoutingTableProvider;
use crate::routing::Server;
use crate::{Config, Error};
use dashmap::DashMap;
use futures::lock::Mutex;
use log::debug;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::sync::mpsc::Sender;

/// Represents a Bolt server, with its address, port and role.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(crate) struct BoltServer {
pub(crate) address: String,
Expand Down Expand Up @@ -36,83 +38,129 @@ impl BoltServer {
}
}

/// A registry of connection pools, indexed by the Bolt server they connect to.
pub type Registry = DashMap<BoltServer, ConnectionPool>;

#[derive(Clone)]
pub(crate) struct ConnectionRegistry {
config: Config,
creation_time: Arc<Mutex<Instant>>,
ttl: Arc<AtomicU64>,
pub(crate) connections: Registry,
}

impl ConnectionRegistry {
pub(crate) fn new(config: &Config) -> Self {
#[allow(dead_code)]
pub(crate) enum RegistryCommand {
Refresh,
Stop,
}

impl Default for ConnectionRegistry {
fn default() -> Self {
ConnectionRegistry {
config: config.clone(),
creation_time: Arc::new(Mutex::new(Instant::now())),
ttl: Arc::new(AtomicU64::new(0)),
connections: DashMap::new(),
connections: Registry::new(),
}
}
}

async fn refresh_routing_table(
config: Config,
registry: Arc<ConnectionRegistry>,
provider: Arc<Box<dyn RoutingTableProvider>>,
) -> Result<u64, Error> {
debug!("Routing table expired or empty, refreshing...");
let routing_table = provider.fetch_routing_table(&config).await?;
debug!("Routing table refreshed: {:?}", routing_table);
let servers = routing_table.resolve();
let url = NeoUrl::parse(config.uri.as_str())?;
// Convert neo4j scheme to bolt scheme to create connection pools.
// We need to use the bolt scheme since we don't want new connections to be routed
let scheme = match url.scheme() {
"neo4j" => "bolt",
"neo4j+s" => "bolt+s",
"neo4j+ssc" => "bolt+ssc",
_ => panic!("Unsupported scheme: {}", url.scheme()),
};

pub(crate) async fn update_if_expired<F, R>(&self, f: F) -> Result<(), Error>
where
F: FnOnce() -> R,
R: std::future::Future<Output = Result<RoutingTable, Error>>,
{
let now = Instant::now();
debug!("Checking if routing table is expired...");
let mut guard = self.creation_time.lock().await;
if self.connections.is_empty()
|| now.duration_since(*guard).as_secs() > self.ttl.load(Ordering::Relaxed)
{
debug!("Routing table expired or empty, refreshing...");
let routing_table = f().await?;
debug!("Routing table refreshed: {:?}", routing_table);
let registry = &self.connections;
let servers = routing_table.resolve();
let url = NeoUrl::parse(self.config.uri.as_str())?;
// Convert neo4j scheme to bolt scheme to create connection pools.
// We need to use the bolt scheme since we don't want new connections to be routed
let scheme = match url.scheme() {
"neo4j" => "bolt",
"neo4j+s" => "bolt+s",
"neo4j+ssc" => "bolt+ssc",
_ => return Err(Error::UnsupportedScheme(url.scheme().to_string())),
};

for server in servers.iter() {
if registry.contains_key(server) {
continue;
for server in servers.iter() {
if registry.connections.contains_key(server) {
continue;
}
let uri = format!("{}://{}:{}", scheme, server.address, server.port);
debug!("Creating pool for server: {}", uri);
registry.connections.insert(
server.clone(),
create_pool(&Config {
uri,
..config.clone()
})
.await?,
);
}
registry.connections.retain(|k, _| servers.contains(k));
debug!(
"Registry updated. New size is {} with TTL {}s",
registry.connections.len(),
routing_table.ttl
);
Ok(routing_table.ttl)
}

pub(crate) async fn start_background_updater(
config: &Config,
registry: Arc<ConnectionRegistry>,
provider: Arc<Box<dyn RoutingTableProvider>>,
) -> Sender<RegistryCommand> {
let config_clone = config.clone();
let (tx, mut rx) = mpsc::channel(1);

// This thread is in charge of refreshing the routing table periodically
tokio::spawn(async move {
let mut ttl =
refresh_routing_table(config_clone.clone(), registry.clone(), provider.clone())
.await
.expect("Failed to get routing table. Exiting...");
debug!("Starting background updater with TTL: {}", ttl);
let mut interval = tokio::time::interval(Duration::from_secs(ttl));
interval.tick().await; // first tick is immediate
loop {
tokio::select! {
// Trigger periodic updates
_ = interval.tick() => {
ttl = match refresh_routing_table(config_clone.clone(), registry.clone(), provider.clone()).await {
Ok(ttl) => ttl,
Err(e) => {
debug!("Failed to refresh routing table: {}", e);
ttl
}
};
interval = tokio::time::interval(Duration::from_secs(ttl)); // recreate interval with the new TTL
}
// Handle forced updates
cmd = rx.recv() => {
match cmd {
Some(RegistryCommand::Refresh) => {
ttl = match refresh_routing_table(config_clone.clone(), registry.clone(), provider.clone()).await {
Ok(ttl) => ttl,
Err(e) => {
debug!("Failed to refresh routing table: {}", e);
ttl
}
};
interval = tokio::time::interval(Duration::from_secs(ttl)); // recreate interval with the new TTL
}
Some(RegistryCommand::Stop) | None => {
debug!("Stopping background updater");
break;
}
}
}
let uri = format!("{}://{}:{}", scheme, server.address, server.port);
debug!("Creating pool for server: {}", uri);
registry.insert(
server.clone(),
create_pool(&Config {
uri,
..self.config.clone()
})
.await?,
);
}
registry.retain(|k, _| servers.contains(k));
let _ = self
.ttl
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |_ttl| {
Some(routing_table.ttl)
})
.unwrap();
debug!(
"Registry updated. New size is {} with TTL {}s",
registry.len(),
routing_table.ttl
);
*guard = now;

interval.tick().await;
}
Ok(())
}
});
tx
}

impl ConnectionRegistry {
/// Retrieve the pool for a specific server.
pub fn get_pool(&self, server: &BoltServer) -> Option<ConnectionPool> {
self.connections.get(server).map(|entry| entry.clone())
Expand All @@ -135,8 +183,30 @@ mod tests {
use super::*;
use crate::auth::ConnectionTLSConfig;
use crate::routing::load_balancing::LoadBalancingStrategy;
use crate::routing::RoundRobinStrategy;
use crate::routing::Server;
use crate::routing::{RoundRobinStrategy, RoutingTable};
use std::future::Future;
use std::pin::Pin;

struct TestRoutingTableProvider {
routing_table: RoutingTable,
}

impl TestRoutingTableProvider {
fn new(routing_table: RoutingTable) -> Self {
TestRoutingTableProvider { routing_table }
}
}

impl RoutingTableProvider for TestRoutingTableProvider {
fn fetch_routing_table(
&self,
_: &Config,
) -> Pin<Box<dyn Future<Output = Result<RoutingTable, Error>> + Send>> {
let routing_table = self.routing_table.clone();
Box::pin(async move { Ok(routing_table) })
}
}

#[tokio::test]
async fn test_available_servers() {
Expand Down Expand Up @@ -165,7 +235,7 @@ mod tests {
role: "ROUTE".to_string(),
}];
let cluster_routing_table = RoutingTable {
ttl: 0,
ttl: 300,
db: None,
servers: readers
.clone()
Expand All @@ -183,11 +253,17 @@ mod tests {
fetch_size: 0,
tls_config: ConnectionTLSConfig::None,
};
let registry = ConnectionRegistry::new(&config);
registry
.update_if_expired(|| async { Ok(cluster_routing_table) })
.await
.unwrap();
let registry = Arc::new(ConnectionRegistry::default());
let ttl = refresh_routing_table(
config.clone(),
registry.clone(),
Arc::new(Box::new(TestRoutingTableProvider::new(
cluster_routing_table,
))),
)
.await
.unwrap();
assert_eq!(ttl, 300);
assert_eq!(registry.connections.len(), 5);
let strategy = RoundRobinStrategy::default();
registry.mark_unavailable(BoltServer::resolve(&writers[0]).first().unwrap());
Expand Down
5 changes: 4 additions & 1 deletion lib/src/routing/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
mod connection_registry;
mod load_balancing;
mod routed_connection_manager;
mod routing_table_provider;

use std::fmt::{Display, Formatter};
#[cfg(feature = "unstable-bolt-protocol-impl-v2")]
use {crate::connection::Routing, serde::Deserialize};
Expand Down Expand Up @@ -29,7 +31,7 @@ pub struct Extra {
pub(crate) imp_user: Option<String>,
}

#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Default, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "unstable-bolt-protocol-impl-v2", derive(Deserialize))]
pub struct RoutingTable {
pub(crate) ttl: u64,
Expand Down Expand Up @@ -163,3 +165,4 @@ use crate::routing::connection_registry::BoltServer;
use crate::{Database, Version};
pub use load_balancing::round_robin_strategy::RoundRobinStrategy;
pub use routed_connection_manager::RoutedConnectionManager;
pub use routing_table_provider::ClusterRoutingTableProvider;
Loading
Loading