Skip to content

[keras/src/print.py] Implement backend-specialised print function #21344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from keras.src.ops.core import fori_loop as fori_loop
from keras.src.ops.core import is_tensor as is_tensor
from keras.src.ops.core import map as map
from keras.src.ops.core import print as print
from keras.src.ops.core import saturate_cast as saturate_cast
from keras.src.ops.core import scan as scan
from keras.src.ops.core import scatter as scatter
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from keras.src.ops.core import fori_loop as fori_loop
from keras.src.ops.core import is_tensor as is_tensor
from keras.src.ops.core import map as map
from keras.src.ops.core import print as print
from keras.src.ops.core import saturate_cast as saturate_cast
from keras.src.ops.core import scan as scan
from keras.src.ops.core import scatter as scatter
Expand Down
1 change: 1 addition & 0 deletions keras/src/backend/jax/core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from jax.debug import print # noqa
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's be explicit instead, even if it seems redundant:

def print(*args, **kwargs):
  jax.debug.print(*args, **kwargs)

That being said, per my overall comment, the JAX implementation needs to be completely different and not use jax.debug.print.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hold up a moment I'll send JAX a PR; at latest this weekend

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, I don't think we should change JAX. I believe we can simply do:

def print(*args, sep=' ', end='\n', file=None, flush=False):
  kwargs = {"sep": sep, "end": end, "file": file, "flush": flush}
  jax.debug.callback(
      lambda *args, **kwargs: print(*args, **kwargs),
      *args, **kwargs)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But what about that whole interpolation discrepancy you referenced above?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess, for consistency, we're using the TF style, which is also the normal Python print style without f-string or format.

So we should document that: you cannot use f-string or format, you have to pass tensors as argument, and give an example: keras.ops.print("x:", x, "y:", y).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hold on let me send JAX a PR and see what they say

import jax
import jax.experimental.sparse as jax_sparse
import jax.numpy as jnp
Expand Down
9 changes: 9 additions & 0 deletions keras/src/backend/numpy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
SUPPORTS_RAGGED_TENSORS = False
IS_THREAD_SAFE = True

_print = print
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: instead of doing this, let's use __builtins__:

def print(*args, **kwargs):
  return __builtins__.print(*args, **kwargs)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh that's a good one, I didn't know the __builtins__



class Variable(KerasVariable):
def _initialize(self, value):
Expand Down Expand Up @@ -452,3 +454,10 @@ def remat(f):
"utilize this feature."
)
return f


def print(*args, print_options=None, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The signature needs to be consistent between all the backends, so remove print_options.

np.set_printoptions(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that np.set_printoptions is global, not just for the scope of this call to print.

Let's say somebody does:

np.set_printoptions(threshold=100)

...
keras.ops.print(x)

They would never understand why their setting of 100 is not working.

So let's remove this.

**{"threshold": 1000} if print_options is None else print_options
)
return _print(*args, **kwargs)
6 changes: 6 additions & 0 deletions keras/src/backend/openvino/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
SUPPORTS_RAGGED_TENSORS = False
IS_THREAD_SAFE = True

_print = print
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: instead of doing this, let's use __builtins__:

def print(*args, **kwargs):
  return __builtins__.print(*args, **kwargs)


OPENVINO_DTYPES = {
"float16": ov.Type.f16,
"float32": ov.Type.f32,
Expand Down Expand Up @@ -664,3 +666,7 @@ def remat(f):
"utilize this feature."
)
return f


def print(*args, **kwargs):
return _print(*args, **kwargs)
1 change: 1 addition & 0 deletions keras/src/backend/tensorflow/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import tensorflow as tf
from tensorflow import print # noqa
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice

from keras.src import tree
Expand Down
9 changes: 9 additions & 0 deletions keras/src/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
SUPPORTS_RAGGED_TENSORS = False
IS_THREAD_SAFE = True

_print = print
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: instead of doing this, let's use __builtins__:

def print(*args, **kwargs):
  return __builtins__.print(*args, **kwargs)


# Some operators such as 'aten::_foreach_mul_.Scalar'
# are not currently implemented for the MPS device.
# check https://github.com/pytorch/pytorch/issues/77764.
Expand Down Expand Up @@ -733,3 +735,10 @@ def backward(ctx, grad_output):
if not isinstance(grads, tuple):
grads = (grads,)
return (None,) + grads


def print(*args, print_options=None, **kwargs):
torch.set_printoptions(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that torch.set_printoptions is global, not just for the scope of this call to print.

Let's say somebody does:

torch.set_printoptions(threshold=100)

...
keras.ops.print(x)

They would never understand why their setting of 100 is not working.

So let's remove this.

**{"threshold": 1000} if print_options is None else print_options
)
return _print(*args, **kwargs)
7 changes: 7 additions & 0 deletions keras/src/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,3 +1183,10 @@ def grad(*args, upstream):
```
"""
return backend.core.custom_gradient(f)


@keras_export("keras.ops.print")
def print(*args, **kwargs):
"""Backend-specialised print function, oft handles tensors and
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring needs to have the following format:

"""One-line short description.

Optionally more paragraphs of details.
"""

other backend-specific types."""
return backend.core.print(*args, **kwargs)