1
1
use crate :: connection:: NeoUrl ;
2
2
use crate :: pool:: { create_pool, ConnectionPool } ;
3
- use crate :: routing:: { RoutingTable , Server } ;
3
+ use crate :: routing:: routing_table_provider:: RoutingTableProvider ;
4
+ use crate :: routing:: Server ;
4
5
use crate :: { Config , Error } ;
5
6
use dashmap:: DashMap ;
6
- use futures:: lock:: Mutex ;
7
7
use log:: debug;
8
- use std:: sync:: atomic:: { AtomicU64 , Ordering } ;
9
8
use std:: sync:: Arc ;
10
- use std:: time:: Instant ;
9
+ use std:: time:: Duration ;
10
+ use tokio:: sync:: mpsc;
11
+ use tokio:: sync:: mpsc:: Sender ;
11
12
13
+ /// Represents a Bolt server, with its address, port and role.
12
14
#[ derive( Debug , Clone , PartialEq , Eq , Hash ) ]
13
15
pub ( crate ) struct BoltServer {
14
16
pub ( crate ) address : String ,
@@ -36,83 +38,129 @@ impl BoltServer {
36
38
}
37
39
}
38
40
41
+ /// A registry of connection pools, indexed by the Bolt server they connect to.
39
42
pub type Registry = DashMap < BoltServer , ConnectionPool > ;
40
43
41
44
#[ derive( Clone ) ]
42
45
pub ( crate ) struct ConnectionRegistry {
43
- config : Config ,
44
- creation_time : Arc < Mutex < Instant > > ,
45
- ttl : Arc < AtomicU64 > ,
46
46
pub ( crate ) connections : Registry ,
47
47
}
48
48
49
- impl ConnectionRegistry {
50
- pub ( crate ) fn new ( config : & Config ) -> Self {
49
+ #[ allow( dead_code) ]
50
+ pub ( crate ) enum RegistryCommand {
51
+ Refresh ,
52
+ Stop ,
53
+ }
54
+
55
+ impl Default for ConnectionRegistry {
56
+ fn default ( ) -> Self {
51
57
ConnectionRegistry {
52
- config : config. clone ( ) ,
53
- creation_time : Arc :: new ( Mutex :: new ( Instant :: now ( ) ) ) ,
54
- ttl : Arc :: new ( AtomicU64 :: new ( 0 ) ) ,
55
- connections : DashMap :: new ( ) ,
58
+ connections : Registry :: new ( ) ,
56
59
}
57
60
}
61
+ }
62
+
63
+ async fn refresh_routing_table (
64
+ config : Config ,
65
+ registry : Arc < ConnectionRegistry > ,
66
+ provider : Arc < Box < dyn RoutingTableProvider > > ,
67
+ ) -> Result < u64 , Error > {
68
+ debug ! ( "Routing table expired or empty, refreshing..." ) ;
69
+ let routing_table = provider. fetch_routing_table ( & config) . await ?;
70
+ debug ! ( "Routing table refreshed: {:?}" , routing_table) ;
71
+ let servers = routing_table. resolve ( ) ;
72
+ let url = NeoUrl :: parse ( config. uri . as_str ( ) ) ?;
73
+ // Convert neo4j scheme to bolt scheme to create connection pools.
74
+ // We need to use the bolt scheme since we don't want new connections to be routed
75
+ let scheme = match url. scheme ( ) {
76
+ "neo4j" => "bolt" ,
77
+ "neo4j+s" => "bolt+s" ,
78
+ "neo4j+ssc" => "bolt+ssc" ,
79
+ _ => panic ! ( "Unsupported scheme: {}" , url. scheme( ) ) ,
80
+ } ;
58
81
59
- pub ( crate ) async fn update_if_expired < F , R > ( & self , f : F ) -> Result < ( ) , Error >
60
- where
61
- F : FnOnce ( ) -> R ,
62
- R : std:: future:: Future < Output = Result < RoutingTable , Error > > ,
63
- {
64
- let now = Instant :: now ( ) ;
65
- debug ! ( "Checking if routing table is expired..." ) ;
66
- let mut guard = self . creation_time . lock ( ) . await ;
67
- if self . connections . is_empty ( )
68
- || now. duration_since ( * guard) . as_secs ( ) > self . ttl . load ( Ordering :: Relaxed )
69
- {
70
- debug ! ( "Routing table expired or empty, refreshing..." ) ;
71
- let routing_table = f ( ) . await ?;
72
- debug ! ( "Routing table refreshed: {:?}" , routing_table) ;
73
- let registry = & self . connections ;
74
- let servers = routing_table. resolve ( ) ;
75
- let url = NeoUrl :: parse ( self . config . uri . as_str ( ) ) ?;
76
- // Convert neo4j scheme to bolt scheme to create connection pools.
77
- // We need to use the bolt scheme since we don't want new connections to be routed
78
- let scheme = match url. scheme ( ) {
79
- "neo4j" => "bolt" ,
80
- "neo4j+s" => "bolt+s" ,
81
- "neo4j+ssc" => "bolt+ssc" ,
82
- _ => return Err ( Error :: UnsupportedScheme ( url. scheme ( ) . to_string ( ) ) ) ,
83
- } ;
84
-
85
- for server in servers. iter ( ) {
86
- if registry. contains_key ( server) {
87
- continue ;
82
+ for server in servers. iter ( ) {
83
+ if registry. connections . contains_key ( server) {
84
+ continue ;
85
+ }
86
+ let uri = format ! ( "{}://{}:{}" , scheme, server. address, server. port) ;
87
+ debug ! ( "Creating pool for server: {}" , uri) ;
88
+ registry. connections . insert (
89
+ server. clone ( ) ,
90
+ create_pool ( & Config {
91
+ uri,
92
+ ..config. clone ( )
93
+ } )
94
+ . await ?,
95
+ ) ;
96
+ }
97
+ registry. connections . retain ( |k, _| servers. contains ( k) ) ;
98
+ debug ! (
99
+ "Registry updated. New size is {} with TTL {}s" ,
100
+ registry. connections. len( ) ,
101
+ routing_table. ttl
102
+ ) ;
103
+ Ok ( routing_table. ttl )
104
+ }
105
+
106
+ pub ( crate ) async fn start_background_updater (
107
+ config : & Config ,
108
+ registry : Arc < ConnectionRegistry > ,
109
+ provider : Arc < Box < dyn RoutingTableProvider > > ,
110
+ ) -> Sender < RegistryCommand > {
111
+ let config_clone = config. clone ( ) ;
112
+ let ( tx, mut rx) = mpsc:: channel ( 1 ) ;
113
+
114
+ // This thread is in charge of refreshing the routing table periodically
115
+ tokio:: spawn ( async move {
116
+ let mut ttl =
117
+ refresh_routing_table ( config_clone. clone ( ) , registry. clone ( ) , provider. clone ( ) )
118
+ . await
119
+ . expect ( "Failed to get routing table. Exiting..." ) ;
120
+ debug ! ( "Starting background updater with TTL: {}" , ttl) ;
121
+ let mut interval = tokio:: time:: interval ( Duration :: from_secs ( ttl) ) ;
122
+ interval. tick ( ) . await ; // first tick is immediate
123
+ loop {
124
+ tokio:: select! {
125
+ // Trigger periodic updates
126
+ _ = interval. tick( ) => {
127
+ ttl = match refresh_routing_table( config_clone. clone( ) , registry. clone( ) , provider. clone( ) ) . await {
128
+ Ok ( ttl) => ttl,
129
+ Err ( e) => {
130
+ debug!( "Failed to refresh routing table: {}" , e) ;
131
+ ttl
132
+ }
133
+ } ;
134
+ interval = tokio:: time:: interval( Duration :: from_secs( ttl) ) ; // recreate interval with the new TTL
135
+ }
136
+ // Handle forced updates
137
+ cmd = rx. recv( ) => {
138
+ match cmd {
139
+ Some ( RegistryCommand :: Refresh ) => {
140
+ ttl = match refresh_routing_table( config_clone. clone( ) , registry. clone( ) , provider. clone( ) ) . await {
141
+ Ok ( ttl) => ttl,
142
+ Err ( e) => {
143
+ debug!( "Failed to refresh routing table: {}" , e) ;
144
+ ttl
145
+ }
146
+ } ;
147
+ interval = tokio:: time:: interval( Duration :: from_secs( ttl) ) ; // recreate interval with the new TTL
148
+ }
149
+ Some ( RegistryCommand :: Stop ) | None => {
150
+ debug!( "Stopping background updater" ) ;
151
+ break ;
152
+ }
153
+ }
88
154
}
89
- let uri = format ! ( "{}://{}:{}" , scheme, server. address, server. port) ;
90
- debug ! ( "Creating pool for server: {}" , uri) ;
91
- registry. insert (
92
- server. clone ( ) ,
93
- create_pool ( & Config {
94
- uri,
95
- ..self . config . clone ( )
96
- } )
97
- . await ?,
98
- ) ;
99
155
}
100
- registry. retain ( |k, _| servers. contains ( k) ) ;
101
- let _ = self
102
- . ttl
103
- . fetch_update ( Ordering :: Relaxed , Ordering :: Relaxed , |_ttl| {
104
- Some ( routing_table. ttl )
105
- } )
106
- . unwrap ( ) ;
107
- debug ! (
108
- "Registry updated. New size is {} with TTL {}s" ,
109
- registry. len( ) ,
110
- routing_table. ttl
111
- ) ;
112
- * guard = now;
156
+
157
+ interval. tick ( ) . await ;
113
158
}
114
- Ok ( ( ) )
115
- }
159
+ } ) ;
160
+ tx
161
+ }
162
+
163
+ impl ConnectionRegistry {
116
164
/// Retrieve the pool for a specific server.
117
165
pub fn get_pool ( & self , server : & BoltServer ) -> Option < ConnectionPool > {
118
166
self . connections . get ( server) . map ( |entry| entry. clone ( ) )
@@ -135,8 +183,30 @@ mod tests {
135
183
use super :: * ;
136
184
use crate :: auth:: ConnectionTLSConfig ;
137
185
use crate :: routing:: load_balancing:: LoadBalancingStrategy ;
138
- use crate :: routing:: RoundRobinStrategy ;
139
186
use crate :: routing:: Server ;
187
+ use crate :: routing:: { RoundRobinStrategy , RoutingTable } ;
188
+ use std:: future:: Future ;
189
+ use std:: pin:: Pin ;
190
+
191
+ struct TestRoutingTableProvider {
192
+ routing_table : RoutingTable ,
193
+ }
194
+
195
+ impl TestRoutingTableProvider {
196
+ fn new ( routing_table : RoutingTable ) -> Self {
197
+ TestRoutingTableProvider { routing_table }
198
+ }
199
+ }
200
+
201
+ impl RoutingTableProvider for TestRoutingTableProvider {
202
+ fn fetch_routing_table (
203
+ & self ,
204
+ _: & Config ,
205
+ ) -> Pin < Box < dyn Future < Output = Result < RoutingTable , Error > > + Send > > {
206
+ let routing_table = self . routing_table . clone ( ) ;
207
+ Box :: pin ( async move { Ok ( routing_table) } )
208
+ }
209
+ }
140
210
141
211
#[ tokio:: test]
142
212
async fn test_available_servers ( ) {
@@ -165,7 +235,7 @@ mod tests {
165
235
role: "ROUTE" . to_string( ) ,
166
236
} ] ;
167
237
let cluster_routing_table = RoutingTable {
168
- ttl : 0 ,
238
+ ttl : 300 ,
169
239
db : None ,
170
240
servers : readers
171
241
. clone ( )
@@ -183,11 +253,17 @@ mod tests {
183
253
fetch_size : 0 ,
184
254
tls_config : ConnectionTLSConfig :: None ,
185
255
} ;
186
- let registry = ConnectionRegistry :: new ( & config) ;
187
- registry
188
- . update_if_expired ( || async { Ok ( cluster_routing_table) } )
189
- . await
190
- . unwrap ( ) ;
256
+ let registry = Arc :: new ( ConnectionRegistry :: default ( ) ) ;
257
+ let ttl = refresh_routing_table (
258
+ config. clone ( ) ,
259
+ registry. clone ( ) ,
260
+ Arc :: new ( Box :: new ( TestRoutingTableProvider :: new (
261
+ cluster_routing_table,
262
+ ) ) ) ,
263
+ )
264
+ . await
265
+ . unwrap ( ) ;
266
+ assert_eq ! ( ttl, 300 ) ;
191
267
assert_eq ! ( registry. connections. len( ) , 5 ) ;
192
268
let strategy = RoundRobinStrategy :: default ( ) ;
193
269
registry. mark_unavailable ( BoltServer :: resolve ( & writers[ 0 ] ) . first ( ) . unwrap ( ) ) ;
0 commit comments