Skip to content

Commit e5a1e0c

Browse files
authored
Better transaction support in cluster mode
- if all the commands of a transaction are not executed on the same node, the transaction will fail cleanly. The test is done in Rustis, before actually sending the commands to the Redis cluster
1 parent a7aafba commit e5a1e0c

File tree

7 files changed

+141
-41
lines changed

7 files changed

+141
-41
lines changed

src/network/cluster_connection.rs

Lines changed: 74 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -152,15 +152,7 @@ impl ClusterConnection {
152152
)));
153153
};
154154

155-
let command_name = command_info.name.to_string();
156-
157-
let request_policy = command_info.command_tips.iter().find_map(|tip| {
158-
if let CommandTip::RequestPolicy(request_policy) = tip {
159-
Some(request_policy)
160-
} else {
161-
None
162-
}
163-
});
155+
let command_name = command_info.name.clone();
164156

165157
let node_idx = self.get_random_node_index();
166158
let keys = self
@@ -171,6 +163,14 @@ impl ClusterConnection {
171163

172164
debug!("[{}] keys: {keys:?}, slots: {slots:?}", self.tag);
173165

166+
let request_policy = command_info.command_tips.iter().find_map(|tip| {
167+
if let CommandTip::RequestPolicy(request_policy) = tip {
168+
Some(request_policy)
169+
} else {
170+
None
171+
}
172+
});
173+
174174
if let Some(request_policy) = request_policy {
175175
match request_policy {
176176
RequestPolicy::AllNodes => {
@@ -205,7 +205,7 @@ impl ClusterConnection {
205205

206206
pub async fn write_batch(
207207
&mut self,
208-
commands: impl Iterator<Item = &mut Command>,
208+
commands: SmallVec<[&mut Command; 10]>,
209209
retry_reasons: &[RetryReason],
210210
) -> Result<()> {
211211
if retry_reasons.iter().any(|r| {
@@ -231,8 +231,39 @@ impl ClusterConnection {
231231
})
232232
.collect::<Vec<_>>();
233233

234-
for command in commands {
235-
self.internal_write(command, &ask_reasons).await?;
234+
if commands.len() > 1 && commands[0].name == "MULTI" {
235+
let node_idx = self.get_random_node_index();
236+
let keys = self
237+
.command_info_manager
238+
.extract_keys(commands[1], &mut self.nodes[node_idx].connection)
239+
.await?;
240+
let slots = Self::hash_slots(&keys);
241+
if slots.is_empty() || !slots.windows(2).all(|s| s[0] == s[1]) {
242+
return Err(Error::Client(format!(
243+
"[{}] Cannot execute transaction with mismatched key slots",
244+
self.tag
245+
)));
246+
}
247+
let ref_slot = slots[0];
248+
249+
for command in commands {
250+
let keys = self
251+
.command_info_manager
252+
.extract_keys(command, &mut self.nodes[node_idx].connection)
253+
.await?;
254+
self.no_request_policy(
255+
command,
256+
command.name.to_string(),
257+
keys,
258+
SmallVec::from_slice(&[ref_slot]),
259+
&ask_reasons,
260+
)
261+
.await?;
262+
}
263+
} else {
264+
for command in commands {
265+
self.internal_write(command, &ask_reasons).await?;
266+
}
236267
}
237268

238269
Ok(())
@@ -308,7 +339,7 @@ impl ClusterConnection {
308339
Ok(())
309340
}
310341

311-
/// The client should execute the command on several shards.
342+
/// The client should execute the command on multiple shards.
312343
/// The shards that execute the command are determined by the hash slots of its input key name arguments.
313344
/// Examples for such commands include MSET, MGET and DEL.
314345
/// However, note that SUNIONSTORE isn't considered as multi_shard because all of its keys must belong to the same hash slot.
@@ -486,24 +517,28 @@ impl ClusterConnection {
486517

487518
let node_id = &self.nodes[node_idx].id;
488519

489-
let Some((req_idx, sub_req_idx)) = self
490-
.pending_requests
491-
.iter()
492-
.enumerate()
493-
.find_map(|(req_idx, req)| {
494-
let sub_req_idx = req
495-
.sub_requests
496-
.iter()
497-
.position(|sr| sr.node_id == *node_id && sr.result.is_none())?;
498-
Some((req_idx, sub_req_idx))
499-
}) else {
500-
log::error!("[{}] Received unexpected message: {result:?} from {}",
501-
self.tag, self.nodes[node_idx].connection.tag());
502-
return Some(Err(Error::Client(format!(
503-
"[{}] Received unexpected message",
504-
self.tag
505-
))));
506-
};
520+
let Some((req_idx, sub_req_idx)) =
521+
self.pending_requests
522+
.iter()
523+
.enumerate()
524+
.find_map(|(req_idx, req)| {
525+
let sub_req_idx = req
526+
.sub_requests
527+
.iter()
528+
.position(|sr| sr.node_id == *node_id && sr.result.is_none())?;
529+
Some((req_idx, sub_req_idx))
530+
})
531+
else {
532+
log::error!(
533+
"[{}] Received unexpected message: {result:?} from {}",
534+
self.tag,
535+
self.nodes[node_idx].connection.tag()
536+
);
537+
return Some(Err(Error::Client(format!(
538+
"[{}] Received unexpected message",
539+
self.tag
540+
))));
541+
};
507542

508543
self.pending_requests[req_idx].sub_requests[sub_req_idx].result = Some(result);
509544
trace!(
@@ -768,7 +803,8 @@ impl ClusterConnection {
768803
let mut deserializer = RespDeserializer::new(resp_buf);
769804
let Ok(chunks) = deserializer.array_chunks() else {
770805
return Some(Err(Error::Client(format!(
771-
"[{}] Unexpected result {sub_result:?}", self.tag
806+
"[{}] Unexpected result {sub_result:?}",
807+
self.tag
772808
))));
773809
};
774810

@@ -795,7 +831,8 @@ impl ClusterConnection {
795831
let mut deserializer = RespDeserializer::new(resp_buf);
796832
let Ok(chunks) = deserializer.array_chunks() else {
797833
return Some(Err(Error::Client(format!(
798-
"[{}] Unexpected result {sub_result:?}", self.tag
834+
"[{}] Unexpected result {sub_result:?}",
835+
self.tag
799836
))));
800837
};
801838

@@ -903,7 +940,8 @@ impl ClusterConnection {
903940
let mut slot_ranges = Vec::<SlotRange>::new();
904941

905942
for shard_info in shard_info_list.into_iter() {
906-
let Some(master_info) = shard_info.nodes.into_iter().find(|n| n.role == "master") else {
943+
let Some(master_info) = shard_info.nodes.into_iter().find(|n| n.role == "master")
944+
else {
907945
return Err(Error::Client("Cluster misconfiguration".to_owned()));
908946
};
909947
let master_id: NodeId = master_info.id.as_str().into();
@@ -1015,7 +1053,8 @@ impl ClusterConnection {
10151053
for mut shard_info in shard_info_list {
10161054
// ensure that the first node is master
10171055
if shard_info.nodes[0].role != "master" {
1018-
let Some(master_idx) = shard_info.nodes.iter().position(|n| n.role == "master") else {
1056+
let Some(master_idx) = shard_info.nodes.iter().position(|n| n.role == "master")
1057+
else {
10191058
return Err(Error::Client("Cluster misconfiguration".to_owned()));
10201059
};
10211060

src/network/connection.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use crate::{
66
StandaloneConnection,
77
};
88
use serde::de::DeserializeOwned;
9+
use smallvec::SmallVec;
910
use std::future::IntoFuture;
1011

1112
pub enum Connection {
@@ -42,7 +43,7 @@ impl Connection {
4243
#[inline]
4344
pub async fn write_batch(
4445
&mut self,
45-
commands: impl Iterator<Item = &mut Command>,
46+
commands: SmallVec::<[&mut Command; 10]>,
4647
retry_reasons: &[RetryReason],
4748
) -> Result<()> {
4849
match self {

src/network/network_handler.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ impl NetworkHandler {
328328

329329
if let Err(e) = self
330330
.connection
331-
.write_batch(commands_to_write.into_iter(), &retry_reasons)
331+
.write_batch(commands_to_write, &retry_reasons)
332332
.await
333333
{
334334
error!("[{}] Error while writing batch: {e}", self.tag);

src/network/sentinel_connection.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use crate::{
55
sleep, Error, Result, RetryReason, StandaloneConnection,
66
};
77
use log::debug;
8+
use smallvec::SmallVec;
89

910
pub struct SentinelConnection {
1011
sentinel_config: SentinelConfig,
@@ -21,7 +22,7 @@ impl SentinelConnection {
2122
#[inline]
2223
pub async fn write_batch(
2324
&mut self,
24-
commands: impl Iterator<Item = &mut Command>,
25+
commands: SmallVec::<[&mut Command; 10]>,
2526
retry_reasons: &[RetryReason],
2627
) -> Result<()> {
2728
self.inner_connection

src/network/standalone_connection.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use bytes::BytesMut;
1212
use futures_util::{SinkExt, StreamExt};
1313
use log::{debug, log_enabled, Level};
1414
use serde::de::DeserializeOwned;
15+
use smallvec::SmallVec;
1516
use std::future::IntoFuture;
1617
use tokio::io::AsyncWriteExt;
1718
use tokio_util::codec::{Encoder, FramedRead, FramedWrite};
@@ -99,7 +100,7 @@ impl StandaloneConnection {
99100

100101
pub async fn write_batch(
101102
&mut self,
102-
commands: impl Iterator<Item = &mut Command>,
103+
commands: SmallVec::<[&mut Command; 10]>,
103104
_retry_reasons: &[RetryReason],
104105
) -> Result<()> {
105106
self.buffer.clear();

src/tests/pipeline.rs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::{
22
client::BatchPreparedCommand,
33
commands::{FlushingMode, ServerCommands, StringCommands},
44
resp::{cmd, Value},
5-
tests::get_test_client,
5+
tests::{get_test_client, get_cluster_test_client},
66
Result,
77
};
88
use serial_test::serial;
@@ -46,3 +46,24 @@ async fn error() -> Result<()> {
4646

4747
Ok(())
4848
}
49+
50+
51+
#[cfg_attr(feature = "tokio-runtime", tokio::test)]
52+
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
53+
#[serial]
54+
async fn pipeline_on_cluster() -> Result<()> {
55+
let client = get_cluster_test_client().await?;
56+
client.flushall(FlushingMode::Sync).await?;
57+
58+
let mut pipeline = client.create_pipeline();
59+
pipeline.set("key1", "value1").forget();
60+
pipeline.set("key2", "value2").forget();
61+
pipeline.get::<_, ()>("key1").queue();
62+
pipeline.get::<_, ()>("key2").queue();
63+
64+
let (value1, value2): (String, String) = pipeline.execute().await?;
65+
assert_eq!("value1", value1);
66+
assert_eq!("value2", value2);
67+
68+
Ok(())
69+
}

src/tests/transaction.rs

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::{
22
client::BatchPreparedCommand,
33
commands::{FlushingMode, ListCommands, ServerCommands, StringCommands, TransactionCommands},
44
resp::cmd,
5-
tests::get_test_client,
5+
tests::{get_test_client, get_cluster_test_client},
66
Error, RedisError, RedisErrorKind, Result,
77
};
88
use serial_test::serial;
@@ -160,3 +160,40 @@ async fn transaction_discard() -> Result<()> {
160160

161161
Ok(())
162162
}
163+
164+
#[cfg_attr(feature = "tokio-runtime", tokio::test)]
165+
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
166+
#[serial]
167+
async fn transaction_on_cluster_connection_with_keys_with_same_slot() -> Result<()> {
168+
let client = get_cluster_test_client().await?;
169+
client.flushall(FlushingMode::Sync).await?;
170+
171+
let mut transaction = client.create_transaction();
172+
173+
transaction.mset([("{hash}key1", "value1"), ("{hash}key2", "value2")]).queue();
174+
transaction.get::<_, String>("{hash}key1").queue();
175+
transaction.get::<_, String>("{hash}key2").queue();
176+
let ((), val1, val2): ((), String, String) = transaction.execute().await.unwrap();
177+
assert_eq!("value1", val1);
178+
assert_eq!("value2", val2);
179+
180+
Ok(())
181+
}
182+
183+
#[cfg_attr(feature = "tokio-runtime", tokio::test)]
184+
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
185+
#[serial]
186+
async fn transaction_on_cluster_connection_with_keys_with_different_slots() -> Result<()> {
187+
let client = get_cluster_test_client().await?;
188+
client.flushall(FlushingMode::Sync).await?;
189+
190+
let mut transaction = client.create_transaction();
191+
192+
transaction.mset([("key1", "value1"), ("key2", "value2")]).queue();
193+
transaction.get::<_, String>("key1").queue();
194+
transaction.get::<_, String>("key2").queue();
195+
let result: Result<((), String, String)> = transaction.execute().await;
196+
assert!(result.is_err());
197+
198+
Ok(())
199+
}

0 commit comments

Comments
 (0)