@@ -152,15 +152,7 @@ impl ClusterConnection {
152
152
) ) ) ;
153
153
} ;
154
154
155
- let command_name = command_info. name . to_string ( ) ;
156
-
157
- let request_policy = command_info. command_tips . iter ( ) . find_map ( |tip| {
158
- if let CommandTip :: RequestPolicy ( request_policy) = tip {
159
- Some ( request_policy)
160
- } else {
161
- None
162
- }
163
- } ) ;
155
+ let command_name = command_info. name . clone ( ) ;
164
156
165
157
let node_idx = self . get_random_node_index ( ) ;
166
158
let keys = self
@@ -171,6 +163,14 @@ impl ClusterConnection {
171
163
172
164
debug ! ( "[{}] keys: {keys:?}, slots: {slots:?}" , self . tag) ;
173
165
166
+ let request_policy = command_info. command_tips . iter ( ) . find_map ( |tip| {
167
+ if let CommandTip :: RequestPolicy ( request_policy) = tip {
168
+ Some ( request_policy)
169
+ } else {
170
+ None
171
+ }
172
+ } ) ;
173
+
174
174
if let Some ( request_policy) = request_policy {
175
175
match request_policy {
176
176
RequestPolicy :: AllNodes => {
@@ -205,7 +205,7 @@ impl ClusterConnection {
205
205
206
206
pub async fn write_batch (
207
207
& mut self ,
208
- commands : impl Iterator < Item = & mut Command > ,
208
+ commands : SmallVec < [ & mut Command ; 10 ] > ,
209
209
retry_reasons : & [ RetryReason ] ,
210
210
) -> Result < ( ) > {
211
211
if retry_reasons. iter ( ) . any ( |r| {
@@ -231,8 +231,39 @@ impl ClusterConnection {
231
231
} )
232
232
. collect :: < Vec < _ > > ( ) ;
233
233
234
- for command in commands {
235
- self . internal_write ( command, & ask_reasons) . await ?;
234
+ if commands. len ( ) > 1 && commands[ 0 ] . name == "MULTI" {
235
+ let node_idx = self . get_random_node_index ( ) ;
236
+ let keys = self
237
+ . command_info_manager
238
+ . extract_keys ( commands[ 1 ] , & mut self . nodes [ node_idx] . connection )
239
+ . await ?;
240
+ let slots = Self :: hash_slots ( & keys) ;
241
+ if slots. is_empty ( ) || !slots. windows ( 2 ) . all ( |s| s[ 0 ] == s[ 1 ] ) {
242
+ return Err ( Error :: Client ( format ! (
243
+ "[{}] Cannot execute transaction with mismatched key slots" ,
244
+ self . tag
245
+ ) ) ) ;
246
+ }
247
+ let ref_slot = slots[ 0 ] ;
248
+
249
+ for command in commands {
250
+ let keys = self
251
+ . command_info_manager
252
+ . extract_keys ( command, & mut self . nodes [ node_idx] . connection )
253
+ . await ?;
254
+ self . no_request_policy (
255
+ command,
256
+ command. name . to_string ( ) ,
257
+ keys,
258
+ SmallVec :: from_slice ( & [ ref_slot] ) ,
259
+ & ask_reasons,
260
+ )
261
+ . await ?;
262
+ }
263
+ } else {
264
+ for command in commands {
265
+ self . internal_write ( command, & ask_reasons) . await ?;
266
+ }
236
267
}
237
268
238
269
Ok ( ( ) )
@@ -308,7 +339,7 @@ impl ClusterConnection {
308
339
Ok ( ( ) )
309
340
}
310
341
311
- /// The client should execute the command on several shards.
342
+ /// The client should execute the command on multiple shards.
312
343
/// The shards that execute the command are determined by the hash slots of its input key name arguments.
313
344
/// Examples for such commands include MSET, MGET and DEL.
314
345
/// However, note that SUNIONSTORE isn't considered as multi_shard because all of its keys must belong to the same hash slot.
@@ -486,24 +517,28 @@ impl ClusterConnection {
486
517
487
518
let node_id = & self . nodes [ node_idx] . id ;
488
519
489
- let Some ( ( req_idx, sub_req_idx) ) = self
490
- . pending_requests
491
- . iter ( )
492
- . enumerate ( )
493
- . find_map ( |( req_idx, req) | {
494
- let sub_req_idx = req
495
- . sub_requests
496
- . iter ( )
497
- . position ( |sr| sr. node_id == * node_id && sr. result . is_none ( ) ) ?;
498
- Some ( ( req_idx, sub_req_idx) )
499
- } ) else {
500
- log:: error!( "[{}] Received unexpected message: {result:?} from {}" ,
501
- self . tag, self . nodes[ node_idx] . connection. tag( ) ) ;
502
- return Some ( Err ( Error :: Client ( format ! (
503
- "[{}] Received unexpected message" ,
504
- self . tag
505
- ) ) ) ) ;
506
- } ;
520
+ let Some ( ( req_idx, sub_req_idx) ) =
521
+ self . pending_requests
522
+ . iter ( )
523
+ . enumerate ( )
524
+ . find_map ( |( req_idx, req) | {
525
+ let sub_req_idx = req
526
+ . sub_requests
527
+ . iter ( )
528
+ . position ( |sr| sr. node_id == * node_id && sr. result . is_none ( ) ) ?;
529
+ Some ( ( req_idx, sub_req_idx) )
530
+ } )
531
+ else {
532
+ log:: error!(
533
+ "[{}] Received unexpected message: {result:?} from {}" ,
534
+ self . tag,
535
+ self . nodes[ node_idx] . connection. tag( )
536
+ ) ;
537
+ return Some ( Err ( Error :: Client ( format ! (
538
+ "[{}] Received unexpected message" ,
539
+ self . tag
540
+ ) ) ) ) ;
541
+ } ;
507
542
508
543
self . pending_requests [ req_idx] . sub_requests [ sub_req_idx] . result = Some ( result) ;
509
544
trace ! (
@@ -768,7 +803,8 @@ impl ClusterConnection {
768
803
let mut deserializer = RespDeserializer :: new ( resp_buf) ;
769
804
let Ok ( chunks) = deserializer. array_chunks ( ) else {
770
805
return Some ( Err ( Error :: Client ( format ! (
771
- "[{}] Unexpected result {sub_result:?}" , self . tag
806
+ "[{}] Unexpected result {sub_result:?}" ,
807
+ self . tag
772
808
) ) ) ) ;
773
809
} ;
774
810
@@ -795,7 +831,8 @@ impl ClusterConnection {
795
831
let mut deserializer = RespDeserializer :: new ( resp_buf) ;
796
832
let Ok ( chunks) = deserializer. array_chunks ( ) else {
797
833
return Some ( Err ( Error :: Client ( format ! (
798
- "[{}] Unexpected result {sub_result:?}" , self . tag
834
+ "[{}] Unexpected result {sub_result:?}" ,
835
+ self . tag
799
836
) ) ) ) ;
800
837
} ;
801
838
@@ -903,7 +940,8 @@ impl ClusterConnection {
903
940
let mut slot_ranges = Vec :: < SlotRange > :: new ( ) ;
904
941
905
942
for shard_info in shard_info_list. into_iter ( ) {
906
- let Some ( master_info) = shard_info. nodes . into_iter ( ) . find ( |n| n. role == "master" ) else {
943
+ let Some ( master_info) = shard_info. nodes . into_iter ( ) . find ( |n| n. role == "master" )
944
+ else {
907
945
return Err ( Error :: Client ( "Cluster misconfiguration" . to_owned ( ) ) ) ;
908
946
} ;
909
947
let master_id: NodeId = master_info. id . as_str ( ) . into ( ) ;
@@ -1015,7 +1053,8 @@ impl ClusterConnection {
1015
1053
for mut shard_info in shard_info_list {
1016
1054
// ensure that the first node is master
1017
1055
if shard_info. nodes [ 0 ] . role != "master" {
1018
- let Some ( master_idx) = shard_info. nodes . iter ( ) . position ( |n| n. role == "master" ) else {
1056
+ let Some ( master_idx) = shard_info. nodes . iter ( ) . position ( |n| n. role == "master" )
1057
+ else {
1019
1058
return Err ( Error :: Client ( "Cluster misconfiguration" . to_owned ( ) ) ) ;
1020
1059
} ;
1021
1060
0 commit comments