Skip to content

Mean statistics fixes #1091

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Mar 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions core/api/core.api
Original file line number Diff line number Diff line change
Expand Up @@ -1777,12 +1777,12 @@ public final class org/jetbrains/kotlinx/dataframe/api/DataColumnArithmeticsKt {
}

public final class org/jetbrains/kotlinx/dataframe/api/DataColumnTypeKt {
public static final fun isBigNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
public static final fun isColumnGroup (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
public static final fun isComparable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
public static final fun isFrameColumn (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
public static final fun isList (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
public static final fun isNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
public static final fun isPrimitiveNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
public static final fun isSubtypeOf (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Lkotlin/reflect/KType;)Z
public static final fun isValueColumn (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
public static final fun valuesAreComparable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
Expand Down Expand Up @@ -2842,8 +2842,6 @@ public final class org/jetbrains/kotlinx/dataframe/api/MeanKt {
public static synthetic fun meanFor$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;[Ljava/lang/String;ZZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
public static synthetic fun meanFor$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;[Lkotlin/reflect/KProperty;ZZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
public static synthetic fun meanFor$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;ZZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
public static final fun meanOrNull (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Z)Ljava/lang/Double;
public static synthetic fun meanOrNull$default (Lorg/jetbrains/kotlinx/dataframe/DataColumn;ZILjava/lang/Object;)Ljava/lang/Double;
public static final fun rowMean (Lorg/jetbrains/kotlinx/dataframe/DataRow;Z)D
public static synthetic fun rowMean$default (Lorg/jetbrains/kotlinx/dataframe/DataRow;ZILjava/lang/Object;)D
}
Expand Down Expand Up @@ -3967,6 +3965,7 @@ public final class org/jetbrains/kotlinx/dataframe/api/TypeConversionsKt {
public static final fun asComparable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Lorg/jetbrains/kotlinx/dataframe/DataColumn;
public static final fun asComparable (Lorg/jetbrains/kotlinx/dataframe/columns/ColumnSet;)Lorg/jetbrains/kotlinx/dataframe/columns/ColumnSet;
public static final fun asComparable (Lorg/jetbrains/kotlinx/dataframe/columns/SingleColumn;)Lorg/jetbrains/kotlinx/dataframe/columns/SingleColumn;
public static final fun asComparableNullable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Lorg/jetbrains/kotlinx/dataframe/DataColumn;
public static final fun asDataColumn (Lorg/jetbrains/kotlinx/dataframe/columns/ColumnGroup;)Lorg/jetbrains/kotlinx/dataframe/DataColumn;
public static final fun asDataColumn (Lorg/jetbrains/kotlinx/dataframe/columns/FrameColumn;)Lorg/jetbrains/kotlinx/dataframe/DataColumn;
public static final fun asDataFrame (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
Expand Down Expand Up @@ -5103,6 +5102,10 @@ public final class org/jetbrains/kotlinx/dataframe/impl/ExceptionUtilsKt {
public static final fun suggestIfNull (Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/Object;
}

public final class org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtilsKt {
public static final fun getPrimitiveNumberTypes ()Ljava/util/Set;
}

public final class org/jetbrains/kotlinx/dataframe/impl/TypeUtilsKt {
public static final fun getValuesType (Ljava/util/List;Lkotlin/reflect/KType;Lorg/jetbrains/kotlinx/dataframe/api/Infer;)Lkotlin/reflect/KType;
public static final synthetic fun guessValueType (Lkotlin/sequences/Sequence;Lkotlin/reflect/KType;Z)Lkotlin/reflect/KType;
Expand Down Expand Up @@ -5210,6 +5213,10 @@ public final class org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/OfRowE
public static final fun aggregateOfDelegated (Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator;Lorg/jetbrains/kotlinx/dataframe/api/Grouped;Ljava/lang/String;Lkotlin/jvm/functions/Function2;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
}

public final class org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/RowKt {
public static final fun aggregateOfRow (Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator;Lorg/jetbrains/kotlinx/dataframe/DataRow;Lkotlin/jvm/functions/Function2;)Ljava/lang/Object;
}

public final class org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/WithinAllColumnsKt {
public static final fun aggregateAll (Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator;Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)Ljava/lang/Object;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@ import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
import org.jetbrains.kotlinx.dataframe.columns.ColumnKind
import org.jetbrains.kotlinx.dataframe.columns.FrameColumn
import org.jetbrains.kotlinx.dataframe.columns.ValueColumn
import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes
import org.jetbrains.kotlinx.dataframe.type
import org.jetbrains.kotlinx.dataframe.typeClass
import org.jetbrains.kotlinx.dataframe.util.IS_COMPARABLE
import org.jetbrains.kotlinx.dataframe.util.IS_COMPARABLE_REPLACE
import org.jetbrains.kotlinx.dataframe.util.IS_INTER_COMPARABLE_IMPORT
import java.math.BigDecimal
import java.math.BigInteger
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
import kotlin.reflect.KType
Expand Down Expand Up @@ -46,9 +45,14 @@ public inline fun <reified T> AnyCol.isSubtypeOf(): Boolean = isSubtypeOf(typeOf

public inline fun <reified T> AnyCol.isType(): Boolean = type() == typeOf<T>()

/** Returns `true` when this column's type is a subtype of `Number?` */
public fun AnyCol.isNumber(): Boolean = isSubtypeOf<Number?>()

public fun AnyCol.isBigNumber(): Boolean = isSubtypeOf<BigInteger?>() || isSubtypeOf<BigDecimal?>()
/**
* Returns `true` when this column has the (nullable) type of either:
* [Byte], [Short], [Int], [Long], [Float], or [Double].
*/
public fun AnyCol.isPrimitiveNumber(): Boolean = type().withNullability(false) in primitiveNumberTypes

public fun AnyCol.isList(): Boolean = typeClass == List::class

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@ public inline fun <C, reified R> ColumnReference<C>.map(

// region DataColumn

public inline fun <T, reified R> DataColumn<T>.map(
infer: Infer = Infer.Nulls,
crossinline transform: (T) -> R,
): DataColumn<R> {
public inline fun <T, reified R> DataColumn<T>.map(infer: Infer = Infer.Nulls, transform: (T) -> R): DataColumn<R> {
val newValues = Array(size()) { transform(get(it)) }.asList()
return DataColumn.createByType(name(), newValues, typeOf<R>(), infer)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,41 +18,60 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.cast2
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateAll
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOfRow
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.of
import org.jetbrains.kotlinx.dataframe.impl.aggregation.numberColumns
import org.jetbrains.kotlinx.dataframe.impl.aggregation.primitiveNumberColumns
import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns
import org.jetbrains.kotlinx.dataframe.impl.suggestIfNull
import org.jetbrains.kotlinx.dataframe.math.mean
import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes
import kotlin.reflect.KProperty
import kotlin.reflect.full.withNullability
import kotlin.reflect.typeOf

/*
* TODO KDocs:
* Calculating the mean is supported for all primitive number types.
* Nulls are filtered from columns.
* The return type is always Double, Double.NaN for empty input, never null.
* (May introduce loss of precision for Longs).
* For mixed primitive number types, [TwoStepNumbersAggregator] unifies the numbers before calculating the mean.
*/

// region DataColumn

public fun <T : Number> DataColumn<T?>.mean(skipNA: Boolean = skipNA_default): Double =
meanOrNull(skipNA).suggestIfNull("mean")

public fun <T : Number> DataColumn<T?>.meanOrNull(skipNA: Boolean = skipNA_default): Double? =
Aggregators.mean(skipNA).aggregate(this)

public inline fun <T, reified R : Number> DataColumn<T>.meanOf(
skipNA: Boolean = skipNA_default,
noinline expression: (T) -> R?,
): Double = Aggregators.mean(skipNA).cast2<R?, Double>().aggregateOf(this, expression) ?: Double.NaN
): Double =
Aggregators.mean(skipNA)
.cast2<R?, Double>()
.aggregateOf(this, expression)

// endregion

// region DataRow

public fun AnyRow.rowMean(skipNA: Boolean = skipNA_default): Double =
values().filterIsInstance<Number>().map { it.toDouble() }.mean(skipNA)

public inline fun <reified T : Number> AnyRow.rowMeanOf(): Double = values().filterIsInstance<T>().mean(typeOf<T>())
Aggregators.mean(skipNA).aggregateOfRow(this) {
colsOf<Number?> { it.isPrimitiveNumber() }
}

public inline fun <reified T : Number?> AnyRow.rowMeanOf(skipNA: Boolean = skipNA_default): Double {
require(typeOf<T>().withNullability(false) in primitiveNumberTypes) {
"Type ${T::class.simpleName} is not a primitive number type. Mean only supports primitive number types."
}
return Aggregators.mean(skipNA)
.aggregateOfRow(this) { colsOf<T>() }
}

// endregion

// region DataFrame

public fun <T> DataFrame<T>.mean(skipNA: Boolean = skipNA_default): DataRow<T> = meanFor(skipNA, numberColumns())
public fun <T> DataFrame<T>.mean(skipNA: Boolean = skipNA_default): DataRow<T> =
meanFor(skipNA, primitiveNumberColumns())

public fun <T, C : Number> DataFrame<T>.meanFor(
skipNA: Boolean = skipNA_default,
Expand All @@ -77,7 +96,7 @@ public fun <T, C : Number> DataFrame<T>.meanFor(
public fun <T, C : Number> DataFrame<T>.mean(
skipNA: Boolean = skipNA_default,
columns: ColumnsSelector<T, C?>,
): Double = Aggregators.mean(skipNA).aggregateAll(this, columns) as Double? ?: Double.NaN
): Double = Aggregators.mean(skipNA).aggregateAll(this, columns)

public fun <T> DataFrame<T>.mean(vararg columns: String, skipNA: Boolean = skipNA_default): Double =
mean(skipNA) { columns.toNumberColumns() }
Expand All @@ -95,14 +114,15 @@ public fun <T, C : Number> DataFrame<T>.mean(vararg columns: KProperty<C?>, skip
public inline fun <T, reified D : Number> DataFrame<T>.meanOf(
skipNA: Boolean = skipNA_default,
noinline expression: RowExpression<T, D?>,
): Double = Aggregators.mean(skipNA).of(this, expression) ?: Double.NaN
): Double = Aggregators.mean(skipNA).of(this, expression)

// endregion

// region GroupBy
@Refine
@Interpretable("GroupByMean1")
public fun <T> Grouped<T>.mean(skipNA: Boolean = skipNA_default): DataFrame<T> = meanFor(skipNA, numberColumns())
public fun <T> Grouped<T>.mean(skipNA: Boolean = skipNA_default): DataFrame<T> =
meanFor(skipNA, primitiveNumberColumns())

@Refine
@Interpretable("GroupByMean0")
Expand Down Expand Up @@ -167,7 +187,7 @@ public inline fun <T, reified R : Number> Grouped<T>.meanOf(
// region Pivot

public fun <T> Pivot<T>.mean(skipNA: Boolean = skipNA_default, separate: Boolean = false): DataRow<T> =
meanFor(skipNA, separate, numberColumns())
meanFor(skipNA, separate, primitiveNumberColumns())

public fun <T, C : Number> Pivot<T>.meanFor(
skipNA: Boolean = skipNA_default,
Expand Down Expand Up @@ -210,7 +230,7 @@ public inline fun <T, reified R : Number> Pivot<T>.meanOf(
// region PivotGroupBy

public fun <T> PivotGroupBy<T>.mean(separate: Boolean = false, skipNA: Boolean = skipNA_default): DataFrame<T> =
meanFor(skipNA, separate, numberColumns())
meanFor(skipNA, separate, primitiveNumberColumns())

public fun <T, C : Number> PivotGroupBy<T>.meanFor(
skipNA: Boolean = skipNA_default,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,17 @@ public fun DataColumn<Any>.asNumbers(): ValueColumn<Number> {
return this as ValueColumn<Number>
}

public fun <T> DataColumn<T>.asComparable(): DataColumn<Comparable<T>> {
public fun <T : Any> DataColumn<T>.asComparable(): DataColumn<Comparable<T>> {
require(valuesAreComparable())
return this as DataColumn<Comparable<T>>
}

@JvmName("asComparableNullable")
public fun <T : Any?> DataColumn<T?>.asComparable(): DataColumn<Comparable<T>?> {
require(valuesAreComparable())
return this as DataColumn<Comparable<T>?>
}

public fun <T> ColumnReference<T?>.castToNotNullable(): ColumnReference<T> = cast()

public fun <T> DataColumn<T?>.castToNotNullable(): DataColumn<T> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,18 +234,21 @@ internal fun Iterable<KClass<*>>.unifiedNumberClass(
* or calculated with [Iterable.unifiedNumberType] from the iterable's elements if not explicitly specified.
*
* @param commonNumberType The desired common numeric type to convert the elements to.
* This is determined by default using the types of the elements in the iterable.
* By default, (or if `null`), this is determined using the types of the elements in the iterable.
* @return A new iterable of numbers where each element is converted to the specified or inferred common number type.
* @throws IllegalStateException if an element cannot be converted to the common number type.
* @see UnifyingNumbers
*/
@Suppress("UNCHECKED_CAST")
internal fun Iterable<Number>.convertToUnifiedNumberType(
@JvmName("convertNullableIterableToUnifiedNumberType")
internal fun Iterable<Number?>.convertToUnifiedNumberType(
options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT,
commonNumberType: KType = this.types().unifiedNumberType(options),
): Iterable<Number> {
commonNumberType: KType? = null,
): Iterable<Number?> {
val commonNumberType = commonNumberType ?: this.filterNotNull().types().unifiedNumberType(options)
val converter = createConverter(typeOf<Number>(), commonNumberType)!! as (Number) -> Number?
return map {
if (it == null) return@map null
converter(it) ?: error("Can not convert $it to $commonNumberType")
}
}
Expand All @@ -255,23 +258,62 @@ internal fun Iterable<Number>.convertToUnifiedNumberType(
* or calculated with [Iterable.unifiedNumberType][kotlin.collections.Iterable.unifiedNumberType] from the iterable's elements if not explicitly specified.
*
* @param commonNumberType The desired common numeric type to convert the elements to.
* This is determined by default using the types of the elements in the iterable.
* By default, (or if `null`), this is determined using the types of the elements in the iterable.
* @return A new iterable of numbers where each element is converted to the specified or inferred common number type.
* @throws IllegalStateException if an element cannot be converted to the common number type.
* @see UnifyingNumbers */
@JvmName("convertToUnifiedNumberTypeSequence")
@Suppress("UNCHECKED_CAST")
internal fun Sequence<Number>.convertToUnifiedNumberType(
@JvmName("convertIterableToUnifiedNumberType")
internal fun Iterable<Number>.convertToUnifiedNumberType(
options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT,
commonNumberType: KType? = null,
): Iterable<Number> =
(this as Iterable<Number?>)
.convertToUnifiedNumberType(options, commonNumberType) as Iterable<Number>

/** Converts the elements of the given iterable of numbers into a common numeric type based on complexity.
* The common numeric type is determined using the provided [commonNumberType] parameter
* or calculated with [Iterable.unifiedNumberType][kotlin.collections.Iterable.unifiedNumberType] from the iterable's elements if not explicitly specified.
*
* @param commonNumberType The desired common numeric type to convert the elements to.
* By default, (or if `null`), this is determined using the types of the elements in the iterable.
* @return A new iterable of numbers where each element is converted to the specified or inferred common number type.
* @throws IllegalStateException if an element cannot be converted to the common number type.
* @see UnifyingNumbers */
@Suppress("UNCHECKED_CAST")
@JvmName("convertNullableSequenceToUnifiedNumberType")
internal fun Sequence<Number?>.convertToUnifiedNumberType(
options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT,
commonNumberType: KType = asIterable().types().unifiedNumberType(options),
): Sequence<Number> {
commonNumberType: KType? = null,
): Sequence<Number?> {
val commonNumberType = commonNumberType ?: this.filterNotNull().asIterable().types().unifiedNumberType(options)
val converter = createConverter(typeOf<Number>(), commonNumberType)!! as (Number) -> Number?
return map {
if (it == null) return@map null
converter(it) ?: error("Can not convert $it to $commonNumberType")
}
}

internal val primitiveNumberTypes =
/** Converts the elements of the given iterable of numbers into a common numeric type based on complexity.
* The common numeric type is determined using the provided [commonNumberType] parameter
* or calculated with [Iterable.unifiedNumberType][kotlin.collections.Iterable.unifiedNumberType] from the iterable's elements if not explicitly specified.
*
* @param commonNumberType The desired common numeric type to convert the elements to.
* By default, (or if `null`), this is determined using the types of the elements in the iterable.
* @return A new iterable of numbers where each element is converted to the specified or inferred common number type.
* @throws IllegalStateException if an element cannot be converted to the common number type.
* @see UnifyingNumbers */
@Suppress("UNCHECKED_CAST")
@JvmName("convert=SequenceToUnifiedNumberType")
internal fun Sequence<Number>.convertToUnifiedNumberType(
options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT,
commonNumberType: KType? = null,
): Sequence<Number> =
(this as Sequence<Number?>)
.convertToUnifiedNumberType(options, commonNumberType) as Sequence<Number>

@PublishedApi
internal val primitiveNumberTypes: Set<KType> =
setOf(
typeOf<Byte>(),
typeOf<Short>(),
Expand All @@ -280,3 +322,11 @@ internal val primitiveNumberTypes =
typeOf<Float>(),
typeOf<Double>(),
)

internal fun Any.isPrimitiveNumber(): Boolean =
this is Byte ||
this is Short ||
this is Int ||
this is Long ||
this is Float ||
this is Double
Loading