-
Notifications
You must be signed in to change notification settings - Fork 1
Open
Labels
enhancementNew feature or requestNew feature or requestgood first issueGood for newcomersGood for newcomers
Description
Currently, jax.jvp, jax.vjp, jax.grad and friends all call endpoints without explicitly checking they exist. This means all error handling is passed to the Tesseract client which by default in debug mode has very verbose error messages. I suggest we test for endpoints in the jvp/transpose rules themselves using a static cache of available endpoints and raise simple short error messages early if missing.
Reproduction:
>>> t = Tesseract.from_image("vectoradd")
>>> t.serve()
>>> inputs
>>> f = lambda inputs: apply_tesseract(t, inputs)["result"].sum()
>>> jax.grad(f)
<function <lambda> at 0x10c441da0> # No error
>>> jax.jit(jax.grad(f)))
<PjitFunction of <function <lambda> at 0x10c4982c0>> # No error
>>> jax.grad(f)(inputs)
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/opt/homebrew/Cellar/python@3.13/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/_pyrepl/__main__.py", line 6, in <module>
__pyrepl_interactive_console()
File "/opt/homebrew/Cellar/python@3.13/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/_pyrepl/main.py", line 59, in interactive_console
run_multiline_interactive_console(console)
File "/opt/homebrew/Cellar/python@3.13/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/_pyrepl/simple_interact.py", line 160, in run_multiline_interactive_console
more = console.push(_strip_final_indent(statement), filename=input_name, _symbol="single") # type: ignore[call-arg]
File "/opt/homebrew/Cellar/python@3.13/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/code.py", line 313, in push
more = self.runsource(source, filename, symbol=_symbol)
File "/opt/homebrew/Cellar/python@3.13/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/_pyrepl/console.py", line 205, in runsource
self.runcode(code)
File "/opt/homebrew/Cellar/python@3.13/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/code.py", line 92, in runcode
exec(code, self.locals)
File "<python-input-21>", line 1, in <module>
jax.grad(f)({"a": jnp.array([3.0]), "b": jnp.array([5.0])})
File "<python-input-20>", line 1, in <lambda>
f = lambda inputs : apply_tesseract(t, inputs=inputs)["result"].sum()
File "/Users/jonathanbrodrick/pasteurcodes/tesseract-jax/tesseract_jax/primitive.py", line 346, in apply_tesseract
out = tesseract_dispatch_p.bind(
File "/Users/jonathanbrodrick/pasteurcodes/tesseract-jax/tesseract_jax/primitive.py", line 119, in tesseract_dispatch_jvp_rule
jvp = tesseract_dispatch_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: jaxlib.xla_extension.XlaRuntimeError: INTERNAL: CpuCallback error: Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
.
.
.
File "/Users/jonathanbrodrick/pasteurcodes/tesseract-core/tesseract_core/sdk/tesseract.py", line 437, in vector_jacobian_product
NotImplementedError: Vector Jacobian Product (VJP) not implemented for this Tesseract.
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<python-input-21>", line 1, in <module>
jax.grad(f)({"a": jnp.array([3.0]), "b": jnp.array([5.0])})
~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/jonathanbrodrick/pasteurcodes/tesseract-jax/tesseract_jax/primitive.py", line 174, in tesseract_dispatch_transpose_rule
vjp = tesseract_dispatch_p.bind(
*args,
...<7 lines>...
eval_func="vector_jacobian_product",
)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: CpuCallback error: Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
.
.
.
File "/Users/jonathanbrodrick/pasteurcodes/tesseract-core/tesseract_core/sdk/tesseract.py", line 437, in vector_jacobian_product
NotImplementedError: Vector Jacobian Product (VJP) not implemented for this Tesseract.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
>>> jax.jit(jax.grad(f))({"a": jnp.array([3.0]), "b": jnp.array([5.0])})
Traceback (most recent call last):
File "<python-input-22>", line 1, in <module>
jax.jit(jax.grad(f))({"a": jnp.array([3.0]), "b": jnp.array([5.0])})
~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: CpuCallback error: Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/opt/homebrew/Cellar/python@3.13/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/_pyrepl/__main__.py", line 6, in <module>
File "/opt/homebrew/Cellar/python@3.13/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/_pyrepl/main.py", line 59, in interactive_console
File "/opt/homebrew/Cellar/python@3.13/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/_pyrepl/simple_interact.py", line 160, in run_multiline_interactive_console
File "/opt/homebrew/Cellar/python@3.13/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/code.py", line 313, in push
File "/opt/homebrew/Cellar/python@3.13/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/_pyrepl/console.py", line 205, in runsource
File "/opt/homebrew/Cellar/python@3.13/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/code.py", line 92, in runcode
File "<python-input-22>", line 1, in <module>
File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/pjit.py", line 339, in cache_miss
File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/pjit.py", line 194, in _python_pjit_helper
File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/pjit.py", line 1681, in _pjit_call_impl_python
File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/profiler.py", line 334, in wrapper
File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/interpreters/pxla.py", line 1288, in __call__
File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/callback.py", line 778, in _wrapped_callback
File "/Users/jonathanbrodrick/pasteurcodes/tesseract-jax/tesseract_jax/primitive.py", line 213, in _dispatch
File "/Users/jonathanbrodrick/pasteurcodes/tesseract-jax/tesseract_jax/tesseract_compat.py", line 243, in vector_jacobian_product
File "/Users/jonathanbrodrick/pasteurcodes/tesseract-core/tesseract_core/sdk/tesseract.py", line 50, in wrapper
File "/Users/jonathanbrodrick/pasteurcodes/tesseract-core/tesseract_core/sdk/tesseract.py", line 437, in vector_jacobian_product
NotImplementedError: Vector Jacobian Product (VJP) not implemented for this Tesseract.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or requestgood first issueGood for newcomersGood for newcomers