Skip to content

Mean statistics fixes #1091

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Mar 18, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions core/api/core.api
Original file line number Diff line number Diff line change
Expand Up @@ -1784,6 +1784,7 @@ public final class org/jetbrains/kotlinx/dataframe/api/DataColumnTypeKt {
public static final fun isList (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
public static final fun isNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
public static final fun isPrimitive (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
public static final fun isPrimitiveNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
public static final fun isSubtypeOf (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Lkotlin/reflect/KType;)Z
public static final fun isValueColumn (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
public static final fun valuesAreComparable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
Expand Down Expand Up @@ -3968,6 +3969,7 @@ public final class org/jetbrains/kotlinx/dataframe/api/TypeConversionsKt {
public static final fun asComparable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Lorg/jetbrains/kotlinx/dataframe/DataColumn;
public static final fun asComparable (Lorg/jetbrains/kotlinx/dataframe/columns/ColumnSet;)Lorg/jetbrains/kotlinx/dataframe/columns/ColumnSet;
public static final fun asComparable (Lorg/jetbrains/kotlinx/dataframe/columns/SingleColumn;)Lorg/jetbrains/kotlinx/dataframe/columns/SingleColumn;
public static final fun asComparableNullable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Lorg/jetbrains/kotlinx/dataframe/DataColumn;
public static final fun asDataColumn (Lorg/jetbrains/kotlinx/dataframe/columns/ColumnGroup;)Lorg/jetbrains/kotlinx/dataframe/DataColumn;
public static final fun asDataColumn (Lorg/jetbrains/kotlinx/dataframe/columns/FrameColumn;)Lorg/jetbrains/kotlinx/dataframe/DataColumn;
public static final fun asDataFrame (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
Expand Down Expand Up @@ -5104,6 +5106,10 @@ public final class org/jetbrains/kotlinx/dataframe/impl/ExceptionUtilsKt {
public static final fun suggestIfNull (Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/Object;
}

public final class org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtilsKt {
public static final fun getPrimitiveNumberTypes ()Ljava/util/Set;
}

public final class org/jetbrains/kotlinx/dataframe/impl/TypeUtilsKt {
public static final fun getValuesType (Ljava/util/List;Lkotlin/reflect/KType;Lorg/jetbrains/kotlinx/dataframe/api/Infer;)Lkotlin/reflect/KType;
public static final synthetic fun guessValueType (Lkotlin/sequences/Sequence;Lkotlin/reflect/KType;Z)Lkotlin/reflect/KType;
Expand Down Expand Up @@ -5211,6 +5217,10 @@ public final class org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/OfRowE
public static final fun aggregateOfDelegated (Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator;Lorg/jetbrains/kotlinx/dataframe/api/Grouped;Ljava/lang/String;Lkotlin/jvm/functions/Function2;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
}

public final class org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/RowKt {
public static final fun aggregateOfRow (Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator;Lorg/jetbrains/kotlinx/dataframe/DataRow;Lkotlin/jvm/functions/Function2;)Ljava/lang/Object;
}

public final class org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/WithinAllColumnsKt {
public static final fun aggregateAll (Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator;Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)Ljava/lang/Object;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
import org.jetbrains.kotlinx.dataframe.columns.ColumnKind
import org.jetbrains.kotlinx.dataframe.columns.FrameColumn
import org.jetbrains.kotlinx.dataframe.columns.ValueColumn
import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes
import org.jetbrains.kotlinx.dataframe.type
import org.jetbrains.kotlinx.dataframe.typeClass
import org.jetbrains.kotlinx.dataframe.util.IS_COMPARABLE
Expand Down Expand Up @@ -52,6 +53,8 @@ public fun AnyCol.isNumber(): Boolean = isSubtypeOf<Number?>()

public fun AnyCol.isBigNumber(): Boolean = isSubtypeOf<BigInteger?>() || isSubtypeOf<BigDecimal?>()

public fun AnyCol.isPrimitiveNumber(): Boolean = type().withNullability(false) in primitiveNumberTypes

public fun AnyCol.isList(): Boolean = typeClass == List::class

/** Returns `true` if [this] column is intra-comparable, i.e.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@ public inline fun <C, reified R> ColumnReference<C>.map(

// region DataColumn

public inline fun <T, reified R> DataColumn<T>.map(
infer: Infer = Infer.Nulls,
crossinline transform: (T) -> R,
): DataColumn<R> {
public inline fun <T, reified R> DataColumn<T>.map(infer: Infer = Infer.Nulls, transform: (T) -> R): DataColumn<R> {
val newValues = Array(size()) { transform(get(it)) }.asList()
return DataColumn.createByType(name(), newValues, typeOf<R>(), infer)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.cast2
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateAll
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOfRow
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.of
import org.jetbrains.kotlinx.dataframe.impl.aggregation.numberColumns
import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns
import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes
import org.jetbrains.kotlinx.dataframe.impl.suggestIfNull
import org.jetbrains.kotlinx.dataframe.math.mean
import kotlin.reflect.KProperty
import kotlin.reflect.typeOf

Expand All @@ -44,9 +45,18 @@ public inline fun <T, reified R : Number> DataColumn<T>.meanOf(
// region DataRow

public fun AnyRow.rowMean(skipNA: Boolean = skipNA_default): Double =
values().filterIsInstance<Number>().map { it.toDouble() }.mean(skipNA)

public inline fun <reified T : Number> AnyRow.rowMeanOf(): Double = values().filterIsInstance<T>().mean(typeOf<T>())
Aggregators.mean(skipNA).aggregateOfRow(this) {
colsOf<Number?> { it.isPrimitiveNumber() }
} ?: Double.NaN

public inline fun <reified T : Number> AnyRow.rowMeanOf(skipNA: Boolean = skipNA_default): Double {
require(typeOf<T>() in primitiveNumberTypes) {
"Type ${T::class.simpleName} is not a primitive number type. Mean only supports primitive number types."
}
return Aggregators.mean(skipNA)
.aggregateOfRow(this) { colsOf<T>() }
?: Double.NaN
}

// endregion

Expand Down Expand Up @@ -77,7 +87,7 @@ public fun <T, C : Number> DataFrame<T>.meanFor(
public fun <T, C : Number> DataFrame<T>.mean(
skipNA: Boolean = skipNA_default,
columns: ColumnsSelector<T, C?>,
): Double = Aggregators.mean(skipNA).aggregateAll(this, columns) as Double? ?: Double.NaN
): Double = Aggregators.mean(skipNA).aggregateAll(this, columns) ?: Double.NaN

public fun <T> DataFrame<T>.mean(vararg columns: String, skipNA: Boolean = skipNA_default): Double =
mean(skipNA) { columns.toNumberColumns() }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,17 @@ public fun DataColumn<Any>.asNumbers(): ValueColumn<Number> {
return this as ValueColumn<Number>
}

public fun <T> DataColumn<T>.asComparable(): DataColumn<Comparable<T>> {
public fun <T : Any> DataColumn<T>.asComparable(): DataColumn<Comparable<T>> {
require(valuesAreComparable())
return this as DataColumn<Comparable<T>>
}

@JvmName("asComparableNullable")
public fun <T : Any?> DataColumn<T?>.asComparable(): DataColumn<Comparable<T>?> {
require(valuesAreComparable())
return this as DataColumn<Comparable<T>?>
}

public fun <T> ColumnReference<T?>.castToNotNullable(): ColumnReference<T> = cast()

public fun <T> DataColumn<T?>.castToNotNullable(): DataColumn<T> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,18 +234,21 @@ internal fun Iterable<KClass<*>>.unifiedNumberClass(
* or calculated with [Iterable.unifiedNumberType] from the iterable's elements if not explicitly specified.
*
* @param commonNumberType The desired common numeric type to convert the elements to.
* This is determined by default using the types of the elements in the iterable.
* By default, (or if `null`), this is determined using the types of the elements in the iterable.
* @return A new iterable of numbers where each element is converted to the specified or inferred common number type.
* @throws IllegalStateException if an element cannot be converted to the common number type.
* @see UnifyingNumbers
*/
@Suppress("UNCHECKED_CAST")
internal fun Iterable<Number>.convertToUnifiedNumberType(
@JvmName("convertNullableIterableToUnifiedNumberType")
internal fun Iterable<Number?>.convertToUnifiedNumberType(
options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT,
commonNumberType: KType = this.types().unifiedNumberType(options),
): Iterable<Number> {
commonNumberType: KType? = null,
): Iterable<Number?> {
val commonNumberType = commonNumberType ?: this.filterNotNull().types().unifiedNumberType(options)
val converter = createConverter(typeOf<Number>(), commonNumberType)!! as (Number) -> Number?
return map {
if (it == null) return@map null
converter(it) ?: error("Can not convert $it to $commonNumberType")
}
}
Expand All @@ -255,23 +258,62 @@ internal fun Iterable<Number>.convertToUnifiedNumberType(
* or calculated with [Iterable.unifiedNumberType][kotlin.collections.Iterable.unifiedNumberType] from the iterable's elements if not explicitly specified.
*
* @param commonNumberType The desired common numeric type to convert the elements to.
* This is determined by default using the types of the elements in the iterable.
* By default, (or if `null`), this is determined using the types of the elements in the iterable.
* @return A new iterable of numbers where each element is converted to the specified or inferred common number type.
* @throws IllegalStateException if an element cannot be converted to the common number type.
* @see UnifyingNumbers */
@JvmName("convertToUnifiedNumberTypeSequence")
@Suppress("UNCHECKED_CAST")
internal fun Sequence<Number>.convertToUnifiedNumberType(
@JvmName("convertIterableToUnifiedNumberType")
internal fun Iterable<Number>.convertToUnifiedNumberType(
options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT,
commonNumberType: KType? = null,
): Iterable<Number> =
(this as Iterable<Number?>)
.convertToUnifiedNumberType(options, commonNumberType) as Iterable<Number>

/** Converts the elements of the given iterable of numbers into a common numeric type based on complexity.
* The common numeric type is determined using the provided [commonNumberType] parameter
* or calculated with [Iterable.unifiedNumberType][kotlin.collections.Iterable.unifiedNumberType] from the iterable's elements if not explicitly specified.
*
* @param commonNumberType The desired common numeric type to convert the elements to.
* By default, (or if `null`), this is determined using the types of the elements in the iterable.
* @return A new iterable of numbers where each element is converted to the specified or inferred common number type.
* @throws IllegalStateException if an element cannot be converted to the common number type.
* @see UnifyingNumbers */
@Suppress("UNCHECKED_CAST")
@JvmName("convertNullableSequenceToUnifiedNumberType")
internal fun Sequence<Number?>.convertToUnifiedNumberType(
options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT,
commonNumberType: KType = asIterable().types().unifiedNumberType(options),
): Sequence<Number> {
commonNumberType: KType? = null,
): Sequence<Number?> {
val commonNumberType = commonNumberType ?: this.filterNotNull().asIterable().types().unifiedNumberType(options)
val converter = createConverter(typeOf<Number>(), commonNumberType)!! as (Number) -> Number?
return map {
if (it == null) return@map null
converter(it) ?: error("Can not convert $it to $commonNumberType")
}
}

internal val primitiveNumberTypes =
/** Converts the elements of the given iterable of numbers into a common numeric type based on complexity.
* The common numeric type is determined using the provided [commonNumberType] parameter
* or calculated with [Iterable.unifiedNumberType][kotlin.collections.Iterable.unifiedNumberType] from the iterable's elements if not explicitly specified.
*
* @param commonNumberType The desired common numeric type to convert the elements to.
* By default, (or if `null`), this is determined using the types of the elements in the iterable.
* @return A new iterable of numbers where each element is converted to the specified or inferred common number type.
* @throws IllegalStateException if an element cannot be converted to the common number type.
* @see UnifyingNumbers */
@Suppress("UNCHECKED_CAST")
@JvmName("convert=SequenceToUnifiedNumberType")
internal fun Sequence<Number>.convertToUnifiedNumberType(
options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT,
commonNumberType: KType? = null,
): Sequence<Number> =
(this as Sequence<Number?>)
.convertToUnifiedNumberType(options, commonNumberType) as Sequence<Number>

@PublishedApi
internal val primitiveNumberTypes: Set<KType> =
setOf(
typeOf<Byte>(),
typeOf<Short>(),
Expand All @@ -280,3 +322,11 @@ internal val primitiveNumberTypes =
typeOf<Float>(),
typeOf<Double>(),
)

internal fun Any.isPrimitiveNumber(): Boolean =
this is Byte ||
this is Short ||
this is Int ||
this is Long ||
this is Float ||
this is Double
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,9 @@ internal object Aggregators {
// step one: T: Number? -> Double
// step two: Double -> Double
val mean by withOneOption { skipNA: Boolean ->
twoStepChangingType(
getReturnTypeOrNull = meanTypeConversion,
stepOneAggregator = { type -> mean(type, skipNA) },
stepTwoAggregator = { mean(skipNA) },
)
twoStepForNumbers(meanTypeConversion) { type ->
mean(type, skipNA)
}
}

// T: Comparable<T>? -> T
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package org.jetbrains.kotlinx.dataframe.impl.aggregation.modes

import org.jetbrains.kotlinx.dataframe.AnyRow
import org.jetbrains.kotlinx.dataframe.ColumnsSelector
import org.jetbrains.kotlinx.dataframe.api.getColumns
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator

/**
* Generic function to apply an [Aggregator] ([this]) to aggregate values of a row.
*
* [Aggregator.aggregateCalculatingType] is used to deal with mixed types.
*
* @param row a row to aggregate
* @param columns selector of which columns inside the [row] to aggregate
*/
@PublishedApi
internal fun <V, R> Aggregator<V, R>.aggregateOfRow(row: AnyRow, columns: ColumnsSelector<*, V?>): R? {
val filteredColumns = row.df().getColumns(columns)
return aggregateCalculatingType(
values = filteredColumns.mapNotNull { row[it] },
valueTypes = filteredColumns.map { it.type() }.toSet(),
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.jetbrains.kotlinx.dataframe.columns.size
import org.jetbrains.kotlinx.dataframe.columns.values
import org.jetbrains.kotlinx.dataframe.impl.columns.addPath
import org.jetbrains.kotlinx.dataframe.impl.columns.asAnyFrameColumn
import org.jetbrains.kotlinx.dataframe.impl.isBigNumber
import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveNumber
import org.jetbrains.kotlinx.dataframe.impl.renderType
import org.jetbrains.kotlinx.dataframe.index
import org.jetbrains.kotlinx.dataframe.kind
Expand Down Expand Up @@ -111,17 +111,20 @@ private fun List<AnyCol>.collectAll(atAnyDepth: Boolean): List<AnyCol> =
}

/** Converts a column to a comparable column if it is not already comparable. */
private fun DataColumn<Any?>.convertToComparableOrNull(): DataColumn<Comparable<Any?>>? =
when {
@Suppress("UNCHECKED_CAST")
private fun DataColumn<Any?>.convertToComparableOrNull(): DataColumn<Comparable<Any>?>? {
return when {
valuesAreComparable() -> asComparable()

// Found incomparable number types, convert all to Double or BigDecimal first
isNumber() ->
if (any { it?.isBigNumber() == true }) {
map { (it as Number?)?.toBigDecimal() }
} else {
map { (it as Number?)?.toDouble() }
}.cast()
// Found incomparable number types, convert all to Double first
isNumber() -> cast<Number?>().map {
if (it?.isPrimitiveNumber() == false) {
// Cannot calculate statistics of a non-primitive number type
return@convertToComparableOrNull null
}
it?.toDouble() as Comparable<Any>?
}

else -> null
}
}
Loading