Skip to content

Commit fd960cd

Browse files
Copy the target test function, don't mutate it (#187)
If there are other references to the test function, they will be broken because we've changed the signature. Co-authored-by: Jesse Wilson <jwilson@squareup.com>
1 parent 4d9d0ad commit fd960cd

File tree

3 files changed

+184
-56
lines changed

3 files changed

+184
-56
lines changed

burst-kotlin-plugin-tests/src/test/kotlin/app/cash/burst/kotlin/BurstKotlinPluginTest.kt

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -945,6 +945,99 @@ class BurstKotlinPluginTest {
945945
)
946946
}
947947

948+
@Test
949+
fun coroutines() {
950+
val result = compile(
951+
sourceFile = SourceFile.kotlin(
952+
"CoffeeTest.kt",
953+
"""
954+
import app.cash.burst.Burst
955+
import app.cash.burst.burstValues
956+
import kotlin.test.Test
957+
import kotlin.time.Duration.Companion.milliseconds
958+
import kotlinx.coroutines.delay
959+
import kotlinx.coroutines.test.runTest
960+
961+
@Burst
962+
class CoffeeTest {
963+
val log = mutableListOf<String>()
964+
965+
@Test
966+
fun test(espresso: Espresso) = runTest {
967+
delay(1000.milliseconds)
968+
log += "running ${'$'}espresso"
969+
}
970+
}
971+
972+
enum class Espresso { Decaf, Regular, Double }
973+
""",
974+
),
975+
)
976+
assertEquals(KotlinCompilation.ExitCode.OK, result.exitCode, result.messages)
977+
978+
val baseClass = result.classLoader.loadClass("CoffeeTest")
979+
val baseInstance = baseClass.constructors.single().newInstance()
980+
val baseLog = baseClass.getMethod("getLog").invoke(baseInstance) as MutableList<*>
981+
982+
baseClass.getMethod("test_Decaf").invoke(baseInstance)
983+
assertThat(baseLog).containsExactly(
984+
"running Decaf",
985+
)
986+
}
987+
988+
/**
989+
* We had a bug where we changed the signatures of user-defined functions, which would cause
990+
* problems if those functions had other callsites.
991+
*/
992+
@Test
993+
fun coroutinesAndTestComposition() {
994+
val result = compile(
995+
sourceFile = SourceFile.kotlin(
996+
"CoffeeTest.kt",
997+
"""
998+
import app.cash.burst.Burst
999+
import app.cash.burst.burstValues
1000+
import kotlin.test.Test
1001+
import kotlin.time.Duration.Companion.milliseconds
1002+
import kotlinx.coroutines.delay
1003+
import kotlinx.coroutines.test.runTest
1004+
1005+
@Burst
1006+
abstract class CoffeeTest {
1007+
abstract val log: MutableList<String>
1008+
1009+
@Test
1010+
fun test(espresso: Espresso) = runTest {
1011+
delay(1000.milliseconds)
1012+
log += "running ${'$'}espresso"
1013+
}
1014+
}
1015+
1016+
class RealCoffeeTest : CoffeeTest() {
1017+
override val log = mutableListOf<String>()
1018+
1019+
@Test
1020+
fun anotherTest() = test(Espresso.Double)
1021+
}
1022+
1023+
enum class Espresso { Decaf, Regular, Double }
1024+
""",
1025+
),
1026+
)
1027+
assertEquals(KotlinCompilation.ExitCode.OK, result.exitCode, result.messages)
1028+
1029+
val baseClass = result.classLoader.loadClass("RealCoffeeTest")
1030+
val baseInstance = baseClass.constructors.single().newInstance()
1031+
val baseLog = baseClass.getMethod("getLog").invoke(baseInstance) as MutableList<*>
1032+
1033+
baseClass.getMethod("test_Decaf").invoke(baseInstance)
1034+
baseClass.getMethod("anotherTest").invoke(baseInstance)
1035+
assertThat(baseLog).containsExactly(
1036+
"running Decaf",
1037+
"running Double",
1038+
)
1039+
}
1040+
9481041
private val Class<*>.testSuffixes: List<String>
9491042
get() = methods.mapNotNull {
9501043
when {

burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/FunctionSpecializer.kt

Lines changed: 78 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ import org.jetbrains.kotlin.name.Name
6363
* Coroutines
6464
* ----------
6565
*
66-
* If the original test uses `runTest()` for coroutines, we do these transformations:
66+
* If the original test uses `runTest()` for coroutines, we copy it into a new function with these
67+
* transformations:
6768
*
6869
* 1. A `TestScope` parameter is added.
6970
* 2. A `suspend` modifier is added.
@@ -125,26 +126,15 @@ internal class FunctionSpecializer(
125126
return
126127
}
127128

128-
// If the function body starts with runTest(), remove that call and call its testBody directly.
129-
if (original is TestFunction.Suspending) {
130-
function.isSuspend = true
131-
function.addValueParameter {
132-
initDefaults(function)
133-
name = Name.identifier("testScope")
134-
type = burstApis.testScope!!
129+
val delegate = when (original) {
130+
is TestFunction.Suspending -> {
131+
createSuspendingOverload(original)
132+
.also {
133+
originalParent.addDeclaration(it)
134+
}
135135
}
136-
function.irFunctionBody(
137-
context = pluginContext,
138-
) {
139-
+irCall(
140-
callee = pluginContext.irBuiltIns.suspendFunctionN(1).symbol.functionByName("invoke"),
141-
).apply {
142-
arguments[0] = original.runTestCall.arguments[2]
143-
arguments[1] = irGet(function.parameters.last())
144-
type = pluginContext.irBuiltIns.unitType
145-
}
146-
}
147-
function.returnType = pluginContext.irBuiltIns.unitType
136+
137+
is TestFunction.NonSuspending -> original.function
148138
}
149139

150140
val specializations = specializations(pluginContext, burstApis, valueParameters)
@@ -155,6 +145,7 @@ internal class FunctionSpecializer(
155145
originalDispatchReceiver = originalDispatchReceiver,
156146
specialization = specialization,
157147
isDefaultSpecialization = index == indexOfDefaultSpecialization,
148+
delegate = delegate,
158149
)
159150
}
160151

@@ -165,22 +156,54 @@ internal class FunctionSpecializer(
165156
}
166157
}
167158

159+
/**
160+
* If the function body starts with `runTest()`, move its body to a new function that is
161+
* suspending and that accepts a `TestScope` parameter.
162+
*/
163+
private fun createSuspendingOverload(original: TestFunction.Suspending): IrSimpleFunction {
164+
val result = original.function.deepCopyWithSymbols(originalParent)
165+
val runTestCall = TestFunctionReader(burstApis).readRunTestCall(result)!!
166+
167+
result.isSuspend = true
168+
result.addValueParameter {
169+
initDefaults(result)
170+
name = Name.identifier("testScope")
171+
type = burstApis.testScope!!
172+
}
173+
174+
result.irFunctionBody(
175+
context = pluginContext,
176+
) {
177+
+irCall(
178+
callee = pluginContext.irBuiltIns.suspendFunctionN(1).symbol.functionByName("invoke"),
179+
).apply {
180+
arguments[0] = runTestCall.arguments[2]
181+
arguments[1] = irGet(result.parameters.last())
182+
type = pluginContext.irBuiltIns.unitType
183+
}
184+
}
185+
result.returnType = pluginContext.irBuiltIns.unitType
186+
187+
result.patchDeclarationParents()
188+
return result
189+
}
190+
168191
private fun createFunction(
169192
originalDispatchReceiver: IrValueParameter,
170193
specialization: Specialization,
171194
isDefaultSpecialization: Boolean,
195+
delegate: IrSimpleFunction,
172196
): IrSimpleFunction {
173-
val function = original.function
174-
val result = function.factory.buildFun {
175-
initDefaults(function)
197+
val result = pluginContext.irFactory.buildFun {
198+
initDefaults(delegate)
176199
modality = Modality.FINAL
177200
name = when {
178-
isDefaultSpecialization -> function.name
179-
else -> Name.identifier("${function.name.identifier}_${specialization.name}")
201+
isDefaultSpecialization -> delegate.name
202+
else -> Name.identifier("${delegate.name.identifier}_${specialization.name}")
180203
}
181204
returnType = when {
182205
original is TestFunction.Suspending -> burstApis.runTestSymbol!!.owner.returnType
183-
else -> function.returnType
206+
else -> delegate.returnType
184207
}
185208
}.apply {
186209
parameters += buildReceiverParameter {
@@ -212,8 +235,8 @@ internal class FunctionSpecializer(
212235
}
213236
}
214237

215-
val callOriginal = irCall(
216-
callee = function.symbol,
238+
val callDelegate = irCall(
239+
callee = delegate.symbol,
217240
).apply {
218241
arguments.clear()
219242
arguments += irGet(receiverLocal)
@@ -222,28 +245,34 @@ internal class FunctionSpecializer(
222245
}
223246
}
224247

225-
if (original is TestFunction.Suspending) {
226-
// Call runTest() with the original's arguments, but this specialization's body.
227-
+irReturn(
228-
irCall(
229-
callee = burstApis.runTestSymbol!!,
248+
when (original) {
249+
// Call runTest() with the original's runTest() arguments. The test body calls the delegate.
250+
is TestFunction.Suspending -> {
251+
+irReturn(
252+
irCall(
253+
callee = burstApis.runTestSymbol!!,
254+
).apply {
255+
arguments.clear()
256+
// TODO: patch these arguments with the specialized arguments.
257+
arguments += original.runTestCall.arguments[0]?.deepCopyWithSymbols(result)
258+
arguments += original.runTestCall.arguments[1]?.deepCopyWithSymbols(result)
259+
arguments += irTestBodyLambda(
260+
context = pluginContext,
261+
burstApis = burstApis,
262+
original = originalParent,
263+
) { testScope ->
264+
callDelegate.arguments += irGet(testScope)
265+
+callDelegate
266+
}
267+
},
230268
).apply {
231-
arguments.clear()
232-
// TODO: patch these arguments with the specialized arguments.
233-
arguments += original.runTestCall.arguments[0]?.deepCopyWithSymbols(result)
234-
arguments += original.runTestCall.arguments[1]?.deepCopyWithSymbols(result)
235-
arguments += irTestBodyLambda(
236-
context = pluginContext,
237-
burstApis = burstApis,
238-
original = originalParent,
239-
) { testScope ->
240-
callOriginal.arguments += irGet(testScope)
241-
+callOriginal
242-
}
243-
},
244-
)
245-
} else {
246-
+callOriginal
269+
type = burstApis.runTestSymbol.owner.returnType
270+
}
271+
}
272+
273+
is TestFunction.NonSuspending -> {
274+
+callDelegate
275+
}
247276
}
248277
}
249278

burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/TestFunction.kt

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,26 +54,32 @@ internal class TestFunctionReader(
5454
/** Returns non-null if [function] is annotated `@Test`. */
5555
fun readOrNull(function: IrSimpleFunction): TestFunction? {
5656
val testAnnotation = burstApis.findTestAnnotation(function) ?: return null
57-
var runTestCall: IrCall? = null
57+
val runTestCall = readRunTestCall(function)
58+
59+
return when {
60+
runTestCall != null -> TestFunction.Suspending(function, testAnnotation, runTestCall)
61+
else -> TestFunction.NonSuspending(function, testAnnotation)
62+
}
63+
}
64+
65+
fun readRunTestCall(function: IrSimpleFunction): IrCall? {
66+
var result: IrCall? = null
5867

5968
function.body?.transform(
6069
object : IrTransformer<Unit>() {
6170
override fun visitCall(
6271
expression: IrCall,
6372
data: Unit,
6473
): IrElement {
65-
if (runTestCall == null && burstApis.isRunTest(expression)) {
66-
runTestCall = expression
74+
if (result == null && burstApis.isRunTest(expression)) {
75+
result = expression
6776
}
6877
return super.visitCall(expression, data)
6978
}
7079
},
7180
Unit,
7281
)
7382

74-
return when {
75-
runTestCall != null -> TestFunction.Suspending(function, testAnnotation, runTestCall)
76-
else -> TestFunction.NonSuspending(function, testAnnotation)
77-
}
83+
return result
7884
}
7985
}

0 commit comments

Comments
 (0)