Skip to content

Commit 26545d0

Browse files
committed
[SPARK-51394][ML] Optimize out the additional shuffle in stats tests
### What changes were proposed in this pull request? Optimize out the additional shuffle in stats tests ### Why are the changes needed? for simplification ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #50166 from zhengruifeng/ml_cst. Authored-by: Ruifeng Zheng <ruifengz@apache.org> Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
1 parent c005a37 commit 26545d0

File tree

3 files changed

+29
-31
lines changed

3 files changed

+29
-31
lines changed

mllib/src/main/scala/org/apache/spark/ml/stat/ANOVATest.scala

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,20 +69,18 @@ private[ml] object ANOVATest {
6969
val spark = dataset.sparkSession
7070
import spark.implicits._
7171

72-
val resultDF = testClassification(dataset, featuresCol, labelCol)
73-
.toDF("featureIndex", "pValue", "degreesOfFreedom", "fValue")
72+
val resRdd = testClassification(dataset, featuresCol, labelCol)
7473

7574
if (flatten) {
76-
resultDF
75+
resRdd.toDF("featureIndex", "pValue", "degreesOfFreedom", "fValue")
7776
} else {
78-
resultDF.agg(collect_list(struct("*")))
79-
.as[Seq[(Int, Double, Long, Double)]]
80-
.map { seq =>
81-
val results = seq.toArray.sortBy(_._1)
82-
val pValues = Vectors.dense(results.map(_._2))
83-
val degreesOfFreedom = results.map(_._3)
84-
val fValues = Vectors.dense(results.map(_._4))
85-
(pValues, degreesOfFreedom, fValues)
77+
resRdd.coalesce(1)
78+
.mapPartitions { iter =>
79+
val res = iter.toArray.sortBy(_._1)
80+
val pValues = Vectors.dense(res.map(_._2))
81+
val degreesOfFreedom = res.map(_._3)
82+
val fValues = Vectors.dense(res.map(_._4))
83+
Iterator.single((pValues, degreesOfFreedom, fValues))
8684
}.toDF("pValues", "degreesOfFreedom", "fValues")
8785
}
8886
}

mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -80,22 +80,21 @@ object ChiSquareTest {
8080
val data = dataset.select(col(labelCol).cast("double"), col(featuresCol)).rdd
8181
.map { case Row(label: Double, vec: Vector) => (label, OldVectors.fromML(vec)) }
8282

83-
val resultDF = OldChiSqTest.computeChiSquared(data)
83+
val resRDD = OldChiSqTest.computeChiSquared(data)
8484
.map { case (col, pValue, degreesOfFreedom, statistic, _) =>
8585
(col, pValue, degreesOfFreedom, statistic)
86-
}.toDF("featureIndex", "pValue", "degreesOfFreedom", "statistic")
86+
}
8787

8888
if (flatten) {
89-
resultDF
89+
resRDD.toDF("featureIndex", "pValue", "degreesOfFreedom", "statistic")
9090
} else {
91-
resultDF.agg(collect_list(struct("*")))
92-
.as[Seq[(Int, Double, Int, Double)]]
93-
.map { seq =>
94-
val results = seq.toArray.sortBy(_._1)
95-
val pValues = Vectors.dense(results.map(_._2))
96-
val degreesOfFreedom = results.map(_._3)
97-
val statistics = Vectors.dense(results.map(_._4))
98-
(pValues, degreesOfFreedom, statistics)
91+
resRDD.coalesce(1)
92+
.mapPartitions { iter =>
93+
val res = iter.toArray.sortBy(_._1)
94+
val pValues = Vectors.dense(res.map(_._2))
95+
val degreesOfFreedom = res.map(_._3)
96+
val statistics = Vectors.dense(res.map(_._4))
97+
Iterator.single((pValues, degreesOfFreedom, statistics))
9998
}.toDF("pValues", "degreesOfFreedom", "statistics")
10099
}
101100
}

mllib/src/main/scala/org/apache/spark/ml/stat/FValueTest.scala

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,20 +70,21 @@ private[ml] object FValueTest {
7070
val spark = dataset.sparkSession
7171
import spark.implicits._
7272

73+
val resRDD = testRegression(dataset, featuresCol, labelCol)
74+
7375
val resultDF = testRegression(dataset, featuresCol, labelCol)
7476
.toDF("featureIndex", "pValue", "degreesOfFreedom", "fValue")
7577

7678
if (flatten) {
77-
resultDF
79+
resRDD.toDF("featureIndex", "pValue", "degreesOfFreedom", "fValue")
7880
} else {
79-
resultDF.agg(collect_list(struct("*")))
80-
.as[Seq[(Int, Double, Long, Double)]]
81-
.map { seq =>
82-
val results = seq.toArray.sortBy(_._1)
83-
val pValues = Vectors.dense(results.map(_._2))
84-
val degreesOfFreedom = results.map(_._3)
85-
val fValues = Vectors.dense(results.map(_._4))
86-
(pValues, degreesOfFreedom, fValues)
81+
resRDD.coalesce(1)
82+
.mapPartitions { iter =>
83+
val res = iter.toArray.sortBy(_._1)
84+
val pValues = Vectors.dense(res.map(_._2))
85+
val degreesOfFreedom = res.map(_._3)
86+
val fValues = Vectors.dense(res.map(_._4))
87+
Iterator.single((pValues, degreesOfFreedom, fValues))
8788
}.toDF("pValues", "degreesOfFreedom", "fValues")
8889
}
8990
}

0 commit comments

Comments
 (0)