Skip to content

Commit cfbc3c5

Browse files
committed
Generate fields for markers in REPL so compiler plugin can extract the schema
1 parent 2fbf725 commit cfbc3c5

File tree

8 files changed

+160
-19
lines changed

8 files changed

+160
-19
lines changed

core/api/core.api

+3-2
Original file line numberDiff line numberDiff line change
@@ -5915,7 +5915,7 @@ public final class org/jetbrains/kotlinx/dataframe/impl/io/FastDoubleParser {
59155915

59165916
public final class org/jetbrains/kotlinx/dataframe/impl/schema/DataFrameSchemaImpl : org/jetbrains/kotlinx/dataframe/schema/DataFrameSchema {
59175917
public fun <init> (Ljava/util/Map;)V
5918-
public fun compare (Lorg/jetbrains/kotlinx/dataframe/schema/DataFrameSchema;)Lorg/jetbrains/kotlinx/dataframe/schema/CompareResult;
5918+
public fun compare (Lorg/jetbrains/kotlinx/dataframe/schema/DataFrameSchema;Z)Lorg/jetbrains/kotlinx/dataframe/schema/CompareResult;
59195919
public fun equals (Ljava/lang/Object;)Z
59205920
public fun getColumns ()Ljava/util/Map;
59215921
public fun hashCode ()I
@@ -6668,7 +6668,8 @@ public final class org/jetbrains/kotlinx/dataframe/schema/CompareResult$Companio
66686668
}
66696669

66706670
public abstract interface class org/jetbrains/kotlinx/dataframe/schema/DataFrameSchema {
6671-
public abstract fun compare (Lorg/jetbrains/kotlinx/dataframe/schema/DataFrameSchema;)Lorg/jetbrains/kotlinx/dataframe/schema/CompareResult;
6671+
public abstract fun compare (Lorg/jetbrains/kotlinx/dataframe/schema/DataFrameSchema;Z)Lorg/jetbrains/kotlinx/dataframe/schema/CompareResult;
6672+
public static synthetic fun compare$default (Lorg/jetbrains/kotlinx/dataframe/schema/DataFrameSchema;Lorg/jetbrains/kotlinx/dataframe/schema/DataFrameSchema;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/schema/CompareResult;
66726673
public abstract fun getColumns ()Ljava/util/Map;
66736674
}
66746675

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/codeGen/ReplCodeGeneratorImpl.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ internal class ReplCodeGeneratorImpl : ReplCodeGenerator {
8181
val result = generator.generate(
8282
schema = schema,
8383
name = name,
84-
fields = false,
84+
fields = true,
8585
extensionProperties = true,
8686
isOpen = isOpen,
8787
visibility = MarkerVisibility.IMPLICIT_PUBLIC,

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/schema/DataFrameSchemaImpl.kt

+4-4
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,23 @@ import org.jetbrains.kotlinx.dataframe.schema.DataFrameSchema
77

88
public class DataFrameSchemaImpl(override val columns: Map<String, ColumnSchema>) : DataFrameSchema {
99

10-
override fun compare(other: DataFrameSchema): CompareResult {
10+
override fun compare(other: DataFrameSchema, strictlyEqualNestedSchemas: Boolean): CompareResult {
1111
require(other is DataFrameSchemaImpl)
1212
if (this === other) return CompareResult.Equals
1313
var result = CompareResult.Equals
1414
columns.forEach {
1515
val otherColumn = other.columns[it.key]
1616
if (otherColumn == null) {
17-
result = result.combine(CompareResult.IsDerived)
17+
result = result.combine(if (strictlyEqualNestedSchemas) CompareResult.None else CompareResult.IsDerived)
1818
} else {
19-
result = result.combine(it.value.compare(otherColumn))
19+
result = result.combine(it.value.compareStrictlyEqualNestedSchemas(otherColumn))
2020
}
2121
if (result == CompareResult.None) return CompareResult.None
2222
}
2323
other.columns.forEach {
2424
val thisField = columns[it.key]
2525
if (thisField == null) {
26-
result = result.combine(CompareResult.IsSuper)
26+
result = result.combine(if (strictlyEqualNestedSchemas) CompareResult.None else CompareResult.IsSuper)
2727
if (result == CompareResult.None) return CompareResult.None
2828
}
2929
}

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/schema/ColumnSchema.kt

+31-3
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ public abstract class ColumnSchema {
5656
override val type: KType get() = typeOf<AnyRow>()
5757

5858
public fun compare(other: Group): CompareResult = schema.compare(other.schema)
59+
60+
internal fun compareStrictlyEqualNestedSchemas(other: Group): CompareResult =
61+
schema.compare(other.schema, strictlyEqualNestedSchemas = true)
5962
}
6063

6164
public class Frame(
@@ -68,6 +71,12 @@ public abstract class ColumnSchema {
6871

6972
public fun compare(other: Frame): CompareResult =
7073
schema.compare(other.schema).combine(CompareResult.compareNullability(nullable, other.nullable))
74+
75+
internal fun compareStrictlyEqualNestedSchemas(other: Frame): CompareResult =
76+
schema.compare(
77+
other.schema,
78+
strictlyEqualNestedSchemas = true,
79+
).combine(CompareResult.compareNullability(nullable, other.nullable))
7180
}
7281

7382
/** Checks equality just on kind, type, or schema. */
@@ -83,13 +92,32 @@ public abstract class ColumnSchema {
8392
}
8493
}
8594

86-
public fun compare(other: ColumnSchema): CompareResult {
95+
public fun compare(other: ColumnSchema): CompareResult = compare(other, false)
96+
97+
internal fun compareStrictlyEqualNestedSchemas(other: ColumnSchema): CompareResult = compare(other, true)
98+
99+
private fun compare(other: ColumnSchema, strictlyEqualNestedSchemas: Boolean): CompareResult {
87100
if (kind != other.kind) return CompareResult.None
88101
if (this === other) return CompareResult.Equals
89102
return when (this) {
90103
is Value -> compare(other as Value)
91-
is Group -> compare(other as Group)
92-
is Frame -> compare(other as Frame)
104+
105+
is Group -> if (strictlyEqualNestedSchemas) {
106+
compareStrictlyEqualNestedSchemas(
107+
other as Group,
108+
)
109+
} else {
110+
compare(other as Group)
111+
}
112+
113+
is Frame -> if (strictlyEqualNestedSchemas) {
114+
compareStrictlyEqualNestedSchemas(
115+
other as Frame,
116+
)
117+
} else {
118+
compare(other as Frame)
119+
}
120+
93121
else -> throw NotImplementedError()
94122
}
95123
}

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/schema/DataFrameSchema.kt

+5-1
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,9 @@ public interface DataFrameSchema {
44

55
public val columns: Map<String, ColumnSchema>
66

7-
public fun compare(other: DataFrameSchema): CompareResult
7+
/**
8+
* By default generated markers for leafs aren't used as supertypes: @DataSchema(isOpen = false)
9+
* strictlyEqualNestedSchemas = true takes this into account for internal codegen logic
10+
*/
11+
public fun compare(other: DataFrameSchema, strictlyEqualNestedSchemas: Boolean = false): CompareResult
812
}

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/codeGen/CodeGenerationTests.kt

+21-4
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,12 @@ class CodeGenerationTests : BaseTest() {
103103
val expectedDeclaration =
104104
"""
105105
@DataSchema
106-
interface $typeName { }
106+
interface $typeName {
107+
val age: Int
108+
val city: String?
109+
val name: String
110+
val weight: Int?
111+
}
107112
108113
""".trimIndent() + "\n" + expectedProperties(typeName, typeName)
109114

@@ -129,7 +134,12 @@ class CodeGenerationTests : BaseTest() {
129134
val expectedDeclaration =
130135
"""
131136
@DataSchema
132-
interface $typeName { }
137+
interface $typeName {
138+
val age: Int
139+
val city: String?
140+
val name: String
141+
val weight: Int?
142+
}
133143
134144
""".trimIndent() + "\n" + expectedProperties(typeName, typeName)
135145

@@ -149,7 +159,10 @@ class CodeGenerationTests : BaseTest() {
149159
val declaration1 =
150160
"""
151161
@DataSchema(isOpen = false)
152-
interface $type1 { }
162+
interface $type1 {
163+
val city: String?
164+
val name: String
165+
}
153166
154167
val $dfName<$type1>.city: $dataCol<$stringName?> @JvmName("${type1}_city") get() = this["city"] as $dataCol<$stringName?>
155168
val $dfRowName<$type1>.city: $stringName? @JvmName("${type1}_city") get() = this["city"] as $stringName?
@@ -161,7 +174,11 @@ class CodeGenerationTests : BaseTest() {
161174
val declaration2 =
162175
"""
163176
@DataSchema
164-
interface $type2 { }
177+
interface $type2 {
178+
val age: Int
179+
val nameAndCity: _DataFrameType1
180+
val weight: Int?
181+
}
165182
166183
val $dfName<$type2>.age: $dataCol<$intName> @JvmName("${type2}_age") get() = this["age"] as $dataCol<$intName>
167184
val $dfRowName<$type2>.age: $intName @JvmName("${type2}_age") get() = this["age"] as $intName

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/codeGen/ReplCodeGenTests.kt

+86-4
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,17 @@ package org.jetbrains.kotlinx.dataframe.codeGen
22

33
import io.kotest.matchers.shouldBe
44
import io.kotest.matchers.string.shouldNotBeEmpty
5+
import org.jetbrains.kotlinx.dataframe.AnyRow
56
import org.jetbrains.kotlinx.dataframe.ColumnsScope
67
import org.jetbrains.kotlinx.dataframe.DataColumn
78
import org.jetbrains.kotlinx.dataframe.DataRow
89
import org.jetbrains.kotlinx.dataframe.annotations.DataSchema
10+
import org.jetbrains.kotlinx.dataframe.api.add
11+
import org.jetbrains.kotlinx.dataframe.api.asFrame
12+
import org.jetbrains.kotlinx.dataframe.api.convert
913
import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
1014
import org.jetbrains.kotlinx.dataframe.api.filter
15+
import org.jetbrains.kotlinx.dataframe.api.first
1116
import org.jetbrains.kotlinx.dataframe.api.select
1217
import org.jetbrains.kotlinx.dataframe.impl.codeGen.ReplCodeGenerator
1318
import org.jetbrains.kotlinx.dataframe.impl.codeGen.ReplCodeGeneratorImpl
@@ -81,7 +86,12 @@ class ReplCodeGenTests : BaseTest() {
8186
val expected =
8287
"""
8388
@DataSchema
84-
interface $marker { }
89+
interface $marker {
90+
val age: Int
91+
val city: String?
92+
val name: String
93+
val weight: Int?
94+
}
8595
8696
val $dfName<$marker>.age: $dataCol<$intName> @JvmName("${marker}_age") get() = this["age"] as $dataCol<$intName>
8797
val $dfRowName<$marker>.age: $intName @JvmName("${marker}_age") get() = this["age"] as $intName
@@ -103,7 +113,9 @@ class ReplCodeGenTests : BaseTest() {
103113
val expected3 =
104114
"""
105115
@DataSchema
106-
interface $marker3 : $markerFull { }
116+
interface $marker3 : $markerFull {
117+
override val city: String
118+
}
107119
108120
val $dfName<$marker3>.city: $dataCol<$stringName> @JvmName("${marker3}_city") get() = this["city"] as $dataCol<$stringName>
109121
val $dfRowName<$marker3>.city: $stringName @JvmName("${marker3}_city") get() = this["city"] as $stringName
@@ -120,7 +132,9 @@ class ReplCodeGenTests : BaseTest() {
120132
val expected5 =
121133
"""
122134
@DataSchema
123-
interface $marker5 : $markerFull { }
135+
interface $marker5 : $markerFull {
136+
override val weight: Int
137+
}
124138
125139
val $dfName<$marker5>.weight: $dataCol<$intName> @JvmName("${marker5}_weight") get() = this["weight"] as $dataCol<$intName>
126140
val $dfRowName<$marker5>.weight: $intName @JvmName("${marker5}_weight") get() = this["weight"] as $intName
@@ -163,7 +177,10 @@ class ReplCodeGenTests : BaseTest() {
163177
val expected =
164178
"""
165179
@DataSchema
166-
interface $marker : ${Test2._DataFrameType::class.qualifiedName} { }
180+
interface $marker : ${Test2._DataFrameType::class.qualifiedName} {
181+
val city: String?
182+
val weight: Int?
183+
}
167184
168185
val $dfName<$marker>.city: $dataCol<$stringName?> @JvmName("${marker}_city") get() = this["city"] as $dataCol<$stringName?>
169186
val $dfRowName<$marker>.city: $stringName? @JvmName("${marker}_city") get() = this["city"] as $stringName?
@@ -218,4 +235,69 @@ class ReplCodeGenTests : BaseTest() {
218235
val c = repl.process(Test4.df, Test4::df)
219236
"""val .*ColumnsScope<\w*>.a:""".toRegex().findAll(c.declarations).count() shouldBe 1
220237
}
238+
239+
object Test5 {
240+
@DataSchema(isOpen = false)
241+
interface _DataFrameType1 {
242+
val a: Int
243+
val b: Int
244+
}
245+
246+
val ColumnsScope<_DataFrameType1>.a: DataColumn<Int>
247+
@JvmName("_DataFrameType1_a")
248+
get() = this["a"] as DataColumn<Int>
249+
val DataRow<_DataFrameType1>.a: Int
250+
@JvmName("_DataFrameType1_a")
251+
get() = this["a"] as Int
252+
val ColumnsScope<_DataFrameType1>.b: DataColumn<Int>
253+
@JvmName("_DataFrameType1_b")
254+
get() = this["b"] as DataColumn<Int>
255+
val DataRow<_DataFrameType1>.b: Int
256+
@JvmName("_DataFrameType1_b")
257+
get() = this["b"] as Int
258+
259+
@DataSchema
260+
interface _DataFrameType {
261+
val col: String
262+
val leaf: _DataFrameType1
263+
}
264+
265+
val df = dataFrameOf("col" to listOf("a"), "leaf" to listOf(dataFrameOf("a")(1).first()))
266+
.convert("leaf").cast<AnyRow>().asFrame { it.add("c") { 3 } }
267+
}
268+
269+
@Test
270+
fun `process closed inheritance override`() {
271+
// if ReplCodeGenerator would generate schemas with isOpen = true or with fields = false, _DataFrameType2 could implement _DataFrameType
272+
// but with isOpen = false and fields = true _DataFrameType2 : _DataFrameType produces incorrect override that couldn't be compiled
273+
// so we avoid this relation
274+
val repl = ReplCodeGenerator.create()
275+
repl.process<Test5._DataFrameType>()
276+
repl.process<Test5._DataFrameType1>()
277+
val c = repl.process(Test5.df, Test5::df)
278+
c.declarations shouldBe
279+
"""
280+
@DataSchema(isOpen = false)
281+
interface _DataFrameType3 {
282+
val a: Int
283+
val c: Int
284+
}
285+
286+
val $dfName<_DataFrameType3>.a: $dataCol<Int> @JvmName("_DataFrameType3_a") get() = this["a"] as $dataCol<Int>
287+
val $dfRowName<_DataFrameType3>.a: Int @JvmName("_DataFrameType3_a") get() = this["a"] as Int
288+
val $dfName<_DataFrameType3>.c: $dataCol<Int> @JvmName("_DataFrameType3_c") get() = this["c"] as $dataCol<Int>
289+
val $dfRowName<_DataFrameType3>.c: Int @JvmName("_DataFrameType3_c") get() = this["c"] as Int
290+
291+
@DataSchema
292+
interface _DataFrameType2 {
293+
val col: String
294+
val leaf: _DataFrameType3
295+
}
296+
297+
val $dfName<_DataFrameType2>.col: $dataCol<String> @JvmName("_DataFrameType2_col") get() = this["col"] as $dataCol<String>
298+
val $dfRowName<_DataFrameType2>.col: String @JvmName("_DataFrameType2_col") get() = this["col"] as String
299+
val $dfName<_DataFrameType2>.leaf: ColumnGroup<_DataFrameType3> @JvmName("_DataFrameType2_leaf") get() = this["leaf"] as ColumnGroup<_DataFrameType3>
300+
val $dfRowName<_DataFrameType2>.leaf: $dfRowName<_DataFrameType3> @JvmName("_DataFrameType2_leaf") get() = this["leaf"] as $dfRowName<_DataFrameType3>
301+
""".trimIndent()
302+
}
221303
}

dataframe-jupyter/src/test/kotlin/org/jetbrains/kotlinx/dataframe/jupyter/CodeGenerationTests.kt

+9
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,13 @@ class CodeGenerationTests : DataFrameJupyterTest() {
6161
ab.a
6262
""".checkCompilation()
6363
}
64+
65+
@Test
66+
fun `nested schema with isOpen = false is ignored in marker generation`() {
67+
"""
68+
val df = dataFrameOf("col" to listOf("a"), "leaf" to listOf(dataFrameOf("a", "b")(1, 2).first()))
69+
val df1 = df.convert { leaf }.asFrame { it.add("c") { 3 } }
70+
df1.leaf.c
71+
""".checkCompilation()
72+
}
6473
}

0 commit comments

Comments
 (0)