Skip to content

Commit 49f8f4c

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent fd4bafa commit 49f8f4c

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed

tests/tests_pytorch/loggers/test_mlflow.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -428,31 +428,46 @@ def test_set_tracking_uri(mlflow_mock):
428428
_ = logger.experiment
429429
mlflow_mock.set_tracking_uri.assert_called_with("the_tracking_uri")
430430

431+
431432
def test_mlflowlogger_metric_deduplication(monkeypatch):
432433
import types
434+
433435
from lightning.pytorch.loggers.mlflow import MLFlowLogger
434436

435437
# Dummy MLflow client to record log_batch calls
436438
logged_metrics = []
439+
437440
class DummyMlflowClient:
438441
def log_batch(self, run_id, metrics, **kwargs):
439442
logged_metrics.extend(metrics)
440-
def set_tracking_uri(self, uri): pass
441-
def create_run(self, experiment_id, tags):
442-
class Run: info = types.SimpleNamespace(run_id="dummy_run_id")
443+
444+
def set_tracking_uri(self, uri):
445+
pass
446+
447+
def create_run(self, experiment_id, tags):
448+
class Run:
449+
info = types.SimpleNamespace(run_id="dummy_run_id")
450+
443451
return Run()
452+
444453
def get_run(self, run_id):
445-
class Run: info = types.SimpleNamespace(experiment_id="dummy_experiment_id")
454+
class Run:
455+
info = types.SimpleNamespace(experiment_id="dummy_experiment_id")
456+
446457
return Run()
447-
def get_experiment_by_name(self, name): return None
448-
def create_experiment(self, name, artifact_location=None): return "dummy_experiment_id"
458+
459+
def get_experiment_by_name(self, name):
460+
return None
461+
462+
def create_experiment(self, name, artifact_location=None):
463+
return "dummy_experiment_id"
449464

450465
# Patch the MLFlowLogger to use DummyMlflowClient
451466
monkeypatch.setattr("mlflow.tracking.MlflowClient", lambda *a, **k: DummyMlflowClient())
452467

453468
logger = MLFlowLogger(experiment_name="test_exp")
454-
logger.log_metrics({'foo': 1.0}, step=5)
455-
logger.log_metrics({'foo': 1.0}, step=5) # duplicate
469+
logger.log_metrics({"foo": 1.0}, step=5)
470+
logger.log_metrics({"foo": 1.0}, step=5) # duplicate
456471

457472
# Only the first metric should be logged
458473
assert len(logged_metrics) == 1

0 commit comments

Comments
 (0)