Skip to content

[keras/src/print.py] Implement backend-specialised print function #21344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

SamuelMarks
Copy link
Contributor

Closes #21137

@codecov-commenter
Copy link

codecov-commenter commented Jun 2, 2025

Codecov Report

Attention: Patch coverage is 61.11111% with 7 lines in your changes missing coverage. Please review.

Project coverage is 82.67%. Comparing base (ad7bbb8) to head (6c7db62).
Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/backend/numpy/core.py 50.00% 2 Missing ⚠️
keras/src/backend/torch/core.py 50.00% 2 Missing ⚠️
keras/api/_tf_keras/keras/ops/__init__.py 0.00% 1 Missing ⚠️
keras/src/backend/openvino/core.py 66.66% 1 Missing ⚠️
keras/src/ops/core.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21344      +/-   ##
==========================================
- Coverage   82.68%   82.67%   -0.01%     
==========================================
  Files         565      565              
  Lines       54829    54847      +18     
  Branches     8513     8513              
==========================================
+ Hits        45335    45346      +11     
- Misses       7409     7416       +7     
  Partials     2085     2085              
Flag Coverage Δ
keras 82.48% <61.11%> (-0.01%) ⬇️
keras-jax 63.62% <38.88%> (-0.01%) ⬇️
keras-numpy 58.77% <50.00%> (-0.01%) ⬇️
keras-openvino 33.12% <38.88%> (+<0.01%) ⬆️
keras-tensorflow 64.01% <38.88%> (-0.01%) ⬇️
keras-torch 63.65% <38.88%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

If the op is not implemented for a given backend, create the op anyway and raise NotImplementedError.

Also please add a simple unit test.

_print = print


@keras_export("keras.print")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead, please add individual print ops to e.g. keras/src/backend/tensorflow/core.py, keras/src/backend/jax/core.py, etc. Then export the op is keras/src/ops/core.py.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I'm not sure if you want from jax.debug import print and from tensorflow import print or if you want them with first-class docstrings. I've added the latter; and updated this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the request was to add a print implementation also for the torch, numpy and openvino backends (even if they're the same).

Then remove the fallback mechanism and _print.

…on implementation across backend-specific internal modules


def print(*args, **kwargs):
"""Prints values and works in staged out JAX functions.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The backend specific docstring is not visible anywhere (neither on keras.io, nor in an IDE). So move to the docstring of keras/src/ops/core.py explaining that this is JAX specific.



def print(*args, **kwargs):
"""Print the specified inputs.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The backend specific docstring is not visible anywhere (neither on keras.io, nor in an IDE). So if you think it is useful, move to the docstring of keras/src/ops/core.py explaining that this is TensorFlow specific.

"""Prints values and works in staged out JAX functions.

This function does *not* work with f-strings because formatting is delayed.
So instead of ``jax.debug.print(f"hello {bar}")``, write
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Single back ticks ` here and on the next line

_print = print


@keras_export("keras.print")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the request was to add a print implementation also for the torch, numpy and openvino backends (even if they're the same).

Then remove the fallback mechanism and _print.

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm having second thoughts about this. In its current form, keras.ops.print is not cross-backend in the way we want it. Take this code:

x = keras.ops.array(1.0)
keras.ops.print("x is", x)

This will work on all backends except JAX, it will fail on JAX.

However:

x = keras.ops.array(1.0)
keras.ops.print("x is {x}", x=x)

Will work on JAX, but fail on all other backends.

In fact there is no way to print text and a tensor that will work on JAX and all other backends. This should be a minimal requirement for this feature: have a truly cross backend way to print text and tensors.

This means that the JAX implementation has to be changed. It is probably doable fairly easily using jax.debug.callback instead of jax.debug.print.

Then there is the question of keyword arguments. Python's print has:

  • sep
  • end
  • file
  • flush

Tensorflow supports:

  • sep
  • end
  • output_stream

Instead of **kwargs, I think we should explicitly spell out sep=' ', end='\n', file=None, flush=False.

  • On Tensorflow, flush will be ignored and file will be remapped to output_stream.
  • On Torch, Numpy, OpenVino, things will just work with Python's print
  • On JAX, the new implementation needs to support them



def print(*args, print_options=None, **kwargs):
np.set_printoptions(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that np.set_printoptions is global, not just for the scope of this call to print.

Let's say somebody does:

np.set_printoptions(threshold=100)

...
keras.ops.print(x)

They would never understand why their setting of 100 is not working.

So let's remove this.

@@ -452,3 +454,10 @@ def remat(f):
"utilize this feature."
)
return f


def print(*args, print_options=None, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The signature needs to be consistent between all the backends, so remove print_options.

@@ -18,6 +18,8 @@
SUPPORTS_RAGGED_TENSORS = False
IS_THREAD_SAFE = True

_print = print
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: instead of doing this, let's use __builtins__:

def print(*args, **kwargs):
  return __builtins__.print(*args, **kwargs)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh that's a good one, I didn't know the __builtins__

@@ -1,3 +1,4 @@
from jax.debug import print # noqa
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's be explicit instead, even if it seems redundant:

def print(*args, **kwargs):
  jax.debug.print(*args, **kwargs)

That being said, per my overall comment, the JAX implementation needs to be completely different and not use jax.debug.print.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hold up a moment I'll send JAX a PR; at latest this weekend

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, I don't think we should change JAX. I believe we can simply do:

def print(*args, sep=' ', end='\n', file=None, flush=False):
  kwargs = {"sep": sep, "end": end, "file": file, "flush": flush}
  jax.debug.callback(
      lambda *args, **kwargs: print(*args, **kwargs),
      *args, **kwargs)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But what about that whole interpolation discrepancy you referenced above?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess, for consistency, we're using the TF style, which is also the normal Python print style without f-string or format.

So we should document that: you cannot use f-string or format, you have to pass tensors as argument, and give an example: keras.ops.print("x:", x, "y:", y).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hold on let me send JAX a PR and see what they say

@@ -22,6 +22,8 @@
SUPPORTS_RAGGED_TENSORS = False
IS_THREAD_SAFE = True

_print = print
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: instead of doing this, let's use __builtins__:

def print(*args, **kwargs):
  return __builtins__.print(*args, **kwargs)

@@ -23,6 +23,8 @@
SUPPORTS_RAGGED_TENSORS = False
IS_THREAD_SAFE = True

_print = print
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: instead of doing this, let's use __builtins__:

def print(*args, **kwargs):
  return __builtins__.print(*args, **kwargs)



def print(*args, print_options=None, **kwargs):
torch.set_printoptions(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that torch.set_printoptions is global, not just for the scope of this call to print.

Let's say somebody does:

torch.set_printoptions(threshold=100)

...
keras.ops.print(x)

They would never understand why their setting of 100 is not working.

So let's remove this.


@keras_export("keras.ops.print")
def print(*args, **kwargs):
"""Backend-specialised print function, oft handles tensors and
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring needs to have the following format:

"""One-line short description.

Optionally more paragraphs of details.
"""

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

keras.ops.print: tf.print; jax.debug.print; …
5 participants