Skip to content

Commit 59e088a

Browse files
authored
Merge pull request #1114 from Kotlin/asGroupBy
[Compiler plugin] Support df.asGroupBy
2 parents 9f52d9a + 87dcab4 commit 59e088a

File tree

6 files changed

+63
-1
lines changed

6 files changed

+63
-1
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/typeConversions.kt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import org.jetbrains.kotlinx.dataframe.DataColumn
1010
import org.jetbrains.kotlinx.dataframe.DataFrame
1111
import org.jetbrains.kotlinx.dataframe.DataRow
1212
import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload
13+
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
14+
import org.jetbrains.kotlinx.dataframe.annotations.Refine
1315
import org.jetbrains.kotlinx.dataframe.columns.ColumnAccessor
1416
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
1517
import org.jetbrains.kotlinx.dataframe.columns.ColumnPath
@@ -373,11 +375,15 @@ public fun <T, G> DataFrame<T>.asGroupBy(groupedColumn: ColumnReference<DataFram
373375
return asGroupBy { groups }
374376
}
375377

378+
@Refine
379+
@Interpretable("AsGroupByDefault")
376380
public fun <T> DataFrame<T>.asGroupBy(): GroupBy<T, T> {
377381
val groupCol = columns().single { it.isFrameColumn() }.asAnyFrameColumn().castFrameColumn<T>()
378382
return asGroupBy { groupCol }
379383
}
380384

385+
@Refine
386+
@Interpretable("AsGroupBy")
381387
public fun <T, G> DataFrame<T>.asGroupBy(selector: ColumnSelector<T, DataFrame<G>>): GroupBy<T, G> {
382388
val column = getColumn(selector).asFrameColumn()
383389
return GroupByImpl(this.move { column }.toEnd(), column) { none() }

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import org.jetbrains.kotlin.fir.types.resolvedType
1414
import org.jetbrains.kotlin.fir.types.typeContext
1515
import org.jetbrains.kotlin.fir.types.withNullability
1616
import org.jetbrains.kotlin.name.StandardClassIds
17+
import org.jetbrains.kotlinx.dataframe.api.remove
1718
import org.jetbrains.kotlinx.dataframe.plugin.InterpretationErrorReporter
1819
import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade
1920
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter
@@ -27,12 +28,14 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleColumnGroup
2728
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleDataColumn
2829
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleFrameColumn
2930
import org.jetbrains.kotlinx.dataframe.plugin.impl.add
31+
import org.jetbrains.kotlinx.dataframe.plugin.impl.asDataFrame
3032
import org.jetbrains.kotlinx.dataframe.plugin.impl.data.ColumnWithPathApproximation
3133
import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame
3234
import org.jetbrains.kotlinx.dataframe.plugin.impl.groupBy
3335
import org.jetbrains.kotlinx.dataframe.plugin.impl.ignore
3436
import org.jetbrains.kotlinx.dataframe.plugin.impl.makeNullable
3537
import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf
38+
import org.jetbrains.kotlinx.dataframe.plugin.impl.toPluginDataFrameSchema
3639
import org.jetbrains.kotlinx.dataframe.plugin.impl.type
3740
import org.jetbrains.kotlinx.dataframe.plugin.interpret
3841
import org.jetbrains.kotlinx.dataframe.plugin.loadInterpreter
@@ -53,6 +56,33 @@ class DataFrameGroupBy : AbstractInterpreter<GroupBy>() {
5356
}
5457
}
5558

59+
class AsGroupBy : AbstractInterpreter<GroupBy>() {
60+
val Arguments.receiver: PluginDataFrameSchema by dataFrame()
61+
val Arguments.selector: ColumnsResolver by arg()
62+
63+
override fun Arguments.interpret(): GroupBy {
64+
val column = selector.resolve(receiver).singleOrNull()?.column
65+
return if (column is SimpleFrameColumn) {
66+
GroupBy(receiver.asDataFrame().remove { selector }.toPluginDataFrameSchema(), PluginDataFrameSchema(column.columns()))
67+
} else {
68+
GroupBy.EMPTY
69+
}
70+
}
71+
}
72+
73+
class AsGroupByDefault : AbstractInterpreter<GroupBy>() {
74+
val Arguments.receiver: PluginDataFrameSchema by dataFrame()
75+
76+
override fun Arguments.interpret(): GroupBy {
77+
val groups = receiver.columns().singleOrNull { it is SimpleFrameColumn } as? SimpleFrameColumn
78+
return if (groups != null) {
79+
GroupBy(receiver.asDataFrame().remove(groups.name).toPluginDataFrameSchema(), PluginDataFrameSchema(groups.columns()))
80+
} else {
81+
GroupBy.EMPTY
82+
}
83+
}
84+
}
85+
5686
class NamedValue(val name: String, val type: ConeKotlinType)
5787

5888
class GroupByDsl {

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/interpret.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColumnsResolver
8383
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBy
8484
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.SingleColumnApproximation
8585
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.TypeApproximation
86+
import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf
8687

8788
fun <T> KotlinTypeFacade.interpret(
8889
functionCall: FirFunctionCall,
@@ -400,7 +401,7 @@ private fun KotlinTypeFacade.columnWithPathApproximations(result: FirPropertyAcc
400401
is ConeStarProjection -> session.builtinTypes.nullableAnyType.type
401402
else -> arg as ConeClassLikeType
402403
}
403-
SimpleDataColumn(f(result), Marker(type))
404+
simpleColumnOf(f(result), type)
404405
}
405406
Names.COLUM_GROUP_CLASS_ID -> {
406407
val arg = it.typeArguments.single()

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.AllFrom2
8787
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.AllUpTo0
8888
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.AllUpTo1
8989
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.AllUpTo2
90+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.AsGroupBy
91+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.AsGroupByDefault
9092
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ByName
9193
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColGroups0
9294
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColGroups1
@@ -324,6 +326,8 @@ internal inline fun <reified T> String.load(): T {
324326
"Exclude1" -> Exclude1()
325327
"RenameInto" -> RenameInto()
326328
"DataFrameGroupBy" -> DataFrameGroupBy()
329+
"AsGroupBy" -> AsGroupBy()
330+
"AsGroupByDefault" -> AsGroupByDefault()
327331
"AggregateDslInto" -> AggregateDslInto()
328332
"GroupByToDataFrame" -> GroupByToDataFrame()
329333
"GroupByInto" -> GroupByInto()
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import org.jetbrains.kotlinx.dataframe.*
2+
import org.jetbrains.kotlinx.dataframe.annotations.*
3+
import org.jetbrains.kotlinx.dataframe.api.*
4+
import org.jetbrains.kotlinx.dataframe.io.*
5+
6+
fun box(): String {
7+
val df = dataFrameOf("i", "group")(1, dataFrameOf("a", "b")(111, 222))
8+
val aggregated1 = df.asGroupBy { group }.aggregate { maxOf { a } into "max" }
9+
val aggregated2 = df.asGroupBy().aggregate { maxOf { a } into "max" }
10+
11+
val i: Int = aggregated1.max[0]
12+
13+
compareSchemas(aggregated1, aggregated2)
14+
return "OK"
15+
}

plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ public void testAllFilesPresentInBox() {
4040
KtTestUtil.assertAllTestsPresentByMetadataWithExcluded(this.getClass(), new File("testData/box"), Pattern.compile("^(.+)\\.kt$"), null, TargetBackend.JVM_IR, true);
4141
}
4242

43+
@Test
44+
@TestMetadata("asGroupBy.kt")
45+
public void testAsGroupBy() {
46+
runTest("testData/box/asGroupBy.kt");
47+
}
48+
4349
@Test
4450
@TestMetadata("castTo.kt")
4551
public void testCastTo() {

0 commit comments

Comments
 (0)