@@ -64,7 +64,9 @@ def test_save_model(self):
64
64
ml = MLflowRegistry (TRACKING_URI )
65
65
skeys = self .skeys
66
66
dkeys = self .dkeys
67
- status = ml .save (skeys = skeys , dkeys = dkeys , artifact = self .model , run_id = "1234" )
67
+ status = ml .save (
68
+ skeys = skeys , dkeys = dkeys , artifact = self .model , run_id = "1234" , artifact_type = "pytorch"
69
+ )
68
70
mock_status = "READY"
69
71
self .assertEqual (mock_status , status .status )
70
72
@@ -79,7 +81,7 @@ def test_save_model_sklearn(self):
79
81
ml = MLflowRegistry (TRACKING_URI )
80
82
skeys = self .skeys
81
83
dkeys = self .dkeys
82
- status = ml .save (skeys = skeys , dkeys = dkeys , artifact = model )
84
+ status = ml .save (skeys = skeys , dkeys = dkeys , artifact = model , artifact_type = "sklearn" )
83
85
mock_status = "READY"
84
86
self .assertEqual (mock_status , status .status )
85
87
@@ -96,7 +98,7 @@ def test_load_model_when_pytorch_model_exist1(self):
96
98
ml = MLflowRegistry (TRACKING_URI )
97
99
skeys = self .skeys
98
100
dkeys = self .dkeys
99
- ml .save (skeys = skeys , dkeys = dkeys , artifact = model , ** {"lr" : 0.01 })
101
+ ml .save (skeys = skeys , dkeys = dkeys , artifact = model , ** {"lr" : 0.01 }, artifact_type = "pytorch" )
100
102
data = ml .load (skeys = skeys , dkeys = dkeys , artifact_type = "pytorch" )
101
103
self .assertIsNotNone (data .metadata )
102
104
self .assertIsInstance (data .artifact , VanillaAE )
@@ -113,7 +115,7 @@ def test_load_model_when_pytorch_model_exist2(self):
113
115
ml = MLflowRegistry (TRACKING_URI , models_to_retain = 2 )
114
116
skeys = self .skeys
115
117
dkeys = self .dkeys
116
- ml .save (skeys = skeys , dkeys = dkeys , artifact = model )
118
+ ml .save (skeys = skeys , dkeys = dkeys , artifact = model , artifact_type = "pytorch" )
117
119
data = ml .load (skeys = skeys , dkeys = dkeys , artifact_type = "pytorch" )
118
120
self .assertEqual (data .metadata , {})
119
121
self .assertIsInstance (data .artifact , VanillaAE )
@@ -139,8 +141,8 @@ def test_load_model_when_sklearn_model_exist(self):
139
141
skeys = self .skeys
140
142
dkeys = self .dkeys
141
143
scaler = StandardScaler ()
142
- ml .save (skeys = skeys , dkeys = dkeys , artifact = scaler )
143
- data = ml .load (skeys = skeys , dkeys = dkeys )
144
+ ml .save (skeys = skeys , dkeys = dkeys , artifact = scaler , artifact_type = "sklearn" )
145
+ data = ml .load (skeys = skeys , dkeys = dkeys , artifact_type = "sklearn" )
144
146
print (data )
145
147
self .assertIsInstance (data .artifact , StandardScaler )
146
148
self .assertEqual (data .metadata , {})
@@ -158,8 +160,8 @@ def test_load_model_with_version(self):
158
160
ml = MLflowRegistry (TRACKING_URI )
159
161
skeys = self .skeys
160
162
dkeys = self .dkeys
161
- ml .save (skeys = skeys , dkeys = dkeys , artifact = model )
162
- data = ml .load (skeys = skeys , dkeys = dkeys , version = "5" , latest = False )
163
+ ml .save (skeys = skeys , dkeys = dkeys , artifact = model , artifact_type = "pytorch" )
164
+ data = ml .load (skeys = skeys , dkeys = dkeys , version = "5" , latest = False , artifact_type = "pytorch" )
163
165
self .assertIsInstance (data .artifact , VanillaAE )
164
166
self .assertEqual (data .metadata , {})
165
167
@@ -175,7 +177,7 @@ def test_staging_model_load_error(self):
175
177
ml = MLflowRegistry (TRACKING_URI , model_stage = ModelStage .STAGE )
176
178
skeys = self .skeys
177
179
dkeys = self .dkeys
178
- ml .load (skeys = skeys , dkeys = dkeys )
180
+ ml .load (skeys = skeys , dkeys = dkeys , artifact_type = "pytorch" )
179
181
self .assertRaises (ModelVersionError )
180
182
181
183
@patch ("mlflow.tracking.MlflowClient.search_model_versions" , mock_list_of_model_version2 )
@@ -188,7 +190,7 @@ def test_both_version_latest_model_with_version(self):
188
190
skeys = self .skeys
189
191
dkeys = self .dkeys
190
192
with self .assertRaises (ValueError ):
191
- ml .load (skeys = skeys , dkeys = dkeys , latest = False )
193
+ ml .load (skeys = skeys , dkeys = dkeys , latest = False , artifact_type = "pytorch" )
192
194
193
195
@patch ("mlflow.tracking.MlflowClient.search_model_versions" , mock_list_of_model_version2 )
194
196
@patch ("mlflow.tracking.MlflowClient.transition_model_version_stage" , mock_transition_stage )
@@ -211,7 +213,11 @@ def test_load_model_when_no_model_02(self):
211
213
fake_dkeys = ["error" ]
212
214
ml = MLflowRegistry (TRACKING_URI )
213
215
with self .assertLogs (level = "ERROR" ) as log :
214
- ml .load (skeys = fake_skeys , dkeys = fake_dkeys , artifact_type = "pytorch" )
216
+ ml .load (
217
+ skeys = fake_skeys ,
218
+ dkeys = fake_dkeys ,
219
+ artifact_type = "pytorch" ,
220
+ )
215
221
self .assertTrue (log .output )
216
222
217
223
@patch ("mlflow.tracking.MlflowClient.get_latest_versions" , mock_get_model_version )
@@ -237,6 +243,9 @@ def test_no_implementation(self):
237
243
with self .assertLogs (level = "ERROR" ) as log :
238
244
ml .load (skeys = fake_skeys , dkeys = fake_dkeys , artifact_type = "somerandom" )
239
245
self .assertTrue (log .output )
246
+ with self .assertLogs (level = "ERROR" ) as log :
247
+ ml .load (skeys = fake_skeys , dkeys = fake_dkeys )
248
+ self .assertTrue (log .output )
240
249
241
250
@patch ("mlflow.start_run" , Mock (return_value = ActiveRun (return_pytorch_rundata_dict ())))
242
251
@patch ("mlflow.active_run" , Mock (return_value = return_pytorch_rundata_dict ()))
@@ -252,7 +261,7 @@ def test_delete_model_when_model_exist(self):
252
261
ml = MLflowRegistry (TRACKING_URI )
253
262
skeys = self .skeys
254
263
dkeys = self .dkeys
255
- ml .save (skeys = skeys , dkeys = dkeys , artifact = model , ** {"lr" : 0.01 })
264
+ ml .save (skeys = skeys , dkeys = dkeys , artifact = model , artifact_type = "pytorch" , ** {"lr" : 0.01 })
256
265
ml .delete (skeys = skeys , dkeys = dkeys , version = "5" )
257
266
with self .assertLogs (level = "ERROR" ) as log :
258
267
ml .load (skeys = skeys , dkeys = dkeys )
@@ -276,7 +285,9 @@ def test_save_failed(self):
276
285
277
286
ml = MLflowRegistry (TRACKING_URI )
278
287
with self .assertLogs (level = "ERROR" ) as log :
279
- ml .save (skeys = fake_skeys , dkeys = fake_dkeys , artifact = self .model )
288
+ ml .save (
289
+ skeys = fake_skeys , dkeys = fake_dkeys , artifact = self .model , artifact_type = "pytorch"
290
+ )
280
291
self .assertTrue (log .output )
281
292
282
293
@patch ("mlflow.start_run" , Mock (return_value = ActiveRun (return_pytorch_rundata_dict ())))
@@ -290,7 +301,11 @@ def test_load_no_model_found(self):
290
301
ml = MLflowRegistry (TRACKING_URI )
291
302
skeys = self .skeys
292
303
dkeys = self .dkeys
293
- data = ml .load (skeys = skeys , dkeys = dkeys , artifact_type = "pytorch" )
304
+ data = ml .load (
305
+ skeys = skeys ,
306
+ dkeys = dkeys ,
307
+ artifact_type = "pytorch" ,
308
+ )
294
309
self .assertIsNone (data )
295
310
296
311
@patch ("mlflow.start_run" , Mock (return_value = ActiveRun (return_pytorch_rundata_dict ())))
@@ -317,7 +332,13 @@ def test_load_other_mlflow_err(self):
317
332
def test_is_model_stale_true (self ):
318
333
model = self .model
319
334
ml = MLflowRegistry (TRACKING_URI )
320
- ml .save (skeys = self .skeys , dkeys = self .dkeys , artifact = model , ** {"lr" : 0.01 })
335
+ ml .save (
336
+ skeys = self .skeys ,
337
+ dkeys = self .dkeys ,
338
+ artifact = model ,
339
+ ** {"lr" : 0.01 },
340
+ artifact_type = "pytorch" ,
341
+ )
321
342
data = ml .load (skeys = self .skeys , dkeys = self .dkeys , artifact_type = "pytorch" )
322
343
self .assertTrue (ml .is_artifact_stale (data , 12 ))
323
344
@@ -332,7 +353,13 @@ def test_is_model_stale_true(self):
332
353
def test_is_model_stale_false (self ):
333
354
model = self .model
334
355
ml = MLflowRegistry (TRACKING_URI )
335
- ml .save (skeys = self .skeys , dkeys = self .dkeys , artifact = model , ** {"lr" : 0.01 })
356
+ ml .save (
357
+ skeys = self .skeys ,
358
+ dkeys = self .dkeys ,
359
+ artifact = model ,
360
+ ** {"lr" : 0.01 },
361
+ artifact_type = "pytorch" ,
362
+ )
336
363
data = ml .load (skeys = self .skeys , dkeys = self .dkeys , artifact_type = "pytorch" )
337
364
with freeze_time ("2022-05-24 10:30:00" ):
338
365
self .assertFalse (ml .is_artifact_stale (data , 12 ))
@@ -365,7 +392,13 @@ def test_cache(self):
365
392
def test_cache_loading (self ):
366
393
cache_registry = LocalLRUCache (ttl = 50000 )
367
394
ml = MLflowRegistry (TRACKING_URI , cache_registry = cache_registry )
368
- ml .save (skeys = self .skeys , dkeys = self .dkeys , artifact = self .model , ** {"lr" : 0.01 })
395
+ ml .save (
396
+ skeys = self .skeys ,
397
+ dkeys = self .dkeys ,
398
+ artifact = self .model ,
399
+ ** {"lr" : 0.01 },
400
+ artifact_type = "pytorch" ,
401
+ )
369
402
ml .load (skeys = self .skeys , dkeys = self .dkeys , artifact_type = "pytorch" )
370
403
key = MLflowRegistry .construct_key (self .skeys , self .dkeys )
371
404
self .assertIsNotNone (ml ._load_from_cache (key ))
0 commit comments