diff --git a/KMPNativeCoroutinesAsync/AsyncSequence.swift b/KMPNativeCoroutinesAsync/AsyncSequence.swift index fe31ee8a..1e658ca3 100644 --- a/KMPNativeCoroutinesAsync/AsyncSequence.swift +++ b/KMPNativeCoroutinesAsync/AsyncSequence.swift @@ -24,50 +24,83 @@ public struct NativeFlowAsyncSequence: AsyncSequen public class Iterator: AsyncIteratorProtocol, @unchecked Sendable { + private enum State { + case new(NativeFlow) + case producing(UnsafeContinuation) + case consuming(() -> Unit) + case completed(Failure?) + case cancelled + } + private let semaphore = DispatchSemaphore(value: 1) + private var state: State private var nativeCancellable: NativeCancellable? - private var item: (Output, () -> Unit)? = nil - private var result: Failure?? = Optional.none - private var cancellationError: Failure? = nil - private var continuation: UnsafeContinuation? = nil - init(nativeFlow: NativeFlow) { - nativeCancellable = nativeFlow({ item, next, unit in - self.semaphore.wait() - defer { self.semaphore.signal() } - if let continuation = self.continuation { - continuation.resume(returning: item) - self.continuation = nil - return next() - } else { - self.item = (item, next) - return unit - } - }, { error, unit in - self.semaphore.wait() - defer { self.semaphore.signal() } - self.result = Optional.some(error) - if let continuation = self.continuation { - if let error = error { - continuation.resume(throwing: error) - } else { - continuation.resume(returning: nil) - } - self.continuation = nil - } - self.nativeCancellable = nil + init(nativeFlow: @escaping NativeFlow) { + state = .new(nativeFlow) + } + + private func onItem(item: Output, next: @escaping () -> Unit, unit: Unit) -> Unit { + semaphore.wait() + defer { semaphore.signal() } + switch state { + case .new: + fatalError("onItem can't be called while in state new") + case .producing(let continuation): + continuation.resume(returning: item) + state = .consuming(next) return unit - }, { cancellationError, unit in - self.semaphore.wait() - defer { self.semaphore.signal() } - self.cancellationError = cancellationError - if let continuation = self.continuation { + case .consuming: + fatalError("onItem can't be called while in state consuming") + case .completed: + fatalError("onItem can't be called while in state completed") + case .cancelled: + fatalError("onItem can't be called while in state cancelled") + } + } + + private func onComplete(error: Failure?, unit: Unit) -> Unit { + semaphore.wait() + defer { semaphore.signal() } + switch state { + case .new: + fatalError("onComplete can't be called while in state new") + case .producing(let continuation): + if let error { + continuation.resume(throwing: error) + } else { continuation.resume(returning: nil) - self.continuation = nil } - self.nativeCancellable = nil + state = .completed(error) + return unit + case .consuming: + state = .completed(error) + return unit + case .completed: + return unit + case .cancelled: + return unit + } + } + + private func onCancelled(error: Failure, unit: Unit) -> Unit { + semaphore.wait() + defer { semaphore.signal() } + switch state { + case .new: + fatalError("onCancelled can't be called while in state new") + case .producing(let continuation): + continuation.resume(throwing: CancellationError()) + state = .cancelled return unit - }) + case .consuming: + state = .cancelled + return unit + case .completed: + return unit + case .cancelled: + return unit + } } public func next() async throws -> Output? { @@ -75,28 +108,34 @@ public struct NativeFlowAsyncSequence: AsyncSequen try await withUnsafeThrowingContinuation { continuation in self.semaphore.wait() defer { self.semaphore.signal() } - if let (item, next) = self.item { - continuation.resume(returning: item) + switch state { + case .new(let nativeFlow): + nativeCancellable = nativeFlow(onItem, onComplete, onCancelled) + state = .producing(continuation) + case .producing: + fatalError("Concurrent calls to next aren't supported") + case .consuming(let next): _ = next() - self.item = nil - } else if let result = self.result { - if let error = result { + state = .producing(continuation) + case .completed(let error): + if let error { continuation.resume(throwing: error) } else { continuation.resume(returning: nil) } - } else if self.cancellationError != nil { + case .cancelled: continuation.resume(throwing: CancellationError()) - } else { - guard self.continuation == nil else { - fatalError("Concurrent calls to next aren't supported") - } - self.continuation = continuation } } } onCancel: { + self.semaphore.wait() + if case .new = state { + state = .cancelled + } + let nativeCancellable = self.nativeCancellable + self.nativeCancellable = nil + self.semaphore.signal() _ = nativeCancellable?() - nativeCancellable = nil } } } diff --git a/KMPNativeCoroutinesAsyncTests/AsyncSequenceTests.swift b/KMPNativeCoroutinesAsyncTests/AsyncSequenceTests.swift index 320d3f43..1e25fe42 100644 --- a/KMPNativeCoroutinesAsyncTests/AsyncSequenceTests.swift +++ b/KMPNativeCoroutinesAsyncTests/AsyncSequenceTests.swift @@ -13,7 +13,7 @@ class AsyncSequenceTests: XCTestCase { private class TestValue { } - func testCancellableInvoked() async { + func testCancellableInvoked() async throws { var cancelCount = 0 let nativeFlow: NativeFlow = { _, _, cancelCallback in return { @@ -25,6 +25,7 @@ class AsyncSequenceTests: XCTestCase { for try await _ in asyncSequence(for: nativeFlow) { } } XCTAssertEqual(cancelCount, 0, "Cancellable shouldn't be invoked yet") + try await Task.sleep(nanoseconds: 10_000_000) // Gives the sequence a moment to start handle.cancel() let result = await handle.result XCTAssertEqual(cancelCount, 1, "Cancellable should be invoked once") @@ -65,8 +66,10 @@ class AsyncSequenceTests: XCTestCase { func testCompletionWithError() async { let sendError = NSError(domain: "Test", code: 0) let nativeFlow: NativeFlow = { _, completionCallback, _ in - completionCallback(sendError, ()) - return { } + let handle = Task { + completionCallback(sendError, ()) + } + return { handle.cancel() } } var valueCount = 0 do { diff --git a/sample/Async/AsyncSequenceIntegrationTests.swift b/sample/Async/AsyncSequenceIntegrationTests.swift index 04177f40..fe6dade1 100644 --- a/sample/Async/AsyncSequenceIntegrationTests.swift +++ b/sample/Async/AsyncSequenceIntegrationTests.swift @@ -41,8 +41,7 @@ class AsyncSequenceIntegrationTests: XCTestCase { var receivedValueCount: Int32 = 0 for try await _ in sequence { let emittedCount = integrationTests.emittedCount - // Note the AsyncSequence buffers at most a single item - XCTAssert(emittedCount == receivedValueCount || emittedCount == receivedValueCount + 1, "Back pressure isn't applied") + XCTAssert(emittedCount == receivedValueCount, "Back pressure isn't applied") delay(0.2) receivedValueCount += 1 }