diff --git a/generator/graphql-kotlin-schema-generator/src/main/kotlin/com/expediagroup/graphql/generator/execution/FunctionDataFetcher.kt b/generator/graphql-kotlin-schema-generator/src/main/kotlin/com/expediagroup/graphql/generator/execution/FunctionDataFetcher.kt index 645ea39718..173d6b469e 100644 --- a/generator/graphql-kotlin-schema-generator/src/main/kotlin/com/expediagroup/graphql/generator/execution/FunctionDataFetcher.kt +++ b/generator/graphql-kotlin-schema-generator/src/main/kotlin/com/expediagroup/graphql/generator/execution/FunctionDataFetcher.kt @@ -24,6 +24,7 @@ import graphql.schema.DataFetcher import graphql.schema.DataFetchingEnvironment import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.future.future +import kotlinx.coroutines.runBlocking import java.lang.reflect.InvocationTargetException import java.util.concurrent.CompletableFuture import kotlin.coroutines.EmptyCoroutineContext @@ -60,7 +61,7 @@ open class FunctionDataFetcher( if (fn.isSuspend) { runSuspendingFunction(environment, parameterValues) } else { - runBlockingFunction(parameterValues) + runBlockingFunction(environment, parameterValues) } } else { null @@ -123,8 +124,14 @@ open class FunctionDataFetcher( * Once all parameters values are properly converted, this function will be called to run a simple blocking function. * If you need to override the exception handling you can override this method. */ - protected open fun runBlockingFunction(parameterValues: Map): Any? = try { - fn.callBy(parameterValues) + protected open fun runBlockingFunction( + environment: DataFetchingEnvironment, + parameterValues: Map + ): Any? = try { + val coroutineScope = environment.graphQlContext.getOrDefault(CoroutineScope(EmptyCoroutineContext)) + runBlocking(coroutineScope.coroutineContext) { + fn.callBy(parameterValues) + } } catch (exception: InvocationTargetException) { throw exception.cause ?: exception } diff --git a/generator/graphql-kotlin-schema-generator/src/test/kotlin/com/expediagroup/graphql/generator/execution/FunctionDataFetcherTest.kt b/generator/graphql-kotlin-schema-generator/src/test/kotlin/com/expediagroup/graphql/generator/execution/FunctionDataFetcherTest.kt index 1907788742..91f8c5364f 100644 --- a/generator/graphql-kotlin-schema-generator/src/test/kotlin/com/expediagroup/graphql/generator/execution/FunctionDataFetcherTest.kt +++ b/generator/graphql-kotlin-schema-generator/src/test/kotlin/com/expediagroup/graphql/generator/execution/FunctionDataFetcherTest.kt @@ -22,14 +22,15 @@ import graphql.GraphQLException import graphql.schema.DataFetchingEnvironment import io.mockk.every import io.mockk.mockk +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.asCoroutineDispatcher import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.runBlocking import org.junit.jupiter.api.Test import java.util.concurrent.CompletableFuture -import kotlin.test.assertEquals -import kotlin.test.assertFalse -import kotlin.test.assertNotNull -import kotlin.test.assertNull -import kotlin.test.assertTrue +import java.util.concurrent.Executors +import java.util.concurrent.ThreadFactory +import kotlin.test.* class FunctionDataFetcherTest { @@ -37,6 +38,15 @@ class FunctionDataFetcherTest { fun print(string: String): String } + val threadFactory = object: ThreadFactory { + override fun newThread(r: Runnable): Thread? { + val thread = Thread(r) + thread.name = "custom-thread-1" + return thread + } + } + val customCoroutineDispatcher = Executors.newSingleThreadExecutor(threadFactory).asCoroutineDispatcher() + class MyClass : MyInterface { override fun print(string: String) = string @@ -52,7 +62,9 @@ class FunctionDataFetcherTest { string } - fun throwException() { throw GraphQLException("Test Exception") } + fun throwException() { + throw GraphQLException("Test Exception") + } suspend fun suspendThrow(): String = coroutineScope { throw GraphQLException("Suspended Exception") @@ -75,14 +87,21 @@ class FunctionDataFetcherTest { is OptionalInput.Undefined -> "optional was UNDEFINED" is OptionalInput.Defined -> "optional was ${input.optional.value}" } + + fun threadNameSync(): String { + return Thread.currentThread().name + } + + fun threadNameAsync(): String { + return Thread.currentThread().name + } } data class InputWrapper(val required: String, val optional: OptionalInput) @GraphQLName("MyInputClassRenamed") data class MyInputClass( - @GraphQLName("jacksonField") - val field1: String + @GraphQLName("jacksonField") val field1: String ) @Test @@ -101,6 +120,7 @@ class FunctionDataFetcherTest { every { getSource() } returns MyClass() every { arguments } returns mapOf("string" to "hello") every { containsArgument("string") } returns true + every { graphQlContext } returns GraphQLContext.newContext().build() } assertEquals(expected = "hello", actual = dataFetcher.get(mockEnvironment)) } @@ -111,6 +131,7 @@ class FunctionDataFetcherTest { val mockEnvironment: DataFetchingEnvironment = mockk { every { arguments } returns mapOf("string" to "hello") every { containsArgument("string") } returns true + every { graphQlContext } returns GraphQLContext.newContext().build() } assertEquals(expected = "hello", actual = dataFetcher.get(mockEnvironment)) } @@ -137,6 +158,7 @@ class FunctionDataFetcherTest { every { getSource() } returns MyClass() every { arguments } returns mapOf("string" to "hello") every { containsArgument("string") } returns true + every { graphQlContext } returns GraphQLContext.newContext().build() } assertEquals(expected = "hello", actual = dataFetcher.get(mockEnvironment)) } @@ -148,6 +170,7 @@ class FunctionDataFetcherTest { every { getSource() } returns MyClass() every { arguments } returns emptyMap() every { containsArgument(any()) } returns false + every { graphQlContext } returns GraphQLContext.newContext().build() } assertEquals(expected = "hello", actual = dataFetcher.get(mockEnvironment)) } @@ -159,6 +182,7 @@ class FunctionDataFetcherTest { every { getSource() } returns MyClass() every { arguments } returns mapOf("string" to "foo") every { containsArgument("string") } returns true + every { graphQlContext } returns GraphQLContext.newContext().build() } assertEquals(expected = "foo", actual = dataFetcher.get(mockEnvironment)) } @@ -170,6 +194,7 @@ class FunctionDataFetcherTest { every { getSource() } returns MyClass() every { arguments } returns mapOf("string" to null) every { containsArgument("string") } returns true + every { graphQlContext } returns GraphQLContext.newContext().build() } assertNull(dataFetcher.get(mockEnvironment)) } @@ -180,6 +205,7 @@ class FunctionDataFetcherTest { val mockEnvironment: DataFetchingEnvironment = mockk { every { arguments } returns mapOf("items" to listOf("foo", "bar")) every { containsArgument("items") } returns true + every { graphQlContext } returns GraphQLContext.newContext().build() } assertEquals(expected = "foo:bar", actual = dataFetcher.get(mockEnvironment)) @@ -194,6 +220,7 @@ class FunctionDataFetcherTest { every { field } returns mockk { every { name } returns "fooBarBaz" } + every { graphQlContext } returns GraphQLContext.newContext().build() } assertEquals(expected = "fooBarBaz", actual = dataFetcher.get(mockEnvironment)) } @@ -219,6 +246,7 @@ class FunctionDataFetcherTest { val mockEnvironment: DataFetchingEnvironment = mockk { every { arguments } returns emptyMap() every { containsArgument(any()) } returns false + every { graphQlContext } returns GraphQLContext.newContext().build() } try { @@ -256,6 +284,7 @@ class FunctionDataFetcherTest { val mockEnvironment: DataFetchingEnvironment = mockk { every { arguments } returns mapOf("myCustomArgument" to mapOf("jacksonField" to "foo")) every { containsArgument("myCustomArgument") } returns true + every { graphQlContext } returns GraphQLContext.newContext().build() } assertEquals(expected = "You sent foo", actual = dataFetcher.get(mockEnvironment)) } @@ -266,6 +295,7 @@ class FunctionDataFetcherTest { val mockEnvironment: DataFetchingEnvironment = mockk { every { arguments } returns mapOf("input" to "hello") every { containsArgument("input") } returns true + every { graphQlContext } returns GraphQLContext.newContext().build() } assertEquals(expected = "input was hello", actual = dataFetcher.get(mockEnvironment)) } @@ -276,6 +306,7 @@ class FunctionDataFetcherTest { val mockEnvironment: DataFetchingEnvironment = mockk { every { arguments } returns mapOf("input" to null) every { containsArgument("input") } returns true + every { graphQlContext } returns GraphQLContext.newContext().build() } assertEquals(expected = "input was null", actual = dataFetcher.get(mockEnvironment)) } @@ -286,6 +317,7 @@ class FunctionDataFetcherTest { val mockEnvironment: DataFetchingEnvironment = mockk { every { arguments } returns emptyMap() every { containsArgument(any()) } returns false + every { graphQlContext } returns GraphQLContext.newContext().build() } assertEquals(expected = "input was UNDEFINED", actual = dataFetcher.get(mockEnvironment)) } @@ -296,6 +328,7 @@ class FunctionDataFetcherTest { val mockEnvironment: DataFetchingEnvironment = mockk { every { arguments } returns mapOf("input" to listOf(linkedMapOf("jacksonField" to "foo"))) every { containsArgument("input") } returns true + every { graphQlContext } returns GraphQLContext.newContext().build() } val result = dataFetcher.get(mockEnvironment) assertEquals(expected = "first input was foo", actual = result) @@ -307,6 +340,7 @@ class FunctionDataFetcherTest { val mockEnvironment: DataFetchingEnvironment = mockk { every { arguments } returns mapOf("input" to mapOf("required" to "hello", "optional" to "hello")) every { containsArgument("input") } returns true + every { graphQlContext } returns GraphQLContext.newContext().build() } assertEquals(expected = "optional was hello", actual = dataFetcher.get(mockEnvironment)) } @@ -317,7 +351,59 @@ class FunctionDataFetcherTest { val mockEnvironment = mockk { every { arguments } returns mapOf("input" to mapOf("required" to "hello")) every { containsArgument("input") } returns true + every { graphQlContext } returns GraphQLContext.newContext().build() } assertEquals(expected = "optional was UNDEFINED", actual = dataFetcher.get(mockEnvironment)) } + + @Test + fun `use default scope for sync function`() { + val dataFetcher = FunctionDataFetcher(target = MyClass(), fn = MyClass::threadNameSync) + val mockEnvironment = mockk { + every { arguments } returns mapOf("input" to mapOf("required" to "hello")) + every { containsArgument("input") } returns true + every { graphQlContext } returns GraphQLContext.newContext().build() + } + val result = dataFetcher.get(mockEnvironment) as String + assertContains(charSequence = result, other = "Test worker") + } + + @Test + fun `use provided scope for sync function`() = runBlocking { + val dataFetcher = FunctionDataFetcher(target = MyClass(), fn = MyClass::threadNameSync) + val scope = CoroutineScope(customCoroutineDispatcher) + val mockEnvironment = mockk { + every { arguments } returns mapOf("input" to mapOf("required" to "hello")) + every { containsArgument("input") } returns true + every { graphQlContext } returns GraphQLContext.newContext().put(CoroutineScope::class, scope).build() + } + val result = dataFetcher.get(mockEnvironment) as String + assertContains(charSequence = result, other = "custom-thread-1") + } + + + @Test + fun `use default scope for async function`() { + val dataFetcher = FunctionDataFetcher(target = MyClass(), fn = MyClass::threadNameAsync) + val mockEnvironment = mockk { + every { arguments } returns mapOf("input" to mapOf("required" to "hello")) + every { containsArgument("input") } returns true + every { graphQlContext } returns GraphQLContext.newContext().build() + } + val result = dataFetcher.get(mockEnvironment) as String + assertContains(charSequence = result, other = "Test worker") + } + + @Test + fun `use provided scope for async function`() = runBlocking { + val dataFetcher = FunctionDataFetcher(target = MyClass(), fn = MyClass::threadNameAsync) + val scope = CoroutineScope(customCoroutineDispatcher) + val mockEnvironment = mockk { + every { arguments } returns mapOf("input" to mapOf("required" to "hello")) + every { containsArgument("input") } returns true + every { graphQlContext } returns GraphQLContext.newContext().put(CoroutineScope::class, scope).build() + } + val result = dataFetcher.get(mockEnvironment) as String + assertContains(charSequence = result, other = "custom-thread-1") + } }