Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,99 @@ class BurstKotlinPluginTest {
)
}

@Test
fun coroutines() {
val result = compile(
sourceFile = SourceFile.kotlin(
"CoffeeTest.kt",
"""
import app.cash.burst.Burst
import app.cash.burst.burstValues
import kotlin.test.Test
import kotlin.time.Duration.Companion.milliseconds
import kotlinx.coroutines.delay
import kotlinx.coroutines.test.runTest

@Burst
class CoffeeTest {
val log = mutableListOf<String>()

@Test
fun test(espresso: Espresso) = runTest {
delay(1000.milliseconds)
log += "running ${'$'}espresso"
}
}

enum class Espresso { Decaf, Regular, Double }
""",
),
)
assertEquals(KotlinCompilation.ExitCode.OK, result.exitCode, result.messages)

val baseClass = result.classLoader.loadClass("CoffeeTest")
val baseInstance = baseClass.constructors.single().newInstance()
val baseLog = baseClass.getMethod("getLog").invoke(baseInstance) as MutableList<*>

baseClass.getMethod("test_Decaf").invoke(baseInstance)
assertThat(baseLog).containsExactly(
"running Decaf",
)
}

/**
* We had a bug where we changed the signatures of user-defined functions, which would cause
* problems if those functions had other callsites.
*/
@Test
fun coroutinesAndTestComposition() {
val result = compile(
sourceFile = SourceFile.kotlin(
"CoffeeTest.kt",
"""
import app.cash.burst.Burst
import app.cash.burst.burstValues
import kotlin.test.Test
import kotlin.time.Duration.Companion.milliseconds
import kotlinx.coroutines.delay
import kotlinx.coroutines.test.runTest

@Burst
abstract class CoffeeTest {
abstract val log: MutableList<String>

@Test
fun test(espresso: Espresso) = runTest {
delay(1000.milliseconds)
log += "running ${'$'}espresso"
}
}

class RealCoffeeTest : CoffeeTest() {
override val log = mutableListOf<String>()

@Test
fun anotherTest() = test(Espresso.Double)
}

enum class Espresso { Decaf, Regular, Double }
""",
),
)
assertEquals(KotlinCompilation.ExitCode.OK, result.exitCode, result.messages)

val baseClass = result.classLoader.loadClass("RealCoffeeTest")
val baseInstance = baseClass.constructors.single().newInstance()
val baseLog = baseClass.getMethod("getLog").invoke(baseInstance) as MutableList<*>

baseClass.getMethod("test_Decaf").invoke(baseInstance)
baseClass.getMethod("anotherTest").invoke(baseInstance)
assertThat(baseLog).containsExactly(
"running Decaf",
"running Double",
)
}

private val Class<*>.testSuffixes: List<String>
get() = methods.mapNotNull {
when {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ import org.jetbrains.kotlin.name.Name
* Coroutines
* ----------
*
* If the original test uses `runTest()` for coroutines, we do these transformations:
* If the original test uses `runTest()` for coroutines, we copy it into a new function with these
* transformations:
*
* 1. A `TestScope` parameter is added.
* 2. A `suspend` modifier is added.
Expand Down Expand Up @@ -125,26 +126,15 @@ internal class FunctionSpecializer(
return
}

// If the function body starts with runTest(), remove that call and call its testBody directly.
if (original is TestFunction.Suspending) {
function.isSuspend = true
function.addValueParameter {
initDefaults(function)
name = Name.identifier("testScope")
type = burstApis.testScope!!
val delegate = when (original) {
is TestFunction.Suspending -> {
createSuspendingOverload(original)
.also {
originalParent.addDeclaration(it)
}
}
function.irFunctionBody(
context = pluginContext,
) {
+irCall(
callee = pluginContext.irBuiltIns.suspendFunctionN(1).symbol.functionByName("invoke"),
).apply {
arguments[0] = original.runTestCall.arguments[2]
arguments[1] = irGet(function.parameters.last())
type = pluginContext.irBuiltIns.unitType
}
}
function.returnType = pluginContext.irBuiltIns.unitType

is TestFunction.NonSuspending -> original.function
}

val specializations = specializations(pluginContext, burstApis, valueParameters)
Expand All @@ -155,6 +145,7 @@ internal class FunctionSpecializer(
originalDispatchReceiver = originalDispatchReceiver,
specialization = specialization,
isDefaultSpecialization = index == indexOfDefaultSpecialization,
delegate = delegate,
)
}

Expand All @@ -165,22 +156,54 @@ internal class FunctionSpecializer(
}
}

/**
* If the function body starts with `runTest()`, move its body to a new function that is
* suspending and that accepts a `TestScope` parameter.
*/
private fun createSuspendingOverload(original: TestFunction.Suspending): IrSimpleFunction {
val result = original.function.deepCopyWithSymbols(originalParent)
val runTestCall = TestFunctionReader(burstApis).readRunTestCall(result)!!

result.isSuspend = true
result.addValueParameter {
initDefaults(result)
name = Name.identifier("testScope")
type = burstApis.testScope!!
}

result.irFunctionBody(
context = pluginContext,
) {
+irCall(
callee = pluginContext.irBuiltIns.suspendFunctionN(1).symbol.functionByName("invoke"),
).apply {
arguments[0] = runTestCall.arguments[2]
arguments[1] = irGet(result.parameters.last())
type = pluginContext.irBuiltIns.unitType
}
}
result.returnType = pluginContext.irBuiltIns.unitType

result.patchDeclarationParents()
return result
}

private fun createFunction(
originalDispatchReceiver: IrValueParameter,
specialization: Specialization,
isDefaultSpecialization: Boolean,
delegate: IrSimpleFunction,
): IrSimpleFunction {
val function = original.function
val result = function.factory.buildFun {
initDefaults(function)
val result = pluginContext.irFactory.buildFun {
initDefaults(delegate)
modality = Modality.FINAL
name = when {
isDefaultSpecialization -> function.name
else -> Name.identifier("${function.name.identifier}_${specialization.name}")
isDefaultSpecialization -> delegate.name
else -> Name.identifier("${delegate.name.identifier}_${specialization.name}")
}
returnType = when {
original is TestFunction.Suspending -> burstApis.runTestSymbol!!.owner.returnType
else -> function.returnType
else -> delegate.returnType
}
}.apply {
parameters += buildReceiverParameter {
Expand Down Expand Up @@ -212,8 +235,8 @@ internal class FunctionSpecializer(
}
}

val callOriginal = irCall(
callee = function.symbol,
val callDelegate = irCall(
callee = delegate.symbol,
).apply {
arguments.clear()
arguments += irGet(receiverLocal)
Expand All @@ -222,28 +245,34 @@ internal class FunctionSpecializer(
}
}

if (original is TestFunction.Suspending) {
// Call runTest() with the original's arguments, but this specialization's body.
+irReturn(
irCall(
callee = burstApis.runTestSymbol!!,
when (original) {
// Call runTest() with the original's runTest() arguments. The test body calls the delegate.
is TestFunction.Suspending -> {
+irReturn(
irCall(
callee = burstApis.runTestSymbol!!,
).apply {
arguments.clear()
// TODO: patch these arguments with the specialized arguments.
arguments += original.runTestCall.arguments[0]?.deepCopyWithSymbols(result)
arguments += original.runTestCall.arguments[1]?.deepCopyWithSymbols(result)
arguments += irTestBodyLambda(
context = pluginContext,
burstApis = burstApis,
original = originalParent,
) { testScope ->
callDelegate.arguments += irGet(testScope)
+callDelegate
}
},
).apply {
arguments.clear()
// TODO: patch these arguments with the specialized arguments.
arguments += original.runTestCall.arguments[0]?.deepCopyWithSymbols(result)
arguments += original.runTestCall.arguments[1]?.deepCopyWithSymbols(result)
arguments += irTestBodyLambda(
context = pluginContext,
burstApis = burstApis,
original = originalParent,
) { testScope ->
callOriginal.arguments += irGet(testScope)
+callOriginal
}
},
)
} else {
+callOriginal
type = burstApis.runTestSymbol.owner.returnType
}
}

is TestFunction.NonSuspending -> {
+callDelegate
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,26 +54,32 @@ internal class TestFunctionReader(
/** Returns non-null if [function] is annotated `@Test`. */
fun readOrNull(function: IrSimpleFunction): TestFunction? {
val testAnnotation = burstApis.findTestAnnotation(function) ?: return null
var runTestCall: IrCall? = null
val runTestCall = readRunTestCall(function)

return when {
runTestCall != null -> TestFunction.Suspending(function, testAnnotation, runTestCall)
else -> TestFunction.NonSuspending(function, testAnnotation)
}
}

fun readRunTestCall(function: IrSimpleFunction): IrCall? {
var result: IrCall? = null

function.body?.transform(
object : IrTransformer<Unit>() {
override fun visitCall(
expression: IrCall,
data: Unit,
): IrElement {
if (runTestCall == null && burstApis.isRunTest(expression)) {
runTestCall = expression
if (result == null && burstApis.isRunTest(expression)) {
result = expression
}
return super.visitCall(expression, data)
}
},
Unit,
)

return when {
runTestCall != null -> TestFunction.Suspending(function, testAnnotation, runTestCall)
else -> TestFunction.NonSuspending(function, testAnnotation)
}
return result
}
}