Skip to content

Commit 844fa24

Browse files
committed
Merge branch 'mean-rework' into aggregators
# Conflicts: # core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt
2 parents 756473e + 2209c1d commit 844fa24

File tree

6 files changed

+291
-66
lines changed

6 files changed

+291
-66
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt

Lines changed: 171 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,30 +21,181 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.numberColumns
2121
import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns
2222
import org.jetbrains.kotlinx.dataframe.impl.suggestIfNull
2323
import org.jetbrains.kotlinx.dataframe.math.mean
24+
import java.math.BigDecimal
25+
import java.math.BigInteger
26+
import kotlin.experimental.ExperimentalTypeInference
2427
import kotlin.reflect.KProperty
2528
import kotlin.reflect.typeOf
2629

2730
// region DataColumn
2831

29-
public fun <T : Number> DataColumn<T?>.mean(skipNA: Boolean = skipNA_default): Double =
30-
meanOrNull(skipNA).suggestIfNull("mean")
32+
// region mean
3133

32-
public fun <T : Number> DataColumn<T?>.meanOrNull(skipNA: Boolean = skipNA_default): Double? =
33-
Aggregators.mean(skipNA).aggregate(this)
34+
@JvmName("meanInt")
35+
public fun DataColumn<Int?>.mean(): Double = meanOrNull().suggestIfNull("mean")
3436

35-
public inline fun <T, reified R : Number> DataColumn<T>.meanOf(
36-
skipNA: Boolean = skipNA_default,
37-
noinline expression: (T) -> R?,
38-
): Double = Aggregators.mean(skipNA).cast2<R?, Double>().aggregateOf(this, expression) ?: Double.NaN
37+
@JvmName("meanShort")
38+
public fun DataColumn<Short?>.mean(): Double = meanOrNull().suggestIfNull("mean")
39+
40+
@JvmName("meanByte")
41+
public fun DataColumn<Byte?>.mean(): Double = meanOrNull().suggestIfNull("mean")
42+
43+
@JvmName("meanLong")
44+
public fun DataColumn<Long?>.mean(): Double = meanOrNull().suggestIfNull("mean")
45+
46+
@JvmName("meanDouble")
47+
public fun DataColumn<Double?>.mean(skipNA: Boolean = skipNA_default): Double = meanOrNull(skipNA).suggestIfNull("mean")
48+
49+
@JvmName("meanFloat")
50+
public fun DataColumn<Float?>.mean(skipNA: Boolean = skipNA_default): Double = meanOrNull(skipNA).suggestIfNull("mean")
51+
52+
@JvmName("meanBigInteger")
53+
public fun DataColumn<BigInteger?>.mean(): BigDecimal = meanOrNull().suggestIfNull("mean")
54+
55+
@JvmName("meanBigDecimal")
56+
public fun DataColumn<BigDecimal?>.mean(): BigDecimal = meanOrNull().suggestIfNull("mean")
57+
58+
@JvmName("meanNumber")
59+
public fun DataColumn<Number?>.mean(skipNA: Boolean = skipNA_default): Number? = meanOrNull(skipNA)
3960

4061
// endregion
4162

42-
// region DataRow
63+
// region meanOrNull
64+
65+
@JvmName("meanOrNullInt")
66+
public fun DataColumn<Int?>.meanOrNull(): Double? = Aggregators.mean.toDouble(skipNA_default).aggregate(this)
67+
68+
@JvmName("meanOrNullShort")
69+
public fun DataColumn<Short?>.meanOrNull(): Double? = Aggregators.mean.toDouble(skipNA_default).aggregate(this)
70+
71+
@JvmName("meanOrNullByte")
72+
public fun DataColumn<Byte?>.meanOrNull(): Double? = Aggregators.mean.toDouble(skipNA_default).aggregate(this)
73+
74+
@JvmName("meanOrNullLong")
75+
public fun DataColumn<Long?>.meanOrNull(): Double? = Aggregators.mean.toDouble(skipNA_default).aggregate(this)
4376

77+
@JvmName("meanOrNullDouble")
78+
public fun DataColumn<Double?>.meanOrNull(skipNA: Boolean = skipNA_default): Double? =
79+
Aggregators.mean.toDouble(skipNA).aggregate(this)
80+
81+
@JvmName("meanOrNullFloat")
82+
public fun DataColumn<Float?>.meanOrNull(skipNA: Boolean = skipNA_default): Double? =
83+
Aggregators.mean.toDouble(skipNA).aggregate(this)
84+
85+
@JvmName("meanOrNullBigInteger")
86+
public fun DataColumn<BigInteger?>.meanOrNull(): BigDecimal? = Aggregators.mean.toBigDecimal.aggregate(this)
87+
88+
@JvmName("meanOrNullBigDecimal")
89+
public fun DataColumn<BigDecimal?>.meanOrNull(): BigDecimal? = Aggregators.mean.toBigDecimal.aggregate(this)
90+
91+
@JvmName("meanOrNullNumber")
92+
public fun DataColumn<Number?>.meanOrNull(skipNA: Boolean = skipNA_default): Number? =
93+
Aggregators.mean.toNumber(skipNA).aggregate(this)
94+
95+
// endregion
96+
97+
// region meanOf
98+
99+
@OptIn(ExperimentalTypeInference::class)
100+
@JvmName("meanOfInt")
101+
//@OverloadResolutionByLambdaReturnType
102+
public fun <T> DataColumn<T>.meanOf(expression: (T) -> Int?): Double =
103+
Aggregators.mean.toDouble(skipNA_default)
104+
.cast2<Int?, Double>()
105+
.aggregateOf(this, expression)
106+
?: Double.NaN
107+
108+
@OptIn(ExperimentalTypeInference::class)
109+
@JvmName("meanOfShort")
110+
@OverloadResolutionByLambdaReturnType
111+
public fun <T> DataColumn<T>.meanOf(expression: (T) -> Short?): Double =
112+
Aggregators.mean.toDouble(skipNA_default)
113+
.cast2<Short?, Double>()
114+
.aggregateOf(this, expression)
115+
?: Double.NaN
116+
117+
@OptIn(ExperimentalTypeInference::class)
118+
@JvmName("meanOfByte")
119+
@OverloadResolutionByLambdaReturnType
120+
public fun <T> DataColumn<T>.meanOf(expression: (T) -> Byte?): Double =
121+
Aggregators.mean.toDouble(skipNA_default)
122+
.cast2<Byte?, Double>()
123+
.aggregateOf(this, expression)
124+
?: Double.NaN
125+
126+
@OptIn(ExperimentalTypeInference::class)
127+
@JvmName("meanOfLong")
128+
@OverloadResolutionByLambdaReturnType
129+
public fun <T> DataColumn<T>.meanOf(expression: (T) -> Long?): Double =
130+
Aggregators.mean.toDouble(skipNA_default)
131+
.cast2<Long?, Double>()
132+
.aggregateOf(this, expression)
133+
?: Double.NaN
134+
135+
@OptIn(ExperimentalTypeInference::class)
136+
@JvmName("meanOfDouble")
137+
@OverloadResolutionByLambdaReturnType
138+
public fun <T> DataColumn<T>.meanOf(skipNA: Boolean = skipNA_default, expression: (T) -> Double?): Double =
139+
Aggregators.mean.toDouble(skipNA)
140+
.cast2<Double?, Double>()
141+
.aggregateOf(this, expression)
142+
?: Double.NaN
143+
144+
@OptIn(ExperimentalTypeInference::class)
145+
@JvmName("meanOfFloat")
146+
@OverloadResolutionByLambdaReturnType
147+
public fun <T> DataColumn<T>.meanOf(skipNA: Boolean = skipNA_default, expression: (T) -> Float?): Double =
148+
Aggregators.mean.toDouble(skipNA)
149+
.cast2<Float?, Double>()
150+
.aggregateOf(this, expression)
151+
?: Double.NaN
152+
153+
@OptIn(ExperimentalTypeInference::class)
154+
@JvmName("meanOfBigInteger")
155+
@OverloadResolutionByLambdaReturnType
156+
public fun <T> DataColumn<T>.meanOf(expression: (T) -> BigInteger?): BigDecimal? =
157+
Aggregators.mean.toBigDecimal
158+
.cast2<BigInteger?, BigDecimal?>()
159+
.aggregateOf(this, expression)
160+
161+
@OptIn(ExperimentalTypeInference::class)
162+
@JvmName("meanOfBigDecimal")
163+
@OverloadResolutionByLambdaReturnType
164+
public fun <T> DataColumn<T>.meanOf(expression: (T) -> BigDecimal?): BigDecimal? =
165+
Aggregators.mean.toBigDecimal
166+
.cast2<BigDecimal?, BigDecimal?>()
167+
.aggregateOf(this, expression)
168+
169+
@OptIn(ExperimentalTypeInference::class)
170+
@JvmName("meanOfNumber")
171+
@OverloadResolutionByLambdaReturnType
172+
public fun <T> DataColumn<T>.meanOf(skipNA: Boolean = skipNA_default, expression: (T) -> Number?): Number? =
173+
Aggregators.mean.toNumber(skipNA)
174+
.cast2<Number?, Number?>()
175+
.aggregateOf(this, expression)
176+
177+
public fun main() {
178+
val data = (1..10).toList()
179+
val df = data.toDataFrame()
180+
181+
val mean = df.value.meanOf { if (true) it.toLong() else it.toDouble() }
182+
val mean2 = df.value.meanOf { it.toBigInteger() }
183+
184+
println(mean)
185+
println(mean!!::class)
186+
}
187+
188+
// endregion
189+
190+
// endregion
191+
192+
// region DataRow
193+
// todo
44194
public fun AnyRow.rowMean(skipNA: Boolean = skipNA_default): Double =
45195
values().filterIsInstance<Number>().map { it.toDouble() }.mean(skipNA)
46196

47-
public inline fun <reified T : Number> AnyRow.rowMeanOf(): Double = values().filterIsInstance<T>().mean(typeOf<T>())
197+
public inline fun <reified T : Number> AnyRow.rowMeanOf(): Double =
198+
values().filterIsInstance<T>().mean(typeOf<T>()) as Double
48199

49200
// endregion
50201

@@ -55,7 +206,7 @@ public fun <T> DataFrame<T>.mean(skipNA: Boolean = skipNA_default): DataRow<T> =
55206
public fun <T, C : Number> DataFrame<T>.meanFor(
56207
skipNA: Boolean = skipNA_default,
57208
columns: ColumnsForAggregateSelector<T, C?>,
58-
): DataRow<T> = Aggregators.mean(skipNA).aggregateFor(this, columns)
209+
): DataRow<T> = Aggregators.mean.toNumber(skipNA).aggregateFor(this, columns)
59210

60211
public fun <T> DataFrame<T>.meanFor(vararg columns: String, skipNA: Boolean = skipNA_default): DataRow<T> =
61212
meanFor(skipNA) { columns.toNumberColumns() }
@@ -72,10 +223,11 @@ public fun <T, C : Number> DataFrame<T>.meanFor(
72223
skipNA: Boolean = skipNA_default,
73224
): DataRow<T> = meanFor(skipNA) { columns.toColumnSet() }
74225

226+
// todo
75227
public fun <T, C : Number> DataFrame<T>.mean(
76228
skipNA: Boolean = skipNA_default,
77229
columns: ColumnsSelector<T, C?>,
78-
): Double = Aggregators.mean(skipNA).aggregateAll(this, columns) as Double? ?: Double.NaN
230+
): Double = Aggregators.mean.toNumber(skipNA).aggregateAll(this, columns) as Double? ?: Double.NaN
79231

80232
public fun <T> DataFrame<T>.mean(vararg columns: String, skipNA: Boolean = skipNA_default): Double =
81233
mean(skipNA) { columns.toNumberColumns() }
@@ -93,7 +245,7 @@ public fun <T, C : Number> DataFrame<T>.mean(vararg columns: KProperty<C?>, skip
93245
public inline fun <T, reified D : Number> DataFrame<T>.meanOf(
94246
skipNA: Boolean = skipNA_default,
95247
noinline expression: RowExpression<T, D?>,
96-
): Double = Aggregators.mean(skipNA).of(this, expression) ?: Double.NaN
248+
): Double = Aggregators.mean.toNumber(skipNA).of(this, expression) as Double? ?: Double.NaN
97249

98250
// endregion
99251

@@ -104,7 +256,7 @@ public fun <T> Grouped<T>.mean(skipNA: Boolean = skipNA_default): DataFrame<T> =
104256
public fun <T, C : Number> Grouped<T>.meanFor(
105257
skipNA: Boolean = skipNA_default,
106258
columns: ColumnsForAggregateSelector<T, C?>,
107-
): DataFrame<T> = Aggregators.mean(skipNA).aggregateFor(this, columns)
259+
): DataFrame<T> = Aggregators.mean.toNumber(skipNA).aggregateFor(this, columns)
108260

109261
public fun <T> Grouped<T>.meanFor(vararg columns: String, skipNA: Boolean = skipNA_default): DataFrame<T> =
110262
meanFor(skipNA) { columns.toNumberColumns() }
@@ -125,7 +277,7 @@ public fun <T, C : Number> Grouped<T>.mean(
125277
name: String? = null,
126278
skipNA: Boolean = skipNA_default,
127279
columns: ColumnsSelector<T, C?>,
128-
): DataFrame<T> = Aggregators.mean(skipNA).aggregateAll(this, name, columns)
280+
): DataFrame<T> = Aggregators.mean.toNumber(skipNA).aggregateAll(this, name, columns)
129281

130282
public fun <T> Grouped<T>.mean(
131283
vararg columns: String,
@@ -151,7 +303,7 @@ public inline fun <T, reified R : Number> Grouped<T>.meanOf(
151303
name: String? = null,
152304
skipNA: Boolean = skipNA_default,
153305
crossinline expression: RowExpression<T, R?>,
154-
): DataFrame<T> = Aggregators.mean(skipNA).aggregateOf(this, name, expression)
306+
): DataFrame<T> = Aggregators.mean.toNumber(skipNA).aggregateOf(this, name, expression)
155307

156308
// endregion
157309

@@ -207,7 +359,7 @@ public fun <T, C : Number> PivotGroupBy<T>.meanFor(
207359
skipNA: Boolean = skipNA_default,
208360
separate: Boolean = false,
209361
columns: ColumnsForAggregateSelector<T, C?>,
210-
): DataFrame<T> = Aggregators.mean(skipNA).aggregateFor(this, separate, columns)
362+
): DataFrame<T> = Aggregators.mean.toNumber(skipNA).aggregateFor(this, separate, columns)
211363

212364
public fun <T> PivotGroupBy<T>.meanFor(
213365
vararg columns: String,
@@ -232,7 +384,7 @@ public fun <T, C : Number> PivotGroupBy<T>.meanFor(
232384
public fun <T, R : Number> PivotGroupBy<T>.mean(
233385
skipNA: Boolean = skipNA_default,
234386
columns: ColumnsSelector<T, R?>,
235-
): DataFrame<T> = Aggregators.mean(skipNA).aggregateAll(this, columns)
387+
): DataFrame<T> = Aggregators.mean.toNumber(skipNA).aggregateAll(this, columns)
236388

237389
public fun <T> PivotGroupBy<T>.mean(vararg columns: String, skipNA: Boolean = skipNA_default): DataFrame<T> =
238390
mean(skipNA) { columns.toColumnsSetOf() }
@@ -252,6 +404,6 @@ public fun <T, R : Number> PivotGroupBy<T>.mean(
252404
public inline fun <T, reified R : Number> PivotGroupBy<T>.meanOf(
253405
skipNA: Boolean = skipNA_default,
254406
crossinline expression: RowExpression<T, R?>,
255-
): DataFrame<T> = Aggregators.mean(skipNA).aggregateOf(this, expression)
407+
): DataFrame<T> = Aggregators.mean.toNumber(skipNA).aggregateOf(this, expression)
256408

257409
// endregion
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,22 @@
11
package org.jetbrains.kotlinx.dataframe.impl
22

3+
import java.math.BigDecimal
4+
import java.math.BigInteger
5+
36
internal fun <T> T?.throwIfNull(message: String): T = this ?: throw NoSuchElementException(message)
47

58
@PublishedApi
69
internal fun <T> T?.suggestIfNull(operation: String): T =
710
throwIfNull("No elements for `$operation` operation. Use `${operation}OrNull` instead.")
11+
12+
@PublishedApi
13+
internal fun BigInteger?.suggestIfNull(operation: String): BigInteger =
14+
throwIfNull(
15+
"The `$operation` operation either had no elements, or the result is NaN. Use `${operation}OrNull` instead.",
16+
)
17+
18+
@PublishedApi
19+
internal fun BigDecimal?.suggestIfNull(operation: String): BigDecimal =
20+
throwIfNull(
21+
"The `$operation` operation either had no elements, or the result is NaN. Use `${operation}OrNull` instead.",
22+
)

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/TypeUtils.kt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,3 +643,13 @@ internal fun Iterable<Any>.classes(): Set<KClass<*>> = mapTo(mutableSetOf()) { i
643643
* @return A set of [KType] objects corresponding to the star-projected runtime types of elements in the iterable.
644644
*/
645645
internal fun Iterable<Any>.types(): Set<KType> = classes().mapTo(mutableSetOf()) { it.createStarProjectedType(false) }
646+
647+
/**
648+
* Casts [this]: [Number] to a [Double]. If [this] is `null`, returns [Double.NaN].
649+
*/
650+
internal fun Number?.asDoubleOrNaN(): Double = this as Double? ?: Double.NaN
651+
652+
/**
653+
* Casts [this]: [Number] to a [Float]. If [this] is `null`, returns [Float.NaN].
654+
*/
655+
internal fun Number?.asFloatOrNaN(): Float = this as Float? ?: Float.NaN

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,23 @@ internal object Aggregators {
8484
flatteningChangingTypes<Number, Double> { std(it, skipNA, ddof) }
8585
}
8686

87-
val mean by withOneOption { skipNA: Boolean ->
88-
twoStepChangingType({ mean(it, skipNA) }) { mean(skipNA) }
87+
@Suppress("ClassName")
88+
object mean {
89+
val toNumber = withOption { skipNA: Boolean ->
90+
extendsNumbers { mean(it, skipNA) }
91+
}.create("meanToNumber")
92+
93+
val toDouble = withOption { skipNA: Boolean ->
94+
changesType(
95+
aggregateWithType = { mean(it, skipNA).asDoubleOrNaN() },
96+
aggregateWithValues = { mean(skipNA) },
97+
)
98+
}.create("meanToDouble")
99+
100+
val toBigDecimal = changesType(
101+
aggregateWithType = { mean(it) as BigDecimal? },
102+
aggregateWithValues = { filterNotNull().mean() },
103+
).create("meanToBigDecimal")
89104
}
90105

91106
val percentile by withOneOption { percentile: Double ->

0 commit comments

Comments
 (0)