Skip to content

Commit cdd5296

Browse files
committed
[SPARK-51856][ML][CONNECT] Update model size API to count distributed DataFrame size
### What changes were proposed in this pull request? Update model size API to count distributed DataFrame size ### Why are the changes needed? For Spark server ML cache management. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #50652 from WeichenXu123/get-model-ser-size-api. Lead-authored-by: Weichen Xu <weichen.xu@databricks.com> Co-authored-by: WeichenXu <weichen.xu@databricks.com> Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
1 parent 772c4cb commit cdd5296

File tree

8 files changed

+50
-5
lines changed

8 files changed

+50
-5
lines changed

mllib/src/main/scala/org/apache/spark/ml/Estimator.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage {
8787
* Estimate an upper-bound size of the model to be fitted in bytes, based on the
8888
* parameters and the dataset, e.g., using $(k) and numFeatures to estimate a
8989
* k-means model size.
90-
* 1, Only driver side memory usage is counted, distributed objects (like DataFrame,
91-
* RDD, Graph, Summary) are ignored.
90+
* 1, Both driver side memory usage and distributed objects size (like DataFrame,
91+
* RDD, Graph, Summary) are counted.
9292
* 2, Lazy vals are not counted, e.g., an auxiliary object used in prediction.
9393
* 3, If there is no enough information to get an accurate size, try to estimate the
9494
* upper-bound size, e.g.

mllib/src/main/scala/org/apache/spark/ml/Model.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ abstract class Model[M <: Model[M]] extends Transformer { self =>
4949
* For ml connect only.
5050
* Estimate the size of this model in bytes.
5151
* This is an approximation, the real size might be different.
52-
* 1, Only driver side memory usage is counted, distributed objects (like DataFrame,
53-
* RDD, Graph, Summary) are ignored.
52+
* 1, Both driver side memory usage and distributed objects size (like DataFrame,
53+
* RDD, Graph, Summary) are counted.
5454
* 2, Lazy vals are not counted, e.g., an auxiliary object used in prediction.
5555
* 3, The default implementation uses `org.apache.spark.util.SizeEstimator.estimate`,
5656
* some models override the default implementation to achieve more precise estimation.

mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,11 @@ class DistributedLDAModel private[ml] (
805805
override def toString: String = {
806806
s"DistributedLDAModel: uid=$uid, k=${$(k)}, numFeatures=$vocabSize"
807807
}
808+
809+
override def estimatedSize: Long = {
810+
// TODO: Implement this method.
811+
throw new UnsupportedOperationException
812+
}
808813
}
809814

810815

mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,11 @@ class FPGrowthModel private[ml] (
322322
override def toString: String = {
323323
s"FPGrowthModel: uid=$uid, numTrainingRecords=$numTrainingRecords"
324324
}
325+
326+
override def estimatedSize: Long = {
327+
// TODO: Implement this method.
328+
throw new UnsupportedOperationException
329+
}
325330
}
326331

327332
@Since("2.2.0")

mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,11 @@ class ALSModel private[ml] (
540540
}
541541
}
542542

543+
override def estimatedSize: Long = {
544+
val userCount = userFactors.count()
545+
val itemCount = itemFactors.count()
546+
(userCount + itemCount) * (rank + 1) * 4
547+
}
543548
}
544549

545550
@Since("1.6.0")
@@ -771,6 +776,13 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
771776

772777
@Since("1.5.0")
773778
override def copy(extra: ParamMap): ALS = defaultCopy(extra)
779+
780+
override def estimateModelSize(dataset: Dataset[_]): Long = {
781+
val userCount = dataset.select(getUserCol).distinct().count()
782+
val itemCount = dataset.select(getItemCol).distinct().count()
783+
val rank = getRank
784+
(userCount + itemCount) * (rank + 1) * 4
785+
}
774786
}
775787

776788

mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,6 +1128,24 @@ class ALSStorageSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
11281128
levels.foreach(level => assert(level == StorageLevel.MEMORY_ONLY))
11291129
nonDefaultListener.storageLevels.foreach(level => assert(level == StorageLevel.DISK_ONLY))
11301130
}
1131+
1132+
test("saved model size estimation") {
1133+
import testImplicits._
1134+
1135+
val als = new ALS().setMaxIter(1).setRank(8)
1136+
val estimatedDFSize = (3 + 2) * (8 + 1) * 4
1137+
val df = sc.parallelize(Seq(
1138+
(123, 1, 0.5),
1139+
(123, 2, 0.7),
1140+
(123, 3, 0.6),
1141+
(111, 2, 1.0),
1142+
(111, 1, 0.1)
1143+
)).toDF("item", "user", "rating")
1144+
assert(als.estimateModelSize(df) === estimatedDFSize)
1145+
1146+
val model = als.fit(df)
1147+
assert(model.estimatedSize == estimatedDFSize)
1148+
}
11311149
}
11321150

11331151
private class IntermediateRDDStorageListener extends SparkListener {

python/pyspark/ml/tests/test_clustering.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
DistributedLDAModel,
3838
PowerIterationClustering,
3939
)
40+
from pyspark.sql import is_remote
4041
from pyspark.testing.sqlutils import ReusedSQLTestCase
4142

4243

@@ -377,6 +378,8 @@ def test_local_lda(self):
377378
self.assertEqual(str(model), str(model2))
378379

379380
def test_distributed_lda(self):
381+
if is_remote():
382+
self.skipTest("Do not support Spark Connect.")
380383
spark = self.spark
381384
df = (
382385
spark.createDataFrame(

python/pyspark/ml/tests/test_fpm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import tempfile
1919
import unittest
2020

21-
from pyspark.sql import Row
21+
from pyspark.sql import is_remote, Row
2222
import pyspark.sql.functions as sf
2323
from pyspark.ml.fpm import (
2424
FPGrowth,
@@ -30,6 +30,8 @@
3030

3131
class FPMTestsMixin:
3232
def test_fp_growth(self):
33+
if is_remote():
34+
self.skipTest("Do not support Spark Connect.")
3335
df = self.spark.createDataFrame(
3436
[
3537
["r z h k p"],

0 commit comments

Comments
 (0)