@@ -13,15 +13,23 @@ use async_trait::async_trait;
13
13
use bytes:: BytesMut ;
14
14
use eyre:: Result ;
15
15
use iris_mpc_common:: fast_metrics:: FastHistogram ;
16
- use std:: io;
17
- use std:: { collections:: HashMap , time:: Instant } ;
16
+ use std:: {
17
+ collections:: HashMap ,
18
+ io,
19
+ sync:: {
20
+ atomic:: { AtomicBool , Ordering } ,
21
+ Arc ,
22
+ } ,
23
+ time:: Instant ,
24
+ } ;
18
25
use tokio:: {
19
26
io:: { AsyncReadExt , AsyncWriteExt , BufReader , ReadHalf , WriteHalf } ,
20
27
sync:: {
21
28
mpsc:: { self , error:: TryRecvError , UnboundedReceiver , UnboundedSender } ,
22
- oneshot,
29
+ oneshot, Mutex ,
23
30
} ,
24
31
} ;
32
+ use tokio_util:: sync:: CancellationToken ;
25
33
26
34
const BUFFER_CAPACITY : usize = 32 * 1024 ;
27
35
const READ_BUF_SIZE : usize = 2 * 1024 * 1024 ;
@@ -69,21 +77,26 @@ impl<T: NetworkConnection + 'static> TcpNetworkHandle<T> {
69
77
reconnector : Reconnector < T > ,
70
78
connections : PeerConnections < T > ,
71
79
config : TcpConfig ,
80
+ ct : CancellationToken ,
72
81
) -> Self {
73
82
let peers = connections. keys ( ) . cloned ( ) . collect ( ) ;
74
83
let mut ch_map = HashMap :: new ( ) ;
84
+ let conn_state = ConnectionState :: new ( ) ;
75
85
for ( peer_id, connections) in connections {
76
86
let mut m = HashMap :: new ( ) ;
77
87
for ( stream_id, connection) in connections {
78
88
let rc = reconnector. clone ( ) ;
79
89
let ( cmd_tx, cmd_rx) = mpsc:: unbounded_channel ( ) ;
80
90
m. insert ( stream_id, cmd_tx) ;
81
91
92
+ let ct2 = ct. clone ( ) ;
82
93
tokio:: spawn ( manage_connection (
83
94
connection,
84
95
rc,
85
96
config. get_sessions_for_stream ( & stream_id) ,
86
97
cmd_rx,
98
+ ct2,
99
+ conn_state. clone ( ) ,
87
100
) ) ;
88
101
}
89
102
ch_map. insert ( peer_id. clone ( ) , m) ;
@@ -220,6 +233,8 @@ async fn manage_connection<T: NetworkConnection>(
220
233
reconnector : Reconnector < T > ,
221
234
num_sessions : usize ,
222
235
mut cmd_ch : UnboundedReceiver < Cmd > ,
236
+ ct : CancellationToken ,
237
+ conn_state : ConnectionState ,
223
238
) {
224
239
let Connection {
225
240
peer,
@@ -257,35 +272,55 @@ async fn manage_connection<T: NetworkConnection>(
257
272
// when new sessions are requested, also tear down and stand up the forwarders.
258
273
loop {
259
274
let r_writer = writer. as_mut ( ) . expect ( "writer should be Some" ) ;
260
- let r_outbound = & mut outbound_rx;
261
- let outbound_task = async move {
262
- let r = handle_outbound_traffic ( r_writer, r_outbound, num_sessions) . await ;
263
- tracing:: warn!( "handle_outbound_traffic exited: {r:?}" ) ;
264
- } ;
275
+ let outbound_task = handle_outbound_traffic ( r_writer, & mut outbound_rx, num_sessions) ;
265
276
266
277
let r_reader = reader. as_mut ( ) . expect ( "reader should be Some" ) ;
267
- let r_inbound = & inbound_forwarder;
268
- let inbound_task = async move {
269
- let r = handle_inbound_traffic ( r_reader, r_inbound) . await ;
270
- tracing:: warn!( "handle_inbound_traffic exited: {r:?}" ) ;
271
- } ;
278
+ let inbound_task = handle_inbound_traffic ( r_reader, & inbound_forwarder) ;
272
279
273
280
enum Evt {
274
281
Cmd ( Cmd ) ,
275
282
Disconnected ,
283
+ Shutdown ,
276
284
}
277
285
let event = tokio:: select! {
278
286
maybe_cmd = cmd_ch. recv( ) => {
279
287
match maybe_cmd {
280
288
Some ( cmd) => Evt :: Cmd ( cmd) ,
281
289
None => {
282
- tracing:: info!( "cmd channel closed" ) ;
290
+ // basically means the networking stack was shut down before a command was received.
291
+ if conn_state. exited( ) {
292
+ tracing:: info!( "cmd channel closed" ) ;
293
+ } ;
294
+ ct. cancel( ) ;
283
295
return ;
284
296
}
285
297
}
286
298
}
287
- _ = inbound_task => Evt :: Disconnected ,
288
- _ = outbound_task => Evt :: Disconnected ,
299
+ r = inbound_task => {
300
+ if ct. is_cancelled( ) {
301
+ Evt :: Shutdown
302
+ } else if let Err ( e) = r {
303
+ if conn_state. incr_reconnect( ) . await {
304
+ tracing:: error!( e=%e, "TCP/TLS connection closed unexpectedly. reconnecting..." ) ;
305
+ }
306
+ Evt :: Disconnected
307
+ } else {
308
+ unreachable!( ) ;
309
+ }
310
+ } ,
311
+ r = outbound_task => {
312
+ if ct. is_cancelled( ) {
313
+ Evt :: Shutdown
314
+ } else if let Err ( e) = r {
315
+ if conn_state. incr_reconnect( ) . await {
316
+ tracing:: error!( e=%e, "TCP/TLS connection closed unexpectedly. reconnecting..." ) ;
317
+ }
318
+ Evt :: Disconnected
319
+ } else {
320
+ unreachable!( ) ;
321
+ }
322
+ } ,
323
+ _ = ct. cancelled( ) => Evt :: Shutdown ,
289
324
} ;
290
325
291
326
// update the Arcs depending on the event. wait for reconnect if needed.
@@ -314,21 +349,35 @@ async fn manage_connection<T: NetworkConnection>(
314
349
)
315
350
. await
316
351
{
317
- tracing:: error !( "reconnect failed: {e:?}" ) ;
352
+ tracing:: debug !( "reconnect failed: {e:?}" ) ;
318
353
return ;
319
354
} ;
320
355
rsp. send ( Ok ( ( ) ) ) . unwrap ( ) ;
321
356
}
322
357
} ,
323
358
Evt :: Disconnected => {
324
- tracing:: info !( "reconnecting to {:?}: {:?}" , peer, stream_id) ;
359
+ tracing:: debug !( "reconnecting to {:?}: {:?}" , peer, stream_id) ;
325
360
if let Err ( e) =
326
361
reconnect_and_replace ( & reconnector, & peer, stream_id, & mut reader, & mut writer)
327
362
. await
328
363
{
329
- tracing:: error!( "reconnect failed: {e:?}" ) ;
364
+ if conn_state. exited ( ) {
365
+ if ct. is_cancelled ( ) {
366
+ tracing:: info!( "shutting down TCP/TLS networking stack" ) ;
367
+ } else {
368
+ tracing:: error!( "reconnect failed: {e:?}" ) ;
369
+ }
370
+ }
330
371
return ;
331
- } ;
372
+ } else if conn_state. decr_reconnect ( ) . await {
373
+ tracing:: info!( "all connections re-established" ) ;
374
+ }
375
+ }
376
+ Evt :: Shutdown => {
377
+ if conn_state. exited ( ) {
378
+ tracing:: info!( "shutting down TCP/TLS networking stack" ) ;
379
+ }
380
+ return ;
332
381
}
333
382
}
334
383
}
@@ -412,17 +461,11 @@ async fn handle_outbound_traffic<T: NetworkConnection>(
412
461
}
413
462
}
414
463
415
- if let Err ( e) = write_buf ( stream, & mut buf) . await {
416
- tracing:: error!( error=%e, "Failed to flush buffer on outbound_rx" ) ;
417
- return Err ( e) ;
418
- }
464
+ write_buf ( stream, & mut buf) . await ?
419
465
}
420
466
421
467
if !buf. is_empty ( ) {
422
- if let Err ( e) = write_buf ( stream, & mut buf) . await {
423
- tracing:: error!( error=%e, "Failed to flush buffer when outbound_rx closed" ) ;
424
- return Err ( e) ;
425
- }
468
+ write_buf ( stream, & mut buf) . await ?
426
469
}
427
470
// the channel will not receive any more commands
428
471
tracing:: debug!( "outbound_rx closed" ) ;
@@ -493,7 +536,10 @@ async fn handle_inbound_traffic<T: NetworkConnection>(
493
536
}
494
537
}
495
538
Err ( e) => {
496
- tracing:: error!( "failed to deserialize message: {e}" ) ;
539
+ return Err ( io:: Error :: new (
540
+ io:: ErrorKind :: Other ,
541
+ format ! ( "failed to deserialize message: {e}" ) ,
542
+ ) ) ;
497
543
}
498
544
} ;
499
545
} else {
@@ -515,3 +561,72 @@ async fn write_buf<T: NetworkConnection>(
515
561
buf. clear ( ) ;
516
562
Ok ( ( ) )
517
563
}
564
+
565
+ // state which is shared by all connections. used to reduce the number of logs
566
+ // emitted upon loss of network connectivity.
567
+ #[ derive( Clone ) ]
568
+ struct ConnectionState {
569
+ inner : Arc < ConnectionStateInner > ,
570
+ }
571
+
572
+ impl ConnectionState {
573
+ fn new ( ) -> Self {
574
+ Self {
575
+ inner : Arc :: new ( ConnectionStateInner :: new ( ) ) ,
576
+ }
577
+ }
578
+
579
+ fn exited ( & self ) -> bool {
580
+ self . inner
581
+ . exited
582
+ . compare_exchange ( false , true , Ordering :: SeqCst , Ordering :: SeqCst )
583
+ . is_ok ( )
584
+ }
585
+
586
+ async fn incr_reconnect ( & self ) -> bool {
587
+ self . inner . reconnecting . increment ( ) . await
588
+ }
589
+
590
+ async fn decr_reconnect ( & self ) -> bool {
591
+ self . inner . reconnecting . decrement ( ) . await
592
+ }
593
+ }
594
+
595
+ struct ConnectionStateInner {
596
+ reconnecting : Counter ,
597
+ exited : AtomicBool ,
598
+ }
599
+
600
+ impl ConnectionStateInner {
601
+ fn new ( ) -> Self {
602
+ Self {
603
+ reconnecting : Counter :: new ( ) ,
604
+ exited : AtomicBool :: new ( false ) ,
605
+ }
606
+ }
607
+ }
608
+
609
+ struct Counter {
610
+ num : Mutex < usize > ,
611
+ }
612
+
613
+ impl Counter {
614
+ fn new ( ) -> Self {
615
+ Self { num : Mutex :: new ( 0 ) }
616
+ }
617
+
618
+ // returns true if num was zero
619
+ async fn increment ( & self ) -> bool {
620
+ let mut l = self . num . lock ( ) . await ;
621
+ * l += 1 ;
622
+ * l == 1
623
+ }
624
+
625
+ // returns true if num was one before decrementing
626
+ async fn decrement ( & self ) -> bool {
627
+ let mut l = self . num . lock ( ) . await ;
628
+ let r = * l == 1 ;
629
+ * l = l. saturating_sub ( 1 ) ;
630
+ r
631
+ }
632
+ }
0 commit comments