Skip to content

keras.ops.print: tf.print; jax.debug.print; … #21137

Open
@SamuelMarks

Description

@SamuelMarks

As per our discussion earlier today @fchollet

https://www.tensorflow.org/api_docs/python/tf/print
https://docs.jax.dev/en/latest/_autosummary/jax.debug.print.html

Where do you think this function should be? keras.utils.summary_utils? keras.ops.print? - Somewhere else?

Roughly:

from keras.api import backend
import jax
import tensorflow as tf


_print = print


def print(*args, **kwargs):
    backend = backend.backend()
    print_fn = {"jax": jax.debug.print,
                "tensorflow": tf.print}.get(backend, _print)
    # "torch" https://pytorch.org/docs/stable/generated/torch.set_printoptions.html ?
    # "openvino"
    # "numpy"
    return print_fn(*args, **kwargs)

Metadata

Metadata

Labels

type:supportUser is asking for help / asking an implementation question. Stackoverflow would be better suited.

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions