From 41b79173b5f22108dd625019a8c2458ed6f092a5 Mon Sep 17 00:00:00 2001 From: Sergei Sysoev Date: Wed, 5 Feb 2025 21:12:15 +0100 Subject: [PATCH] Move `ThreadContextElement` to common --- .../api/kotlinx-coroutines-core.klib.api | 5 + .../common/src/ThreadContextElement.common.kt | 82 ++++++++ .../src/internal/ThreadContext.common.kt | 53 ++++- .../jsAndWasmShared/src/CoroutineContext.kt | 102 +++++++++- .../src/internal/ThreadContext.kt | 38 +++- .../jvm/src/ThreadContextElement.kt | 79 -------- .../jvm/src/internal/ThreadContext.kt | 51 ----- .../native/src/CoroutineContext.kt | 184 +++++++++++++++++- .../native/src/internal/ThreadContext.kt | 38 +++- 9 files changed, 491 insertions(+), 141 deletions(-) create mode 100644 kotlinx-coroutines-core/common/src/ThreadContextElement.common.kt diff --git a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.klib.api b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.klib.api index 9ba54a4e00..468fb059ae 100644 --- a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.klib.api +++ b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.klib.api @@ -186,6 +186,11 @@ abstract interface <#A: kotlin/Any?> kotlinx.coroutines/CompletableDeferred : ko abstract fun completeExceptionally(kotlin/Throwable): kotlin/Boolean // kotlinx.coroutines/CompletableDeferred.completeExceptionally|completeExceptionally(kotlin.Throwable){}[0] } +abstract interface <#A: kotlin/Any?> kotlinx.coroutines/ThreadContextElement : kotlin.coroutines/CoroutineContext.Element { // kotlinx.coroutines/ThreadContextElement|null[0] + abstract fun restoreThreadContext(kotlin.coroutines/CoroutineContext, #A) // kotlinx.coroutines/ThreadContextElement.restoreThreadContext|restoreThreadContext(kotlin.coroutines.CoroutineContext;1:0){}[0] + abstract fun updateThreadContext(kotlin.coroutines/CoroutineContext): #A // kotlinx.coroutines/ThreadContextElement.updateThreadContext|updateThreadContext(kotlin.coroutines.CoroutineContext){}[0] +} + abstract interface <#A: kotlin/Throwable & kotlinx.coroutines/CopyableThrowable<#A>> kotlinx.coroutines/CopyableThrowable { // kotlinx.coroutines/CopyableThrowable|null[0] abstract fun createCopy(): #A? // kotlinx.coroutines/CopyableThrowable.createCopy|createCopy(){}[0] } diff --git a/kotlinx-coroutines-core/common/src/ThreadContextElement.common.kt b/kotlinx-coroutines-core/common/src/ThreadContextElement.common.kt new file mode 100644 index 0000000000..e5ea541c95 --- /dev/null +++ b/kotlinx-coroutines-core/common/src/ThreadContextElement.common.kt @@ -0,0 +1,82 @@ +package kotlinx.coroutines + +import kotlin.coroutines.* + +/** + * Defines elements in [CoroutineContext] that are installed into thread context + * every time the coroutine with this element in the context is resumed on a thread. + * + * Implementations of this interface define a type [S] of the thread-local state that they need to store on + * resume of a coroutine and restore later on suspend. The infrastructure provides the corresponding storage. + * + * Example usage looks like this: + * + * ``` + * // Appends "name" of a coroutine to a current thread name when coroutine is executed + * class CoroutineName(val name: String) : ThreadContextElement { + * // declare companion object for a key of this element in coroutine context + * companion object Key : CoroutineContext.Key + * + * // provide the key of the corresponding context element + * override val key: CoroutineContext.Key + * get() = Key + * + * // this is invoked before coroutine is resumed on current thread + * override fun updateThreadContext(context: CoroutineContext): String { + * val previousName = Thread.currentThread().name + * Thread.currentThread().name = "$previousName # $name" + * return previousName + * } + * + * // this is invoked after coroutine has suspended on current thread + * override fun restoreThreadContext(context: CoroutineContext, oldState: String) { + * Thread.currentThread().name = oldState + * } + * } + * + * // Usage + * launch(Dispatchers.Main + CoroutineName("Progress bar coroutine")) { ... } + * ``` + * + * Every time this coroutine is resumed on a thread, UI thread name is updated to + * "UI thread original name # Progress bar coroutine" and the thread name is restored to the original one when + * this coroutine suspends. + * + * To use [ThreadLocal] variable within the coroutine use [ThreadLocal.asContextElement][asContextElement] function. + * + * ### Reentrancy and thread-safety + * + * Correct implementations of this interface must expect that calls to [restoreThreadContext] + * may happen in parallel to the subsequent [updateThreadContext] and [restoreThreadContext] operations. + * See [CopyableThreadContextElement] for advanced interleaving details. + * + * All implementations of [ThreadContextElement] should be thread-safe and guard their internal mutable state + * within an element accordingly. + */ +public interface ThreadContextElement : CoroutineContext.Element { + /** + * Updates context of the current thread. + * This function is invoked before the coroutine in the specified [context] is resumed in the current thread + * when the context of the coroutine this element. + * The result of this function is the old value of the thread-local state that will be passed to [restoreThreadContext]. + * This method should handle its own exceptions and do not rethrow it. Thrown exceptions will leave coroutine which + * context is updated in an undefined state and may crash an application. + * + * @param context the coroutine context. + */ + public fun updateThreadContext(context: CoroutineContext): S + + /** + * Restores context of the current thread. + * This function is invoked after the coroutine in the specified [context] is suspended in the current thread + * if [updateThreadContext] was previously invoked on resume of this coroutine. + * The value of [oldState] is the result of the previous invocation of [updateThreadContext] and it should + * be restored in the thread-local state by this function. + * This method should handle its own exceptions and do not rethrow it. Thrown exceptions will leave coroutine which + * context is updated in an undefined state and may crash an application. + * + * @param context the coroutine context. + * @param oldState the value returned by the previous invocation of [updateThreadContext]. + */ + public fun restoreThreadContext(context: CoroutineContext, oldState: S) +} \ No newline at end of file diff --git a/kotlinx-coroutines-core/common/src/internal/ThreadContext.common.kt b/kotlinx-coroutines-core/common/src/internal/ThreadContext.common.kt index c52d35c128..03d718eae4 100644 --- a/kotlinx-coroutines-core/common/src/internal/ThreadContext.common.kt +++ b/kotlinx-coroutines-core/common/src/internal/ThreadContext.common.kt @@ -1,5 +1,56 @@ package kotlinx.coroutines.internal +import kotlinx.coroutines.* import kotlin.coroutines.* +import kotlin.jvm.* -internal expect fun threadContextElements(context: CoroutineContext): Any +@JvmField +internal val NO_THREAD_ELEMENTS = Symbol("NO_THREAD_ELEMENTS") + +// Used when there are >= 2 active elements in the context +@Suppress("UNCHECKED_CAST") +internal class ThreadState(@JvmField val context: CoroutineContext, n: Int) { + private val values = arrayOfNulls(n) + private val elements = arrayOfNulls>(n) + private var i = 0 + + fun append(element: ThreadContextElement<*>, value: Any?) { + values[i] = value + elements[i++] = element as ThreadContextElement + } + + fun restore(context: CoroutineContext) { + for (i in elements.indices.reversed()) { + elements[i]!!.restoreThreadContext(context, values[i]) + } + } +} + +// Counts ThreadContextElements in the context +// Any? here is Int | ThreadContextElement (when count is one) +private val countAll = + fun (countOrElement: Any?, element: CoroutineContext.Element): Any? { + if (element is ThreadContextElement<*>) { + val inCount = countOrElement as? Int ?: 1 + return if (inCount == 0) element else inCount + 1 + } + return countOrElement + } + +// Find one (first) ThreadContextElement in the context, it is used when we know there is exactly one +internal val findOne = + fun (found: ThreadContextElement<*>?, element: CoroutineContext.Element): ThreadContextElement<*>? { + if (found != null) return found + return element as? ThreadContextElement<*> + } + +// Updates state for ThreadContextElements in the context using the given ThreadState +internal val updateState = + fun (state: ThreadState, element: CoroutineContext.Element): ThreadState { + if (element is ThreadContextElement<*>) { + state.append(element, element.updateThreadContext(state.context)) + } + return state + } + +internal fun threadContextElements(context: CoroutineContext): Any = context.fold(0, countAll)!! diff --git a/kotlinx-coroutines-core/jsAndWasmShared/src/CoroutineContext.kt b/kotlinx-coroutines-core/jsAndWasmShared/src/CoroutineContext.kt index ae9a9444b9..3d3fdf9cf5 100644 --- a/kotlinx-coroutines-core/jsAndWasmShared/src/CoroutineContext.kt +++ b/kotlinx-coroutines-core/jsAndWasmShared/src/CoroutineContext.kt @@ -1,6 +1,10 @@ package kotlinx.coroutines +import kotlinx.coroutines.internal.CoroutineStackFrame +import kotlinx.coroutines.internal.NO_THREAD_ELEMENTS import kotlinx.coroutines.internal.ScopeCoroutine +import kotlinx.coroutines.internal.restoreThreadContext +import kotlinx.coroutines.internal.updateThreadContext import kotlin.coroutines.* @PublishedApi // Used from kotlinx-coroutines-test via suppress, not part of ABI @@ -18,8 +22,73 @@ public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineCo } // No debugging facilities on Wasm and JS -internal actual inline fun withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T = block() -internal actual inline fun withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T = block() +internal actual inline fun withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T { + val oldValue = updateThreadContext(context, countOrElement) + try { + return block() + } finally { + restoreThreadContext(context, oldValue) + } +} + +internal actual inline fun withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T { + val context = continuation.context + val oldValue = updateThreadContext(context, countOrElement) + val undispatchedCompletion = if (oldValue !== NO_THREAD_ELEMENTS) { + // Only if some values were replaced we'll go to the slow path of figuring out where/how to restore them + continuation.updateUndispatchedCompletion(context, oldValue) + } else { + null // fast path -- don't even try to find undispatchedCompletion as there's nothing to restore in the context + } + try { + return block() + } finally { + if (undispatchedCompletion == null || undispatchedCompletion.clearThreadContext()) { + restoreThreadContext(context, oldValue) + } + } +} + +internal fun Continuation<*>.updateUndispatchedCompletion(context: CoroutineContext, oldValue: Any?): UndispatchedCoroutine<*>? { + if (this !is CoroutineStackFrame) return null + /* + * Fast-path to detect whether we have undispatched coroutine at all in our stack. + * + * Implementation note. + * If we ever find that stackwalking for thread-locals is way too slow, here is another idea: + * 1) Store undispatched coroutine right in the `UndispatchedMarker` instance + * 2) To avoid issues with cross-dispatch boundary, remove `UndispatchedMarker` + * from the context when creating dispatched coroutine in `withContext`. + * Another option is to "unmark it" instead of removing to save an allocation. + * Both options should work, but it requires more careful studying of the performance + * and, mostly, maintainability impact. + */ + val potentiallyHasUndispatchedCoroutine = context[UndispatchedMarker] !== null + if (!potentiallyHasUndispatchedCoroutine) return null + val completion = undispatchedCompletion() + completion?.saveThreadContext(context, oldValue) + return completion +} + +internal tailrec fun CoroutineStackFrame.undispatchedCompletion(): UndispatchedCoroutine<*>? { + // Find direct completion of this continuation + val completion: CoroutineStackFrame = when (this) { + is DispatchedCoroutine<*> -> return null + else -> callerFrame ?: return null // something else -- not supported + } + if (completion is UndispatchedCoroutine<*>) return completion // found UndispatchedCoroutine! + return completion.undispatchedCompletion() // walk up the call stack with tail call +} + +/** + * Marker indicating that [UndispatchedCoroutine] exists somewhere up in the stack. + * Used as a performance optimization to avoid stack walking where it is not necessary. + */ +private object UndispatchedMarker: CoroutineContext.Element, CoroutineContext.Key { + override val key: CoroutineContext.Key<*> + get() = this +} + internal actual fun Continuation<*>.toDebugString(): String = toString() internal actual val CoroutineContext.coroutineName: String? get() = null // not supported on Wasm and JS @@ -27,7 +96,34 @@ internal actual class UndispatchedCoroutine actual constructor( context: CoroutineContext, uCont: Continuation ) : ScopeCoroutine(context, uCont) { - override fun afterResume(state: Any?) = uCont.resumeWith(recoverResult(state, uCont)) + + private var savedContext: CoroutineContext? = null + private var savedOldValue: Any? = null + + fun saveThreadContext(context: CoroutineContext, oldValue: Any?) { + savedContext = context + savedOldValue = oldValue + } + + fun clearThreadContext(): Boolean { + if (savedContext == null) return false + savedContext = null + savedOldValue = null + return true + } + + override fun afterResume(state: Any?) { + savedContext?.let { context -> + restoreThreadContext(context, savedOldValue) + savedContext = null + savedOldValue = null + } + // resume undispatched -- update context but stay on the same dispatcher + val result = recoverResult(state, uCont) + withContinuationContext(uCont, null) { + uCont.resumeWith(result) + } + } } internal actual inline fun withThreadLocalContext(context: CoroutineContext, block: () -> T) : T = block() diff --git a/kotlinx-coroutines-core/jsAndWasmShared/src/internal/ThreadContext.kt b/kotlinx-coroutines-core/jsAndWasmShared/src/internal/ThreadContext.kt index 3f56f99d6c..a7915e43de 100644 --- a/kotlinx-coroutines-core/jsAndWasmShared/src/internal/ThreadContext.kt +++ b/kotlinx-coroutines-core/jsAndWasmShared/src/internal/ThreadContext.kt @@ -1,5 +1,41 @@ package kotlinx.coroutines.internal +import kotlinx.coroutines.* import kotlin.coroutines.* -internal actual fun threadContextElements(context: CoroutineContext): Any = 0 +// countOrElement is pre-cached in dispatched continuation +// returns NO_THREAD_ELEMENTS if the contest does not have any ThreadContextElements +internal fun updateThreadContext(context: CoroutineContext, countOrElement: Any?): Any? { + @Suppress("NAME_SHADOWING") + val countOrElement = countOrElement ?: threadContextElements(context) + @Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS") + return when { + countOrElement == 0 -> NO_THREAD_ELEMENTS // very fast path when there are no active ThreadContextElements + countOrElement is Int -> { + // slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values + context.fold(ThreadState(context, countOrElement), updateState) + } + else -> { + // fast path for one ThreadContextElement (no allocations, no additional context scan) + @Suppress("UNCHECKED_CAST") + val element = countOrElement as ThreadContextElement + element.updateThreadContext(context) + } + } +} + +internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) { + when { + oldState === NO_THREAD_ELEMENTS -> return // very fast path when there are no ThreadContextElements + oldState is ThreadState -> { + // slow path with multiple stored ThreadContextElements + oldState.restore(context) + } + else -> { + // fast path for one ThreadContextElement, but need to find it + @Suppress("UNCHECKED_CAST") + val element = context.fold(null, findOne) as ThreadContextElement + element.restoreThreadContext(context, oldState) + } + } +} diff --git a/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt b/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt index c1898fbd65..9f52f61d78 100644 --- a/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt +++ b/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt @@ -3,85 +3,6 @@ package kotlinx.coroutines import kotlinx.coroutines.internal.* import kotlin.coroutines.* -/** - * Defines elements in [CoroutineContext] that are installed into thread context - * every time the coroutine with this element in the context is resumed on a thread. - * - * Implementations of this interface define a type [S] of the thread-local state that they need to store on - * resume of a coroutine and restore later on suspend. The infrastructure provides the corresponding storage. - * - * Example usage looks like this: - * - * ``` - * // Appends "name" of a coroutine to a current thread name when coroutine is executed - * class CoroutineName(val name: String) : ThreadContextElement { - * // declare companion object for a key of this element in coroutine context - * companion object Key : CoroutineContext.Key - * - * // provide the key of the corresponding context element - * override val key: CoroutineContext.Key - * get() = Key - * - * // this is invoked before coroutine is resumed on current thread - * override fun updateThreadContext(context: CoroutineContext): String { - * val previousName = Thread.currentThread().name - * Thread.currentThread().name = "$previousName # $name" - * return previousName - * } - * - * // this is invoked after coroutine has suspended on current thread - * override fun restoreThreadContext(context: CoroutineContext, oldState: String) { - * Thread.currentThread().name = oldState - * } - * } - * - * // Usage - * launch(Dispatchers.Main + CoroutineName("Progress bar coroutine")) { ... } - * ``` - * - * Every time this coroutine is resumed on a thread, UI thread name is updated to - * "UI thread original name # Progress bar coroutine" and the thread name is restored to the original one when - * this coroutine suspends. - * - * To use [ThreadLocal] variable within the coroutine use [ThreadLocal.asContextElement][asContextElement] function. - * - * ### Reentrancy and thread-safety - * - * Correct implementations of this interface must expect that calls to [restoreThreadContext] - * may happen in parallel to the subsequent [updateThreadContext] and [restoreThreadContext] operations. - * See [CopyableThreadContextElement] for advanced interleaving details. - * - * All implementations of [ThreadContextElement] should be thread-safe and guard their internal mutable state - * within an element accordingly. - */ -public interface ThreadContextElement : CoroutineContext.Element { - /** - * Updates context of the current thread. - * This function is invoked before the coroutine in the specified [context] is resumed in the current thread - * when the context of the coroutine this element. - * The result of this function is the old value of the thread-local state that will be passed to [restoreThreadContext]. - * This method should handle its own exceptions and do not rethrow it. Thrown exceptions will leave coroutine which - * context is updated in an undefined state and may crash an application. - * - * @param context the coroutine context. - */ - public fun updateThreadContext(context: CoroutineContext): S - - /** - * Restores context of the current thread. - * This function is invoked after the coroutine in the specified [context] is suspended in the current thread - * if [updateThreadContext] was previously invoked on resume of this coroutine. - * The value of [oldState] is the result of the previous invocation of [updateThreadContext] and it should - * be restored in the thread-local state by this function. - * This method should handle its own exceptions and do not rethrow it. Thrown exceptions will leave coroutine which - * context is updated in an undefined state and may crash an application. - * - * @param context the coroutine context. - * @param oldState the value returned by the previous invocation of [updateThreadContext]. - */ - public fun restoreThreadContext(context: CoroutineContext, oldState: S) -} - /** * A [ThreadContextElement] copied whenever a child coroutine inherits a context containing it. * diff --git a/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt b/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt index 8f21b13c25..5b876071f6 100644 --- a/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt +++ b/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt @@ -3,57 +3,6 @@ package kotlinx.coroutines.internal import kotlinx.coroutines.* import kotlin.coroutines.* -@JvmField -internal val NO_THREAD_ELEMENTS = Symbol("NO_THREAD_ELEMENTS") - -// Used when there are >= 2 active elements in the context -@Suppress("UNCHECKED_CAST") -private class ThreadState(@JvmField val context: CoroutineContext, n: Int) { - private val values = arrayOfNulls(n) - private val elements = arrayOfNulls>(n) - private var i = 0 - - fun append(element: ThreadContextElement<*>, value: Any?) { - values[i] = value - elements[i++] = element as ThreadContextElement - } - - fun restore(context: CoroutineContext) { - for (i in elements.indices.reversed()) { - elements[i]!!.restoreThreadContext(context, values[i]) - } - } -} - -// Counts ThreadContextElements in the context -// Any? here is Int | ThreadContextElement (when count is one) -private val countAll = - fun (countOrElement: Any?, element: CoroutineContext.Element): Any? { - if (element is ThreadContextElement<*>) { - val inCount = countOrElement as? Int ?: 1 - return if (inCount == 0) element else inCount + 1 - } - return countOrElement - } - -// Find one (first) ThreadContextElement in the context, it is used when we know there is exactly one -private val findOne = - fun (found: ThreadContextElement<*>?, element: CoroutineContext.Element): ThreadContextElement<*>? { - if (found != null) return found - return element as? ThreadContextElement<*> - } - -// Updates state for ThreadContextElements in the context using the given ThreadState -private val updateState = - fun (state: ThreadState, element: CoroutineContext.Element): ThreadState { - if (element is ThreadContextElement<*>) { - state.append(element, element.updateThreadContext(state.context)) - } - return state - } - -internal actual fun threadContextElements(context: CoroutineContext): Any = context.fold(0, countAll)!! - // countOrElement is pre-cached in dispatched continuation // returns NO_THREAD_ELEMENTS if the contest does not have any ThreadContextElements internal fun updateThreadContext(context: CoroutineContext, countOrElement: Any?): Any? { diff --git a/kotlinx-coroutines-core/native/src/CoroutineContext.kt b/kotlinx-coroutines-core/native/src/CoroutineContext.kt index 5500e5b5ca..d3c83ff95d 100644 --- a/kotlinx-coroutines-core/native/src/CoroutineContext.kt +++ b/kotlinx-coroutines-core/native/src/CoroutineContext.kt @@ -1,6 +1,7 @@ package kotlinx.coroutines import kotlinx.coroutines.internal.* +import kotlin.concurrent.* import kotlin.coroutines.* internal actual object DefaultExecutor : CoroutineDispatcher(), Delay { @@ -40,16 +41,189 @@ public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineCo } // No debugging facilities on native -internal actual inline fun withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T = block() -internal actual inline fun withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T = block() +internal actual inline fun withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T { + val oldValue = updateThreadContext(context, countOrElement) + try { + return block() + } finally { + restoreThreadContext(context, oldValue) + } +} + +internal actual inline fun withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T { + val context = continuation.context + val oldValue = updateThreadContext(context, countOrElement) + val undispatchedCompletion = if (oldValue !== NO_THREAD_ELEMENTS) { + // Only if some values were replaced we'll go to the slow path of figuring out where/how to restore them + continuation.updateUndispatchedCompletion(context, oldValue) + } else { + null // fast path -- don't even try to find undispatchedCompletion as there's nothing to restore in the context + } + try { + return block() + } finally { + if (undispatchedCompletion == null || undispatchedCompletion.clearThreadContext()) { + restoreThreadContext(context, oldValue) + } + } +} + +internal fun Continuation<*>.updateUndispatchedCompletion(context: CoroutineContext, oldValue: Any?): UndispatchedCoroutine<*>? { + if (this !is CoroutineStackFrame) return null + /* + * Fast-path to detect whether we have undispatched coroutine at all in our stack. + * + * Implementation note. + * If we ever find that stackwalking for thread-locals is way too slow, here is another idea: + * 1) Store undispatched coroutine right in the `UndispatchedMarker` instance + * 2) To avoid issues with cross-dispatch boundary, remove `UndispatchedMarker` + * from the context when creating dispatched coroutine in `withContext`. + * Another option is to "unmark it" instead of removing to save an allocation. + * Both options should work, but it requires more careful studying of the performance + * and, mostly, maintainability impact. + */ + val potentiallyHasUndispatchedCoroutine = context[UndispatchedMarker] !== null + if (!potentiallyHasUndispatchedCoroutine) return null + val completion = undispatchedCompletion() + completion?.saveThreadContext(context, oldValue) + return completion +} + +internal tailrec fun CoroutineStackFrame.undispatchedCompletion(): UndispatchedCoroutine<*>? { + // Find direct completion of this continuation + val completion: CoroutineStackFrame = when (this) { + is DispatchedCoroutine<*> -> return null + else -> callerFrame ?: return null // something else -- not supported + } + if (completion is UndispatchedCoroutine<*>) return completion // found UndispatchedCoroutine! + return completion.undispatchedCompletion() // walk up the call stack with tail call +} + +/** + * Marker indicating that [UndispatchedCoroutine] exists somewhere up in the stack. + * Used as a performance optimization to avoid stack walking where it is not necessary. + */ +private object UndispatchedMarker: CoroutineContext.Element, CoroutineContext.Key { + override val key: CoroutineContext.Key<*> + get() = this +} + internal actual fun Continuation<*>.toDebugString(): String = toString() internal actual val CoroutineContext.coroutineName: String? get() = null // not supported on native -internal actual class UndispatchedCoroutine actual constructor( +internal actual class UndispatchedCoroutineactual constructor ( context: CoroutineContext, uCont: Continuation -) : ScopeCoroutine(context, uCont) { - override fun afterResume(state: Any?) = uCont.resumeWith(recoverResult(state, uCont)) +) : ScopeCoroutine(if (context[UndispatchedMarker] == null) context + UndispatchedMarker else context, uCont) { + + /** + * The state of [ThreadContextElement]s associated with the current undispatched coroutine. + * It is stored in a thread local because this coroutine can be used concurrently in suspend-resume race scenario. + * See the following, boiled down example with inlined `withContinuationContext` body: + * ``` + * val state = saveThreadContext(ctx) + * try { + * invokeSmthWithThisCoroutineAsCompletion() // Completion implies that 'afterResume' will be called + * // COROUTINE_SUSPENDED is returned + * } finally { + * thisCoroutine().clearThreadContext() // Concurrently the "smth" could've been already resumed on a different thread + * // and it also calls saveThreadContext and clearThreadContext + * } + * ``` + * + * Usage note: + * + * This part of the code is performance-sensitive. + * It is a well-established pattern to wrap various activities into system-specific undispatched + * `withContext` for the sake of logging, MDC, tracing etc., meaning that there exists thousands of + * undispatched coroutines. + * [ThreadLocal.set] leaves a footprint in the corresponding Thread's `ThreadLocalMap`. + * We attempt to narrow down the lifetime of this thread local as much as possible: + * - It's never accessed when we are sure there are no thread context elements + * - It's cleaned up via [ThreadLocal.remove] as soon as the coroutine is suspended or finished. + */ + private val threadStateToRecover = ThreadLocal?>(this) + + /* + * Indicates that a coroutine has at least one thread context element associated with it + * and that 'threadStateToRecover' is going to be set in case of dispatchhing in order to preserve them. + * Better than nullable thread-local for easier debugging. + * + * It is used as a performance optimization to avoid 'threadStateToRecover' initialization + * (note: tl.get() initializes thread local), + * and is prone to false-positives as it is never reset: otherwise + * it may lead to logical data races between suspensions point where + * coroutine is yet being suspended in one thread while already being resumed + * in another. + */ + @Volatile + private var threadLocalIsSet = false + + init { + /* + * This is a hack for a very specific case in #2930 unless #3253 is implemented. + * 'ThreadLocalStressTest' covers this change properly. + * + * The scenario this change covers is the following: + * 1) The coroutine is being started as plain non kotlinx.coroutines related suspend function, + * e.g. `suspend fun main` or, more importantly, Ktor `SuspendFunGun`, that is invoking + * `withContext(tlElement)` which creates `UndispatchedCoroutine`. + * 2) It (original continuation) is then not wrapped into `DispatchedContinuation` via `intercept()` + * and goes neither through `DC.run` nor through `resumeUndispatchedWith` that both + * do thread context element tracking. + * 3) So thread locals never got chance to get properly set up via `saveThreadContext`, + * but when `withContext` finishes, it attempts to recover thread locals in its `afterResume`. + * + * Here we detect precisely this situation and properly setup context to recover later. + * + */ + if (uCont.context[ContinuationInterceptor] !is CoroutineDispatcher) { + /* + * We cannot just "read" the elements as there is no such API, + * so we update-restore it immediately and use the intermediate value + * as the initial state, leveraging the fact that thread context element + * is idempotent and such situations are increasingly rare. + */ + val values = updateThreadContext(context, null) + restoreThreadContext(context, values) + saveThreadContext(context, values) + } + } + + fun saveThreadContext(context: CoroutineContext, oldValue: Any?) { + threadLocalIsSet = true // Specify that thread-local is touched at all + threadStateToRecover.set(context to oldValue) + } + + fun clearThreadContext(): Boolean { + return !(threadLocalIsSet && threadStateToRecover.get() == null).also { + threadStateToRecover.remove() + } + } + + override fun afterResume(state: Any?) { + if (threadLocalIsSet) { + threadStateToRecover.get()?.let { (ctx, value) -> + restoreThreadContext(ctx, value) + } + threadStateToRecover.remove() + } + // resume undispatched -- update context but stay on the same dispatcher + val result = recoverResult(state, uCont) + withContinuationContext(uCont, null) { + uCont.resumeWith(result) + } + } } +private class ThreadLocal(private val key: Any) { + @Suppress("UNCHECKED_CAST") + fun get(): T? = ThreadLocalMap[key] as? T + fun set(value: T) { ThreadLocalMap[key] = value } + fun remove() { ThreadLocalMap.remove(key) } +} + +@kotlin.native.concurrent.ThreadLocal +private object ThreadLocalMap: MutableMap by mutableMapOf() + internal actual inline fun withThreadLocalContext(context: CoroutineContext, block: () -> T) : T = block() diff --git a/kotlinx-coroutines-core/native/src/internal/ThreadContext.kt b/kotlinx-coroutines-core/native/src/internal/ThreadContext.kt index 3f56f99d6c..a7915e43de 100644 --- a/kotlinx-coroutines-core/native/src/internal/ThreadContext.kt +++ b/kotlinx-coroutines-core/native/src/internal/ThreadContext.kt @@ -1,5 +1,41 @@ package kotlinx.coroutines.internal +import kotlinx.coroutines.* import kotlin.coroutines.* -internal actual fun threadContextElements(context: CoroutineContext): Any = 0 +// countOrElement is pre-cached in dispatched continuation +// returns NO_THREAD_ELEMENTS if the contest does not have any ThreadContextElements +internal fun updateThreadContext(context: CoroutineContext, countOrElement: Any?): Any? { + @Suppress("NAME_SHADOWING") + val countOrElement = countOrElement ?: threadContextElements(context) + @Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS") + return when { + countOrElement == 0 -> NO_THREAD_ELEMENTS // very fast path when there are no active ThreadContextElements + countOrElement is Int -> { + // slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values + context.fold(ThreadState(context, countOrElement), updateState) + } + else -> { + // fast path for one ThreadContextElement (no allocations, no additional context scan) + @Suppress("UNCHECKED_CAST") + val element = countOrElement as ThreadContextElement + element.updateThreadContext(context) + } + } +} + +internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) { + when { + oldState === NO_THREAD_ELEMENTS -> return // very fast path when there are no ThreadContextElements + oldState is ThreadState -> { + // slow path with multiple stored ThreadContextElements + oldState.restore(context) + } + else -> { + // fast path for one ThreadContextElement, but need to find it + @Suppress("UNCHECKED_CAST") + val element = context.fold(null, findOne) as ThreadContextElement + element.restoreThreadContext(context, oldState) + } + } +}