Skip to content

Commit 73ceaa8

Browse files
committed
enabled some previously broken sum.kt tests and fixed edge case in NumberInputHandler
1 parent 2e0875b commit 73ceaa8

File tree

3 files changed

+22
-65
lines changed

3 files changed

+22
-65
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/inputHandlers/ComparableOrNumberInputHandler.kt

Lines changed: 0 additions & 54 deletions
This file was deleted.

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/inputHandlers/NumberInputHandler.kt

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.ValueType
99
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregate
1010
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.toValueType
1111
import org.jetbrains.kotlinx.dataframe.impl.convertToUnifiedNumberType
12+
import org.jetbrains.kotlinx.dataframe.impl.isMixedNumber
1213
import org.jetbrains.kotlinx.dataframe.impl.isNothing
14+
import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveOrMixedNumber
1315
import org.jetbrains.kotlinx.dataframe.impl.nothingType
1416
import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes
1517
import org.jetbrains.kotlinx.dataframe.impl.renderType
@@ -47,14 +49,23 @@ internal class NumberInputHandler<out Return : Any?> : AggregatorInputHandler<Nu
4749
valueType: ValueType,
4850
): Pair<Sequence<Number?>, KType> {
4951
require(valueType.kType.isSubtypeOf(typeOf<Number?>())) {
50-
"${NumberInputHandler::class.simpleName}: Type $valueType is not a subtype of Number?, only primitive numbers are supported in statistics"
52+
"Type $valueType is not a subtype of Number?, only primitive numbers are supported in ${aggregator!!.name}."
5153
}
5254
return when (valueType.kType.withNullability(false)) {
5355
// If the type is not a specific number, but rather a mixed Number, we unify the types first.
5456
// This is heavy and could be avoided by calling aggregate with a specific number type
5557
// or calling aggregateCalculatingType with all known number types
5658
typeOf<Number>() -> {
5759
val unifiedType = calculateValueType(values).kType
60+
61+
// If calculateValueType returns Number(?),
62+
// it means the values cannot be unified to a primitive number type
63+
require(!unifiedType.isMixedNumber()) {
64+
"Types ${
65+
values.asIterable().types().toSet()
66+
} are not all primitive numbers, only those are supported in ${aggregator!!.name}."
67+
}
68+
5869
val unifiedValues = values.convertToUnifiedNumberType(
5970
UnifiedNumberTypeOptions.PRIMITIVES_ONLY,
6071
unifiedType,
@@ -90,23 +101,21 @@ internal class NumberInputHandler<out Return : Any?> : AggregatorInputHandler<Nu
90101
* this function can be called to calculate it in terms of [number unification][UnifyingNumbers]
91102
*
92103
* @throws IllegalArgumentException if the input type is not [Number]`(?)` or a primitive number type.
104+
* @return The (primitive) unified number type of the input values.
105+
* If no valid unification can be found or the input is solely [Number]`(?)`, the type [Number]`(?)` is returned.
93106
*/
94107
override fun calculateValueType(valueTypes: Set<KType>): ValueType {
95108
val unifiedType = valueTypes.unifiedNumberTypeOrNull(UnifiedNumberTypeOptions.Companion.PRIMITIVES_ONLY)
96-
?: throw IllegalArgumentException(
97-
"Cannot calculate the ${aggregator!!.name} of the number types: ${
98-
valueTypes.joinToString { renderType(it) }
99-
}. Note, only primitive number types are supported in statistics.",
100-
)
109+
?: typeOf<Number>().withNullability(valueTypes.any { it.isMarkedNullable })
101110

102111
if (unifiedType.isSubtypeOf(typeOf<Double?>()) &&
103112
(typeOf<ULong>() in valueTypes || typeOf<Long>() in valueTypes)
104113
) {
105114
logger.warn {
106-
"Number unification of Long -> Double happened during aggregation. Loss of precision may have occurred."
115+
"Number unification of Long -> Double happened during ${aggregator!!.name} aggregation. Loss of precision may have occurred."
107116
}
108117
}
109-
if (unifiedType.withNullability(false) !in primitiveNumberTypes && !unifiedType.isNothing) {
118+
if (!unifiedType.isPrimitiveOrMixedNumber() && !unifiedType.isNothing) {
110119
throw IllegalArgumentException(
111120
"Cannot calculate ${aggregator!!.name} of ${
112121
renderType(unifiedType)
@@ -124,7 +133,9 @@ internal class NumberInputHandler<out Return : Any?> : AggregatorInputHandler<Nu
124133
* this function can be called to calculate it in terms of [number unification][UnifyingNumbers]
125134
* by getting the types of [values] at runtime.
126135
*
127-
* @throws IllegalArgumentException if the input type contains a non-primitive number type.
136+
* @throws IllegalArgumentException if the input type is not [Number]`(?)` or a primitive number type.
137+
* @return The (primitive) unified number type of the input values.
138+
* If no valid unification can be found or the input is solely [Number]`(?)`, the type [Number]`(?)` is returned.
128139
*/
129140
override fun calculateValueType(values: Sequence<Number?>): ValueType =
130141
calculateValueType(values.asIterable().types().toSet())

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/sum.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,10 @@ class SumTests {
6060
df.sumOf { value3() } shouldBe expected3
6161
df.sum(value1) shouldBe expected1
6262
df.sum(value2) shouldBe expected2
63-
// TODO sum rework, has Number in results df.sum(value3) shouldBe expected3
63+
df.sum(value3) shouldBe expected3
6464
df.sum { value1 } shouldBe expected1
6565
df.sum { value2 } shouldBe expected2
66-
// TODO sum rework, has Number in results df.sum { value3 } shouldBe expected3
66+
df.sum { value3 } shouldBe expected3
6767
}
6868

6969
/** [Issue #1068](https://github.com/Kotlin/dataframe/issues/1068) */

0 commit comments

Comments
 (0)