@@ -4,7 +4,11 @@ import io.kotest.matchers.doubles.shouldBeNaN
4
4
import io.kotest.matchers.shouldBe
5
5
import org.jetbrains.kotlinx.dataframe.DataColumn
6
6
import org.jetbrains.kotlinx.dataframe.api.columnOf
7
+ import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
7
8
import org.jetbrains.kotlinx.dataframe.api.mean
9
+ import org.jetbrains.kotlinx.dataframe.api.meanFor
10
+ import org.jetbrains.kotlinx.dataframe.api.meanOf
11
+ import org.jetbrains.kotlinx.dataframe.api.rowMean
8
12
import org.jetbrains.kotlinx.dataframe.impl.nothingType
9
13
import org.junit.Test
10
14
import kotlin.reflect.typeOf
@@ -25,4 +29,223 @@ class MeanTests {
25
29
DataColumn .createValueColumn(" " , emptyList<Nothing >(), nothingType(false )).mean().shouldBeNaN()
26
30
DataColumn .createValueColumn(" " , listOf (null ), nothingType(true )).mean().shouldBeNaN()
27
31
}
32
+
33
+ @Test
34
+ fun `mean with int values` () {
35
+ val col = columnOf(1 , 2 , 3 , 4 , 5 )
36
+ col.mean() shouldBe 3.0
37
+
38
+ val colWithNull = columnOf<Int ?>(1 , 2 , 3 , 4 , 5 , null )
39
+ colWithNull.mean() shouldBe 3.0
40
+ }
41
+
42
+ @Test
43
+ fun `mean with long values` () {
44
+ val col = columnOf(1L , 2L , 3L , 4L , 5L )
45
+ col.mean() shouldBe 3.0
46
+
47
+ val colWithNull = columnOf<Long ?>(1L , 2L , 3L , 4L , 5L , null )
48
+ colWithNull.mean() shouldBe 3.0
49
+ }
50
+
51
+ @Test
52
+ fun `mean with float values` () {
53
+ val col = columnOf(1.0f , 2.0f , 3.0f , 4.0f , 5.0f )
54
+ col.mean() shouldBe 3.0
55
+
56
+ val colWithNull = columnOf<Float ?>(1.0f , 2.0f , 3.0f , 4.0f , 5.0f , null )
57
+ colWithNull.mean() shouldBe 3.0
58
+ }
59
+
60
+ @Test
61
+ fun `mean with double values` () {
62
+ val col = columnOf(1.0 , 2.0 , 3.0 , 4.0 , 5.0 )
63
+ col.mean() shouldBe 3.0
64
+
65
+ val colWithNull = columnOf<Double ?>(1.0 , 2.0 , 3.0 , 4.0 , 5.0 , null )
66
+ colWithNull.mean() shouldBe 3.0
67
+ }
68
+
69
+ @Test
70
+ fun `mean with short values` () {
71
+ val col = columnOf(1 .toShort(), 2 .toShort(), 3 .toShort(), 4 .toShort(), 5 .toShort())
72
+ col.mean() shouldBe 3.0
73
+
74
+ val colWithNull = columnOf<Short ?>(1 .toShort(), 2 .toShort(), 3 .toShort(), 4 .toShort(), 5 .toShort(), null )
75
+ colWithNull.mean() shouldBe 3.0
76
+ }
77
+
78
+ @Test
79
+ fun `mean with byte values` () {
80
+ val col = columnOf(1 .toByte(), 2 .toByte(), 3 .toByte(), 4 .toByte(), 5 .toByte())
81
+ col.mean() shouldBe 3.0
82
+
83
+ val colWithNull = columnOf<Byte ?>(1 .toByte(), 2 .toByte(), 3 .toByte(), 4 .toByte(), 5 .toByte(), null )
84
+ colWithNull.mean() shouldBe 3.0
85
+ }
86
+
87
+ @Test
88
+ fun `mean with mixed number types` () {
89
+ val col = columnOf<Number >(1 , 2L , 3.0f , 4.0 , 5 .toShort())
90
+ col.mean() shouldBe 3.0
91
+
92
+ val colWithNull = columnOf<Number ?>(1 , 2L , 3.0f , 4.0 , 5 .toShort(), null )
93
+ colWithNull.mean() shouldBe 3.0
94
+
95
+ // Mix of different integer types
96
+ val intMix = columnOf<Number >(1 , 2L , 3 .toShort(), 4 .toByte())
97
+ intMix.mean() shouldBe 2.5
98
+
99
+ // Mix of different floating point types
100
+ val floatMix = columnOf<Number >(1.0f , 2.0 , 3.0f , 4.0 )
101
+ floatMix.mean() shouldBe 2.5
102
+
103
+ // Mix of integer and floating point types
104
+ val mixedTypes = columnOf<Number >(1 , 2.0f , 3L , 4.0 )
105
+ mixedTypes.mean() shouldBe 2.5
106
+ }
107
+
108
+ @Test
109
+ fun `meanOf with transformation` () {
110
+ val col = columnOf(" 1" , " 2" , " 3" , " 4" , " 5" )
111
+ col.meanOf { it.toInt() } shouldBe 3.0
112
+
113
+ val colWithNull = columnOf(" 1" , " 2" , " 3" , null , " 5" )
114
+ colWithNull.meanOf { it?.toInt() } shouldBe 2.75
115
+ }
116
+
117
+ @Suppress(" ktlint:standard:argument-list-wrapping" )
118
+ @Test
119
+ fun `rowMean with dataframe` () {
120
+ val df = dataFrameOf(
121
+ " a" , " b" , " c" ,
122
+ )(
123
+ 1 , 2 , 3 ,
124
+ 4 , 5 , 6 ,
125
+ 7 , 8 , 9 ,
126
+ )
127
+
128
+ df[0 ].rowMean() shouldBe 2.0
129
+ df[1 ].rowMean() shouldBe 5.0
130
+ df[2 ].rowMean() shouldBe 8.0
131
+ }
132
+
133
+ @Suppress(" ktlint:standard:argument-list-wrapping" )
134
+ @Test
135
+ fun `dataframe mean` () {
136
+ val df = dataFrameOf(
137
+ " a" , " b" , " c" ,
138
+ )(
139
+ 1 , 2.0 , 3f ,
140
+ 4 , 5.0 , 6f ,
141
+ 7 , 8.0 , 9f ,
142
+ )
143
+
144
+ val means = df.mean()
145
+ means[" a" ] shouldBe 4.0
146
+ means[" b" ] shouldBe 5.0
147
+ means[" c" ] shouldBe 6.0
148
+
149
+ // Test mean for specific columns
150
+ val meanFor = df.meanFor(" a" , " c" )
151
+ meanFor[" a" ] shouldBe 4.0
152
+ meanFor[" c" ] shouldBe 6.0
153
+
154
+ // Test mean of all columns as a single value
155
+ df.mean(" a" , " b" , " c" ) shouldBe 5.0
156
+ }
157
+
158
+ @Suppress(" ktlint:standard:argument-list-wrapping" )
159
+ @Test
160
+ fun `dataframe meanOf with transformation` () {
161
+ val df = dataFrameOf(
162
+ " a" , " b" , " c" ,
163
+ )(
164
+ 1 , 2 , 3 ,
165
+ 4 , 5 , 6 ,
166
+ 7 , 8 , 9 ,
167
+ )
168
+
169
+ df.meanOf { " a" <Int >() + " c" <Int >() } shouldBe 10.0
170
+ }
171
+
172
+ @Suppress(" ktlint:standard:argument-list-wrapping" )
173
+ @Test
174
+ fun `mean with skipNaN for floating point numbers` () {
175
+ // Test with Float.NaN values
176
+ val floatWithNaN = columnOf(1.0f , 2.0f , Float .NaN , 4.0f , 5.0f )
177
+ floatWithNaN.mean().shouldBeNaN() // Default behavior: NaN propagates
178
+ floatWithNaN.mean(skipNaN = true ) shouldBe 3.0 // Skip NaN values
179
+
180
+ // Test with Double.NaN values
181
+ val doubleWithNaN = columnOf(1.0 , 2.0 , Double .NaN , 4.0 , 5.0 )
182
+ doubleWithNaN.mean().shouldBeNaN() // Default behavior: NaN propagates
183
+ doubleWithNaN.mean(skipNaN = true ) shouldBe 3.0 // Skip NaN values
184
+
185
+ // Test with multiple NaN values in different positions
186
+ val multipleNaN = columnOf(Float .NaN , 2.0f , Float .NaN , 4.0f , Float .NaN )
187
+ multipleNaN.mean().shouldBeNaN() // Default behavior: NaN propagates
188
+ multipleNaN.mean(skipNaN = true ) shouldBe 3.0 // Skip NaN values
189
+
190
+ // Test with all NaN values
191
+ val allNaN = columnOf(Float .NaN , Float .NaN , Float .NaN )
192
+ allNaN.mean().shouldBeNaN() // Default behavior: NaN propagates
193
+ allNaN.mean(skipNaN = true ).shouldBeNaN() // No valid values, result is NaN
194
+
195
+ // Test with mixed number types including NaN
196
+ val mixedWithNaN = columnOf<Number >(1 , 2.0f , Double .NaN , 4L , 5.0 )
197
+ mixedWithNaN.mean().shouldBeNaN() // Default behavior: NaN propagates
198
+ mixedWithNaN.mean(skipNaN = true ) shouldBe 3.0 // Skip NaN values
199
+
200
+ // Test with DataFrame containing NaN values
201
+ val dfWithNaN = dataFrameOf(
202
+ " a" , " b" , " c" ,
203
+ )(
204
+ 1.0 , Double .NaN , 3.0 ,
205
+ 4.0 , 5.0 , Float .NaN ,
206
+ Double .NaN , 8.0 , 9.0 ,
207
+ )
208
+
209
+ // Test DataFrame mean with skipNaN
210
+ val meansWithNaN = dfWithNaN.mean() // Default behavior
211
+ (meansWithNaN[" a" ] as Double ).shouldBeNaN() // Contains NaN
212
+ (meansWithNaN[" b" ] as Double ).shouldBeNaN() // Contains NaN
213
+ (meansWithNaN[" c" ] as Double ).shouldBeNaN() // Contains NaN
214
+
215
+ val meansSkipNaN = dfWithNaN.mean(skipNaN = true ) // Skip NaN values
216
+ meansSkipNaN[" a" ] shouldBe 2.5 // (1.0 + 4.0) / 2
217
+ meansSkipNaN[" b" ] shouldBe 6.5 // (5.0 + 8.0) / 2
218
+ meansSkipNaN[" c" ] shouldBe 6.0 // (3.0 + 9.0) / 2
219
+
220
+ // Test meanFor with skipNaN
221
+ val meanForWithNaN = dfWithNaN.meanFor(" a" , " c" ) // Default behavior
222
+ (meanForWithNaN[" a" ] as Double ).shouldBeNaN() // Contains NaN
223
+ (meanForWithNaN[" c" ] as Double ).shouldBeNaN() // Contains NaN
224
+
225
+ val meanForSkipNaN = dfWithNaN.meanFor(" a" , " c" , skipNaN = true ) // Skip NaN values
226
+ meanForSkipNaN[" a" ] shouldBe 2.5 // (1.0 + 4.0) / 2
227
+ meanForSkipNaN[" c" ] shouldBe 6.0 // (3.0 + 9.0) / 2
228
+
229
+ // Test mean of all columns as a single value with skipNaN
230
+ dfWithNaN.mean(" a" , " b" , " c" ).shouldBeNaN() // Default behavior: NaN propagates
231
+ dfWithNaN.mean(" a" , " b" , " c" , skipNaN = true ) shouldBe 5.0 // Skip NaN values
232
+
233
+ // Test meanOf with transformation that might produce NaN values
234
+ val dfForTransform = dataFrameOf(
235
+ " a" , " b" ,
236
+ )(
237
+ 1.0 , 0.0 ,
238
+ 4.0 , 2.0 ,
239
+ 0.0 , 0.0 ,
240
+ )
241
+
242
+ // Division by zero produces NaN
243
+ dfForTransform.meanOf { " a" <Double >() / " b" <Double >() }.shouldBeNaN() // Default behavior: NaN propagates
244
+
245
+ // Skip NaN values from division by zero
246
+ dfForTransform.meanOf(skipNaN = true ) {
247
+ val b = " b" <Double >()
248
+ if (b == 0.0 ) Double .NaN else " a" <Double >() / b
249
+ } shouldBe 2.0 // Only 4.0/2.0 = 2.0 is valid
250
+ }
28
251
}
0 commit comments