From ab976a79324b443a845898825c0f6d9023f4a175 Mon Sep 17 00:00:00 2001 From: Jolan Rensen Date: Mon, 10 Mar 2025 20:50:53 +0100 Subject: [PATCH 01/11] marked aggregateBy for removal --- .../kotlinx/dataframe/impl/aggregation/modes/aggregateBy.kt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/aggregateBy.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/aggregateBy.kt index 239ed236de..659180a3aa 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/aggregateBy.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/aggregateBy.kt @@ -3,12 +3,14 @@ package org.jetbrains.kotlinx.dataframe.impl.aggregation.modes import org.jetbrains.kotlinx.dataframe.DataFrame import org.jetbrains.kotlinx.dataframe.DataFrameExpression import org.jetbrains.kotlinx.dataframe.DataRow +import org.jetbrains.kotlinx.dataframe.annotations.CandidateForRemoval import org.jetbrains.kotlinx.dataframe.api.GroupBy import org.jetbrains.kotlinx.dataframe.api.Grouped import org.jetbrains.kotlinx.dataframe.api.cast import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregateInternal import org.jetbrains.kotlinx.dataframe.impl.namedValues +@CandidateForRemoval internal fun Grouped.aggregateBy(body: DataFrameExpression?>): DataFrame { require(this is GroupBy<*, T>) val keyColumns = keys.columnNames().toSet() From bcf7231f708868460eff73bc5de423b10441294d Mon Sep 17 00:00:00 2001 From: Jolan Rensen Date: Mon, 10 Mar 2025 21:36:16 +0100 Subject: [PATCH 02/11] redid implementation of `mean`. It's now based on TwoStepNumbersAggregator, such that mixed numbers are unified first. Big number support is dropped. Created generic rowX() function aggregateOfRow(), used by rowMean() and rowMeanOf() --- .../kotlinx/dataframe/api/DataColumnType.kt | 3 + .../jetbrains/kotlinx/dataframe/api/mean.kt | 20 +++-- .../kotlinx/dataframe/impl/NumberTypeUtils.kt | 3 +- .../aggregation/aggregators/Aggregators.kt | 8 +- .../dataframe/impl/aggregation/modes/row.kt | 23 ++++++ .../jetbrains/kotlinx/dataframe/math/mean.kt | 76 ++++++++----------- .../statistics/{BasicMathTests.kt => mean.kt} | 8 +- 7 files changed, 80 insertions(+), 61 deletions(-) create mode 100644 core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/row.kt rename core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/{BasicMathTests.kt => mean.kt} (77%) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataColumnType.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataColumnType.kt index 8547f2a060..11bc564346 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataColumnType.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataColumnType.kt @@ -7,6 +7,7 @@ import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup import org.jetbrains.kotlinx.dataframe.columns.ColumnKind import org.jetbrains.kotlinx.dataframe.columns.FrameColumn import org.jetbrains.kotlinx.dataframe.columns.ValueColumn +import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes import org.jetbrains.kotlinx.dataframe.type import org.jetbrains.kotlinx.dataframe.typeClass import org.jetbrains.kotlinx.dataframe.util.IS_COMPARABLE @@ -52,6 +53,8 @@ public fun AnyCol.isNumber(): Boolean = isSubtypeOf() public fun AnyCol.isBigNumber(): Boolean = isSubtypeOf() || isSubtypeOf() +public fun AnyCol.isPrimitiveNumber(): Boolean = type().withNullability(false) in primitiveNumberTypes + public fun AnyCol.isList(): Boolean = typeClass == List::class /** @include [valuesAreComparable] */ diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt index 994cbf27db..92bf8344ab 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt @@ -16,11 +16,12 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.cast2 import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateAll import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf +import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOfRow import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.of import org.jetbrains.kotlinx.dataframe.impl.aggregation.numberColumns import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns +import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes import org.jetbrains.kotlinx.dataframe.impl.suggestIfNull -import org.jetbrains.kotlinx.dataframe.math.mean import kotlin.reflect.KProperty import kotlin.reflect.typeOf @@ -42,9 +43,18 @@ public inline fun DataColumn.meanOf( // region DataRow public fun AnyRow.rowMean(skipNA: Boolean = skipNA_default): Double = - values().filterIsInstance().map { it.toDouble() }.mean(skipNA) - -public inline fun AnyRow.rowMeanOf(): Double = values().filterIsInstance().mean(typeOf()) + Aggregators.mean(skipNA).aggregateOfRow(this) { + colsOf { it.isPrimitiveNumber() } + } ?: Double.NaN + +public inline fun AnyRow.rowMeanOf(skipNA: Boolean = skipNA_default): Double { + require(typeOf() in primitiveNumberTypes) { + "Type ${T::class.simpleName} is not a primitive number type. Mean only supports primitive number types." + } + return Aggregators.mean(skipNA) + .aggregateOfRow(this) { colsOf() } + ?: Double.NaN +} // endregion @@ -75,7 +85,7 @@ public fun DataFrame.meanFor( public fun DataFrame.mean( skipNA: Boolean = skipNA_default, columns: ColumnsSelector, -): Double = Aggregators.mean(skipNA).aggregateAll(this, columns) as Double? ?: Double.NaN +): Double = Aggregators.mean(skipNA).aggregateAll(this, columns) ?: Double.NaN public fun DataFrame.mean(vararg columns: String, skipNA: Boolean = skipNA_default): Double = mean(skipNA) { columns.toNumberColumns() } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt index fca8cb23e7..4db0d6d603 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt @@ -202,7 +202,8 @@ internal fun Sequence.convertToUnifiedNumberType( } } -internal val primitiveNumberTypes = +@PublishedApi +internal val primitiveNumberTypes: Set = setOf( typeOf(), typeOf(), diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt index 8c677a0990..9fa6e115d5 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt @@ -110,11 +110,9 @@ internal object Aggregators { // step one: T: Number? -> Double // step two: Double -> Double val mean by withOneOption { skipNA: Boolean -> - twoStepChangingType( - getReturnTypeOrNull = meanTypeConversion, - stepOneAggregator = { type -> mean(type, skipNA) }, - stepTwoAggregator = { mean(skipNA) }, - ) + twoStepForNumbers(meanTypeConversion) { type -> + mean(type, skipNA) + } } // T: Comparable? -> T diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/row.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/row.kt new file mode 100644 index 0000000000..5cdba355ea --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/row.kt @@ -0,0 +1,23 @@ +package org.jetbrains.kotlinx.dataframe.impl.aggregation.modes + +import org.jetbrains.kotlinx.dataframe.AnyRow +import org.jetbrains.kotlinx.dataframe.ColumnsSelector +import org.jetbrains.kotlinx.dataframe.api.getColumns +import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator + +/** + * Generic function to apply an [Aggregator] ([this]) to aggregate values of a row. + * + * [Aggregator.aggregateCalculatingType] is used to deal with mixed types. + * + * @param row a row to aggregate + * @param columns selector of which columns inside the [row] to aggregate + */ +@PublishedApi +internal fun Aggregator.aggregateOfRow(row: AnyRow, columns: ColumnsSelector<*, V?>): R? { + val filteredColumns = row.df().getColumns(columns) + return aggregateCalculatingType( + values = filteredColumns.mapNotNull { row[it] }, + valueTypes = filteredColumns.map { it.type() }.toSet(), + ) +} diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/mean.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/mean.kt index fc5fb70b70..352ddacae5 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/mean.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/mean.kt @@ -1,7 +1,9 @@ package org.jetbrains.kotlinx.dataframe.math +import io.github.oshai.kotlinlogging.KotlinLogging import org.jetbrains.kotlinx.dataframe.api.skipNA_default import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.CalculateReturnTypeOrNull +import org.jetbrains.kotlinx.dataframe.impl.nothingType import org.jetbrains.kotlinx.dataframe.impl.renderType import java.math.BigDecimal import java.math.BigInteger @@ -9,6 +11,8 @@ import kotlin.reflect.KType import kotlin.reflect.full.withNullability import kotlin.reflect.typeOf +private val logger = KotlinLogging.logger { } + @PublishedApi internal fun Iterable.mean(type: KType, skipNA: Boolean = skipNA_default): Double = asSequence().mean(type, skipNA) @@ -18,30 +22,36 @@ internal fun Sequence.mean(type: KType, skipNA: Boolean = skipN if (type.isMarkedNullable) { return filterNotNull().mean(type.withNullability(false), skipNA) } - return when (type.classifier) { - Double::class -> (this as Sequence).mean(skipNA) - - Float::class -> (this as Sequence).mean(skipNA) + return when (type.withNullability(false)) { + typeOf() -> (this as Sequence).mean(skipNA) - Int::class -> (this as Sequence).map { it.toDouble() }.mean(false) + typeOf() -> (this as Sequence).mean(skipNA) - // for integer values NA is not possible - Short::class -> (this as Sequence).map { it.toDouble() }.mean(false) + typeOf() -> (this as Sequence).map { it.toDouble() }.mean(false) - Byte::class -> (this as Sequence).map { it.toDouble() }.mean(false) + typeOf() -> (this as Sequence).map { it.toDouble() }.mean(false) - Long::class -> (this as Sequence).map { it.toDouble() }.mean(false) + typeOf() -> (this as Sequence).map { it.toDouble() }.mean(false) - BigInteger::class -> (this as Sequence).map { it.toDouble() }.mean(false) + typeOf() -> { + logger.warn { "Converting Longs to Doubles to calculate the mean, loss of precision may occur." } + (this as Sequence).map { it.toDouble() }.mean(false) + } - BigDecimal::class -> (this as Sequence).map { it.toDouble() }.mean(skipNA) + typeOf(), typeOf() -> + throw IllegalArgumentException( + "Cannot calculate the mean for big numbers in DataFrame. Only primitive numbers are supported.", + ) - Number::class -> (this as Sequence).map { it.toDouble() }.mean(skipNA) + typeOf() -> + error("Encountered non-specific Number type in mean function. This should not occur.") // this means the sequence is empty - Nothing::class -> Double.NaN + nothingType -> Double.NaN - else -> throw IllegalArgumentException("Unable to compute the mean for type ${renderType(type)}") + else -> throw IllegalArgumentException( + "Unable to compute the mean for ${renderType(type)}, Only primitive numbers are supported.", + ) } } @@ -94,7 +104,7 @@ internal fun Iterable.mean(skipNA: Boolean = skipNA_default): Double = as @JvmName("intMean") internal fun Iterable.mean(): Double = if (this is Collection) { - if (size > 0) sumOf { it.toDouble() } / size else Double.NaN + if (isNotEmpty()) sumOf { it.toDouble() } / size else Double.NaN } else { var count = 0 val sum = sumOf { @@ -107,7 +117,7 @@ internal fun Iterable.mean(): Double = @JvmName("shortMean") internal fun Iterable.mean(): Double = if (this is Collection) { - if (size > 0) sumOf { it.toDouble() } / size else Double.NaN + if (isNotEmpty()) sumOf { it.toDouble() } / size else Double.NaN } else { var count = 0 val sum = sumOf { @@ -120,7 +130,7 @@ internal fun Iterable.mean(): Double = @JvmName("byteMean") internal fun Iterable.mean(): Double = if (this is Collection) { - if (size > 0) sumOf { it.toDouble() } / size else Double.NaN + if (isNotEmpty()) sumOf { it.toDouble() } / size else Double.NaN } else { var count = 0 val sum = sumOf { @@ -133,35 +143,7 @@ internal fun Iterable.mean(): Double = @JvmName("longMean") internal fun Iterable.mean(): Double = if (this is Collection) { - if (size > 0) sumOf { it.toDouble() } / size else Double.NaN - } else { - var count = 0 - val sum = sumOf { - count++ - it.toDouble() - } - if (count > 0) sum / count else Double.NaN - } - -// TODO result is Double, but should be BigDecimal, Issue #558 -@JvmName("bigIntegerMean") -internal fun Iterable.mean(): Double = - if (this is Collection) { - if (size > 0) sumOf { it.toDouble() } / size else Double.NaN - } else { - var count = 0 - val sum = sumOf { - count++ - it.toDouble() - } - if (count > 0) sum / count else Double.NaN - } - -// TODO result is Double, but should be BigDecimal, Issue #558 -@JvmName("bigDecimalMean") -internal fun Iterable.mean(): Double = - if (this is Collection) { - if (size > 0) sum().toDouble() / size else Double.NaN + if (isNotEmpty()) sumOf { it.toDouble() } / size else Double.NaN } else { var count = 0 val sum = sumOf { @@ -169,4 +151,6 @@ internal fun Iterable.mean(): Double = it.toDouble() } if (count > 0) sum / count else Double.NaN + }.also { + logger.warn { "Converting Longs to Doubles to calculate the mean, loss of precision may occur." } } diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/BasicMathTests.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/mean.kt similarity index 77% rename from core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/BasicMathTests.kt rename to core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/mean.kt index 53121d1a9f..16ad780fac 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/BasicMathTests.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/mean.kt @@ -9,18 +9,18 @@ import org.jetbrains.kotlinx.dataframe.impl.nothingType import org.junit.Test import kotlin.reflect.typeOf -class BasicMathTests { +class MeanTests { @Test fun `type for column with mixed numbers`() { - val col = columnOf(10, 10.0, null) + val col = columnOf(10, 10.0, null) col.type() shouldBe typeOf() } @Test fun `mean with nans and nulls`() { - columnOf(10, 20, Double.NaN, null).mean().shouldBeNaN() - columnOf(10, 20, Double.NaN, null).mean(skipNA = true) shouldBe 15 + columnOf(10, 20, Double.NaN, null).mean().shouldBeNaN() + columnOf(10, 20, Double.NaN, null).mean(skipNA = true) shouldBe 15 DataColumn.createValueColumn("", emptyList(), nothingType(false)).mean().shouldBeNaN() DataColumn.createValueColumn("", listOf(null), nothingType(true)).mean().shouldBeNaN() From 15a2df5b6f64250f42cb8da953788f9a1e9d20e0 Mon Sep 17 00:00:00 2001 From: Jolan Rensen Date: Mon, 10 Mar 2025 21:36:37 +0100 Subject: [PATCH 03/11] removed big numbers from describe() --- .../kotlinx/dataframe/impl/api/describe.kt | 17 +++++-------- .../kotlinx/dataframe/api/describe.kt | 25 ++++++++----------- 2 files changed, 17 insertions(+), 25 deletions(-) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/describe.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/describe.kt index 254a598a86..672381f95a 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/describe.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/describe.kt @@ -13,6 +13,7 @@ import org.jetbrains.kotlinx.dataframe.api.asNumbers import org.jetbrains.kotlinx.dataframe.api.cast import org.jetbrains.kotlinx.dataframe.api.concat import org.jetbrains.kotlinx.dataframe.api.isNumber +import org.jetbrains.kotlinx.dataframe.api.isPrimitiveNumber import org.jetbrains.kotlinx.dataframe.api.map import org.jetbrains.kotlinx.dataframe.api.maxOrNull import org.jetbrains.kotlinx.dataframe.api.mean @@ -29,7 +30,6 @@ import org.jetbrains.kotlinx.dataframe.columns.size import org.jetbrains.kotlinx.dataframe.columns.values import org.jetbrains.kotlinx.dataframe.impl.columns.addPath import org.jetbrains.kotlinx.dataframe.impl.columns.asAnyFrameColumn -import org.jetbrains.kotlinx.dataframe.impl.isBigNumber import org.jetbrains.kotlinx.dataframe.impl.renderType import org.jetbrains.kotlinx.dataframe.index import org.jetbrains.kotlinx.dataframe.kind @@ -38,7 +38,7 @@ import org.jetbrains.kotlinx.dataframe.type internal fun describeImpl(cols: List): DataFrame { val allCols = cols.collectAll(false) - val hasNumericCols = allCols.any { it.isNumber() } + val hasNumericCols = allCols.any { it.isPrimitiveNumber() } val hasComparableCols = allCols.any { it.valuesAreComparable() } val hasLongPaths = allCols.any { it.path().size > 1 } var df = allCols.toDataFrame { @@ -56,8 +56,8 @@ internal fun describeImpl(cols: List): DataFrame { ?.key } if (hasNumericCols) { - ColumnDescription::mean from { if (it.isNumber()) it.asNumbers().mean() else null } - ColumnDescription::std from { if (it.isNumber()) it.asNumbers().std() else null } + ColumnDescription::mean from { if (it.isPrimitiveNumber()) it.asNumbers().mean() else null } + ColumnDescription::std from { if (it.isPrimitiveNumber()) it.asNumbers().std() else null } } if (hasComparableCols || hasNumericCols) { ColumnDescription::min from inferType { @@ -115,13 +115,8 @@ private fun DataColumn.convertToComparableOrNull(): DataColumn asComparable() - // Found incomparable number types, convert all to Double or BigDecimal first - isNumber() -> - if (any { it?.isBigNumber() == true }) { - map { (it as Number?)?.toBigDecimal() } - } else { - map { (it as Number?)?.toDouble() } - }.cast() + // Found incomparable number types, convert all to Double first + isPrimitiveNumber() -> map { (it as Number?)?.toDouble() }.cast() else -> null } diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/describe.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/describe.kt index 71f049f5ca..7bb70b3659 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/describe.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/describe.kt @@ -4,7 +4,6 @@ import io.kotest.matchers.doubles.shouldBeNaN import io.kotest.matchers.shouldBe import org.jetbrains.kotlinx.dataframe.alsoDebug import org.junit.Test -import java.math.BigDecimal class DescribeTests { @@ -17,15 +16,13 @@ class DescribeTests { @Test fun `describe nullable Number column`() { - val a by columnOf( + val a by columnOf( 1, 2.0, 3f, 4L, 5.toShort(), 6.toByte(), - 7.toBigInteger(), - 8.toBigDecimal(), null, ) val df = dataFrameOf(a) @@ -35,18 +32,18 @@ class DescribeTests { with(describe) { name shouldBe "a" type shouldBe "Number?" - count shouldBe 9 - unique shouldBe 9 + count shouldBe 7 + unique shouldBe 7 nulls shouldBe 1 top shouldBe 1 freq shouldBe 1 - mean shouldBe 4.5 - std shouldBe 2.449489742783178 - min shouldBe 1.toBigDecimal() - (p25 as BigDecimal).setScale(2) shouldBe 2.75.toBigDecimal() - median shouldBe 4.toBigDecimal() - p75 shouldBe 6.25.toBigDecimal() - max shouldBe 8.toBigDecimal() + this.mean shouldBe 3.5 + std shouldBe 1.8708286933869707 + min shouldBe 1.0 + p25 shouldBe 2.25 + median shouldBe 3.5 + p75 shouldBe 4.75 + max shouldBe 6.0 } } @@ -65,7 +62,7 @@ class DescribeTests { nulls shouldBe 0 top shouldBe 1 freq shouldBe 1 - mean.shouldBeNaN() + this.mean.shouldBeNaN() std.shouldBeNaN() min shouldBe 1.0 // TODO should be NaN too? p25 shouldBe 1.75 From 3ffb2ef6a20182ed600725fc265c43beab4ca1f8 Mon Sep 17 00:00:00 2001 From: Jolan Rensen Date: Tue, 11 Mar 2025 14:48:56 +0100 Subject: [PATCH 04/11] small extension to convertToUnifiedNumberType --- core/api/core.api | 10 +++ .../kotlinx/dataframe/api/DataColumnType.kt | 3 + .../jetbrains/kotlinx/dataframe/api/map.kt | 5 +- .../jetbrains/kotlinx/dataframe/api/mean.kt | 20 +++-- .../kotlinx/dataframe/api/typeConversions.kt | 8 +- .../kotlinx/dataframe/impl/NumberTypeUtils.kt | 70 ++++++++++++++--- .../aggregation/aggregators/Aggregators.kt | 8 +- .../kotlinx/dataframe/impl/api/describe.kt | 23 +++--- .../jetbrains/kotlinx/dataframe/math/mean.kt | 76 ++++++++----------- .../kotlinx/dataframe/api/describe.kt | 25 +++--- .../kotlinx/dataframe/api/statistics.kt | 29 +++---- .../dataframe/statistics/BasicMathTests.kt | 28 ------- .../kotlinx/dataframe/impl/NumberTypeUtils.kt | 49 ++++++++++-- 13 files changed, 209 insertions(+), 145 deletions(-) delete mode 100644 core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/BasicMathTests.kt diff --git a/core/api/core.api b/core/api/core.api index c146483135..befaf6903f 100644 --- a/core/api/core.api +++ b/core/api/core.api @@ -1975,6 +1975,7 @@ public final class org/jetbrains/kotlinx/dataframe/api/DataColumnTypeKt { public static final fun isList (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z public static final fun isNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z public static final fun isPrimitive (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z + public static final fun isPrimitiveNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z public static final fun isSubtypeOf (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Lkotlin/reflect/KType;)Z public static final fun isValueColumn (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z public static final fun valuesAreComparable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z @@ -4159,6 +4160,7 @@ public final class org/jetbrains/kotlinx/dataframe/api/TypeConversionsKt { public static final fun asComparable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Lorg/jetbrains/kotlinx/dataframe/DataColumn; public static final fun asComparable (Lorg/jetbrains/kotlinx/dataframe/columns/ColumnSet;)Lorg/jetbrains/kotlinx/dataframe/columns/ColumnSet; public static final fun asComparable (Lorg/jetbrains/kotlinx/dataframe/columns/SingleColumn;)Lorg/jetbrains/kotlinx/dataframe/columns/SingleColumn; + public static final fun asComparableNullable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Lorg/jetbrains/kotlinx/dataframe/DataColumn; public static final fun asDataColumn (Lorg/jetbrains/kotlinx/dataframe/columns/ColumnGroup;)Lorg/jetbrains/kotlinx/dataframe/DataColumn; public static final fun asDataColumn (Lorg/jetbrains/kotlinx/dataframe/columns/FrameColumn;)Lorg/jetbrains/kotlinx/dataframe/DataColumn; public static final fun asDataFrame (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; @@ -5254,6 +5256,10 @@ public final class org/jetbrains/kotlinx/dataframe/impl/ExceptionUtilsKt { public static final fun suggestIfNull (Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/Object; } +public final class org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtilsKt { + public static final fun getPrimitiveNumberTypes ()Ljava/util/Set; +} + public final class org/jetbrains/kotlinx/dataframe/impl/TypeUtilsKt { public static final fun getValuesType (Ljava/util/List;Lkotlin/reflect/KType;Lorg/jetbrains/kotlinx/dataframe/api/Infer;)Lkotlin/reflect/KType; public static final synthetic fun guessValueType (Lkotlin/sequences/Sequence;Lkotlin/reflect/KType;Z)Lkotlin/reflect/KType; @@ -5361,6 +5367,10 @@ public final class org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/OfRowE public static final fun aggregateOfDelegated (Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator;Lorg/jetbrains/kotlinx/dataframe/api/Grouped;Ljava/lang/String;Lkotlin/jvm/functions/Function2;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; } +public final class org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/RowKt { + public static final fun aggregateOfRow (Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator;Lorg/jetbrains/kotlinx/dataframe/DataRow;Lkotlin/jvm/functions/Function2;)Ljava/lang/Object; +} + public final class org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/WithinAllColumnsKt { public static final fun aggregateAll (Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator;Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)Ljava/lang/Object; } diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataColumnType.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataColumnType.kt index 3ba61039ad..811287f910 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataColumnType.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataColumnType.kt @@ -7,6 +7,7 @@ import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup import org.jetbrains.kotlinx.dataframe.columns.ColumnKind import org.jetbrains.kotlinx.dataframe.columns.FrameColumn import org.jetbrains.kotlinx.dataframe.columns.ValueColumn +import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes import org.jetbrains.kotlinx.dataframe.type import org.jetbrains.kotlinx.dataframe.typeClass import org.jetbrains.kotlinx.dataframe.util.IS_COMPARABLE @@ -52,6 +53,8 @@ public fun AnyCol.isNumber(): Boolean = isSubtypeOf() public fun AnyCol.isBigNumber(): Boolean = isSubtypeOf() || isSubtypeOf() +public fun AnyCol.isPrimitiveNumber(): Boolean = type().withNullability(false) in primitiveNumberTypes + public fun AnyCol.isList(): Boolean = typeClass == List::class /** Returns `true` if [this] column is intra-comparable, i.e. diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/map.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/map.kt index 90381b9418..98c00fd6c5 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/map.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/map.kt @@ -31,10 +31,7 @@ public inline fun ColumnReference.map( // region DataColumn -public inline fun DataColumn.map( - infer: Infer = Infer.Nulls, - crossinline transform: (T) -> R, -): DataColumn { +public inline fun DataColumn.map(infer: Infer = Infer.Nulls, transform: (T) -> R): DataColumn { val newValues = Array(size()) { transform(get(it)) }.asList() return DataColumn.createByType(name(), newValues, typeOf(), infer) } diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt index 97dcc70087..a5592b12b4 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt @@ -18,11 +18,12 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.cast2 import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateAll import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf +import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOfRow import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.of import org.jetbrains.kotlinx.dataframe.impl.aggregation.numberColumns import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns +import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes import org.jetbrains.kotlinx.dataframe.impl.suggestIfNull -import org.jetbrains.kotlinx.dataframe.math.mean import kotlin.reflect.KProperty import kotlin.reflect.typeOf @@ -44,9 +45,18 @@ public inline fun DataColumn.meanOf( // region DataRow public fun AnyRow.rowMean(skipNA: Boolean = skipNA_default): Double = - values().filterIsInstance().map { it.toDouble() }.mean(skipNA) - -public inline fun AnyRow.rowMeanOf(): Double = values().filterIsInstance().mean(typeOf()) + Aggregators.mean(skipNA).aggregateOfRow(this) { + colsOf { it.isPrimitiveNumber() } + } ?: Double.NaN + +public inline fun AnyRow.rowMeanOf(skipNA: Boolean = skipNA_default): Double { + require(typeOf() in primitiveNumberTypes) { + "Type ${T::class.simpleName} is not a primitive number type. Mean only supports primitive number types." + } + return Aggregators.mean(skipNA) + .aggregateOfRow(this) { colsOf() } + ?: Double.NaN +} // endregion @@ -77,7 +87,7 @@ public fun DataFrame.meanFor( public fun DataFrame.mean( skipNA: Boolean = skipNA_default, columns: ColumnsSelector, -): Double = Aggregators.mean(skipNA).aggregateAll(this, columns) as Double? ?: Double.NaN +): Double = Aggregators.mean(skipNA).aggregateAll(this, columns) ?: Double.NaN public fun DataFrame.mean(vararg columns: String, skipNA: Boolean = skipNA_default): Double = mean(skipNA) { columns.toNumberColumns() } diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/typeConversions.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/typeConversions.kt index 73940948e5..0aafa38d93 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/typeConversions.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/typeConversions.kt @@ -83,11 +83,17 @@ public fun DataColumn.asNumbers(): ValueColumn { return this as ValueColumn } -public fun DataColumn.asComparable(): DataColumn> { +public fun DataColumn.asComparable(): DataColumn> { require(valuesAreComparable()) return this as DataColumn> } +@JvmName("asComparableNullable") +public fun DataColumn.asComparable(): DataColumn?> { + require(valuesAreComparable()) + return this as DataColumn?> +} + public fun ColumnReference.castToNotNullable(): ColumnReference = cast() public fun DataColumn.castToNotNullable(): DataColumn { diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt index c4e1a9679a..b66ecae103 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt @@ -234,18 +234,21 @@ internal fun Iterable>.unifiedNumberClass( * or calculated with [Iterable.unifiedNumberType] from the iterable's elements if not explicitly specified. * * @param commonNumberType The desired common numeric type to convert the elements to. - * This is determined by default using the types of the elements in the iterable. + * By default, (or if `null`), this is determined using the types of the elements in the iterable. * @return A new iterable of numbers where each element is converted to the specified or inferred common number type. * @throws IllegalStateException if an element cannot be converted to the common number type. * @see UnifyingNumbers */ @Suppress("UNCHECKED_CAST") -internal fun Iterable.convertToUnifiedNumberType( +@JvmName("convertNullableIterableToUnifiedNumberType") +internal fun Iterable.convertToUnifiedNumberType( options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT, - commonNumberType: KType = this.types().unifiedNumberType(options), -): Iterable { + commonNumberType: KType? = null, +): Iterable { + val commonNumberType = commonNumberType ?: this.filterNotNull().types().unifiedNumberType(options) val converter = createConverter(typeOf(), commonNumberType)!! as (Number) -> Number? return map { + if (it == null) return@map null converter(it) ?: error("Can not convert $it to $commonNumberType") } } @@ -255,23 +258,62 @@ internal fun Iterable.convertToUnifiedNumberType( * or calculated with [Iterable.unifiedNumberType][kotlin.collections.Iterable.unifiedNumberType] from the iterable's elements if not explicitly specified. * * @param commonNumberType The desired common numeric type to convert the elements to. - * This is determined by default using the types of the elements in the iterable. + * By default, (or if `null`), this is determined using the types of the elements in the iterable. * @return A new iterable of numbers where each element is converted to the specified or inferred common number type. * @throws IllegalStateException if an element cannot be converted to the common number type. * @see UnifyingNumbers */ -@JvmName("convertToUnifiedNumberTypeSequence") @Suppress("UNCHECKED_CAST") -internal fun Sequence.convertToUnifiedNumberType( +@JvmName("convertIterableToUnifiedNumberType") +internal fun Iterable.convertToUnifiedNumberType( + options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT, + commonNumberType: KType? = null, +): Iterable = + (this as Iterable) + .convertToUnifiedNumberType(options, commonNumberType) as Iterable + +/** Converts the elements of the given iterable of numbers into a common numeric type based on complexity. + * The common numeric type is determined using the provided [commonNumberType] parameter + * or calculated with [Iterable.unifiedNumberType][kotlin.collections.Iterable.unifiedNumberType] from the iterable's elements if not explicitly specified. + * + * @param commonNumberType The desired common numeric type to convert the elements to. + * By default, (or if `null`), this is determined using the types of the elements in the iterable. + * @return A new iterable of numbers where each element is converted to the specified or inferred common number type. + * @throws IllegalStateException if an element cannot be converted to the common number type. + * @see UnifyingNumbers */ +@Suppress("UNCHECKED_CAST") +@JvmName("convertNullableSequenceToUnifiedNumberType") +internal fun Sequence.convertToUnifiedNumberType( options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT, - commonNumberType: KType = asIterable().types().unifiedNumberType(options), -): Sequence { + commonNumberType: KType? = null, +): Sequence { + val commonNumberType = commonNumberType ?: this.filterNotNull().asIterable().types().unifiedNumberType(options) val converter = createConverter(typeOf(), commonNumberType)!! as (Number) -> Number? return map { + if (it == null) return@map null converter(it) ?: error("Can not convert $it to $commonNumberType") } } -internal val primitiveNumberTypes = +/** Converts the elements of the given iterable of numbers into a common numeric type based on complexity. + * The common numeric type is determined using the provided [commonNumberType] parameter + * or calculated with [Iterable.unifiedNumberType][kotlin.collections.Iterable.unifiedNumberType] from the iterable's elements if not explicitly specified. + * + * @param commonNumberType The desired common numeric type to convert the elements to. + * By default, (or if `null`), this is determined using the types of the elements in the iterable. + * @return A new iterable of numbers where each element is converted to the specified or inferred common number type. + * @throws IllegalStateException if an element cannot be converted to the common number type. + * @see UnifyingNumbers */ +@Suppress("UNCHECKED_CAST") +@JvmName("convert=SequenceToUnifiedNumberType") +internal fun Sequence.convertToUnifiedNumberType( + options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT, + commonNumberType: KType? = null, +): Sequence = + (this as Sequence) + .convertToUnifiedNumberType(options, commonNumberType) as Sequence + +@PublishedApi +internal val primitiveNumberTypes: Set = setOf( typeOf(), typeOf(), @@ -280,3 +322,11 @@ internal val primitiveNumberTypes = typeOf(), typeOf(), ) + +internal fun Any.isPrimitiveNumber(): Boolean = + this is Byte || + this is Short || + this is Int || + this is Long || + this is Float || + this is Double diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt index abbdffc575..45ba013cf3 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt @@ -245,11 +245,9 @@ internal object Aggregators { // step one: T: Number? -> Double // step two: Double -> Double val mean by withOneOption { skipNA: Boolean -> - twoStepChangingType( - getReturnTypeOrNull = meanTypeConversion, - stepOneAggregator = { type -> mean(type, skipNA) }, - stepTwoAggregator = { mean(skipNA) }, - ) + twoStepForNumbers(meanTypeConversion) { type -> + mean(type, skipNA) + } } // T: Comparable? -> T diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/describe.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/describe.kt index 254a598a86..08ea27a182 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/describe.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/describe.kt @@ -29,7 +29,7 @@ import org.jetbrains.kotlinx.dataframe.columns.size import org.jetbrains.kotlinx.dataframe.columns.values import org.jetbrains.kotlinx.dataframe.impl.columns.addPath import org.jetbrains.kotlinx.dataframe.impl.columns.asAnyFrameColumn -import org.jetbrains.kotlinx.dataframe.impl.isBigNumber +import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveNumber import org.jetbrains.kotlinx.dataframe.impl.renderType import org.jetbrains.kotlinx.dataframe.index import org.jetbrains.kotlinx.dataframe.kind @@ -111,17 +111,20 @@ private fun List.collectAll(atAnyDepth: Boolean): List = } /** Converts a column to a comparable column if it is not already comparable. */ -private fun DataColumn.convertToComparableOrNull(): DataColumn>? = - when { +@Suppress("UNCHECKED_CAST") +private fun DataColumn.convertToComparableOrNull(): DataColumn?>? { + return when { valuesAreComparable() -> asComparable() - // Found incomparable number types, convert all to Double or BigDecimal first - isNumber() -> - if (any { it?.isBigNumber() == true }) { - map { (it as Number?)?.toBigDecimal() } - } else { - map { (it as Number?)?.toDouble() } - }.cast() + // Found incomparable number types, convert all to Double first + isNumber() -> cast().map { + if (it?.isPrimitiveNumber() == false) { + // Cannot calculate statistics of a non-primitive number type + return@convertToComparableOrNull null + } + it?.toDouble() as Comparable? + } else -> null } +} diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/mean.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/mean.kt index fc5fb70b70..352ddacae5 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/mean.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/mean.kt @@ -1,7 +1,9 @@ package org.jetbrains.kotlinx.dataframe.math +import io.github.oshai.kotlinlogging.KotlinLogging import org.jetbrains.kotlinx.dataframe.api.skipNA_default import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.CalculateReturnTypeOrNull +import org.jetbrains.kotlinx.dataframe.impl.nothingType import org.jetbrains.kotlinx.dataframe.impl.renderType import java.math.BigDecimal import java.math.BigInteger @@ -9,6 +11,8 @@ import kotlin.reflect.KType import kotlin.reflect.full.withNullability import kotlin.reflect.typeOf +private val logger = KotlinLogging.logger { } + @PublishedApi internal fun Iterable.mean(type: KType, skipNA: Boolean = skipNA_default): Double = asSequence().mean(type, skipNA) @@ -18,30 +22,36 @@ internal fun Sequence.mean(type: KType, skipNA: Boolean = skipN if (type.isMarkedNullable) { return filterNotNull().mean(type.withNullability(false), skipNA) } - return when (type.classifier) { - Double::class -> (this as Sequence).mean(skipNA) - - Float::class -> (this as Sequence).mean(skipNA) + return when (type.withNullability(false)) { + typeOf() -> (this as Sequence).mean(skipNA) - Int::class -> (this as Sequence).map { it.toDouble() }.mean(false) + typeOf() -> (this as Sequence).mean(skipNA) - // for integer values NA is not possible - Short::class -> (this as Sequence).map { it.toDouble() }.mean(false) + typeOf() -> (this as Sequence).map { it.toDouble() }.mean(false) - Byte::class -> (this as Sequence).map { it.toDouble() }.mean(false) + typeOf() -> (this as Sequence).map { it.toDouble() }.mean(false) - Long::class -> (this as Sequence).map { it.toDouble() }.mean(false) + typeOf() -> (this as Sequence).map { it.toDouble() }.mean(false) - BigInteger::class -> (this as Sequence).map { it.toDouble() }.mean(false) + typeOf() -> { + logger.warn { "Converting Longs to Doubles to calculate the mean, loss of precision may occur." } + (this as Sequence).map { it.toDouble() }.mean(false) + } - BigDecimal::class -> (this as Sequence).map { it.toDouble() }.mean(skipNA) + typeOf(), typeOf() -> + throw IllegalArgumentException( + "Cannot calculate the mean for big numbers in DataFrame. Only primitive numbers are supported.", + ) - Number::class -> (this as Sequence).map { it.toDouble() }.mean(skipNA) + typeOf() -> + error("Encountered non-specific Number type in mean function. This should not occur.") // this means the sequence is empty - Nothing::class -> Double.NaN + nothingType -> Double.NaN - else -> throw IllegalArgumentException("Unable to compute the mean for type ${renderType(type)}") + else -> throw IllegalArgumentException( + "Unable to compute the mean for ${renderType(type)}, Only primitive numbers are supported.", + ) } } @@ -94,7 +104,7 @@ internal fun Iterable.mean(skipNA: Boolean = skipNA_default): Double = as @JvmName("intMean") internal fun Iterable.mean(): Double = if (this is Collection) { - if (size > 0) sumOf { it.toDouble() } / size else Double.NaN + if (isNotEmpty()) sumOf { it.toDouble() } / size else Double.NaN } else { var count = 0 val sum = sumOf { @@ -107,7 +117,7 @@ internal fun Iterable.mean(): Double = @JvmName("shortMean") internal fun Iterable.mean(): Double = if (this is Collection) { - if (size > 0) sumOf { it.toDouble() } / size else Double.NaN + if (isNotEmpty()) sumOf { it.toDouble() } / size else Double.NaN } else { var count = 0 val sum = sumOf { @@ -120,7 +130,7 @@ internal fun Iterable.mean(): Double = @JvmName("byteMean") internal fun Iterable.mean(): Double = if (this is Collection) { - if (size > 0) sumOf { it.toDouble() } / size else Double.NaN + if (isNotEmpty()) sumOf { it.toDouble() } / size else Double.NaN } else { var count = 0 val sum = sumOf { @@ -133,35 +143,7 @@ internal fun Iterable.mean(): Double = @JvmName("longMean") internal fun Iterable.mean(): Double = if (this is Collection) { - if (size > 0) sumOf { it.toDouble() } / size else Double.NaN - } else { - var count = 0 - val sum = sumOf { - count++ - it.toDouble() - } - if (count > 0) sum / count else Double.NaN - } - -// TODO result is Double, but should be BigDecimal, Issue #558 -@JvmName("bigIntegerMean") -internal fun Iterable.mean(): Double = - if (this is Collection) { - if (size > 0) sumOf { it.toDouble() } / size else Double.NaN - } else { - var count = 0 - val sum = sumOf { - count++ - it.toDouble() - } - if (count > 0) sum / count else Double.NaN - } - -// TODO result is Double, but should be BigDecimal, Issue #558 -@JvmName("bigDecimalMean") -internal fun Iterable.mean(): Double = - if (this is Collection) { - if (size > 0) sum().toDouble() / size else Double.NaN + if (isNotEmpty()) sumOf { it.toDouble() } / size else Double.NaN } else { var count = 0 val sum = sumOf { @@ -169,4 +151,6 @@ internal fun Iterable.mean(): Double = it.toDouble() } if (count > 0) sum / count else Double.NaN + }.also { + logger.warn { "Converting Longs to Doubles to calculate the mean, loss of precision may occur." } } diff --git a/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/describe.kt b/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/describe.kt index 71f049f5ca..3161e3bc7b 100644 --- a/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/describe.kt +++ b/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/describe.kt @@ -4,7 +4,6 @@ import io.kotest.matchers.doubles.shouldBeNaN import io.kotest.matchers.shouldBe import org.jetbrains.kotlinx.dataframe.alsoDebug import org.junit.Test -import java.math.BigDecimal class DescribeTests { @@ -17,15 +16,13 @@ class DescribeTests { @Test fun `describe nullable Number column`() { - val a by columnOf( + val a by columnOf( 1, 2.0, 3f, 4L, 5.toShort(), 6.toByte(), - 7.toBigInteger(), - 8.toBigDecimal(), null, ) val df = dataFrameOf(a) @@ -35,18 +32,18 @@ class DescribeTests { with(describe) { name shouldBe "a" type shouldBe "Number?" - count shouldBe 9 - unique shouldBe 9 + count shouldBe 7 + unique shouldBe 7 nulls shouldBe 1 top shouldBe 1 freq shouldBe 1 - mean shouldBe 4.5 - std shouldBe 2.449489742783178 - min shouldBe 1.toBigDecimal() - (p25 as BigDecimal).setScale(2) shouldBe 2.75.toBigDecimal() - median shouldBe 4.toBigDecimal() - p75 shouldBe 6.25.toBigDecimal() - max shouldBe 8.toBigDecimal() + mean shouldBe 3.5 + std shouldBe 1.8708286933869707 + min shouldBe 1.0 + p25 shouldBe 2.0 + median shouldBe 3.0 + p75 shouldBe 4.0 + max shouldBe 6.0 } } @@ -65,7 +62,7 @@ class DescribeTests { nulls shouldBe 0 top shouldBe 1 freq shouldBe 1 - mean.shouldBeNaN() + this.mean.shouldBeNaN() std.shouldBeNaN() min shouldBe 1.0 // TODO should be NaN too? p25 shouldBe 1.75 diff --git a/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt b/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt index 006b8048b2..1d4d76b637 100644 --- a/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt +++ b/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt @@ -148,14 +148,15 @@ class StatisticsTests { "city", "name", "age", + "weight", "height", - "yearsToRetirement" - ) // TODO: why double values from weight are not in the list? are they not Comparable? + "yearsToRetirement", + ) val median01 = res0["age"][0] as Int median01 shouldBe 22 - //val median02 = res0["weight"][0] as Double - //median02 shouldBe 66.0 + // val median02 = res0["weight"][0] as Double + // median02 shouldBe 66.0 // scenario #1: particular column val res1 = personsDf.groupBy("city").medianFor("age") @@ -276,14 +277,15 @@ class StatisticsTests { "city", "name", "age", + "weight", "height", - "yearsToRetirement" - ) // TODO: why it's working for height and doesn't work for Double column weight + "yearsToRetirement", + ) val min01 = res0["age"][0] as Int min01 shouldBe 15 - //val min02 = res0["weight"][0] as Double - //min02 shouldBe 38.85756039691633 + // val min02 = res0["weight"][0] as Double + // min02 shouldBe 38.85756039691633 // scenario #1: particular column val res1 = personsDf.groupBy("city").minFor("age") @@ -342,7 +344,7 @@ class StatisticsTests { "age", "weight", "height", - "yearsToRetirement" + "yearsToRetirement", ) // TODO: why is here weight presented? looks like inconsitency val min41 = res4["age"][0] as Int @@ -362,12 +364,12 @@ class StatisticsTests { fun `max on GroupBy`() { // scenario #0: all numerical columns val res0 = personsDf.groupBy("city").max() - res0.columnNames() shouldBe listOf("city", "name", "age", "height", "yearsToRetirement") // TODO: DOUBLE weight? + res0.columnNames() shouldBe listOf("city", "name", "age", "weight", "height", "yearsToRetirement") val max01 = res0["age"][0] as Int max01 shouldBe 35 - //val max02 = res0["weight"][0] as Double - //max02 shouldBe 140.0 + // val max02 = res0["weight"][0] as Double + // max02 shouldBe 140.0 // scenario #1: particular column val res1 = personsDf.groupBy("city").maxFor("age") @@ -426,7 +428,7 @@ class StatisticsTests { "age", "weight", "height", - "yearsToRetirement" + "yearsToRetirement", ) // TODO: weight is here? val max41 = res4["age"][0] as Int @@ -442,4 +444,3 @@ class StatisticsTests { max51 shouldBe 35 } } - diff --git a/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/BasicMathTests.kt b/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/BasicMathTests.kt deleted file mode 100644 index 53121d1a9f..0000000000 --- a/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/BasicMathTests.kt +++ /dev/null @@ -1,28 +0,0 @@ -package org.jetbrains.kotlinx.dataframe.statistics - -import io.kotest.matchers.doubles.shouldBeNaN -import io.kotest.matchers.shouldBe -import org.jetbrains.kotlinx.dataframe.DataColumn -import org.jetbrains.kotlinx.dataframe.api.columnOf -import org.jetbrains.kotlinx.dataframe.api.mean -import org.jetbrains.kotlinx.dataframe.impl.nothingType -import org.junit.Test -import kotlin.reflect.typeOf - -class BasicMathTests { - - @Test - fun `type for column with mixed numbers`() { - val col = columnOf(10, 10.0, null) - col.type() shouldBe typeOf() - } - - @Test - fun `mean with nans and nulls`() { - columnOf(10, 20, Double.NaN, null).mean().shouldBeNaN() - columnOf(10, 20, Double.NaN, null).mean(skipNA = true) shouldBe 15 - - DataColumn.createValueColumn("", emptyList(), nothingType(false)).mean().shouldBeNaN() - DataColumn.createValueColumn("", listOf(null), nothingType(true)).mean().shouldBeNaN() - } -} diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt index 4db0d6d603..557c2ae0fd 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt @@ -173,35 +173,60 @@ internal fun Iterable>.unifiedNumberClass( * or calculated with [Iterable.unifiedNumberType] from the iterable's elements if not explicitly specified. * * @param commonNumberType The desired common numeric type to convert the elements to. - * This is determined by default using the types of the elements in the iterable. + * By default, (or if `null`), this is determined using the types of the elements in the iterable. * @return A new iterable of numbers where each element is converted to the specified or inferred common number type. * @throws IllegalStateException if an element cannot be converted to the common number type. * @see UnifyingNumbers */ @Suppress("UNCHECKED_CAST") -internal fun Iterable.convertToUnifiedNumberType( +@JvmName("convertNullableIterableToUnifiedNumberType") +internal fun Iterable.convertToUnifiedNumberType( options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT, - commonNumberType: KType = this.types().unifiedNumberType(options), -): Iterable { + commonNumberType: KType? = null, +): Iterable { + val commonNumberType = commonNumberType ?: this.filterNotNull().types().unifiedNumberType(options) val converter = createConverter(typeOf(), commonNumberType)!! as (Number) -> Number? return map { + if (it == null) return@map null converter(it) ?: error("Can not convert $it to $commonNumberType") } } /** @include [Iterable.convertToUnifiedNumberType] */ -@JvmName("convertToUnifiedNumberTypeSequence") @Suppress("UNCHECKED_CAST") -internal fun Sequence.convertToUnifiedNumberType( +@JvmName("convertIterableToUnifiedNumberType") +internal fun Iterable.convertToUnifiedNumberType( + options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT, + commonNumberType: KType? = null, +): Iterable = + (this as Iterable) + .convertToUnifiedNumberType(options, commonNumberType) as Iterable + +/** @include [Iterable.convertToUnifiedNumberType] */ +@Suppress("UNCHECKED_CAST") +@JvmName("convertNullableSequenceToUnifiedNumberType") +internal fun Sequence.convertToUnifiedNumberType( options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT, - commonNumberType: KType = asIterable().types().unifiedNumberType(options), -): Sequence { + commonNumberType: KType? = null, +): Sequence { + val commonNumberType = commonNumberType ?: this.filterNotNull().asIterable().types().unifiedNumberType(options) val converter = createConverter(typeOf(), commonNumberType)!! as (Number) -> Number? return map { + if (it == null) return@map null converter(it) ?: error("Can not convert $it to $commonNumberType") } } +/** @include [Iterable.convertToUnifiedNumberType] */ +@Suppress("UNCHECKED_CAST") +@JvmName("convert=SequenceToUnifiedNumberType") +internal fun Sequence.convertToUnifiedNumberType( + options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT, + commonNumberType: KType? = null, +): Sequence = + (this as Sequence) + .convertToUnifiedNumberType(options, commonNumberType) as Sequence + @PublishedApi internal val primitiveNumberTypes: Set = setOf( @@ -212,3 +237,11 @@ internal val primitiveNumberTypes: Set = typeOf(), typeOf(), ) + +internal fun Any.isPrimitiveNumber(): Boolean = + this is Byte || + this is Short || + this is Int || + this is Long || + this is Float || + this is Double From 499334d21ae7d48d405bb123c227e58eb857e742 Mon Sep 17 00:00:00 2001 From: Jolan Rensen Date: Tue, 11 Mar 2025 14:49:21 +0100 Subject: [PATCH 05/11] fixed describe and tests --- .../jetbrains/kotlinx/dataframe/api/map.kt | 5 +---- .../kotlinx/dataframe/api/typeConversions.kt | 8 ++++++- .../kotlinx/dataframe/impl/api/describe.kt | 22 +++++++++++++------ .../kotlinx/dataframe/api/describe.kt | 8 +++---- 4 files changed, 27 insertions(+), 16 deletions(-) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/map.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/map.kt index 90381b9418..98c00fd6c5 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/map.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/map.kt @@ -31,10 +31,7 @@ public inline fun ColumnReference.map( // region DataColumn -public inline fun DataColumn.map( - infer: Infer = Infer.Nulls, - crossinline transform: (T) -> R, -): DataColumn { +public inline fun DataColumn.map(infer: Infer = Infer.Nulls, transform: (T) -> R): DataColumn { val newValues = Array(size()) { transform(get(it)) }.asList() return DataColumn.createByType(name(), newValues, typeOf(), infer) } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/typeConversions.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/typeConversions.kt index 4dd9fcab52..274a2d257d 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/typeConversions.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/typeConversions.kt @@ -83,11 +83,17 @@ public fun DataColumn.asNumbers(): ValueColumn { return this as ValueColumn } -public fun DataColumn.asComparable(): DataColumn> { +public fun DataColumn.asComparable(): DataColumn> { require(valuesAreComparable()) return this as DataColumn> } +@JvmName("asComparableNullable") +public fun DataColumn.asComparable(): DataColumn?> { + require(valuesAreComparable()) + return this as DataColumn?> +} + public fun ColumnReference.castToNotNullable(): ColumnReference = cast() public fun DataColumn.castToNotNullable(): DataColumn { diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/describe.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/describe.kt index 672381f95a..08ea27a182 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/describe.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/describe.kt @@ -13,7 +13,6 @@ import org.jetbrains.kotlinx.dataframe.api.asNumbers import org.jetbrains.kotlinx.dataframe.api.cast import org.jetbrains.kotlinx.dataframe.api.concat import org.jetbrains.kotlinx.dataframe.api.isNumber -import org.jetbrains.kotlinx.dataframe.api.isPrimitiveNumber import org.jetbrains.kotlinx.dataframe.api.map import org.jetbrains.kotlinx.dataframe.api.maxOrNull import org.jetbrains.kotlinx.dataframe.api.mean @@ -30,6 +29,7 @@ import org.jetbrains.kotlinx.dataframe.columns.size import org.jetbrains.kotlinx.dataframe.columns.values import org.jetbrains.kotlinx.dataframe.impl.columns.addPath import org.jetbrains.kotlinx.dataframe.impl.columns.asAnyFrameColumn +import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveNumber import org.jetbrains.kotlinx.dataframe.impl.renderType import org.jetbrains.kotlinx.dataframe.index import org.jetbrains.kotlinx.dataframe.kind @@ -38,7 +38,7 @@ import org.jetbrains.kotlinx.dataframe.type internal fun describeImpl(cols: List): DataFrame { val allCols = cols.collectAll(false) - val hasNumericCols = allCols.any { it.isPrimitiveNumber() } + val hasNumericCols = allCols.any { it.isNumber() } val hasComparableCols = allCols.any { it.valuesAreComparable() } val hasLongPaths = allCols.any { it.path().size > 1 } var df = allCols.toDataFrame { @@ -56,8 +56,8 @@ internal fun describeImpl(cols: List): DataFrame { ?.key } if (hasNumericCols) { - ColumnDescription::mean from { if (it.isPrimitiveNumber()) it.asNumbers().mean() else null } - ColumnDescription::std from { if (it.isPrimitiveNumber()) it.asNumbers().std() else null } + ColumnDescription::mean from { if (it.isNumber()) it.asNumbers().mean() else null } + ColumnDescription::std from { if (it.isNumber()) it.asNumbers().std() else null } } if (hasComparableCols || hasNumericCols) { ColumnDescription::min from inferType { @@ -111,12 +111,20 @@ private fun List.collectAll(atAnyDepth: Boolean): List = } /** Converts a column to a comparable column if it is not already comparable. */ -private fun DataColumn.convertToComparableOrNull(): DataColumn>? = - when { +@Suppress("UNCHECKED_CAST") +private fun DataColumn.convertToComparableOrNull(): DataColumn?>? { + return when { valuesAreComparable() -> asComparable() // Found incomparable number types, convert all to Double first - isPrimitiveNumber() -> map { (it as Number?)?.toDouble() }.cast() + isNumber() -> cast().map { + if (it?.isPrimitiveNumber() == false) { + // Cannot calculate statistics of a non-primitive number type + return@convertToComparableOrNull null + } + it?.toDouble() as Comparable? + } else -> null } +} diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/describe.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/describe.kt index 7bb70b3659..3161e3bc7b 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/describe.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/describe.kt @@ -37,12 +37,12 @@ class DescribeTests { nulls shouldBe 1 top shouldBe 1 freq shouldBe 1 - this.mean shouldBe 3.5 + mean shouldBe 3.5 std shouldBe 1.8708286933869707 min shouldBe 1.0 - p25 shouldBe 2.25 - median shouldBe 3.5 - p75 shouldBe 4.75 + p25 shouldBe 2.0 + median shouldBe 3.0 + p75 shouldBe 4.0 max shouldBe 6.0 } } From ca73fbc62b48715e4166d6c72a6d35b96147b669 Mon Sep 17 00:00:00 2001 From: Jolan Rensen Date: Mon, 17 Mar 2025 12:50:03 +0100 Subject: [PATCH 06/11] mean fixes based on feedback and some extra cleanup and fixes --- core/api/core.api | 3 - .../kotlinx/dataframe/api/DataColumnType.kt | 9 ++- .../jetbrains/kotlinx/dataframe/api/mean.kt | 31 ++++--- .../dataframe/impl/aggregation/getColumns.kt | 12 ++- .../jetbrains/kotlinx/dataframe/math/mean.kt | 80 +------------------ .../testSets/person/DataFrameTests.kt | 2 +- .../kotlinx/dataframe/api/DataColumnType.kt | 9 ++- .../jetbrains/kotlinx/dataframe/api/mean.kt | 33 +++++--- .../dataframe/impl/aggregation/getColumns.kt | 12 ++- .../jetbrains/kotlinx/dataframe/math/mean.kt | 80 +------------------ .../testSets/person/DataFrameTests.kt | 2 +- 11 files changed, 76 insertions(+), 197 deletions(-) diff --git a/core/api/core.api b/core/api/core.api index 6a5a54db21..a991051ccb 100644 --- a/core/api/core.api +++ b/core/api/core.api @@ -1777,7 +1777,6 @@ public final class org/jetbrains/kotlinx/dataframe/api/DataColumnArithmeticsKt { } public final class org/jetbrains/kotlinx/dataframe/api/DataColumnTypeKt { - public static final fun isBigNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z public static final fun isColumnGroup (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z public static final fun isComparable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z public static final fun isFrameColumn (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z @@ -2844,8 +2843,6 @@ public final class org/jetbrains/kotlinx/dataframe/api/MeanKt { public static synthetic fun meanFor$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;[Ljava/lang/String;ZZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static synthetic fun meanFor$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;[Lkotlin/reflect/KProperty;ZZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static synthetic fun meanFor$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;ZZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; - public static final fun meanOrNull (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Z)Ljava/lang/Double; - public static synthetic fun meanOrNull$default (Lorg/jetbrains/kotlinx/dataframe/DataColumn;ZILjava/lang/Object;)Ljava/lang/Double; public static final fun rowMean (Lorg/jetbrains/kotlinx/dataframe/DataRow;Z)D public static synthetic fun rowMean$default (Lorg/jetbrains/kotlinx/dataframe/DataRow;ZILjava/lang/Object;)D } diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataColumnType.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataColumnType.kt index 811287f910..51262ce33d 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataColumnType.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataColumnType.kt @@ -13,8 +13,6 @@ import org.jetbrains.kotlinx.dataframe.typeClass import org.jetbrains.kotlinx.dataframe.util.IS_COMPARABLE import org.jetbrains.kotlinx.dataframe.util.IS_COMPARABLE_REPLACE import org.jetbrains.kotlinx.dataframe.util.IS_INTER_COMPARABLE_IMPORT -import java.math.BigDecimal -import java.math.BigInteger import kotlin.contracts.ExperimentalContracts import kotlin.contracts.contract import kotlin.reflect.KClass @@ -49,10 +47,13 @@ public inline fun AnyCol.isSubtypeOf(): Boolean = isSubtypeOf(typeOf public inline fun AnyCol.isType(): Boolean = type() == typeOf() +/** Returns `true` when this column's type is a subtype of `Number?` */ public fun AnyCol.isNumber(): Boolean = isSubtypeOf() -public fun AnyCol.isBigNumber(): Boolean = isSubtypeOf() || isSubtypeOf() - +/** + * Returns `true` when this column has the (nullable) type of either: + * [Byte], [Short], [Int], [Long], [Float], or [Double]. + */ public fun AnyCol.isPrimitiveNumber(): Boolean = type().withNullability(false) in primitiveNumberTypes public fun AnyCol.isList(): Boolean = typeClass == List::class diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt index a5592b12b4..e278a39219 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt @@ -20,25 +20,32 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOfRow import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.of -import org.jetbrains.kotlinx.dataframe.impl.aggregation.numberColumns +import org.jetbrains.kotlinx.dataframe.impl.aggregation.primitiveNumberColumns import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes -import org.jetbrains.kotlinx.dataframe.impl.suggestIfNull import kotlin.reflect.KProperty import kotlin.reflect.typeOf +/* + * Calculating the mean is supported for all primitive number types. + * The return type is always Double, Double.NaN for empty input, never null. + * (this may introduce loss of precision for Longs). + * For mixed primitive number types [TwoStepNumbersAggregator], unifies the numbers before calculating the mean. + */ + // region DataColumn public fun DataColumn.mean(skipNA: Boolean = skipNA_default): Double = - meanOrNull(skipNA).suggestIfNull("mean") - -public fun DataColumn.meanOrNull(skipNA: Boolean = skipNA_default): Double? = - Aggregators.mean(skipNA).aggregate(this) + Aggregators.mean(skipNA).aggregate(this)!! public inline fun DataColumn.meanOf( skipNA: Boolean = skipNA_default, noinline expression: (T) -> R?, -): Double = Aggregators.mean(skipNA).cast2().aggregateOf(this, expression) ?: Double.NaN +): Double = + Aggregators.mean(skipNA) + .cast2() + .aggregateOf(this, expression) + ?: Double.NaN // endregion @@ -62,7 +69,8 @@ public inline fun AnyRow.rowMeanOf(skipNA: Boolean = skipNA // region DataFrame -public fun DataFrame.mean(skipNA: Boolean = skipNA_default): DataRow = meanFor(skipNA, numberColumns()) +public fun DataFrame.mean(skipNA: Boolean = skipNA_default): DataRow = + meanFor(skipNA, primitiveNumberColumns()) public fun DataFrame.meanFor( skipNA: Boolean = skipNA_default, @@ -112,7 +120,8 @@ public inline fun DataFrame.meanOf( // region GroupBy @Refine @Interpretable("GroupByMean1") -public fun Grouped.mean(skipNA: Boolean = skipNA_default): DataFrame = meanFor(skipNA, numberColumns()) +public fun Grouped.mean(skipNA: Boolean = skipNA_default): DataFrame = + meanFor(skipNA, primitiveNumberColumns()) @Refine @Interpretable("GroupByMean0") @@ -177,7 +186,7 @@ public inline fun Grouped.meanOf( // region Pivot public fun Pivot.mean(skipNA: Boolean = skipNA_default, separate: Boolean = false): DataRow = - meanFor(skipNA, separate, numberColumns()) + meanFor(skipNA, separate, primitiveNumberColumns()) public fun Pivot.meanFor( skipNA: Boolean = skipNA_default, @@ -220,7 +229,7 @@ public inline fun Pivot.meanOf( // region PivotGroupBy public fun PivotGroupBy.mean(separate: Boolean = false, skipNA: Boolean = skipNA_default): DataFrame = - meanFor(skipNA, separate, numberColumns()) + meanFor(skipNA, separate, primitiveNumberColumns()) public fun PivotGroupBy.meanFor( skipNA: Boolean = skipNA_default, diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/getColumns.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/getColumns.kt index b7b2c1052d..66f598b02d 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/getColumns.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/getColumns.kt @@ -6,6 +6,7 @@ import org.jetbrains.kotlinx.dataframe.aggregation.Aggregatable import org.jetbrains.kotlinx.dataframe.aggregation.NamedValue import org.jetbrains.kotlinx.dataframe.api.filter import org.jetbrains.kotlinx.dataframe.api.isNumber +import org.jetbrains.kotlinx.dataframe.api.isPrimitiveNumber import org.jetbrains.kotlinx.dataframe.api.valuesAreComparable import org.jetbrains.kotlinx.dataframe.columns.TypeSuggestion import org.jetbrains.kotlinx.dataframe.impl.columns.createColumnGuessingType @@ -14,10 +15,17 @@ internal inline fun Aggregatable.remainingColumns( crossinline predicate: (AnyCol) -> Boolean, ): ColumnsSelector = remainingColumnsSelector().filter { predicate(it.data) } -internal fun Aggregatable.intraComparableColumns() = +@Suppress("UNCHECKED_CAST") +internal fun Aggregatable.intraComparableColumns(): ColumnsSelector> = remainingColumns { it.valuesAreComparable() } as ColumnsSelector> -internal fun Aggregatable.numberColumns() = remainingColumns { it.isNumber() } as ColumnsSelector +@Suppress("UNCHECKED_CAST") +internal fun Aggregatable.numberColumns(): ColumnsSelector = + remainingColumns { it.isNumber() } as ColumnsSelector + +@Suppress("UNCHECKED_CAST") +internal fun Aggregatable.primitiveNumberColumns(): ColumnsSelector = + remainingColumns { it.isPrimitiveNumber() } as ColumnsSelector internal fun NamedValue.toColumnWithPath() = path to createColumnGuessingType( diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/mean.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/mean.kt index 8250251fcc..f94c274a04 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/mean.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/mean.kt @@ -25,7 +25,7 @@ internal fun Sequence.mean(type: KType, skipNA: Boolean = skipN return when (type.withNullability(false)) { typeOf() -> (this as Sequence).mean(skipNA) - typeOf() -> (this as Sequence).mean(skipNA) + typeOf() -> (this as Sequence).map { it.toDouble() }.mean(skipNA) typeOf() -> (this as Sequence).map { it.toDouble() }.mean(false) @@ -76,81 +76,3 @@ internal fun Sequence.mean(skipNA: Boolean = skipNA_default): Double { } return if (count > 0) sum / count else Double.NaN } - -@JvmName("meanFloat") -internal fun Sequence.mean(skipNA: Boolean = skipNA_default): Double { - var count = 0 - var sum: Double = 0.toDouble() - for (element in this) { - if (element.isNaN()) { - if (skipNA) { - continue - } else { - return Double.NaN - } - } - sum += element - count++ - } - return if (count > 0) sum / count else Double.NaN -} - -@JvmName("doubleMean") -internal fun Iterable.mean(skipNA: Boolean = skipNA_default): Double = asSequence().mean(skipNA) - -@JvmName("floatMean") -internal fun Iterable.mean(skipNA: Boolean = skipNA_default): Double = asSequence().mean(skipNA) - -@JvmName("intMean") -internal fun Iterable.mean(): Double = - if (this is Collection) { - if (isNotEmpty()) sumOf { it.toDouble() } / size else Double.NaN - } else { - var count = 0 - val sum = sumOf { - count++ - it.toDouble() - } - if (count > 0) sum / count else Double.NaN - } - -@JvmName("shortMean") -internal fun Iterable.mean(): Double = - if (this is Collection) { - if (isNotEmpty()) sumOf { it.toDouble() } / size else Double.NaN - } else { - var count = 0 - val sum = sumOf { - count++ - it.toDouble() - } - if (count > 0) sum / count else Double.NaN - } - -@JvmName("byteMean") -internal fun Iterable.mean(): Double = - if (this is Collection) { - if (isNotEmpty()) sumOf { it.toDouble() } / size else Double.NaN - } else { - var count = 0 - val sum = sumOf { - count++ - it.toDouble() - } - if (count > 0) sum / count else Double.NaN - } - -@JvmName("longMean") -internal fun Iterable.mean(): Double = - if (this is Collection) { - if (isNotEmpty()) sumOf { it.toDouble() } / size else Double.NaN - } else { - var count = 0 - val sum = sumOf { - count++ - it.toDouble() - } - if (count > 0) sum / count else Double.NaN - }.also { - logger.warn { "Converting Longs to Doubles to calculate the mean, loss of precision may occur." } - } diff --git a/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/testSets/person/DataFrameTests.kt b/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/testSets/person/DataFrameTests.kt index 7f5886bbe5..da3d7c11aa 100644 --- a/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/testSets/person/DataFrameTests.kt +++ b/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/testSets/person/DataFrameTests.kt @@ -1485,7 +1485,7 @@ class DataFrameTests : BaseTest() { @Test fun `column stats`() { - typed.age.mean() shouldBe typed.age.toList().mean() + typed.age.mean() shouldBe typed.age.toList().average() typed.age.min() shouldBe typed.age.toList().minOrNull() typed.age.max() shouldBe typed.age.toList().maxOrNull() typed.age.sum() shouldBe typed.age.toList().sum() diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataColumnType.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataColumnType.kt index 11bc564346..e63b3af185 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataColumnType.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataColumnType.kt @@ -13,8 +13,6 @@ import org.jetbrains.kotlinx.dataframe.typeClass import org.jetbrains.kotlinx.dataframe.util.IS_COMPARABLE import org.jetbrains.kotlinx.dataframe.util.IS_COMPARABLE_REPLACE import org.jetbrains.kotlinx.dataframe.util.IS_INTER_COMPARABLE_IMPORT -import java.math.BigDecimal -import java.math.BigInteger import kotlin.contracts.ExperimentalContracts import kotlin.contracts.contract import kotlin.reflect.KClass @@ -49,10 +47,13 @@ public inline fun AnyCol.isSubtypeOf(): Boolean = isSubtypeOf(typeOf public inline fun AnyCol.isType(): Boolean = type() == typeOf() +/** Returns `true` when this column's type is a subtype of `Number?` */ public fun AnyCol.isNumber(): Boolean = isSubtypeOf() -public fun AnyCol.isBigNumber(): Boolean = isSubtypeOf() || isSubtypeOf() - +/** + * Returns `true` when this column has the (nullable) type of either: + * [Byte], [Short], [Int], [Long], [Float], or [Double]. + */ public fun AnyCol.isPrimitiveNumber(): Boolean = type().withNullability(false) in primitiveNumberTypes public fun AnyCol.isList(): Boolean = typeClass == List::class diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt index a5592b12b4..0f6a17181d 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt @@ -20,25 +20,34 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOfRow import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.of -import org.jetbrains.kotlinx.dataframe.impl.aggregation.numberColumns +import org.jetbrains.kotlinx.dataframe.impl.aggregation.primitiveNumberColumns import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes -import org.jetbrains.kotlinx.dataframe.impl.suggestIfNull import kotlin.reflect.KProperty import kotlin.reflect.typeOf +/* + * TODO KDocs: + * Calculating the mean is supported for all primitive number types. + * Nulls are filtered from columns. + * The return type is always Double, Double.NaN for empty input, never null. + * (May introduce loss of precision for Longs). + * For mixed primitive number types, [TwoStepNumbersAggregator] unifies the numbers before calculating the mean. + */ + // region DataColumn public fun DataColumn.mean(skipNA: Boolean = skipNA_default): Double = - meanOrNull(skipNA).suggestIfNull("mean") - -public fun DataColumn.meanOrNull(skipNA: Boolean = skipNA_default): Double? = - Aggregators.mean(skipNA).aggregate(this) + Aggregators.mean(skipNA).aggregate(this)!! public inline fun DataColumn.meanOf( skipNA: Boolean = skipNA_default, noinline expression: (T) -> R?, -): Double = Aggregators.mean(skipNA).cast2().aggregateOf(this, expression) ?: Double.NaN +): Double = + Aggregators.mean(skipNA) + .cast2() + .aggregateOf(this, expression) + ?: Double.NaN // endregion @@ -62,7 +71,8 @@ public inline fun AnyRow.rowMeanOf(skipNA: Boolean = skipNA // region DataFrame -public fun DataFrame.mean(skipNA: Boolean = skipNA_default): DataRow = meanFor(skipNA, numberColumns()) +public fun DataFrame.mean(skipNA: Boolean = skipNA_default): DataRow = + meanFor(skipNA, primitiveNumberColumns()) public fun DataFrame.meanFor( skipNA: Boolean = skipNA_default, @@ -112,7 +122,8 @@ public inline fun DataFrame.meanOf( // region GroupBy @Refine @Interpretable("GroupByMean1") -public fun Grouped.mean(skipNA: Boolean = skipNA_default): DataFrame = meanFor(skipNA, numberColumns()) +public fun Grouped.mean(skipNA: Boolean = skipNA_default): DataFrame = + meanFor(skipNA, primitiveNumberColumns()) @Refine @Interpretable("GroupByMean0") @@ -177,7 +188,7 @@ public inline fun Grouped.meanOf( // region Pivot public fun Pivot.mean(skipNA: Boolean = skipNA_default, separate: Boolean = false): DataRow = - meanFor(skipNA, separate, numberColumns()) + meanFor(skipNA, separate, primitiveNumberColumns()) public fun Pivot.meanFor( skipNA: Boolean = skipNA_default, @@ -220,7 +231,7 @@ public inline fun Pivot.meanOf( // region PivotGroupBy public fun PivotGroupBy.mean(separate: Boolean = false, skipNA: Boolean = skipNA_default): DataFrame = - meanFor(skipNA, separate, numberColumns()) + meanFor(skipNA, separate, primitiveNumberColumns()) public fun PivotGroupBy.meanFor( skipNA: Boolean = skipNA_default, diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/getColumns.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/getColumns.kt index b7b2c1052d..66f598b02d 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/getColumns.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/getColumns.kt @@ -6,6 +6,7 @@ import org.jetbrains.kotlinx.dataframe.aggregation.Aggregatable import org.jetbrains.kotlinx.dataframe.aggregation.NamedValue import org.jetbrains.kotlinx.dataframe.api.filter import org.jetbrains.kotlinx.dataframe.api.isNumber +import org.jetbrains.kotlinx.dataframe.api.isPrimitiveNumber import org.jetbrains.kotlinx.dataframe.api.valuesAreComparable import org.jetbrains.kotlinx.dataframe.columns.TypeSuggestion import org.jetbrains.kotlinx.dataframe.impl.columns.createColumnGuessingType @@ -14,10 +15,17 @@ internal inline fun Aggregatable.remainingColumns( crossinline predicate: (AnyCol) -> Boolean, ): ColumnsSelector = remainingColumnsSelector().filter { predicate(it.data) } -internal fun Aggregatable.intraComparableColumns() = +@Suppress("UNCHECKED_CAST") +internal fun Aggregatable.intraComparableColumns(): ColumnsSelector> = remainingColumns { it.valuesAreComparable() } as ColumnsSelector> -internal fun Aggregatable.numberColumns() = remainingColumns { it.isNumber() } as ColumnsSelector +@Suppress("UNCHECKED_CAST") +internal fun Aggregatable.numberColumns(): ColumnsSelector = + remainingColumns { it.isNumber() } as ColumnsSelector + +@Suppress("UNCHECKED_CAST") +internal fun Aggregatable.primitiveNumberColumns(): ColumnsSelector = + remainingColumns { it.isPrimitiveNumber() } as ColumnsSelector internal fun NamedValue.toColumnWithPath() = path to createColumnGuessingType( diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/mean.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/mean.kt index 8250251fcc..f94c274a04 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/mean.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/mean.kt @@ -25,7 +25,7 @@ internal fun Sequence.mean(type: KType, skipNA: Boolean = skipN return when (type.withNullability(false)) { typeOf() -> (this as Sequence).mean(skipNA) - typeOf() -> (this as Sequence).mean(skipNA) + typeOf() -> (this as Sequence).map { it.toDouble() }.mean(skipNA) typeOf() -> (this as Sequence).map { it.toDouble() }.mean(false) @@ -76,81 +76,3 @@ internal fun Sequence.mean(skipNA: Boolean = skipNA_default): Double { } return if (count > 0) sum / count else Double.NaN } - -@JvmName("meanFloat") -internal fun Sequence.mean(skipNA: Boolean = skipNA_default): Double { - var count = 0 - var sum: Double = 0.toDouble() - for (element in this) { - if (element.isNaN()) { - if (skipNA) { - continue - } else { - return Double.NaN - } - } - sum += element - count++ - } - return if (count > 0) sum / count else Double.NaN -} - -@JvmName("doubleMean") -internal fun Iterable.mean(skipNA: Boolean = skipNA_default): Double = asSequence().mean(skipNA) - -@JvmName("floatMean") -internal fun Iterable.mean(skipNA: Boolean = skipNA_default): Double = asSequence().mean(skipNA) - -@JvmName("intMean") -internal fun Iterable.mean(): Double = - if (this is Collection) { - if (isNotEmpty()) sumOf { it.toDouble() } / size else Double.NaN - } else { - var count = 0 - val sum = sumOf { - count++ - it.toDouble() - } - if (count > 0) sum / count else Double.NaN - } - -@JvmName("shortMean") -internal fun Iterable.mean(): Double = - if (this is Collection) { - if (isNotEmpty()) sumOf { it.toDouble() } / size else Double.NaN - } else { - var count = 0 - val sum = sumOf { - count++ - it.toDouble() - } - if (count > 0) sum / count else Double.NaN - } - -@JvmName("byteMean") -internal fun Iterable.mean(): Double = - if (this is Collection) { - if (isNotEmpty()) sumOf { it.toDouble() } / size else Double.NaN - } else { - var count = 0 - val sum = sumOf { - count++ - it.toDouble() - } - if (count > 0) sum / count else Double.NaN - } - -@JvmName("longMean") -internal fun Iterable.mean(): Double = - if (this is Collection) { - if (isNotEmpty()) sumOf { it.toDouble() } / size else Double.NaN - } else { - var count = 0 - val sum = sumOf { - count++ - it.toDouble() - } - if (count > 0) sum / count else Double.NaN - }.also { - logger.warn { "Converting Longs to Doubles to calculate the mean, loss of precision may occur." } - } diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/testSets/person/DataFrameTests.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/testSets/person/DataFrameTests.kt index 7f5886bbe5..da3d7c11aa 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/testSets/person/DataFrameTests.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/testSets/person/DataFrameTests.kt @@ -1485,7 +1485,7 @@ class DataFrameTests : BaseTest() { @Test fun `column stats`() { - typed.age.mean() shouldBe typed.age.toList().mean() + typed.age.mean() shouldBe typed.age.toList().average() typed.age.min() shouldBe typed.age.toList().minOrNull() typed.age.max() shouldBe typed.age.toList().maxOrNull() typed.age.sum() shouldBe typed.age.toList().sum() From 2411a1baacb21d4db1f221a3318de48ab0d9f6de Mon Sep 17 00:00:00 2001 From: Jolan Rensen Date: Mon, 17 Mar 2025 12:54:28 +0100 Subject: [PATCH 07/11] merged master --- core/api/core.api | 13 ++++++++++--- .../org/jetbrains/kotlinx/dataframe/api/mean.kt | 6 ++++-- .../kotlinx/dataframe/io/DelimCsvTsvTests.kt | 7 ++++++- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/core/api/core.api b/core/api/core.api index bd990d9eea..0f002f0ab0 100644 --- a/core/api/core.api +++ b/core/api/core.api @@ -1777,12 +1777,12 @@ public final class org/jetbrains/kotlinx/dataframe/api/DataColumnArithmeticsKt { } public final class org/jetbrains/kotlinx/dataframe/api/DataColumnTypeKt { - public static final fun isBigNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z public static final fun isColumnGroup (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z public static final fun isComparable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z public static final fun isFrameColumn (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z public static final fun isList (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z public static final fun isNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z + public static final fun isPrimitiveNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z public static final fun isSubtypeOf (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Lkotlin/reflect/KType;)Z public static final fun isValueColumn (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z public static final fun valuesAreComparable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z @@ -2842,8 +2842,6 @@ public final class org/jetbrains/kotlinx/dataframe/api/MeanKt { public static synthetic fun meanFor$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;[Ljava/lang/String;ZZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static synthetic fun meanFor$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;[Lkotlin/reflect/KProperty;ZZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static synthetic fun meanFor$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;ZZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; - public static final fun meanOrNull (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Z)Ljava/lang/Double; - public static synthetic fun meanOrNull$default (Lorg/jetbrains/kotlinx/dataframe/DataColumn;ZILjava/lang/Object;)Ljava/lang/Double; public static final fun rowMean (Lorg/jetbrains/kotlinx/dataframe/DataRow;Z)D public static synthetic fun rowMean$default (Lorg/jetbrains/kotlinx/dataframe/DataRow;ZILjava/lang/Object;)D } @@ -3967,6 +3965,7 @@ public final class org/jetbrains/kotlinx/dataframe/api/TypeConversionsKt { public static final fun asComparable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Lorg/jetbrains/kotlinx/dataframe/DataColumn; public static final fun asComparable (Lorg/jetbrains/kotlinx/dataframe/columns/ColumnSet;)Lorg/jetbrains/kotlinx/dataframe/columns/ColumnSet; public static final fun asComparable (Lorg/jetbrains/kotlinx/dataframe/columns/SingleColumn;)Lorg/jetbrains/kotlinx/dataframe/columns/SingleColumn; + public static final fun asComparableNullable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Lorg/jetbrains/kotlinx/dataframe/DataColumn; public static final fun asDataColumn (Lorg/jetbrains/kotlinx/dataframe/columns/ColumnGroup;)Lorg/jetbrains/kotlinx/dataframe/DataColumn; public static final fun asDataColumn (Lorg/jetbrains/kotlinx/dataframe/columns/FrameColumn;)Lorg/jetbrains/kotlinx/dataframe/DataColumn; public static final fun asDataFrame (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; @@ -5103,6 +5102,10 @@ public final class org/jetbrains/kotlinx/dataframe/impl/ExceptionUtilsKt { public static final fun suggestIfNull (Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/Object; } +public final class org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtilsKt { + public static final fun getPrimitiveNumberTypes ()Ljava/util/Set; +} + public final class org/jetbrains/kotlinx/dataframe/impl/TypeUtilsKt { public static final fun getValuesType (Ljava/util/List;Lkotlin/reflect/KType;Lorg/jetbrains/kotlinx/dataframe/api/Infer;)Lkotlin/reflect/KType; public static final synthetic fun guessValueType (Lkotlin/sequences/Sequence;Lkotlin/reflect/KType;Z)Lkotlin/reflect/KType; @@ -5210,6 +5213,10 @@ public final class org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/OfRowE public static final fun aggregateOfDelegated (Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator;Lorg/jetbrains/kotlinx/dataframe/api/Grouped;Ljava/lang/String;Lkotlin/jvm/functions/Function2;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; } +public final class org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/RowKt { + public static final fun aggregateOfRow (Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator;Lorg/jetbrains/kotlinx/dataframe/DataRow;Lkotlin/jvm/functions/Function2;)Ljava/lang/Object; +} + public final class org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/WithinAllColumnsKt { public static final fun aggregateAll (Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator;Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)Ljava/lang/Object; } diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt index e278a39219..0f6a17181d 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt @@ -27,10 +27,12 @@ import kotlin.reflect.KProperty import kotlin.reflect.typeOf /* + * TODO KDocs: * Calculating the mean is supported for all primitive number types. + * Nulls are filtered from columns. * The return type is always Double, Double.NaN for empty input, never null. - * (this may introduce loss of precision for Longs). - * For mixed primitive number types [TwoStepNumbersAggregator], unifies the numbers before calculating the mean. + * (May introduce loss of precision for Longs). + * For mixed primitive number types, [TwoStepNumbersAggregator] unifies the numbers before calculating the mean. */ // region DataColumn diff --git a/dataframe-csv/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/DelimCsvTsvTests.kt b/dataframe-csv/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/DelimCsvTsvTests.kt index 31abcfd041..f303e5028c 100644 --- a/dataframe-csv/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/DelimCsvTsvTests.kt +++ b/dataframe-csv/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/DelimCsvTsvTests.kt @@ -37,6 +37,9 @@ import java.util.zip.GZIPInputStream import kotlin.reflect.KClass import kotlin.reflect.typeOf +// can be enabled for showing logs for these tests +private const val SHOW_LOGS = false + @Suppress("ktlint:standard:argument-list-wrapping") class DelimCsvTsvTests { @@ -45,12 +48,14 @@ class DelimCsvTsvTests { @Before fun setLogger() { + if (!SHOW_LOGS) return loggerBefore = System.getProperty(logLevel) - System.setProperty(logLevel, "debug") + System.setProperty(logLevel, "trace") } @After fun restoreLogger() { + if (!SHOW_LOGS) return if (loggerBefore != null) { System.setProperty(logLevel, loggerBefore) } From f17929bb74a69870f2dcd46cdf08ca7a1487b15e Mon Sep 17 00:00:00 2001 From: Jolan Rensen Date: Mon, 17 Mar 2025 13:39:53 +0100 Subject: [PATCH 08/11] made it no longer mandatory for an aggregator to return a nullable value. This simplifies logic in a lot of places. It can still be nullable for aggregators that require it (like min/max). --- .../org/jetbrains/kotlinx/dataframe/api/mean.kt | 15 +++++++-------- .../impl/aggregation/aggregators/Aggregator.kt | 15 ++++++++------- .../aggregation/aggregators/AggregatorBase.kt | 8 ++++---- .../impl/aggregation/aggregators/Aggregators.kt | 16 ++++++++-------- .../aggregators/FlatteningAggregator.kt | 4 ++-- .../aggregation/aggregators/TwoStepAggregator.kt | 8 ++++---- .../aggregators/TwoStepNumbersAggregator.kt | 12 ++++++------ .../impl/aggregation/modes/ofRowExpression.kt | 10 +++++----- .../dataframe/impl/aggregation/modes/row.kt | 2 +- .../impl/aggregation/modes/withinAllColumns.kt | 4 ++-- .../org/jetbrains/kotlinx/dataframe/api/mean.kt | 16 +++++++--------- .../impl/aggregation/aggregators/Aggregator.kt | 15 ++++++++------- .../aggregation/aggregators/AggregatorBase.kt | 8 ++++---- .../impl/aggregation/aggregators/Aggregators.kt | 6 +++--- .../aggregators/FlatteningAggregator.kt | 4 ++-- .../aggregation/aggregators/TwoStepAggregator.kt | 8 ++++---- .../aggregators/TwoStepNumbersAggregator.kt | 12 ++++++------ .../impl/aggregation/modes/ofRowExpression.kt | 10 +++++----- .../dataframe/impl/aggregation/modes/row.kt | 2 +- .../impl/aggregation/modes/withinAllColumns.kt | 4 ++-- 20 files changed, 89 insertions(+), 90 deletions(-) diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt index 0f6a17181d..a4fbad83d1 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt @@ -24,6 +24,7 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.primitiveNumberColumns import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes import kotlin.reflect.KProperty +import kotlin.reflect.full.withNullability import kotlin.reflect.typeOf /* @@ -38,7 +39,7 @@ import kotlin.reflect.typeOf // region DataColumn public fun DataColumn.mean(skipNA: Boolean = skipNA_default): Double = - Aggregators.mean(skipNA).aggregate(this)!! + Aggregators.mean(skipNA).aggregate(this) public inline fun DataColumn.meanOf( skipNA: Boolean = skipNA_default, @@ -47,7 +48,6 @@ public inline fun DataColumn.meanOf( Aggregators.mean(skipNA) .cast2() .aggregateOf(this, expression) - ?: Double.NaN // endregion @@ -56,15 +56,14 @@ public inline fun DataColumn.meanOf( public fun AnyRow.rowMean(skipNA: Boolean = skipNA_default): Double = Aggregators.mean(skipNA).aggregateOfRow(this) { colsOf { it.isPrimitiveNumber() } - } ?: Double.NaN + } -public inline fun AnyRow.rowMeanOf(skipNA: Boolean = skipNA_default): Double { - require(typeOf() in primitiveNumberTypes) { +public inline fun AnyRow.rowMeanOf(skipNA: Boolean = skipNA_default): Double { + require(typeOf().withNullability(false) in primitiveNumberTypes) { "Type ${T::class.simpleName} is not a primitive number type. Mean only supports primitive number types." } return Aggregators.mean(skipNA) .aggregateOfRow(this) { colsOf() } - ?: Double.NaN } // endregion @@ -97,7 +96,7 @@ public fun DataFrame.meanFor( public fun DataFrame.mean( skipNA: Boolean = skipNA_default, columns: ColumnsSelector, -): Double = Aggregators.mean(skipNA).aggregateAll(this, columns) ?: Double.NaN +): Double = Aggregators.mean(skipNA).aggregateAll(this, columns) public fun DataFrame.mean(vararg columns: String, skipNA: Boolean = skipNA_default): Double = mean(skipNA) { columns.toNumberColumns() } @@ -115,7 +114,7 @@ public fun DataFrame.mean(vararg columns: KProperty, skip public inline fun DataFrame.meanOf( skipNA: Boolean = skipNA_default, noinline expression: RowExpression, -): Double = Aggregators.mean(skipNA).of(this, expression) ?: Double.NaN +): Double = Aggregators.mean(skipNA).of(this, expression) // endregion diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt index 0050145715..7e1e8abd81 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt @@ -15,8 +15,7 @@ import kotlin.reflect.full.withNullability * @param Value The type of the values to be aggregated. * This can be nullable for [Iterables][Iterable] or not, depending on the use case. * For columns, [Value] will always be considered nullable; nulls are filtered out from columns anyway. - * @param Return The type of the resulting value. It doesn't matter if this is nullable or not, as the aggregator - * will always return a [Return]`?`. + * @param Return The type of the resulting value. Can optionally be nullable. */ @PublishedApi internal interface Aggregator { @@ -33,7 +32,7 @@ internal interface Aggregator { * * When the exact [type] is unknown, use [aggregateCalculatingType]. */ - fun aggregate(values: Iterable, type: KType): Return? + fun aggregate(values: Iterable, type: KType): Return /** * Aggregates the data in the given column and computes a single resulting value. @@ -41,12 +40,12 @@ internal interface Aggregator { * * See [AggregatorBase.aggregate]. */ - fun aggregate(column: DataColumn): Return? + fun aggregate(column: DataColumn): Return /** * Aggregates the data in the multiple given columns and computes a single resulting value. */ - fun aggregate(columns: Iterable>): Return? + fun aggregate(columns: Iterable>): Return /** * Special case of [aggregate] with [Iterable] that calculates the common type of the values at runtime. @@ -58,7 +57,7 @@ internal interface Aggregator { * It should contain all types of [values]. * If `null`, the types of [values] will be calculated at runtime (heavy!). */ - fun aggregateCalculatingType(values: Iterable, valueTypes: Set? = null): Return? + fun aggregateCalculatingType(values: Iterable, valueTypes: Set? = null): Return /** * Function that can give the return type of [aggregate] as [KType], given the type of the input. @@ -82,9 +81,11 @@ internal interface Aggregator { fun calculateReturnTypeOrNull(colTypes: Set, colsEmpty: Boolean): KType? } +@Suppress("UNCHECKED_CAST") @PublishedApi internal fun Aggregator<*, *>.cast(): Aggregator = this as Aggregator +@Suppress("UNCHECKED_CAST") @PublishedApi internal fun Aggregator<*, *>.cast2(): Aggregator = this as Aggregator @@ -92,7 +93,7 @@ internal fun Aggregator<*, *>.cast2(): Aggregator internal typealias CalculateReturnTypeOrNull = (type: KType, emptyInput: Boolean) -> KType? /** Type alias for [Aggregator.aggregate]. */ -internal typealias Aggregate = Iterable.(type: KType) -> Return? +internal typealias Aggregate = Iterable.(type: KType) -> Return /** Common case for [CalculateReturnTypeOrNull], preserves return type, but makes it nullable for empty inputs. */ internal val preserveReturnTypeNullIfEmpty: CalculateReturnTypeOrNull = { type, emptyInput -> diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorBase.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorBase.kt index 99c7b33c61..6b8d16edf9 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorBase.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorBase.kt @@ -31,7 +31,7 @@ internal abstract class AggregatorBase( * * When the exact [type] is unknown, use [aggregateCalculatingType]. */ - override fun aggregate(values: Iterable, type: KType): Return? = aggregator(values, type) + override fun aggregate(values: Iterable, type: KType): Return = aggregator(values, type) /** * Function that can give the return type of [aggregate] as [KType], given the type of the input. @@ -52,7 +52,7 @@ internal abstract class AggregatorBase( * Nulls are filtered out by default, then [aggregate] (with [Iterable] and [KType]) is called. */ @Suppress("UNCHECKED_CAST") - override fun aggregate(column: DataColumn): Return? = + override fun aggregate(column: DataColumn): Return = aggregate( values = if (column.hasNulls()) { @@ -71,7 +71,7 @@ internal abstract class AggregatorBase( * If provided, this can be used to avoid calculating the types of [values][org.jetbrains.kotlinx.dataframe.values] at runtime with reflection. * It should contain all types of [values][org.jetbrains.kotlinx.dataframe.values]. * If `null`, the types of [values][org.jetbrains.kotlinx.dataframe.values] will be calculated at runtime (heavy!). */ - override fun aggregateCalculatingType(values: Iterable, valueTypes: Set?): Return? { + override fun aggregateCalculatingType(values: Iterable, valueTypes: Set?): Return { val commonType = if (valueTypes != null) { valueTypes.commonType(false) } else { @@ -93,7 +93,7 @@ internal abstract class AggregatorBase( * Aggregates the data in the multiple given columns and computes a single resulting value. * Must be overridden to use. */ - abstract override fun aggregate(columns: Iterable>): Return? + abstract override fun aggregate(columns: Iterable>): Return /** * Function that can give the return type of [aggregate] with columns as [KType], diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt index 0de1dddb95..9828a34f5e 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt @@ -30,7 +30,7 @@ internal object Aggregators { * -> stepOneAggregator(Iterable, colType) // called on each iterable * -> Iterable // nulls filtered out * -> stepTwoAggregator(Iterable, common valueType) - * -> Return? + * -> Return * ``` * * It can also be used as a "simple" aggregator by providing the same function for both steps. @@ -43,7 +43,7 @@ internal object Aggregators { * @param stepTwoAggregator Functional argument for the aggregation function used between different columns. * It is run on the results of [stepOneAggregator]. */ - private fun twoStepPreservingType(aggregator: Aggregate) = + private fun twoStepPreservingType(aggregator: Aggregate) = TwoStepAggregator.Factory( getReturnTypeOrNull = preserveReturnTypeNullIfEmpty, stepOneAggregator = aggregator, @@ -68,7 +68,7 @@ internal object Aggregators { * -> stepOneAggregator(Iterable, colType) // called on each iterable * -> Iterable // nulls filtered out * -> stepTwoAggregator(Iterable, common valueType) - * -> Return? + * -> Return * ``` * * It can also be used as a "simple" aggregator by providing the same function for both steps. @@ -106,7 +106,7 @@ internal object Aggregators { * Iterable> * -> Iterable // flattened without nulls * -> aggregator(Iterable, common colType) - * -> Return? + * -> Return * ``` * * This is essential for aggregators that depend on the distribution of all values across the dataframe, like @@ -119,7 +119,7 @@ internal object Aggregators { * @param aggregator Functional argument for the [aggregate][org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.FlatteningAggregator.aggregate] function. * Note that it must be able to handle `null` values for the [Iterable] overload of [aggregate][org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.FlatteningAggregator.aggregate]. */ - private fun flatteningPreservingTypes(aggregate: Aggregate) = + private fun flatteningPreservingTypes(aggregate: Aggregate) = FlatteningAggregator.Factory( getReturnTypeOrNull = preserveReturnTypeNullIfEmpty, aggregator = aggregate, @@ -140,7 +140,7 @@ internal object Aggregators { * Iterable> * -> Iterable // flattened without nulls * -> aggregator(Iterable, common colType) - * -> Return? + * -> Return * ``` * * This is essential for aggregators that depend on the distribution of all values across the dataframe, like @@ -182,7 +182,7 @@ internal object Aggregators { * -> aggregator(Iterable, unified number type of common colType) // called on each iterable * -> Iterable // nulls filtered out * -> aggregator(Iterable, unified number type of common valueType) - * -> Return? + * -> Return * ``` * * @param name The name of this aggregator. @@ -191,7 +191,7 @@ internal object Aggregators { * While it takes a [Number] argument, you can assume that all values are of the same specific type, however, * this type can be different for different calls to [aggregator][org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.AggregatorBase.aggregator]. */ - private fun twoStepForNumbers( + private fun twoStepForNumbers( getReturnTypeOrNull: CalculateReturnTypeOrNull, aggregate: Aggregate, ) = TwoStepNumbersAggregator.Factory( diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/FlatteningAggregator.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/FlatteningAggregator.kt index 270777e7a2..53124d1c6b 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/FlatteningAggregator.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/FlatteningAggregator.kt @@ -18,7 +18,7 @@ import kotlin.reflect.full.withNullability * Iterable> * -> Iterable // flattened without nulls * -> aggregator(Iterable, common colType) - * -> Return? + * -> Return * ``` * * This is essential for aggregators that depend on the distribution of all values across the dataframe, like @@ -42,7 +42,7 @@ internal class FlatteningAggregator( * The columns are flattened into a single list of values, filtering nulls as usual; * then the aggregation function is with the common type of the columns. */ - override fun aggregate(columns: Iterable>): Return? { + override fun aggregate(columns: Iterable>): Return { val commonType = columns.map { it.type() }.commonType().withNullability(false) val allValues = columns.asSequence().flatMap { it.values() }.filterNotNull() return aggregate(allValues.asIterable(), commonType) diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepAggregator.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepAggregator.kt index 11738fbf5e..879f740f79 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepAggregator.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepAggregator.kt @@ -23,7 +23,7 @@ import kotlin.reflect.full.withNullability * -> stepOneAggregator(Iterable, colType) // called on each iterable * -> Iterable // nulls filtered out * -> stepTwoAggregator(Iterable, common valueType) - * -> Return? + * -> Return * ``` * * It can also be used as a "simple" aggregator by providing the same function for both steps. @@ -40,7 +40,7 @@ internal class TwoStepAggregator( name: String, getReturnTypeOrNull: CalculateReturnTypeOrNull, stepOneAggregator: Aggregate, - private val stepTwoAggregator: Aggregate, + private val stepTwoAggregator: Aggregate, ) : AggregatorBase(name, getReturnTypeOrNull, stepOneAggregator) { /** @@ -50,7 +50,7 @@ internal class TwoStepAggregator( * * Post-step-one types are calculated by [calculateReturnTypeOrNull]. */ - override fun aggregate(columns: Iterable>): Return? { + override fun aggregate(columns: Iterable>): Return { val (values, types) = columns.mapNotNull { col -> // uses stepOneAggregator val value = aggregate(col) ?: return@mapNotNull null @@ -93,7 +93,7 @@ internal class TwoStepAggregator( class Factory( private val getReturnTypeOrNull: CalculateReturnTypeOrNull, private val stepOneAggregator: Aggregate, - private val stepTwoAggregator: Aggregate, + private val stepTwoAggregator: Aggregate, ) : AggregatorProvider> by AggregatorProvider({ name -> TwoStepAggregator( name = name, diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepNumbersAggregator.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepNumbersAggregator.kt index bb229720d3..7441c63769 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepNumbersAggregator.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepNumbersAggregator.kt @@ -39,7 +39,7 @@ private val logger = KotlinLogging.logger { } * -> aggregator(Iterable, unified number type of common colType) // called on each iterable * -> Iterable // nulls filtered out * -> aggregator(Iterable, unified number type of common valueType) - * -> Return? + * -> Return * ``` * * @param name The name of this aggregator. @@ -48,7 +48,7 @@ private val logger = KotlinLogging.logger { } * While it takes a [Number] argument, you can assume that all values are of the same specific type, however, * this type can be different for different calls to [aggregator]. */ -internal class TwoStepNumbersAggregator( +internal class TwoStepNumbersAggregator( name: String, getReturnTypeOrNull: CalculateReturnTypeOrNull, aggregator: Aggregate, @@ -62,7 +62,7 @@ internal class TwoStepNumbersAggregator( * After the first aggregation, the number types are found by [calculateReturnTypeOrNull] and then * unified using [aggregateCalculatingType]. */ - override fun aggregate(columns: Iterable>): Return? { + override fun aggregate(columns: Iterable>): Return { val (values, types) = columns.mapNotNull { col -> val value = aggregate(col) ?: return@mapNotNull null val type = calculateReturnTypeOrNull( @@ -113,7 +113,7 @@ internal class TwoStepNumbersAggregator( * * When the exact [type] is unknown, use [aggregateCalculatingType]. */ - override fun aggregate(values: Iterable, type: KType): Return? { + override fun aggregate(values: Iterable, type: KType): Return { require(type.isSubtypeOf(typeOf())) { "${TwoStepNumbersAggregator::class.simpleName}: Type $type is not a subtype of Number?" } @@ -147,7 +147,7 @@ internal class TwoStepNumbersAggregator( * If `null`, the types of [values] will be calculated at runtime (heavy!). */ @Suppress("UNCHECKED_CAST") - override fun aggregateCalculatingType(values: Iterable, valueTypes: Set?): Return? { + override fun aggregateCalculatingType(values: Iterable, valueTypes: Set?): Return { val valueTypes = valueTypes ?: values.types() val commonType = valueTypes .unifiedNumberType(PRIMITIVES_ONLY) @@ -176,7 +176,7 @@ internal class TwoStepNumbersAggregator( * @param getReturnTypeOrNull Functional argument for the [calculateReturnTypeOrNull] function. * @param aggregator Functional argument for the [aggregate] function, used within a [DataColumn] or [Iterable]. */ - class Factory( + class Factory( private val getReturnTypeOrNull: CalculateReturnTypeOrNull, private val aggregate: Aggregate, ) : AggregatorProvider> by AggregatorProvider({ name -> diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/ofRowExpression.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/ofRowExpression.kt index 80bdc5bc33..8d678b2914 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/ofRowExpression.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/ofRowExpression.kt @@ -19,19 +19,19 @@ import kotlin.reflect.typeOf internal inline fun Aggregator.aggregateOf( values: Iterable, noinline transform: (C) -> V, -): R? = aggregate(values.asSequence().map(transform).asIterable(), typeOf()) +): R = aggregate(values.asSequence().map(transform).asIterable(), typeOf()) @PublishedApi internal inline fun Aggregator.aggregateOf( column: DataColumn, noinline transform: (C) -> V, -): R? = aggregateOf(column.values(), transform) +): R = aggregateOf(column.values(), transform) @PublishedApi internal inline fun Aggregator<*, R>.aggregateOf( frame: DataFrame, crossinline expression: RowExpression, -): R? = (this as Aggregator).aggregateOf(frame.rows()) { expression(it, it) } +): R = (this as Aggregator).aggregateOf(frame.rows()) { expression(it, it) } @PublishedApi internal fun Aggregator<*, R>.aggregateOfDelegated( @@ -47,10 +47,10 @@ internal fun Aggregator<*, R>.aggregateOfDelegated( internal inline fun Aggregator<*, R>.of( data: DataFrame, crossinline expression: RowExpression, -): R? = aggregateOf(data as DataFrame, expression) +): R = aggregateOf(data as DataFrame, expression) @PublishedApi -internal inline fun Aggregator.of(data: DataColumn, crossinline expression: (C) -> V): R? = +internal inline fun Aggregator.of(data: DataColumn, crossinline expression: (C) -> V): R = aggregateOf(data.values()) { expression(it) } @PublishedApi diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/row.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/row.kt index 5cdba355ea..53abbbe9b5 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/row.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/row.kt @@ -14,7 +14,7 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator * @param columns selector of which columns inside the [row] to aggregate */ @PublishedApi -internal fun Aggregator.aggregateOfRow(row: AnyRow, columns: ColumnsSelector<*, V?>): R? { +internal fun Aggregator.aggregateOfRow(row: AnyRow, columns: ColumnsSelector<*, V?>): R { val filteredColumns = row.df().getColumns(columns) return aggregateCalculatingType( values = filteredColumns.mapNotNull { row[it] }, diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/withinAllColumns.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/withinAllColumns.kt index c46481bf65..b7217d3b1e 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/withinAllColumns.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/withinAllColumns.kt @@ -15,7 +15,7 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.internal import org.jetbrains.kotlinx.dataframe.impl.emptyPath @PublishedApi -internal fun Aggregator<*, R>.aggregateAll(data: DataFrame, columns: ColumnsSelector): R? = +internal fun Aggregator<*, R>.aggregateAll(data: DataFrame, columns: ColumnsSelector): R = data.aggregateAll(cast2(), columns) internal fun Aggregator<*, R>.aggregateAll( @@ -29,7 +29,7 @@ internal fun Aggregator<*, R>.aggregateAll( columns: ColumnsSelector, ): DataFrame = data.aggregateAll(cast(), columns) -internal fun DataFrame.aggregateAll(aggregator: Aggregator, columns: ColumnsSelector): R? = +internal fun DataFrame.aggregateAll(aggregator: Aggregator, columns: ColumnsSelector): R = aggregator.aggregate(get(columns)) internal fun Grouped.aggregateAll( diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt index 0f6a17181d..0bcf448e0b 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt @@ -24,6 +24,7 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.primitiveNumberColumns import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes import kotlin.reflect.KProperty +import kotlin.reflect.full.withNullability import kotlin.reflect.typeOf /* @@ -37,8 +38,7 @@ import kotlin.reflect.typeOf // region DataColumn -public fun DataColumn.mean(skipNA: Boolean = skipNA_default): Double = - Aggregators.mean(skipNA).aggregate(this)!! +public fun DataColumn.mean(skipNA: Boolean = skipNA_default): Double = Aggregators.mean(skipNA).aggregate(this) public inline fun DataColumn.meanOf( skipNA: Boolean = skipNA_default, @@ -47,7 +47,6 @@ public inline fun DataColumn.meanOf( Aggregators.mean(skipNA) .cast2() .aggregateOf(this, expression) - ?: Double.NaN // endregion @@ -56,15 +55,14 @@ public inline fun DataColumn.meanOf( public fun AnyRow.rowMean(skipNA: Boolean = skipNA_default): Double = Aggregators.mean(skipNA).aggregateOfRow(this) { colsOf { it.isPrimitiveNumber() } - } ?: Double.NaN + } -public inline fun AnyRow.rowMeanOf(skipNA: Boolean = skipNA_default): Double { - require(typeOf() in primitiveNumberTypes) { +public inline fun AnyRow.rowMeanOf(skipNA: Boolean = skipNA_default): Double { + require(typeOf().withNullability(false) in primitiveNumberTypes) { "Type ${T::class.simpleName} is not a primitive number type. Mean only supports primitive number types." } return Aggregators.mean(skipNA) .aggregateOfRow(this) { colsOf() } - ?: Double.NaN } // endregion @@ -97,7 +95,7 @@ public fun DataFrame.meanFor( public fun DataFrame.mean( skipNA: Boolean = skipNA_default, columns: ColumnsSelector, -): Double = Aggregators.mean(skipNA).aggregateAll(this, columns) ?: Double.NaN +): Double = Aggregators.mean(skipNA).aggregateAll(this, columns) public fun DataFrame.mean(vararg columns: String, skipNA: Boolean = skipNA_default): Double = mean(skipNA) { columns.toNumberColumns() } @@ -115,7 +113,7 @@ public fun DataFrame.mean(vararg columns: KProperty, skip public inline fun DataFrame.meanOf( skipNA: Boolean = skipNA_default, noinline expression: RowExpression, -): Double = Aggregators.mean(skipNA).of(this, expression) ?: Double.NaN +): Double = Aggregators.mean(skipNA).of(this, expression) // endregion diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt index 0050145715..7e1e8abd81 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt @@ -15,8 +15,7 @@ import kotlin.reflect.full.withNullability * @param Value The type of the values to be aggregated. * This can be nullable for [Iterables][Iterable] or not, depending on the use case. * For columns, [Value] will always be considered nullable; nulls are filtered out from columns anyway. - * @param Return The type of the resulting value. It doesn't matter if this is nullable or not, as the aggregator - * will always return a [Return]`?`. + * @param Return The type of the resulting value. Can optionally be nullable. */ @PublishedApi internal interface Aggregator { @@ -33,7 +32,7 @@ internal interface Aggregator { * * When the exact [type] is unknown, use [aggregateCalculatingType]. */ - fun aggregate(values: Iterable, type: KType): Return? + fun aggregate(values: Iterable, type: KType): Return /** * Aggregates the data in the given column and computes a single resulting value. @@ -41,12 +40,12 @@ internal interface Aggregator { * * See [AggregatorBase.aggregate]. */ - fun aggregate(column: DataColumn): Return? + fun aggregate(column: DataColumn): Return /** * Aggregates the data in the multiple given columns and computes a single resulting value. */ - fun aggregate(columns: Iterable>): Return? + fun aggregate(columns: Iterable>): Return /** * Special case of [aggregate] with [Iterable] that calculates the common type of the values at runtime. @@ -58,7 +57,7 @@ internal interface Aggregator { * It should contain all types of [values]. * If `null`, the types of [values] will be calculated at runtime (heavy!). */ - fun aggregateCalculatingType(values: Iterable, valueTypes: Set? = null): Return? + fun aggregateCalculatingType(values: Iterable, valueTypes: Set? = null): Return /** * Function that can give the return type of [aggregate] as [KType], given the type of the input. @@ -82,9 +81,11 @@ internal interface Aggregator { fun calculateReturnTypeOrNull(colTypes: Set, colsEmpty: Boolean): KType? } +@Suppress("UNCHECKED_CAST") @PublishedApi internal fun Aggregator<*, *>.cast(): Aggregator = this as Aggregator +@Suppress("UNCHECKED_CAST") @PublishedApi internal fun Aggregator<*, *>.cast2(): Aggregator = this as Aggregator @@ -92,7 +93,7 @@ internal fun Aggregator<*, *>.cast2(): Aggregator internal typealias CalculateReturnTypeOrNull = (type: KType, emptyInput: Boolean) -> KType? /** Type alias for [Aggregator.aggregate]. */ -internal typealias Aggregate = Iterable.(type: KType) -> Return? +internal typealias Aggregate = Iterable.(type: KType) -> Return /** Common case for [CalculateReturnTypeOrNull], preserves return type, but makes it nullable for empty inputs. */ internal val preserveReturnTypeNullIfEmpty: CalculateReturnTypeOrNull = { type, emptyInput -> diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorBase.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorBase.kt index 906b40dc83..799442849f 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorBase.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorBase.kt @@ -31,7 +31,7 @@ internal abstract class AggregatorBase( * * When the exact [type] is unknown, use [aggregateCalculatingType]. */ - override fun aggregate(values: Iterable, type: KType): Return? = aggregator(values, type) + override fun aggregate(values: Iterable, type: KType): Return = aggregator(values, type) /** * Function that can give the return type of [aggregate] as [KType], given the type of the input. @@ -52,7 +52,7 @@ internal abstract class AggregatorBase( * Nulls are filtered out by default, then [aggregate] (with [Iterable] and [KType]) is called. */ @Suppress("UNCHECKED_CAST") - override fun aggregate(column: DataColumn): Return? = + override fun aggregate(column: DataColumn): Return = aggregate( values = if (column.hasNulls()) { @@ -64,7 +64,7 @@ internal abstract class AggregatorBase( ) /** @include [Aggregator.aggregateCalculatingType] */ - override fun aggregateCalculatingType(values: Iterable, valueTypes: Set?): Return? { + override fun aggregateCalculatingType(values: Iterable, valueTypes: Set?): Return { val commonType = if (valueTypes != null) { valueTypes.commonType(false) } else { @@ -86,7 +86,7 @@ internal abstract class AggregatorBase( * Aggregates the data in the multiple given columns and computes a single resulting value. * Must be overridden to use. */ - abstract override fun aggregate(columns: Iterable>): Return? + abstract override fun aggregate(columns: Iterable>): Return /** * Function that can give the return type of [aggregate] with columns as [KType], diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt index 6258e8bbd4..c81f6af8e1 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt @@ -17,7 +17,7 @@ internal object Aggregators { * * @include [TwoStepAggregator] */ - private fun twoStepPreservingType(aggregator: Aggregate) = + private fun twoStepPreservingType(aggregator: Aggregate) = TwoStepAggregator.Factory( getReturnTypeOrNull = preserveReturnTypeNullIfEmpty, stepOneAggregator = aggregator, @@ -44,7 +44,7 @@ internal object Aggregators { * * @include [FlatteningAggregator] */ - private fun flatteningPreservingTypes(aggregate: Aggregate) = + private fun flatteningPreservingTypes(aggregate: Aggregate) = FlatteningAggregator.Factory( getReturnTypeOrNull = preserveReturnTypeNullIfEmpty, aggregator = aggregate, @@ -68,7 +68,7 @@ internal object Aggregators { * * @include [TwoStepNumbersAggregator] */ - private fun twoStepForNumbers( + private fun twoStepForNumbers( getReturnTypeOrNull: CalculateReturnTypeOrNull, aggregate: Aggregate, ) = TwoStepNumbersAggregator.Factory( diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/FlatteningAggregator.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/FlatteningAggregator.kt index 270777e7a2..53124d1c6b 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/FlatteningAggregator.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/FlatteningAggregator.kt @@ -18,7 +18,7 @@ import kotlin.reflect.full.withNullability * Iterable> * -> Iterable // flattened without nulls * -> aggregator(Iterable, common colType) - * -> Return? + * -> Return * ``` * * This is essential for aggregators that depend on the distribution of all values across the dataframe, like @@ -42,7 +42,7 @@ internal class FlatteningAggregator( * The columns are flattened into a single list of values, filtering nulls as usual; * then the aggregation function is with the common type of the columns. */ - override fun aggregate(columns: Iterable>): Return? { + override fun aggregate(columns: Iterable>): Return { val commonType = columns.map { it.type() }.commonType().withNullability(false) val allValues = columns.asSequence().flatMap { it.values() }.filterNotNull() return aggregate(allValues.asIterable(), commonType) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepAggregator.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepAggregator.kt index 11738fbf5e..879f740f79 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepAggregator.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepAggregator.kt @@ -23,7 +23,7 @@ import kotlin.reflect.full.withNullability * -> stepOneAggregator(Iterable, colType) // called on each iterable * -> Iterable // nulls filtered out * -> stepTwoAggregator(Iterable, common valueType) - * -> Return? + * -> Return * ``` * * It can also be used as a "simple" aggregator by providing the same function for both steps. @@ -40,7 +40,7 @@ internal class TwoStepAggregator( name: String, getReturnTypeOrNull: CalculateReturnTypeOrNull, stepOneAggregator: Aggregate, - private val stepTwoAggregator: Aggregate, + private val stepTwoAggregator: Aggregate, ) : AggregatorBase(name, getReturnTypeOrNull, stepOneAggregator) { /** @@ -50,7 +50,7 @@ internal class TwoStepAggregator( * * Post-step-one types are calculated by [calculateReturnTypeOrNull]. */ - override fun aggregate(columns: Iterable>): Return? { + override fun aggregate(columns: Iterable>): Return { val (values, types) = columns.mapNotNull { col -> // uses stepOneAggregator val value = aggregate(col) ?: return@mapNotNull null @@ -93,7 +93,7 @@ internal class TwoStepAggregator( class Factory( private val getReturnTypeOrNull: CalculateReturnTypeOrNull, private val stepOneAggregator: Aggregate, - private val stepTwoAggregator: Aggregate, + private val stepTwoAggregator: Aggregate, ) : AggregatorProvider> by AggregatorProvider({ name -> TwoStepAggregator( name = name, diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepNumbersAggregator.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepNumbersAggregator.kt index bb229720d3..7441c63769 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepNumbersAggregator.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepNumbersAggregator.kt @@ -39,7 +39,7 @@ private val logger = KotlinLogging.logger { } * -> aggregator(Iterable, unified number type of common colType) // called on each iterable * -> Iterable // nulls filtered out * -> aggregator(Iterable, unified number type of common valueType) - * -> Return? + * -> Return * ``` * * @param name The name of this aggregator. @@ -48,7 +48,7 @@ private val logger = KotlinLogging.logger { } * While it takes a [Number] argument, you can assume that all values are of the same specific type, however, * this type can be different for different calls to [aggregator]. */ -internal class TwoStepNumbersAggregator( +internal class TwoStepNumbersAggregator( name: String, getReturnTypeOrNull: CalculateReturnTypeOrNull, aggregator: Aggregate, @@ -62,7 +62,7 @@ internal class TwoStepNumbersAggregator( * After the first aggregation, the number types are found by [calculateReturnTypeOrNull] and then * unified using [aggregateCalculatingType]. */ - override fun aggregate(columns: Iterable>): Return? { + override fun aggregate(columns: Iterable>): Return { val (values, types) = columns.mapNotNull { col -> val value = aggregate(col) ?: return@mapNotNull null val type = calculateReturnTypeOrNull( @@ -113,7 +113,7 @@ internal class TwoStepNumbersAggregator( * * When the exact [type] is unknown, use [aggregateCalculatingType]. */ - override fun aggregate(values: Iterable, type: KType): Return? { + override fun aggregate(values: Iterable, type: KType): Return { require(type.isSubtypeOf(typeOf())) { "${TwoStepNumbersAggregator::class.simpleName}: Type $type is not a subtype of Number?" } @@ -147,7 +147,7 @@ internal class TwoStepNumbersAggregator( * If `null`, the types of [values] will be calculated at runtime (heavy!). */ @Suppress("UNCHECKED_CAST") - override fun aggregateCalculatingType(values: Iterable, valueTypes: Set?): Return? { + override fun aggregateCalculatingType(values: Iterable, valueTypes: Set?): Return { val valueTypes = valueTypes ?: values.types() val commonType = valueTypes .unifiedNumberType(PRIMITIVES_ONLY) @@ -176,7 +176,7 @@ internal class TwoStepNumbersAggregator( * @param getReturnTypeOrNull Functional argument for the [calculateReturnTypeOrNull] function. * @param aggregator Functional argument for the [aggregate] function, used within a [DataColumn] or [Iterable]. */ - class Factory( + class Factory( private val getReturnTypeOrNull: CalculateReturnTypeOrNull, private val aggregate: Aggregate, ) : AggregatorProvider> by AggregatorProvider({ name -> diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/ofRowExpression.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/ofRowExpression.kt index 80bdc5bc33..8d678b2914 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/ofRowExpression.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/ofRowExpression.kt @@ -19,19 +19,19 @@ import kotlin.reflect.typeOf internal inline fun Aggregator.aggregateOf( values: Iterable, noinline transform: (C) -> V, -): R? = aggregate(values.asSequence().map(transform).asIterable(), typeOf()) +): R = aggregate(values.asSequence().map(transform).asIterable(), typeOf()) @PublishedApi internal inline fun Aggregator.aggregateOf( column: DataColumn, noinline transform: (C) -> V, -): R? = aggregateOf(column.values(), transform) +): R = aggregateOf(column.values(), transform) @PublishedApi internal inline fun Aggregator<*, R>.aggregateOf( frame: DataFrame, crossinline expression: RowExpression, -): R? = (this as Aggregator).aggregateOf(frame.rows()) { expression(it, it) } +): R = (this as Aggregator).aggregateOf(frame.rows()) { expression(it, it) } @PublishedApi internal fun Aggregator<*, R>.aggregateOfDelegated( @@ -47,10 +47,10 @@ internal fun Aggregator<*, R>.aggregateOfDelegated( internal inline fun Aggregator<*, R>.of( data: DataFrame, crossinline expression: RowExpression, -): R? = aggregateOf(data as DataFrame, expression) +): R = aggregateOf(data as DataFrame, expression) @PublishedApi -internal inline fun Aggregator.of(data: DataColumn, crossinline expression: (C) -> V): R? = +internal inline fun Aggregator.of(data: DataColumn, crossinline expression: (C) -> V): R = aggregateOf(data.values()) { expression(it) } @PublishedApi diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/row.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/row.kt index 5cdba355ea..53abbbe9b5 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/row.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/row.kt @@ -14,7 +14,7 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator * @param columns selector of which columns inside the [row] to aggregate */ @PublishedApi -internal fun Aggregator.aggregateOfRow(row: AnyRow, columns: ColumnsSelector<*, V?>): R? { +internal fun Aggregator.aggregateOfRow(row: AnyRow, columns: ColumnsSelector<*, V?>): R { val filteredColumns = row.df().getColumns(columns) return aggregateCalculatingType( values = filteredColumns.mapNotNull { row[it] }, diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/withinAllColumns.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/withinAllColumns.kt index c46481bf65..b7217d3b1e 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/withinAllColumns.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/withinAllColumns.kt @@ -15,7 +15,7 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.internal import org.jetbrains.kotlinx.dataframe.impl.emptyPath @PublishedApi -internal fun Aggregator<*, R>.aggregateAll(data: DataFrame, columns: ColumnsSelector): R? = +internal fun Aggregator<*, R>.aggregateAll(data: DataFrame, columns: ColumnsSelector): R = data.aggregateAll(cast2(), columns) internal fun Aggregator<*, R>.aggregateAll( @@ -29,7 +29,7 @@ internal fun Aggregator<*, R>.aggregateAll( columns: ColumnsSelector, ): DataFrame = data.aggregateAll(cast(), columns) -internal fun DataFrame.aggregateAll(aggregator: Aggregator, columns: ColumnsSelector): R? = +internal fun DataFrame.aggregateAll(aggregator: Aggregator, columns: ColumnsSelector): R = aggregator.aggregate(get(columns)) internal fun Grouped.aggregateAll( From 7d72bdfead6c351283907ea99da82a2b7dcf86c5 Mon Sep 17 00:00:00 2001 From: Jolan Rensen Date: Mon, 17 Mar 2025 15:37:35 +0100 Subject: [PATCH 09/11] aggregators now always filter out nulls, not just for columns --- .../aggregation/aggregators/Aggregator.kt | 19 +++++++------ .../aggregation/aggregators/AggregatorBase.kt | 27 ++++++++++--------- .../aggregators/FlatteningAggregator.kt | 2 +- .../aggregators/TwoStepAggregator.kt | 2 +- .../aggregators/TwoStepNumbersAggregator.kt | 9 +++---- .../impl/aggregation/modes/ofRowExpression.kt | 7 ++++- 6 files changed, 37 insertions(+), 29 deletions(-) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt index 7e1e8abd81..3fd3829c5d 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt @@ -13,8 +13,7 @@ import kotlin.reflect.full.withNullability * The [AggregatorBase] class is a base implementation of this interface. * * @param Value The type of the values to be aggregated. - * This can be nullable for [Iterables][Iterable] or not, depending on the use case. - * For columns, [Value] will always be considered nullable; nulls are filtered out from columns anyway. + * The input can always have nulls, they are filtered out. * @param Return The type of the resulting value. Can optionally be nullable. */ @PublishedApi @@ -26,17 +25,18 @@ internal interface Aggregator { /** * Base function of [Aggregator]. * - * Aggregates the given values, taking [type] into account, and computes a single resulting value. + * Aggregates the given values, taking [type] into account, + * filtering nulls, and computes a single resulting value. * * When using [AggregatorBase], this can be supplied by the [AggregatorBase.aggregator] argument. * * When the exact [type] is unknown, use [aggregateCalculatingType]. */ - fun aggregate(values: Iterable, type: KType): Return + fun aggregate(values: Iterable, type: KType): Return /** * Aggregates the data in the given column and computes a single resulting value. - * Nulls are filtered out by default, then [aggregate] (with [Iterable] and [KType]) is called. + * Calls [aggregate] (with [Iterable] and [KType]). * * See [AggregatorBase.aggregate]. */ @@ -57,7 +57,7 @@ internal interface Aggregator { * It should contain all types of [values]. * If `null`, the types of [values] will be calculated at runtime (heavy!). */ - fun aggregateCalculatingType(values: Iterable, valueTypes: Set? = null): Return + fun aggregateCalculatingType(values: Iterable, valueTypes: Set? = null): Return /** * Function that can give the return type of [aggregate] as [KType], given the type of the input. @@ -92,8 +92,11 @@ internal fun Aggregator<*, *>.cast2(): Aggregator /** Type alias for [Aggregator.calculateReturnTypeOrNull] */ internal typealias CalculateReturnTypeOrNull = (type: KType, emptyInput: Boolean) -> KType? -/** Type alias for [Aggregator.aggregate]. */ -internal typealias Aggregate = Iterable.(type: KType) -> Return +/** + * Type alias for the argument for [Aggregator.aggregate]. + * Nulls have already been filtered out when this argument is called. + */ +internal typealias Aggregate = Iterable.(type: KType) -> Return /** Common case for [CalculateReturnTypeOrNull], preserves return type, but makes it nullable for empty inputs. */ internal val preserveReturnTypeNullIfEmpty: CalculateReturnTypeOrNull = { type, emptyInput -> diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorBase.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorBase.kt index 799442849f..6596573903 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorBase.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorBase.kt @@ -14,7 +14,7 @@ import kotlin.reflect.full.withNullability * or multiple [DataColumns][DataColumn]. * * @param name The name of this aggregator. - * @param aggregator Functional argument for the [aggregate] function. + * @param aggregator Functional argument for the [aggregate] function. Nulls are filtered out before this is called. */ internal abstract class AggregatorBase( override val name: String, @@ -25,13 +25,19 @@ internal abstract class AggregatorBase( /** * Base function of [Aggregator]. * - * Aggregates the given values, taking [type] into account, and computes a single resulting value. + * Aggregates the given values, taking [type] into account, + * filtering nulls, and computes a single resulting value. * - * Uses [aggregator] to compute the result. + * When using [AggregatorBase], this can be supplied by the [AggregatorBase.aggregator] argument. * * When the exact [type] is unknown, use [aggregateCalculatingType]. */ - override fun aggregate(values: Iterable, type: KType): Return = aggregator(values, type) + @Suppress("UNCHECKED_CAST") + override fun aggregate(values: Iterable, type: KType): Return = + aggregator( + values.asSequence().filterNotNull().asIterable(), // TODO make dependant on type's nullability + type.withNullability(false), + ) /** * Function that can give the return type of [aggregate] as [KType], given the type of the input. @@ -44,7 +50,7 @@ internal abstract class AggregatorBase( * @return The return type of [aggregate] as [KType]. */ override fun calculateReturnTypeOrNull(type: KType, emptyInput: Boolean): KType? = - getReturnTypeOrNull(type, emptyInput) + getReturnTypeOrNull(type.withNullability(false), emptyInput) /** * Aggregates the data in the given column and computes a single resulting value. @@ -54,17 +60,12 @@ internal abstract class AggregatorBase( @Suppress("UNCHECKED_CAST") override fun aggregate(column: DataColumn): Return = aggregate( - values = - if (column.hasNulls()) { - column.asSequence().filterNotNull().asIterable() - } else { - column.asIterable() as Iterable - }, - type = column.type().withNullability(false), + values = column.asIterable(), + type = column.type(), ) /** @include [Aggregator.aggregateCalculatingType] */ - override fun aggregateCalculatingType(values: Iterable, valueTypes: Set?): Return { + override fun aggregateCalculatingType(values: Iterable, valueTypes: Set?): Return { val commonType = if (valueTypes != null) { valueTypes.commonType(false) } else { diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/FlatteningAggregator.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/FlatteningAggregator.kt index 53124d1c6b..1ba0c41eae 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/FlatteningAggregator.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/FlatteningAggregator.kt @@ -8,7 +8,7 @@ import kotlin.reflect.full.withNullability /** * Simple [Aggregator] implementation with flattening behavior for multiple columns. * - * Nulls are filtered from columns. + * Nulls are filtered out. * * When called on multiple columns, * the columns are flattened into a single list of values, filtering nulls as usual; diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepAggregator.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepAggregator.kt index 879f740f79..b593f4f261 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepAggregator.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepAggregator.kt @@ -10,7 +10,7 @@ import kotlin.reflect.full.withNullability /** * A slightly more advanced [Aggregator] implementation. * - * Nulls are filtered from columns. + * Nulls are filtered out. * * When called on multiple columns, this [Aggregator] works in two steps: * First, it aggregates within a [DataColumn]/[Iterable] ([stepOneAggregator]) with their (given) type, diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepNumbersAggregator.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepNumbersAggregator.kt index 7441c63769..7768a2d051 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepNumbersAggregator.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepNumbersAggregator.kt @@ -24,7 +24,7 @@ private val logger = KotlinLogging.logger { } * [Aggregator] made specifically for number calculations. * Mixed number types are [unified][UnifyingNumbers] to [primitives][PRIMITIVES_ONLY]. * - * Nulls are filtered from columns. + * Nulls are filtered out. * * When called on multiple columns (with potentially mixed [Number] types), * this [Aggregator] works in two steps: @@ -113,11 +113,10 @@ internal class TwoStepNumbersAggregator( * * When the exact [type] is unknown, use [aggregateCalculatingType]. */ - override fun aggregate(values: Iterable, type: KType): Return { + override fun aggregate(values: Iterable, type: KType): Return { require(type.isSubtypeOf(typeOf())) { "${TwoStepNumbersAggregator::class.simpleName}: Type $type is not a subtype of Number?" } - return when (type.withNullability(false)) { // If the type is not a specific number, but rather a mixed Number, we unify the types first. // This is heavy and could be avoided by calling aggregate with a specific number type @@ -147,8 +146,8 @@ internal class TwoStepNumbersAggregator( * If `null`, the types of [values] will be calculated at runtime (heavy!). */ @Suppress("UNCHECKED_CAST") - override fun aggregateCalculatingType(values: Iterable, valueTypes: Set?): Return { - val valueTypes = valueTypes ?: values.types() + override fun aggregateCalculatingType(values: Iterable, valueTypes: Set?): Return { + val valueTypes = valueTypes ?: values.filterNotNull().types() val commonType = valueTypes .unifiedNumberType(PRIMITIVES_ONLY) .withNullability(false) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/ofRowExpression.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/ofRowExpression.kt index 8d678b2914..41a74fdc1b 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/ofRowExpression.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/ofRowExpression.kt @@ -13,13 +13,18 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregateInternal import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator import org.jetbrains.kotlinx.dataframe.impl.aggregation.internal import org.jetbrains.kotlinx.dataframe.impl.emptyPath +import kotlin.reflect.full.withNullability import kotlin.reflect.typeOf @PublishedApi internal inline fun Aggregator.aggregateOf( values: Iterable, noinline transform: (C) -> V, -): R = aggregate(values.asSequence().map(transform).asIterable(), typeOf()) +): R = + aggregate( + values = values.asSequence().mapNotNull(transform).asIterable(), + type = typeOf().withNullability(false), + ) @PublishedApi internal inline fun Aggregator.aggregateOf( From b42c7b4d58b30d983f3d04ddc0133b4588b11072 Mon Sep 17 00:00:00 2001 From: Jolan Rensen Date: Mon, 17 Mar 2025 16:24:29 +0100 Subject: [PATCH 10/11] Fixed AggregatorBase only filtering nulls when the type says they exist. unifying numbers can now handle null/nothing in the input. --- .../documentation/UnifyingNumbers.kt | 5 +++++ .../kotlinx/dataframe/impl/NumberTypeUtils.kt | 13 +++++++++--- .../kotlinx/dataframe/impl/TypeUtils.kt | 17 +++++++++++----- .../aggregation/aggregators/Aggregator.kt | 5 +++-- .../aggregation/aggregators/AggregatorBase.kt | 20 +++++++++++++++---- .../aggregators/TwoStepNumbersAggregator.kt | 17 +++++++++------- 6 files changed, 56 insertions(+), 21 deletions(-) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/documentation/UnifyingNumbers.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/documentation/UnifyingNumbers.kt index 2d1a0125c5..dd149f2328 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/documentation/UnifyingNumbers.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/documentation/UnifyingNumbers.kt @@ -21,6 +21,8 @@ import org.jetbrains.kotlinx.dataframe.impl.UnifiedNumberTypeOptions * potentially losing a little precision, but a warning will be given. * * See [UnifiedNumberTypeOptions] for these settings. + * + * At the bottom of the graph is [Nothing]. This can be interpreted as `null`. */ internal interface UnifyingNumbers { @@ -40,6 +42,9 @@ internal interface UnifyingNumbers { * | / | * | / | * UByte Byte + * \\ / + * \\ / + * Nothing? * ``` */ interface Graph diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt index 557c2ae0fd..f3f42ced5b 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt @@ -82,6 +82,9 @@ internal fun getUnifiedNumberTypeGraph( addEdge(typeOf(), typeOf()) addEdge(typeOf(), typeOf()) + + addEdge(typeOf(), nothingType) + addEdge(typeOf(), nothingType) } } @@ -121,7 +124,11 @@ internal fun getUnifiedNumberType( ?: error("Can not find common number type for $first and $second") } - return if (first.isMarkedNullable || second.isMarkedNullable) result.withNullability(true) else result + return if (first.isMarkedNullable || second.isMarkedNullable) { + result.withNullability(true) + } else { + result + } } /** @include [getUnifiedNumberType] */ @@ -184,7 +191,7 @@ internal fun Iterable.convertToUnifiedNumberType( options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT, commonNumberType: KType? = null, ): Iterable { - val commonNumberType = commonNumberType ?: this.filterNotNull().types().unifiedNumberType(options) + val commonNumberType = commonNumberType ?: this.types().unifiedNumberType(options) val converter = createConverter(typeOf(), commonNumberType)!! as (Number) -> Number? return map { if (it == null) return@map null @@ -209,7 +216,7 @@ internal fun Sequence.convertToUnifiedNumberType( options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT, commonNumberType: KType? = null, ): Sequence { - val commonNumberType = commonNumberType ?: this.filterNotNull().asIterable().types().unifiedNumberType(options) + val commonNumberType = commonNumberType ?: this.asIterable().types().unifiedNumberType(options) val converter = createConverter(typeOf(), commonNumberType)!! as (Number) -> Number? return map { if (it == null) return@map null diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/TypeUtils.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/TypeUtils.kt index b5f9635139..d6b4b96ed4 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/TypeUtils.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/TypeUtils.kt @@ -667,19 +667,26 @@ internal fun Any.isBigNumber(): Boolean = this is BigInteger || this is BigDecim * * The [KClass] is determined by retrieving the runtime class of each element. * + * [Nothing::class][Nothing] is used for elements that are `null`. + * * @return A set of [KClass] objects representing the runtime types of elements in the iterable. */ -internal fun Iterable.classes(): Set> = mapTo(mutableSetOf()) { it::class } +internal fun Iterable.classes(): Set> = + mapTo(mutableSetOf()) { + if (it == null) Nothing::class else it::class + } /** * Returns a set of [KType] objects representing the star-projected types of the runtime classes * of all unique elements in the iterable. * - * The method internally relies on the [classes] function to collect the runtime classes of the - * elements in the iterable and then maps each class to its star-projected type. - * * This can be a heavy operation! * + * [typeOf()][nullableNothingType] is used for elements that are `null`. + * * @return A set of [KType] objects corresponding to the star-projected runtime types of elements in the iterable. */ -internal fun Iterable.types(): Set = classes().mapTo(mutableSetOf()) { it.createStarProjectedType(false) } +internal fun Iterable.types(): Set = + mapTo(mutableSetOf()) { + if (it == null) nullableNothingType else it::class.createStarProjectedType(false) + } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt index 3fd3829c5d..74bbc1a6e9 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt @@ -26,7 +26,8 @@ internal interface Aggregator { * Base function of [Aggregator]. * * Aggregates the given values, taking [type] into account, - * filtering nulls, and computes a single resulting value. + * filtering nulls (only if [type.isMarkedNullable][KType.isMarkedNullable]), + * and computes a single resulting value. * * When using [AggregatorBase], this can be supplied by the [AggregatorBase.aggregator] argument. * @@ -55,7 +56,7 @@ internal interface Aggregator { * @param valueTypes The types of the values. * If provided, this can be used to avoid calculating the types of [values] at runtime with reflection. * It should contain all types of [values]. - * If `null`, the types of [values] will be calculated at runtime (heavy!). + * If `null` or empty, the types of [values] will be calculated at runtime (heavy!). */ fun aggregateCalculatingType(values: Iterable, valueTypes: Set? = null): Return diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorBase.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorBase.kt index 6596573903..ada6638ceb 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorBase.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorBase.kt @@ -4,6 +4,7 @@ import org.jetbrains.kotlinx.dataframe.DataColumn import org.jetbrains.kotlinx.dataframe.api.asIterable import org.jetbrains.kotlinx.dataframe.api.asSequence import org.jetbrains.kotlinx.dataframe.impl.commonType +import org.jetbrains.kotlinx.dataframe.impl.nothingType import kotlin.reflect.KType import kotlin.reflect.full.withNullability @@ -26,7 +27,8 @@ internal abstract class AggregatorBase( * Base function of [Aggregator]. * * Aggregates the given values, taking [type] into account, - * filtering nulls, and computes a single resulting value. + * filtering nulls (only if [type.isMarkedNullable][KType.isMarkedNullable]), + * and computes a single resulting value. * * When using [AggregatorBase], this can be supplied by the [AggregatorBase.aggregator] argument. * @@ -35,7 +37,13 @@ internal abstract class AggregatorBase( @Suppress("UNCHECKED_CAST") override fun aggregate(values: Iterable, type: KType): Return = aggregator( - values.asSequence().filterNotNull().asIterable(), // TODO make dependant on type's nullability + // values = + if (type.isMarkedNullable) { + values.asSequence().filterNotNull().asIterable() + } else { + values as Iterable + }, + // type = type.withNullability(false), ) @@ -66,7 +74,7 @@ internal abstract class AggregatorBase( /** @include [Aggregator.aggregateCalculatingType] */ override fun aggregateCalculatingType(values: Iterable, valueTypes: Set?): Return { - val commonType = if (valueTypes != null) { + val commonType = if (valueTypes != null && valueTypes.isNotEmpty()) { valueTypes.commonType(false) } else { var hasNulls = false @@ -78,7 +86,11 @@ internal abstract class AggregatorBase( it.javaClass.kotlin } } - classes.commonType(hasNulls) + if (classes.isEmpty()) { + nothingType(hasNulls) + } else { + classes.commonType(hasNulls) + } } return aggregate(values, commonType) } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepNumbersAggregator.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepNumbersAggregator.kt index 7768a2d051..31784c6036 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepNumbersAggregator.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepNumbersAggregator.kt @@ -7,6 +7,7 @@ import org.jetbrains.kotlinx.dataframe.documentation.UnifyingNumbers import org.jetbrains.kotlinx.dataframe.impl.UnifiedNumberTypeOptions.Companion.PRIMITIVES_ONLY import org.jetbrains.kotlinx.dataframe.impl.anyNull import org.jetbrains.kotlinx.dataframe.impl.convertToUnifiedNumberType +import org.jetbrains.kotlinx.dataframe.impl.isNothing import org.jetbrains.kotlinx.dataframe.impl.nothingType import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes import org.jetbrains.kotlinx.dataframe.impl.renderType @@ -106,6 +107,8 @@ internal class TwoStepNumbersAggregator( * * Aggregates the given values, taking [type] into account, and computes a single resulting value. * + * Nulls are filtered out (only if [type.isMarkedNullable][KType.isMarkedNullable]). + * * Uses [aggregator] to compute the result. * * This function is modified to call [aggregateCalculatingType] when it encounters mixed number types. @@ -143,21 +146,21 @@ internal class TwoStepNumbersAggregator( * @param valueTypes The types of the numbers. * If provided, this can be used to avoid calculating the types of [values] at runtime with reflection. * It should contain all types of [values]. - * If `null`, the types of [values] will be calculated at runtime (heavy!). + * If `null` or empty, the types of [values] will be calculated at runtime (heavy!). */ @Suppress("UNCHECKED_CAST") override fun aggregateCalculatingType(values: Iterable, valueTypes: Set?): Return { - val valueTypes = valueTypes ?: values.filterNotNull().types() - val commonType = valueTypes - .unifiedNumberType(PRIMITIVES_ONLY) - .withNullability(false) + val valueTypes = valueTypes?.takeUnless { it.isEmpty() } ?: values.types() + val commonType = valueTypes.unifiedNumberType(PRIMITIVES_ONLY) - if (commonType == typeOf() && (typeOf() in valueTypes || typeOf() in valueTypes)) { + if (commonType.isSubtypeOf(typeOf()) && + (typeOf() in valueTypes || typeOf() in valueTypes) + ) { logger.warn { "Number unification of Long -> Double happened during aggregation. Loss of precision may have occurred." } } - if (commonType !in primitiveNumberTypes && commonType != nothingType) { + if (commonType.withNullability(false) !in primitiveNumberTypes && !commonType.isNothing) { throw IllegalArgumentException( "Cannot calculate $name of ${renderType(commonType)}, only primitive numbers are supported.", ) From 33e35bc50893d7468e09669dbb369faf0a594eb1 Mon Sep 17 00:00:00 2001 From: Jolan Rensen Date: Mon, 17 Mar 2025 17:50:09 +0100 Subject: [PATCH 11/11] fixed ofRowExpression.kt: removed of-overloads, as they are duplicates of aggregateOf. Fixed nullability in lambda return types. Made sure all lambdas are crossinline. Added test for medianOf to check everything still works as expected. --- .../jetbrains/kotlinx/dataframe/api/mean.kt | 13 +++------- .../jetbrains/kotlinx/dataframe/api/median.kt | 3 +-- .../kotlinx/dataframe/api/percentile.kt | 3 +-- .../jetbrains/kotlinx/dataframe/api/std.kt | 3 +-- .../jetbrains/kotlinx/dataframe/api/sum.kt | 5 ++-- .../impl/aggregation/modes/ofRowExpression.kt | 26 ++++++------------- .../kotlinx/dataframe/statistics/median.kt | 22 ++++++++++++++++ 7 files changed, 39 insertions(+), 36 deletions(-) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt index 0bcf448e0b..bbdb9707d0 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt @@ -14,12 +14,10 @@ import org.jetbrains.kotlinx.dataframe.columns.ColumnReference import org.jetbrains.kotlinx.dataframe.columns.toColumnSet import org.jetbrains.kotlinx.dataframe.columns.toColumnsSetOf import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregators -import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.cast2 import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateAll import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOfRow -import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.of import org.jetbrains.kotlinx.dataframe.impl.aggregation.primitiveNumberColumns import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes @@ -42,11 +40,8 @@ public fun DataColumn.mean(skipNA: Boolean = skipNA_default): Double = public inline fun DataColumn.meanOf( skipNA: Boolean = skipNA_default, - noinline expression: (T) -> R?, -): Double = - Aggregators.mean(skipNA) - .cast2() - .aggregateOf(this, expression) + crossinline expression: (T) -> R?, +): Double = Aggregators.mean(skipNA).aggregateOf(this, expression) // endregion @@ -112,8 +107,8 @@ public fun DataFrame.mean(vararg columns: KProperty, skip public inline fun DataFrame.meanOf( skipNA: Boolean = skipNA_default, - noinline expression: RowExpression, -): Double = Aggregators.mean(skipNA).of(this, expression) + crossinline expression: RowExpression, +): Double = Aggregators.mean(skipNA).aggregateOf(this, expression) // endregion diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/median.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/median.kt index 8da5194a7d..0dfd662d1e 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/median.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/median.kt @@ -18,7 +18,6 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.intraComparableColumns import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateAll import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf -import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.of import org.jetbrains.kotlinx.dataframe.impl.columns.toComparableColumns import org.jetbrains.kotlinx.dataframe.impl.suggestIfNull import org.jetbrains.kotlinx.dataframe.math.median @@ -101,7 +100,7 @@ public fun > DataFrame.medianOrNull(vararg columns: KPro public inline fun > DataFrame.medianOf( crossinline expression: RowExpression, -): R? = Aggregators.median.of(this, expression) as R? +): R? = Aggregators.median.aggregateOf(this, expression) as R? // endregion diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/percentile.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/percentile.kt index b0a08bef6d..421817eeb7 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/percentile.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/percentile.kt @@ -16,7 +16,6 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.intraComparableColumns import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateAll import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf -import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.of import org.jetbrains.kotlinx.dataframe.impl.columns.toComparableColumns import org.jetbrains.kotlinx.dataframe.impl.suggestIfNull import org.jetbrains.kotlinx.dataframe.math.percentile @@ -121,7 +120,7 @@ public fun > DataFrame.percentileOrNull(percentile: Doub public inline fun > DataFrame.percentileOf( percentile: Double, crossinline expression: RowExpression, -): R? = Aggregators.percentile(percentile).of(this, expression) as R? +): R? = Aggregators.percentile(percentile).aggregateOf(this, expression) as R? // endregion diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/std.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/std.kt index 163cabf4c7..52745aa555 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/std.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/std.kt @@ -18,7 +18,6 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.cast2 import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateAll import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf -import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.of import org.jetbrains.kotlinx.dataframe.impl.aggregation.numberColumns import org.jetbrains.kotlinx.dataframe.math.std import kotlin.reflect.KProperty @@ -99,7 +98,7 @@ public inline fun DataFrame.stdOf( skipNA: Boolean = skipNA_default, ddof: Int = ddof_default, crossinline expression: RowExpression, -): Double = Aggregators.std(skipNA, ddof).of(this, expression) ?: .0 +): Double = Aggregators.std(skipNA, ddof).aggregateOf(this, expression) ?: .0 // endregion diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt index c0bda09485..af0016ee18 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt @@ -20,7 +20,6 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.cast import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateAll import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf -import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.of import org.jetbrains.kotlinx.dataframe.impl.aggregation.numberColumns import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns import org.jetbrains.kotlinx.dataframe.impl.zero @@ -38,8 +37,8 @@ public fun DataColumn.sum(): T = values.sum(type()) @JvmName("sumTNullable") public fun DataColumn.sum(): T = values.sum(type()) -public inline fun DataColumn.sumOf(crossinline expression: (T) -> R): R? = - (Aggregators.sum as Aggregator<*, *>).cast().of(this, expression) +public inline fun DataColumn.sumOf(noinline expression: (T) -> R): R? = + (Aggregators.sum as Aggregator<*, *>).cast().aggregateOf(this, expression) // endregion diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/ofRowExpression.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/ofRowExpression.kt index 41a74fdc1b..c2ca6a2483 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/ofRowExpression.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/ofRowExpression.kt @@ -19,23 +19,23 @@ import kotlin.reflect.typeOf @PublishedApi internal inline fun Aggregator.aggregateOf( values: Iterable, - noinline transform: (C) -> V, + crossinline transform: (C) -> V?, ): R = aggregate( - values = values.asSequence().mapNotNull(transform).asIterable(), + values = values.asSequence().mapNotNull { transform(it) }.asIterable(), type = typeOf().withNullability(false), ) @PublishedApi internal inline fun Aggregator.aggregateOf( column: DataColumn, - noinline transform: (C) -> V, + crossinline transform: (C) -> V?, ): R = aggregateOf(column.values(), transform) @PublishedApi internal inline fun Aggregator<*, R>.aggregateOf( frame: DataFrame, - crossinline expression: RowExpression, + crossinline expression: RowExpression, ): R = (this as Aggregator).aggregateOf(frame.rows()) { expression(it, it) } @PublishedApi @@ -48,33 +48,23 @@ internal fun Aggregator<*, R>.aggregateOfDelegated( body(this, this) } -@PublishedApi -internal inline fun Aggregator<*, R>.of( - data: DataFrame, - crossinline expression: RowExpression, -): R = aggregateOf(data as DataFrame, expression) - -@PublishedApi -internal inline fun Aggregator.of(data: DataColumn, crossinline expression: (C) -> V): R = - aggregateOf(data.values()) { expression(it) } - @PublishedApi internal inline fun Aggregator<*, R>.aggregateOf( data: Grouped, resultName: String? = null, - crossinline expression: RowExpression, + crossinline expression: RowExpression, ): DataFrame = data.aggregateOf(resultName, expression, this as Aggregator) @PublishedApi internal inline fun Aggregator<*, R>.aggregateOf( data: PivotGroupBy, - crossinline expression: RowExpression, + crossinline expression: RowExpression, ): DataFrame = data.aggregateOf(expression, this as Aggregator) @PublishedApi internal inline fun Grouped.aggregateOf( resultName: String?, - crossinline expression: RowExpression, + crossinline expression: RowExpression, aggregator: Aggregator, ): DataFrame { val path = pathOf(resultName ?: aggregator.name) @@ -97,7 +87,7 @@ internal inline fun Grouped.aggregateOf( @PublishedApi internal inline fun PivotGroupBy.aggregateOf( - crossinline expression: RowExpression, + crossinline expression: RowExpression, aggregator: Aggregator, ): DataFrame = aggregate { diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/median.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/median.kt index bc793b2704..9430cd4748 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/median.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/median.kt @@ -4,14 +4,36 @@ import io.kotest.matchers.shouldBe import org.jetbrains.kotlinx.dataframe.api.Infer import org.jetbrains.kotlinx.dataframe.api.columnOf import org.jetbrains.kotlinx.dataframe.api.dataFrameOf +import org.jetbrains.kotlinx.dataframe.api.groupBy import org.jetbrains.kotlinx.dataframe.api.mapToColumn import org.jetbrains.kotlinx.dataframe.api.median +import org.jetbrains.kotlinx.dataframe.api.medianOf import org.jetbrains.kotlinx.dataframe.api.rowMedian import org.junit.Test +import kotlin.reflect.typeOf @Suppress("ktlint:standard:argument-list-wrapping") class MedianTests { + val personsDf = dataFrameOf("name", "age", "city", "weight", "height", "yearsToRetirement")( + "Alice", 15, "London", 99.5, "1.85", 50, + "Bob", 20, "Paris", 140.0, "1.35", 45, + "Charlie", 100, "Dubai", 75.0, "1.95", 0, + "Rose", 1, "Moscow", 45.33, "0.79", 64, + "Dylan", 35, "London", 23.4, "1.83", 30, + "Eve", 40, "Paris", 56.72, "1.85", 25, + "Frank", 55, "Dubai", 78.9, "1.35", 10, + "Grace", 29, "Moscow", 67.8, "1.65", 36, + "Hank", 60, "Paris", 80.22, "1.75", 5, + "Isla", 22, "London", 75.1, "1.85", 43, + ) + + @Test + fun `medianOf test`() { + val d = personsDf.groupBy("city").medianOf("newAge") { "age"() * 10 } + d["newAge"].type() shouldBe typeOf() + } + @Test fun `median of two columns`() { val df = dataFrameOf("a", "b")(