Skip to content

Improve error messages and testing for missing endpoints #17

@jpbrodrick89

Description

@jpbrodrick89

Currently, jax.jvp, jax.vjp, jax.grad and friends all call endpoints without explicitly checking they exist. This means all error handling is passed to the Tesseract client which by default in debug mode has very verbose error messages. I suggest we test for endpoints in the jvp/transpose rules themselves using a static cache of available endpoints and raise simple short error messages early if missing.

Reproduction:

>>> t = Tesseract.from_image("vectoradd")
>>> t.serve()
>>> inputs 
>>> f = lambda inputs: apply_tesseract(t, inputs)["result"].sum()
>>> jax.grad(f)
<function <lambda> at 0x10c441da0> # No error
>>> jax.jit(jax.grad(f)))
<PjitFunction of <function <lambda> at 0x10c4982c0>> # No error
>>> jax.grad(f)(inputs)
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/opt/homebrew/Cellar/python@3.13/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/_pyrepl/__main__.py", line 6, in <module>
    __pyrepl_interactive_console()
  File "/opt/homebrew/Cellar/python@3.13/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/_pyrepl/main.py", line 59, in interactive_console
    run_multiline_interactive_console(console)
  File "/opt/homebrew/Cellar/python@3.13/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/_pyrepl/simple_interact.py", line 160, in run_multiline_interactive_console
    more = console.push(_strip_final_indent(statement), filename=input_name, _symbol="single")  # type: ignore[call-arg]
  File "/opt/homebrew/Cellar/python@3.13/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/code.py", line 313, in push
    more = self.runsource(source, filename, symbol=_symbol)
  File "/opt/homebrew/Cellar/python@3.13/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/_pyrepl/console.py", line 205, in runsource
    self.runcode(code)
  File "/opt/homebrew/Cellar/python@3.13/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/code.py", line 92, in runcode
    exec(code, self.locals)
  File "<python-input-21>", line 1, in <module>
    jax.grad(f)({"a": jnp.array([3.0]), "b": jnp.array([5.0])})
  File "<python-input-20>", line 1, in <lambda>
    f = lambda inputs : apply_tesseract(t, inputs=inputs)["result"].sum()
  File "/Users/jonathanbrodrick/pasteurcodes/tesseract-jax/tesseract_jax/primitive.py", line 346, in apply_tesseract
    out = tesseract_dispatch_p.bind(
  File "/Users/jonathanbrodrick/pasteurcodes/tesseract-jax/tesseract_jax/primitive.py", line 119, in tesseract_dispatch_jvp_rule
    jvp = tesseract_dispatch_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: jaxlib.xla_extension.XlaRuntimeError: INTERNAL: CpuCallback error: Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
 .
 .
 .
  File "/Users/jonathanbrodrick/pasteurcodes/tesseract-core/tesseract_core/sdk/tesseract.py", line 437, in vector_jacobian_product
NotImplementedError: Vector Jacobian Product (VJP) not implemented for this Tesseract.
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File "<python-input-21>", line 1, in <module>
    jax.grad(f)({"a": jnp.array([3.0]), "b": jnp.array([5.0])})
    ~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonathanbrodrick/pasteurcodes/tesseract-jax/tesseract_jax/primitive.py", line 174, in tesseract_dispatch_transpose_rule
    vjp = tesseract_dispatch_p.bind(
        *args,
    ...<7 lines>...
        eval_func="vector_jacobian_product",
    )
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: CpuCallback error: Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
 .
 .
 .
  File "/Users/jonathanbrodrick/pasteurcodes/tesseract-core/tesseract_core/sdk/tesseract.py", line 437, in vector_jacobian_product
NotImplementedError: Vector Jacobian Product (VJP) not implemented for this Tesseract.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
>>> jax.jit(jax.grad(f))({"a": jnp.array([3.0]), "b": jnp.array([5.0])})
Traceback (most recent call last):
  File "<python-input-22>", line 1, in <module>
    jax.jit(jax.grad(f))({"a": jnp.array([3.0]), "b": jnp.array([5.0])})
    ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: CpuCallback error: Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/opt/homebrew/Cellar/python@3.13/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/_pyrepl/__main__.py", line 6, in <module>
  File "/opt/homebrew/Cellar/python@3.13/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/_pyrepl/main.py", line 59, in interactive_console
  File "/opt/homebrew/Cellar/python@3.13/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/_pyrepl/simple_interact.py", line 160, in run_multiline_interactive_console
  File "/opt/homebrew/Cellar/python@3.13/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/code.py", line 313, in push
  File "/opt/homebrew/Cellar/python@3.13/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/_pyrepl/console.py", line 205, in runsource
  File "/opt/homebrew/Cellar/python@3.13/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/code.py", line 92, in runcode
  File "<python-input-22>", line 1, in <module>
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/pjit.py", line 339, in cache_miss
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/pjit.py", line 194, in _python_pjit_helper
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/pjit.py", line 1681, in _pjit_call_impl_python
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/profiler.py", line 334, in wrapper
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/interpreters/pxla.py", line 1288, in __call__
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/callback.py", line 778, in _wrapped_callback
  File "/Users/jonathanbrodrick/pasteurcodes/tesseract-jax/tesseract_jax/primitive.py", line 213, in _dispatch
  File "/Users/jonathanbrodrick/pasteurcodes/tesseract-jax/tesseract_jax/tesseract_compat.py", line 243, in vector_jacobian_product
  File "/Users/jonathanbrodrick/pasteurcodes/tesseract-core/tesseract_core/sdk/tesseract.py", line 50, in wrapper
  File "/Users/jonathanbrodrick/pasteurcodes/tesseract-core/tesseract_core/sdk/tesseract.py", line 437, in vector_jacobian_product
NotImplementedError: Vector Jacobian Product (VJP) not implemented for this Tesseract.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions