Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
namespace NServiceBus.Transport.RabbitMQ.Tests.ConnectionString
{
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using global::RabbitMQ.Client;
using global::RabbitMQ.Client.Events;
using NUnit.Framework;

[TestFixture]
public class ChannelProviderTests
{
[Test]
public async Task Should_recover_connection_and_dispose_old_one_when_connection_shutdown()
{
var channelProvider = new TestableChannelProvider();
channelProvider.CreateConnection();

var publishConnection = channelProvider.PublishConnections.Dequeue();
publishConnection.RaiseConnectionShutdown(new ShutdownEventArgs(ShutdownInitiator.Library, 0, "Test"));

channelProvider.DelayTaskCompletionSource.SetResult();

await channelProvider.FireAndForgetAction(CancellationToken.None);

var recoveredConnection = channelProvider.PublishConnections.Dequeue();

Assert.That(publishConnection.WasDisposed, Is.True);
Assert.That(recoveredConnection.WasDisposed, Is.False);
}

[Test]
public void Should_dispose_connection_when_disposed()
{
var channelProvider = new TestableChannelProvider();
channelProvider.CreateConnection();

var publishConnection = channelProvider.PublishConnections.Dequeue();
channelProvider.Dispose();

Assert.That(publishConnection.WasDisposed, Is.True);
}

[Test]
public async Task Should_not_attempt_to_recover_during_dispose_when_retry_delay_still_pending()
{
var channelProvider = new TestableChannelProvider();
channelProvider.CreateConnection();

var publishConnection = channelProvider.PublishConnections.Dequeue();
publishConnection.RaiseConnectionShutdown(new ShutdownEventArgs(ShutdownInitiator.Library, 0, "Test"));

// Deliberately not completing the delay task with channelProvider.DelayTaskCompletionSource.SetResult(); before disposing
// to simulate a pending delay task
channelProvider.Dispose();

await channelProvider.FireAndForgetAction(CancellationToken.None);

Assert.That(publishConnection.WasDisposed, Is.True);
Assert.That(channelProvider.PublishConnections.TryDequeue(out _), Is.False);
}

[Test]
public async Task Should_dispose_newly_established_connection()
{
var channelProvider = new TestableChannelProvider();
channelProvider.CreateConnection();

var publishConnection = channelProvider.PublishConnections.Dequeue();
publishConnection.RaiseConnectionShutdown(new ShutdownEventArgs(ShutdownInitiator.Library, 0, "Test"));

// This simulates the race of the reconnection loop being fired off with the delay task completed during
// the disposal of the channel provider. To achieve that it is necessary to kick off the reconnection loop
// and await its completion after the channel provider has been disposed.
var fireAndForgetTask = channelProvider.FireAndForgetAction(CancellationToken.None);
channelProvider.DelayTaskCompletionSource.SetResult();
channelProvider.Dispose();

await fireAndForgetTask;

var recoveredConnection = channelProvider.PublishConnections.Dequeue();

Assert.That(publishConnection.WasDisposed, Is.True);
Assert.That(recoveredConnection.WasDisposed, Is.True);
}

class TestableChannelProvider() : ChannelProvider(null!, TimeSpan.Zero, null!)
{
public Queue<FakeConnection> PublishConnections { get; } = new();

public TaskCompletionSource DelayTaskCompletionSource { get; } = new(TaskCreationOptions.RunContinuationsAsynchronously);

public Func<CancellationToken, Task> FireAndForgetAction { get; private set; }

protected override IConnection CreatePublishConnection()
{
var connection = new FakeConnection();
PublishConnections.Enqueue(connection);
return connection;
}

protected override void FireAndForget(Func<CancellationToken, Task> action, CancellationToken cancellationToken = default)
=> FireAndForgetAction = _ => action(cancellationToken);

protected override async Task DelayReconnect(CancellationToken cancellationToken = default)
{
await using var _ = cancellationToken.Register(() => DelayTaskCompletionSource.TrySetCanceled(cancellationToken));
await DelayTaskCompletionSource.Task;
}
}

class FakeConnection : IConnection
{
public int LocalPort { get; }
public int RemotePort { get; }

public void Dispose() => WasDisposed = true;

public bool WasDisposed { get; private set; }

public void UpdateSecret(string newSecret, string reason) => throw new NotImplementedException();

public void Abort() => throw new NotImplementedException();

public void Abort(ushort reasonCode, string reasonText) => throw new NotImplementedException();

public void Abort(TimeSpan timeout) => throw new NotImplementedException();

public void Abort(ushort reasonCode, string reasonText, TimeSpan timeout) => throw new NotImplementedException();

public void Close() => throw new NotImplementedException();

public void Close(ushort reasonCode, string reasonText) => throw new NotImplementedException();

public void Close(TimeSpan timeout) => throw new NotImplementedException();

public void Close(ushort reasonCode, string reasonText, TimeSpan timeout) => throw new NotImplementedException();

public IModel CreateModel() => throw new NotImplementedException();

public void HandleConnectionBlocked(string reason) => throw new NotImplementedException();

public void HandleConnectionUnblocked() => throw new NotImplementedException();

public ushort ChannelMax { get; }
public IDictionary<string, object> ClientProperties { get; }
public ShutdownEventArgs CloseReason { get; }
public AmqpTcpEndpoint Endpoint { get; }
public uint FrameMax { get; }
public TimeSpan Heartbeat { get; }
public bool IsOpen { get; }
public AmqpTcpEndpoint[] KnownHosts { get; }
public IProtocol Protocol { get; }
public IDictionary<string, object> ServerProperties { get; }
public IList<ShutdownReportEntry> ShutdownReport { get; }
public string ClientProvidedName { get; } = $"FakeConnection{Interlocked.Increment(ref connectionCounter)}";
public event EventHandler<CallbackExceptionEventArgs> CallbackException = (_, _) => { };
public event EventHandler<ConnectionBlockedEventArgs> ConnectionBlocked = (_, _) => { };
public event EventHandler<ShutdownEventArgs> ConnectionShutdown = (_, _) => { };
public event EventHandler<EventArgs> ConnectionUnblocked = (_, _) => { };

public void RaiseConnectionShutdown(ShutdownEventArgs args) => ConnectionShutdown?.Invoke(this, args);

static int connectionCounter;
}
}
}
79 changes: 60 additions & 19 deletions src/NServiceBus.Transport.RabbitMQ/Connection/ChannelProvider.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#nullable enable

namespace NServiceBus.Transport.RabbitMQ
{
using System;
Expand All @@ -7,7 +9,7 @@ namespace NServiceBus.Transport.RabbitMQ
using global::RabbitMQ.Client;
using Logging;

sealed class ChannelProvider : IDisposable
class ChannelProvider : IDisposable
{
public ChannelProvider(ConnectionFactory connectionFactory, TimeSpan retryDelay, IRoutingTopology routingTopology)
{
Expand All @@ -19,36 +21,56 @@ public ChannelProvider(ConnectionFactory connectionFactory, TimeSpan retryDelay,
channels = new ConcurrentQueue<ConfirmsAwareChannel>();
}

public void CreateConnection()
public void CreateConnection() => connection = CreateConnectionWithShutdownListener();

protected virtual IConnection CreatePublishConnection() => connectionFactory.CreatePublishConnection();

IConnection CreateConnectionWithShutdownListener()
{
connection = connectionFactory.CreatePublishConnection();
connection.ConnectionShutdown += Connection_ConnectionShutdown;
var newConnection = CreatePublishConnection();
newConnection.ConnectionShutdown += Connection_ConnectionShutdown;
return newConnection;
}

void Connection_ConnectionShutdown(object sender, ShutdownEventArgs e)
void Connection_ConnectionShutdown(object? sender, ShutdownEventArgs e)
{
if (e.Initiator != ShutdownInitiator.Application)
if (e.Initiator == ShutdownInitiator.Application || sender is null)
{
var connection = (IConnection)sender;

// Task.Run() so the call returns immediately instead of waiting for the first await or return down the call stack
_ = Task.Run(() => ReconnectSwallowingExceptions(connection.ClientProvidedName), CancellationToken.None);
return;
}

var connectionThatWasShutdown = (IConnection)sender;

FireAndForget(cancellationToken => ReconnectSwallowingExceptions(connectionThatWasShutdown.ClientProvidedName, cancellationToken), stoppingTokenSource.Token);
}

#pragma warning disable PS0018 // A task-returning method should have a CancellationToken parameter unless it has a parameter implementing ICancellableContext
async Task ReconnectSwallowingExceptions(string connectionName)
#pragma warning restore PS0018 // A task-returning method should have a CancellationToken parameter unless it has a parameter implementing ICancellableContext
async Task ReconnectSwallowingExceptions(string connectionName, CancellationToken cancellationToken)
{
while (true)
while (!cancellationToken.IsCancellationRequested)
{
Logger.InfoFormat("'{0}': Attempting to reconnect in {1} seconds.", connectionName, retryDelay.TotalSeconds);

await Task.Delay(retryDelay).ConfigureAwait(false);

try
{
CreateConnection();
await DelayReconnect(cancellationToken).ConfigureAwait(false);

var newConnection = CreateConnectionWithShutdownListener();

// A race condition is possible where CreatePublishConnection is invoked during Dispose
// where the returned connection isn't disposed so invoking Dispose to be sure
if (cancellationToken.IsCancellationRequested)
{
newConnection.Dispose();
break;
}

var oldConnection = Interlocked.Exchange(ref connection, newConnection);
oldConnection?.Dispose();
break;
}
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
{
Logger.InfoFormat("'{0}': Stopped trying to reconnecting to the broker due to shutdown", connectionName);
break;
}
catch (Exception ex)
Expand All @@ -60,6 +82,12 @@ async Task ReconnectSwallowingExceptions(string connectionName)
Logger.InfoFormat("'{0}': Connection to the broker reestablished successfully.", connectionName);
}

protected virtual void FireAndForget(Func<CancellationToken, Task> action, CancellationToken cancellationToken = default) =>
// Task.Run() so the call returns immediately instead of waiting for the first await or return down the call stack
_ = Task.Run(() => action(cancellationToken), CancellationToken.None);

protected virtual Task DelayReconnect(CancellationToken cancellationToken = default) => Task.Delay(retryDelay, cancellationToken);

public ConfirmsAwareChannel GetPublishChannel()
{
if (!channels.TryDequeue(out var channel) || channel.IsClosed)
Expand All @@ -86,19 +114,32 @@ public void ReturnPublishChannel(ConfirmsAwareChannel channel)

public void Dispose()
{
connection?.Dispose();
if (disposed)
{
return;
}

stoppingTokenSource.Cancel();
stoppingTokenSource.Dispose();

var oldConnection = Interlocked.Exchange(ref connection, null);
oldConnection?.Dispose();

foreach (var channel in channels)
{
channel.Dispose();
}

disposed = true;
}

readonly ConnectionFactory connectionFactory;
readonly TimeSpan retryDelay;
readonly IRoutingTopology routingTopology;
readonly ConcurrentQueue<ConfirmsAwareChannel> channels;
IConnection connection;
readonly CancellationTokenSource stoppingTokenSource = new();
volatile IConnection? connection;
bool disposed;

static readonly ILog Logger = LogManager.GetLogger(typeof(ChannelProvider));
}
Expand Down