Skip to content

Commit 17d0bdb

Browse files
committed
Support coroutine tests in Burst
We were previously broken with @burst tests that called runTest, when executing on Kotlin/JS.
1 parent 6c54dec commit 17d0bdb

File tree

13 files changed

+336
-202
lines changed

13 files changed

+336
-202
lines changed

burst-gradle-plugin/src/test/kotlin/app/cash/burst/gradle/BurstGradlePluginTest.kt

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package app.cash.burst.gradle
1818

1919
import assertk.assertThat
20+
import assertk.assertions.contains
2021
import assertk.assertions.containsExactlyInAnyOrder
2122
import assertk.assertions.isEqualTo
2223
import assertk.assertions.isFalse
@@ -69,13 +70,29 @@ class BurstGradlePluginTest {
6970
// Each test class is executed normally with nothing skipped.
7071
with(tester.readTestSuite("CoffeeTest_Regular", testTaskName)) {
7172
assertThat(testCases.map { it.name }).containsExactlyInAnyOrder(
72-
"test_Milk[$platformName]",
73-
"test_None[$platformName]",
74-
"test_Oat[$platformName]",
73+
"basicTest_Milk[$platformName]",
74+
"basicTest_None[$platformName]",
75+
"basicTest_Oat[$platformName]",
76+
"coroutinesTest_Milk[$platformName]",
77+
"coroutinesTest_None[$platformName]",
78+
"coroutinesTest_Oat[$platformName]",
7579
)
7680

77-
val sampleSpecialization = testCases.single { it.name == "test_Milk[$platformName]" }
81+
val sampleSpecialization = testCases.single { it.name == "basicTest_Milk[$platformName]" }
7882
assertThat(sampleSpecialization.skipped).isFalse()
83+
84+
assertThat(systemOut).contains(
85+
"""
86+
|set up Regular
87+
|running Regular Oat in coffeeCoroutine
88+
|
89+
""".trimMargin(),
90+
"""
91+
|set up Regular
92+
|running Regular Oat
93+
|
94+
""".trimMargin(),
95+
)
7996
}
8097

8198
if (checkKlibMetadata) {
@@ -86,10 +103,14 @@ class BurstGradlePluginTest {
86103
val coffeeTestMetadata = klibMetadata.classes.first { it.name == "CoffeeTest" }
87104
assertThat(coffeeTestMetadata.functions.map { it.name }).containsExactlyInAnyOrder(
88105
"setUp",
89-
"test",
90-
"test_Milk",
91-
"test_None",
92-
"test_Oat",
106+
"basicTest",
107+
"basicTest_Milk",
108+
"basicTest_None",
109+
"basicTest_Oat",
110+
"coroutinesTest",
111+
"coroutinesTest_Milk",
112+
"coroutinesTest_None",
113+
"coroutinesTest_Oat",
93114
)
94115
}
95116
}

burst-gradle-plugin/src/test/projects/multiplatform/lib/build.gradle.kts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ kotlin {
1717
commonTest {
1818
dependencies {
1919
implementation(kotlin("test"))
20+
implementation(libs.kotlinx.coroutines.core)
21+
implementation(libs.kotlinx.coroutines.test)
2022
}
2123
}
2224
}

burst-gradle-plugin/src/test/projects/multiplatform/lib/src/commonTest/kotlin/CoffeeTest.kt

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
import app.cash.burst.Burst
2+
import kotlin.coroutines.coroutineContext
23
import kotlin.test.BeforeTest
34
import kotlin.test.Test
5+
import kotlin.time.Duration.Companion.milliseconds
6+
import kotlinx.coroutines.CoroutineName
7+
import kotlinx.coroutines.async
8+
import kotlinx.coroutines.delay
9+
import kotlinx.coroutines.test.runTest
410

511
@Burst
612
class CoffeeTest(
@@ -12,9 +18,18 @@ class CoffeeTest(
1218
}
1319

1420
@Test
15-
fun test(dairy: Dairy) {
21+
fun basicTest(dairy: Dairy) {
1622
println("running $espresso $dairy")
1723
}
24+
25+
@Test
26+
fun coroutinesTest(dairy: Dairy) = runTest(CoroutineName("coffeeCoroutine")) {
27+
val deferred = async {
28+
println("running $espresso $dairy in ${coroutineContext[CoroutineName]?.name}")
29+
}
30+
delay(1000.milliseconds)
31+
deferred.await()
32+
}
1833
}
1934

2035
enum class Espresso { Decaf, Regular, Double }

burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/Argument.kt

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,6 @@ internal sealed interface Argument {
4343
/** True if this argument matches the default parameter value. */
4444
val isDefault: Boolean
4545

46-
/** Where to assign this argument to in a call. */
47-
val indexInParameters: Int
48-
4946
/** A string that's safe to use in a declaration name. */
5047
val name: String
5148

@@ -64,9 +61,6 @@ private class EnumValueArgument(
6461
) : Argument {
6562
override val name = value.name.identifier
6663

67-
override val indexInParameters: Int
68-
get() = original.indexInParameters
69-
7064
override fun expression() =
7165
IrGetEnumValueImpl(original.startOffset, original.endOffset, type, value.symbol)
7266

@@ -83,9 +77,6 @@ private class BooleanArgument(
8377
) : Argument {
8478
override val name = value.toString()
8579

86-
override val indexInParameters: Int
87-
get() = original.indexInParameters
88-
8980
override fun expression() =
9081
IrConstImpl.boolean(original.startOffset, original.endOffset, booleanType, value)
9182

@@ -101,9 +92,6 @@ private class NullArgument(
10192
) : Argument {
10293
override val name = "null"
10394

104-
override val indexInParameters: Int
105-
get() = original.indexInParameters
106-
10795
override fun expression() = IrConstImpl.constNull(original.startOffset, original.endOffset, type)
10896

10997
override fun <R, D> accept(visitor: IrVisitor<R, D>, data: D): R {
@@ -118,8 +106,6 @@ private class BurstValuesArgument(
118106
index: Int,
119107
) : Argument {
120108
override val isDefault = index == 0
121-
override val indexInParameters: Int
122-
get() = parameter.indexInParameters
123109
override val name = value.suggestedName() ?: index.toString()
124110

125111
override fun expression() = value.deepCopyWithSymbols(parameter.parent)

burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/BurstApis.kt

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16+
@file:OptIn(UnsafeDuringIrConstructionAPI::class)
17+
1618
package app.cash.burst.kotlin
1719

1820
import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
@@ -24,7 +26,9 @@ import org.jetbrains.kotlin.ir.symbols.IrClassSymbol
2426
import org.jetbrains.kotlin.ir.symbols.IrFunctionSymbol
2527
import org.jetbrains.kotlin.ir.symbols.IrPropertySymbol
2628
import org.jetbrains.kotlin.ir.symbols.IrSimpleFunctionSymbol
29+
import org.jetbrains.kotlin.ir.symbols.UnsafeDuringIrConstructionAPI
2730
import org.jetbrains.kotlin.ir.types.IrType
31+
import org.jetbrains.kotlin.ir.types.classFqName
2832
import org.jetbrains.kotlin.ir.types.classOrNull
2933
import org.jetbrains.kotlin.ir.types.defaultType
3034
import org.jetbrains.kotlin.ir.util.hasAnnotation
@@ -37,7 +41,8 @@ internal class BurstApis private constructor(
3741
private val testClassSymbols: List<IrClassSymbol>,
3842
val beforeTestSymbols: List<IrClassSymbol>,
3943
val afterTestSymbols: List<IrClassSymbol>,
40-
val runTestSymbols: List<IrFunctionSymbol>,
44+
/** Null if `kotlinx.coroutines.test` isn't in this build. */
45+
val runTestSymbol: IrFunctionSymbol?,
4146
) {
4247
val burstValues: IrFunctionSymbol = pluginContext.referenceFunctions(burstValuesId).single()
4348

@@ -46,6 +51,7 @@ internal class BurstApis private constructor(
4651
testInterceptorClassId = burstFqPackage.classId("TestInterceptor"),
4752
)!!
4853

54+
/** Null if `app.cash.burst.coroutines` isn't in this build. */
4955
val coroutinesTestInterceptorApis: TestInterceptorApis? = pluginContext.testInterceptorApis(
5056
testFunctionClassId = burstCoroutinesFqPackage.classId("CoroutineTestFunction"),
5157
testInterceptorClassId = burstCoroutinesFqPackage.classId("CoroutineTestInterceptor"),
@@ -75,13 +81,8 @@ internal class BurstApis private constructor(
7581
.firstOrNull { it in afterTestSymbols }
7682
}
7783

78-
fun isEitherTestInterceptor(property: IrProperty): Boolean {
79-
return testInterceptorApis.isTestInterceptor(property) ||
80-
coroutinesTestInterceptorApis?.isTestInterceptor(property) == true
81-
}
82-
8384
fun isRunTest(irCall: IrCall): Boolean {
84-
return irCall.symbol in runTestSymbols
85+
return runTestSymbol != null && irCall.symbol == runTestSymbol
8586
}
8687

8788
companion object {
@@ -114,14 +115,18 @@ internal class BurstApis private constructor(
114115
pluginContext.referenceClass(kotlinAfterTestClassId),
115116
)
116117

117-
val runTestSymbols = pluginContext.referenceFunctions(runTestId).toList()
118+
val runTestSymbol = pluginContext.referenceFunctions(runTestId).singleOrNull {
119+
it.owner.parameters.size == 3 &&
120+
it.owner.parameters[0].type.classFqName == coroutineContextId.asSingleFqName() &&
121+
it.owner.parameters[1].type.classFqName == durationId.asSingleFqName()
122+
}
118123

119124
return BurstApis(
120125
pluginContext = pluginContext,
121126
testClassSymbols = testClassSymbols,
122127
beforeTestSymbols = beforeTestSymbols,
123128
afterTestSymbols = afterTestSymbols,
124-
runTestSymbols = runTestSymbols,
129+
runTestSymbol = runTestSymbol,
125130
)
126131
}
127132
}
@@ -168,6 +173,12 @@ private fun IrPluginContext.testInterceptorApis(
168173
private val kotlinPackage = FqPackageName("kotlin")
169174
private val throwableAddSuppressedId = kotlinPackage.callableId("addSuppressed")
170175

176+
private val kotlinCoroutinePackage = FqPackageName("kotlin.coroutines")
177+
private val coroutineContextId = kotlinCoroutinePackage.callableId("CoroutineContext")
178+
179+
private val kotlinTimeFqPackage = FqPackageName("kotlin.time")
180+
private val durationId = kotlinTimeFqPackage.classId("Duration")
181+
171182
private val burstFqPackage = FqPackageName("app.cash.burst")
172183
private val burstCoroutinesFqPackage = FqPackageName("app.cash.burst.coroutines")
173184
private val burstAnnotationId = burstFqPackage.classId("Burst")

burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/BurstIrGenerationExtension.kt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class BurstIrGenerationExtension(
3333
override fun generate(moduleFragment: IrModuleFragment, pluginContext: IrPluginContext) {
3434
// Skip the rewrite if the Burst APIs aren't loaded. We don't expect to find @Burst anywhere.
3535
val burstApis = BurstApis.maybeCreate(pluginContext) ?: return
36+
val testFunctionReader = TestFunctionReader(burstApis)
3637

3738
val transformer = object : IrElementTransformerVoidWithContext() {
3839
override fun visitClassNew(declaration: IrClass): IrStatement {
@@ -61,16 +62,15 @@ class BurstIrGenerationExtension(
6162
val originalFunctions = classDeclaration.functions.toList()
6263

6364
for (function in originalFunctions) {
64-
val testAnnotationClassSymbol = burstApis.findTestAnnotation(function) ?: continue
65+
val testFunction = testFunctionReader.readOrNull(function) ?: continue
6566
if (!classHasAtBurst && !function.hasAtBurst) continue
6667

6768
try {
6869
val specializer = FunctionSpecializer(
6970
pluginContext = pluginContext,
7071
burstApis = burstApis,
7172
originalParent = classDeclaration,
72-
original = function,
73-
testAnnotationClassSymbol = testAnnotationClassSymbol,
73+
original = testFunction,
7474
)
7575
specializer.generateSpecializations()
7676
} catch (e: BurstCompilationException) {

burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/ClassSpecializer.kt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,9 @@ internal class ClassSpecializer(
148148
context = pluginContext,
149149
symbol = superConstructor.symbol,
150150
) {
151+
arguments.clear()
151152
for (argument in specialization.arguments) {
152-
arguments[argument.indexInParameters] = argument.expression()
153+
arguments += argument.expression()
153154
}
154155
}
155156
statements += irInstanceInitializerCall(
@@ -176,8 +177,9 @@ internal class ClassSpecializer(
176177
context = pluginContext,
177178
symbol = superConstructor.symbol,
178179
) {
180+
arguments.clear()
179181
for (argument in specialization.arguments) {
180-
arguments[argument.indexInParameters] = argument.expression()
182+
arguments += argument.expression()
181183
}
182184
}
183185
}

0 commit comments

Comments
 (0)