Skip to content

Commit adac331

Browse files
committed
[Compiler plugin] Support xs operation
1 parent 8497e3e commit adac331

File tree

6 files changed

+100
-1
lines changed

6 files changed

+100
-1
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/xs.kt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,37 @@ package org.jetbrains.kotlinx.dataframe.api
22

33
import org.jetbrains.kotlinx.dataframe.ColumnsSelector
44
import org.jetbrains.kotlinx.dataframe.DataFrame
5+
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
6+
import org.jetbrains.kotlinx.dataframe.annotations.Refine
57
import org.jetbrains.kotlinx.dataframe.impl.api.xsImpl
68

79
// region DataFrame
810

11+
@Refine
12+
@Interpretable("DataFrameXs")
913
public fun <T> DataFrame<T>.xs(vararg keyValues: Any?): DataFrame<T> =
1014
xs(*keyValues) {
1115
colsAtAnyDepth { !it.isColumnGroup() }.take(keyValues.size)
1216
}
1317

18+
@Refine
19+
@Interpretable("DataFrameXs")
1420
public fun <T, C> DataFrame<T>.xs(vararg keyValues: C, keyColumns: ColumnsSelector<T, C>): DataFrame<T> =
1521
xsImpl(keyColumns, false, *keyValues)
1622

1723
// endregion
1824

1925
// region GroupBy
2026

27+
@Refine
28+
@Interpretable("GroupByXs")
2129
public fun <T, G> GroupBy<T, G>.xs(vararg keyValues: Any?): GroupBy<T, G> =
2230
xs(*keyValues) {
2331
colsAtAnyDepth { !it.isColumnGroup() }.take(keyValues.size)
2432
}
2533

34+
@Refine
35+
@Interpretable("GroupByXs")
2636
public fun <T, G, C> GroupBy<T, G>.xs(vararg keyValues: C, keyColumns: ColumnsSelector<T, C>): GroupBy<T, G> =
2737
xsImpl(*keyValues, keyColumns = keyColumns)
2838

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.type
2828
import org.jetbrains.kotlinx.dataframe.plugin.interpret
2929
import org.jetbrains.kotlinx.dataframe.plugin.loadInterpreter
3030

31-
class GroupBy(val keys: PluginDataFrameSchema, val groups: PluginDataFrameSchema)
31+
class GroupBy(val keys: PluginDataFrameSchema, val groups: PluginDataFrameSchema) {
32+
companion object {
33+
val EMPTY = GroupBy(PluginDataFrameSchema.EMPTY, PluginDataFrameSchema.EMPTY)
34+
}
35+
}
3236

3337
class DataFrameGroupBy : AbstractInterpreter<GroupBy>() {
3438
val Arguments.receiver: PluginDataFrameSchema by dataFrame()
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package org.jetbrains.kotlinx.dataframe.plugin.impl.api
2+
3+
import org.jetbrains.kotlin.fir.expressions.FirExpression
4+
import org.jetbrains.kotlin.fir.expressions.FirVarargArgumentsExpression
5+
import org.jetbrains.kotlinx.dataframe.api.getColumnsWithPaths
6+
import org.jetbrains.kotlinx.dataframe.api.isColumnGroup
7+
import org.jetbrains.kotlinx.dataframe.api.remove
8+
import org.jetbrains.kotlinx.dataframe.api.toPath
9+
import org.jetbrains.kotlinx.dataframe.columns.toColumnSet
10+
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter
11+
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractSchemaModificationInterpreter
12+
import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments
13+
import org.jetbrains.kotlinx.dataframe.plugin.impl.Interpreter
14+
import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema
15+
import org.jetbrains.kotlinx.dataframe.plugin.impl.Present
16+
import org.jetbrains.kotlinx.dataframe.plugin.impl.asDataFrame
17+
import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame
18+
import org.jetbrains.kotlinx.dataframe.plugin.impl.groupBy
19+
import org.jetbrains.kotlinx.dataframe.plugin.impl.toPluginDataFrameSchema
20+
21+
class DataFrameXs : AbstractSchemaModificationInterpreter() {
22+
val Arguments.receiver: PluginDataFrameSchema by dataFrame()
23+
val Arguments.keyValues: FirExpression by arg(lens = Interpreter.Id)
24+
val Arguments.keyColumns: ColumnsResolver? by arg(defaultValue = Present(null))
25+
26+
override fun Arguments.interpret(): PluginDataFrameSchema {
27+
val keyColumns = keyColumns?.let { it.resolve(receiver).map { it.path.toPath() } }
28+
val n = (keyValues as? FirVarargArgumentsExpression)?.arguments?.size ?: return PluginDataFrameSchema.EMPTY
29+
return receiver
30+
.asDataFrame()
31+
.remove { keyColumns?.toColumnSet() ?: colsAtAnyDepth { !it.isColumnGroup() }.take(n) }
32+
.toPluginDataFrameSchema()
33+
}
34+
}
35+
36+
class GroupByXs : AbstractInterpreter<GroupBy>() {
37+
val Arguments.receiver by groupBy()
38+
val Arguments.keyValues: FirExpression by arg(lens = Interpreter.Id)
39+
val Arguments.keyColumns: ColumnsResolver? by arg(defaultValue = Present(null))
40+
41+
override fun Arguments.interpret(): GroupBy {
42+
val keyColumns = keyColumns?.let { it.resolve(receiver.keys).map { it.path.toPath() } }
43+
val n = (keyValues as? FirVarargArgumentsExpression)?.arguments?.size ?: return GroupBy.EMPTY
44+
45+
val toRemove = receiver.keys.asDataFrame()
46+
.getColumnsWithPaths { keyColumns?.toColumnSet() ?: colsAtAnyDepth { !it.isColumnGroup() }.take(n) }
47+
.toColumnSet()
48+
val updatedKeys = receiver.keys.asDataFrame().remove { toRemove }.toPluginDataFrameSchema()
49+
val updatedGroups = receiver.groups.asDataFrame().remove { toRemove }.toPluginDataFrameSchema()
50+
return GroupBy(updatedKeys, updatedGroups)
51+
}
52+
}

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColumnRange
106106
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameBuilderInvoke0
107107
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameOf0
108108
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameOf3
109+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameXs
109110
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataRowReadJsonStr
110111
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Drop0
111112
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Drop1
@@ -131,6 +132,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMinOf
131132
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReduceExpression
132133
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReduceInto
133134
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReducePredicate
135+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByXs
134136
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Last0
135137
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Last1
136138
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Last2
@@ -435,6 +437,8 @@ internal inline fun <reified T> String.load(): T {
435437
"GroupByReduceInto" -> GroupByReduceInto()
436438
"GroupByMaxOf" -> GroupByMaxOf()
437439
"GroupByMinOf" -> GroupByMinOf()
440+
"DataFrameXs" -> DataFrameXs()
441+
"GroupByXs" -> GroupByXs()
438442
else -> error("$this")
439443
} as T
440444
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import org.jetbrains.kotlinx.dataframe.*
2+
import org.jetbrains.kotlinx.dataframe.annotations.*
3+
import org.jetbrains.kotlinx.dataframe.api.*
4+
import org.jetbrains.kotlinx.dataframe.io.*
5+
6+
fun box(): String {
7+
val typed = dataFrameOf("name", "age", "city", "weight")(
8+
"Alice", 15, "London", 54,
9+
"Bob", 45, "Dubai", 87,
10+
"Charlie", 20, "Moscow", null,
11+
"Charlie", 40, "Milan", null,
12+
"Bob", 30, "Tokyo", 68,
13+
"Alice", 20, null, 55,
14+
"Charlie", 30, "Moscow", 90,
15+
)
16+
17+
typed.xs("Charlie").compareSchemas()
18+
19+
typed.groupBy { name }.xs("Charlie").toDataFrame().compareSchemas()
20+
21+
typed.groupBy { name }.xs("Alice") { name }.toDataFrame().compareSchemas()
22+
return "OK"
23+
}

plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,12 @@ public void testWrongReceiver() {
694694
runTest("testData/box/wrongReceiver.kt");
695695
}
696696

697+
@Test
698+
@TestMetadata("xs.kt")
699+
public void testXs() {
700+
runTest("testData/box/xs.kt");
701+
}
702+
697703
@Nested
698704
@TestMetadata("testData/box/colKinds")
699705
@TestDataPath("$PROJECT_ROOT")

0 commit comments

Comments
 (0)