Skip to content

Commit f3327d6

Browse files
committed
mean rework: returns null for no values regardless of the type. Added orNull overloads for each mean function. Added specific overloads for each primitive type -> Double(?) and big number -> BigDecimal(?)
1 parent 844fa24 commit f3327d6

File tree

13 files changed

+416
-153
lines changed

13 files changed

+416
-153
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/describe.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public interface ColumnDescription {
2525
public val nulls: Int
2626
public val top: Any
2727
public val freq: Int
28-
public val mean: Double
28+
public val mean: Number?
2929
public val std: Double
3030
public val min: Any
3131
public val p25: Any

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt

+312-61
Large diffs are not rendered by default.

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt

+14-16
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators
22

3-
import org.jetbrains.kotlinx.dataframe.math.mean
3+
import org.jetbrains.kotlinx.dataframe.math.meanOrNull
44
import org.jetbrains.kotlinx.dataframe.math.median
55
import org.jetbrains.kotlinx.dataframe.math.percentile
66
import org.jetbrains.kotlinx.dataframe.math.std
77
import org.jetbrains.kotlinx.dataframe.math.sum
8+
import java.math.BigDecimal
89
import kotlin.reflect.KType
910

1011
@PublishedApi
@@ -86,21 +87,18 @@ internal object Aggregators {
8687

8788
@Suppress("ClassName")
8889
object mean {
89-
val toNumber = withOption { skipNA: Boolean ->
90-
extendsNumbers { mean(it, skipNA) }
91-
}.create("meanToNumber")
92-
93-
val toDouble = withOption { skipNA: Boolean ->
94-
changesType(
95-
aggregateWithType = { mean(it, skipNA).asDoubleOrNaN() },
96-
aggregateWithValues = { mean(skipNA) },
97-
)
98-
}.create("meanToDouble")
99-
100-
val toBigDecimal = changesType(
101-
aggregateWithType = { mean(it) as BigDecimal? },
102-
aggregateWithValues = { filterNotNull().mean() },
103-
).create("meanToBigDecimal")
90+
val toNumber = withOneOption { skipNA: Boolean ->
91+
twoStepForNumbers { meanOrNull(it, skipNA) }
92+
}.create(mean::class.simpleName!!)
93+
94+
val toDouble = withOneOption { skipNA: Boolean ->
95+
twoStepForNumbers { meanOrNull(it, skipNA) as Double? }
96+
}.create(mean::class.simpleName!!)
97+
98+
val toBigDecimal =
99+
twoStepForNumbers {
100+
meanOrNull(it) as BigDecimal?
101+
}.create(mean::class.simpleName!!)
104102
}
105103

106104
val percentile by withOneOption { percentile: Double ->

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepNumbersAggregator.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ internal class TwoStepNumbersAggregator<Return : Number>(
4242
) : AggregatorBase<Number, Return>(name, aggregator) {
4343

4444
override fun aggregate(values: Iterable<Number>, type: KType): Return? {
45-
require(type.isSubtypeOf(typeOf<Number>())) {
46-
"${TwoStepNumbersAggregator::class.simpleName}: Type $type is not a subtype of Number"
45+
require(type.isSubtypeOf(typeOf<Number?>())) {
46+
"${TwoStepNumbersAggregator::class.simpleName}: Type $type is not a subtype of Number?"
4747
}
4848
return super.aggregate(values, type)
4949
}

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/ofRowExpression.kt

+4-3
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ internal inline fun <C, reified V, R> Aggregator<V, R>.aggregateOf(
3030
internal inline fun <T, reified C, R> Aggregator<*, R>.aggregateOf(
3131
frame: DataFrame<T>,
3232
crossinline expression: RowExpression<T, C>,
33-
): R? = (this as Aggregator<C, R>).aggregateOf(frame.rows()) { expression(it, it) } // TODO: inline
33+
): R? = (this as Aggregator<C, R>).aggregateOf(frame.rows()) { expression(it, it) }
3434

3535
@PublishedApi
3636
internal fun <T, C, R> Aggregator<*, R>.aggregateOfDelegated(
@@ -50,7 +50,7 @@ internal inline fun <T, reified C, R> Aggregator<*, R>.of(
5050

5151
@PublishedApi
5252
internal inline fun <C, reified V, R> Aggregator<V, R>.of(data: DataColumn<C>, crossinline expression: (C) -> V): R? =
53-
aggregateOf(data.values()) { expression(it) } // TODO: inline
53+
aggregateOf(data.values()) { expression(it) }
5454

5555
@PublishedApi
5656
internal inline fun <T, reified C, reified R> Aggregator<*, R>.aggregateOf(
@@ -75,7 +75,8 @@ internal inline fun <T, reified C, reified R> Grouped<T>.aggregateOf(
7575
val type = typeOf<R>()
7676
return aggregateInternal {
7777
val value = aggregator.aggregateOf(df, expression)
78-
yield(path, value, type, null, false)
78+
val inferType = !aggregator.preservesType
79+
yield(path, value, type, null, inferType)
7980
}
8081
}
8182

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/describe.kt

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import org.jetbrains.kotlinx.dataframe.api.isNumber
1616
import org.jetbrains.kotlinx.dataframe.api.map
1717
import org.jetbrains.kotlinx.dataframe.api.maxOrNull
1818
import org.jetbrains.kotlinx.dataframe.api.mean
19+
import org.jetbrains.kotlinx.dataframe.api.meanOrNull
1920
import org.jetbrains.kotlinx.dataframe.api.medianOrNull
2021
import org.jetbrains.kotlinx.dataframe.api.minOrNull
2122
import org.jetbrains.kotlinx.dataframe.api.move
@@ -56,7 +57,7 @@ internal fun describeImpl(cols: List<AnyCol>): DataFrame<ColumnDescription> {
5657
?.key
5758
}
5859
if (hasNumericCols) {
59-
ColumnDescription::mean from { if (it.isNumber()) it.asNumbers().mean() else null }
60+
ColumnDescription::mean from { if (it.isNumber()) it.asNumbers().meanOrNull() else null }
6061
ColumnDescription::std from { if (it.isNumber()) it.asNumbers().std() else null }
6162
}
6263
if (hasComparableCols || hasNumericCols) {
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package org.jetbrains.kotlinx.dataframe.math
22

3-
import org.jetbrains.kotlinx.dataframe.api.isNaN
43
import org.jetbrains.kotlinx.dataframe.api.skipNA_default
54
import org.jetbrains.kotlinx.dataframe.impl.api.toBigDecimal
65
import org.jetbrains.kotlinx.dataframe.impl.convertToUnifiedNumberType
@@ -13,52 +12,53 @@ import kotlin.reflect.KType
1312
import kotlin.reflect.full.withNullability
1413
import kotlin.reflect.typeOf
1514

16-
/** @include [Sequence.mean] */
15+
/** @include [Sequence.meanOrNull] */
1716
@PublishedApi
18-
internal fun <T : Number> Iterable<T>.mean(type: KType, skipNA: Boolean = skipNA_default): Number? =
19-
asSequence().mean(type, skipNA)
17+
internal fun <T : Number> Iterable<T>.meanOrNull(type: KType, skipNA: Boolean = skipNA_default): Number? =
18+
asSequence().meanOrNull(type, skipNA)
2019

2120
/**
2221
* Returns the mean of the numbers in [this].
2322
*
2423
* If the input is empty, the return value will be `null`.
2524
*
2625
* If the [type] given or input consists of only [Int], [Short], [Byte], [Long], [Double], or [Float],
27-
* the return type will be [Double]`?` (Never `NaN`).
26+
* the return type will be [Double].
2827
*
29-
* If the [type] given or the input contains [BigInteger] or [BigDecimal], the return type will be [BigDecimal]`?`.
28+
* If the [type] given or the input contains [BigInteger] or [BigDecimal],
29+
* the return type will be [BigDecimal].
3030
* @param type The type of the numbers in the sequence.
3131
* @param skipNA Whether to skip `NaN` values (default: `false`). Only relevant for [Double] and [Float].
3232
*/
3333
@Suppress("UNCHECKED_CAST")
34-
internal fun <T : Number> Sequence<T>.mean(type: KType, skipNA: Boolean = skipNA_default): Number? {
34+
internal fun <T : Number> Sequence<T>.meanOrNull(type: KType, skipNA: Boolean = skipNA_default): Number? {
3535
if (type.isMarkedNullable) {
36-
return filterNotNull().mean(type.withNullability(false), skipNA)
36+
return filterNotNull().meanOrNull(type.withNullability(false), skipNA)
3737
}
3838
return when (type.classifier) {
39-
// Double -> Double?
40-
Double::class -> (this as Sequence<Double>).mean(skipNA).takeUnless { it.isNaN }
39+
// Double -> Double
40+
Double::class -> (this as Sequence<Double>).meanOrNull(skipNA)
4141

42-
// Float -> Double?
43-
Float::class -> (this as Sequence<Float>).mean(skipNA).takeUnless { it.isNaN }
42+
// Float -> Double
43+
Float::class -> (this as Sequence<Float>).meanOrNull(skipNA)
4444

45-
// Int -> Double?
46-
Int::class -> (this as Sequence<Int>).map { it.toDouble() }.mean(false).takeUnless { it.isNaN }
45+
// Int -> Double
46+
Int::class -> (this as Sequence<Int>).map { it.toDouble() }.meanOrNull(false)
4747

48-
// Short -> Double?
49-
Short::class -> (this as Sequence<Short>).map { it.toDouble() }.mean(false).takeUnless { it.isNaN }
48+
// Short -> Double
49+
Short::class -> (this as Sequence<Short>).map { it.toDouble() }.meanOrNull(false)
5050

51-
// Byte -> Double?
52-
Byte::class -> (this as Sequence<Byte>).map { it.toDouble() }.mean(false).takeUnless { it.isNaN }
51+
// Byte -> Double
52+
Byte::class -> (this as Sequence<Byte>).map { it.toDouble() }.meanOrNull(false)
5353

54-
// Long -> Double?
55-
Long::class -> (this as Sequence<Long>).map { it.toDouble() }.mean(false).takeUnless { it.isNaN }
54+
// Long -> Double
55+
Long::class -> (this as Sequence<Long>).map { it.toDouble() }.meanOrNull(false)
5656

57-
// BigInteger -> BigDecimal?
58-
BigInteger::class -> (this as Sequence<BigInteger>).mean()
57+
// BigInteger -> BigDecimal
58+
BigInteger::class -> (this as Sequence<BigInteger>).meanOrNull()
5959

60-
// BigDecimal -> BigDecimal?
61-
BigDecimal::class -> (this as Sequence<BigDecimal>).mean()
60+
// BigDecimal -> BigDecimal
61+
BigDecimal::class -> (this as Sequence<BigDecimal>).meanOrNull()
6262

6363
// Number -> Conversion(Common number type) -> Number? (Double or BigDecimal?)
6464
// fallback case, heavy as it needs to collect all types at runtime
@@ -69,7 +69,7 @@ internal fun <T : Number> Sequence<T>.mean(type: KType, skipNA: Boolean = skipNA
6969
error("Cannot find unified number type for $numberTypes")
7070
}
7171
this.convertToUnifiedNumberType(unifiedType)
72-
.mean(unifiedType, skipNA)
72+
.meanOrNull(unifiedType, skipNA)
7373
}
7474

7575
// this means the sequence is empty
@@ -79,43 +79,43 @@ internal fun <T : Number> Sequence<T>.mean(type: KType, skipNA: Boolean = skipNA
7979
}
8080
}
8181

82-
internal fun Sequence<Double>.mean(skipNA: Boolean = skipNA_default): Double {
82+
internal fun Sequence<Double>.meanOrNull(skipNA: Boolean = skipNA_default): Double? {
8383
var count = 0
8484
var sum: Double = 0.toDouble()
8585
for (element in this) {
8686
if (element.isNaN()) {
8787
if (skipNA) {
8888
continue
8989
} else {
90-
return Double.NaN
90+
return null
9191
}
9292
}
9393
sum += element
9494
count++
9595
}
96-
return if (count > 0) sum / count else Double.NaN
96+
return if (count > 0) sum / count else null
9797
}
9898

9999
@JvmName("meanFloat")
100-
internal fun Sequence<Float>.mean(skipNA: Boolean = skipNA_default): Double {
100+
internal fun Sequence<Float>.meanOrNull(skipNA: Boolean = skipNA_default): Double? {
101101
var count = 0
102102
var sum: Double = 0.toDouble()
103103
for (element in this) {
104104
if (element.isNaN()) {
105105
if (skipNA) {
106106
continue
107107
} else {
108-
return Double.NaN
108+
return null
109109
}
110110
}
111111
sum += element
112112
count++
113113
}
114-
return if (count > 0) sum / count else Double.NaN
114+
return if (count > 0) sum / count else null
115115
}
116116

117117
@JvmName("bigIntegerMean")
118-
internal fun Sequence<BigInteger>.mean(): BigDecimal? {
118+
internal fun Sequence<BigInteger>.meanOrNull(): BigDecimal? {
119119
var count = 0
120120
val sum = sumOf {
121121
count++
@@ -125,7 +125,7 @@ internal fun Sequence<BigInteger>.mean(): BigDecimal? {
125125
}
126126

127127
@JvmName("bigDecimalMean")
128-
internal fun Sequence<BigDecimal>.mean(): BigDecimal? {
128+
internal fun Sequence<BigDecimal>.meanOrNull(): BigDecimal? {
129129
var count = 0
130130
val sum = sumOf {
131131
count++
@@ -135,65 +135,65 @@ internal fun Sequence<BigDecimal>.mean(): BigDecimal? {
135135
}
136136

137137
@JvmName("doubleMean")
138-
internal fun Iterable<Double>.mean(skipNA: Boolean = skipNA_default): Double = asSequence().mean(skipNA)
138+
internal fun Iterable<Double>.meanOrNull(skipNA: Boolean = skipNA_default): Double? = asSequence().meanOrNull(skipNA)
139139

140140
@JvmName("floatMean")
141-
internal fun Iterable<Float>.mean(skipNA: Boolean = skipNA_default): Double = asSequence().mean(skipNA)
141+
internal fun Iterable<Float>.meanOrNull(skipNA: Boolean = skipNA_default): Double? = asSequence().meanOrNull(skipNA)
142142

143143
@JvmName("bigDecimalMean")
144-
internal fun Iterable<BigDecimal>.mean(): BigDecimal? = asSequence().mean()
144+
internal fun Iterable<BigDecimal>.meanOrNull(): BigDecimal? = asSequence().meanOrNull()
145145

146146
@JvmName("bigIntegerMean")
147-
internal fun Iterable<BigInteger>.mean(): BigDecimal? = asSequence().mean()
147+
internal fun Iterable<BigInteger>.meanOrNull(): BigDecimal? = asSequence().meanOrNull()
148148

149149
@JvmName("intMean")
150-
internal fun Iterable<Int>.mean(): Double =
150+
internal fun Iterable<Int>.meanOrNull(): Double? =
151151
if (this is Collection) {
152-
if (size > 0) sumOf { it.toDouble() } / size else Double.NaN
152+
if (size > 0) sumOf { it.toDouble() } / size else null
153153
} else {
154154
var count = 0
155155
val sum = sumOf {
156156
count++
157157
it.toDouble()
158158
}
159-
if (count > 0) sum / count else Double.NaN
159+
if (count > 0) sum / count else null
160160
}
161161

162162
@JvmName("shortMean")
163-
internal fun Iterable<Short>.mean(): Double =
163+
internal fun Iterable<Short>.meanOrNull(): Double? =
164164
if (this is Collection) {
165-
if (size > 0) sumOf { it.toDouble() } / size else Double.NaN
165+
if (size > 0) sumOf { it.toDouble() } / size else null
166166
} else {
167167
var count = 0
168168
val sum = sumOf {
169169
count++
170170
it.toDouble()
171171
}
172-
if (count > 0) sum / count else Double.NaN
172+
if (count > 0) sum / count else null
173173
}
174174

175175
@JvmName("byteMean")
176-
internal fun Iterable<Byte>.mean(): Double =
176+
internal fun Iterable<Byte>.meanOrNull(): Double? =
177177
if (this is Collection) {
178-
if (size > 0) sumOf { it.toDouble() } / size else Double.NaN
178+
if (size > 0) sumOf { it.toDouble() } / size else null
179179
} else {
180180
var count = 0
181181
val sum = sumOf {
182182
count++
183183
it.toDouble()
184184
}
185-
if (count > 0) sum / count else Double.NaN
185+
if (count > 0) sum / count else null
186186
}
187187

188188
@JvmName("longMean")
189-
internal fun Iterable<Long>.mean(): Double =
189+
internal fun Iterable<Long>.meanOrNull(): Double? =
190190
if (this is Collection) {
191-
if (size > 0) sumOf { it.toDouble() } / size else Double.NaN
191+
if (size > 0) sumOf { it.toDouble() } / size else null
192192
} else {
193193
var count = 0
194194
val sum = sumOf {
195195
count++
196196
it.toDouble()
197197
}
198-
if (count > 0) sum / count else Double.NaN
198+
if (count > 0) sum / count else null
199199
}

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/describe.kt

+3-3
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class DescribeTests {
3939
nulls shouldBe 1
4040
top shouldBe 1
4141
freq shouldBe 1
42-
mean shouldBe 4.5
42+
mean shouldBe 4.5.toBigDecimal()
4343
std shouldBe 2.449489742783178
4444
min shouldBe 1.toBigDecimal()
4545
(p25 as BigDecimal).setScale(2) shouldBe 2.75.toBigDecimal()
@@ -64,8 +64,8 @@ class DescribeTests {
6464
nulls shouldBe 0
6565
top shouldBe 1
6666
freq shouldBe 1
67-
mean.isNaN() shouldBe true
68-
std.isNaN() shouldBe true
67+
mean shouldBe null
68+
std.isNaN shouldBe true
6969
min shouldBe 1.0 // TODO should be NaN too?
7070
p25 shouldBe 1.75
7171
median shouldBe 3.0

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/puzzles/BasicTests.kt

+3-3
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,9 @@ class BasicTests {
142142
@Test
143143
fun `calculate mean age for each animal`() {
144144
val expected = dataFrameOf("animal", "age")(
145-
"cat", Double.NaN,
145+
"cat", null,
146146
"snake", 2.5,
147-
"dog", Double.NaN,
147+
"dog", null,
148148
)
149149

150150
df.groupBy { animal }.mean { age } shouldBe expected
@@ -213,7 +213,7 @@ class BasicTests {
213213
val expected = dataFrameOf("animal", "1", "3", "2")(
214214
"cat", 2.5, 2.5, null,
215215
"snake", 4.5, null, 0.5,
216-
"dog", 3.0, Double.NaN, 6.0,
216+
"dog", 3.0, null, 6.0,
217217
)
218218

219219
val actualDfAcc = df.pivot(inward = false) { visits }.groupBy { animal }.mean(skipNA = true) { age }

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/puzzles/MediumTests.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class MediumTests {
6666
-1, 0, 1,
6767
)
6868

69-
df.convert { colsOf<Double>() }.with { (it - rowMean()).roundToInt() } shouldBe expected
69+
df.convert { colsOf<Double>() }.with { (it - rowMean().toDouble()).roundToInt() } shouldBe expected
7070
}
7171

7272
@Test

0 commit comments

Comments
 (0)