Skip to content

Commit 6d24729

Browse files
committed
WIP rework of aggregator implementation
1 parent 31401f6 commit 6d24729

File tree

12 files changed

+382
-137
lines changed

12 files changed

+382
-137
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/median.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public inline fun <T, reified R : Comparable<R>> DataColumn<T>.medianOf(noinline
3939
// region DataRow
4040

4141
public fun AnyRow.rowMedianOrNull(): Any? =
42-
Aggregators.median.aggregateMixed(
42+
Aggregators.median.aggregateCalculatingType(
4343
values().filterIsInstance<Comparable<Any?>>().asIterable(),
4444
)
4545

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ public inline fun <T, reified R : Number> DataColumn<T>.sumOf(crossinline expres
4444
// region DataRow
4545

4646
public fun AnyRow.rowSum(): Number =
47-
Aggregators.sum.aggregateMixed(
47+
Aggregators.sum.aggregateUnifyingNumbers(
4848
values = values().filterIsInstance<Number>(),
49-
types = columnTypes().filter { it.isSubtypeOf(typeOf<Number?>()) }.toSet(),
49+
numberTypes = columnTypes().filter { it.isSubtypeOf(typeOf<Number?>()) }.toSet(),
5050
) ?: 0
5151

5252
public inline fun <reified T : Number> AnyRow.rowSumOf(): T = values().filterIsInstance<T>().sum(typeOf<T>())

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,15 @@ internal fun Iterable<Number>.convertToUnifiedNumberType(
137137
converter(it) ?: error("Can not convert $it to $commonNumberType")
138138
}
139139
}
140+
141+
/** @include [Iterable.convertToUnifiedNumberType] */
142+
@JvmName("convertToUnifiedNumberTypeSequence")
143+
@Suppress("UNCHECKED_CAST")
144+
internal fun Sequence<Number>.convertToUnifiedNumberType(
145+
commonNumberType: KType = asIterable().types().unifiedNumberType(),
146+
): Sequence<Number> {
147+
val converter = createConverter(typeOf<Number>(), commonNumberType)!! as (Number) -> Number?
148+
return map {
149+
converter(it) ?: error("Can not convert $it to $commonNumberType")
150+
}
151+
}

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,50 @@ package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators
33
import org.jetbrains.kotlinx.dataframe.DataColumn
44
import kotlin.reflect.KType
55

6+
/**
7+
* Base interface for all aggregators.
8+
*
9+
* Aggregators are used to compute a single value from an [Iterable] of values, a single [DataColumn],
10+
* or multiple [DataColumns][DataColumn].
11+
*
12+
* The [AggregatorBase] class is a base implementation of this interface.
13+
*/
614
@PublishedApi
7-
internal interface Aggregator<C, R> {
15+
internal interface Aggregator<Value, Return> {
816

17+
/** The name of this aggregator. */
918
val name: String
1019

11-
fun aggregate(column: DataColumn<C?>): R?
12-
20+
/** If `true`, [Value][Value]` == ` [Return][Return]. */
1321
val preservesType: Boolean
1422

15-
fun aggregate(columns: Iterable<DataColumn<C?>>): R?
16-
17-
fun aggregate(values: Iterable<C>, type: KType): R?
23+
/**
24+
* Base function of [Aggregator].
25+
*
26+
* Aggregates the given values, taking [type] into account, and computes a single resulting value.
27+
*
28+
* When using [AggregatorBase], this can be supplied by the [AggregatorBase.aggregator] argument.
29+
*/
30+
fun aggregate(values: Iterable<Value>, type: KType): Return?
31+
32+
/**
33+
* Aggregates the data in the given column and computes a single resulting value.
34+
* Nulls are filtered out by default, then the aggregation function (with [Iterable] and [KType]) is called.
35+
*
36+
* See [AggregatorBase.aggregate].
37+
*/
38+
fun aggregate(column: DataColumn<Value?>): Return?
39+
40+
/**
41+
* Aggregates the data in the multiple given columns and computes a single resulting value.
42+
*
43+
* Must be overridden when using [AggregatorBase].
44+
*/
45+
fun aggregate(columns: Iterable<DataColumn<Value?>>): Return?
1846
}
1947

2048
@PublishedApi
21-
internal fun <T> Aggregator<*, *>.cast(): Aggregator<T, T> = this as Aggregator<T, T>
49+
internal fun <Type> Aggregator<*, *>.cast(): Aggregator<Type, Type> = this as Aggregator<Type, Type>
2250

2351
@PublishedApi
24-
internal fun <T, P> Aggregator<*, *>.cast2(): Aggregator<T, P> = this as Aggregator<T, P>
52+
internal fun <Value, Return> Aggregator<*, *>.cast2(): Aggregator<Value, Return> = this as Aggregator<Value, Return>

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorBase.kt

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,62 @@ package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators
33
import org.jetbrains.kotlinx.dataframe.DataColumn
44
import org.jetbrains.kotlinx.dataframe.api.asIterable
55
import org.jetbrains.kotlinx.dataframe.api.asSequence
6+
import org.jetbrains.kotlinx.dataframe.impl.commonType
67
import kotlin.reflect.KType
78

8-
internal abstract class AggregatorBase<C, R>(
9+
/**
10+
* Base class for [aggregators][Aggregator].
11+
*
12+
* Aggregators are used to compute a single value from an [Iterable] of values, a single [DataColumn],
13+
* or multiple [DataColumns][DataColumn].
14+
*
15+
* @param name The name of this aggregator.
16+
* @param aggregator Functional argument for the [aggregate] function.
17+
*/
18+
internal abstract class AggregatorBase<Value, Return>(
919
override val name: String,
10-
protected val aggregator: (Iterable<C>, KType) -> R?,
11-
) : Aggregator<C, R> {
20+
protected val aggregator: (values: Iterable<Value>, type: KType) -> Return?,
21+
) : Aggregator<Value, Return> {
1222

13-
override fun aggregate(column: DataColumn<C?>): R? =
23+
/**
24+
* Base function of [Aggregator].
25+
*
26+
* Aggregates the given values, taking [type] into account, and computes a single resulting value.
27+
* Uses [aggregator] to compute the result.
28+
*/
29+
override fun aggregate(values: Iterable<Value>, type: KType): Return? = aggregator(values, type)
30+
31+
/**
32+
* Aggregates the data in the given column and computes a single resulting value.
33+
* Nulls are filtered out before calling the aggregation function with [Iterable] and [KType].
34+
*/
35+
override fun aggregate(column: DataColumn<Value?>): Return? =
1436
if (column.hasNulls()) {
1537
aggregate(column.asSequence().filterNotNull().asIterable(), column.type())
1638
} else {
17-
aggregate(column.asIterable() as Iterable<C>, column.type())
39+
aggregate(column.asIterable() as Iterable<Value>, column.type())
40+
}
41+
42+
/**
43+
* Special case of [aggregate] with [Iterable] that calculates the common type of the values at runtime.
44+
* This is a heavy operation and should be avoided when possible.
45+
*/
46+
fun aggregateCalculatingType(values: Iterable<Value>): Return? {
47+
var hasNulls = false
48+
val classes = values.mapNotNull {
49+
if (it == null) {
50+
hasNulls = true
51+
null
52+
} else {
53+
it.javaClass.kotlin
54+
}
1855
}
56+
return aggregate(values, classes.commonType(hasNulls))
57+
}
1958

20-
override fun aggregate(values: Iterable<C>, type: KType): R? = aggregator(values, type)
59+
/**
60+
* Aggregates the data in the multiple given columns and computes a single resulting value.
61+
* Must be overridden to use.
62+
*/
63+
abstract override fun aggregate(columns: Iterable<DataColumn<Value?>>): Return?
2164
}
Lines changed: 56 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,72 @@
11
package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators
22

3-
import kotlin.reflect.KProperty
4-
3+
/**
4+
* Wrapper around an [aggregator factory][AggregatorProvider] for aggregators that require a single parameter.
5+
*
6+
* Aggregators are cached by their parameter value.
7+
* @see AggregatorOptionSwitch2
8+
*/
59
@PublishedApi
6-
internal class AggregatorOptionSwitch<P, C, R>(val name: String, val getAggregator: (P) -> AggregatorProvider<C, R>) {
10+
internal class AggregatorOptionSwitch1<Param1, AggregatorType : Aggregator<*, *>>(
11+
val name: String,
12+
val getAggregator: (param1: Param1) -> AggregatorProvider<AggregatorType>,
13+
) {
714

8-
private val cache = mutableMapOf<P, Aggregator<C, R>>()
15+
private val cache: MutableMap<Param1, AggregatorType> = mutableMapOf()
916

10-
operator fun invoke(option: P) = cache.getOrPut(option) { getAggregator(option).create(name) }
17+
operator fun invoke(param1: Param1): AggregatorType =
18+
cache.getOrPut(param1) {
19+
getAggregator(param1).create(name)
20+
}
1121

12-
class Factory<P, C, R>(val getAggregator: (P) -> AggregatorProvider<C, R>) {
13-
operator fun getValue(obj: Any?, property: KProperty<*>) = AggregatorOptionSwitch(property.name, getAggregator)
14-
}
22+
/**
23+
* Creates [AggregatorOptionSwitch1].
24+
*
25+
* Used like:
26+
* ```kt
27+
* val myAggregator by AggregatorOptionSwitch1.Factory { param1: Param1 ->
28+
* MyAggregator.Factory(param1)
29+
* }
30+
*/
31+
class Factory<Param1, AggregatorType : Aggregator<*, *>>(
32+
val getAggregator: (Param1) -> AggregatorProvider<AggregatorType>,
33+
) : Provider<AggregatorOptionSwitch1<Param1, AggregatorType>> by Provider({ name ->
34+
AggregatorOptionSwitch1(name, getAggregator)
35+
})
1536
}
1637

38+
/**
39+
* Wrapper around an [aggregator factory][AggregatorProvider] for aggregators that require two parameters.
40+
*
41+
* Aggregators are cached by their parameter values.
42+
* @see AggregatorOptionSwitch1
43+
*/
1744
@PublishedApi
18-
internal class AggregatorOptionSwitch2<P1, P2, C, R>(
45+
internal class AggregatorOptionSwitch2<Param1, Param2, AggregatorType : Aggregator<*, *>>(
1946
val name: String,
20-
val getAggregator: (P1, P2) -> AggregatorProvider<C, R>,
47+
val getAggregator: (param1: Param1, param2: Param2) -> AggregatorProvider<AggregatorType>,
2148
) {
2249

23-
private val cache = mutableMapOf<Pair<P1, P2>, Aggregator<C, R>>()
50+
private val cache: MutableMap<Pair<Param1, Param2>, AggregatorType> = mutableMapOf()
2451

25-
operator fun invoke(option1: P1, option2: P2) =
26-
cache.getOrPut(option1 to option2) {
27-
getAggregator(option1, option2).create(name)
52+
operator fun invoke(param1: Param1, param2: Param2): AggregatorType =
53+
cache.getOrPut(param1 to param2) {
54+
getAggregator(param1, param2).create(name)
2855
}
2956

30-
class Factory<P1, P2, C, R>(val getAggregator: (P1, P2) -> AggregatorProvider<C, R>) {
31-
operator fun getValue(obj: Any?, property: KProperty<*>) = AggregatorOptionSwitch2(property.name, getAggregator)
32-
}
57+
/**
58+
* Creates [AggregatorOptionSwitch2].
59+
*
60+
* Used like:
61+
* ```kt
62+
* val myAggregator by AggregatorOptionSwitch2.Factory { param1: Param1, param2: Param2 ->
63+
* MyAggregator.Factory(param1, param2)
64+
* }
65+
* ```
66+
*/
67+
class Factory<Param1, Param2, AggregatorType : Aggregator<*, *>>(
68+
val getAggregator: (Param1, Param2) -> AggregatorProvider<AggregatorType>,
69+
) : Provider<AggregatorOptionSwitch2<Param1, Param2, AggregatorType>> by Provider({ name ->
70+
AggregatorOptionSwitch2(name, getAggregator)
71+
})
3372
}

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorProvider.kt

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,27 @@ package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators
22

33
import kotlin.reflect.KProperty
44

5-
internal interface AggregatorProvider<C, R> {
5+
/**
6+
* Common interface for providers or "factory" objects that create anything of type [T].
7+
*
8+
* When implemented, this allows the object to be created using the `by` delegate, to give it a name, like:
9+
* ```kt
10+
* val myNamedValue by MyFactory
11+
* ```
12+
*/
13+
internal fun interface Provider<T> {
614

7-
operator fun getValue(obj: Any?, property: KProperty<*>): Aggregator<C, R> = create(property.name)
8-
9-
fun create(name: String): Aggregator<C, R>
15+
fun create(name: String): T
1016
}
17+
18+
internal operator fun <T> Provider<T>.getValue(obj: Any?, property: KProperty<*>): T = create(property.name)
19+
20+
/**
21+
* Common interface for providers of [Aggregators][Aggregator] or "factory" objects that create aggregators.
22+
*
23+
* When implemented, this allows an aggregator to be created using the `by` delegate, to give it a name, like:
24+
* ```kt
25+
* val myAggregator by MyAggregator.Factory
26+
* ```
27+
*/
28+
internal fun interface AggregatorProvider<AggregatorType : Aggregator<*, *>> : Provider<AggregatorType>

0 commit comments

Comments
 (0)