Skip to content

Commit 268d238

Browse files
committed
added calculateReturnTypeOrNull system to aggregators to avoid runtime value instance checks where we know the types already
1 parent 3ea0167 commit 268d238

File tree

8 files changed

+123
-35
lines changed

8 files changed

+123
-35
lines changed

core/api/core.api

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10082,6 +10082,7 @@ public abstract interface class org/jetbrains/kotlinx/dataframe/impl/aggregation
1008210082
public abstract fun aggregate (Ljava/lang/Iterable;Lkotlin/reflect/KType;)Ljava/lang/Object;
1008310083
public abstract fun aggregate (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Ljava/lang/Object;
1008410084
public abstract fun aggregateCalculatingType (Ljava/lang/Iterable;Ljava/util/Set;)Ljava/lang/Object;
10085+
public abstract fun calculateReturnTypeOrNull (Lkotlin/reflect/KType;Z)Lkotlin/reflect/KType;
1008510086
public abstract fun getName ()Ljava/lang/String;
1008610087
public abstract fun getPreservesType ()Z
1008710088
}

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators
22

33
import org.jetbrains.kotlinx.dataframe.DataColumn
44
import kotlin.reflect.KType
5+
import kotlin.reflect.full.withNullability
56

67
/**
78
* Base interface for all aggregators.
@@ -56,10 +57,19 @@ internal interface Aggregator<Value, Return> {
5657
* If provided, [valueTypes] can be used to avoid calculating the types of [values] at runtime.
5758
*/
5859
fun aggregateCalculatingType(values: Iterable<Value>, valueTypes: Set<KType>? = null): Return?
60+
61+
/**
62+
* Function that can give the return type of [aggregate] as [KType], given the type of the input.
63+
*/
64+
fun calculateReturnTypeOrNull(type: KType, emptyInput: Boolean): KType?
5965
}
6066

6167
@PublishedApi
6268
internal fun <Type> Aggregator<*, *>.cast(): Aggregator<Type, Type> = this as Aggregator<Type, Type>
6369

6470
@PublishedApi
6571
internal fun <Value, Return> Aggregator<*, *>.cast2(): Aggregator<Value, Return> = this as Aggregator<Value, Return>
72+
73+
internal val preserveReturnTypeNullIfEmpty: (KType, Boolean) -> KType = { type, emptyInput ->
74+
type.withNullability(emptyInput)
75+
}

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorBase.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import kotlin.reflect.full.withNullability
1818
*/
1919
internal abstract class AggregatorBase<Value, Return>(
2020
override val name: String,
21+
protected val getReturnTypeOrNull: (type: KType, emptyInput: Boolean) -> KType?,
2122
protected val aggregator: (values: Iterable<Value>, type: KType) -> Return?,
2223
) : Aggregator<Value, Return> {
2324

@@ -29,6 +30,9 @@ internal abstract class AggregatorBase<Value, Return>(
2930
*/
3031
override fun aggregate(values: Iterable<Value>, type: KType): Return? = aggregator(values, type)
3132

33+
override fun calculateReturnTypeOrNull(type: KType, emptyInput: Boolean): KType? =
34+
getReturnTypeOrNull(type, emptyInput)
35+
3236
/**
3337
* Aggregates the data in the given column and computes a single resulting value.
3438
* Nulls are filtered out before calling the aggregation function with [Iterable] and [KType].

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators
22

33
import org.jetbrains.kotlinx.dataframe.math.meanOrNull
4+
import org.jetbrains.kotlinx.dataframe.math.meanTypeResultOrNull
45
import org.jetbrains.kotlinx.dataframe.math.median
56
import org.jetbrains.kotlinx.dataframe.math.percentile
67
import org.jetbrains.kotlinx.dataframe.math.std
78
import org.jetbrains.kotlinx.dataframe.math.sum
89
import java.math.BigDecimal
910
import kotlin.reflect.KType
11+
import kotlin.reflect.full.withNullability
12+
import kotlin.reflect.typeOf
1013

1114
@PublishedApi
1215
internal object Aggregators {
@@ -18,6 +21,7 @@ internal object Aggregators {
1821
*/
1922
private fun <Type> twoStepPreservingType(aggregator: Iterable<Type>.(type: KType) -> Type?) =
2023
TwoStepAggregator.Factory(
24+
getReturnTypeOrNull = preserveReturnTypeNullIfEmpty,
2125
stepOneAggregator = aggregator,
2226
stepTwoAggregator = aggregator,
2327
preservesType = true,
@@ -29,9 +33,11 @@ internal object Aggregators {
2933
* @include [TwoStepAggregator]
3034
*/
3135
private fun <Value, Return> twoStepChangingType(
36+
getReturnTypeOrNull: (type: KType, emptyInput: Boolean) -> KType?,
3237
stepOneAggregator: Iterable<Value>.(type: KType) -> Return,
3338
stepTwoAggregator: Iterable<Return>.(type: KType) -> Return,
3439
) = TwoStepAggregator.Factory(
40+
getReturnTypeOrNull = getReturnTypeOrNull,
3541
stepOneAggregator = stepOneAggregator,
3642
stepTwoAggregator = stepTwoAggregator,
3743
preservesType = false,
@@ -44,6 +50,7 @@ internal object Aggregators {
4450
*/
4551
private fun <Type> flatteningPreservingTypes(aggregate: Iterable<Type?>.(type: KType) -> Type?) =
4652
FlatteningAggregator.Factory(
53+
getReturnTypeOrNull = preserveReturnTypeNullIfEmpty,
4754
aggregator = aggregate,
4855
preservesType = true,
4956
)
@@ -53,19 +60,27 @@ internal object Aggregators {
5360
*
5461
* @include [FlatteningAggregator]
5562
*/
56-
private fun <Value, Return> flatteningChangingTypes(aggregate: Iterable<Value?>.(type: KType) -> Return?) =
57-
FlatteningAggregator.Factory(
58-
aggregator = aggregate,
59-
preservesType = false,
60-
)
63+
private fun <Value, Return> flatteningChangingTypes(
64+
getReturnTypeOrNull: (type: KType, emptyInput: Boolean) -> KType?,
65+
aggregate: Iterable<Value?>.(type: KType) -> Return?,
66+
) = FlatteningAggregator.Factory(
67+
getReturnTypeOrNull = getReturnTypeOrNull,
68+
aggregator = aggregate,
69+
preservesType = false,
70+
)
6171

6272
/**
6373
* Factory for a two-step aggregator that works only with numbers.
6474
*
6575
* @include [TwoStepNumbersAggregator]
6676
*/
67-
private fun <Return : Number> twoStepForNumbers(aggregate: Iterable<Number>.(numberType: KType) -> Return?) =
68-
TwoStepNumbersAggregator.Factory(aggregate)
77+
private fun <Return : Number> twoStepForNumbers(
78+
getReturnTypeOrNull: (type: KType, emptyInput: Boolean) -> KType?,
79+
aggregate: Iterable<Number>.(numberType: KType) -> Return?,
80+
) = TwoStepNumbersAggregator.Factory(
81+
getReturnTypeOrNull = getReturnTypeOrNull,
82+
aggregate = aggregate,
83+
)
6984

7085
/** @include [AggregatorOptionSwitch1] */
7186
private fun <Param1, AggregatorType : Aggregator<*, *>> withOneOption(
@@ -82,27 +97,29 @@ internal object Aggregators {
8297
val max by twoStepPreservingType<Comparable<Any?>> { maxOrNull() }
8398

8499
val std by withTwoOptions { skipNA: Boolean, ddof: Int ->
85-
flatteningChangingTypes<Number, Double> { std(it, skipNA, ddof) }
100+
flatteningChangingTypes<Number, Double>(
101+
getReturnTypeOrNull = { type, emptyInput -> typeOf<Double>().withNullability(emptyInput) },
102+
) { std(it, skipNA, ddof) }
86103
}
87104

88105
@Suppress("ClassName")
89106
object mean {
90107
val toNumber = withOneOption { skipNA: Boolean ->
91-
twoStepForNumbers { meanOrNull(it, skipNA) }
108+
twoStepForNumbers(::meanTypeResultOrNull) { meanOrNull(it, skipNA) }
92109
}.create(mean::class.simpleName!!)
93110

94111
val toDouble = withOneOption { skipNA: Boolean ->
95-
twoStepForNumbers { meanOrNull(it, skipNA) as Double? }
112+
twoStepForNumbers(::meanTypeResultOrNull) { meanOrNull(it, skipNA) as Double? }
96113
}.create(mean::class.simpleName!!)
97114

98115
val toBigDecimal =
99-
twoStepForNumbers {
116+
twoStepForNumbers(::meanTypeResultOrNull) {
100117
meanOrNull(it) as BigDecimal?
101118
}.create(mean::class.simpleName!!)
102119
}
103120

104121
val percentile by withOneOption { percentile: Double ->
105-
flatteningChangingTypes<Comparable<Any?>, Comparable<Any?>> { type ->
122+
flatteningChangingTypes<Comparable<Any?>, Comparable<Any?>>(preserveReturnTypeNullIfEmpty) { type ->
106123
percentile(percentile, type)
107124
}
108125
}
@@ -111,5 +128,7 @@ internal object Aggregators {
111128
median(it)
112129
}
113130

114-
val sum by twoStepForNumbers { sum(it) }
131+
val sum by twoStepForNumbers(
132+
getReturnTypeOrNull = { type, _ -> type.withNullability(false) },
133+
) { sum(it) }
115134
}

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/FlatteningAggregator.kt

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@ import kotlin.reflect.full.withNullability
3232
*/
3333
internal class FlatteningAggregator<Value, Return>(
3434
name: String,
35+
getReturnTypeOrNull: (type: KType, emptyInput: Boolean) -> KType?,
3536
aggregator: (values: Iterable<Value>, type: KType) -> Return?,
3637
override val preservesType: Boolean,
37-
) : AggregatorBase<Value, Return>(name, aggregator) {
38+
) : AggregatorBase<Value, Return>(name, getReturnTypeOrNull, aggregator) {
3839

3940
/**
4041
* Aggregates the data in the multiple given columns and computes a single resulting value.
@@ -54,9 +55,15 @@ internal class FlatteningAggregator<Value, Return>(
5455
* @param preservesType If `true`, [Value][Value]` == `[Return][Return].
5556
*/
5657
class Factory<Value, Return>(
58+
private val getReturnTypeOrNull: (type: KType, emptyInput: Boolean) -> KType?,
5759
private val aggregator: (Iterable<Value>, KType) -> Return?,
5860
private val preservesType: Boolean,
5961
) : AggregatorProvider<FlatteningAggregator<Value, Return>> by AggregatorProvider({ name ->
60-
FlatteningAggregator(name = name, aggregator = aggregator, preservesType = preservesType)
62+
FlatteningAggregator(
63+
name = name,
64+
getReturnTypeOrNull = getReturnTypeOrNull,
65+
aggregator = aggregator,
66+
preservesType = preservesType,
67+
)
6168
})
6269
}

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepAggregator.kt

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators
22

33
import org.jetbrains.kotlinx.dataframe.DataColumn
4-
import org.jetbrains.kotlinx.dataframe.impl.classes
54
import org.jetbrains.kotlinx.dataframe.impl.commonType
5+
import org.jetbrains.kotlinx.dataframe.size
66
import kotlin.reflect.KType
7+
import kotlin.reflect.full.starProjectedType
78
import kotlin.reflect.full.withNullability
89

910
/**
@@ -38,28 +39,30 @@ import kotlin.reflect.full.withNullability
3839
*/
3940
internal class TwoStepAggregator<Value, Return>(
4041
name: String,
42+
getReturnTypeOrNull: (type: KType, emptyInput: Boolean) -> KType?,
4143
stepOneAggregator: (values: Iterable<Value>, type: KType) -> Return?,
4244
private val stepTwoAggregator: (values: Iterable<Return>, type: KType) -> Return?,
4345
override val preservesType: Boolean,
44-
) : AggregatorBase<Value, Return>(name, stepOneAggregator) {
46+
) : AggregatorBase<Value, Return>(name, getReturnTypeOrNull, stepOneAggregator) {
4547

4648
/**
4749
* Aggregates the data in the multiple given columns and computes a single resulting value.
4850
*
4951
* This function calls [stepOneAggregator] on each column and then [stepTwoAggregator] on the results.
5052
*/
5153
override fun aggregate(columns: Iterable<DataColumn<Value?>>): Return? {
52-
val columnValues = columns.mapNotNull {
54+
val (values, types) = columns.mapNotNull { col ->
5355
// uses stepOneAggregator
54-
aggregate(it)
55-
}
56-
val commonType = if (preservesType) {
57-
columns.map { it.type() }.commonType().withNullability(false)
58-
} else {
59-
// heavy!
60-
columnValues.classes().commonType(false)
61-
}
62-
return stepTwoAggregator(columnValues, commonType)
56+
val value = aggregate(col) ?: return@mapNotNull null
57+
val type = calculateReturnTypeOrNull(
58+
type = col.type().withNullability(false),
59+
emptyInput = col.size() == 0,
60+
) ?: value::class.starProjectedType // heavy fallback type calculation
61+
62+
value to type
63+
}.unzip()
64+
val commonType = types.commonType()
65+
return stepTwoAggregator(values, commonType)
6366
}
6467

6568
/**
@@ -71,12 +74,14 @@ internal class TwoStepAggregator<Value, Return>(
7174
* @param preservesType If `true`, [Value][Value]` == `[Return][Return].
7275
*/
7376
class Factory<Value, Return>(
77+
private val getReturnTypeOrNull: (type: KType, emptyInput: Boolean) -> KType?,
7478
private val stepOneAggregator: (Iterable<Value>, KType) -> Return?,
7579
private val stepTwoAggregator: (Iterable<Return>, KType) -> Return?,
7680
private val preservesType: Boolean,
7781
) : AggregatorProvider<TwoStepAggregator<Value, Return>> by AggregatorProvider({ name ->
7882
TwoStepAggregator(
7983
name = name,
84+
getReturnTypeOrNull = getReturnTypeOrNull,
8085
stepOneAggregator = stepOneAggregator,
8186
stepTwoAggregator = stepTwoAggregator,
8287
preservesType = preservesType,

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepNumbersAggregator.kt

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import org.jetbrains.kotlinx.dataframe.impl.types
77
import org.jetbrains.kotlinx.dataframe.impl.unifiedNumberType
88
import kotlin.reflect.KType
99
import kotlin.reflect.full.isSubtypeOf
10+
import kotlin.reflect.full.starProjectedType
1011
import kotlin.reflect.full.withNullability
1112
import kotlin.reflect.typeOf
1213

@@ -38,8 +39,9 @@ import kotlin.reflect.typeOf
3839
*/
3940
internal class TwoStepNumbersAggregator<Return : Number>(
4041
name: String,
42+
getReturnTypeOrNull: (type: KType, emptyInput: Boolean) -> KType?,
4143
aggregator: (values: Iterable<Number>, numberType: KType) -> Return?,
42-
) : AggregatorBase<Number, Return>(name, aggregator) {
44+
) : AggregatorBase<Number, Return>(name, getReturnTypeOrNull, aggregator) {
4345

4446
override fun aggregate(values: Iterable<Number>, type: KType): Return? {
4547
require(type.isSubtypeOf(typeOf<Number?>())) {
@@ -48,11 +50,22 @@ internal class TwoStepNumbersAggregator<Return : Number>(
4850
return super.aggregate(values, type)
4951
}
5052

51-
override fun aggregate(columns: Iterable<DataColumn<Number?>>): Return? =
52-
aggregateCalculatingType(
53-
values = columns.mapNotNull { aggregate(it) },
54-
valueTypes = null, // makes the operation heavy
53+
override fun aggregate(columns: Iterable<DataColumn<Number?>>): Return? {
54+
val (values, types) = columns.mapNotNull { col ->
55+
val value = aggregate(col) ?: return@mapNotNull null
56+
val type = calculateReturnTypeOrNull(
57+
type = col.type().withNullability(false),
58+
emptyInput = col.size() == 0,
59+
) ?: value::class.starProjectedType // heavy fallback type calculation
60+
61+
value to type
62+
}.unzip()
63+
64+
return aggregateCalculatingType(
65+
values = values,
66+
valueTypes = types.toSet(),
5567
)
68+
}
5669

5770
/**
5871
* Special case of [aggregate] with [Iterable] that calculates the [unified number type][UnifyingNumbers]
@@ -71,8 +84,14 @@ internal class TwoStepNumbersAggregator<Return : Number>(
7184

7285
override val preservesType = false
7386

74-
class Factory<Return : Number>(private val aggregate: Iterable<Number>.(numberType: KType) -> Return?) :
75-
AggregatorProvider<TwoStepNumbersAggregator<Return>> by AggregatorProvider({ name ->
76-
TwoStepNumbersAggregator(name = name, aggregator = aggregate)
87+
class Factory<Return : Number>(
88+
private val getReturnTypeOrNull: (type: KType, emptyInput: Boolean) -> KType?,
89+
private val aggregate: Iterable<Number>.(numberType: KType) -> Return?,
90+
) : AggregatorProvider<TwoStepNumbersAggregator<Return>> by AggregatorProvider({ name ->
91+
TwoStepNumbersAggregator(
92+
name = name,
93+
getReturnTypeOrNull = getReturnTypeOrNull,
94+
aggregator = aggregate,
95+
)
7796
})
7897
}

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/mean.kt

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package org.jetbrains.kotlinx.dataframe.math
33
import org.jetbrains.kotlinx.dataframe.api.skipNA_default
44
import org.jetbrains.kotlinx.dataframe.impl.api.toBigDecimal
55
import org.jetbrains.kotlinx.dataframe.impl.convertToUnifiedNumberType
6+
import org.jetbrains.kotlinx.dataframe.impl.nothingType
7+
import org.jetbrains.kotlinx.dataframe.impl.nullableNothingType
68
import org.jetbrains.kotlinx.dataframe.impl.renderType
79
import org.jetbrains.kotlinx.dataframe.impl.types
810
import org.jetbrains.kotlinx.dataframe.impl.unifiedNumberType
@@ -79,6 +81,27 @@ internal fun <T : Number> Sequence<T>.meanOrNull(type: KType, skipNA: Boolean =
7981
}
8082
}
8183

84+
internal fun meanTypeResultOrNull(type: KType, emptyInput: Boolean): KType? =
85+
when (val type = type.withNullability(false)) {
86+
typeOf<Double>(),
87+
typeOf<Float>(),
88+
typeOf<Int>(),
89+
typeOf<Short>(),
90+
typeOf<Byte>(),
91+
typeOf<Long>(),
92+
-> typeOf<Double>().withNullability(emptyInput)
93+
94+
typeOf<BigInteger>(),
95+
typeOf<BigDecimal>(),
96+
-> typeOf<BigDecimal>().withNullability(emptyInput)
97+
98+
nothingType -> nullableNothingType
99+
100+
typeOf<Number>() -> null
101+
102+
else -> throw IllegalArgumentException("Unable to compute the mean for type ${renderType(type)}")
103+
}
104+
82105
internal fun Sequence<Double>.meanOrNull(skipNA: Boolean = skipNA_default): Double? {
83106
var count = 0
84107
var sum: Double = 0.toDouble()

0 commit comments

Comments
 (0)