diff --git a/Sources/Helpers/Task+withTimeout.swift b/Sources/Helpers/Task+withTimeout.swift index 5477806b..80ab8fb7 100644 --- a/Sources/Helpers/Task+withTimeout.swift +++ b/Sources/Helpers/Task+withTimeout.swift @@ -10,7 +10,7 @@ import Foundation @discardableResult package func withTimeout( interval: TimeInterval, - @_inheritActorContext operation: @escaping @Sendable () async throws -> R + @_inheritActorContext operation: @escaping @Sendable () async -> R ) async throws -> R { try await withThrowingTaskGroup(of: R.self) { group in defer { @@ -20,7 +20,7 @@ package func withTimeout( let deadline = Date(timeIntervalSinceNow: interval) group.addTask { - try await operation() + await operation() } group.addTask { diff --git a/Sources/Realtime/RealtimeChannelV2.swift b/Sources/Realtime/RealtimeChannelV2.swift index 9b6efa14..746ef43b 100644 --- a/Sources/Realtime/RealtimeChannelV2.swift +++ b/Sources/Realtime/RealtimeChannelV2.swift @@ -83,9 +83,101 @@ public final class RealtimeChannelV2: Sendable { callbackManager.reset() } - /// Subscribes to the channel + /// Subscribes to the channel. + public func subscribeWithError() async throws { + logger?.debug("Starting subscription to channel '\(topic)' (attempt 1/\(socket.options.maxRetryAttempts))") + + status = .subscribing + + defer { + // If the subscription fails, we need to set the status to unsubscribed + // to avoid the channel being stuck in a subscribing state. + if status != .subscribed { + status = .unsubscribed + } + } + + var attempts = 0 + + while attempts < socket.options.maxRetryAttempts { + attempts += 1 + + do { + logger?.debug( + "Attempting to subscribe to channel '\(topic)' (attempt \(attempts)/\(socket.options.maxRetryAttempts))" + ) + + try await withTimeout(interval: socket.options.timeoutInterval) { [self] in + await _subscribe() + } + + logger?.debug("Successfully subscribed to channel '\(topic)'") + return + + } catch is TimeoutError { + logger?.debug( + "Subscribe timed out for channel '\(topic)' (attempt \(attempts)/\(socket.options.maxRetryAttempts))" + ) + + if attempts < socket.options.maxRetryAttempts { + // Add exponential backoff with jitter + let delay = calculateRetryDelay(for: attempts) + logger?.debug( + "Retrying subscription to channel '\(topic)' in \(String(format: "%.2f", delay)) seconds..." + ) + + do { + try await _clock.sleep(for: delay) + } catch { + // If sleep is cancelled, break out of retry loop + logger?.debug("Subscription retry cancelled for channel '\(topic)'") + throw CancellationError() + } + } else { + logger?.error( + "Failed to subscribe to channel '\(topic)' after \(socket.options.maxRetryAttempts) attempts due to timeout" + ) + } + } catch is CancellationError { + logger?.debug("Subscription retry cancelled for channel '\(topic)'") + throw CancellationError() + } catch { + preconditionFailure( + "The only possible error here is TimeoutError or CancellationError, this should never happen." + ) + } + } + + logger?.error("Subscription to channel '\(topic)' failed after \(attempts) attempts") + throw RealtimeError.maxRetryAttemptsReached + } + + /// Subscribes to the channel. + @available(*, deprecated, message: "Use `subscribeWithError` instead") @MainActor public func subscribe() async { + try? await subscribeWithError() + } + + /// Calculates retry delay with exponential backoff and jitter + private func calculateRetryDelay(for attempt: Int) -> TimeInterval { + let baseDelay: TimeInterval = 1.0 + let maxDelay: TimeInterval = 30.0 + let backoffMultiplier: Double = 2.0 + + let exponentialDelay = baseDelay * pow(backoffMultiplier, Double(attempt - 1)) + let cappedDelay = min(exponentialDelay, maxDelay) + + // Add jitter (±25% random variation) to prevent thundering herd + let jitterRange = cappedDelay * 0.25 + let jitter = Double.random(in: -jitterRange...jitterRange) + + return max(0.1, cappedDelay + jitter) + } + + /// Subscribes to the channel + @MainActor + private func _subscribe() async { if socket.status != .connected { if socket.options.connectOnSubscribe != true { reportIssue( @@ -96,7 +188,6 @@ public final class RealtimeChannelV2: Sendable { await socket.connect() } - status = .subscribing logger?.debug("Subscribing to channel \(topic)") config.presence.enabled = callbackManager.callbacks.contains(where: { $0.isPresence }) @@ -125,18 +216,7 @@ public final class RealtimeChannelV2: Sendable { payload: try! JSONObject(payload) ) - do { - try await withTimeout(interval: socket.options.timeoutInterval) { [self] in - _ = await statusChange.first { @Sendable in $0 == .subscribed } - } - } catch { - if error is TimeoutError { - logger?.debug("Subscribe timed out.") - await subscribe() - } else { - logger?.error("Subscribe failed: \(error)") - } - } + _ = await statusChange.first { @Sendable in $0 == .subscribed } } public func unsubscribe() async { @@ -175,13 +255,6 @@ public final class RealtimeChannelV2: Sendable { @MainActor public func broadcast(event: String, message: JSONObject) async { if status != .subscribed { - struct Message: Encodable { - let topic: String - let event: String - let payload: JSONObject - let `private`: Bool - } - var headers: HTTPFields = [.contentType: "application/json"] if let apiKey = socket.options.apikey { headers[.apiKey] = apiKey @@ -190,6 +263,17 @@ public final class RealtimeChannelV2: Sendable { headers[.authorization] = "Bearer \(accessToken)" } + struct BroadcastMessagePayload: Encodable { + let messages: [Message] + + struct Message: Encodable { + let topic: String + let event: String + let payload: JSONObject + let `private`: Bool + } + } + let task = Task { [headers] in _ = try? await socket.http.send( HTTPRequest( @@ -197,16 +281,16 @@ public final class RealtimeChannelV2: Sendable { method: .post, headers: headers, body: JSONEncoder().encode( - [ - "messages": [ - Message( + BroadcastMessagePayload( + messages: [ + BroadcastMessagePayload.Message( topic: topic, event: event, payload: message, private: config.isPrivate ) ] - ] + ) ) ) ) diff --git a/Sources/Realtime/RealtimeError.swift b/Sources/Realtime/RealtimeError.swift index db0d3770..675ca27e 100644 --- a/Sources/Realtime/RealtimeError.swift +++ b/Sources/Realtime/RealtimeError.swift @@ -14,3 +14,10 @@ struct RealtimeError: LocalizedError { self.errorDescription = errorDescription } } + +extension RealtimeError { + /// The maximum retry attempts reached. + static var maxRetryAttemptsReached: Self { + Self("Maximum retry attempts reached.") + } +} diff --git a/Sources/Realtime/Types.swift b/Sources/Realtime/Types.swift index f1bd073e..30d625e0 100644 --- a/Sources/Realtime/Types.swift +++ b/Sources/Realtime/Types.swift @@ -20,6 +20,7 @@ public struct RealtimeClientOptions: Sendable { var timeoutInterval: TimeInterval var disconnectOnSessionLoss: Bool var connectOnSubscribe: Bool + var maxRetryAttempts: Int /// Sets the log level for Realtime var logLevel: LogLevel? @@ -32,6 +33,7 @@ public struct RealtimeClientOptions: Sendable { public static let defaultTimeoutInterval: TimeInterval = 10 public static let defaultDisconnectOnSessionLoss = true public static let defaultConnectOnSubscribe: Bool = true + public static let defaultMaxRetryAttempts: Int = 5 public init( headers: [String: String] = [:], @@ -40,6 +42,7 @@ public struct RealtimeClientOptions: Sendable { timeoutInterval: TimeInterval = Self.defaultTimeoutInterval, disconnectOnSessionLoss: Bool = Self.defaultDisconnectOnSessionLoss, connectOnSubscribe: Bool = Self.defaultConnectOnSubscribe, + maxRetryAttempts: Int = Self.defaultMaxRetryAttempts, logLevel: LogLevel? = nil, fetch: (@Sendable (_ request: URLRequest) async throws -> (Data, URLResponse))? = nil, accessToken: (@Sendable () async throws -> String?)? = nil, @@ -51,6 +54,7 @@ public struct RealtimeClientOptions: Sendable { self.timeoutInterval = timeoutInterval self.disconnectOnSessionLoss = disconnectOnSessionLoss self.connectOnSubscribe = connectOnSubscribe + self.maxRetryAttempts = maxRetryAttempts self.logLevel = logLevel self.fetch = fetch self.accessToken = accessToken diff --git a/Tests/IntegrationTests/RealtimeIntegrationTests.swift b/Tests/IntegrationTests/RealtimeIntegrationTests.swift index e641154e..5ad82f26 100644 --- a/Tests/IntegrationTests/RealtimeIntegrationTests.swift +++ b/Tests/IntegrationTests/RealtimeIntegrationTests.swift @@ -70,7 +70,11 @@ struct TestLogger: SupabaseLogger { await Task.yield() - await channel.subscribe() + do { + try await channel.subscribeWithError() + } catch { + XCTFail("Expected .subscribed but got error: \(error)") + } struct Message: Codable { var value: Int @@ -141,7 +145,11 @@ struct TestLogger: SupabaseLogger { await Task.yield() - await channel.subscribe() + do { + try await channel.subscribeWithError() + } catch { + XCTFail("Expected .subscribed but got error: \(error)") + } struct UserState: Codable, Equatable { let email: String @@ -201,7 +209,11 @@ struct TestLogger: SupabaseLogger { } await Task.yield() - await channel.subscribe() + do { + try await channel.subscribeWithError() + } catch { + XCTFail("Expected .subscribed but got error: \(error)") + } struct Entry: Codable, Equatable { let key: String diff --git a/Tests/RealtimeTests/RealtimeChannelTests.swift b/Tests/RealtimeTests/RealtimeChannelTests.swift index 8589519d..9362513a 100644 --- a/Tests/RealtimeTests/RealtimeChannelTests.swift +++ b/Tests/RealtimeTests/RealtimeChannelTests.swift @@ -161,7 +161,7 @@ final class RealtimeChannelTests: XCTestCase { XCTAssertTrue(channel.callbackManager.callbacks.contains(where: { $0.isPresence })) // Start subscription process - Task { + let subscribeTask = Task { await channel.subscribe() } @@ -191,5 +191,8 @@ final class RealtimeChannelTests: XCTestCase { presenceSubscription.cancel() await channel.unsubscribe() socket.disconnect() + + // Note: We don't assert the subscribe status here because the test doesn't wait for completion + // The subscription is still in progress when we clean up } } diff --git a/Tests/RealtimeTests/RealtimeTests.swift b/Tests/RealtimeTests/RealtimeTests.swift index 59cb9ff5..f24aec6f 100644 --- a/Tests/RealtimeTests/RealtimeTests.swift +++ b/Tests/RealtimeTests/RealtimeTests.swift @@ -108,6 +108,23 @@ final class RealtimeTests: XCTestCase { } .store(in: &subscriptions) + // Set up server to respond to heartbeats + server.onEvent = { @Sendable [server] event in + guard let msg = event.realtimeMessage else { return } + + if msg.event == "heartbeat" { + server?.send( + RealtimeMessageV2( + joinRef: msg.joinRef, + ref: msg.ref, + topic: "phoenix", + event: "phx_reply", + payload: ["response": [:]] + ) + ) + } + } + await sut.connect() XCTAssertEqual(socketStatuses.value, [.disconnected, .connecting, .connected]) @@ -127,14 +144,17 @@ final class RealtimeTests: XCTestCase { .store(in: &subscriptions) let subscribeTask = Task { - await channel.subscribe() + try await channel.subscribeWithError() } await Task.yield() server.send(.messagesSubscribed) // Wait until it subscribes to assert WS events - await subscribeTask.value - + do { + try await subscribeTask.value + } catch { + XCTFail("Expected .subscribed but got error: \(error)") + } XCTAssertEqual(channelStatuses.value, [.unsubscribed, .subscribing, .subscribed]) assertInlineSnapshot(of: client.sentEvents.map(\.json), as: .json) { @@ -216,11 +236,17 @@ final class RealtimeTests: XCTestCase { await testClock.advance(by: .seconds(heartbeatInterval)) Task { - await channel.subscribe() + try await channel.subscribeWithError() } // Wait for the timeout for rejoining. await testClock.advance(by: .seconds(timeoutInterval)) + + // Wait for the retry delay (base delay is 1.0s, but we need to account for jitter) + // The retry delay is calculated as: baseDelay * pow(2, attempt-1) + jitter + // For attempt 2: 1.0 * pow(2, 1) = 2.0s + jitter (up to ±25% = ±0.5s) + // So we need to wait at least 2.5s to ensure the retry happens + await testClock.advance(by: .seconds(2.5)) let events = client.sentEvents.compactMap { $0.realtimeMessage }.filter { $0.event == "phx_join" @@ -281,6 +307,161 @@ final class RealtimeTests: XCTestCase { } } + // Succeeds after 2 retries (on 3rd attempt) + func testSubscribeTimeout_successAfterRetries() async throws { + let successAttempt = 3 + let channel = sut.channel("public:messages") + let joinEventCount = LockIsolated(0) + + server.onEvent = { @Sendable [server] event in + guard let msg = event.realtimeMessage else { return } + + if msg.event == "heartbeat" { + server?.send( + RealtimeMessageV2( + joinRef: msg.joinRef, + ref: msg.ref, + topic: "phoenix", + event: "phx_reply", + payload: ["response": [:]] + ) + ) + } else if msg.event == "phx_join" { + joinEventCount.withValue { $0 += 1 } + // Respond on the 3rd attempt + if joinEventCount.value == successAttempt { + server?.send(.messagesSubscribed) + } + } + } + + await sut.connect() + await testClock.advance(by: .seconds(heartbeatInterval)) + + let subscribeTask = Task { + _ = try? await channel.subscribeWithError() + } + + // Wait for each attempt and retry delay + for attempt in 1..