Skip to content

Add memory usage monitor callback #21245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from keras.src.callbacks.learning_rate_scheduler import (
LearningRateScheduler as LearningRateScheduler,
)
from keras.src.callbacks.memory_usage_callback import (
MemoryUsageCallback as MemoryUsageCallback,
)
from keras.src.callbacks.model_checkpoint import (
ModelCheckpoint as ModelCheckpoint,
)
Expand Down
3 changes: 3 additions & 0 deletions keras/api/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from keras.src.callbacks.learning_rate_scheduler import (
LearningRateScheduler as LearningRateScheduler,
)
from keras.src.callbacks.memory_usage_callback import (
MemoryUsageCallback as MemoryUsageCallback,
)
from keras.src.callbacks.model_checkpoint import (
ModelCheckpoint as ModelCheckpoint,
)
Expand Down
1 change: 1 addition & 0 deletions keras/src/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from keras.src.callbacks.history import History
from keras.src.callbacks.lambda_callback import LambdaCallback
from keras.src.callbacks.learning_rate_scheduler import LearningRateScheduler
from keras.src.callbacks.memory_usage_callback import MemoryUsageCallback
from keras.src.callbacks.model_checkpoint import ModelCheckpoint
from keras.src.callbacks.monitor_callback import MonitorCallback
from keras.src.callbacks.progbar_logger import ProgbarLogger
Expand Down
192 changes: 192 additions & 0 deletions keras/src/callbacks/memory_usage_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import os
import warnings

from keras.src import backend as K
from keras.src.api_export import keras_export
from keras.src.callbacks.callback import Callback

# Attempt to import psutil for CPU memory monitoring
try:
import psutil
except ImportError:
psutil = None


@keras_export("keras.callbacks.MemoryUsageCallback")
class MemoryUsageCallback(Callback):
"""Monitors and logs memory usage
(CPU + optional GPU/TPU/OpenVINO) during training.

This callback measures:

- **CPU**: via psutil.Process().memory_info().rss
- **GPU/TPU**: via backend‐specific APIs
(TensorFlow, PyTorch, JAX)

Logs are printed to stdout at the start/end of each epoch and,
if `log_every_batch=True`, after every batch.
If `tensorboard_log_dir` is provided, scalars are also written
via `tf.summary` (TensorBoard).

Args:
monitor_gpu (bool): If True, attempt to measure accelerator memory.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add automatic detection instead of this arg?
if running_on_gpu:
..
if running_on_tpu:
...

logic to add these detections

def running_on_tpu():
    backend = keras.config.backend()
    if backend == "jax":
        import jax

        devices = jax.devices()
        return any(d.platform == "tpu" for d in devices)
    elif backend == "tensorflow":
        import tensorflow as tf

        return bool(tf.config.list_logical_devices("TPU"))
    elif backend == "torch":
        return False


def running_on_gpu():
    backend = keras.config.backend()
    if backend == "jax":
        import jax

        devices = jax.devices()
        return any(d.platform == "gpu" for d in devices)
    elif backend == "tensorflow":
        import tensorflow as tf

        return bool(tf.config.list_logical_devices("GPU"))
    elif backend == "torch":
        import torch

        return torch.cuda.is_available()

log_every_batch (bool): If True, also log after each batch.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the default behavior?
Log at the end of each epoch? - You would document default behavior in docstring

tensorboard_log_dir (str|None): Directory for TensorBoard logs;
if None, no TF summary writer is created.

Raises:
ImportError: If `psutil` is not installed (required for CPU logging).
"""

def __init__(
self,
monitor_gpu=True,
log_every_batch=False,
tensorboard_log_dir=None,
):
super().__init__()
if psutil is None:
raise ImportError(
"MemoryUsageCallback requires the 'psutil' library. "
"Install via `pip install psutil`."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT : "To install please use pip install psutil"

)
self.monitor_gpu = monitor_gpu
self.log_every_batch = log_every_batch
self._proc = psutil.Process()
self._step_counter = 0
self._writer = None

if tensorboard_log_dir:
try:
import tensorflow as tf

logdir = os.path.expanduser(tensorboard_log_dir)
self._writer = tf.summary.create_file_writer(logdir)
print(f"MemoryUsageCallback: TensorBoard logs → {logdir}")
except Exception as e:
warnings.warn(
f"Could not initialize TensorBoard writer: {e}",
RuntimeWarning,
)
self._writer = None

def on_train_begin(self, logs=None):
self._step_counter = 0

def on_epoch_begin(self, epoch, logs=None):
self._log_epoch("start", epoch)

def on_epoch_end(self, epoch, logs=None):
self._log_epoch("end", epoch, offset=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from the colab output I am observing that Epoch end is not logged when log_every_batch is False


def on_batch_end(self, batch, logs=None):
if self.log_every_batch:
self._log_step(
f"Batch {self._step_counter} end", self._step_counter
)
self._step_counter += 1

def on_train_end(self, logs=None):
if self._writer:
self._writer.close()

def _log_epoch(self, when, epoch, offset=0):
label = f"Epoch {epoch} {when}"
step = epoch + offset
self._log_step(label, step)

def _log_step(self, label, step):
cpu_mb = self._get_cpu_memory()
gpu_mb = self._get_gpu_memory() if self.monitor_gpu else None

msg = f"{label} - CPU Memory: {cpu_mb:.2f} MB"
if gpu_mb is not None:
msg += f"; GPU Memory: {gpu_mb:.2f} MB"
print(msg)

if self._writer:
import tensorflow as tf # noqa: E501

with self._writer.as_default(step=int(step)):
tf.summary.scalar("Memory/CPU_MB", cpu_mb)
if gpu_mb is not None:
tf.summary.scalar("Memory/GPU_MB", gpu_mb)
# flush happens inside writer

def _get_cpu_memory(self):
return self._proc.memory_info().rss / (1024**2)

def _get_gpu_memory(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another function to get tpu memory would be needed as well

backend_name = K.backend()
try:
if backend_name == "tensorflow":
import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")
if not gpus:
return None
total = sum(
tf.config.experimental.get_memory_info(g.name)["current"]
for g in gpus
)
return total / (1024**2)

elif backend_name == "torch":
import torch

if not torch.cuda.is_available():
return None
total = sum(
torch.cuda.memory_allocated(i)
for i in range(torch.cuda.device_count())
)
return total / (1024**2)

elif backend_name == "jax":
import jax

devs = [d for d in jax.devices() if d.platform.upper() == "GPU"]
if not devs:
return None
total = 0
for d in devs:
stats = getattr(d, "memory_stats", lambda: {})()
total += stats.get("bytes_in_use", 0)
return total / (1024**2)

elif backend_name == "openvino":
# OpenVINO provides no memory-stats API:
if not hasattr(self, "_warn_openvino"):
warnings.warn(
" OpenVINO does not expose memory stats; "
"GPU monitoring disabled.",
RuntimeWarning,
)
self._warn_openvino = True
return None

else:
if not hasattr(self, "_warn_backend"):
warnings.warn(
f"MemoryUsageCallback: no backend '{backend_name}'",
RuntimeWarning,
)
self._warn_backend = True
return None

except ImportError as imp_err:
if not hasattr(self, "_warn_import"):
warnings.warn(
f"Could not import for backend '{backend_name}': {imp_err}",
RuntimeWarning,
)
self._warn_import = True
return None

except Exception as exc:
if not hasattr(self, "_warn_exc"):
warnings.warn(
f"Error retrieving GPU memory: {exc}", RuntimeWarning
)
self._warn_exc = True
return None
165 changes: 165 additions & 0 deletions keras/src/callbacks/memory_usage_callback_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import glob
import os
import re
import sys
import tempfile
from contextlib import redirect_stdout
from importlib import reload
from io import StringIO
from unittest.mock import MagicMock
from unittest.mock import patch

import numpy as np
import pytest

from keras.src.callbacks.memory_usage_callback import MemoryUsageCallback
from keras.src.layers import Dense
from keras.src.models import Sequential
from keras.src.testing import TestCase

try:
import psutil
except ImportError:
psutil = None


@pytest.mark.skipif(psutil is None, reason="psutil is required")
class MemoryUsageCallbackTest(TestCase):
def setUp(self):
super().setUp()
# Prepare 20 samples of 10-dim data → 4 batches @ bs=5

self.x = np.random.random((20, 10)).astype(np.float32)
self.y = np.random.randint(0, 2, (20, 1)).astype(np.float32)
self.model = Sequential(
[
Dense(5, activation="relu", input_shape=(10,)),
Dense(1, activation="sigmoid"),
]
)
self.model.compile(optimizer="adam", loss="binary_crossentropy")
self.epochs = 2
self.bs = 5
self.steps = len(self.x) // self.bs

@pytest.mark.requires_trainable_backend
def test_epoch_logging_stdout(self):
"""Epoch-level logs appear with correct format."""
buf = StringIO()
with redirect_stdout(buf):
cb = MemoryUsageCallback(monitor_gpu=False)
self.model.fit(
self.x,
self.y,
epochs=self.epochs,
batch_size=self.bs,
callbacks=[cb],
verbose=0,
)
out = buf.getvalue()
for e in range(self.epochs):
assert f"Epoch {e} start" in out
assert f"Epoch {e} end" in out
assert re.search(rf"Epoch {e} start - CPU Memory: [\d\.]+ MB", out)
assert re.search(rf"Epoch {e} end - CPU Memory: [\d\.]+ MB", out)

@pytest.mark.requires_trainable_backend
def test_batch_logging_stdout(self):
"""Batch-level logs appear when log_every_batch=True."""
buf = StringIO()
with redirect_stdout(buf):
cb = MemoryUsageCallback(monitor_gpu=False, log_every_batch=True)
self.model.fit(
self.x,
self.y,
epochs=1,
batch_size=self.bs,
callbacks=[cb],
verbose=0,
)
lines = buf.getvalue().splitlines()
batch_lines = [l for l in lines if l.startswith("Batch ")]
assert len(batch_lines) == self.steps
assert all(
re.match(r"Batch \d+ end - CPU Memory: [\d\.]+ MB", l)
for l in batch_lines
)

@pytest.mark.requires_trainable_backend
def test_tensorboard_writes_files(self):
"""TensorBoard event files are created."""
tmp = tempfile.TemporaryDirectory()
logdir = os.path.join(tmp.name, "tb")
buf = StringIO()
with redirect_stdout(buf):
cb = MemoryUsageCallback(
monitor_gpu=False, tensorboard_log_dir=logdir
)
self.model.fit(
self.x,
self.y,
epochs=1,
batch_size=self.bs,
callbacks=[cb],
verbose=0,
)
files = glob.glob(os.path.join(logdir, "events.out.tfevents.*"))
assert files, "No TensorBoard event files generated"

@pytest.mark.requires_trainable_backend
def test_missing_psutil_raises(self):
"""Constructor raises if psutil is missing."""
mod = sys.modules["keras.src.callbacks.memory_usage_callback"]
orig = getattr(mod, "psutil", None)
with patch.dict(sys.modules, {"psutil": None}):
reload(mod)
with pytest.raises(ImportError):
_ = mod.MemoryUsageCallback(monitor_gpu=False)
# restore

if orig is not None:
sys.modules["psutil"] = orig
reload(mod)


@pytest.mark.requires_trainable_backend
def test_torch_backend_gpu_memory(monkeypatch):
"""Simulate PyTorch backend and verify GPU memory sum."""
import keras.src.backend as B

monkeypatch.setattr(B, "backend", lambda: "torch")

fake_torch = MagicMock()
fake_torch.cuda.is_available.return_value = True
fake_torch.cuda.device_count.return_value = 2
fake_torch.cuda.memory_allocated.side_effect = [
100 * 1024**2,
150 * 1024**2,
]
monkeypatch.setitem(sys.modules, "torch", fake_torch)

cb = MemoryUsageCallback(monitor_gpu=True)
mem = cb._get_gpu_memory()
assert pytest.approx(250, rel=1e-6) == mem


@pytest.mark.requires_trainable_backend
def test_jax_backend_gpu_memory(monkeypatch):
"""Simulate JAX backend and verify GPU memory sum."""
import keras.src.backend as B

monkeypatch.setattr(B, "backend", lambda: "jax")

class FakeDev:
platform = "gpu"

def memory_stats(self):
return {"bytes_in_use": 200 * 1024**2}

fake_jax = MagicMock()
fake_jax.devices.return_value = [FakeDev(), FakeDev()]
monkeypatch.setitem(sys.modules, "jax", fake_jax)

cb = MemoryUsageCallback(monitor_gpu=True)
mem = cb._get_gpu_memory()
assert pytest.approx(400, rel=1e-6) == mem
2 changes: 1 addition & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ dm_tree
coverage!=7.6.5 # 7.6.5 breaks CI
# for onnx_test.py
onnxruntime
openvino
psutil
Loading