diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index fd4237ef74e66..e59a404849056 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -17,3 +17,4 @@ fastapi # for `ServableModuleValidator` # not setting version as re-defined in uvicorn # for `ServableModuleValidator` # not setting version as re-defined in App tensorboard >=2.9.1, <2.20.0 # for `TensorBoardLogger` +mlflow diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py index ff9b2b0d7e542..b08f1bd999040 100644 --- a/src/lightning/pytorch/loggers/mlflow.py +++ b/src/lightning/pytorch/loggers/mlflow.py @@ -109,6 +109,12 @@ def any_lightning_module_function_or_hook(self): ModuleNotFoundError: If required MLFlow package is not installed on the device. + Note: + As of vX.XX, MLFlowLogger will skip logging any metric (same name and step) + more than once per run, to prevent database unique constraint violations on + some MLflow backends (such as PostgreSQL). Only the first value for each (metric, step) + pair will be logged per run. This improves robustness for all users. + """ LOGGER_JOIN_CHAR = "-" @@ -151,6 +157,7 @@ def __init__( from mlflow.tracking import MlflowClient self._mlflow_client = MlflowClient(tracking_uri) + self._logged_metrics = set() # Track (key, step) @property @rank_zero_experiment @@ -201,6 +208,7 @@ def experiment(self) -> "MlflowClient": resolve_tags = _get_resolve_tags() run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=resolve_tags(self.tags)) self._run_id = run.info.run_id + self._logged_metrics.clear() self._initialized = True return self._mlflow_client @@ -266,6 +274,12 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) category=RuntimeWarning, ) k = new_k + + metric_id = (k, step or 0) + if metric_id in self._logged_metrics: + continue + self._logged_metrics.add(metric_id) + metrics_list.append(Metric(key=k, value=v, timestamp=timestamp_ms, step=step or 0)) self.experiment.log_batch(run_id=self.run_id, metrics=metrics_list, **self._log_batch_kwargs) diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index c7f9dbe1fe2c6..a63a617ec4486 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -427,3 +427,50 @@ def test_set_tracking_uri(mlflow_mock): mlflow_mock.set_tracking_uri.assert_not_called() _ = logger.experiment mlflow_mock.set_tracking_uri.assert_called_with("the_tracking_uri") + + +def test_mlflowlogger_metric_deduplication(monkeypatch): + import types + + from lightning.pytorch.loggers.mlflow import MLFlowLogger + + # Dummy MLflow client to record log_batch calls + logged_metrics = [] + + class DummyMlflowClient: + def log_batch(self, run_id, metrics, **kwargs): + logged_metrics.extend(metrics) + + def set_tracking_uri(self, uri): + pass + + def create_run(self, experiment_id, tags): + class Run: + info = types.SimpleNamespace(run_id="dummy_run_id") + + return Run() + + def get_run(self, run_id): + class Run: + info = types.SimpleNamespace(experiment_id="dummy_experiment_id") + + return Run() + + def get_experiment_by_name(self, name): + return None + + def create_experiment(self, name, artifact_location=None): + return "dummy_experiment_id" + + # Patch the MLFlowLogger to use DummyMlflowClient + monkeypatch.setattr("mlflow.tracking.MlflowClient", lambda *a, **k: DummyMlflowClient()) + + logger = MLFlowLogger(experiment_name="test_exp") + logger.log_metrics({"foo": 1.0}, step=5) + logger.log_metrics({"foo": 1.0}, step=5) # duplicate + + # Only the first metric should be logged + assert len(logged_metrics) == 1 + assert logged_metrics[0].key == "foo" + assert logged_metrics[0].value == 1.0 + assert logged_metrics[0].step == 5