Skip to content

Commit 05e0993

Browse files
committed
adding tests and comments, missing types for median, expanding cumSum
1 parent b2bdb4a commit 05e0993

File tree

9 files changed

+197
-38
lines changed

9 files changed

+197
-38
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/cumSum.kt

+14-4
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,21 @@ public fun <T : Number?> DataColumn<T>.cumSum(skipNA: Boolean = defaultCumSumSki
2727

2828
typeOf<Float?>() -> cast<Float?>().cumSum(skipNA).cast()
2929

30-
// careful, cast to Int can occur! TODO
31-
typeOf<Int>(), typeOf<Byte>(), typeOf<Short>() -> cast<Int>().cumSum().cast()
30+
typeOf<Int>() -> cast<Int>().cumSum().cast()
3231

33-
// careful, cast to Int can occur! TODO
34-
typeOf<Int?>(), typeOf<Byte?>(), typeOf<Short?>() -> cast<Int?>().cumSum(skipNA).cast()
32+
// TODO cumSum for Byte returns Int but is cast back to T: Byte
33+
typeOf<Byte>() -> cast<Byte>().cumSum().cast()
34+
35+
// TODO cumSum for Short returns Int but is cast back to T: Short
36+
typeOf<Short>() -> cast<Short>().cumSum().cast()
37+
38+
typeOf<Int?>() -> cast<Int?>().cumSum(skipNA).cast()
39+
40+
// TODO cumSum for Byte? returns Int? but is cast back to T: Byte?
41+
typeOf<Byte?>() -> cast<Byte?>().cumSum(skipNA).cast()
42+
43+
// TODO cumSum for Short? returns Int? but is cast back to T: Short?
44+
typeOf<Short?>() -> cast<Short?>().cumSum(skipNA).cast()
3545

3646
typeOf<Long>() -> cast<Long>().cumSum().cast()
3747

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/cumsum.kt

+60
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,66 @@ internal fun DataColumn<Int?>.cumSum(skipNA: Boolean = defaultCumSumSkipNA): Dat
8989
}
9090
}
9191

92+
@JvmName("byteCumsum")
93+
internal fun DataColumn<Byte>.cumSum(): DataColumn<Int> {
94+
var sum = 0
95+
return map {
96+
sum += it
97+
sum
98+
}
99+
}
100+
101+
@JvmName("cumsumByteNullable")
102+
internal fun DataColumn<Byte?>.cumSum(skipNA: Boolean = defaultCumSumSkipNA): DataColumn<Int?> {
103+
var sum = 0
104+
var fillNull = false
105+
return map {
106+
when {
107+
it == null -> {
108+
if (!skipNA) fillNull = true
109+
null
110+
}
111+
112+
fillNull -> null
113+
114+
else -> {
115+
sum += it
116+
sum
117+
}
118+
}
119+
}
120+
}
121+
122+
@JvmName("shortCumsum")
123+
internal fun DataColumn<Short>.cumSum(): DataColumn<Int> {
124+
var sum = 0
125+
return map {
126+
sum += it
127+
sum
128+
}
129+
}
130+
131+
@JvmName("cumsumShortNullable")
132+
internal fun DataColumn<Short?>.cumSum(skipNA: Boolean = defaultCumSumSkipNA): DataColumn<Int?> {
133+
var sum = 0
134+
var fillNull = false
135+
return map {
136+
when {
137+
it == null -> {
138+
if (!skipNA) fillNull = true
139+
null
140+
}
141+
142+
fillNull -> null
143+
144+
else -> {
145+
sum += it
146+
sum
147+
}
148+
}
149+
}
150+
}
151+
92152
@JvmName("longCumsum")
93153
internal fun DataColumn<Long>.cumSum(): DataColumn<Long> {
94154
var sum = 0L

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/mean.kt

+2
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ public fun Iterable<Long>.mean(): Double =
135135
if (count > 0) sum / count else Double.NaN
136136
}
137137

138+
// TODO result is Double, but should be BigDecimal, Issue #558
138139
@JvmName("bigIntegerMean")
139140
public fun Iterable<BigInteger>.mean(): Double =
140141
if (this is Collection) {
@@ -148,6 +149,7 @@ public fun Iterable<BigInteger>.mean(): Double =
148149
if (count > 0) sum / count else Double.NaN
149150
}
150151

152+
// TODO result is Double, but should be BigDecimal, Issue #558
151153
@JvmName("bigDecimalMean")
152154
public fun Iterable<BigDecimal>.mean(): Double =
153155
if (this is Collection) {

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/median.kt

+12-1
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@ package org.jetbrains.kotlinx.dataframe.math
22

33
import org.jetbrains.kotlinx.dataframe.impl.asList
44
import java.math.BigDecimal
5+
import java.math.BigInteger
56
import kotlin.reflect.KType
67
import kotlin.reflect.typeOf
78

89
public inline fun <reified T : Comparable<T>> Iterable<T>.medianOrNull(): T? = median(typeOf<T>())
910

1011
public inline fun <reified T : Comparable<T>> Iterable<T>.median(): T = medianOrNull()!!
1112

13+
// TODO median always returns the same type, but this can be confusing for iterables of even length
14+
// TODO (e.g. median of [1, 2] should be 1.5, but the type is Int, so it returns 1), Issue #558
1215
@PublishedApi
1316
internal inline fun <reified T : Comparable<T>> Iterable<T?>.median(type: KType): T? {
1417
val list = if (type.isMarkedNullable) filterNotNull() else (this as Iterable<T>).asList()
@@ -19,14 +22,22 @@ internal inline fun <reified T : Comparable<T>> Iterable<T?>.median(type: KType)
1922
return when (type.classifier) {
2023
Double::class -> ((list.quickSelect(index - 1) as Double + list.quickSelect(index) as Double) / 2.0) as T
2124

25+
Float::class -> ((list.quickSelect(index - 1) as Float + list.quickSelect(index) as Float) / 2.0f) as T
26+
2227
Int::class -> ((list.quickSelect(index - 1) as Int + list.quickSelect(index) as Int) / 2) as T
2328

29+
Short::class -> ((list.quickSelect(index - 1) as Short + list.quickSelect(index) as Short) / 2) as T
30+
2431
Long::class -> ((list.quickSelect(index - 1) as Long + list.quickSelect(index) as Long) / 2L) as T
2532

2633
Byte::class -> ((list.quickSelect(index - 1) as Byte + list.quickSelect(index) as Byte) / 2).toByte() as T
2734

2835
BigDecimal::class -> (
29-
(list.quickSelect(index - 1) as BigDecimal + list.quickSelect(index) as BigDecimal) / BigDecimal(2)
36+
(list.quickSelect(index - 1) as BigDecimal + list.quickSelect(index) as BigDecimal) / 2.toBigDecimal()
37+
) as T
38+
39+
BigInteger::class -> (
40+
(list.quickSelect(index - 1) as BigInteger + list.quickSelect(index) as BigInteger) / 2.toBigInteger()
3041
) as T
3142

3243
else -> list.quickSelect(index - 1)

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/minmax.kt

-25
This file was deleted.

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/sum.kt

+14-4
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,13 @@ internal fun <T : Number> Iterable<T>.sum(type: KType): T =
4545

4646
Float::class -> (this as Iterable<Float>).sum() as T
4747

48-
// careful, cast to Int occurs! TODO
49-
Int::class, Short::class, Byte::class -> (this as Iterable<Int>).sum() as T
48+
Int::class -> (this as Iterable<Int>).sum() as T
49+
50+
// TODO result should be Int, but same type as input is returned, Issue #558
51+
Short::class -> (this as Iterable<Short>).sum().toShort() as T
52+
53+
// TODO result should be Int, but same type as input is returned, Issue #558
54+
Byte::class -> (this as Iterable<Byte>).sum().toByte() as T
5055

5156
Long::class -> (this as Iterable<Long>).sum() as T
5257

@@ -69,8 +74,13 @@ internal fun <T : Number> Iterable<T?>.sum(type: KType): T =
6974

7075
Float::class -> (this as Iterable<Float?>).asSequence().filterNotNull().sum() as T
7176

72-
// careful, cast to Int occurs! TODO
73-
Int::class, Short::class, Byte::class -> (this as Iterable<Int?>).asSequence().filterNotNull().sum() as T
77+
Int::class -> (this as Iterable<Int?>).asSequence().filterNotNull().sum() as T
78+
79+
// TODO result should be Int, but same type as input is returned, Issue #558
80+
Short::class -> (this as Iterable<Short?>).asSequence().filterNotNull().sum().toShort() as T
81+
82+
// TODO result should be Int, but same type as input is returned, Issue #558
83+
Byte::class -> (this as Iterable<Short?>).asSequence().filterNotNull().sum().toByte() as T
7484

7585
Long::class -> (this as Iterable<Long?>).asSequence().filterNotNull().sum() as T
7686

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/cumsum.kt

+16-4
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,26 @@ class CumsumTests {
2525

2626
@Test
2727
fun `short column`() {
28-
col.map { it?.toShort() }.cumSum().toList() shouldBe expected.map { it?.toShort() }
29-
col.map { it?.toShort() }.cumSum(skipNA = false).toList() shouldBe expectedNoSkip.map { it?.toShort() }
28+
col.map { it?.toShort() }.cumSum().toList() shouldBe expected
29+
col.map { it?.toShort() }.cumSum(skipNA = false).toList() shouldBe expectedNoSkip
30+
}
31+
32+
@Test
33+
fun `frame with multiple columns`() {
34+
val col2 by columnOf(1.toShort(), 2, 3, 4, 5)
35+
val col3 by columnOf(1.toByte(), 2, 3, 4, null)
36+
val df = dataFrameOf(col, col2, col3)
37+
val res = df.cumSum(skipNA = false)
38+
39+
res[col].toList() shouldBe expectedNoSkip
40+
res[col2].toList() shouldBe listOf(1, 3, 6, 10, 15)
41+
res[col3].toList() shouldBe listOf(1, 3, 6, 10, null)
3042
}
3143

3244
@Test
3345
fun `byte column`() {
34-
col.map { it?.toByte() }.cumSum().toList() shouldBe expected.map { it?.toByte() }
35-
col.map { it?.toByte() }.cumSum(skipNA = false).toList() shouldBe expectedNoSkip.map { it?.toByte() }
46+
col.map { it?.toByte() }.cumSum().toList() shouldBe expected
47+
col.map { it?.toByte() }.cumSum(skipNA = false).toList() shouldBe expectedNoSkip
3648
}
3749

3850
@Test

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/std.kt

+13
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,19 @@ class StdTests {
2828
df.std().columnTypes().single() shouldBe typeOf<Double>()
2929
}
3030

31+
@Test
32+
fun `std one byte column`() {
33+
val value by columnOf(1.toByte(), 2.toByte(), 3.toByte())
34+
val df = dataFrameOf(value)
35+
val expected = 1.0
36+
37+
value.values().std(typeOf<Byte>()) shouldBe expected
38+
value.std() shouldBe expected
39+
df[value].std() shouldBe expected
40+
df.std { value } shouldBe expected
41+
df.std().columnTypes().single() shouldBe typeOf<Double>()
42+
}
43+
3144
@Test
3245
fun `std one double column`() {
3346
val value by columnOf(1.0, 2.0, 3.0)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package org.jetbrains.kotlinx.dataframe.statistics
2+
3+
import io.kotest.matchers.shouldBe
4+
import org.jetbrains.kotlinx.dataframe.DataColumn
5+
import org.jetbrains.kotlinx.dataframe.api.columnOf
6+
import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
7+
import org.jetbrains.kotlinx.dataframe.api.sum
8+
import org.jetbrains.kotlinx.dataframe.api.sumOf
9+
import org.junit.Test
10+
11+
class SumTests {
12+
13+
@Test
14+
fun `test single column`() {
15+
val value by columnOf(1, 2, 3)
16+
val df = dataFrameOf(value)
17+
val expected = 6
18+
19+
value.values().sum() shouldBe expected
20+
value.sum() shouldBe expected
21+
df[value].sum() shouldBe expected
22+
df.sum { value } shouldBe expected
23+
df.sum()[value] shouldBe expected
24+
df.sumOf { value() } shouldBe expected
25+
}
26+
27+
@Test
28+
fun `test single short column`() {
29+
val value by columnOf(1.toShort(), 2.toShort(), 3.toShort())
30+
val df = dataFrameOf(value)
31+
val expected = 6
32+
33+
value.values().sum() shouldBe expected
34+
value.sum() shouldBe expected
35+
df[value].sum() shouldBe expected
36+
df.sum { value } shouldBe expected
37+
df.sum()[value] shouldBe expected
38+
df.sumOf { value() } shouldBe expected
39+
}
40+
41+
@Test
42+
fun `test multiple columns`() {
43+
val value1 by columnOf(1, 2, 3)
44+
val value2 by columnOf(4.0, 5.0, 6.0)
45+
val value3: DataColumn<Number?> by columnOf(7.0, 8, null)
46+
val df = dataFrameOf(value1, value2, value3)
47+
val expected1 = 6
48+
val expected2 = 15.0
49+
val expected3 = 15.0
50+
51+
df.sum()[value1] shouldBe expected1
52+
df.sum()[value2] shouldBe expected2
53+
df.sum()[value3] shouldBe expected3
54+
df.sumOf { value1() } shouldBe expected1
55+
df.sumOf { value2() } shouldBe expected2
56+
df.sumOf { value3() } shouldBe expected3
57+
df.sum(value1) shouldBe expected1
58+
df.sum(value2) shouldBe expected2
59+
df.sum(value3) shouldBe expected3
60+
df.sum { value1 } shouldBe expected1
61+
df.sum { value2 } shouldBe expected2
62+
df.sum { value3 } shouldBe expected3
63+
}
64+
65+
66+
}

0 commit comments

Comments
 (0)