Skip to content

Commit 499334d

Browse files
committed
fixed describe and tests
1 parent 3ffb2ef commit 499334d

File tree

4 files changed

+27
-16
lines changed

4 files changed

+27
-16
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/map.kt

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,7 @@ public inline fun <C, reified R> ColumnReference<C>.map(
3131

3232
// region DataColumn
3333

34-
public inline fun <T, reified R> DataColumn<T>.map(
35-
infer: Infer = Infer.Nulls,
36-
crossinline transform: (T) -> R,
37-
): DataColumn<R> {
34+
public inline fun <T, reified R> DataColumn<T>.map(infer: Infer = Infer.Nulls, transform: (T) -> R): DataColumn<R> {
3835
val newValues = Array(size()) { transform(get(it)) }.asList()
3936
return DataColumn.createByType(name(), newValues, typeOf<R>(), infer)
4037
}

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/typeConversions.kt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,17 @@ public fun DataColumn<Any>.asNumbers(): ValueColumn<Number> {
8383
return this as ValueColumn<Number>
8484
}
8585

86-
public fun <T> DataColumn<T>.asComparable(): DataColumn<Comparable<T>> {
86+
public fun <T : Any> DataColumn<T>.asComparable(): DataColumn<Comparable<T>> {
8787
require(valuesAreComparable())
8888
return this as DataColumn<Comparable<T>>
8989
}
9090

91+
@JvmName("asComparableNullable")
92+
public fun <T : Any?> DataColumn<T?>.asComparable(): DataColumn<Comparable<T>?> {
93+
require(valuesAreComparable())
94+
return this as DataColumn<Comparable<T>?>
95+
}
96+
9197
public fun <T> ColumnReference<T?>.castToNotNullable(): ColumnReference<T> = cast()
9298

9399
public fun <T> DataColumn<T?>.castToNotNullable(): DataColumn<T> {

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/describe.kt

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import org.jetbrains.kotlinx.dataframe.api.asNumbers
1313
import org.jetbrains.kotlinx.dataframe.api.cast
1414
import org.jetbrains.kotlinx.dataframe.api.concat
1515
import org.jetbrains.kotlinx.dataframe.api.isNumber
16-
import org.jetbrains.kotlinx.dataframe.api.isPrimitiveNumber
1716
import org.jetbrains.kotlinx.dataframe.api.map
1817
import org.jetbrains.kotlinx.dataframe.api.maxOrNull
1918
import org.jetbrains.kotlinx.dataframe.api.mean
@@ -30,6 +29,7 @@ import org.jetbrains.kotlinx.dataframe.columns.size
3029
import org.jetbrains.kotlinx.dataframe.columns.values
3130
import org.jetbrains.kotlinx.dataframe.impl.columns.addPath
3231
import org.jetbrains.kotlinx.dataframe.impl.columns.asAnyFrameColumn
32+
import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveNumber
3333
import org.jetbrains.kotlinx.dataframe.impl.renderType
3434
import org.jetbrains.kotlinx.dataframe.index
3535
import org.jetbrains.kotlinx.dataframe.kind
@@ -38,7 +38,7 @@ import org.jetbrains.kotlinx.dataframe.type
3838
internal fun describeImpl(cols: List<AnyCol>): DataFrame<ColumnDescription> {
3939
val allCols = cols.collectAll(false)
4040

41-
val hasNumericCols = allCols.any { it.isPrimitiveNumber() }
41+
val hasNumericCols = allCols.any { it.isNumber() }
4242
val hasComparableCols = allCols.any { it.valuesAreComparable() }
4343
val hasLongPaths = allCols.any { it.path().size > 1 }
4444
var df = allCols.toDataFrame {
@@ -56,8 +56,8 @@ internal fun describeImpl(cols: List<AnyCol>): DataFrame<ColumnDescription> {
5656
?.key
5757
}
5858
if (hasNumericCols) {
59-
ColumnDescription::mean from { if (it.isPrimitiveNumber()) it.asNumbers().mean() else null }
60-
ColumnDescription::std from { if (it.isPrimitiveNumber()) it.asNumbers().std() else null }
59+
ColumnDescription::mean from { if (it.isNumber()) it.asNumbers().mean() else null }
60+
ColumnDescription::std from { if (it.isNumber()) it.asNumbers().std() else null }
6161
}
6262
if (hasComparableCols || hasNumericCols) {
6363
ColumnDescription::min from inferType {
@@ -111,12 +111,20 @@ private fun List<AnyCol>.collectAll(atAnyDepth: Boolean): List<AnyCol> =
111111
}
112112

113113
/** Converts a column to a comparable column if it is not already comparable. */
114-
private fun DataColumn<Any?>.convertToComparableOrNull(): DataColumn<Comparable<Any?>>? =
115-
when {
114+
@Suppress("UNCHECKED_CAST")
115+
private fun DataColumn<Any?>.convertToComparableOrNull(): DataColumn<Comparable<Any>?>? {
116+
return when {
116117
valuesAreComparable() -> asComparable()
117118

118119
// Found incomparable number types, convert all to Double first
119-
isPrimitiveNumber() -> map { (it as Number?)?.toDouble() }.cast()
120+
isNumber() -> cast<Number?>().map {
121+
if (it?.isPrimitiveNumber() == false) {
122+
// Cannot calculate statistics of a non-primitive number type
123+
return@convertToComparableOrNull null
124+
}
125+
it?.toDouble() as Comparable<Any>?
126+
}
120127

121128
else -> null
122129
}
130+
}

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/describe.kt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,12 @@ class DescribeTests {
3737
nulls shouldBe 1
3838
top shouldBe 1
3939
freq shouldBe 1
40-
this.mean shouldBe 3.5
40+
mean shouldBe 3.5
4141
std shouldBe 1.8708286933869707
4242
min shouldBe 1.0
43-
p25 shouldBe 2.25
44-
median shouldBe 3.5
45-
p75 shouldBe 4.75
43+
p25 shouldBe 2.0
44+
median shouldBe 3.0
45+
p75 shouldBe 4.0
4646
max shouldBe 6.0
4747
}
4848
}

0 commit comments

Comments
 (0)