Skip to content

Commit 2ac173c

Browse files
zaleslawCopilot
andauthored
Add support for DataFrame sum operation with tests (#1148)
* Add support for DataFrame `sum` operation with tests Introduced the `sum` operation for DataFrames, supporting numerical columns aggregation. Updated relevant tests and added new test cases to verify functionality. Included schema modifications for handling numerical column operations. * Make aggregator-related classes and functions public Converted various internal classes, interfaces, and functions related to aggregation into public entities. This change expands their visibility, enabling external usage and facilitating integration with other modules or libraries. * Enhance type conversions between `KType` and `ConeKotlinType` to ensure compatibility and correctness in sum calculations. * Update plugins/kotlin-dataframe/testData/box/sum.kt Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Refactor type conversion and column handling logic * Fixed review * Fixed conflict * Fix linting --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 5c2aabe commit 2ac173c

File tree

16 files changed

+406
-78
lines changed

16 files changed

+406
-78
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt

+4-1
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,13 @@ public fun AnyRow.rowSumOf(type: KType, skipNaN: Boolean = skipNaNDefault): Numb
125125
// endregion
126126

127127
// region DataFrame
128-
128+
@Refine
129+
@Interpretable("Sum0")
129130
public fun <T> DataFrame<T>.sum(skipNaN: Boolean = skipNaNDefault): DataRow<T> =
130131
sumFor(skipNaN, primitiveOrMixedNumberColumns())
131132

133+
@Refine
134+
@Interpretable("Sum1")
132135
public fun <T, C : Number?> DataFrame<T>.sumFor(
133136
skipNaN: Boolean = skipNaNDefault,
134137
columns: ColumnsForAggregateSelector<T, C>,

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt

+7-8
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,11 @@ import kotlin.reflect.full.withNullability
3737
* @param Return The type of the resulting value. Can optionally be nullable.
3838
* @see [invoke]
3939
*/
40-
@PublishedApi
41-
internal class Aggregator<in Value : Any, out Return : Any?>(
42-
val aggregationHandler: AggregatorAggregationHandler<Value, Return>,
43-
val inputHandler: AggregatorInputHandler<Value, Return>,
44-
val multipleColumnsHandler: AggregatorMultipleColumnsHandler<Value, Return>,
45-
val name: String,
40+
public class Aggregator<in Value : Any, out Return : Any?>(
41+
public val aggregationHandler: AggregatorAggregationHandler<Value, Return>,
42+
public val inputHandler: AggregatorInputHandler<Value, Return>,
43+
public val multipleColumnsHandler: AggregatorMultipleColumnsHandler<Value, Return>,
44+
public val name: String,
4645
) : AggregatorInputHandler<Value, Return> by inputHandler,
4746
AggregatorMultipleColumnsHandler<Value, Return> by multipleColumnsHandler,
4847
AggregatorAggregationHandler<Value, Return> by aggregationHandler {
@@ -96,7 +95,7 @@ internal class Aggregator<in Value : Any, out Return : Any?>(
9695
internal fun <Value : Any, Return : Any?> Aggregator<Value, Return>.aggregate(
9796
values: Sequence<Value?>,
9897
valueType: ValueType,
99-
) = aggregateSequence(values, valueType)
98+
): Return = aggregateSequence(values, valueType)
10099

101100
/**
102101
* Performs aggregation on the given [values], taking [valueType] into account.
@@ -106,7 +105,7 @@ internal fun <Value : Any, Return : Any?> Aggregator<Value, Return>.aggregate(
106105
internal fun <Value : Any, Return : Any?> Aggregator<Value, Return>.aggregate(
107106
values: Sequence<Value?>,
108107
valueType: KType,
109-
) = aggregate(values, valueType.toValueType(needsFullConversion = false))
108+
): Return = aggregate(values, valueType.toValueType(needsFullConversion = false))
110109

111110
/**
112111
* If the specific [ValueType] of the input is not known, but you still want to call [aggregate],

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorAggregationHandler.kt

+5-6
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@ import kotlin.reflect.KType
1111
* It also provides information on which return type will be given, as [KType], given a [value type][ValueType].
1212
* It can also provide the index of the result in the input values if it is a selecting aggregator.
1313
*/
14-
@PublishedApi
15-
internal interface AggregatorAggregationHandler<in Value : Any, out Return : Any?> : AggregatorHandler<Value, Return> {
14+
public interface AggregatorAggregationHandler<in Value : Any, out Return : Any?> : AggregatorHandler<Value, Return> {
1615

1716
/**
1817
* Base function of [Aggregator].
@@ -23,13 +22,13 @@ internal interface AggregatorAggregationHandler<in Value : Any, out Return : Any
2322
*
2423
* When the exact [valueType] is unknown, use [calculateValueType] or [aggregateCalculatingValueType].
2524
*/
26-
fun aggregateSequence(values: Sequence<Value?>, valueType: ValueType): Return
25+
public fun aggregateSequence(values: Sequence<Value?>, valueType: ValueType): Return
2726

2827
/**
2928
* Aggregates the data in the given column and computes a single resulting value.
3029
* Calls [aggregateSequence].
3130
*/
32-
fun aggregateSingleColumn(column: DataColumn<Value?>): Return =
31+
public fun aggregateSingleColumn(column: DataColumn<Value?>): Return =
3332
aggregateSequence(
3433
values = column.asSequence(),
3534
valueType = column.type().toValueType(),
@@ -43,7 +42,7 @@ internal interface AggregatorAggregationHandler<in Value : Any, out Return : Any
4342
* @param emptyInput If `true`, the input values are considered empty. This often affects the return type.
4443
* @return The return type of [aggregateSequence] as [KType].
4544
*/
46-
fun calculateReturnType(valueType: KType, emptyInput: Boolean): KType
45+
public fun calculateReturnType(valueType: KType, emptyInput: Boolean): KType
4746

4847
/**
4948
* Function that can give the index of the aggregation result in the input [values], if it applies.
@@ -54,5 +53,5 @@ internal interface AggregatorAggregationHandler<in Value : Any, out Return : Any
5453
*
5554
* Defaults to `-1`.
5655
*/
57-
fun indexOfAggregationResultSingleSequence(values: Sequence<Value?>, valueType: ValueType): Int
56+
public fun indexOfAggregationResultSingleSequence(values: Sequence<Value?>, valueType: ValueType): Int
5857
}

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorHandler.kt

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@ package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators
77
* the [init] function of each [AggregatorAggregationHandlers][AggregatorAggregationHandler] is called,
88
* which allows the handler to refer to [Aggregator] instance via [aggregator].
99
*/
10-
internal interface AggregatorHandler<in Value : Any, out Return : Any?> {
10+
public interface AggregatorHandler<in Value : Any, out Return : Any?> {
1111

1212
/**
1313
* Reference to the aggregator instance.
1414
*
1515
* Can only be used once [init] has run.
1616
*/
17-
var aggregator: Aggregator<@UnsafeVariance Value, @UnsafeVariance Return>?
17+
public var aggregator: Aggregator<@UnsafeVariance Value, @UnsafeVariance Return>?
1818

19-
fun init(aggregator: Aggregator<@UnsafeVariance Value, @UnsafeVariance Return>) {
19+
public fun init(aggregator: Aggregator<@UnsafeVariance Value, @UnsafeVariance Return>) {
2020
this.aggregator = aggregator
2121
}
2222
}

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorInputHandler.kt

+4-4
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@ import kotlin.reflect.KType
88
* It can also calculate a specific [value type][ValueType] from the input values or input types
99
* if the (specific) type is not known.
1010
*/
11-
internal interface AggregatorInputHandler<in Value : Any, out Return : Any?> : AggregatorHandler<Value, Return> {
11+
public interface AggregatorInputHandler<in Value : Any, out Return : Any?> : AggregatorHandler<Value, Return> {
1212

1313
/**
1414
* If the specific [ValueType] of the input is not known, but you still want to call [aggregate],
1515
* this function can be called to calculate it by combining the set of known [valueTypes].
1616
*/
17-
fun calculateValueType(valueTypes: Set<KType>): ValueType
17+
public fun calculateValueType(valueTypes: Set<KType>): ValueType
1818

1919
/**
2020
* WARNING: HEAVY!
@@ -23,7 +23,7 @@ internal interface AggregatorInputHandler<in Value : Any, out Return : Any?> : A
2323
* this function can be called to calculate it by getting the types of [values] at runtime.
2424
* This is heavy because it uses reflection on each value.
2525
*/
26-
fun calculateValueType(values: Sequence<Value?>): ValueType
26+
public fun calculateValueType(values: Sequence<Value?>): ValueType
2727

2828
/**
2929
* Preprocesses the input values before aggregation.
@@ -32,7 +32,7 @@ internal interface AggregatorInputHandler<in Value : Any, out Return : Any?> : A
3232
*
3333
* @return A pair of the preprocessed values and the (potentially new) type of the values.
3434
*/
35-
fun preprocessAggregation(
35+
public fun preprocessAggregation(
3636
values: Sequence<Value?>,
3737
valueType: ValueType,
3838
): Pair<Sequence<@UnsafeVariance Value?>, KType>

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorMultipleColumnsHandler.kt

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ import kotlin.reflect.KType
1010
* [AggregatorAggregationHandler].
1111
* It can also calculate the return type of the aggregation given all input column types.
1212
*/
13-
internal interface AggregatorMultipleColumnsHandler<in Value : Any, out Return : Any?> :
13+
public interface AggregatorMultipleColumnsHandler<in Value : Any, out Return : Any?> :
1414
AggregatorHandler<Value, Return> {
1515

1616
/**
1717
* Aggregates the data in the multiple given columns and computes a single resulting value.
1818
* Calls [Aggregator.aggregateSequence] or [Aggregator.aggregateSingleColumn].
1919
*/
20-
fun aggregateMultipleColumns(columns: Sequence<DataColumn<Value?>>): Return
20+
public fun aggregateMultipleColumns(columns: Sequence<DataColumn<Value?>>): Return
2121

2222
/**
2323
* Function that can give the return type of [aggregateMultipleColumns], given types of the columns.
@@ -26,5 +26,5 @@ internal interface AggregatorMultipleColumnsHandler<in Value : Any, out Return :
2626
* @param colTypes The types of the input columns.
2727
* @param colsEmpty If `true`, all the input columns are considered empty. This often affects the return type.
2828
*/
29-
fun calculateReturnTypeMultipleColumns(colTypes: Set<KType>, colsEmpty: Boolean): KType
29+
public fun calculateReturnTypeMultipleColumns(colTypes: Set<KType>, colsEmpty: Boolean): KType
3030
}

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorOptionSwitch.kt

+14-15
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,20 @@ package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators
66
* Aggregators are cached by their parameter value.
77
* @see AggregatorOptionSwitch2
88
*/
9-
@PublishedApi
10-
internal class AggregatorOptionSwitch1<in Param1, in Value : Any, out Return : Any?>(
11-
val name: String,
12-
val getAggregator: (param1: Param1) -> AggregatorProvider<Value, Return>,
9+
public class AggregatorOptionSwitch1<in Param1, in Value : Any, out Return : Any?>(
10+
public val name: String,
11+
public val getAggregator: (param1: Param1) -> AggregatorProvider<Value, Return>,
1312
) {
1413

1514
private val cache: MutableMap<Param1, Aggregator<Value, Return>> = mutableMapOf()
1615

17-
operator fun invoke(param1: Param1): Aggregator<Value, Return> =
16+
public operator fun invoke(param1: Param1): Aggregator<Value, @UnsafeVariance Return> =
1817
cache.getOrPut(param1) {
1918
getAggregator(param1).create(name)
2019
}
2120

2221
@Suppress("FunctionName")
23-
companion object {
22+
public companion object {
2423

2524
/**
2625
* Creates [AggregatorOptionSwitch1].
@@ -31,9 +30,10 @@ internal class AggregatorOptionSwitch1<in Param1, in Value : Any, out Return : A
3130
* MyAggregator.Factory(param1)
3231
* }
3332
*/
34-
fun <Param1, Value : Any, Return : Any?> Factory(
33+
public fun <Param1, Value : Any, Return : Any?> Factory(
3534
getAggregator: (param1: Param1) -> AggregatorProvider<Value, Return>,
36-
) = Provider { name -> AggregatorOptionSwitch1(name, getAggregator) }
35+
): Provider<AggregatorOptionSwitch1<Param1, Value, Return>> =
36+
Provider { name -> AggregatorOptionSwitch1(name, getAggregator) }
3737
}
3838
}
3939

@@ -43,21 +43,20 @@ internal class AggregatorOptionSwitch1<in Param1, in Value : Any, out Return : A
4343
* Aggregators are cached by their parameter values.
4444
* @see AggregatorOptionSwitch1
4545
*/
46-
@PublishedApi
47-
internal class AggregatorOptionSwitch2<in Param1, in Param2, in Value : Any, out Return : Any?>(
48-
val name: String,
49-
val getAggregator: (param1: Param1, param2: Param2) -> AggregatorProvider<Value, Return>,
46+
public class AggregatorOptionSwitch2<in Param1, in Param2, in Value : Any, out Return : Any?>(
47+
public val name: String,
48+
public val getAggregator: (param1: Param1, param2: Param2) -> AggregatorProvider<Value, Return>,
5049
) {
5150

5251
private val cache: MutableMap<Pair<Param1, Param2>, Aggregator<Value, Return>> = mutableMapOf()
5352

54-
operator fun invoke(param1: Param1, param2: Param2): Aggregator<Value, Return> =
53+
public operator fun invoke(param1: Param1, param2: Param2): Aggregator<Value, @UnsafeVariance Return> =
5554
cache.getOrPut(param1 to param2) {
5655
getAggregator(param1, param2).create(name)
5756
}
5857

5958
@Suppress("FunctionName")
60-
companion object {
59+
public companion object {
6160

6261
/**
6362
* Creates [AggregatorOptionSwitch2].
@@ -68,7 +67,7 @@ internal class AggregatorOptionSwitch2<in Param1, in Param2, in Value : Any, out
6867
* MyAggregator.Factory(param1, param2)
6968
* }
7069
*/
71-
fun <Param1, Param2, Value : Any, Return : Any?> Factory(
70+
internal fun <Param1, Param2, Value : Any, Return : Any?> Factory(
7271
getAggregator: (param1: Param1, param2: Param2) -> AggregatorProvider<Value, Return>,
7372
) = Provider { name -> AggregatorOptionSwitch2(name, getAggregator) }
7473
}

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorProvider.kt

+4-3
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ import kotlin.reflect.KProperty
1010
* val myNamedValue by MyFactory
1111
* ```
1212
*/
13-
internal fun interface Provider<out T> {
13+
public fun interface Provider<out T> {
1414

15-
fun create(name: String): T
15+
public fun create(name: String): T
1616
}
1717

1818
internal operator fun <T> Provider<T>.getValue(obj: Any?, property: KProperty<*>): T = create(property.name)
@@ -25,4 +25,5 @@ internal operator fun <T> Provider<T>.getValue(obj: Any?, property: KProperty<*>
2525
* val myAggregator by MyAggregator.Factory
2626
* ```
2727
*/
28-
internal fun interface AggregatorProvider<in Value : Any, out Return : Any?> : Provider<Aggregator<Value, Return>>
28+
public fun interface AggregatorProvider<in Value : Any, out Return : Any?> :
29+
Provider<Aggregator<Value, @UnsafeVariance Return>>

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt

+17-13
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ import org.jetbrains.kotlinx.dataframe.math.stdTypeConversion
2727
import org.jetbrains.kotlinx.dataframe.math.sum
2828
import org.jetbrains.kotlinx.dataframe.math.sumTypeConversion
2929

30-
@PublishedApi
31-
internal object Aggregators {
30+
public object Aggregators {
3231

3332
// TODO these might need some small refactoring
3433

@@ -112,7 +111,7 @@ internal object Aggregators {
112111

113112
// T: Comparable<T> -> T?
114113
// T : Comparable<T & Any>? -> T?
115-
fun <T : Comparable<T & Any>?> min(skipNaN: Boolean): Aggregator<T & Any, T?> = min.invoke(skipNaN).cast2()
114+
public fun <T : Comparable<T & Any>?> min(skipNaN: Boolean): Aggregator<T & Any, T?> = min.invoke(skipNaN).cast2()
116115

117116
private val min by withOneOption { skipNaN: Boolean ->
118117
twoStepSelectingForAny<Comparable<Any>, Comparable<Any>?>(
@@ -124,7 +123,7 @@ internal object Aggregators {
124123

125124
// T: Comparable<T> -> T?
126125
// T : Comparable<T & Any>? -> T?
127-
fun <T : Comparable<T & Any>?> max(skipNaN: Boolean): Aggregator<T & Any, T?> = max.invoke(skipNaN).cast2()
126+
public fun <T : Comparable<T & Any>?> max(skipNaN: Boolean): Aggregator<T & Any, T?> = max.invoke(skipNaN).cast2()
128127

129128
private val max by withOneOption { skipNaN: Boolean ->
130129
twoStepSelectingForAny<Comparable<Any>, Comparable<Any>?>(
@@ -135,36 +134,41 @@ internal object Aggregators {
135134
}
136135

137136
// T: Number? -> Double
138-
val std by withTwoOptions { skipNaN: Boolean, ddof: Int ->
137+
public val std: AggregatorOptionSwitch2<Boolean, Int, Number, Double> by withTwoOptions {
138+
skipNaN: Boolean,
139+
ddof: Int,
140+
->
139141
flattenReducingForNumbers(stdTypeConversion) { type ->
140142
std(type, skipNaN, ddof)
141143
}
142144
}
143145

144146
// step one: T: Number? -> Double
145147
// step two: Double -> Double
146-
val mean by withOneOption { skipNaN: Boolean ->
148+
public val mean: AggregatorOptionSwitch1<Boolean, Number, Double> by withOneOption { skipNaN: Boolean ->
147149
twoStepReducingForNumbers(meanTypeConversion) { type ->
148150
mean(type, skipNaN)
149151
}
150152
}
151153

152154
// T : primitive Number? -> Double?
153155
// T : Comparable<T & Any>? -> T?
154-
fun <T> percentileCommon(
156+
public fun <T> percentileCommon(
155157
percentile: Double,
156158
skipNaN: Boolean,
157159
): Aggregator<T & Any, T?>
158160
where T : Comparable<T & Any>? =
159161
this.percentile.invoke(percentile, skipNaN).cast2()
160162

161163
// T : Comparable<T & Any>? -> T?
162-
fun <T> percentileComparables(percentile: Double): Aggregator<T & Any, T?>
164+
public fun <T> percentileComparables(
165+
percentile: Double,
166+
): Aggregator<T & Any, T?>
163167
where T : Comparable<T & Any>? =
164168
percentileCommon<T>(percentile, skipNaNDefault).cast2()
165169

166170
// T : primitive Number? -> Double?
167-
fun <T> percentileNumbers(
171+
public fun <T> percentileNumbers(
168172
percentile: Double,
169173
skipNaN: Boolean,
170174
): Aggregator<T & Any, Double?>
@@ -182,17 +186,17 @@ internal object Aggregators {
182186

183187
// T : primitive Number? -> Double?
184188
// T : Comparable<T & Any>? -> T?
185-
fun <T> medianCommon(skipNaN: Boolean): Aggregator<T & Any, T?>
189+
public fun <T> medianCommon(skipNaN: Boolean): Aggregator<T & Any, T?>
186190
where T : Comparable<T & Any>? =
187191
median.invoke(skipNaN).cast2()
188192

189193
// T : Comparable<T & Any>? -> T?
190-
fun <T> medianComparables(): Aggregator<T & Any, T?>
194+
public fun <T> medianComparables(): Aggregator<T & Any, T?>
191195
where T : Comparable<T & Any>? =
192196
medianCommon<T>(skipNaNDefault).cast2()
193197

194198
// T : primitive Number? -> Double?
195-
fun <T> medianNumbers(
199+
public fun <T> medianNumbers(
196200
skipNaN: Boolean,
197201
): Aggregator<T & Any, Double?>
198202
where T : Comparable<T & Any>?, T : Number? =
@@ -211,7 +215,7 @@ internal object Aggregators {
211215
// Byte -> Int
212216
// Short -> Int
213217
// Nothing -> Double
214-
val sum by withOneOption { skipNaN: Boolean ->
218+
public val sum: AggregatorOptionSwitch1<Boolean, Number, Number> by withOneOption { skipNaN: Boolean ->
215219
twoStepReducingForNumbers(sumTypeConversion) { type ->
216220
sum(type, skipNaN)
217221
}

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/ValueType.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@ import kotlin.reflect.KType
1010
* for the values to become the correct value type. If `false`, the values are already the right type,
1111
* or a simple cast will suffice.
1212
*/
13-
internal data class ValueType(val kType: KType, val needsFullConversion: Boolean = false)
13+
public data class ValueType(val kType: KType, val needsFullConversion: Boolean = false)
1414

1515
internal fun KType.toValueType(needsFullConversion: Boolean = false): ValueType = ValueType(this, needsFullConversion)

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/multipleColumnsHandlers/TwoStepMultipleColumnsHandler.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import kotlin.reflect.KType
3131
* If not supplied, the handler of the first step is reused.
3232
* @see [FlatteningMultipleColumnsHandler]
3333
*/
34-
internal class TwoStepMultipleColumnsHandler<in Value : Any, out Return : Any?>(
34+
internal class TwoStepMultipleColumnsHandler<in Value : Any, Return : Any?>(
3535
stepTwoAggregationHandler: AggregatorAggregationHandler<Return & Any, Return>? = null,
3636
stepTwoInputHandler: AggregatorInputHandler<Return & Any, Return>? = null,
3737
) : AggregatorMultipleColumnsHandler<Value, Return> {

0 commit comments

Comments
 (0)