Skip to content

Commit 04b5171

Browse files
committed
Fix cancellation propagation in operators
1 parent e2e5650 commit 04b5171

File tree

12 files changed

+174
-40
lines changed

12 files changed

+174
-40
lines changed

async-enumerable-dotnet-test/FlatMapTest.cs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,5 +185,34 @@ await AsyncEnumerable.Range(1, 5)
185185
.FlatMap<int, int>(v => throw new InvalidOperationException())
186186
.AssertFailure(typeof(InvalidOperationException));
187187
}
188+
189+
[Fact]
190+
public async void NoCancelDelay()
191+
{
192+
var start = DateTime.Now;
193+
try
194+
{
195+
await foreach (var item in async_enumerable_dotnet.AsyncEnumerable.Interval(TimeSpan.Zero, TimeSpan.FromSeconds(10))
196+
.Map(x => async_enumerable_dotnet.AsyncEnumerable.Just(x))
197+
.Merge())
198+
{
199+
Console.WriteLine(item);
200+
201+
throw new Exception("expected");
202+
}
203+
204+
throw new Exception("unexpected");
205+
}
206+
catch (Exception e) when (e.Message == "expected")
207+
{
208+
// expected
209+
}
210+
var end = DateTime.Now;
211+
212+
if (end - start > TimeSpan.FromSeconds(5))
213+
{
214+
Assert.True(false, "Test took too much time");
215+
}
216+
}
188217
}
189218
}

async-enumerable-dotnet-test/SwitchMapTest.cs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
using Xunit;
66
using async_enumerable_dotnet;
77
using System;
8+
using System.Threading.Tasks;
9+
using System.Threading;
10+
using System.Data;
811

912
namespace async_enumerable_dotnet_test
1013
{
@@ -87,5 +90,33 @@ await AsyncEnumerable.Just(AsyncEnumerable.Range(2, 5))
8790
.Switch()
8891
.AssertResult(2, 3, 4, 5, 6);
8992
}
93+
94+
[Fact]
95+
public async void NoCancelDelay()
96+
{
97+
var start = DateTime.Now;
98+
try
99+
{
100+
await foreach (var item in async_enumerable_dotnet.AsyncEnumerable.Interval(TimeSpan.Zero, TimeSpan.FromSeconds(10))
101+
.Map(x => async_enumerable_dotnet.AsyncEnumerable.Just(x))
102+
.Switch())
103+
{
104+
Console.WriteLine(item);
105+
106+
throw new Exception("expected");
107+
}
108+
109+
throw new Exception("unexpected");
110+
} catch (Exception e) when (e.Message == "expected")
111+
{
112+
// expected
113+
}
114+
var end = DateTime.Now;
115+
116+
if (end - start > TimeSpan.FromSeconds(5))
117+
{
118+
Assert.True(false, "Test took too much time");
119+
}
120+
}
90121
}
91122
}

async-enumerable-dotnet/impl/BufferBoundary.cs

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ private sealed class BufferBoundaryExactEnumerator : IAsyncEnumerator<TCollectio
7676
private TCollection _buffer;
7777
private int _size;
7878

79+
private bool _suppressCancel;
80+
7981
public TCollection Current { get; private set; }
8082

8183
public BufferBoundaryExactEnumerator(IAsyncEnumerator<TSource> source,
@@ -172,6 +174,8 @@ public async ValueTask<bool> MoveNextAsync()
172174

173175
public async ValueTask DisposeAsync()
174176
{
177+
_mainCancel.Cancel();
178+
_otherCancel.Cancel();
175179
if (Interlocked.Increment(ref _sourceDisposeWip) == 1)
176180
{
177181
Dispose(_source);
@@ -207,7 +211,16 @@ private void HandleNextSource(Task<bool> t)
207211
{
208212
if (t.IsCanceled)
209213
{
210-
// FIXME ignore???
214+
if (!Volatile.Read(ref _suppressCancel))
215+
{
216+
ExceptionHelper.AddException(ref _error, new OperationCanceledException());
217+
_queue.Enqueue(new Entry
218+
{
219+
Done = true
220+
});
221+
Volatile.Write(ref _suppressCancel, true);
222+
_otherCancel.Cancel();
223+
}
211224
}
212225
else if (t.IsFaulted)
213226
{
@@ -216,6 +229,7 @@ private void HandleNextSource(Task<bool> t)
216229
{
217230
Done = true
218231
});
232+
Volatile.Write(ref _suppressCancel, true);
219233
_otherCancel.Cancel();
220234
}
221235
else if (t.Result)
@@ -231,6 +245,7 @@ private void HandleNextSource(Task<bool> t)
231245
{
232246
Done = true
233247
});
248+
Volatile.Write(ref _suppressCancel, true);
234249
_otherCancel.Cancel();
235250
}
236251
if (TryDisposeSource())
@@ -260,7 +275,16 @@ private void HandleNextOther(Task<bool> t)
260275
{
261276
if (t.IsCanceled)
262277
{
263-
// FIXME ignore???
278+
if (!Volatile.Read(ref _suppressCancel))
279+
{
280+
ExceptionHelper.AddException(ref _error, new OperationCanceledException());
281+
_queue.Enqueue(new Entry
282+
{
283+
Done = true
284+
});
285+
Volatile.Write(ref _suppressCancel, true);
286+
_mainCancel.Cancel();
287+
}
264288
}
265289
else if (t.IsFaulted)
266290
{
@@ -269,6 +293,7 @@ private void HandleNextOther(Task<bool> t)
269293
{
270294
Done = true
271295
});
296+
Volatile.Write(ref _suppressCancel, true);
272297
_mainCancel.Cancel();
273298
}
274299
else if (t.Result)
@@ -284,6 +309,7 @@ private void HandleNextOther(Task<bool> t)
284309
{
285310
Done = true
286311
});
312+
Volatile.Write(ref _suppressCancel, true);
287313
_mainCancel.Cancel();
288314
}
289315
if (TryDisposeOther())

async-enumerable-dotnet/impl/CombineLatest.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,10 @@ private void NextHandler(Task<bool> t)
259259
{
260260
if (t.IsCanceled)
261261
{
262-
// FIXME ignore???
262+
if (TryDispose())
263+
{
264+
_parent.InnerError(_index, new OperationCanceledException());
265+
}
263266
}
264267
else if (t.IsFaulted)
265268
{

async-enumerable-dotnet/impl/ConcatMapEager.cs

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@ public ConcatMapEager(IAsyncEnumerable<TSource> source, Func<TSource, IAsyncEnum
3030

3131
public IAsyncEnumerator<TResult> GetAsyncEnumerator(CancellationToken cancellationToken)
3232
{
33-
var en = new ConcatMapEagerEnumerator(_source.GetAsyncEnumerator(cancellationToken), _mapper, _maxConcurrency, _prefetch,
34-
cancellationToken);
33+
var sourceCTS = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
34+
var en = new ConcatMapEagerEnumerator(_source.GetAsyncEnumerator(sourceCTS.Token), _mapper, _maxConcurrency, _prefetch,
35+
sourceCTS);
3536
en.MoveNextSource();
3637
return en;
3738
}
@@ -44,7 +45,7 @@ private sealed class ConcatMapEagerEnumerator : IAsyncEnumerator<TResult>
4445

4546
private readonly int _prefetch;
4647

47-
private readonly CancellationToken _ct;
48+
private readonly CancellationTokenSource _sourceCTS;
4849

4950
private int _sourceOutstanding;
5051

@@ -69,7 +70,7 @@ private sealed class ConcatMapEagerEnumerator : IAsyncEnumerator<TResult>
6970

7071
public ConcatMapEagerEnumerator(IAsyncEnumerator<TSource> source,
7172
Func<TSource, IAsyncEnumerable<TResult>> mapper, int maxConcurrency, int prefetch,
72-
CancellationToken ct)
73+
CancellationTokenSource cts)
7374
{
7475
_source = source;
7576
_mapper = mapper;
@@ -78,11 +79,12 @@ public ConcatMapEagerEnumerator(IAsyncEnumerator<TSource> source,
7879
_disposeWip = 1;
7980
_inners = new ConcurrentQueue<InnerHandler>();
8081
_disposeTask = new TaskCompletionSource<bool>();
81-
_ct = ct;
82+
_sourceCTS = cts;
8283
}
8384

8485
public ValueTask DisposeAsync()
8586
{
87+
_sourceCTS.Cancel();
8688
_disposeRequested = true;
8789
if (Interlocked.Increment(ref _sourceDisposeWip) == 1)
8890
{
@@ -175,6 +177,15 @@ private bool TryDispose()
175177

176178
private void NextHandler(Task<bool> t)
177179
{
180+
if (t.IsCanceled)
181+
{
182+
ExceptionHelper.AddException(ref _error, new OperationCanceledException());
183+
_sourceDone = true;
184+
if (TryDispose())
185+
{
186+
ResumeHelper.Resume(ref _resume);
187+
}
188+
} else
178189
if (t.IsFaulted)
179190
{
180191
ExceptionHelper.AddException(ref _error, ExceptionHelper.Extract(t.Exception));
@@ -186,7 +197,7 @@ private void NextHandler(Task<bool> t)
186197
}
187198
else if (t.Result)
188199
{
189-
var cts = CancellationTokenSource.CreateLinkedTokenSource(_ct);
200+
var cts = CancellationTokenSource.CreateLinkedTokenSource(_sourceCTS.Token);
190201
IAsyncEnumerator<TResult> src;
191202
try
192203
{
@@ -312,7 +323,15 @@ private bool TryDispose()
312323

313324
private void InnerNextHandler(Task<bool> t)
314325
{
315-
if (t.IsFaulted)
326+
if (t.IsCanceled)
327+
{
328+
ExceptionHelper.AddException(ref _parent._error, new OperationCanceledException());
329+
Done = true;
330+
if (TryDispose())
331+
{
332+
ResumeHelper.Resume(ref _parent._resume);
333+
}
334+
} else if (t.IsFaulted)
316335
{
317336
ExceptionHelper.AddException(ref _parent._error, ExceptionHelper.Extract(t.Exception));
318337
Done = true;

async-enumerable-dotnet/impl/Debounce.cs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ public Debounce(IAsyncEnumerable<T> source, TimeSpan delay, bool emitLast)
2626

2727
public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken)
2828
{
29-
var en = new DebounceEnumerator(_source.GetAsyncEnumerator(cancellationToken), _delay, _emitLast, cancellationToken);
29+
var sourceCTS = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
30+
var en = new DebounceEnumerator(_source.GetAsyncEnumerator(sourceCTS.Token), _delay, _emitLast, sourceCTS);
3031
en.MoveNext();
3132
return en;
3233
}
@@ -39,7 +40,7 @@ private sealed class DebounceEnumerator : IAsyncEnumerator<T>
3940

4041
private readonly bool _emitLast;
4142

42-
private readonly CancellationToken _ct;
43+
private readonly CancellationTokenSource _sourceCTS;
4344

4445
public T Current { get; private set; }
4546

@@ -62,16 +63,17 @@ private sealed class DebounceEnumerator : IAsyncEnumerator<T>
6263

6364
private CancellationTokenSource _cts;
6465

65-
public DebounceEnumerator(IAsyncEnumerator<T> source, TimeSpan delay, bool emitLast, CancellationToken ct)
66+
public DebounceEnumerator(IAsyncEnumerator<T> source, TimeSpan delay, bool emitLast, CancellationTokenSource cts)
6667
{
6768
_source = source;
6869
_delay = delay;
6970
_emitLast = emitLast;
70-
_ct = ct;
71+
_sourceCTS = cts;
7172
}
7273

7374
public ValueTask DisposeAsync()
7475
{
76+
_sourceCTS.Cancel();
7577
CancellationHelper.Cancel(ref _cts);
7678
if (Interlocked.Increment(ref _disposeWip) == 1)
7779
{
@@ -175,7 +177,7 @@ private void HandleMain(Task<bool> t)
175177
_emitLastItem = v;
176178
}
177179
var idx = ++_sourceIndex;
178-
var newCts = CancellationTokenSource.CreateLinkedTokenSource(_ct);
180+
var newCts = CancellationTokenSource.CreateLinkedTokenSource(_sourceCTS.Token);
179181
if (CancellationHelper.Replace(ref _cts, newCts))
180182
{
181183
Task.Delay(_delay, newCts.Token)

async-enumerable-dotnet/impl/FlatMap.cs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ public FlatMap(IAsyncEnumerable<TSource> source, Func<TSource, IAsyncEnumerable<
3030

3131
public IAsyncEnumerator<TResult> GetAsyncEnumerator(CancellationToken cancellationToken)
3232
{
33-
var en = new FlatMapEnumerator(_source.GetAsyncEnumerator(cancellationToken), _mapper, _maxConcurrency, _prefetch, cancellationToken);
33+
var sourceCTS = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
34+
var en = new FlatMapEnumerator(_source.GetAsyncEnumerator(sourceCTS.Token), _mapper, _maxConcurrency, _prefetch, sourceCTS);
3435
en.MoveNext();
3536
return en;
3637
}
@@ -45,7 +46,7 @@ internal sealed class FlatMapEnumerator : IAsyncEnumerator<TResult>
4546

4647
private readonly ConcurrentQueue<Item> _queue;
4748

48-
private readonly CancellationToken _ct;
49+
private readonly CancellationTokenSource _sourceCTS;
4950

5051
private TaskCompletionSource<bool> _resume;
5152

@@ -71,21 +72,22 @@ internal sealed class FlatMapEnumerator : IAsyncEnumerator<TResult>
7172

7273
public FlatMapEnumerator(IAsyncEnumerator<TSource> source, Func<TSource, IAsyncEnumerable<TResult>> mapper,
7374
int maxConcurrency, int prefetch,
74-
CancellationToken ct)
75+
CancellationTokenSource cts)
7576
{
7677
_source = source;
7778
_mapper = mapper;
7879
_queue = new ConcurrentQueue<Item>();
7980
_prefetch = prefetch;
8081
_allDisposeWip = 1; // the main source is one
8182
_allDisposeTask = new TaskCompletionSource<bool>();
82-
_ct = ct;
83+
_sourceCTS = cts;
8384
Volatile.Write(ref _outstanding, maxConcurrency);
8485
Volatile.Write(ref _inners, Empty);
8586
}
8687

8788
public ValueTask DisposeAsync()
8889
{
90+
_sourceCTS.Cancel();
8991
if (Interlocked.Increment(ref _sourceDisposeWip) == 1)
9092
{
9193
Dispose(_source);
@@ -145,7 +147,7 @@ private void Handle(Task<bool> task)
145147

146148
if (TryDispose())
147149
{
148-
var cts = CancellationTokenSource.CreateLinkedTokenSource(_ct);
150+
var cts = CancellationTokenSource.CreateLinkedTokenSource(_sourceCTS.Token);
149151
IAsyncEnumerator<TResult> innerSource;
150152
try
151153
{

async-enumerable-dotnet/impl/Multicast.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ internal MulticastEnumerator(IAsyncEnumerator<TSource> source, IAsyncConsumer<TS
4444

4545
public ValueTask DisposeAsync()
4646
{
47+
_cancelSource.Cancel();
4748
if (Interlocked.Increment(ref _sourceDisposeWip) == 1)
4849
{
4950
Dispose(_source);

0 commit comments

Comments
 (0)