Skip to content

Commit 6b4e9d0

Browse files
committed
expanded cumSum with Short, Byte, and BigInteger
1 parent b76c5a2 commit 6b4e9d0

File tree

3 files changed

+133
-1
lines changed
  • core/src

3 files changed

+133
-1
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/cumSum.kt

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import org.jetbrains.kotlinx.dataframe.math.cumSum
99
import org.jetbrains.kotlinx.dataframe.math.defaultCumSumSkipNA
1010
import org.jetbrains.kotlinx.dataframe.typeClass
1111
import java.math.BigDecimal
12+
import java.math.BigInteger
1213
import kotlin.reflect.KProperty
1314
import kotlin.reflect.typeOf
1415

@@ -22,15 +23,30 @@ public fun <T : Number?> DataColumn<T>.cumSum(skipNA: Boolean = defaultCumSumSki
2223
typeOf<Float?>() -> cast<Float?>().cumSum(skipNA).cast()
2324
typeOf<Int>() -> cast<Int>().cumSum().cast()
2425
typeOf<Int?>() -> cast<Int?>().cumSum(skipNA).cast()
26+
typeOf<Byte>() -> cast<Byte>().cumSum().cast()
27+
typeOf<Byte?>() -> cast<Byte?>().cumSum(skipNA).cast()
28+
typeOf<Short>() -> cast<Short>().cumSum().cast()
29+
typeOf<Short?>() -> cast<Short?>().cumSum(skipNA).cast()
2530
typeOf<Long>() -> cast<Long>().cumSum().cast()
2631
typeOf<Long?>() -> cast<Long?>().cumSum(skipNA).cast()
32+
typeOf<BigInteger>() -> cast<BigInteger>().cumSum().cast()
33+
typeOf<BigInteger?>() -> cast<BigInteger?>().cumSum(skipNA).cast()
2734
typeOf<BigDecimal>() -> cast<BigDecimal>().cumSum().cast()
2835
typeOf<BigDecimal?>() -> cast<BigDecimal?>().cumSum(skipNA).cast()
2936
typeOf<Number?>(), typeOf<Number>() -> convertToDouble().cumSum(skipNA).cast()
3037
else -> error("Cumsum for type ${type()} is not supported")
3138
}
3239

33-
private val supportedClasses = setOf(Double::class, Float::class, Int::class, Long::class, BigDecimal::class)
40+
private val supportedClasses = setOf(
41+
Double::class,
42+
Float::class,
43+
Int::class,
44+
Byte::class,
45+
Short::class,
46+
Long::class,
47+
BigInteger::class,
48+
BigDecimal::class,
49+
)
3450

3551
// endregion
3652

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/cumsum.kt

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import org.jetbrains.kotlinx.dataframe.DataColumn
44
import org.jetbrains.kotlinx.dataframe.api.isNA
55
import org.jetbrains.kotlinx.dataframe.api.map
66
import java.math.BigDecimal
7+
import java.math.BigInteger
78

89
internal val defaultCumSumSkipNA: Boolean = true
910

@@ -88,6 +89,66 @@ internal fun DataColumn<Int?>.cumSum(skipNA: Boolean = defaultCumSumSkipNA): Dat
8889
}
8990
}
9091

92+
@JvmName("byteCumsum")
93+
internal fun DataColumn<Byte>.cumSum(): DataColumn<Byte> {
94+
var sum = 0.toByte()
95+
return map {
96+
sum = (sum + it).toByte()
97+
sum
98+
}
99+
}
100+
101+
@JvmName("cumsumByteNullable")
102+
internal fun DataColumn<Byte?>.cumSum(skipNA: Boolean = defaultCumSumSkipNA): DataColumn<Byte?> {
103+
var sum = 0.toByte()
104+
var fillNull = false
105+
return map {
106+
when {
107+
it == null -> {
108+
if (!skipNA) fillNull = true
109+
null
110+
}
111+
112+
fillNull -> null
113+
114+
else -> {
115+
sum = (sum + it).toByte()
116+
sum
117+
}
118+
}
119+
}
120+
}
121+
122+
@JvmName("shortCumsum")
123+
internal fun DataColumn<Short>.cumSum(): DataColumn<Short> {
124+
var sum = 0.toShort()
125+
return map {
126+
sum = (sum + it).toShort()
127+
sum
128+
}
129+
}
130+
131+
@JvmName("cumsumShortNullable")
132+
internal fun DataColumn<Short?>.cumSum(skipNA: Boolean = defaultCumSumSkipNA): DataColumn<Short?> {
133+
var sum = 0.toShort()
134+
var fillNull = false
135+
return map {
136+
when {
137+
it == null -> {
138+
if (!skipNA) fillNull = true
139+
null
140+
}
141+
142+
fillNull -> null
143+
144+
else -> {
145+
sum = (sum + it).toShort()
146+
sum
147+
}
148+
}
149+
}
150+
}
151+
91152
@JvmName("longCumsum")
92153
internal fun DataColumn<Long>.cumSum(): DataColumn<Long> {
93154
var sum = 0L
@@ -118,6 +179,36 @@ internal fun DataColumn<Long?>.cumSum(skipNA: Boolean = defaultCumSumSkipNA): Da
118179
}
119180
}
120181

182+
@JvmName("bigIntegerCumsum")
183+
internal fun DataColumn<BigInteger>.cumSum(): DataColumn<BigInteger> {
184+
var sum = BigInteger.ZERO
185+
return map {
186+
sum += it
187+
sum
188+
}
189+
}
190+
191+
@JvmName("cumsumBigIntegerNullable")
192+
internal fun DataColumn<BigInteger?>.cumSum(skipNA: Boolean = defaultCumSumSkipNA): DataColumn<BigInteger?> {
193+
var sum = BigInteger.ZERO
194+
var fillNull = false
195+
return map {
196+
when {
197+
it == null -> {
198+
if (!skipNA) fillNull = true
199+
null
200+
}
201+
202+
fillNull -> null
203+
204+
else -> {
205+
sum += it
206+
sum
207+
}
208+
}
209+
}
210+
}
211+
121212
@JvmName("bigDecimalCumsum")
122213
internal fun DataColumn<BigDecimal>.cumSum(): DataColumn<BigDecimal> {
123214
var sum = BigDecimal.ZERO

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/cumsum.kt

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import org.jetbrains.kotlinx.dataframe.api.concat
77
import org.jetbrains.kotlinx.dataframe.api.cumSum
88
import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
99
import org.jetbrains.kotlinx.dataframe.api.groupBy
10+
import org.jetbrains.kotlinx.dataframe.api.map
1011
import org.junit.Test
1112

1213
@Suppress("ktlint:standard:argument-list-wrapping")
@@ -22,6 +23,30 @@ class CumsumTests {
2223
col.cumSum(skipNA = false).toList() shouldBe expectedNoSkip
2324
}
2425

26+
@Test
27+
fun `short column`() {
28+
col.map { it?.toShort() }.cumSum().toList() shouldBe expected.map { it?.toShort() }
29+
col.map { it?.toShort() }.cumSum(skipNA = false).toList() shouldBe expectedNoSkip.map { it?.toShort() }
30+
}
31+
32+
@Test
33+
fun `byte column`() {
34+
col.map { it?.toByte() }.cumSum().toList() shouldBe expected.map { it?.toByte() }
35+
col.map { it?.toByte() }.cumSum(skipNA = false).toList() shouldBe expectedNoSkip.map { it?.toByte() }
36+
}
37+
38+
@Test
39+
fun `big int column`() {
40+
col.map { it?.toBigInteger() }.cumSum().toList() shouldBe expected.map { it?.toBigInteger() }
41+
col.map { it?.toBigInteger() }.cumSum(skipNA = false).toList() shouldBe expectedNoSkip.map { it?.toBigInteger() }
42+
}
43+
44+
@Test
45+
fun `big decimal column`() {
46+
col.map { it?.toBigDecimal() }.cumSum().toList() shouldBe expected.map { it?.toBigDecimal() }
47+
col.map { it?.toBigDecimal() }.cumSum(skipNA = false).toList() shouldBe expectedNoSkip.map { it?.toBigDecimal() }
48+
}
49+
2550
@Test
2651
fun frame() {
2752
val str by columnOf("a", "b", "c", "d", "e")

0 commit comments

Comments
 (0)