1
- use crate :: routing:: connection_registry:: BoltServer ;
1
+ use crate :: routing:: connection_registry:: { BoltServer , ConnectionRegistry } ;
2
2
use crate :: routing:: load_balancing:: LoadBalancingStrategy ;
3
- use std:: sync:: atomic:: AtomicUsize ;
3
+ use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
4
+ use std:: sync:: Arc ;
4
5
5
- #[ derive( Default ) ]
6
6
pub struct RoundRobinStrategy {
7
+ connection_registry : Arc < ConnectionRegistry > ,
7
8
reader_index : AtomicUsize ,
8
9
writer_index : AtomicUsize ,
9
10
}
10
11
11
12
impl RoundRobinStrategy {
12
- fn select ( servers : & [ BoltServer ] , index : & AtomicUsize ) -> Option < BoltServer > {
13
+ pub fn new ( connection_registry : Arc < ConnectionRegistry > ) -> Self {
14
+ RoundRobinStrategy {
15
+ connection_registry,
16
+ reader_index : AtomicUsize :: new ( 0 ) ,
17
+ writer_index : AtomicUsize :: new ( 0 ) ,
18
+ }
19
+ }
20
+
21
+ fn select ( all_servers : & [ BoltServer ] , servers : & [ BoltServer ] , index : & AtomicUsize ) -> Option < BoltServer > {
13
22
if servers. is_empty ( ) {
14
23
return None ;
15
24
}
16
25
17
- let _ = index. compare_exchange (
18
- 0 ,
19
- servers. len ( ) ,
20
- std:: sync:: atomic:: Ordering :: Relaxed ,
21
- std:: sync:: atomic:: Ordering :: Relaxed ,
22
- ) ;
23
- let i = index. fetch_sub ( 1 , std:: sync:: atomic:: Ordering :: Relaxed ) ;
24
- if let Some ( server) = servers. get ( i - 1 ) {
25
- Some ( server. clone ( ) )
26
- } else {
27
- //reset index
28
- index. store ( servers. len ( ) , std:: sync:: atomic:: Ordering :: Relaxed ) ;
29
- servers. last ( ) . cloned ( )
26
+ let mut used = vec ! [ ] ;
27
+ loop {
28
+ if used. len ( ) >= all_servers. len ( ) {
29
+ return None ; // All servers have been used
30
+ }
31
+ let _ = index. compare_exchange (
32
+ 0 ,
33
+ all_servers. len ( ) ,
34
+ Ordering :: Relaxed ,
35
+ Ordering :: Relaxed ,
36
+ ) ;
37
+ let i = index. fetch_sub ( 1 , Ordering :: Relaxed ) ;
38
+ if let Some ( server) = all_servers. get ( i - 1 ) {
39
+ if servers. contains ( server) {
40
+ return Some ( server. clone ( ) ) ;
41
+ }
42
+ used. push ( server. clone ( ) ) ;
43
+ }
30
44
}
31
45
}
32
46
}
@@ -38,7 +52,13 @@ impl LoadBalancingStrategy for RoundRobinStrategy {
38
52
. filter ( |s| s. role == "READ" )
39
53
. cloned ( )
40
54
. collect :: < Vec < BoltServer > > ( ) ;
41
- Self :: select ( & readers, & self . reader_index )
55
+ let all_readers = self . connection_registry . all_servers ( )
56
+ . iter ( )
57
+ . filter ( |s| s. role == "READ" )
58
+ . cloned ( )
59
+ . collect :: < Vec < BoltServer > > ( ) ;
60
+
61
+ Self :: select ( & all_readers, & readers, & self . reader_index )
42
62
}
43
63
44
64
fn select_writer ( & self , servers : & [ BoltServer ] ) -> Option < BoltServer > {
@@ -47,7 +67,13 @@ impl LoadBalancingStrategy for RoundRobinStrategy {
47
67
. filter ( |s| s. role == "WRITE" )
48
68
. cloned ( )
49
69
. collect :: < Vec < BoltServer > > ( ) ;
50
- Self :: select ( & writers, & self . writer_index )
70
+ let all_writers = self . connection_registry . all_servers ( )
71
+ . iter ( )
72
+ . filter ( |s| s. role == "WRITE" )
73
+ . cloned ( )
74
+ . collect :: < Vec < BoltServer > > ( ) ;
75
+
76
+ Self :: select ( & all_writers, & writers, & self . writer_index )
51
77
}
52
78
}
53
79
@@ -59,54 +85,94 @@ mod tests {
59
85
#[ test]
60
86
fn should_get_next_server ( ) {
61
87
let routers = vec ! [ Server {
62
- addresses: vec![ "192.168.0.1:7688" . to_string( ) ] ,
88
+ addresses: vec![ "server1:7687" . to_string( ) ] ,
89
+ role: "ROUTE" . to_string( ) ,
90
+ } ] ;
91
+ let readers1 = vec ! [ Server {
92
+ addresses: vec![ "server1:7687" . to_string( ) ] ,
93
+ role: "READ" . to_string( ) ,
94
+ } , Server {
95
+ addresses: vec![ "server2:7687" . to_string( ) ] ,
96
+ role: "READ" . to_string( ) ,
97
+ } ] ;
98
+ let writers1 = vec ! [ Server {
99
+ addresses: vec![ "server4:7687" . to_string( ) ] ,
63
100
role: "WRITE" . to_string( ) ,
64
101
} ] ;
65
- let readers = vec ! [ Server {
66
- addresses: vec![
67
- "192.168.0.2:7687 ". to_string( ) ,
68
- "192.168.0.3:7687" . to_string ( ) ,
69
- ] ,
102
+ let readers2 = vec ! [ Server {
103
+ addresses: vec![ "server1:7687" . to_string ( ) ] ,
104
+ role : "READ ". to_string( ) ,
105
+ } , Server {
106
+ addresses : vec! [ "server3:7687" . to_string ( ) ] ,
70
107
role: "READ" . to_string( ) ,
71
108
} ] ;
72
- let writers = vec ! [ Server {
73
- addresses: vec![ "192.168.0.4:7688" . to_string( ) ] ,
109
+
110
+ let writers2 = vec ! [ Server {
111
+ addresses: vec![ "server4:7687" . to_string( ) ] ,
74
112
role: "WRITE" . to_string( ) ,
75
113
} ] ;
76
114
77
- let cluster_routing_table = RoutingTable {
115
+ let routing_table_1 = RoutingTable {
116
+ ttl : 300 ,
117
+ db : Some ( "db-1" . into ( ) ) ,
118
+ servers : routers
119
+ . clone ( )
120
+ . into_iter ( )
121
+ . chain ( readers1. clone ( ) )
122
+ . chain ( writers1. clone ( ) )
123
+ . collect ( ) ,
124
+ } ;
125
+ let routing_table_2 = RoutingTable {
78
126
ttl : 300 ,
79
- db : Some ( "neo4j " . into ( ) ) ,
127
+ db : Some ( "db-2 " . into ( ) ) ,
80
128
servers : routers
81
129
. clone ( )
82
130
. into_iter ( )
83
- . chain ( readers . clone ( ) )
84
- . chain ( writers . clone ( ) )
131
+ . chain ( readers2 . clone ( ) )
132
+ . chain ( writers2 . clone ( ) )
85
133
. collect ( ) ,
86
134
} ;
87
- let all_servers = cluster_routing_table. resolve ( ) ;
88
- assert_eq ! ( all_servers. len( ) , 4 ) ;
89
- let strategy = RoundRobinStrategy :: default ( ) ;
90
135
91
- let reader = strategy. select_reader ( & all_servers) . unwrap ( ) ;
92
- assert_eq ! (
93
- format!( "{}:{}" , reader. address, reader. port) ,
94
- readers[ 0 ] . addresses[ 1 ]
95
- ) ;
96
- let reader = strategy. select_reader ( & all_servers) . unwrap ( ) ;
136
+ let registry = Arc :: new ( ConnectionRegistry :: default ( ) ) ;
137
+
138
+ let mut servers1 = routing_table_1. resolve ( ) ;
139
+ servers1. retain ( |s| s. role == "READ" ) ;
140
+ let mut servers2 = routing_table_2. resolve ( ) ;
141
+ servers2. retain ( |s| s. role == "READ" ) ;
142
+
143
+ let mut all_readers: Vec < BoltServer > = Vec :: new ( ) ;
144
+ for s in servers1. iter ( ) {
145
+ if !all_readers. iter ( ) . any ( |x| x == s) {
146
+ all_readers. push ( s. clone ( ) ) ;
147
+ }
148
+ }
149
+ for s in servers2. iter ( ) {
150
+ if !all_readers. iter ( ) . any ( |x| x == s) {
151
+ all_readers. push ( s. clone ( ) ) ;
152
+ }
153
+ }
154
+ all_readers. retain ( |s| s. role == "READ" ) ;
155
+
156
+ assert_eq ! ( all_readers. len( ) , 3 ) ;
157
+ let strategy = RoundRobinStrategy :: new ( registry. clone ( ) ) ;
158
+
159
+ // select a reader for db-1
160
+ let reader = RoundRobinStrategy :: select ( & all_readers, & servers1, & strategy. reader_index ) . unwrap ( ) ;
97
161
assert_eq ! (
98
- format! ( "{}:{}" , reader. address, reader . port ) ,
99
- readers [ 0 ] . addresses [ 0 ]
162
+ reader. address,
163
+ "server2"
100
164
) ;
101
- let reader = strategy. select_reader ( & all_servers) . unwrap ( ) ;
165
+ // select a reader for db-2
166
+ let reader = RoundRobinStrategy :: select ( & all_readers, & servers2, & strategy. reader_index ) . unwrap ( ) ;
102
167
assert_eq ! (
103
- format! ( "{}:{}" , reader. address, reader . port ) ,
104
- readers [ 0 ] . addresses [ 1 ]
168
+ reader. address,
169
+ "server1"
105
170
) ;
106
- let writer = strategy. select_writer ( & all_servers) . unwrap ( ) ;
171
+ // select another reader for db-1
172
+ let reader = RoundRobinStrategy :: select ( & all_readers, & servers1, & strategy. reader_index ) . unwrap ( ) ;
107
173
assert_eq ! (
108
- format! ( "{}:{}" , writer . address, writer . port ) ,
109
- writers [ 0 ] . addresses [ 0 ]
174
+ reader . address,
175
+ "server2"
110
176
) ;
111
177
}
112
178
}
0 commit comments