Skip to content

Commit aeff679

Browse files
WeichenXu123zhengruifeng
authored andcommitted
[SPARK-51880][ML][PYTHON][CONNECT] Fix ML cache object python client references
### What changes were proposed in this pull request? Fix ML cache object python client references. When a model is copied from client, it results in multiple client model objects refer to the same server cached model. In this case, we need a reference count, only when reference count decreases to zero, we can release the server cached model. ### Why are the changes needed? Bugfix. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #50707 from WeichenXu123/ml-ref-id-fix. Lead-authored-by: Weichen Xu <weichen.xu@databricks.com> Co-authored-by: WeichenXu <weichen.xu@databricks.com> Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
1 parent 6f47783 commit aeff679

File tree

8 files changed

+123
-40
lines changed

8 files changed

+123
-40
lines changed

python/pyspark/ml/classification.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2306,7 +2306,12 @@ def featureImportances(self) -> Vector:
23062306
def trees(self) -> List[DecisionTreeClassificationModel]:
23072307
"""Trees in this ensemble. Warning: These have null parent Estimators."""
23082308
if is_remote():
2309-
return [DecisionTreeClassificationModel(m) for m in self._call_java("trees").split(",")]
2309+
from pyspark.ml.util import RemoteModelRef
2310+
2311+
return [
2312+
DecisionTreeClassificationModel(RemoteModelRef(m))
2313+
for m in self._call_java("trees").split(",")
2314+
]
23102315
return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))]
23112316

23122317
@property
@@ -2805,7 +2810,12 @@ def featureImportances(self) -> Vector:
28052810
def trees(self) -> List[DecisionTreeRegressionModel]:
28062811
"""Trees in this ensemble. Warning: These have null parent Estimators."""
28072812
if is_remote():
2808-
return [DecisionTreeRegressionModel(m) for m in self._call_java("trees").split(",")]
2813+
from pyspark.ml.util import RemoteModelRef
2814+
2815+
return [
2816+
DecisionTreeRegressionModel(RemoteModelRef(m))
2817+
for m in self._call_java("trees").split(",")
2818+
]
28092819
return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]
28102820

28112821
def evaluateEachIteration(self, dataset: DataFrame) -> List[float]:

python/pyspark/ml/connect/readwrite.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,13 @@ def saveInstance(
7777
# Spark Connect ML is built on scala Spark.ML, that means we're only
7878
# supporting JavaModel or JavaEstimator or JavaEvaluator
7979
if isinstance(instance, JavaModel):
80+
from pyspark.ml.util import RemoteModelRef
81+
8082
model = cast("JavaModel", instance)
8183
params = serialize_ml_params(model, session.client)
82-
assert isinstance(model._java_obj, str)
84+
assert isinstance(model._java_obj, RemoteModelRef)
8385
writer = pb2.MlCommand.Write(
84-
obj_ref=pb2.ObjectRef(id=model._java_obj),
86+
obj_ref=pb2.ObjectRef(id=model._java_obj.ref_id),
8587
params=params,
8688
path=path,
8789
should_overwrite=shouldOverwrite,
@@ -270,9 +272,12 @@ def _get_class() -> Type[RL]:
270272
py_type = _get_class()
271273
# It must be JavaWrapper, since we're passing the string to the _java_obj
272274
if issubclass(py_type, JavaWrapper):
275+
from pyspark.ml.util import RemoteModelRef
276+
273277
if ml_type == pb2.MlOperator.OPERATOR_TYPE_MODEL:
274278
session.client.add_ml_cache(result.obj_ref.id)
275-
instance = py_type(result.obj_ref.id)
279+
remote_model_ref = RemoteModelRef(result.obj_ref.id)
280+
instance = py_type(remote_model_ref)
276281
else:
277282
instance = py_type()
278283
instance._resetUid(result.uid)

python/pyspark/ml/feature.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
_jvm,
6565
)
6666
from pyspark.ml.common import inherit_doc
67+
from pyspark.ml.util import RemoteModelRef
6768
from pyspark.sql.types import ArrayType, StringType
6869
from pyspark.sql.utils import is_remote
6970

@@ -1224,10 +1225,12 @@ def from_vocabulary(
12241225

12251226
if is_remote():
12261227
model = CountVectorizerModel()
1227-
model._java_obj = invoke_helper_attr(
1228-
"countVectorizerModelFromVocabulary",
1229-
model.uid,
1230-
list(vocabulary),
1228+
model._java_obj = RemoteModelRef(
1229+
invoke_helper_attr(
1230+
"countVectorizerModelFromVocabulary",
1231+
model.uid,
1232+
list(vocabulary),
1233+
)
12311234
)
12321235

12331236
else:
@@ -4843,10 +4846,12 @@ def from_labels(
48434846
"""
48444847
if is_remote():
48454848
model = StringIndexerModel()
4846-
model._java_obj = invoke_helper_attr(
4847-
"stringIndexerModelFromLabels",
4848-
model.uid,
4849-
(list(labels), ArrayType(StringType())),
4849+
model._java_obj = RemoteModelRef(
4850+
invoke_helper_attr(
4851+
"stringIndexerModelFromLabels",
4852+
model.uid,
4853+
(list(labels), ArrayType(StringType())),
4854+
)
48504855
)
48514856

48524857
else:
@@ -4882,13 +4887,15 @@ def from_arrays_of_labels(
48824887
"""
48834888
if is_remote():
48844889
model = StringIndexerModel()
4885-
model._java_obj = invoke_helper_attr(
4886-
"stringIndexerModelFromLabelsArray",
4887-
model.uid,
4888-
(
4889-
[list(labels) for labels in arrayOfLabels],
4890-
ArrayType(ArrayType(StringType())),
4891-
),
4890+
model._java_obj = RemoteModelRef(
4891+
invoke_helper_attr(
4892+
"stringIndexerModelFromLabelsArray",
4893+
model.uid,
4894+
(
4895+
[list(labels) for labels in arrayOfLabels],
4896+
ArrayType(ArrayType(StringType())),
4897+
),
4898+
)
48924899
)
48934900

48944901
else:

python/pyspark/ml/regression.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1614,7 +1614,12 @@ class RandomForestRegressionModel(
16141614
def trees(self) -> List[DecisionTreeRegressionModel]:
16151615
"""Trees in this ensemble. Warning: These have null parent Estimators."""
16161616
if is_remote():
1617-
return [DecisionTreeRegressionModel(m) for m in self._call_java("trees").split(",")]
1617+
from pyspark.ml.util import RemoteModelRef
1618+
1619+
return [
1620+
DecisionTreeRegressionModel(RemoteModelRef(m))
1621+
for m in self._call_java("trees").split(",")
1622+
]
16181623
return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]
16191624

16201625
@property
@@ -2005,7 +2010,12 @@ def featureImportances(self) -> Vector:
20052010
def trees(self) -> List[DecisionTreeRegressionModel]:
20062011
"""Trees in this ensemble. Warning: These have null parent Estimators."""
20072012
if is_remote():
2008-
return [DecisionTreeRegressionModel(m) for m in self._call_java("trees").split(",")]
2013+
from pyspark.ml.util import RemoteModelRef
2014+
2015+
return [
2016+
DecisionTreeRegressionModel(RemoteModelRef(m))
2017+
for m in self._call_java("trees").split(",")
2018+
]
20092019
return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]
20102020

20112021
def evaluateEachIteration(self, dataset: DataFrame, loss: str) -> List[float]:

python/pyspark/ml/tests/test_tuning.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ def test_train_validation_split(self):
9797
self.assertEqual(str(tvs_model.getEstimator()), str(model2.getEstimator()))
9898
self.assertEqual(str(tvs_model.getEvaluator()), str(model2.getEvaluator()))
9999

100-
@unittest.skip("Disabled due to a Python side reference count issue in _parallelFitTasks.")
101100
def test_cross_validator(self):
102101
dataset = self.spark.createDataFrame(
103102
[

python/pyspark/ml/util.py

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import json
1919
import os
20+
import threading
2021
import time
2122
import uuid
2223
import functools
@@ -75,7 +76,7 @@ def try_remote_intermediate_result(f: FuncT) -> FuncT:
7576
@functools.wraps(f)
7677
def wrapped(self: "JavaWrapper") -> Any:
7778
if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
78-
return f"{self._java_obj}.{f.__name__}"
79+
return f"{str(self._java_obj)}.{f.__name__}"
7980
else:
8081
return f(self)
8182

@@ -108,13 +109,18 @@ def invoke_remote_attribute_relation(
108109
from pyspark.ml.connect.proto import AttributeRelation
109110
from pyspark.sql.connect.session import SparkSession
110111
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
112+
from pyspark.ml.wrapper import JavaModel
111113

112114
session = SparkSession.getActiveSession()
113115
assert session is not None
114116

115-
assert isinstance(instance._java_obj, str)
116-
117-
methods, obj_ref = _extract_id_methods(instance._java_obj)
117+
if isinstance(instance, JavaModel):
118+
assert isinstance(instance._java_obj, RemoteModelRef)
119+
object_id = instance._java_obj.ref_id
120+
else:
121+
# model summary
122+
object_id = instance._java_obj # type: ignore
123+
methods, obj_ref = _extract_id_methods(object_id)
118124
methods.append(pb2.Fetch.Method(method=method, args=serialize(session.client, *args)))
119125
plan = AttributeRelation(obj_ref, methods)
120126

@@ -139,6 +145,33 @@ def wrapped(self: "JavaWrapper", *args: Any, **kwargs: Any) -> Any:
139145
return cast(FuncT, wrapped)
140146

141147

148+
class RemoteModelRef:
149+
def __init__(self, ref_id: str) -> None:
150+
self._ref_id = ref_id
151+
self._ref_count = 1
152+
self._lock = threading.Lock()
153+
154+
@property
155+
def ref_id(self) -> str:
156+
return self._ref_id
157+
158+
def add_ref(self) -> None:
159+
with self._lock:
160+
assert self._ref_count > 0
161+
self._ref_count += 1
162+
163+
def release_ref(self) -> None:
164+
with self._lock:
165+
assert self._ref_count > 0
166+
self._ref_count -= 1
167+
if self._ref_count == 0:
168+
# Delete the model if possible
169+
del_remote_cache(self.ref_id)
170+
171+
def __str__(self) -> str:
172+
return self.ref_id
173+
174+
142175
def try_remote_fit(f: FuncT) -> FuncT:
143176
"""Mark the function that fits a model."""
144177

@@ -165,7 +198,8 @@ def wrapped(self: "JavaEstimator", dataset: "ConnectDataFrame") -> Any:
165198
(_, properties, _) = client.execute_command(command)
166199
model_info = deserialize(properties)
167200
client.add_ml_cache(model_info.obj_ref.id)
168-
model = self._create_model(model_info.obj_ref.id)
201+
remote_model_ref = RemoteModelRef(model_info.obj_ref.id)
202+
model = self._create_model(remote_model_ref)
169203
if model.__class__.__name__ not in ["Bucketizer"]:
170204
model._resetUid(self.uid)
171205
return self._copyValues(model)
@@ -192,11 +226,11 @@ def wrapped(self: "JavaWrapper", dataset: "ConnectDataFrame") -> Any:
192226
if isinstance(self, Model):
193227
from pyspark.ml.connect.proto import TransformerRelation
194228

195-
assert isinstance(self._java_obj, str)
229+
assert isinstance(self._java_obj, RemoteModelRef)
196230
params = serialize_ml_params(self, session.client)
197231
plan = TransformerRelation(
198232
child=dataset._plan,
199-
name=self._java_obj,
233+
name=self._java_obj.ref_id,
200234
ml_params=params,
201235
is_model=True,
202236
)
@@ -246,11 +280,20 @@ def wrapped(self: "JavaWrapper", name: str, *args: Any) -> Any:
246280
from pyspark.sql.connect.session import SparkSession
247281
from pyspark.ml.connect.util import _extract_id_methods
248282
from pyspark.ml.connect.serialize import serialize, deserialize
283+
from pyspark.ml.wrapper import JavaModel
249284

250285
session = SparkSession.getActiveSession()
251286
assert session is not None
252-
assert isinstance(self._java_obj, str)
253-
methods, obj_ref = _extract_id_methods(self._java_obj)
287+
if self._java_obj == ML_CONNECT_HELPER_ID:
288+
obj_id = ML_CONNECT_HELPER_ID
289+
else:
290+
if isinstance(self, JavaModel):
291+
assert isinstance(self._java_obj, RemoteModelRef)
292+
obj_id = self._java_obj.ref_id
293+
else:
294+
# model summary
295+
obj_id = self._java_obj # type: ignore
296+
methods, obj_ref = _extract_id_methods(obj_id)
254297
methods.append(pb2.Fetch.Method(method=name, args=serialize(session.client, *args)))
255298
command = pb2.Command()
256299
command.ml_command.fetch.CopyFrom(
@@ -301,10 +344,8 @@ def wrapped(self: "JavaWrapper") -> Any:
301344
except Exception:
302345
return
303346

304-
if in_remote:
305-
# Delete the model if possible
306-
model_id = self._java_obj
307-
del_remote_cache(cast(str, model_id))
347+
if in_remote and isinstance(self._java_obj, RemoteModelRef):
348+
self._java_obj.release_ref()
308349
return
309350
else:
310351
return f(self)

python/pyspark/ml/wrapper.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -356,9 +356,15 @@ def copy(self: "JP", extra: Optional["ParamMap"] = None) -> "JP":
356356
if extra is None:
357357
extra = dict()
358358
that = super(JavaParams, self).copy(extra)
359-
if self._java_obj is not None and not isinstance(self._java_obj, str):
360-
that._java_obj = self._java_obj.copy(self._empty_java_param_map())
361-
that._transfer_params_to_java()
359+
if self._java_obj is not None:
360+
from pyspark.ml.util import RemoteModelRef
361+
362+
if isinstance(self._java_obj, RemoteModelRef):
363+
that._java_obj = self._java_obj
364+
self._java_obj.add_ref()
365+
elif not isinstance(self._java_obj, str):
366+
that._java_obj = self._java_obj.copy(self._empty_java_param_map())
367+
that._transfer_params_to_java()
362368
return that
363369

364370
@try_remote_intercept
@@ -452,6 +458,10 @@ def __init__(self, java_model: Optional["JavaObject"] = None):
452458
other ML classes).
453459
"""
454460
super(JavaModel, self).__init__(java_model)
461+
if is_remote() and java_model is not None:
462+
from pyspark.ml.util import RemoteModelRef
463+
464+
assert isinstance(java_model, RemoteModelRef)
455465
if java_model is not None and not is_remote():
456466
# SPARK-10931: This is a temporary fix to allow models to own params
457467
# from estimators. Eventually, these params should be in models through

python/pyspark/sql/connect/client/core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1981,9 +1981,10 @@ def add_ml_cache(self, cache_id: str) -> None:
19811981
self.thread_local.ml_caches.add(cache_id)
19821982

19831983
def remove_ml_cache(self, cache_id: str) -> None:
1984+
deleted = self._delete_ml_cache([cache_id])
1985+
# TODO: Fix the code: change thread-local `ml_caches` to global `ml_caches`.
19841986
if hasattr(self.thread_local, "ml_caches"):
19851987
if cache_id in self.thread_local.ml_caches:
1986-
deleted = self._delete_ml_cache([cache_id])
19871988
for obj_id in deleted:
19881989
self.thread_local.ml_caches.remove(obj_id)
19891990

0 commit comments

Comments
 (0)