From d31c128189cbf697e7c6625b131b4581e07dcd1a Mon Sep 17 00:00:00 2001 From: Samuel Marks <807580+SamuelMarks@users.noreply.github.com> Date: Mon, 2 Jun 2025 14:09:39 -0600 Subject: [PATCH 1/4] [keras/src/print.py] Implement backend-specialised `print` function --- keras/src/__init__.py | 1 + keras/src/print.py | 27 +++++++++++++++++++++++++++ 2 files changed, 28 insertions(+) create mode 100644 keras/src/print.py diff --git a/keras/src/__init__.py b/keras/src/__init__.py index 9778bcd4d63a..9674ef6bb2ed 100644 --- a/keras/src/__init__.py +++ b/keras/src/__init__.py @@ -17,4 +17,5 @@ from keras.src.models import Functional from keras.src.models import Model from keras.src.models import Sequential +from keras.src.print import print from keras.src.version import __version__ diff --git a/keras/src/print.py b/keras/src/print.py new file mode 100644 index 000000000000..e1658c47cdbd --- /dev/null +++ b/keras/src/print.py @@ -0,0 +1,27 @@ +import keras.backend +from keras.src.api_export import keras_export + +# Unique source of truth for the version number. +__version__ = "3.10.0" +_print = print + + +@keras_export("keras.print") +def print(*args, **kwargs): + backend = keras.backend.backend() + if backend == "jax": + import jax # noqa: E402 + + print_fn = jax.debug.print + elif backend == "tensorflow": + import tensorflow as tf # noqa: E402 + + print_fn = tf.print + else: + print_fn = _print + # TODO: + # "torch" + # pytorch.org/docs/stable/generated/torch.set_printoptions.html ? + # "openvino" + # "numpy" + return print_fn(*args, **kwargs) From 4571fee7031d33a32daccd74c1d854cd656d5130 Mon Sep 17 00:00:00 2001 From: Samuel Marks <807580+SamuelMarks@users.noreply.github.com> Date: Mon, 2 Jun 2025 14:16:51 -0600 Subject: [PATCH 2/4] [keras/src/print.py] Use `keras.src.backend` import ; [*] `pre-commit run --all-files` --- keras/api/__init__.py | 1 + keras/api/_tf_keras/keras/__init__.py | 1 + keras/src/print.py | 6 ++---- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/keras/api/__init__.py b/keras/api/__init__.py index dee6cea5bb19..9050d21f9124 100644 --- a/keras/api/__init__.py +++ b/keras/api/__init__.py @@ -60,6 +60,7 @@ from keras.src.ops.function import Function as Function from keras.src.ops.operation import Operation as Operation from keras.src.optimizers.optimizer import Optimizer as Optimizer +from keras.src.print import print as print from keras.src.quantizers.quantizers import Quantizer as Quantizer from keras.src.regularizers.regularizers import Regularizer as Regularizer from keras.src.version import __version__ as __version__ diff --git a/keras/api/_tf_keras/keras/__init__.py b/keras/api/_tf_keras/keras/__init__.py index 67d4738a0f3c..f890bc07b793 100644 --- a/keras/api/_tf_keras/keras/__init__.py +++ b/keras/api/_tf_keras/keras/__init__.py @@ -58,6 +58,7 @@ from keras.src.ops.function import Function as Function from keras.src.ops.operation import Operation as Operation from keras.src.optimizers.optimizer import Optimizer as Optimizer +from keras.src.print import print as print from keras.src.quantizers.quantizers import Quantizer as Quantizer from keras.src.regularizers.regularizers import Regularizer as Regularizer from keras.src.version import __version__ as __version__ diff --git a/keras/src/print.py b/keras/src/print.py index e1658c47cdbd..f4ad2a7fa599 100644 --- a/keras/src/print.py +++ b/keras/src/print.py @@ -1,14 +1,12 @@ -import keras.backend from keras.src.api_export import keras_export +from keras.src.backend import backend as keras_backend -# Unique source of truth for the version number. -__version__ = "3.10.0" _print = print @keras_export("keras.print") def print(*args, **kwargs): - backend = keras.backend.backend() + backend = keras_backend() if backend == "jax": import jax # noqa: E402 From 726837cd76e13baaa40c455268f52c351f1e20fb Mon Sep 17 00:00:00 2001 From: Samuel Marks <807580+SamuelMarks@users.noreply.github.com> Date: Mon, 2 Jun 2025 16:31:29 -0600 Subject: [PATCH 3/4] [keras/src/ops/core.py] Refactor to put `print` in ops and distribution implementation across backend-specific internal modules --- keras/api/__init__.py | 1 - keras/api/_tf_keras/keras/__init__.py | 1 - keras/api/_tf_keras/keras/ops/__init__.py | 1 + keras/api/ops/__init__.py | 1 + keras/src/__init__.py | 1 - keras/src/backend/jax/core.py | 10 +++++++++ keras/src/backend/tensorflow/core.py | 11 ++++++++++ keras/src/ops/core.py | 10 +++++++++ keras/src/print.py | 25 ----------------------- 9 files changed, 33 insertions(+), 28 deletions(-) delete mode 100644 keras/src/print.py diff --git a/keras/api/__init__.py b/keras/api/__init__.py index 9050d21f9124..dee6cea5bb19 100644 --- a/keras/api/__init__.py +++ b/keras/api/__init__.py @@ -60,7 +60,6 @@ from keras.src.ops.function import Function as Function from keras.src.ops.operation import Operation as Operation from keras.src.optimizers.optimizer import Optimizer as Optimizer -from keras.src.print import print as print from keras.src.quantizers.quantizers import Quantizer as Quantizer from keras.src.regularizers.regularizers import Regularizer as Regularizer from keras.src.version import __version__ as __version__ diff --git a/keras/api/_tf_keras/keras/__init__.py b/keras/api/_tf_keras/keras/__init__.py index f890bc07b793..67d4738a0f3c 100644 --- a/keras/api/_tf_keras/keras/__init__.py +++ b/keras/api/_tf_keras/keras/__init__.py @@ -58,7 +58,6 @@ from keras.src.ops.function import Function as Function from keras.src.ops.operation import Operation as Operation from keras.src.optimizers.optimizer import Optimizer as Optimizer -from keras.src.print import print as print from keras.src.quantizers.quantizers import Quantizer as Quantizer from keras.src.regularizers.regularizers import Regularizer as Regularizer from keras.src.version import __version__ as __version__ diff --git a/keras/api/_tf_keras/keras/ops/__init__.py b/keras/api/_tf_keras/keras/ops/__init__.py index c5770a93c49d..da80815a3c65 100644 --- a/keras/api/_tf_keras/keras/ops/__init__.py +++ b/keras/api/_tf_keras/keras/ops/__init__.py @@ -18,6 +18,7 @@ from keras.src.ops.core import fori_loop as fori_loop from keras.src.ops.core import is_tensor as is_tensor from keras.src.ops.core import map as map +from keras.src.ops.core import print as print from keras.src.ops.core import saturate_cast as saturate_cast from keras.src.ops.core import scan as scan from keras.src.ops.core import scatter as scatter diff --git a/keras/api/ops/__init__.py b/keras/api/ops/__init__.py index c5770a93c49d..da80815a3c65 100644 --- a/keras/api/ops/__init__.py +++ b/keras/api/ops/__init__.py @@ -18,6 +18,7 @@ from keras.src.ops.core import fori_loop as fori_loop from keras.src.ops.core import is_tensor as is_tensor from keras.src.ops.core import map as map +from keras.src.ops.core import print as print from keras.src.ops.core import saturate_cast as saturate_cast from keras.src.ops.core import scan as scan from keras.src.ops.core import scatter as scatter diff --git a/keras/src/__init__.py b/keras/src/__init__.py index 9674ef6bb2ed..9778bcd4d63a 100644 --- a/keras/src/__init__.py +++ b/keras/src/__init__.py @@ -17,5 +17,4 @@ from keras.src.models import Functional from keras.src.models import Model from keras.src.models import Sequential -from keras.src.print import print from keras.src.version import __version__ diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 747c5881106b..173968c22455 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -428,3 +428,13 @@ def device_scope(device_name): else: jax_device = device_name return jax.default_device(jax_device) + + +def print(*args, **kwargs): + """Prints values and works in staged out JAX functions. + + This function does *not* work with f-strings because formatting is delayed. + So instead of ``jax.debug.print(f"hello {bar}")``, write + ``jax.debug.print("hello {bar}", bar=bar)``. + """ + return jax.debug.print(*args, **kwargs) diff --git a/keras/src/backend/tensorflow/core.py b/keras/src/backend/tensorflow/core.py index 6896b74c519c..02352d4ee7f5 100644 --- a/keras/src/backend/tensorflow/core.py +++ b/keras/src/backend/tensorflow/core.py @@ -696,3 +696,14 @@ def __exit__(self, *args, **kwargs): def device_scope(device_name): return tf.device(device_name) + + +def print(*args, **kwargs): + """Print the specified inputs. + + A TensorFlow operator that prints the specified inputs to a desired + output stream or logging level. The inputs may be dense or sparse Tensors, + primitive python objects, data structures that contain tensors, and + printable Python objects. Printed tensors will recursively show the first + and last elements of each dimension to summarize.""" + return tf.print(*args, **kwargs) diff --git a/keras/src/ops/core.py b/keras/src/ops/core.py index 74807b280eae..b8bb128080b6 100644 --- a/keras/src/ops/core.py +++ b/keras/src/ops/core.py @@ -1183,3 +1183,13 @@ def grad(*args, upstream): ``` """ return backend.core.custom_gradient(f) + + +_print = print + + +@keras_export("keras.ops.print") +def print(*args, **kwargs): + return (backend.core.print if hasattr(backend.core, "print") else _print)( + *args, **kwargs + ) diff --git a/keras/src/print.py b/keras/src/print.py deleted file mode 100644 index f4ad2a7fa599..000000000000 --- a/keras/src/print.py +++ /dev/null @@ -1,25 +0,0 @@ -from keras.src.api_export import keras_export -from keras.src.backend import backend as keras_backend - -_print = print - - -@keras_export("keras.print") -def print(*args, **kwargs): - backend = keras_backend() - if backend == "jax": - import jax # noqa: E402 - - print_fn = jax.debug.print - elif backend == "tensorflow": - import tensorflow as tf # noqa: E402 - - print_fn = tf.print - else: - print_fn = _print - # TODO: - # "torch" - # pytorch.org/docs/stable/generated/torch.set_printoptions.html ? - # "openvino" - # "numpy" - return print_fn(*args, **kwargs) From 6c7db6295ab54c9b662baa1f87dfcfd781fecd85 Mon Sep 17 00:00:00 2001 From: Samuel Marks <807580+SamuelMarks@users.noreply.github.com> Date: Tue, 3 Jun 2025 21:25:07 -0600 Subject: [PATCH 4/4] [keras/src/backend/*/core.py] Add `print` impl for each backend --- keras/src/backend/jax/core.py | 11 +---------- keras/src/backend/numpy/core.py | 9 +++++++++ keras/src/backend/openvino/core.py | 6 ++++++ keras/src/backend/tensorflow/core.py | 12 +----------- keras/src/backend/torch/core.py | 9 +++++++++ keras/src/ops/core.py | 9 +++------ 6 files changed, 29 insertions(+), 27 deletions(-) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 173968c22455..7db9fae3ea16 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -1,3 +1,4 @@ +from jax.debug import print # noqa import jax import jax.experimental.sparse as jax_sparse import jax.numpy as jnp @@ -428,13 +429,3 @@ def device_scope(device_name): else: jax_device = device_name return jax.default_device(jax_device) - - -def print(*args, **kwargs): - """Prints values and works in staged out JAX functions. - - This function does *not* work with f-strings because formatting is delayed. - So instead of ``jax.debug.print(f"hello {bar}")``, write - ``jax.debug.print("hello {bar}", bar=bar)``. - """ - return jax.debug.print(*args, **kwargs) diff --git a/keras/src/backend/numpy/core.py b/keras/src/backend/numpy/core.py index 16b2303e5e43..a9441af891cc 100644 --- a/keras/src/backend/numpy/core.py +++ b/keras/src/backend/numpy/core.py @@ -18,6 +18,8 @@ SUPPORTS_RAGGED_TENSORS = False IS_THREAD_SAFE = True +_print = print + class Variable(KerasVariable): def _initialize(self, value): @@ -452,3 +454,10 @@ def remat(f): "utilize this feature." ) return f + + +def print(*args, print_options=None, **kwargs): + np.set_printoptions( + **{"threshold": 1000} if print_options is None else print_options + ) + return _print(*args, **kwargs) diff --git a/keras/src/backend/openvino/core.py b/keras/src/backend/openvino/core.py index ec990c376bf3..0595f600bc0f 100644 --- a/keras/src/backend/openvino/core.py +++ b/keras/src/backend/openvino/core.py @@ -22,6 +22,8 @@ SUPPORTS_RAGGED_TENSORS = False IS_THREAD_SAFE = True +_print = print + OPENVINO_DTYPES = { "float16": ov.Type.f16, "float32": ov.Type.f32, @@ -664,3 +666,7 @@ def remat(f): "utilize this feature." ) return f + + +def print(*args, **kwargs): + return _print(*args, **kwargs) diff --git a/keras/src/backend/tensorflow/core.py b/keras/src/backend/tensorflow/core.py index 02352d4ee7f5..7843c309cb7e 100644 --- a/keras/src/backend/tensorflow/core.py +++ b/keras/src/backend/tensorflow/core.py @@ -2,6 +2,7 @@ import numpy as np import tensorflow as tf +from tensorflow import print # noqa from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice from keras.src import tree @@ -696,14 +697,3 @@ def __exit__(self, *args, **kwargs): def device_scope(device_name): return tf.device(device_name) - - -def print(*args, **kwargs): - """Print the specified inputs. - - A TensorFlow operator that prints the specified inputs to a desired - output stream or logging level. The inputs may be dense or sparse Tensors, - primitive python objects, data structures that contain tensors, and - printable Python objects. Printed tensors will recursively show the first - and last elements of each dimension to summarize.""" - return tf.print(*args, **kwargs) diff --git a/keras/src/backend/torch/core.py b/keras/src/backend/torch/core.py index 6fb2ab4eeebb..c747595c8796 100644 --- a/keras/src/backend/torch/core.py +++ b/keras/src/backend/torch/core.py @@ -23,6 +23,8 @@ SUPPORTS_RAGGED_TENSORS = False IS_THREAD_SAFE = True +_print = print + # Some operators such as 'aten::_foreach_mul_.Scalar' # are not currently implemented for the MPS device. # check https://github.com/pytorch/pytorch/issues/77764. @@ -733,3 +735,10 @@ def backward(ctx, grad_output): if not isinstance(grads, tuple): grads = (grads,) return (None,) + grads + + +def print(*args, print_options=None, **kwargs): + torch.set_printoptions( + **{"threshold": 1000} if print_options is None else print_options + ) + return _print(*args, **kwargs) diff --git a/keras/src/ops/core.py b/keras/src/ops/core.py index b8bb128080b6..cc1076f3febf 100644 --- a/keras/src/ops/core.py +++ b/keras/src/ops/core.py @@ -1185,11 +1185,8 @@ def grad(*args, upstream): return backend.core.custom_gradient(f) -_print = print - - @keras_export("keras.ops.print") def print(*args, **kwargs): - return (backend.core.print if hasattr(backend.core, "print") else _print)( - *args, **kwargs - ) + """Backend-specialised print function, oft handles tensors and + other backend-specific types.""" + return backend.core.print(*args, **kwargs)