Skip to content

Commit c4d3bcc

Browse files
authored
fix: support all sklearn classes (#408)
1. Support artifact-type in mlflow registry from user. 2. Update docs --------- Signed-off-by: Kushal Batra <i.kushalbatra@gmail.com>
1 parent aa3b4bb commit c4d3bcc

File tree

4 files changed

+76
-39
lines changed

4 files changed

+76
-39
lines changed

docs/ml-flow.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ Numalogic provides `MLflowRegistry`, to save and load models to/from MLflow.
2323

2424
Here, `tracking_uri` is the uri where mlflow server is running. The `static_keys` and `dynamic_keys` are used to form a unique key for the model.
2525

26-
The `artifact` would be the model or transformer object that needs to be saved.
26+
The `artifact` would be the model or transformer object that needs to be saved. Artifact saving also takes in 'artifact_type' which is the type of the artifact being saved. Currently, 'pytorch', 'sklearn' and 'pyfunc' is supported.
2727
A dictionary of metadata can also be saved along with the artifact.
2828
```python
2929
from numalogic.registry import MLflowRegistry
@@ -37,13 +37,13 @@ dynamic_keys = ["vanilla", "seq10"]
3737

3838
registry = MLflowRegistry(tracking_uri="http://0.0.0.0:5000")
3939
registry.save(
40-
skeys=static_keys, dkeys=dynamic_keys, artifact=model, seq_len=10, lr=0.001
40+
skeys=static_keys, dkeys=dynamic_keys, artifact_type='pytorch', artifact=model, seq_len=10, lr=0.001
4141
)
4242
```
4343

4444
### Model loading
4545

46-
Once, the models are save to MLflow, the `load` function of `MLflowRegistry` can be used to load the model.
46+
Once, the models are save to MLflow, the `load` function of `MLflowRegistry` can be used to load the model. Like how the artifacts were saved with 'artifact_type', the same type shall be passed to the `load` function as well.
4747

4848
```python
4949
from numalogic.registry import MLflowRegistry

examples/multi_udf/src/udf/train.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,15 @@ def __init__(self):
3434
self.model_key = "ae::model"
3535

3636
def _save_artifact(
37-
self, model, skeys: list[str], dkeys: list[str], _: Optional[TimeseriesTrainer] = None
37+
self,
38+
model,
39+
skeys: list[str],
40+
dkeys: list[str],
41+
artifact_type: str,
42+
_: Optional[TimeseriesTrainer] = None,
3843
) -> None:
3944
"""Saves the model in the registry."""
40-
self.registry.save(skeys=skeys, dkeys=dkeys, artifact=model)
45+
self.registry.save(skeys=skeys, dkeys=dkeys, artifact=model, artifact_type=artifact_type)
4146

4247
@staticmethod
4348
def _fit_preprocess(data: pd.DataFrame) -> npt.NDArray[float]:
@@ -93,8 +98,8 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
9398
thresh_clf = self._fit_threshold(train_reconerr.numpy())
9499

95100
# Save to registry
96-
self._save_artifact(model, ["ae"], ["model"], trainer)
97-
self._save_artifact(thresh_clf, ["thresh_clf"], ["model"])
101+
self._save_artifact(model, ["ae"], ["model"], "pytorch", trainer)
102+
self._save_artifact(thresh_clf, ["thresh_clf"], ["model"], artifact_type="sklearn")
98103
LOGGER.info("%s - Model Saving complete", payload.uuid)
99104

100105
# Train is the last vertex in the graph

numalogic/registry/mlflow_registry.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@
2121
from mlflow.exceptions import RestException
2222
from mlflow.protos.databricks_pb2 import ErrorCode, RESOURCE_DOES_NOT_EXIST
2323
from mlflow.tracking import MlflowClient
24-
from torch import nn
2524

26-
from numalogic.base import BaseThresholdModel, BaseTransformer
2725
from numalogic.registry import ArtifactManager, ArtifactData
2826
from numalogic.registry.artifact import ArtifactCache
2927
from numalogic.tools.exceptions import ModelVersionError
@@ -65,7 +63,8 @@ class MLflowRegistry(ArtifactManager):
6563
>>> data = [[0, 0], [0, 0], [1, 1], [1, 1]]
6664
>>> scaler = StandardScaler.fit(data)
6765
>>> registry = MLflowRegistry(tracking_uri="http://0.0.0.0:8080")
68-
>>> registry.save(skeys=["model"], dkeys=["AE"], artifact=VanillaAE(10))
66+
>>> registry.save(skeys=["model"], dkeys=["AE"], artifact=VanillaAE(10),
67+
>>> artifact_type="pytorch")
6968
>>> artifact_data = registry.load(skeys=["model"], dkeys=["AE"], artifact_type="pytorch")
7069
"""
7170

@@ -100,17 +99,13 @@ def __init__(
10099
self.model_stage = model_stage
101100
self.cache_registry = cache_registry
102101

103-
@staticmethod
104-
def handler_from_obj(artifact: artifact_t):
105-
if isinstance(artifact, nn.Module):
106-
return mlflow.pytorch
107-
if isinstance(artifact, (BaseThresholdModel, BaseTransformer)):
108-
return mlflow.sklearn
109-
return mlflow.pyfunc
110-
111102
@staticmethod
112103
def handler_from_type(artifact_type: str):
113104
"""Helper method to return the right handler given the artifact type."""
105+
if not artifact_type:
106+
raise ValueError(
107+
"Artifact Type not provided. Options include: {pytorch, sklearn, pyfunc}"
108+
)
114109
if artifact_type == "pytorch":
115110
return mlflow.pytorch
116111
if artifact_type == "sklearn":
@@ -137,9 +132,9 @@ def load(
137132
self,
138133
skeys: KEYS,
139134
dkeys: KEYS,
135+
artifact_type: Optional[str] = None,
140136
latest: bool = True,
141137
version: Optional[str] = None,
142-
artifact_type: str = "pytorch",
143138
) -> Optional[ArtifactData]:
144139
"""Load the artifact from the registry. The artifact is loaded from the cache if available.
145140
@@ -149,7 +144,8 @@ def load(
149144
dkeys: Dynamic keys
150145
latest: Load the latest version of the model (default = True)
151146
version: Version of the model to load (default = None)
152-
artifact_type: Type of the artifact to load (default = "pytorch").
147+
artifact_type: Type of the artifact to load. Options include: pytorch, pyfunc
148+
and sklearn.
153149
154150
Returns
155151
-------
@@ -205,6 +201,7 @@ def save(
205201
dkeys: KEYS,
206202
artifact: artifact_t,
207203
run_id: Optional[str] = None,
204+
artifact_type: Optional[str] = None,
208205
**metadata: META_VT,
209206
) -> Optional[ModelVersion]:
210207
"""Saves the artifact into mlflow registry and updates version.
@@ -216,13 +213,15 @@ def save(
216213
artifact: primary artifact to be saved
217214
run_id: mlflow run id
218215
metadata: additional metadata surrounding the artifact that needs to be saved.
216+
artifact_type: Type of the artifact to save. Options include: pytorch, pyfunc
217+
and sklearn.
219218
220219
Returns
221220
-------
222221
mlflow ModelVersion instance
223222
"""
224223
model_key = self.construct_key(skeys, dkeys)
225-
handler = self.handler_from_obj(artifact)
224+
handler = self.handler_from_type(artifact_type)
226225
try:
227226
mlflow.start_run(run_id=run_id)
228227
handler.log_model(artifact, "model", registered_model_name=model_key)
@@ -241,7 +240,7 @@ def save(
241240
@staticmethod
242241
def is_artifact_stale(artifact_data: ArtifactData, freq_hr: int) -> bool:
243242
"""Returns whether the given artifact is stale or not, i.e. if
244-
more time has elasped since it was last retrained.
243+
more time has elapsed since it was last retrained.
245244
246245
Args:
247246
----

tests/registry/test_mlflow_registry.py

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ def test_save_model(self):
6464
ml = MLflowRegistry(TRACKING_URI)
6565
skeys = self.skeys
6666
dkeys = self.dkeys
67-
status = ml.save(skeys=skeys, dkeys=dkeys, artifact=self.model, run_id="1234")
67+
status = ml.save(
68+
skeys=skeys, dkeys=dkeys, artifact=self.model, run_id="1234", artifact_type="pytorch"
69+
)
6870
mock_status = "READY"
6971
self.assertEqual(mock_status, status.status)
7072

@@ -79,7 +81,7 @@ def test_save_model_sklearn(self):
7981
ml = MLflowRegistry(TRACKING_URI)
8082
skeys = self.skeys
8183
dkeys = self.dkeys
82-
status = ml.save(skeys=skeys, dkeys=dkeys, artifact=model)
84+
status = ml.save(skeys=skeys, dkeys=dkeys, artifact=model, artifact_type="sklearn")
8385
mock_status = "READY"
8486
self.assertEqual(mock_status, status.status)
8587

@@ -96,7 +98,7 @@ def test_load_model_when_pytorch_model_exist1(self):
9698
ml = MLflowRegistry(TRACKING_URI)
9799
skeys = self.skeys
98100
dkeys = self.dkeys
99-
ml.save(skeys=skeys, dkeys=dkeys, artifact=model, **{"lr": 0.01})
101+
ml.save(skeys=skeys, dkeys=dkeys, artifact=model, **{"lr": 0.01}, artifact_type="pytorch")
100102
data = ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch")
101103
self.assertIsNotNone(data.metadata)
102104
self.assertIsInstance(data.artifact, VanillaAE)
@@ -113,7 +115,7 @@ def test_load_model_when_pytorch_model_exist2(self):
113115
ml = MLflowRegistry(TRACKING_URI, models_to_retain=2)
114116
skeys = self.skeys
115117
dkeys = self.dkeys
116-
ml.save(skeys=skeys, dkeys=dkeys, artifact=model)
118+
ml.save(skeys=skeys, dkeys=dkeys, artifact=model, artifact_type="pytorch")
117119
data = ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch")
118120
self.assertEqual(data.metadata, {})
119121
self.assertIsInstance(data.artifact, VanillaAE)
@@ -139,8 +141,8 @@ def test_load_model_when_sklearn_model_exist(self):
139141
skeys = self.skeys
140142
dkeys = self.dkeys
141143
scaler = StandardScaler()
142-
ml.save(skeys=skeys, dkeys=dkeys, artifact=scaler)
143-
data = ml.load(skeys=skeys, dkeys=dkeys)
144+
ml.save(skeys=skeys, dkeys=dkeys, artifact=scaler, artifact_type="sklearn")
145+
data = ml.load(skeys=skeys, dkeys=dkeys, artifact_type="sklearn")
144146
print(data)
145147
self.assertIsInstance(data.artifact, StandardScaler)
146148
self.assertEqual(data.metadata, {})
@@ -158,8 +160,8 @@ def test_load_model_with_version(self):
158160
ml = MLflowRegistry(TRACKING_URI)
159161
skeys = self.skeys
160162
dkeys = self.dkeys
161-
ml.save(skeys=skeys, dkeys=dkeys, artifact=model)
162-
data = ml.load(skeys=skeys, dkeys=dkeys, version="5", latest=False)
163+
ml.save(skeys=skeys, dkeys=dkeys, artifact=model, artifact_type="pytorch")
164+
data = ml.load(skeys=skeys, dkeys=dkeys, version="5", latest=False, artifact_type="pytorch")
163165
self.assertIsInstance(data.artifact, VanillaAE)
164166
self.assertEqual(data.metadata, {})
165167

@@ -175,7 +177,7 @@ def test_staging_model_load_error(self):
175177
ml = MLflowRegistry(TRACKING_URI, model_stage=ModelStage.STAGE)
176178
skeys = self.skeys
177179
dkeys = self.dkeys
178-
ml.load(skeys=skeys, dkeys=dkeys)
180+
ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch")
179181
self.assertRaises(ModelVersionError)
180182

181183
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@@ -188,7 +190,7 @@ def test_both_version_latest_model_with_version(self):
188190
skeys = self.skeys
189191
dkeys = self.dkeys
190192
with self.assertRaises(ValueError):
191-
ml.load(skeys=skeys, dkeys=dkeys, latest=False)
193+
ml.load(skeys=skeys, dkeys=dkeys, latest=False, artifact_type="pytorch")
192194

193195
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
194196
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@@ -211,7 +213,11 @@ def test_load_model_when_no_model_02(self):
211213
fake_dkeys = ["error"]
212214
ml = MLflowRegistry(TRACKING_URI)
213215
with self.assertLogs(level="ERROR") as log:
214-
ml.load(skeys=fake_skeys, dkeys=fake_dkeys, artifact_type="pytorch")
216+
ml.load(
217+
skeys=fake_skeys,
218+
dkeys=fake_dkeys,
219+
artifact_type="pytorch",
220+
)
215221
self.assertTrue(log.output)
216222

217223
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@@ -237,6 +243,9 @@ def test_no_implementation(self):
237243
with self.assertLogs(level="ERROR") as log:
238244
ml.load(skeys=fake_skeys, dkeys=fake_dkeys, artifact_type="somerandom")
239245
self.assertTrue(log.output)
246+
with self.assertLogs(level="ERROR") as log:
247+
ml.load(skeys=fake_skeys, dkeys=fake_dkeys)
248+
self.assertTrue(log.output)
240249

241250
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
242251
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
@@ -252,7 +261,7 @@ def test_delete_model_when_model_exist(self):
252261
ml = MLflowRegistry(TRACKING_URI)
253262
skeys = self.skeys
254263
dkeys = self.dkeys
255-
ml.save(skeys=skeys, dkeys=dkeys, artifact=model, **{"lr": 0.01})
264+
ml.save(skeys=skeys, dkeys=dkeys, artifact=model, artifact_type="pytorch", **{"lr": 0.01})
256265
ml.delete(skeys=skeys, dkeys=dkeys, version="5")
257266
with self.assertLogs(level="ERROR") as log:
258267
ml.load(skeys=skeys, dkeys=dkeys)
@@ -276,7 +285,9 @@ def test_save_failed(self):
276285

277286
ml = MLflowRegistry(TRACKING_URI)
278287
with self.assertLogs(level="ERROR") as log:
279-
ml.save(skeys=fake_skeys, dkeys=fake_dkeys, artifact=self.model)
288+
ml.save(
289+
skeys=fake_skeys, dkeys=fake_dkeys, artifact=self.model, artifact_type="pytorch"
290+
)
280291
self.assertTrue(log.output)
281292

282293
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@@ -290,7 +301,11 @@ def test_load_no_model_found(self):
290301
ml = MLflowRegistry(TRACKING_URI)
291302
skeys = self.skeys
292303
dkeys = self.dkeys
293-
data = ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch")
304+
data = ml.load(
305+
skeys=skeys,
306+
dkeys=dkeys,
307+
artifact_type="pytorch",
308+
)
294309
self.assertIsNone(data)
295310

296311
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@@ -317,7 +332,13 @@ def test_load_other_mlflow_err(self):
317332
def test_is_model_stale_true(self):
318333
model = self.model
319334
ml = MLflowRegistry(TRACKING_URI)
320-
ml.save(skeys=self.skeys, dkeys=self.dkeys, artifact=model, **{"lr": 0.01})
335+
ml.save(
336+
skeys=self.skeys,
337+
dkeys=self.dkeys,
338+
artifact=model,
339+
**{"lr": 0.01},
340+
artifact_type="pytorch",
341+
)
321342
data = ml.load(skeys=self.skeys, dkeys=self.dkeys, artifact_type="pytorch")
322343
self.assertTrue(ml.is_artifact_stale(data, 12))
323344

@@ -332,7 +353,13 @@ def test_is_model_stale_true(self):
332353
def test_is_model_stale_false(self):
333354
model = self.model
334355
ml = MLflowRegistry(TRACKING_URI)
335-
ml.save(skeys=self.skeys, dkeys=self.dkeys, artifact=model, **{"lr": 0.01})
356+
ml.save(
357+
skeys=self.skeys,
358+
dkeys=self.dkeys,
359+
artifact=model,
360+
**{"lr": 0.01},
361+
artifact_type="pytorch",
362+
)
336363
data = ml.load(skeys=self.skeys, dkeys=self.dkeys, artifact_type="pytorch")
337364
with freeze_time("2022-05-24 10:30:00"):
338365
self.assertFalse(ml.is_artifact_stale(data, 12))
@@ -365,7 +392,13 @@ def test_cache(self):
365392
def test_cache_loading(self):
366393
cache_registry = LocalLRUCache(ttl=50000)
367394
ml = MLflowRegistry(TRACKING_URI, cache_registry=cache_registry)
368-
ml.save(skeys=self.skeys, dkeys=self.dkeys, artifact=self.model, **{"lr": 0.01})
395+
ml.save(
396+
skeys=self.skeys,
397+
dkeys=self.dkeys,
398+
artifact=self.model,
399+
**{"lr": 0.01},
400+
artifact_type="pytorch",
401+
)
369402
ml.load(skeys=self.skeys, dkeys=self.dkeys, artifact_type="pytorch")
370403
key = MLflowRegistry.construct_key(self.skeys, self.dkeys)
371404
self.assertIsNotNone(ml._load_from_cache(key))

0 commit comments

Comments
 (0)