-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Add memory usage monitor callback #21245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 14 commits
6f7143c
7695601
5f9d975
daddf29
105cbdc
f340101
5af4a44
e7f225c
1ae7659
cb00aa2
e13528e
728b770
c4c0e5e
e064d0e
a9e0212
9148a60
20cbdf9
a671b62
bd7fc07
8f37649
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
import os | ||
import warnings | ||
|
||
from keras.src import backend as K | ||
from keras.src.api_export import keras_export | ||
from keras.src.callbacks.callback import Callback | ||
|
||
# Attempt to import psutil for CPU memory monitoring | ||
try: | ||
import psutil | ||
except ImportError: | ||
psutil = None | ||
|
||
|
||
@keras_export("keras.callbacks.MemoryUsageCallback") | ||
class MemoryUsageCallback(Callback): | ||
"""Monitors and logs memory usage | ||
(CPU + optional GPU/TPU/OpenVINO) during training. | ||
|
||
This callback measures: | ||
|
||
- **CPU**: via psutil.Process().memory_info().rss | ||
- **GPU/TPU**: via backend‐specific APIs | ||
(TensorFlow, PyTorch, JAX) | ||
|
||
Logs are printed to stdout at the start/end of each epoch and, | ||
if `log_every_batch=True`, after every batch. | ||
If `tensorboard_log_dir` is provided, scalars are also written | ||
via `tf.summary` (TensorBoard). | ||
|
||
Args: | ||
monitor_gpu (bool): If True, attempt to measure accelerator memory. | ||
log_every_batch (bool): If True, also log after each batch. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is the default behavior? |
||
tensorboard_log_dir (str|None): Directory for TensorBoard logs; | ||
if None, no TF summary writer is created. | ||
|
||
Raises: | ||
ImportError: If `psutil` is not installed (required for CPU logging). | ||
""" | ||
|
||
def __init__( | ||
self, | ||
monitor_gpu=True, | ||
log_every_batch=False, | ||
tensorboard_log_dir=None, | ||
): | ||
super().__init__() | ||
if psutil is None: | ||
raise ImportError( | ||
"MemoryUsageCallback requires the 'psutil' library. " | ||
"Install via `pip install psutil`." | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NIT : "To install please use |
||
) | ||
self.monitor_gpu = monitor_gpu | ||
self.log_every_batch = log_every_batch | ||
self._proc = psutil.Process() | ||
self._step_counter = 0 | ||
self._writer = None | ||
|
||
if tensorboard_log_dir: | ||
try: | ||
import tensorflow as tf | ||
|
||
logdir = os.path.expanduser(tensorboard_log_dir) | ||
self._writer = tf.summary.create_file_writer(logdir) | ||
print(f"MemoryUsageCallback: TensorBoard logs → {logdir}") | ||
except Exception as e: | ||
warnings.warn( | ||
f"Could not initialize TensorBoard writer: {e}", | ||
RuntimeWarning, | ||
) | ||
self._writer = None | ||
|
||
def on_train_begin(self, logs=None): | ||
self._step_counter = 0 | ||
|
||
def on_epoch_begin(self, epoch, logs=None): | ||
self._log_epoch("start", epoch) | ||
|
||
def on_epoch_end(self, epoch, logs=None): | ||
self._log_epoch("end", epoch, offset=1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. from the colab output I am observing that Epoch end is not logged when log_every_batch is False |
||
|
||
def on_batch_end(self, batch, logs=None): | ||
if self.log_every_batch: | ||
self._log_step( | ||
f"Batch {self._step_counter} end", self._step_counter | ||
) | ||
self._step_counter += 1 | ||
|
||
def on_train_end(self, logs=None): | ||
if self._writer: | ||
self._writer.close() | ||
|
||
def _log_epoch(self, when, epoch, offset=0): | ||
label = f"Epoch {epoch} {when}" | ||
step = epoch + offset | ||
self._log_step(label, step) | ||
|
||
def _log_step(self, label, step): | ||
cpu_mb = self._get_cpu_memory() | ||
gpu_mb = self._get_gpu_memory() if self.monitor_gpu else None | ||
|
||
msg = f"{label} - CPU Memory: {cpu_mb:.2f} MB" | ||
if gpu_mb is not None: | ||
msg += f"; GPU Memory: {gpu_mb:.2f} MB" | ||
print(msg) | ||
|
||
if self._writer: | ||
import tensorflow as tf # noqa: E501 | ||
|
||
with self._writer.as_default(step=int(step)): | ||
tf.summary.scalar("Memory/CPU_MB", cpu_mb) | ||
if gpu_mb is not None: | ||
tf.summary.scalar("Memory/GPU_MB", gpu_mb) | ||
# flush happens inside writer | ||
|
||
def _get_cpu_memory(self): | ||
return self._proc.memory_info().rss / (1024**2) | ||
|
||
def _get_gpu_memory(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. another function to get tpu memory would be needed as well |
||
backend_name = K.backend() | ||
try: | ||
if backend_name == "tensorflow": | ||
import tensorflow as tf | ||
|
||
gpus = tf.config.list_physical_devices("GPU") | ||
if not gpus: | ||
return None | ||
total = sum( | ||
tf.config.experimental.get_memory_info(g.name)["current"] | ||
for g in gpus | ||
) | ||
return total / (1024**2) | ||
|
||
elif backend_name == "torch": | ||
import torch | ||
|
||
if not torch.cuda.is_available(): | ||
return None | ||
total = sum( | ||
torch.cuda.memory_allocated(i) | ||
for i in range(torch.cuda.device_count()) | ||
) | ||
return total / (1024**2) | ||
|
||
elif backend_name == "jax": | ||
import jax | ||
|
||
devs = [d for d in jax.devices() if d.platform.upper() == "GPU"] | ||
if not devs: | ||
return None | ||
total = 0 | ||
for d in devs: | ||
stats = getattr(d, "memory_stats", lambda: {})() | ||
total += stats.get("bytes_in_use", 0) | ||
return total / (1024**2) | ||
|
||
elif backend_name == "openvino": | ||
# OpenVINO provides no memory-stats API: | ||
if not hasattr(self, "_warn_openvino"): | ||
warnings.warn( | ||
" OpenVINO does not expose memory stats; " | ||
"GPU monitoring disabled.", | ||
RuntimeWarning, | ||
) | ||
self._warn_openvino = True | ||
return None | ||
|
||
else: | ||
if not hasattr(self, "_warn_backend"): | ||
warnings.warn( | ||
f"MemoryUsageCallback: no backend '{backend_name}'", | ||
RuntimeWarning, | ||
) | ||
self._warn_backend = True | ||
return None | ||
|
||
except ImportError as imp_err: | ||
if not hasattr(self, "_warn_import"): | ||
warnings.warn( | ||
f"Could not import for backend '{backend_name}': {imp_err}", | ||
RuntimeWarning, | ||
) | ||
self._warn_import = True | ||
return None | ||
|
||
except Exception as exc: | ||
if not hasattr(self, "_warn_exc"): | ||
warnings.warn( | ||
f"Error retrieving GPU memory: {exc}", RuntimeWarning | ||
) | ||
self._warn_exc = True | ||
return None |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
import glob | ||
import os | ||
import re | ||
import sys | ||
import tempfile | ||
from contextlib import redirect_stdout | ||
from importlib import reload | ||
from io import StringIO | ||
from unittest.mock import MagicMock | ||
from unittest.mock import patch | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from keras.src.callbacks.memory_usage_callback import MemoryUsageCallback | ||
from keras.src.layers import Dense | ||
from keras.src.models import Sequential | ||
from keras.src.testing import TestCase | ||
|
||
try: | ||
import psutil | ||
except ImportError: | ||
psutil = None | ||
|
||
|
||
@pytest.mark.skipif(psutil is None, reason="psutil is required") | ||
class MemoryUsageCallbackTest(TestCase): | ||
def setUp(self): | ||
super().setUp() | ||
# Prepare 20 samples of 10-dim data → 4 batches @ bs=5 | ||
|
||
self.x = np.random.random((20, 10)).astype(np.float32) | ||
self.y = np.random.randint(0, 2, (20, 1)).astype(np.float32) | ||
self.model = Sequential( | ||
[ | ||
Dense(5, activation="relu", input_shape=(10,)), | ||
Dense(1, activation="sigmoid"), | ||
] | ||
) | ||
self.model.compile(optimizer="adam", loss="binary_crossentropy") | ||
self.epochs = 2 | ||
self.bs = 5 | ||
self.steps = len(self.x) // self.bs | ||
|
||
@pytest.mark.requires_trainable_backend | ||
def test_epoch_logging_stdout(self): | ||
"""Epoch-level logs appear with correct format.""" | ||
buf = StringIO() | ||
with redirect_stdout(buf): | ||
cb = MemoryUsageCallback(monitor_gpu=False) | ||
self.model.fit( | ||
self.x, | ||
self.y, | ||
epochs=self.epochs, | ||
batch_size=self.bs, | ||
callbacks=[cb], | ||
verbose=0, | ||
) | ||
out = buf.getvalue() | ||
for e in range(self.epochs): | ||
assert f"Epoch {e} start" in out | ||
assert f"Epoch {e} end" in out | ||
assert re.search(rf"Epoch {e} start - CPU Memory: [\d\.]+ MB", out) | ||
assert re.search(rf"Epoch {e} end - CPU Memory: [\d\.]+ MB", out) | ||
|
||
@pytest.mark.requires_trainable_backend | ||
def test_batch_logging_stdout(self): | ||
"""Batch-level logs appear when log_every_batch=True.""" | ||
buf = StringIO() | ||
with redirect_stdout(buf): | ||
cb = MemoryUsageCallback(monitor_gpu=False, log_every_batch=True) | ||
self.model.fit( | ||
self.x, | ||
self.y, | ||
epochs=1, | ||
batch_size=self.bs, | ||
callbacks=[cb], | ||
verbose=0, | ||
) | ||
lines = buf.getvalue().splitlines() | ||
batch_lines = [l for l in lines if l.startswith("Batch ")] | ||
assert len(batch_lines) == self.steps | ||
assert all( | ||
re.match(r"Batch \d+ end - CPU Memory: [\d\.]+ MB", l) | ||
for l in batch_lines | ||
) | ||
|
||
@pytest.mark.requires_trainable_backend | ||
def test_tensorboard_writes_files(self): | ||
"""TensorBoard event files are created.""" | ||
tmp = tempfile.TemporaryDirectory() | ||
logdir = os.path.join(tmp.name, "tb") | ||
buf = StringIO() | ||
with redirect_stdout(buf): | ||
cb = MemoryUsageCallback( | ||
monitor_gpu=False, tensorboard_log_dir=logdir | ||
) | ||
self.model.fit( | ||
self.x, | ||
self.y, | ||
epochs=1, | ||
batch_size=self.bs, | ||
callbacks=[cb], | ||
verbose=0, | ||
) | ||
files = glob.glob(os.path.join(logdir, "events.out.tfevents.*")) | ||
assert files, "No TensorBoard event files generated" | ||
|
||
@pytest.mark.requires_trainable_backend | ||
def test_missing_psutil_raises(self): | ||
"""Constructor raises if psutil is missing.""" | ||
mod = sys.modules["keras.src.callbacks.memory_usage_callback"] | ||
orig = getattr(mod, "psutil", None) | ||
with patch.dict(sys.modules, {"psutil": None}): | ||
reload(mod) | ||
with pytest.raises(ImportError): | ||
_ = mod.MemoryUsageCallback(monitor_gpu=False) | ||
# restore | ||
|
||
if orig is not None: | ||
sys.modules["psutil"] = orig | ||
reload(mod) | ||
|
||
|
||
@pytest.mark.requires_trainable_backend | ||
def test_torch_backend_gpu_memory(monkeypatch): | ||
"""Simulate PyTorch backend and verify GPU memory sum.""" | ||
import keras.src.backend as B | ||
|
||
monkeypatch.setattr(B, "backend", lambda: "torch") | ||
|
||
fake_torch = MagicMock() | ||
fake_torch.cuda.is_available.return_value = True | ||
fake_torch.cuda.device_count.return_value = 2 | ||
fake_torch.cuda.memory_allocated.side_effect = [ | ||
100 * 1024**2, | ||
150 * 1024**2, | ||
] | ||
monkeypatch.setitem(sys.modules, "torch", fake_torch) | ||
|
||
cb = MemoryUsageCallback(monitor_gpu=True) | ||
mem = cb._get_gpu_memory() | ||
assert pytest.approx(250, rel=1e-6) == mem | ||
|
||
|
||
@pytest.mark.requires_trainable_backend | ||
def test_jax_backend_gpu_memory(monkeypatch): | ||
"""Simulate JAX backend and verify GPU memory sum.""" | ||
import keras.src.backend as B | ||
|
||
monkeypatch.setattr(B, "backend", lambda: "jax") | ||
|
||
class FakeDev: | ||
platform = "gpu" | ||
|
||
def memory_stats(self): | ||
return {"bytes_in_use": 200 * 1024**2} | ||
|
||
fake_jax = MagicMock() | ||
fake_jax.devices.return_value = [FakeDev(), FakeDev()] | ||
monkeypatch.setitem(sys.modules, "jax", fake_jax) | ||
|
||
cb = MemoryUsageCallback(monitor_gpu=True) | ||
mem = cb._get_gpu_memory() | ||
assert pytest.approx(400, rel=1e-6) == mem |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,4 +23,4 @@ dm_tree | |
coverage!=7.6.5 # 7.6.5 breaks CI | ||
# for onnx_test.py | ||
onnxruntime | ||
openvino | ||
psutil |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add automatic detection instead of this arg?
if running_on_gpu:
..
if running_on_tpu:
...
logic to add these detections