Open
Description
As per our discussion earlier today @fchollet
https://www.tensorflow.org/api_docs/python/tf/print
https://docs.jax.dev/en/latest/_autosummary/jax.debug.print.html
Where do you think this function should be? keras.utils.summary_utils
? keras.ops.print
? - Somewhere else?
Roughly:
from keras.api import backend
import jax
import tensorflow as tf
_print = print
def print(*args, **kwargs):
backend = backend.backend()
print_fn = {"jax": jax.debug.print,
"tensorflow": tf.print}.get(backend, _print)
# "torch" https://pytorch.org/docs/stable/generated/torch.set_printoptions.html ?
# "openvino"
# "numpy"
return print_fn(*args, **kwargs)