16
16
*/
17
17
package org .apache .spark .sql .connect .ml
18
18
19
+ import java .io .File
20
+ import java .nio .file .{Files , Path , Paths }
19
21
import java .util .UUID
20
- import java .util .concurrent .{ConcurrentMap , TimeUnit }
22
+ import java .util .concurrent .{ConcurrentHashMap , ConcurrentMap , TimeUnit }
21
23
import java .util .concurrent .atomic .AtomicLong
22
24
23
25
import scala .collection .mutable
24
26
25
27
import com .google .common .cache .{CacheBuilder , RemovalNotification }
28
+ import org .apache .commons .io .FileUtils
26
29
27
30
import org .apache .spark .internal .Logging
28
31
import org .apache .spark .ml .Model
29
- import org .apache .spark .ml .util .ConnectHelper
32
+ import org .apache .spark .ml .util .{ ConnectHelper , MLWritable , Summary }
30
33
import org .apache .spark .sql .connect .config .Connect
31
34
import org .apache .spark .sql .connect .service .SessionHolder
32
35
@@ -36,38 +39,74 @@ import org.apache.spark.sql.connect.service.SessionHolder
36
39
private [connect] class MLCache (sessionHolder : SessionHolder ) extends Logging {
37
40
private val helper = new ConnectHelper ()
38
41
private val helperID = " ______ML_CONNECT_HELPER______"
42
+ private val modelClassNameFile = " __model_class_name__"
39
43
40
- private def getMaxCacheSizeKB : Long = {
41
- sessionHolder.session.conf.get(Connect .CONNECT_SESSION_CONNECT_ML_CACHE_MAX_SIZE ) / 1024
44
+ // TODO: rename it to `totalInMemorySizeBytes` because it only counts the in-memory
45
+ // part data size.
46
+ private [ml] val totalSizeBytes : AtomicLong = new AtomicLong (0 )
47
+
48
+ val offloadedModelsDir : Path = {
49
+ val path = Paths .get(
50
+ System .getProperty(" java.io.tmpdir" ),
51
+ " spark_connect_model_cache" ,
52
+ sessionHolder.sessionId)
53
+ Files .createDirectories(path)
54
+ }
55
+ private [spark] def getOffloadingEnabled : Boolean = {
56
+ sessionHolder.session.conf.get(Connect .CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED )
42
57
}
43
58
44
- private def getTimeoutMinute : Long = {
45
- sessionHolder.session.conf.get(Connect .CONNECT_SESSION_CONNECT_ML_CACHE_TIMEOUT )
59
+ private def getMaxInMemoryCacheSizeKB : Long = {
60
+ sessionHolder.session.conf.get(
61
+ Connect .CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_MAX_IN_MEMORY_SIZE ) / 1024
62
+ }
63
+
64
+ private def getOffloadingTimeoutMinute : Long = {
65
+ sessionHolder.session.conf.get(Connect .CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_TIMEOUT )
46
66
}
47
67
48
68
private [ml] case class CacheItem (obj : Object , sizeBytes : Long )
49
- private [ml] val cachedModel : ConcurrentMap [String , CacheItem ] = CacheBuilder
50
- .newBuilder()
51
- .softValues()
52
- .maximumWeight(getMaxCacheSizeKB)
53
- .expireAfterAccess(getTimeoutMinute, TimeUnit .MINUTES )
54
- .weigher((key : String , value : CacheItem ) => {
55
- Math .ceil(value.sizeBytes.toDouble / 1024 ).toInt
56
- })
57
- .removalListener((removed : RemovalNotification [String , CacheItem ]) =>
58
- totalSizeBytes.addAndGet(- removed.getValue.sizeBytes))
59
- .build[String , CacheItem ]()
60
- .asMap()
69
+ private [ml] val cachedModel : ConcurrentMap [String , CacheItem ] = {
70
+ if (getOffloadingEnabled) {
71
+ CacheBuilder
72
+ .newBuilder()
73
+ .softValues()
74
+ .removalListener((removed : RemovalNotification [String , CacheItem ]) =>
75
+ totalSizeBytes.addAndGet(- removed.getValue.sizeBytes))
76
+ .maximumWeight(getMaxInMemoryCacheSizeKB)
77
+ .weigher((key : String , value : CacheItem ) => {
78
+ Math .ceil(value.sizeBytes.toDouble / 1024 ).toInt
79
+ })
80
+ .expireAfterAccess(getOffloadingTimeoutMinute, TimeUnit .MINUTES )
81
+ .build[String , CacheItem ]()
82
+ .asMap()
83
+ } else {
84
+ new ConcurrentHashMap [String , CacheItem ]()
85
+ }
86
+ }
61
87
62
- private [ml] val totalSizeBytes : AtomicLong = new AtomicLong (0 )
88
+ private [ml] val cachedSummary : ConcurrentMap [String , Summary ] = {
89
+ new ConcurrentHashMap [String , Summary ]()
90
+ }
63
91
64
92
private def estimateObjectSize (obj : Object ): Long = {
65
93
obj match {
66
94
case model : Model [_] =>
67
95
model.asInstanceOf [Model [_]].estimatedSize
68
96
case _ =>
69
97
// There can only be Models in the cache, so we should never reach here.
70
- 1
98
+ throw new RuntimeException (f " Unexpected model object type. " )
99
+ }
100
+ }
101
+
102
+ private [spark] def checkSummaryAvail (): Unit = {
103
+ if (getOffloadingEnabled) {
104
+ throw MlUnsupportedException (
105
+ " SparkML 'model.summary' and 'model.evaluate' APIs are not supported' when " +
106
+ " Spark Connect session ML cache offloading is enabled. You can use APIs in " +
107
+ " 'pyspark.ml.evaluation' instead, or you can set Spark config " +
108
+ " 'spark.connect.session.connectML.mlCache.offloading.enabled' to 'false' to " +
109
+ " disable Spark Connect session ML cache offloading." )
71
110
}
72
111
}
73
112
@@ -80,9 +119,26 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging {
80
119
*/
81
120
def register (obj : Object ): String = {
82
121
val objectId = UUID .randomUUID().toString
83
- val sizeBytes = estimateObjectSize(obj)
84
- totalSizeBytes.addAndGet(sizeBytes)
85
- cachedModel.put(objectId, CacheItem (obj, sizeBytes))
122
+
123
+ if (obj.isInstanceOf [Summary ]) {
124
+ checkSummaryAvail()
125
+ cachedSummary.put(objectId, obj.asInstanceOf [Summary ])
126
+ } else if (obj.isInstanceOf [Model [_]]) {
127
+ val sizeBytes = if (getOffloadingEnabled) {
128
+ estimateObjectSize(obj)
129
+ } else {
130
+ 0L // Don't need to calculate size if disables offloading.
131
+ }
132
+ cachedModel.put(objectId, CacheItem (obj, sizeBytes))
133
+ if (getOffloadingEnabled) {
134
+ val savePath = offloadedModelsDir.resolve(objectId)
135
+ obj.asInstanceOf [MLWritable ].write.saveToLocal(savePath.toString)
136
+ Files .writeString(savePath.resolve(modelClassNameFile), obj.getClass.getName)
137
+ }
138
+ totalSizeBytes.addAndGet(sizeBytes)
139
+ } else {
140
+ throw new RuntimeException (" 'MLCache.register' only accepts model or summary objects." )
141
+ }
86
142
objectId
87
143
}
88
144
@@ -97,8 +153,41 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging {
97
153
if (refId == helperID) {
98
154
helper
99
155
} else {
100
- Option (cachedModel.get(refId)).map(_.obj).orNull
156
+ var obj : Object =
157
+ Option (cachedModel.get(refId)).map(_.obj).getOrElse(cachedSummary.get(refId))
158
+ if (obj == null && getOffloadingEnabled) {
159
+ val loadPath = offloadedModelsDir.resolve(refId)
160
+ if (Files .isDirectory(loadPath)) {
161
+ val className = Files .readString(loadPath.resolve(modelClassNameFile))
162
+ obj = MLUtils .loadTransformer(
163
+ sessionHolder,
164
+ className,
165
+ loadPath.toString,
166
+ loadFromLocal = true )
167
+ val sizeBytes = estimateObjectSize(obj)
168
+ cachedModel.put(refId, CacheItem (obj, sizeBytes))
169
+ totalSizeBytes.addAndGet(sizeBytes)
170
+ }
171
+ }
172
+ obj
173
+ }
174
+ }
175
+
176
+ def _removeModel (refId : String ): Boolean = {
177
+ val removedModel = cachedModel.remove(refId)
178
+ val removedFromMem = removedModel != null
179
+ val removedFromDisk = if (getOffloadingEnabled) {
180
+ val offloadingPath = new File (offloadedModelsDir.resolve(refId).toString)
181
+ if (offloadingPath.exists()) {
182
+ FileUtils .deleteDirectory(offloadingPath)
183
+ true
184
+ } else {
185
+ false
186
+ }
187
+ } else {
188
+ false
101
189
}
190
+ removedFromMem || removedFromDisk
102
191
}
103
192
104
193
/**
@@ -107,9 +196,14 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging {
107
196
* the key used to look up the corresponding object
108
197
*/
109
198
def remove (refId : String ): Boolean = {
110
- val removed = cachedModel.remove(refId)
111
- // remove returns null if the key is not present
112
- removed != null
199
+ val modelIsRemoved = _removeModel(refId)
200
+
201
+ if (modelIsRemoved) {
202
+ true
203
+ } else {
204
+ val removedSummary = cachedSummary.remove(refId)
205
+ removedSummary != null
206
+ }
113
207
}
114
208
115
209
/**
@@ -118,6 +212,10 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging {
118
212
def clear (): Int = {
119
213
val size = cachedModel.size()
120
214
cachedModel.clear()
215
+ cachedSummary.clear()
216
+ if (getOffloadingEnabled) {
217
+ FileUtils .cleanDirectory(new File (offloadedModelsDir.toString))
218
+ }
121
219
size
122
220
}
123
221
0 commit comments