From dc4914088de86f5ecb4deee8a89912d2deb96db9 Mon Sep 17 00:00:00 2001 From: Kavyansh Tyagi <142140238+KAVYANSHTYAGI@users.noreply.github.com> Date: Mon, 2 Jun 2025 18:53:05 +0530 Subject: [PATCH 1/5] Update mlflow.py --- src/lightning/pytorch/loggers/mlflow.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py index ff9b2b0d7e542..eca624223fff9 100644 --- a/src/lightning/pytorch/loggers/mlflow.py +++ b/src/lightning/pytorch/loggers/mlflow.py @@ -109,6 +109,12 @@ def any_lightning_module_function_or_hook(self): ModuleNotFoundError: If required MLFlow package is not installed on the device. + Note: + As of vX.XX, MLFlowLogger will skip logging any metric (same name and step) + more than once per run, to prevent database unique constraint violations on + some MLflow backends (such as PostgreSQL). Only the first value for each (metric, step) + pair will be logged per run. This improves robustness for all users. + """ LOGGER_JOIN_CHAR = "-" @@ -126,6 +132,7 @@ def __init__( run_id: Optional[str] = None, synchronous: Optional[bool] = None, ): + if not _MLFLOW_AVAILABLE: raise ModuleNotFoundError(str(_MLFLOW_AVAILABLE)) if synchronous is not None and not _MLFLOW_SYNCHRONOUS_AVAILABLE: @@ -151,6 +158,7 @@ def __init__( from mlflow.tracking import MlflowClient self._mlflow_client = MlflowClient(tracking_uri) + self._logged_metrics = set() # Track (key, step) @property @rank_zero_experiment @@ -201,6 +209,7 @@ def experiment(self) -> "MlflowClient": resolve_tags = _get_resolve_tags() run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=resolve_tags(self.tags)) self._run_id = run.info.run_id + self._logged_metrics.clear() self._initialized = True return self._mlflow_client @@ -257,7 +266,7 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) if isinstance(v, str): log.warning(f"Discarding metric with string value {k}={v}.") continue - + new_k = re.sub("[^a-zA-Z0-9_/. -]+", "", k) if k != new_k: rank_zero_warn( @@ -266,8 +275,15 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) category=RuntimeWarning, ) k = new_k + + metric_id = (k, step or 0) + if metric_id in self._logged_metrics: + continue + self._logged_metrics.add(metric_id) + metrics_list.append(Metric(key=k, value=v, timestamp=timestamp_ms, step=step or 0)) + self.experiment.log_batch(run_id=self.run_id, metrics=metrics_list, **self._log_batch_kwargs) @override From 7d786af6f1829c4ffff472cc885af9ad1e778780 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Jun 2025 13:30:39 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/loggers/mlflow.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py index eca624223fff9..b08f1bd999040 100644 --- a/src/lightning/pytorch/loggers/mlflow.py +++ b/src/lightning/pytorch/loggers/mlflow.py @@ -132,7 +132,6 @@ def __init__( run_id: Optional[str] = None, synchronous: Optional[bool] = None, ): - if not _MLFLOW_AVAILABLE: raise ModuleNotFoundError(str(_MLFLOW_AVAILABLE)) if synchronous is not None and not _MLFLOW_SYNCHRONOUS_AVAILABLE: @@ -209,7 +208,7 @@ def experiment(self) -> "MlflowClient": resolve_tags = _get_resolve_tags() run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=resolve_tags(self.tags)) self._run_id = run.info.run_id - self._logged_metrics.clear() + self._logged_metrics.clear() self._initialized = True return self._mlflow_client @@ -266,7 +265,7 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) if isinstance(v, str): log.warning(f"Discarding metric with string value {k}={v}.") continue - + new_k = re.sub("[^a-zA-Z0-9_/. -]+", "", k) if k != new_k: rank_zero_warn( @@ -275,14 +274,13 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) category=RuntimeWarning, ) k = new_k - + metric_id = (k, step or 0) if metric_id in self._logged_metrics: - continue + continue self._logged_metrics.add(metric_id) - - metrics_list.append(Metric(key=k, value=v, timestamp=timestamp_ms, step=step or 0)) + metrics_list.append(Metric(key=k, value=v, timestamp=timestamp_ms, step=step or 0)) self.experiment.log_batch(run_id=self.run_id, metrics=metrics_list, **self._log_batch_kwargs) From fd4bafa6c8d3440acda4ff1366e82142fcc60c5e Mon Sep 17 00:00:00 2001 From: Kavyansh Tyagi <142140238+KAVYANSHTYAGI@users.noreply.github.com> Date: Mon, 2 Jun 2025 19:02:06 +0530 Subject: [PATCH 3/5] Update test_mlflow.py --- tests/tests_pytorch/loggers/test_mlflow.py | 32 ++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index c7f9dbe1fe2c6..d99449eb9f893 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -427,3 +427,35 @@ def test_set_tracking_uri(mlflow_mock): mlflow_mock.set_tracking_uri.assert_not_called() _ = logger.experiment mlflow_mock.set_tracking_uri.assert_called_with("the_tracking_uri") + +def test_mlflowlogger_metric_deduplication(monkeypatch): + import types + from lightning.pytorch.loggers.mlflow import MLFlowLogger + + # Dummy MLflow client to record log_batch calls + logged_metrics = [] + class DummyMlflowClient: + def log_batch(self, run_id, metrics, **kwargs): + logged_metrics.extend(metrics) + def set_tracking_uri(self, uri): pass + def create_run(self, experiment_id, tags): + class Run: info = types.SimpleNamespace(run_id="dummy_run_id") + return Run() + def get_run(self, run_id): + class Run: info = types.SimpleNamespace(experiment_id="dummy_experiment_id") + return Run() + def get_experiment_by_name(self, name): return None + def create_experiment(self, name, artifact_location=None): return "dummy_experiment_id" + + # Patch the MLFlowLogger to use DummyMlflowClient + monkeypatch.setattr("mlflow.tracking.MlflowClient", lambda *a, **k: DummyMlflowClient()) + + logger = MLFlowLogger(experiment_name="test_exp") + logger.log_metrics({'foo': 1.0}, step=5) + logger.log_metrics({'foo': 1.0}, step=5) # duplicate + + # Only the first metric should be logged + assert len(logged_metrics) == 1 + assert logged_metrics[0].key == "foo" + assert logged_metrics[0].value == 1.0 + assert logged_metrics[0].step == 5 From 49f8f4ccff18f2dcc112b9baa746eddf5dd363d0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Jun 2025 13:32:26 +0000 Subject: [PATCH 4/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_pytorch/loggers/test_mlflow.py | 31 ++++++++++++++++------ 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index d99449eb9f893..a63a617ec4486 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -428,31 +428,46 @@ def test_set_tracking_uri(mlflow_mock): _ = logger.experiment mlflow_mock.set_tracking_uri.assert_called_with("the_tracking_uri") + def test_mlflowlogger_metric_deduplication(monkeypatch): import types + from lightning.pytorch.loggers.mlflow import MLFlowLogger # Dummy MLflow client to record log_batch calls logged_metrics = [] + class DummyMlflowClient: def log_batch(self, run_id, metrics, **kwargs): logged_metrics.extend(metrics) - def set_tracking_uri(self, uri): pass - def create_run(self, experiment_id, tags): - class Run: info = types.SimpleNamespace(run_id="dummy_run_id") + + def set_tracking_uri(self, uri): + pass + + def create_run(self, experiment_id, tags): + class Run: + info = types.SimpleNamespace(run_id="dummy_run_id") + return Run() + def get_run(self, run_id): - class Run: info = types.SimpleNamespace(experiment_id="dummy_experiment_id") + class Run: + info = types.SimpleNamespace(experiment_id="dummy_experiment_id") + return Run() - def get_experiment_by_name(self, name): return None - def create_experiment(self, name, artifact_location=None): return "dummy_experiment_id" + + def get_experiment_by_name(self, name): + return None + + def create_experiment(self, name, artifact_location=None): + return "dummy_experiment_id" # Patch the MLFlowLogger to use DummyMlflowClient monkeypatch.setattr("mlflow.tracking.MlflowClient", lambda *a, **k: DummyMlflowClient()) logger = MLFlowLogger(experiment_name="test_exp") - logger.log_metrics({'foo': 1.0}, step=5) - logger.log_metrics({'foo': 1.0}, step=5) # duplicate + logger.log_metrics({"foo": 1.0}, step=5) + logger.log_metrics({"foo": 1.0}, step=5) # duplicate # Only the first metric should be logged assert len(logged_metrics) == 1 From 609fcbb41d40e3793b07303cb4219992638314a3 Mon Sep 17 00:00:00 2001 From: Rittik Panda Date: Wed, 11 Jun 2025 22:29:13 +0530 Subject: [PATCH 5/5] Update test.txt --- requirements/pytorch/test.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index fd4237ef74e66..e59a404849056 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -17,3 +17,4 @@ fastapi # for `ServableModuleValidator` # not setting version as re-defined in uvicorn # for `ServableModuleValidator` # not setting version as re-defined in App tensorboard >=2.9.1, <2.20.0 # for `TensorBoardLogger` +mlflow