Skip to content

Commit b60d6fc

Browse files
Add Compiler Plugin support for statistics on GroupBy (#1077)
* Add GroupBySumOf functionality to groupBy operations Introduces the `GroupBySumOf` interpreter for aggregation, enabling the calculation of column sums with customizable expressions and result names in grouped DataFrames. Adds tests and updates APIs to support and validate this feature. * Add commented GroupBySum0 support with updated scenarios and tests * Refactor GroupBy statistics functionality and tests. Updated statistical aggregation functions for GroupBy with comments addressing open questions. Added comprehensive tests to verify behavior across various statistics (sum, mean, median, std, min, and max), replacing older test cases for cleaner coverage. * Add support for GroupBy mean and median operations. * Add support for min and max functions * Added support for std function * Updated support for sum/sumFor * Added support for all statistics but faced with limitation of FiR * Fixed comparable types for median * Fixed for max/min * Fixed for std/mean * added missed casts to median/percentile. Could result in Comparable<Any?> columns * Refactor groupBy for enhanced type safety and comparability Replaced direct subtype checks with `isIntraComparable` to improve type safety when resolving columns. Updated documentation syntax for better consistency and clarity. Added schema comparison in test to validate grouping behavior. * Refactor `GroupBy` aggregation classes and test handling. Revised `GroupBy` aggregation logic by restructuring classes, improving naming consistency, and refining comments/documentation. Updated test cases to address initializer type mismatches and better handle scenarios involving multiple columns. Added relevant TODOs for unresolved cases linked to issue #1090. --------- Co-authored-by: Jolan Rensen <jolan.rensen@jetbrains.com>
1 parent ccbcd19 commit b60d6fc

File tree

17 files changed

+1159
-14
lines changed

17 files changed

+1159
-14
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/max.kt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,12 @@ public fun <T, C : Comparable<C>> DataFrame<T>.maxByOrNull(column: KProperty<C?>
133133
// endregion
134134

135135
// region GroupBy
136-
136+
@Refine
137+
@Interpretable("GroupByMax1")
137138
public fun <T> Grouped<T>.max(): DataFrame<T> = maxFor(interComparableColumns())
138139

140+
@Refine
141+
@Interpretable("GroupByMax0")
139142
public fun <T, C : Comparable<C>> Grouped<T>.maxFor(columns: ColumnsForAggregateSelector<T, C?>): DataFrame<T> =
140143
Aggregators.max.aggregateFor(this, columns)
141144

@@ -149,6 +152,8 @@ public fun <T, C : Comparable<C>> Grouped<T>.maxFor(vararg columns: ColumnRefere
149152
public fun <T, C : Comparable<C>> Grouped<T>.maxFor(vararg columns: KProperty<C?>): DataFrame<T> =
150153
maxFor { columns.toColumnSet() }
151154

155+
@Refine
156+
@Interpretable("GroupByMax0")
152157
public fun <T, C : Comparable<C>> Grouped<T>.max(name: String? = null, columns: ColumnsSelector<T, C?>): DataFrame<T> =
153158
Aggregators.max.aggregateAll(this, name, columns)
154159

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import org.jetbrains.kotlinx.dataframe.DataRow
88
import org.jetbrains.kotlinx.dataframe.RowExpression
99
import org.jetbrains.kotlinx.dataframe.aggregation.ColumnsForAggregateSelector
1010
import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload
11+
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
12+
import org.jetbrains.kotlinx.dataframe.annotations.Refine
1113
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
1214
import org.jetbrains.kotlinx.dataframe.columns.toColumnSet
1315
import org.jetbrains.kotlinx.dataframe.columns.toColumnsSetOf
@@ -98,9 +100,12 @@ public inline fun <T, reified D : Number> DataFrame<T>.meanOf(
98100
// endregion
99101

100102
// region GroupBy
101-
103+
@Refine
104+
@Interpretable("GroupByMean1")
102105
public fun <T> Grouped<T>.mean(skipNA: Boolean = skipNA_default): DataFrame<T> = meanFor(skipNA, numberColumns())
103106

107+
@Refine
108+
@Interpretable("GroupByMean0")
104109
public fun <T, C : Number> Grouped<T>.meanFor(
105110
skipNA: Boolean = skipNA_default,
106111
columns: ColumnsForAggregateSelector<T, C?>,
@@ -121,6 +126,8 @@ public fun <T, C : Number> Grouped<T>.meanFor(
121126
skipNA: Boolean = skipNA_default,
122127
): DataFrame<T> = meanFor(skipNA) { columns.toColumnSet() }
123128

129+
@Refine
130+
@Interpretable("GroupByMean0")
124131
public fun <T, C : Number> Grouped<T>.mean(
125132
name: String? = null,
126133
skipNA: Boolean = skipNA_default,
@@ -147,6 +154,8 @@ public fun <T, C : Number> Grouped<T>.mean(
147154
skipNA: Boolean = skipNA_default,
148155
): DataFrame<T> = mean(name, skipNA) { columns.toColumnSet() }
149156

157+
@Refine
158+
@Interpretable("GroupByMeanOf")
150159
public inline fun <T, reified R : Number> Grouped<T>.meanOf(
151160
name: String? = null,
152161
skipNA: Boolean = skipNA_default,

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/median.kt

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import org.jetbrains.kotlinx.dataframe.DataRow
88
import org.jetbrains.kotlinx.dataframe.RowExpression
99
import org.jetbrains.kotlinx.dataframe.aggregation.ColumnsForAggregateSelector
1010
import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload
11+
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
12+
import org.jetbrains.kotlinx.dataframe.annotations.Refine
1113
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
1214
import org.jetbrains.kotlinx.dataframe.columns.toColumnSet
1315
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregators
@@ -16,6 +18,7 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.interComparableColumns
1618
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateAll
1719
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor
1820
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf
21+
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOfDelegated
1922
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.of
2023
import org.jetbrains.kotlinx.dataframe.impl.columns.toComparableColumns
2124
import org.jetbrains.kotlinx.dataframe.impl.suggestIfNull
@@ -103,9 +106,12 @@ public inline fun <T, reified R : Comparable<R>> DataFrame<T>.medianOf(
103106
// endregion
104107

105108
// region GroupBy
106-
109+
@Refine
110+
@Interpretable("GroupByMedian1")
107111
public fun <T> Grouped<T>.median(): DataFrame<T> = medianFor(interComparableColumns())
108112

113+
@Refine
114+
@Interpretable("GroupByMedian0")
109115
public fun <T, C : Comparable<C>> Grouped<T>.medianFor(columns: ColumnsForAggregateSelector<T, C?>): DataFrame<T> =
110116
Aggregators.median.aggregateFor(this, columns)
111117

@@ -119,6 +125,8 @@ public fun <T, C : Comparable<C>> Grouped<T>.medianFor(vararg columns: ColumnRef
119125
public fun <T, C : Comparable<C>> Grouped<T>.medianFor(vararg columns: KProperty<C?>): DataFrame<T> =
120126
medianFor { columns.toColumnSet() }
121127

128+
@Refine
129+
@Interpretable("GroupByMedian0")
122130
public fun <T, C : Comparable<C>> Grouped<T>.median(
123131
name: String? = null,
124132
columns: ColumnsSelector<T, C?>,
@@ -137,10 +145,12 @@ public fun <T, C : Comparable<C>> Grouped<T>.median(
137145
public fun <T, C : Comparable<C>> Grouped<T>.median(vararg columns: KProperty<C?>, name: String? = null): DataFrame<T> =
138146
median(name) { columns.toColumnSet() }
139147

148+
@Refine
149+
@Interpretable("GroupByMedianOf")
140150
public inline fun <T, reified R : Comparable<R>> Grouped<T>.medianOf(
141151
name: String? = null,
142152
crossinline expression: RowExpression<T, R?>,
143-
): DataFrame<T> = Aggregators.median.aggregateOf(this, name, expression)
153+
): DataFrame<T> = Aggregators.median.cast<R?>().aggregateOf(this, name, expression)
144154

145155
// endregion
146156

@@ -227,6 +237,6 @@ public fun <T, C : Comparable<C>> PivotGroupBy<T>.median(vararg columns: KProper
227237

228238
public inline fun <T, reified R : Comparable<R>> PivotGroupBy<T>.medianOf(
229239
crossinline expression: RowExpression<T, R?>,
230-
): DataFrame<T> = Aggregators.median.aggregateOf(this, expression)
240+
): DataFrame<T> = Aggregators.median.cast<R?>().aggregateOf(this, expression)
231241

232242
// endregion

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/min.kt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,12 @@ public fun <T, C : Comparable<C>> DataFrame<T>.minByOrNull(column: KProperty<C?>
133133
// endregion
134134

135135
// region GroupBy
136-
136+
@Refine
137+
@Interpretable("GroupByMin1")
137138
public fun <T> Grouped<T>.min(): DataFrame<T> = minFor(interComparableColumns())
138139

140+
@Refine
141+
@Interpretable("GroupByMin0")
139142
public fun <T, C : Comparable<C>> Grouped<T>.minFor(columns: ColumnsForAggregateSelector<T, C?>): DataFrame<T> =
140143
Aggregators.min.aggregateFor(this, columns)
141144

@@ -149,6 +152,8 @@ public fun <T, C : Comparable<C>> Grouped<T>.minFor(vararg columns: ColumnRefere
149152
public fun <T, C : Comparable<C>> Grouped<T>.minFor(vararg columns: KProperty<C?>): DataFrame<T> =
150153
minFor { columns.toColumnSet() }
151154

155+
@Refine
156+
@Interpretable("GroupByMin0")
152157
public fun <T, C : Comparable<C>> Grouped<T>.min(name: String? = null, columns: ColumnsSelector<T, C?>): DataFrame<T> =
153158
Aggregators.min.aggregateAll(this, name, columns)
154159

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/percentile.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ public inline fun <T, reified R : Comparable<R>> Grouped<T>.percentileOf(
177177
percentile: Double,
178178
name: String? = null,
179179
crossinline expression: RowExpression<T, R?>,
180-
): DataFrame<T> = Aggregators.percentile(percentile).aggregateOf(this, name, expression)
180+
): DataFrame<T> = Aggregators.percentile(percentile).cast<R?>().aggregateOf(this, name, expression)
181181

182182
// endregion
183183

@@ -289,6 +289,6 @@ public fun <T, C : Comparable<C>> PivotGroupBy<T>.percentile(
289289
public inline fun <T, reified R : Comparable<R>> PivotGroupBy<T>.percentileOf(
290290
percentile: Double,
291291
crossinline expression: RowExpression<T, R?>,
292-
): DataFrame<T> = Aggregators.percentile(percentile).aggregateOf(this, expression)
292+
): DataFrame<T> = Aggregators.percentile(percentile).cast<R?>().aggregateOf(this, expression)
293293

294294
// endregion

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/std.kt

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import org.jetbrains.kotlinx.dataframe.DataRow
88
import org.jetbrains.kotlinx.dataframe.RowExpression
99
import org.jetbrains.kotlinx.dataframe.aggregation.ColumnsForAggregateSelector
1010
import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload
11+
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
12+
import org.jetbrains.kotlinx.dataframe.annotations.Refine
1113
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
1214
import org.jetbrains.kotlinx.dataframe.columns.toColumnSet
1315
import org.jetbrains.kotlinx.dataframe.columns.toColumnsSetOf
@@ -102,10 +104,13 @@ public inline fun <T, reified R : Number> DataFrame<T>.stdOf(
102104
// endregion
103105

104106
// region GroupBy
105-
107+
@Refine
108+
@Interpretable("GroupByStd1")
106109
public fun <T> Grouped<T>.std(skipNA: Boolean = skipNA_default, ddof: Int = ddof_default): DataFrame<T> =
107110
stdFor(skipNA, ddof, numberColumns())
108111

112+
@Refine
113+
@Interpretable("GroupByStd0")
109114
public fun <T> Grouped<T>.stdFor(
110115
skipNA: Boolean = skipNA_default,
111116
ddof: Int = ddof_default,
@@ -118,6 +123,7 @@ public fun <T> Grouped<T>.stdFor(
118123
ddof: Int = ddof_default,
119124
): DataFrame<T> = stdFor(skipNA, ddof) { columns.toColumnsSetOf() }
120125

126+
@AccessApiOverload
121127
public fun <T, C : Number> Grouped<T>.stdFor(
122128
vararg columns: ColumnReference<C?>,
123129
skipNA: Boolean = skipNA_default,
@@ -131,13 +137,16 @@ public fun <T, C : Number> Grouped<T>.stdFor(
131137
ddof: Int = ddof_default,
132138
): DataFrame<T> = stdFor(skipNA, ddof) { columns.toColumnSet() }
133139

140+
@Refine
141+
@Interpretable("GroupByStd0")
134142
public fun <T> Grouped<T>.std(
135143
name: String? = null,
136144
skipNA: Boolean = skipNA_default,
137145
ddof: Int = ddof_default,
138146
columns: ColumnsSelector<T, Number?>,
139147
): DataFrame<T> = Aggregators.std(skipNA, ddof).aggregateAll(this, name, columns)
140148

149+
@AccessApiOverload
141150
public fun <T> Grouped<T>.std(
142151
vararg columns: ColumnReference<Number?>,
143152
name: String? = null,
@@ -160,6 +169,8 @@ public fun <T> Grouped<T>.std(
160169
ddof: Int = ddof_default,
161170
): DataFrame<T> = std(name, skipNA, ddof) { columns.toColumnSet() }
162171

172+
@Refine
173+
@Interpretable("GroupByStdOf")
163174
public inline fun <T, reified R : Number> Grouped<T>.stdOf(
164175
name: String? = null,
165176
skipNA: Boolean = skipNA_default,

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import org.jetbrains.kotlinx.dataframe.DataRow
88
import org.jetbrains.kotlinx.dataframe.RowExpression
99
import org.jetbrains.kotlinx.dataframe.aggregation.ColumnsForAggregateSelector
1010
import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload
11+
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
12+
import org.jetbrains.kotlinx.dataframe.annotations.Refine
1113
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
1214
import org.jetbrains.kotlinx.dataframe.columns.toColumnSet
1315
import org.jetbrains.kotlinx.dataframe.columns.toColumnsSetOf
@@ -89,9 +91,12 @@ public inline fun <T, reified C : Number?> DataFrame<T>.sumOf(crossinline expres
8991
// endregion
9092

9193
// region GroupBy
92-
94+
@Refine
95+
@Interpretable("GroupBySum1")
9396
public fun <T> Grouped<T>.sum(): DataFrame<T> = sumFor(numberColumns())
9497

98+
@Refine
99+
@Interpretable("GroupBySum0")
95100
public fun <T, C : Number> Grouped<T>.sumFor(columns: ColumnsForAggregateSelector<T, C?>): DataFrame<T> =
96101
Aggregators.sum.aggregateFor(this, columns)
97102

@@ -105,6 +110,8 @@ public fun <T, C : Number> Grouped<T>.sumFor(vararg columns: ColumnReference<C?>
105110
public fun <T, C : Number> Grouped<T>.sumFor(vararg columns: KProperty<C?>): DataFrame<T> =
106111
sumFor { columns.toColumnSet() }
107112

113+
@Refine
114+
@Interpretable("GroupBySum0")
108115
public fun <T, C : Number> Grouped<T>.sum(name: String? = null, columns: ColumnsSelector<T, C?>): DataFrame<T> =
109116
Aggregators.sum.aggregateAll(this, name, columns)
110117

@@ -119,6 +126,8 @@ public fun <T, C : Number> Grouped<T>.sum(vararg columns: ColumnReference<C?>, n
119126
public fun <T, C : Number> Grouped<T>.sum(vararg columns: KProperty<C?>, name: String? = null): DataFrame<T> =
120127
sum(name) { columns.toColumnSet() }
121128

129+
@Refine
130+
@Interpretable("GroupBySumOf")
122131
public inline fun <T, reified R : Number> Grouped<T>.sumOf(
123132
resultName: String? = null,
124133
crossinline expression: RowExpression<T, R?>,

0 commit comments

Comments
 (0)