Skip to content

Commit 2e0875b

Browse files
committed
added mean tests
1 parent a0279cd commit 2e0875b

File tree

1 file changed

+223
-0
lines changed
  • core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics

1 file changed

+223
-0
lines changed

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/mean.kt

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@ import io.kotest.matchers.doubles.shouldBeNaN
44
import io.kotest.matchers.shouldBe
55
import org.jetbrains.kotlinx.dataframe.DataColumn
66
import org.jetbrains.kotlinx.dataframe.api.columnOf
7+
import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
78
import org.jetbrains.kotlinx.dataframe.api.mean
9+
import org.jetbrains.kotlinx.dataframe.api.meanFor
10+
import org.jetbrains.kotlinx.dataframe.api.meanOf
11+
import org.jetbrains.kotlinx.dataframe.api.rowMean
812
import org.jetbrains.kotlinx.dataframe.impl.nothingType
913
import org.junit.Test
1014
import kotlin.reflect.typeOf
@@ -25,4 +29,223 @@ class MeanTests {
2529
DataColumn.createValueColumn("", emptyList<Nothing>(), nothingType(false)).mean().shouldBeNaN()
2630
DataColumn.createValueColumn("", listOf(null), nothingType(true)).mean().shouldBeNaN()
2731
}
32+
33+
@Test
34+
fun `mean with int values`() {
35+
val col = columnOf(1, 2, 3, 4, 5)
36+
col.mean() shouldBe 3.0
37+
38+
val colWithNull = columnOf<Int?>(1, 2, 3, 4, 5, null)
39+
colWithNull.mean() shouldBe 3.0
40+
}
41+
42+
@Test
43+
fun `mean with long values`() {
44+
val col = columnOf(1L, 2L, 3L, 4L, 5L)
45+
col.mean() shouldBe 3.0
46+
47+
val colWithNull = columnOf<Long?>(1L, 2L, 3L, 4L, 5L, null)
48+
colWithNull.mean() shouldBe 3.0
49+
}
50+
51+
@Test
52+
fun `mean with float values`() {
53+
val col = columnOf(1.0f, 2.0f, 3.0f, 4.0f, 5.0f)
54+
col.mean() shouldBe 3.0
55+
56+
val colWithNull = columnOf<Float?>(1.0f, 2.0f, 3.0f, 4.0f, 5.0f, null)
57+
colWithNull.mean() shouldBe 3.0
58+
}
59+
60+
@Test
61+
fun `mean with double values`() {
62+
val col = columnOf(1.0, 2.0, 3.0, 4.0, 5.0)
63+
col.mean() shouldBe 3.0
64+
65+
val colWithNull = columnOf<Double?>(1.0, 2.0, 3.0, 4.0, 5.0, null)
66+
colWithNull.mean() shouldBe 3.0
67+
}
68+
69+
@Test
70+
fun `mean with short values`() {
71+
val col = columnOf(1.toShort(), 2.toShort(), 3.toShort(), 4.toShort(), 5.toShort())
72+
col.mean() shouldBe 3.0
73+
74+
val colWithNull = columnOf<Short?>(1.toShort(), 2.toShort(), 3.toShort(), 4.toShort(), 5.toShort(), null)
75+
colWithNull.mean() shouldBe 3.0
76+
}
77+
78+
@Test
79+
fun `mean with byte values`() {
80+
val col = columnOf(1.toByte(), 2.toByte(), 3.toByte(), 4.toByte(), 5.toByte())
81+
col.mean() shouldBe 3.0
82+
83+
val colWithNull = columnOf<Byte?>(1.toByte(), 2.toByte(), 3.toByte(), 4.toByte(), 5.toByte(), null)
84+
colWithNull.mean() shouldBe 3.0
85+
}
86+
87+
@Test
88+
fun `mean with mixed number types`() {
89+
val col = columnOf<Number>(1, 2L, 3.0f, 4.0, 5.toShort())
90+
col.mean() shouldBe 3.0
91+
92+
val colWithNull = columnOf<Number?>(1, 2L, 3.0f, 4.0, 5.toShort(), null)
93+
colWithNull.mean() shouldBe 3.0
94+
95+
// Mix of different integer types
96+
val intMix = columnOf<Number>(1, 2L, 3.toShort(), 4.toByte())
97+
intMix.mean() shouldBe 2.5
98+
99+
// Mix of different floating point types
100+
val floatMix = columnOf<Number>(1.0f, 2.0, 3.0f, 4.0)
101+
floatMix.mean() shouldBe 2.5
102+
103+
// Mix of integer and floating point types
104+
val mixedTypes = columnOf<Number>(1, 2.0f, 3L, 4.0)
105+
mixedTypes.mean() shouldBe 2.5
106+
}
107+
108+
@Test
109+
fun `meanOf with transformation`() {
110+
val col = columnOf("1", "2", "3", "4", "5")
111+
col.meanOf { it.toInt() } shouldBe 3.0
112+
113+
val colWithNull = columnOf("1", "2", "3", null, "5")
114+
colWithNull.meanOf { it?.toInt() } shouldBe 2.75
115+
}
116+
117+
@Suppress("ktlint:standard:argument-list-wrapping")
118+
@Test
119+
fun `rowMean with dataframe`() {
120+
val df = dataFrameOf(
121+
"a", "b", "c",
122+
)(
123+
1, 2, 3,
124+
4, 5, 6,
125+
7, 8, 9,
126+
)
127+
128+
df[0].rowMean() shouldBe 2.0
129+
df[1].rowMean() shouldBe 5.0
130+
df[2].rowMean() shouldBe 8.0
131+
}
132+
133+
@Suppress("ktlint:standard:argument-list-wrapping")
134+
@Test
135+
fun `dataframe mean`() {
136+
val df = dataFrameOf(
137+
"a", "b", "c",
138+
)(
139+
1, 2.0, 3f,
140+
4, 5.0, 6f,
141+
7, 8.0, 9f,
142+
)
143+
144+
val means = df.mean()
145+
means["a"] shouldBe 4.0
146+
means["b"] shouldBe 5.0
147+
means["c"] shouldBe 6.0
148+
149+
// Test mean for specific columns
150+
val meanFor = df.meanFor("a", "c")
151+
meanFor["a"] shouldBe 4.0
152+
meanFor["c"] shouldBe 6.0
153+
154+
// Test mean of all columns as a single value
155+
df.mean("a", "b", "c") shouldBe 5.0
156+
}
157+
158+
@Suppress("ktlint:standard:argument-list-wrapping")
159+
@Test
160+
fun `dataframe meanOf with transformation`() {
161+
val df = dataFrameOf(
162+
"a", "b", "c",
163+
)(
164+
1, 2, 3,
165+
4, 5, 6,
166+
7, 8, 9,
167+
)
168+
169+
df.meanOf { "a"<Int>() + "c"<Int>() } shouldBe 10.0
170+
}
171+
172+
@Suppress("ktlint:standard:argument-list-wrapping")
173+
@Test
174+
fun `mean with skipNaN for floating point numbers`() {
175+
// Test with Float.NaN values
176+
val floatWithNaN = columnOf(1.0f, 2.0f, Float.NaN, 4.0f, 5.0f)
177+
floatWithNaN.mean().shouldBeNaN() // Default behavior: NaN propagates
178+
floatWithNaN.mean(skipNaN = true) shouldBe 3.0 // Skip NaN values
179+
180+
// Test with Double.NaN values
181+
val doubleWithNaN = columnOf(1.0, 2.0, Double.NaN, 4.0, 5.0)
182+
doubleWithNaN.mean().shouldBeNaN() // Default behavior: NaN propagates
183+
doubleWithNaN.mean(skipNaN = true) shouldBe 3.0 // Skip NaN values
184+
185+
// Test with multiple NaN values in different positions
186+
val multipleNaN = columnOf(Float.NaN, 2.0f, Float.NaN, 4.0f, Float.NaN)
187+
multipleNaN.mean().shouldBeNaN() // Default behavior: NaN propagates
188+
multipleNaN.mean(skipNaN = true) shouldBe 3.0 // Skip NaN values
189+
190+
// Test with all NaN values
191+
val allNaN = columnOf(Float.NaN, Float.NaN, Float.NaN)
192+
allNaN.mean().shouldBeNaN() // Default behavior: NaN propagates
193+
allNaN.mean(skipNaN = true).shouldBeNaN() // No valid values, result is NaN
194+
195+
// Test with mixed number types including NaN
196+
val mixedWithNaN = columnOf<Number>(1, 2.0f, Double.NaN, 4L, 5.0)
197+
mixedWithNaN.mean().shouldBeNaN() // Default behavior: NaN propagates
198+
mixedWithNaN.mean(skipNaN = true) shouldBe 3.0 // Skip NaN values
199+
200+
// Test with DataFrame containing NaN values
201+
val dfWithNaN = dataFrameOf(
202+
"a", "b", "c",
203+
)(
204+
1.0, Double.NaN, 3.0,
205+
4.0, 5.0, Float.NaN,
206+
Double.NaN, 8.0, 9.0,
207+
)
208+
209+
// Test DataFrame mean with skipNaN
210+
val meansWithNaN = dfWithNaN.mean() // Default behavior
211+
(meansWithNaN["a"] as Double).shouldBeNaN() // Contains NaN
212+
(meansWithNaN["b"] as Double).shouldBeNaN() // Contains NaN
213+
(meansWithNaN["c"] as Double).shouldBeNaN() // Contains NaN
214+
215+
val meansSkipNaN = dfWithNaN.mean(skipNaN = true) // Skip NaN values
216+
meansSkipNaN["a"] shouldBe 2.5 // (1.0 + 4.0) / 2
217+
meansSkipNaN["b"] shouldBe 6.5 // (5.0 + 8.0) / 2
218+
meansSkipNaN["c"] shouldBe 6.0 // (3.0 + 9.0) / 2
219+
220+
// Test meanFor with skipNaN
221+
val meanForWithNaN = dfWithNaN.meanFor("a", "c") // Default behavior
222+
(meanForWithNaN["a"] as Double).shouldBeNaN() // Contains NaN
223+
(meanForWithNaN["c"] as Double).shouldBeNaN() // Contains NaN
224+
225+
val meanForSkipNaN = dfWithNaN.meanFor("a", "c", skipNaN = true) // Skip NaN values
226+
meanForSkipNaN["a"] shouldBe 2.5 // (1.0 + 4.0) / 2
227+
meanForSkipNaN["c"] shouldBe 6.0 // (3.0 + 9.0) / 2
228+
229+
// Test mean of all columns as a single value with skipNaN
230+
dfWithNaN.mean("a", "b", "c").shouldBeNaN() // Default behavior: NaN propagates
231+
dfWithNaN.mean("a", "b", "c", skipNaN = true) shouldBe 5.0 // Skip NaN values
232+
233+
// Test meanOf with transformation that might produce NaN values
234+
val dfForTransform = dataFrameOf(
235+
"a", "b",
236+
)(
237+
1.0, 0.0,
238+
4.0, 2.0,
239+
0.0, 0.0,
240+
)
241+
242+
// Division by zero produces NaN
243+
dfForTransform.meanOf { "a"<Double>() / "b"<Double>() }.shouldBeNaN() // Default behavior: NaN propagates
244+
245+
// Skip NaN values from division by zero
246+
dfForTransform.meanOf(skipNaN = true) {
247+
val b = "b"<Double>()
248+
if (b == 0.0) Double.NaN else "a"<Double>() / b
249+
} shouldBe 2.0 // Only 4.0/2.0 = 2.0 is valid
250+
}
28251
}

0 commit comments

Comments
 (0)