-
Notifications
You must be signed in to change notification settings - Fork 19.6k
[keras/src/print.py] Implement backend-specialised print
function
#21344
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
d31c128
4571fee
726837c
6c7db62
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,8 @@ | |
SUPPORTS_RAGGED_TENSORS = False | ||
IS_THREAD_SAFE = True | ||
|
||
_print = print | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nitpick: instead of doing this, let's use def print(*args, **kwargs):
return __builtins__.print(*args, **kwargs) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh that's a good one, I didn't know the |
||
|
||
|
||
class Variable(KerasVariable): | ||
def _initialize(self, value): | ||
|
@@ -452,3 +454,10 @@ def remat(f): | |
"utilize this feature." | ||
) | ||
return f | ||
|
||
|
||
def print(*args, print_options=None, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The signature needs to be consistent between all the backends, so remove |
||
np.set_printoptions( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The issue is that Let's say somebody does: np.set_printoptions(threshold=100)
...
keras.ops.print(x) They would never understand why their setting of 100 is not working. So let's remove this. |
||
**{"threshold": 1000} if print_options is None else print_options | ||
) | ||
return _print(*args, **kwargs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,8 @@ | |
SUPPORTS_RAGGED_TENSORS = False | ||
IS_THREAD_SAFE = True | ||
|
||
_print = print | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nitpick: instead of doing this, let's use def print(*args, **kwargs):
return __builtins__.print(*args, **kwargs) |
||
|
||
OPENVINO_DTYPES = { | ||
"float16": ov.Type.f16, | ||
"float32": ov.Type.f32, | ||
|
@@ -664,3 +666,7 @@ def remat(f): | |
"utilize this feature." | ||
) | ||
return f | ||
|
||
|
||
def print(*args, **kwargs): | ||
return _print(*args, **kwargs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,8 @@ | |
SUPPORTS_RAGGED_TENSORS = False | ||
IS_THREAD_SAFE = True | ||
|
||
_print = print | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nitpick: instead of doing this, let's use def print(*args, **kwargs):
return __builtins__.print(*args, **kwargs) |
||
|
||
# Some operators such as 'aten::_foreach_mul_.Scalar' | ||
# are not currently implemented for the MPS device. | ||
# check https://github.com/pytorch/pytorch/issues/77764. | ||
|
@@ -733,3 +735,10 @@ def backward(ctx, grad_output): | |
if not isinstance(grads, tuple): | ||
grads = (grads,) | ||
return (None,) + grads | ||
|
||
|
||
def print(*args, print_options=None, **kwargs): | ||
torch.set_printoptions( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The issue is that Let's say somebody does: torch.set_printoptions(threshold=100)
...
keras.ops.print(x) They would never understand why their setting of 100 is not working. So let's remove this. |
||
**{"threshold": 1000} if print_options is None else print_options | ||
) | ||
return _print(*args, **kwargs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1183,3 +1183,10 @@ def grad(*args, upstream): | |
``` | ||
""" | ||
return backend.core.custom_gradient(f) | ||
|
||
|
||
@keras_export("keras.ops.print") | ||
def print(*args, **kwargs): | ||
"""Backend-specialised print function, oft handles tensors and | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Docstring needs to have the following format:
|
||
other backend-specific types.""" | ||
return backend.core.print(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's be explicit instead, even if it seems redundant:
That being said, per my overall comment, the JAX implementation needs to be completely different and not use
jax.debug.print
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hold up a moment I'll send JAX a PR; at latest this weekend
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, I don't think we should change JAX. I believe we can simply do:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But what about that whole interpolation discrepancy you referenced above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess, for consistency, we're using the TF style, which is also the normal Python
print
style without f-string or format.So we should document that: you cannot use f-string or format, you have to pass tensors as argument, and give an example:
keras.ops.print("x:", x, "y:", y)
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hold on let me send JAX a PR and see what they say