@@ -428,31 +428,46 @@ def test_set_tracking_uri(mlflow_mock):
428
428
_ = logger .experiment
429
429
mlflow_mock .set_tracking_uri .assert_called_with ("the_tracking_uri" )
430
430
431
+
431
432
def test_mlflowlogger_metric_deduplication (monkeypatch ):
432
433
import types
434
+
433
435
from lightning .pytorch .loggers .mlflow import MLFlowLogger
434
436
435
437
# Dummy MLflow client to record log_batch calls
436
438
logged_metrics = []
439
+
437
440
class DummyMlflowClient :
438
441
def log_batch (self , run_id , metrics , ** kwargs ):
439
442
logged_metrics .extend (metrics )
440
- def set_tracking_uri (self , uri ): pass
441
- def create_run (self , experiment_id , tags ):
442
- class Run : info = types .SimpleNamespace (run_id = "dummy_run_id" )
443
+
444
+ def set_tracking_uri (self , uri ):
445
+ pass
446
+
447
+ def create_run (self , experiment_id , tags ):
448
+ class Run :
449
+ info = types .SimpleNamespace (run_id = "dummy_run_id" )
450
+
443
451
return Run ()
452
+
444
453
def get_run (self , run_id ):
445
- class Run : info = types .SimpleNamespace (experiment_id = "dummy_experiment_id" )
454
+ class Run :
455
+ info = types .SimpleNamespace (experiment_id = "dummy_experiment_id" )
456
+
446
457
return Run ()
447
- def get_experiment_by_name (self , name ): return None
448
- def create_experiment (self , name , artifact_location = None ): return "dummy_experiment_id"
458
+
459
+ def get_experiment_by_name (self , name ):
460
+ return None
461
+
462
+ def create_experiment (self , name , artifact_location = None ):
463
+ return "dummy_experiment_id"
449
464
450
465
# Patch the MLFlowLogger to use DummyMlflowClient
451
466
monkeypatch .setattr ("mlflow.tracking.MlflowClient" , lambda * a , ** k : DummyMlflowClient ())
452
467
453
468
logger = MLFlowLogger (experiment_name = "test_exp" )
454
- logger .log_metrics ({' foo' : 1.0 }, step = 5 )
455
- logger .log_metrics ({' foo' : 1.0 }, step = 5 ) # duplicate
469
+ logger .log_metrics ({" foo" : 1.0 }, step = 5 )
470
+ logger .log_metrics ({" foo" : 1.0 }, step = 5 ) # duplicate
456
471
457
472
# Only the first metric should be logged
458
473
assert len (logged_metrics ) == 1
0 commit comments