-
Notifications
You must be signed in to change notification settings - Fork 19.6k
[keras/src/print.py] Implement backend-specialised print
function
#21344
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
… run --all-files`
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #21344 +/- ##
==========================================
- Coverage 82.68% 82.67% -0.01%
==========================================
Files 565 565
Lines 54829 54847 +18
Branches 8513 8513
==========================================
+ Hits 45335 45346 +11
- Misses 7409 7416 +7
Partials 2085 2085
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
If the op is not implemented for a given backend, create the op anyway and raise NotImplementedError
.
Also please add a simple unit test.
keras/src/print.py
Outdated
_print = print | ||
|
||
|
||
@keras_export("keras.print") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead, please add individual print ops to e.g. keras/src/backend/tensorflow/core.py
, keras/src/backend/jax/core.py
, etc. Then export the op is keras/src/ops/core.py
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I'm not sure if you want from jax.debug import print
and from tensorflow import print
or if you want them with first-class docstrings. I've added the latter; and updated this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the request was to add a print
implementation also for the torch, numpy and openvino backends (even if they're the same).
Then remove the fallback mechanism and _print
.
…on implementation across backend-specific internal modules
keras/src/backend/jax/core.py
Outdated
|
||
|
||
def print(*args, **kwargs): | ||
"""Prints values and works in staged out JAX functions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The backend specific docstring is not visible anywhere (neither on keras.io, nor in an IDE). So move to the docstring of keras/src/ops/core.py
explaining that this is JAX specific.
keras/src/backend/tensorflow/core.py
Outdated
|
||
|
||
def print(*args, **kwargs): | ||
"""Print the specified inputs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The backend specific docstring is not visible anywhere (neither on keras.io, nor in an IDE). So if you think it is useful, move to the docstring of keras/src/ops/core.py
explaining that this is TensorFlow specific.
keras/src/backend/jax/core.py
Outdated
"""Prints values and works in staged out JAX functions. | ||
|
||
This function does *not* work with f-strings because formatting is delayed. | ||
So instead of ``jax.debug.print(f"hello {bar}")``, write |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Single back ticks ` here and on the next line
keras/src/print.py
Outdated
_print = print | ||
|
||
|
||
@keras_export("keras.print") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the request was to add a print
implementation also for the torch, numpy and openvino backends (even if they're the same).
Then remove the fallback mechanism and _print
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm having second thoughts about this. In its current form, keras.ops.print
is not cross-backend in the way we want it. Take this code:
x = keras.ops.array(1.0)
keras.ops.print("x is", x)
This will work on all backends except JAX, it will fail on JAX.
However:
x = keras.ops.array(1.0)
keras.ops.print("x is {x}", x=x)
Will work on JAX, but fail on all other backends.
In fact there is no way to print text and a tensor that will work on JAX and all other backends. This should be a minimal requirement for this feature: have a truly cross backend way to print text and tensors.
This means that the JAX implementation has to be changed. It is probably doable fairly easily using jax.debug.callback
instead of jax.debug.print
.
Then there is the question of keyword arguments. Python's print
has:
sep
end
file
flush
Tensorflow supports:
sep
end
output_stream
Instead of **kwargs
, I think we should explicitly spell out sep=' ', end='\n', file=None, flush=False
.
- On Tensorflow,
flush
will be ignored andfile
will be remapped tooutput_stream
. - On Torch, Numpy, OpenVino, things will just work with Python's
print
- On JAX, the new implementation needs to support them
|
||
|
||
def print(*args, print_options=None, **kwargs): | ||
np.set_printoptions( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue is that np.set_printoptions
is global, not just for the scope of this call to print
.
Let's say somebody does:
np.set_printoptions(threshold=100)
...
keras.ops.print(x)
They would never understand why their setting of 100 is not working.
So let's remove this.
@@ -452,3 +454,10 @@ def remat(f): | |||
"utilize this feature." | |||
) | |||
return f | |||
|
|||
|
|||
def print(*args, print_options=None, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The signature needs to be consistent between all the backends, so remove print_options
.
@@ -18,6 +18,8 @@ | |||
SUPPORTS_RAGGED_TENSORS = False | |||
IS_THREAD_SAFE = True | |||
|
|||
_print = print |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick: instead of doing this, let's use __builtins__
:
def print(*args, **kwargs):
return __builtins__.print(*args, **kwargs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh that's a good one, I didn't know the __builtins__
@@ -1,3 +1,4 @@ | |||
from jax.debug import print # noqa |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's be explicit instead, even if it seems redundant:
def print(*args, **kwargs):
jax.debug.print(*args, **kwargs)
That being said, per my overall comment, the JAX implementation needs to be completely different and not use jax.debug.print
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hold up a moment I'll send JAX a PR; at latest this weekend
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, I don't think we should change JAX. I believe we can simply do:
def print(*args, sep=' ', end='\n', file=None, flush=False):
kwargs = {"sep": sep, "end": end, "file": file, "flush": flush}
jax.debug.callback(
lambda *args, **kwargs: print(*args, **kwargs),
*args, **kwargs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But what about that whole interpolation discrepancy you referenced above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess, for consistency, we're using the TF style, which is also the normal Python print
style without f-string or format.
So we should document that: you cannot use f-string or format, you have to pass tensors as argument, and give an example: keras.ops.print("x:", x, "y:", y)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hold on let me send JAX a PR and see what they say
@@ -22,6 +22,8 @@ | |||
SUPPORTS_RAGGED_TENSORS = False | |||
IS_THREAD_SAFE = True | |||
|
|||
_print = print |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick: instead of doing this, let's use __builtins__
:
def print(*args, **kwargs):
return __builtins__.print(*args, **kwargs)
@@ -23,6 +23,8 @@ | |||
SUPPORTS_RAGGED_TENSORS = False | |||
IS_THREAD_SAFE = True | |||
|
|||
_print = print |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick: instead of doing this, let's use __builtins__
:
def print(*args, **kwargs):
return __builtins__.print(*args, **kwargs)
|
||
|
||
def print(*args, print_options=None, **kwargs): | ||
torch.set_printoptions( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue is that torch.set_printoptions
is global, not just for the scope of this call to print
.
Let's say somebody does:
torch.set_printoptions(threshold=100)
...
keras.ops.print(x)
They would never understand why their setting of 100 is not working.
So let's remove this.
|
||
@keras_export("keras.ops.print") | ||
def print(*args, **kwargs): | ||
"""Backend-specialised print function, oft handles tensors and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstring needs to have the following format:
"""One-line short description.
Optionally more paragraphs of details.
"""
Closes #21137