Skip to content

Commit 9440640

Browse files
committed
[Compiler plugin] Support joinWith operations
1 parent 862240c commit 9440640

File tree

6 files changed

+230
-2
lines changed

6 files changed

+230
-2
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/joinWith.kt

+16
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package org.jetbrains.kotlinx.dataframe.api
33
import org.jetbrains.kotlinx.dataframe.DataFrame
44
import org.jetbrains.kotlinx.dataframe.DataRow
55
import org.jetbrains.kotlinx.dataframe.Selector
6+
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
7+
import org.jetbrains.kotlinx.dataframe.annotations.Refine
68
import org.jetbrains.kotlinx.dataframe.impl.api.joinWithImpl
79

810
public interface JoinedDataRow<out A, out B> : DataRow<A> {
@@ -11,27 +13,41 @@ public interface JoinedDataRow<out A, out B> : DataRow<A> {
1113

1214
public typealias JoinExpression<A, B> = Selector<JoinedDataRow<A, B>, Boolean>
1315

16+
@Refine
17+
@Interpretable("JoinWith")
1418
public fun <A, B> DataFrame<A>.joinWith(
1519
right: DataFrame<B>,
1620
type: JoinType = JoinType.Inner,
1721
joinExpression: JoinExpression<A, B>,
1822
): DataFrame<A> = joinWithImpl(right, type, addNewColumns = type.addNewColumns, joinExpression)
1923

24+
@Refine
25+
@Interpretable("InnerJoinWith")
2026
public fun <A, B> DataFrame<A>.innerJoinWith(right: DataFrame<B>, joinExpression: JoinExpression<A, B>): DataFrame<A> =
2127
joinWith(right, JoinType.Inner, joinExpression)
2228

29+
@Refine
30+
@Interpretable("LeftJoinWith")
2331
public fun <A, B> DataFrame<A>.leftJoinWith(right: DataFrame<B>, joinExpression: JoinExpression<A, B>): DataFrame<A> =
2432
joinWith(right, JoinType.Left, joinExpression)
2533

34+
@Refine
35+
@Interpretable("RightJoinWith")
2636
public fun <A, B> DataFrame<A>.rightJoinWith(right: DataFrame<B>, joinExpression: JoinExpression<A, B>): DataFrame<A> =
2737
joinWith(right, JoinType.Right, joinExpression)
2838

39+
@Refine
40+
@Interpretable("FullJoinWith")
2941
public fun <A, B> DataFrame<A>.fullJoinWith(right: DataFrame<B>, joinExpression: JoinExpression<A, B>): DataFrame<A> =
3042
joinWith(right, JoinType.Full, joinExpression)
3143

44+
@Refine
45+
@Interpretable("FilterJoinWith")
3246
public fun <A, B> DataFrame<A>.filterJoinWith(right: DataFrame<B>, joinExpression: JoinExpression<A, B>): DataFrame<A> =
3347
joinWithImpl(right, JoinType.Inner, addNewColumns = false, joinExpression)
3448

49+
@Refine
50+
@Interpretable("ExcludeJoinWith")
3551
public fun <A, B> DataFrame<A>.excludeJoinWith(
3652
right: DataFrame<B>,
3753
joinExpression: JoinExpression<A, B>,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
package org.jetbrains.kotlinx.dataframe.plugin.impl.api
2+
3+
import org.jetbrains.kotlinx.dataframe.api.JoinType
4+
import org.jetbrains.kotlinx.dataframe.impl.ColumnNameGenerator
5+
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter
6+
import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments
7+
import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema
8+
import org.jetbrains.kotlinx.dataframe.plugin.impl.Present
9+
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleCol
10+
import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame
11+
import org.jetbrains.kotlinx.dataframe.plugin.impl.enum
12+
import org.jetbrains.kotlinx.dataframe.plugin.impl.ignore
13+
import org.jetbrains.kotlinx.dataframe.plugin.impl.makeNullable
14+
15+
internal abstract class AbstractJoinWith() : AbstractInterpreter<PluginDataFrameSchema>() {
16+
val Arguments.receiver: PluginDataFrameSchema by dataFrame()
17+
val Arguments.right: PluginDataFrameSchema by dataFrame()
18+
val Arguments.joinExpression by ignore()
19+
20+
fun Arguments.join(type: JoinType): PluginDataFrameSchema {
21+
val left = receiver.columns()
22+
val right = right.columns()
23+
24+
val nameGenerator = ColumnNameGenerator()
25+
26+
fun MutableList<SimpleCol>.addColumns(columns: List<SimpleCol>) {
27+
for (column in columns) {
28+
val uniqueName = nameGenerator.addUnique(column.name)
29+
add(column.rename(uniqueName))
30+
}
31+
}
32+
33+
val result = buildList {
34+
when (type) {
35+
JoinType.Inner -> {
36+
addColumns(left)
37+
addColumns(right)
38+
}
39+
40+
JoinType.Left -> {
41+
addColumns(left)
42+
addColumns(right.map { makeNullable(it) })
43+
}
44+
45+
JoinType.Right -> {
46+
addColumns(left.map { makeNullable(it) })
47+
addColumns(right)
48+
}
49+
50+
JoinType.Full -> {
51+
addColumns(left.map { makeNullable(it) })
52+
addColumns(right.map { makeNullable(it) })
53+
}
54+
55+
JoinType.Filter -> addColumns(left)
56+
JoinType.Exclude -> addColumns(left)
57+
}
58+
}
59+
return PluginDataFrameSchema(result)
60+
}
61+
}
62+
63+
internal class JoinWith : AbstractJoinWith() {
64+
val Arguments.type: JoinType by enum(defaultValue = Present(JoinType.Inner))
65+
66+
override fun Arguments.interpret(): PluginDataFrameSchema {
67+
return join(type)
68+
}
69+
}
70+
71+
internal class LeftJoinWith : AbstractJoinWith() {
72+
override fun Arguments.interpret(): PluginDataFrameSchema {
73+
return join(JoinType.Left)
74+
}
75+
}
76+
77+
internal class RightJoinWith : AbstractJoinWith() {
78+
override fun Arguments.interpret(): PluginDataFrameSchema {
79+
return join(JoinType.Right)
80+
}
81+
}
82+
83+
internal class FullJoinWith : AbstractJoinWith() {
84+
override fun Arguments.interpret(): PluginDataFrameSchema {
85+
return join(JoinType.Full)
86+
}
87+
}
88+
89+
internal class InnerJoinWith : AbstractJoinWith() {
90+
override fun Arguments.interpret(): PluginDataFrameSchema {
91+
return join(JoinType.Inner)
92+
}
93+
}
94+
95+
internal class FilterJoinWith : AbstractJoinWith() {
96+
override fun Arguments.interpret(): PluginDataFrameSchema {
97+
return join(JoinType.Filter)
98+
}
99+
}
100+
101+
internal class ExcludeJoinWith : AbstractJoinWith() {
102+
override fun Arguments.interpret(): PluginDataFrameSchema {
103+
return join(JoinType.Exclude)
104+
}
105+
}

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt

+14
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,10 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DropLast1
111111
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DropLast2
112112
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DropNa0
113113
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ExcludeJoin
114+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ExcludeJoinWith
114115
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FillNulls0
115116
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FilterJoin
117+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FilterJoinWith
116118
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.First0
117119
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.First1
118120
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.First2
@@ -122,6 +124,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FrameCols0
122124
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FrameCols1
123125
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FrameCols2
124126
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FullJoin
127+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FullJoinWith
125128
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByAdd
126129
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByCount0
127130
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByInto
@@ -153,7 +156,10 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByStdOf
153156
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBySum0
154157
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBySum1
155158
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBySumOf
159+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.InnerJoinWith
156160
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.InsertAt
161+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.JoinWith
162+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.LeftJoinWith
157163
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.MapToFrame
158164
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Merge0
159165
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.MergeId
@@ -199,6 +205,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.RenameToCamelCaseClause
199205
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Reorder
200206
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ReorderColumnsByName
201207
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.RightJoin
208+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.RightJoinWith
202209
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Single0
203210
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Single1
204211
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Single2
@@ -322,6 +329,13 @@ internal inline fun <reified T> String.load(): T {
322329
"InnerJoin" -> InnerJoin()
323330
"ExcludeJoin" -> ExcludeJoin()
324331
"FilterJoin" -> FilterJoin()
332+
"JoinWith" -> JoinWith()
333+
"LeftJoinWith" -> LeftJoinWith()
334+
"RightJoinWith" -> RightJoinWith()
335+
"FullJoinWith" -> FullJoinWith()
336+
"InnerJoinWith" -> InnerJoinWith()
337+
"ExcludeJoinWith" -> ExcludeJoinWith()
338+
"FilterJoinWith" -> FilterJoinWith()
325339
"Match0" -> Match0()
326340
"Rename" -> Rename()
327341
"RenameMapping" -> RenameMapping()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import org.jetbrains.kotlinx.dataframe.*
2+
import org.jetbrains.kotlinx.dataframe.annotations.*
3+
import org.jetbrains.kotlinx.dataframe.api.*
4+
import org.jetbrains.kotlinx.dataframe.io.*
5+
6+
fun box(): String {
7+
val typed = dataFrameOf("name", "age", "city", "weight")(
8+
"Alice", 15, "London", 54,
9+
"Bob", 45, "Dubai", 87,
10+
"Charlie", 20, "Moscow", null,
11+
"Charlie", 40, "Milan", null,
12+
"Bob", 30, "Tokyo", 68,
13+
"Alice", 20, null, 55,
14+
"Charlie", 30, "Moscow", 90,
15+
)
16+
17+
val typed2 = dataFrameOf("name", "origin", "grade", "age")(
18+
"Alice", "London", 3, "young",
19+
"Alice", "London", 5, "old",
20+
"Bob", "Tokyo", 4, "young",
21+
"Bob", "Paris", 5, "old",
22+
"Charlie", "Moscow", 1, "young",
23+
"Charlie", "Moscow", 2, "old",
24+
"Bob", "Paris", 4, null,
25+
)
26+
27+
typed.joinWith(typed2) { name == right.name && city == right.origin }.assert()
28+
29+
typed.joinWith(typed2, type = JoinType.Inner) { name == right.name && city == right.origin }.assert()
30+
31+
typed.innerJoinWith(typed2) { name == right.name && city == right.origin }.assert()
32+
33+
typed.leftJoinWith(typed2) { name == right.name && city == right.origin }.assert()
34+
35+
typed.rightJoinWith(typed2) { name == right.name && city == right.origin }.assert()
36+
37+
typed.fullJoinWith(typed2) { name == right.name && city == right.origin }.assert()
38+
39+
typed.filterJoinWith(typed2) { city == right.origin }.assert()
40+
41+
typed.excludeJoinWith(typed2) { city == right.origin }.assert()
42+
43+
return "OK"
44+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import org.jetbrains.kotlinx.dataframe.*
2+
import org.jetbrains.kotlinx.dataframe.annotations.*
3+
import org.jetbrains.kotlinx.dataframe.api.*
4+
import org.jetbrains.kotlinx.dataframe.io.*
5+
6+
fun box(): String {
7+
val typed = dataFrameOf("name", "age", "city", "weight")(
8+
"Alice", 15, "London", 54,
9+
"Bob", 45, "Dubai", 87,
10+
"Charlie", 20, "Moscow", null,
11+
"Charlie", 40, "Milan", null,
12+
"Bob", 30, "Tokyo", 68,
13+
"Alice", 20, null, 55,
14+
"Charlie", 30, "Moscow", 90,
15+
)
16+
17+
val typed2 = dataFrameOf("name", "origin", "grade", "age")(
18+
"Alice", "London", 3, "young",
19+
"Alice", "London", 5, "old",
20+
"Bob", "Tokyo", 4, "young",
21+
"Bob", "Paris", 5, "old",
22+
"Charlie", "Moscow", 1, "young",
23+
"Charlie", "Moscow", 2, "old",
24+
"Bob", "Paris", 4, null,
25+
)
26+
27+
val joinWithGroups = typed.group { name and age }.into("gr")
28+
.innerJoinWith(typed2.group { origin and age }.into("gr")) { gr.name == right.name }
29+
30+
// columns from right are duplicated, including groups
31+
joinWithGroups.gr1.age
32+
33+
println(joinWithGroups.schema())
34+
joinWithGroups.assert()
35+
36+
return "OK"
37+
}

plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java

+14-2
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ public void testDiff() {
157157
@Test
158158
@TestMetadata("distinct.kt")
159159
public void testDistinct() {
160-
runTest("testData/box/distinct.kt");
161-
}
160+
runTest("testData/box/distinct.kt");
161+
}
162162

163163
@Test
164164
@TestMetadata("dropNA.kt")
@@ -382,6 +382,18 @@ public void testJoinKinds() {
382382
runTest("testData/box/joinKinds.kt");
383383
}
384384

385+
@Test
386+
@TestMetadata("joinWithKinds.kt")
387+
public void testJoinWithKinds() {
388+
runTest("testData/box/joinWithKinds.kt");
389+
}
390+
391+
@Test
392+
@TestMetadata("joinWith_duplicateColumnGroups.kt")
393+
public void testJoinWith_duplicateColumnGroups() {
394+
runTest("testData/box/joinWith_duplicateColumnGroups.kt");
395+
}
396+
385397
@Test
386398
@TestMetadata("join_1.kt")
387399
public void testJoin_1() {

0 commit comments

Comments
 (0)