Skip to content

fix: respect custom coroutine scope when calling blocking functions in a data fetcher #2116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import graphql.schema.DataFetcher
import graphql.schema.DataFetchingEnvironment
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.future.future
import kotlinx.coroutines.runBlocking
import java.lang.reflect.InvocationTargetException
import java.util.concurrent.CompletableFuture
import kotlin.coroutines.EmptyCoroutineContext
Expand Down Expand Up @@ -60,7 +61,7 @@ open class FunctionDataFetcher(
if (fn.isSuspend) {
runSuspendingFunction(environment, parameterValues)
} else {
runBlockingFunction(parameterValues)
runBlockingFunction(environment, parameterValues)
}
} else {
null
Expand Down Expand Up @@ -123,8 +124,14 @@ open class FunctionDataFetcher(
* Once all parameters values are properly converted, this function will be called to run a simple blocking function.
* If you need to override the exception handling you can override this method.
*/
protected open fun runBlockingFunction(parameterValues: Map<KParameter, Any?>): Any? = try {
fn.callBy(parameterValues)
protected open fun runBlockingFunction(
environment: DataFetchingEnvironment,
parameterValues: Map<KParameter, Any?>
): Any? = try {
val coroutineScope = environment.graphQlContext.getOrDefault(CoroutineScope(EmptyCoroutineContext))
runBlocking(coroutineScope.coroutineContext) {
fn.callBy(parameterValues)
}
} catch (exception: InvocationTargetException) {
throw exception.cause ?: exception
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,31 @@ import graphql.GraphQLException
import graphql.schema.DataFetchingEnvironment
import io.mockk.every
import io.mockk.mockk
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.runBlocking
import org.junit.jupiter.api.Test
import java.util.concurrent.CompletableFuture
import kotlin.test.assertEquals
import kotlin.test.assertFalse
import kotlin.test.assertNotNull
import kotlin.test.assertNull
import kotlin.test.assertTrue
import java.util.concurrent.Executors
import java.util.concurrent.ThreadFactory
import kotlin.test.*

class FunctionDataFetcherTest {

interface MyInterface {
fun print(string: String): String
}

val threadFactory = object: ThreadFactory {
override fun newThread(r: Runnable): Thread? {
val thread = Thread(r)
thread.name = "custom-thread-1"
return thread
}
}
val customCoroutineDispatcher = Executors.newSingleThreadExecutor(threadFactory).asCoroutineDispatcher()

class MyClass : MyInterface {
override fun print(string: String) = string

Expand All @@ -52,7 +62,9 @@ class FunctionDataFetcherTest {
string
}

fun throwException() { throw GraphQLException("Test Exception") }
fun throwException() {
throw GraphQLException("Test Exception")
}

suspend fun suspendThrow(): String = coroutineScope<String> {
throw GraphQLException("Suspended Exception")
Expand All @@ -75,14 +87,21 @@ class FunctionDataFetcherTest {
is OptionalInput.Undefined -> "optional was UNDEFINED"
is OptionalInput.Defined -> "optional was ${input.optional.value}"
}

fun threadNameSync(): String {
return Thread.currentThread().name
}

fun threadNameAsync(): String {
return Thread.currentThread().name
}
}

data class InputWrapper(val required: String, val optional: OptionalInput<String>)

@GraphQLName("MyInputClassRenamed")
data class MyInputClass(
@GraphQLName("jacksonField")
val field1: String
@GraphQLName("jacksonField") val field1: String
)

@Test
Expand All @@ -101,6 +120,7 @@ class FunctionDataFetcherTest {
every { getSource<Any>() } returns MyClass()
every { arguments } returns mapOf("string" to "hello")
every { containsArgument("string") } returns true
every { graphQlContext } returns GraphQLContext.newContext().build()
}
assertEquals(expected = "hello", actual = dataFetcher.get(mockEnvironment))
}
Expand All @@ -111,6 +131,7 @@ class FunctionDataFetcherTest {
val mockEnvironment: DataFetchingEnvironment = mockk {
every { arguments } returns mapOf("string" to "hello")
every { containsArgument("string") } returns true
every { graphQlContext } returns GraphQLContext.newContext().build()
}
assertEquals(expected = "hello", actual = dataFetcher.get(mockEnvironment))
}
Expand All @@ -137,6 +158,7 @@ class FunctionDataFetcherTest {
every { getSource<Any>() } returns MyClass()
every { arguments } returns mapOf("string" to "hello")
every { containsArgument("string") } returns true
every { graphQlContext } returns GraphQLContext.newContext().build()
}
assertEquals(expected = "hello", actual = dataFetcher.get(mockEnvironment))
}
Expand All @@ -148,6 +170,7 @@ class FunctionDataFetcherTest {
every { getSource<Any>() } returns MyClass()
every { arguments } returns emptyMap()
every { containsArgument(any()) } returns false
every { graphQlContext } returns GraphQLContext.newContext().build()
}
assertEquals(expected = "hello", actual = dataFetcher.get(mockEnvironment))
}
Expand All @@ -159,6 +182,7 @@ class FunctionDataFetcherTest {
every { getSource<Any>() } returns MyClass()
every { arguments } returns mapOf("string" to "foo")
every { containsArgument("string") } returns true
every { graphQlContext } returns GraphQLContext.newContext().build()
}
assertEquals(expected = "foo", actual = dataFetcher.get(mockEnvironment))
}
Expand All @@ -170,6 +194,7 @@ class FunctionDataFetcherTest {
every { getSource<Any>() } returns MyClass()
every { arguments } returns mapOf("string" to null)
every { containsArgument("string") } returns true
every { graphQlContext } returns GraphQLContext.newContext().build()
}
assertNull(dataFetcher.get(mockEnvironment))
}
Expand All @@ -180,6 +205,7 @@ class FunctionDataFetcherTest {
val mockEnvironment: DataFetchingEnvironment = mockk {
every { arguments } returns mapOf("items" to listOf("foo", "bar"))
every { containsArgument("items") } returns true
every { graphQlContext } returns GraphQLContext.newContext().build()
}

assertEquals(expected = "foo:bar", actual = dataFetcher.get(mockEnvironment))
Expand All @@ -194,6 +220,7 @@ class FunctionDataFetcherTest {
every { field } returns mockk {
every { name } returns "fooBarBaz"
}
every { graphQlContext } returns GraphQLContext.newContext().build()
}
assertEquals(expected = "fooBarBaz", actual = dataFetcher.get(mockEnvironment))
}
Expand All @@ -219,6 +246,7 @@ class FunctionDataFetcherTest {
val mockEnvironment: DataFetchingEnvironment = mockk {
every { arguments } returns emptyMap()
every { containsArgument(any()) } returns false
every { graphQlContext } returns GraphQLContext.newContext().build()
}

try {
Expand Down Expand Up @@ -256,6 +284,7 @@ class FunctionDataFetcherTest {
val mockEnvironment: DataFetchingEnvironment = mockk {
every { arguments } returns mapOf("myCustomArgument" to mapOf("jacksonField" to "foo"))
every { containsArgument("myCustomArgument") } returns true
every { graphQlContext } returns GraphQLContext.newContext().build()
}
assertEquals(expected = "You sent foo", actual = dataFetcher.get(mockEnvironment))
}
Expand All @@ -266,6 +295,7 @@ class FunctionDataFetcherTest {
val mockEnvironment: DataFetchingEnvironment = mockk {
every { arguments } returns mapOf("input" to "hello")
every { containsArgument("input") } returns true
every { graphQlContext } returns GraphQLContext.newContext().build()
}
assertEquals(expected = "input was hello", actual = dataFetcher.get(mockEnvironment))
}
Expand All @@ -276,6 +306,7 @@ class FunctionDataFetcherTest {
val mockEnvironment: DataFetchingEnvironment = mockk {
every { arguments } returns mapOf("input" to null)
every { containsArgument("input") } returns true
every { graphQlContext } returns GraphQLContext.newContext().build()
}
assertEquals(expected = "input was null", actual = dataFetcher.get(mockEnvironment))
}
Expand All @@ -286,6 +317,7 @@ class FunctionDataFetcherTest {
val mockEnvironment: DataFetchingEnvironment = mockk {
every { arguments } returns emptyMap()
every { containsArgument(any()) } returns false
every { graphQlContext } returns GraphQLContext.newContext().build()
}
assertEquals(expected = "input was UNDEFINED", actual = dataFetcher.get(mockEnvironment))
}
Expand All @@ -296,6 +328,7 @@ class FunctionDataFetcherTest {
val mockEnvironment: DataFetchingEnvironment = mockk {
every { arguments } returns mapOf("input" to listOf(linkedMapOf("jacksonField" to "foo")))
every { containsArgument("input") } returns true
every { graphQlContext } returns GraphQLContext.newContext().build()
}
val result = dataFetcher.get(mockEnvironment)
assertEquals(expected = "first input was foo", actual = result)
Expand All @@ -307,6 +340,7 @@ class FunctionDataFetcherTest {
val mockEnvironment: DataFetchingEnvironment = mockk {
every { arguments } returns mapOf("input" to mapOf("required" to "hello", "optional" to "hello"))
every { containsArgument("input") } returns true
every { graphQlContext } returns GraphQLContext.newContext().build()
}
assertEquals(expected = "optional was hello", actual = dataFetcher.get(mockEnvironment))
}
Expand All @@ -317,7 +351,59 @@ class FunctionDataFetcherTest {
val mockEnvironment = mockk<DataFetchingEnvironment> {
every { arguments } returns mapOf("input" to mapOf("required" to "hello"))
every { containsArgument("input") } returns true
every { graphQlContext } returns GraphQLContext.newContext().build()
}
assertEquals(expected = "optional was UNDEFINED", actual = dataFetcher.get(mockEnvironment))
}

@Test
fun `use default scope for sync function`() {
val dataFetcher = FunctionDataFetcher(target = MyClass(), fn = MyClass::threadNameSync)
val mockEnvironment = mockk<DataFetchingEnvironment> {
every { arguments } returns mapOf("input" to mapOf("required" to "hello"))
every { containsArgument("input") } returns true
every { graphQlContext } returns GraphQLContext.newContext().build()
}
val result = dataFetcher.get(mockEnvironment) as String
assertContains(charSequence = result, other = "Test worker")
}

@Test
fun `use provided scope for sync function`() = runBlocking {
val dataFetcher = FunctionDataFetcher(target = MyClass(), fn = MyClass::threadNameSync)
val scope = CoroutineScope(customCoroutineDispatcher)
val mockEnvironment = mockk<DataFetchingEnvironment> {
every { arguments } returns mapOf("input" to mapOf("required" to "hello"))
every { containsArgument("input") } returns true
every { graphQlContext } returns GraphQLContext.newContext().put(CoroutineScope::class, scope).build()
}
val result = dataFetcher.get(mockEnvironment) as String
assertContains(charSequence = result, other = "custom-thread-1")
}


@Test
fun `use default scope for async function`() {
val dataFetcher = FunctionDataFetcher(target = MyClass(), fn = MyClass::threadNameAsync)
val mockEnvironment = mockk<DataFetchingEnvironment> {
every { arguments } returns mapOf("input" to mapOf("required" to "hello"))
every { containsArgument("input") } returns true
every { graphQlContext } returns GraphQLContext.newContext().build()
}
val result = dataFetcher.get(mockEnvironment) as String
assertContains(charSequence = result, other = "Test worker")
}

@Test
fun `use provided scope for async function`() = runBlocking {
val dataFetcher = FunctionDataFetcher(target = MyClass(), fn = MyClass::threadNameAsync)
val scope = CoroutineScope(customCoroutineDispatcher)
val mockEnvironment = mockk<DataFetchingEnvironment> {
every { arguments } returns mapOf("input" to mapOf("required" to "hello"))
every { containsArgument("input") } returns true
every { graphQlContext } returns GraphQLContext.newContext().put(CoroutineScope::class, scope).build()
}
val result = dataFetcher.get(mockEnvironment) as String
assertContains(charSequence = result, other = "custom-thread-1")
}
}