Skip to content

Commit 894d828

Browse files
committed
[SPARK-51947] Spark connect model cache offloading
### What changes were proposed in this pull request? Support offloading model cache to spark driver local disk ### Why are the changes needed? Motivation: this feature helps to reduce spark driver memory pressure. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manually. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #50752 from WeichenXu123/mlcache-offload. Authored-by: Weichen Xu <weichen.xu@databricks.com> Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
1 parent e857f43 commit 894d828

File tree

7 files changed

+228
-57
lines changed

7 files changed

+228
-57
lines changed

python/pyspark/ml/tests/connect/test_parity_tuning.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,16 @@ class TuningParityTests(TuningTestsMixin, ReusedConnectTestCase):
2525
pass
2626

2727

28+
class TuningParityWithMLCacheOffloadingEnabledTests(TuningTestsMixin, ReusedConnectTestCase):
29+
@classmethod
30+
def conf(cls):
31+
conf = super().conf()
32+
conf.set("spark.connect.session.connectML.mlCache.offloading.enabled", "true")
33+
conf.set("spark.connect.session.connectML.mlCache.offloading.maxInMemorySize", "1024")
34+
conf.set("spark.connect.session.connectML.mlCache.offloading.timeout", "1")
35+
return conf
36+
37+
2838
if __name__ == "__main__":
2939
from pyspark.ml.tests.connect.test_parity_tuning import * # noqa: F401
3040

python/pyspark/testing/connectutils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,9 @@ def conf(cls):
158158
# Set a static token for all tests so the parallelism doesn't overwrite each
159159
# tests' environment variables
160160
conf.set("spark.connect.authenticate.token", "deadbeef")
161-
# Make the max size of ML Cache larger, to avoid CONNECT_ML.CACHE_INVALID issues
162-
# in tests.
163-
conf.set("spark.connect.session.connectML.mlCache.maxSize", "1g")
161+
# Disable ml cache offloading,
162+
# offloading hasn't supported APIs like model.summary / model.evaluate
163+
conf.set("spark.connect.session.connectML.mlCache.offloading.enabled", "false")
164164
return conf
165165

166166
@classmethod

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -334,23 +334,36 @@ object Connect {
334334
}
335335
}
336336

337-
val CONNECT_SESSION_CONNECT_ML_CACHE_MAX_SIZE =
338-
buildConf("spark.connect.session.connectML.mlCache.maxSize")
339-
.doc("Maximum size of the MLCache per session. The cache will evict the least recently" +
340-
"used models if the size exceeds this limit. The size is in bytes.")
337+
val CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_MAX_IN_MEMORY_SIZE =
338+
buildConf("spark.connect.session.connectML.mlCache.offloading.maxInMemorySize")
339+
.doc(
340+
"In-memory maximum size of the MLCache per session. The cache will offload the least " +
341+
"recently used models to Spark driver local disk if the size exceeds this limit. " +
342+
"The size is in bytes. This configuration only works when " +
343+
"'spark.connect.session.connectML.mlCache.offloading.enabled' is 'true'.")
341344
.version("4.1.0")
342345
.internal()
343346
.bytesConf(ByteUnit.BYTE)
344347
// By default, 1/3 of total designated memory (the configured -Xmx).
345348
.createWithDefault(Runtime.getRuntime.maxMemory() / 3)
346349

347-
val CONNECT_SESSION_CONNECT_ML_CACHE_TIMEOUT =
348-
buildConf("spark.connect.session.connectML.mlCache.timeout")
350+
val CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_TIMEOUT =
351+
buildConf("spark.connect.session.connectML.mlCache.offloading.timeout")
349352
.doc(
350-
"Timeout of models in MLCache. Models will be evicted from the cache if they are not " +
351-
"used for this amount of time. The timeout is in minutes.")
353+
"Timeout of model offloading in MLCache. Models will be offloaded to Spark driver local " +
354+
"disk if they are not used for this amount of time. The timeout is in minutes. " +
355+
"This configuration only works when " +
356+
"'spark.connect.session.connectML.mlCache.offloading.enabled' is 'true'.")
352357
.version("4.1.0")
353358
.internal()
354359
.timeConf(TimeUnit.MINUTES)
355-
.createWithDefault(15)
360+
.createWithDefault(5)
361+
362+
val CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED =
363+
buildConf("spark.connect.session.connectML.mlCache.offloading.enabled")
364+
.doc("Enables ML cache offloading.")
365+
.version("4.1.0")
366+
.internal()
367+
.booleanConf
368+
.createWithDefault(true)
356369
}

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala

Lines changed: 125 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,20 @@
1616
*/
1717
package org.apache.spark.sql.connect.ml
1818

19+
import java.io.File
20+
import java.nio.file.{Files, Path, Paths}
1921
import java.util.UUID
20-
import java.util.concurrent.{ConcurrentMap, TimeUnit}
22+
import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, TimeUnit}
2123
import java.util.concurrent.atomic.AtomicLong
2224

2325
import scala.collection.mutable
2426

2527
import com.google.common.cache.{CacheBuilder, RemovalNotification}
28+
import org.apache.commons.io.FileUtils
2629

2730
import org.apache.spark.internal.Logging
2831
import org.apache.spark.ml.Model
29-
import org.apache.spark.ml.util.ConnectHelper
32+
import org.apache.spark.ml.util.{ConnectHelper, MLWritable, Summary}
3033
import org.apache.spark.sql.connect.config.Connect
3134
import org.apache.spark.sql.connect.service.SessionHolder
3235

@@ -36,38 +39,74 @@ import org.apache.spark.sql.connect.service.SessionHolder
3639
private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging {
3740
private val helper = new ConnectHelper()
3841
private val helperID = "______ML_CONNECT_HELPER______"
42+
private val modelClassNameFile = "__model_class_name__"
3943

40-
private def getMaxCacheSizeKB: Long = {
41-
sessionHolder.session.conf.get(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MAX_SIZE) / 1024
44+
// TODO: rename it to `totalInMemorySizeBytes` because it only counts the in-memory
45+
// part data size.
46+
private[ml] val totalSizeBytes: AtomicLong = new AtomicLong(0)
47+
48+
val offloadedModelsDir: Path = {
49+
val path = Paths.get(
50+
System.getProperty("java.io.tmpdir"),
51+
"spark_connect_model_cache",
52+
sessionHolder.sessionId)
53+
Files.createDirectories(path)
54+
}
55+
private[spark] def getOffloadingEnabled: Boolean = {
56+
sessionHolder.session.conf.get(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED)
4257
}
4358

44-
private def getTimeoutMinute: Long = {
45-
sessionHolder.session.conf.get(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_TIMEOUT)
59+
private def getMaxInMemoryCacheSizeKB: Long = {
60+
sessionHolder.session.conf.get(
61+
Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_MAX_IN_MEMORY_SIZE) / 1024
62+
}
63+
64+
private def getOffloadingTimeoutMinute: Long = {
65+
sessionHolder.session.conf.get(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_TIMEOUT)
4666
}
4767

4868
private[ml] case class CacheItem(obj: Object, sizeBytes: Long)
49-
private[ml] val cachedModel: ConcurrentMap[String, CacheItem] = CacheBuilder
50-
.newBuilder()
51-
.softValues()
52-
.maximumWeight(getMaxCacheSizeKB)
53-
.expireAfterAccess(getTimeoutMinute, TimeUnit.MINUTES)
54-
.weigher((key: String, value: CacheItem) => {
55-
Math.ceil(value.sizeBytes.toDouble / 1024).toInt
56-
})
57-
.removalListener((removed: RemovalNotification[String, CacheItem]) =>
58-
totalSizeBytes.addAndGet(-removed.getValue.sizeBytes))
59-
.build[String, CacheItem]()
60-
.asMap()
69+
private[ml] val cachedModel: ConcurrentMap[String, CacheItem] = {
70+
if (getOffloadingEnabled) {
71+
CacheBuilder
72+
.newBuilder()
73+
.softValues()
74+
.removalListener((removed: RemovalNotification[String, CacheItem]) =>
75+
totalSizeBytes.addAndGet(-removed.getValue.sizeBytes))
76+
.maximumWeight(getMaxInMemoryCacheSizeKB)
77+
.weigher((key: String, value: CacheItem) => {
78+
Math.ceil(value.sizeBytes.toDouble / 1024).toInt
79+
})
80+
.expireAfterAccess(getOffloadingTimeoutMinute, TimeUnit.MINUTES)
81+
.build[String, CacheItem]()
82+
.asMap()
83+
} else {
84+
new ConcurrentHashMap[String, CacheItem]()
85+
}
86+
}
6187

62-
private[ml] val totalSizeBytes: AtomicLong = new AtomicLong(0)
88+
private[ml] val cachedSummary: ConcurrentMap[String, Summary] = {
89+
new ConcurrentHashMap[String, Summary]()
90+
}
6391

6492
private def estimateObjectSize(obj: Object): Long = {
6593
obj match {
6694
case model: Model[_] =>
6795
model.asInstanceOf[Model[_]].estimatedSize
6896
case _ =>
6997
// There can only be Models in the cache, so we should never reach here.
70-
1
98+
throw new RuntimeException(f"Unexpected model object type.")
99+
}
100+
}
101+
102+
private[spark] def checkSummaryAvail(): Unit = {
103+
if (getOffloadingEnabled) {
104+
throw MlUnsupportedException(
105+
"SparkML 'model.summary' and 'model.evaluate' APIs are not supported' when " +
106+
"Spark Connect session ML cache offloading is enabled. You can use APIs in " +
107+
"'pyspark.ml.evaluation' instead, or you can set Spark config " +
108+
"'spark.connect.session.connectML.mlCache.offloading.enabled' to 'false' to " +
109+
"disable Spark Connect session ML cache offloading.")
71110
}
72111
}
73112

@@ -80,9 +119,26 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging {
80119
*/
81120
def register(obj: Object): String = {
82121
val objectId = UUID.randomUUID().toString
83-
val sizeBytes = estimateObjectSize(obj)
84-
totalSizeBytes.addAndGet(sizeBytes)
85-
cachedModel.put(objectId, CacheItem(obj, sizeBytes))
122+
123+
if (obj.isInstanceOf[Summary]) {
124+
checkSummaryAvail()
125+
cachedSummary.put(objectId, obj.asInstanceOf[Summary])
126+
} else if (obj.isInstanceOf[Model[_]]) {
127+
val sizeBytes = if (getOffloadingEnabled) {
128+
estimateObjectSize(obj)
129+
} else {
130+
0L // Don't need to calculate size if disables offloading.
131+
}
132+
cachedModel.put(objectId, CacheItem(obj, sizeBytes))
133+
if (getOffloadingEnabled) {
134+
val savePath = offloadedModelsDir.resolve(objectId)
135+
obj.asInstanceOf[MLWritable].write.saveToLocal(savePath.toString)
136+
Files.writeString(savePath.resolve(modelClassNameFile), obj.getClass.getName)
137+
}
138+
totalSizeBytes.addAndGet(sizeBytes)
139+
} else {
140+
throw new RuntimeException("'MLCache.register' only accepts model or summary objects.")
141+
}
86142
objectId
87143
}
88144

@@ -97,8 +153,41 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging {
97153
if (refId == helperID) {
98154
helper
99155
} else {
100-
Option(cachedModel.get(refId)).map(_.obj).orNull
156+
var obj: Object =
157+
Option(cachedModel.get(refId)).map(_.obj).getOrElse(cachedSummary.get(refId))
158+
if (obj == null && getOffloadingEnabled) {
159+
val loadPath = offloadedModelsDir.resolve(refId)
160+
if (Files.isDirectory(loadPath)) {
161+
val className = Files.readString(loadPath.resolve(modelClassNameFile))
162+
obj = MLUtils.loadTransformer(
163+
sessionHolder,
164+
className,
165+
loadPath.toString,
166+
loadFromLocal = true)
167+
val sizeBytes = estimateObjectSize(obj)
168+
cachedModel.put(refId, CacheItem(obj, sizeBytes))
169+
totalSizeBytes.addAndGet(sizeBytes)
170+
}
171+
}
172+
obj
173+
}
174+
}
175+
176+
def _removeModel(refId: String): Boolean = {
177+
val removedModel = cachedModel.remove(refId)
178+
val removedFromMem = removedModel != null
179+
val removedFromDisk = if (getOffloadingEnabled) {
180+
val offloadingPath = new File(offloadedModelsDir.resolve(refId).toString)
181+
if (offloadingPath.exists()) {
182+
FileUtils.deleteDirectory(offloadingPath)
183+
true
184+
} else {
185+
false
186+
}
187+
} else {
188+
false
101189
}
190+
removedFromMem || removedFromDisk
102191
}
103192

104193
/**
@@ -107,9 +196,14 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging {
107196
* the key used to look up the corresponding object
108197
*/
109198
def remove(refId: String): Boolean = {
110-
val removed = cachedModel.remove(refId)
111-
// remove returns null if the key is not present
112-
removed != null
199+
val modelIsRemoved = _removeModel(refId)
200+
201+
if (modelIsRemoved) {
202+
true
203+
} else {
204+
val removedSummary = cachedSummary.remove(refId)
205+
removedSummary != null
206+
}
113207
}
114208

115209
/**
@@ -118,6 +212,10 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging {
118212
def clear(): Int = {
119213
val size = cachedModel.size()
120214
cachedModel.clear()
215+
cachedSummary.clear()
216+
if (getOffloadingEnabled) {
217+
FileUtils.cleanDirectory(new File(offloadedModelsDir.toString))
218+
}
121219
size
122220
}
123221

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.spark.ml.param.{ParamMap, Params}
2727
import org.apache.spark.ml.util.{MLWritable, Summary}
2828
import org.apache.spark.sql.DataFrame
2929
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter
30+
import org.apache.spark.sql.connect.config.Connect
3031
import org.apache.spark.sql.connect.ml.Serializer.deserializeMethodArguments
3132
import org.apache.spark.sql.connect.service.SessionHolder
3233

@@ -42,7 +43,7 @@ private class AttributeHelper(
4243
val sessionHolder: SessionHolder,
4344
val objRef: String,
4445
val methods: Array[Method]) {
45-
protected lazy val instance = {
46+
protected def instance(): Object = {
4647
val obj = sessionHolder.mlCache.get(objRef)
4748
if (obj == null) {
4849
throw MLCacheInvalidException(s"object $objRef")
@@ -52,7 +53,10 @@ private class AttributeHelper(
5253
// Get the attribute by reflection
5354
def getAttribute: Any = {
5455
assert(methods.length >= 1)
55-
methods.foldLeft(instance) { (obj, m) =>
56+
methods.foldLeft(instance()) { (obj, m) =>
57+
if (obj.isInstanceOf[Summary]) {
58+
sessionHolder.mlCache.checkSummaryAvail()
59+
}
5660
if (m.argValues.isEmpty) {
5761
MLUtils.invokeMethodAllowed(obj, m.name)
5862
} else {
@@ -71,7 +75,7 @@ private class ModelAttributeHelper(
7175

7276
def transform(relation: proto.MlRelation.Transform): DataFrame = {
7377
// Create a copied model to avoid concurrently modify model params.
74-
val model = instance.asInstanceOf[Model[_]]
78+
val model = instance().asInstanceOf[Model[_]]
7579
val copiedModel = model.copy(ParamMap.empty).asInstanceOf[Model[_]]
7680
MLUtils.setInstanceParams(copiedModel, relation.getParams)
7781
val inputDF = MLUtils.parseRelationProto(relation.getInput, sessionHolder)
@@ -119,13 +123,31 @@ private[connect] object MLHandler extends Logging {
119123

120124
mlCommand.getCommandCase match {
121125
case proto.MlCommand.CommandCase.FIT =>
126+
val offloadingEnabled = sessionHolder.session.conf.get(
127+
Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED)
122128
val fitCmd = mlCommand.getFit
123129
val estimatorProto = fitCmd.getEstimator
124130
assert(estimatorProto.getType == proto.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR)
125131

126132
val dataset = MLUtils.parseRelationProto(fitCmd.getDataset, sessionHolder)
127133
val estimator =
128134
MLUtils.getEstimator(sessionHolder, estimatorProto, Some(fitCmd.getParams))
135+
if (offloadingEnabled) {
136+
if (estimator.getClass.getName == "org.apache.spark.ml.fpm.FPGrowth") {
137+
throw MlUnsupportedException(
138+
"FPGrowth algorithm is not supported " +
139+
"if Spark Connect model cache offloading is enabled.")
140+
}
141+
if (estimator.getClass.getName == "org.apache.spark.ml.clustering.LDA"
142+
&& estimator
143+
.asInstanceOf[org.apache.spark.ml.clustering.LDA]
144+
.getOptimizer
145+
.toLowerCase() == "em") {
146+
throw MlUnsupportedException(
147+
"LDA algorithm with 'em' optimizer is not supported " +
148+
"if Spark Connect model cache offloading is enabled.")
149+
}
150+
}
129151
val model = estimator.fit(dataset).asInstanceOf[Model[_]]
130152
val id = mlCache.register(model)
131153
proto.MlCommandResult

0 commit comments

Comments
 (0)