Skip to content

Commit d514df7

Browse files
Merge pull request #1107 from Kotlin/concat_with_keys
concatWithKeys impl
2 parents 253ce70 + ddb2eb3 commit d514df7

File tree

7 files changed

+111
-2
lines changed

7 files changed

+111
-2
lines changed

core/api/core.api

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,6 +1341,7 @@ public final class org/jetbrains/kotlinx/dataframe/api/ConcatKt {
13411341
public static final fun concat (Lorg/jetbrains/kotlinx/dataframe/api/ReducedGroupBy;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
13421342
public static final fun concatRows (Ljava/lang/Iterable;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
13431343
public static final fun concatT (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Ljava/lang/Iterable;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
1344+
public static final fun concatWithKeys (Lorg/jetbrains/kotlinx/dataframe/api/GroupBy;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
13441345
}
13451346

13461347
public final class org/jetbrains/kotlinx/dataframe/api/ConstructorsKt {

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/concat.kt

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@ package org.jetbrains.kotlinx.dataframe.api
33
import org.jetbrains.kotlinx.dataframe.DataColumn
44
import org.jetbrains.kotlinx.dataframe.DataFrame
55
import org.jetbrains.kotlinx.dataframe.DataRow
6+
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
7+
import org.jetbrains.kotlinx.dataframe.annotations.Refine
68
import org.jetbrains.kotlinx.dataframe.columns.values
79
import org.jetbrains.kotlinx.dataframe.impl.api.concatImpl
810
import org.jetbrains.kotlinx.dataframe.impl.asList
11+
import org.jetbrains.kotlinx.dataframe.type
912

1013
// region DataColumn
1114

@@ -40,6 +43,64 @@ public fun <T> DataFrame<T>.concat(frames: Iterable<DataFrame<T>>): DataFrame<T>
4043

4144
public fun <T, G> GroupBy<T, G>.concat(): DataFrame<G> = groups.concat()
4245

46+
/**
47+
* Concatenates all groups in this [GroupBy] into a single [DataFrame],
48+
* preserving and including all grouping key columns that are not present in the group's columns.
49+
*
50+
* Doesn't affect key columns that have the same name as columns inside the groups (even if their content differs).
51+
*
52+
* This function is especially useful when grouping by expressions or renamed columns,
53+
* and you want the resulting [DataFrame] to include those keys as part of the output.
54+
*
55+
* ### Example
56+
*
57+
* ```kotlin
58+
* val df = dataFrameOf(
59+
* "value" to listOf(1, 2, 3, 3),
60+
* "type" to listOf("a", "b", "a", "b")
61+
* )
62+
*
63+
* val gb = df.groupBy { expr { "Category: \${type.uppercase()}" } named "category" }
64+
* ```
65+
*
66+
* A regular `concat()` will return a [DataFrame] similar to the original `df`
67+
* (with the same columns and rows but in the different orders):
68+
*
69+
* ```
70+
* gb.concat()
71+
* ```
72+
* | value | type |
73+
* | :---- | :--- |
74+
* | 1 | a |
75+
* | 3 | a |
76+
* | 2 | b |
77+
* | 3 | b |
78+
*
79+
* But `concatWithKeys()` will include the new "category" key column:
80+
*
81+
* ```
82+
* gb.concatWithKeys()
83+
* ```
84+
* | value | type | category |
85+
* | :---- | :--- | :------------ |
86+
* | 1 | a | Category: A |
87+
* | 3 | a | Category: A |
88+
* | 2 | b | Category: B |
89+
* | 3 | b | Category: B |
90+
*
91+
* @return A new [DataFrame] where all groups are combined and additional key columns are included in each row.
92+
*/
93+
@Refine
94+
@Interpretable("ConcatWithKeys")
95+
public fun <T, G> GroupBy<T, G>.concatWithKeys(): DataFrame<G> =
96+
mapToFrames {
97+
val rowsCount = group.rowsCount()
98+
val keyColumns = keys.columns().filter { it.name !in group.columnNames() }.map { keyColumn ->
99+
DataColumn.createByType(keyColumn.name, List(rowsCount) { key[keyColumn] }, keyColumn.type)
100+
}
101+
group.addAll(keyColumns)
102+
}.concat()
103+
43104
// endregion
44105

45106
// region ReducedGroupBy

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/concat.kt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,16 @@ class ConcatTests {
1111
val b by columnOf(3.0, null)
1212
a.concat(b) shouldBe columnOf(1, 2, 3.0, null).named("a")
1313
}
14+
15+
@Test
16+
fun `concat with keys`() {
17+
val df = dataFrameOf(
18+
"value" to listOf(1, 2, 3, 3),
19+
"type" to listOf("a", "b", "a", "b"),
20+
)
21+
val gb = df.groupBy { expr { "Category: ${(this["type"] as String).uppercase()}" } named "category" }
22+
val dfWithCategory = gb.concatWithKeys()
23+
24+
dfWithCategory.columnNames() shouldBe listOf("value", "type", "category")
25+
}
1426
}

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.toPluginDataFrameSchema
3939
import org.jetbrains.kotlinx.dataframe.plugin.impl.type
4040
import org.jetbrains.kotlinx.dataframe.plugin.interpret
4141
import org.jetbrains.kotlinx.dataframe.plugin.loadInterpreter
42+
import kotlin.collections.plus
4243

4344
class GroupBy(val keys: PluginDataFrameSchema, val groups: PluginDataFrameSchema) {
4445
companion object {
@@ -450,6 +451,12 @@ private fun isIntraComparable(col: SimpleDataColumn, session: FirSession): Boole
450451
return col.type.type.isSubtypeOf(comparable, session)
451452
}
452453

454+
class ConcatWithKeys : AbstractSchemaModificationInterpreter() {
455+
val Arguments.receiver by groupBy()
453456

454-
455-
457+
override fun Arguments.interpret(): PluginDataFrameSchema {
458+
val originalColumns = receiver.groups.columns()
459+
return PluginDataFrameSchema(
460+
originalColumns + receiver.keys.columns().filter { it.name !in originalColumns.map { it.name } })
461+
}
462+
}

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsOf0
101101
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsOf1
102102
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsOf2
103103
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColumnRange
104+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ConcatWithKeys
104105
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameBuilderInvoke0
105106
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameOf0
106107
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameOf3
@@ -210,6 +211,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.WithoutNulls1
210211
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.WithoutNulls2
211212
import org.jetbrains.kotlinx.dataframe.plugin.utils.Names
212213

214+
213215
internal fun FirFunctionCall.loadInterpreter(session: FirSession): Interpreter<*>? {
214216
val interpreter = Stdlib.interpreter(this)
215217
if (interpreter != null) return interpreter
@@ -463,6 +465,7 @@ internal inline fun <reified T> String.load(): T {
463465
"GroupByStdOf" -> GroupByStdOf()
464466
"DataFrameXs" -> DataFrameXs()
465467
"GroupByXs" -> GroupByXs()
468+
"ConcatWithKeys" -> ConcatWithKeys()
466469
else -> error("$this")
467470
} as T
468471
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import org.jetbrains.kotlinx.dataframe.*
2+
import org.jetbrains.kotlinx.dataframe.annotations.*
3+
import org.jetbrains.kotlinx.dataframe.api.*
4+
import org.jetbrains.kotlinx.dataframe.io.*
5+
6+
fun box(): String {
7+
val df = dataFrameOf(
8+
"value" to listOf(1, 2, 3, 3),
9+
"type" to listOf("a", "b", "a", "b")
10+
)
11+
val gb = df.groupBy { expr { "Category: ${type.uppercase()}" } named "category" }
12+
val categoryKey = gb.keys.category
13+
14+
val dfWithCategory = gb.concatWithKeys()
15+
16+
val category: DataColumn<String> = dfWithCategory.category
17+
18+
return "OK"
19+
}

plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ public void testColumnWithStarProjection() {
8282
runTest("testData/box/columnWithStarProjection.kt");
8383
}
8484

85+
@Test
86+
@TestMetadata("concatWithKeys.kt")
87+
public void testConcatWithKeys() {
88+
runTest("testData/box/concatWithKeys.kt");
89+
}
90+
8591
@Test
8692
@TestMetadata("conflictingJvmDeclarations.kt")
8793
public void testConflictingJvmDeclarations() {

0 commit comments

Comments
 (0)