Skip to content
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9cb5e82
first commit for a decorator that transforms JAX to pytensor
jdehning Dec 12, 2024
d3a277e
Add more tests
jdehning Dec 12, 2024
5914162
Define JAXOp outside of the decorator
jdehning Dec 15, 2024
b810e99
Added comment regarding flattening of inputs
jdehning Dec 15, 2024
df76e73
Add as_jax_op to pytensor.__init__.py and to documentation
jdehning Dec 15, 2024
c2338fb
Add [jax] requirement to readthedocs in order to read the docstring o…
jdehning Dec 15, 2024
89474ae
Added an example to the docstring of as_jax_op
jdehning Dec 16, 2024
4e2e005
Use infer_static_shape, currently still with the possibility to use t…
jdehning Feb 3, 2025
e6b52d6
Remove `sol` in variable names
jdehning Feb 3, 2025
abf99f1
Rename tests and make static variables test more meaningfull
jdehning Feb 4, 2025
21d252d
More test renaming, forgot a few
jdehning Feb 5, 2025
7291696
Refactoring of ops.py: code is in general cleaner, and JAXOp can now …
jdehning Feb 5, 2025
3e2949d
Clean up tests
jdehning Feb 6, 2025
38b17b5
Add to some tests a direct call to JAXOp
jdehning Feb 6, 2025
bb75938
temporary as_jax_op fix
aseyboldt May 6, 2025
d04f41d
Simplify as_jax_op
aseyboldt Sep 16, 2025
8ecf45c
optionally eval shapes in as_jax_op
aseyboldt Sep 16, 2025
abd668b
minor coding style changes in as_jax_op
aseyboldt Sep 16, 2025
fcff09b
set output shape to None if not statically known in as_jax_op
aseyboldt Sep 16, 2025
866b4ba
clean up global import
aseyboldt Sep 16, 2025
0ae53a0
remove name from JaxOp.__props__
aseyboldt Sep 16, 2025
07a2c43
don't compile in shape eval of as_jax_op
aseyboldt Sep 16, 2025
f832126
more tests for as_jax_op
aseyboldt Sep 16, 2025
bf2d0b3
changes based on review
aseyboldt Sep 24, 2025
1e93af9
rename as_jax_op to wrap_jax
aseyboldt Sep 24, 2025
10d2097
remove equinox dependency
aseyboldt Sep 24, 2025
fceed2c
rename as_op to wrap_py
aseyboldt Sep 24, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ jobs:
micromamba install --yes -q "python~=${PYTHON_VERSION}" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx;
fi
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tfp-nightly; fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro equinox && pip install tfp-nightly; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi

Expand Down
1 change: 1 addition & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"jax": ("https://jax.readthedocs.io/en/latest", None),
"numpy": ("https://numpy.org/doc/stable", None),
"torch": ("https://pytorch.org/docs/stable", None),
"equinox": ("https://docs.kidger.site/equinox/", None),
}

needs_sphinx = "3"
Expand Down
2 changes: 1 addition & 1 deletion doc/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ dependencies:
- ablog
- pip
- pip:
- -e ..
- -e ..[jax]
7 changes: 7 additions & 0 deletions doc/library/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ Convert to Variable

.. autofunction:: pytensor.as_symbolic(...)

Wrap JAX functions
==================

.. autofunction:: as_jax_op(...)

Alias for :func:`pytensor.link.jax.ops.as_jax_op`

Debug
=====

Expand Down
12 changes: 12 additions & 0 deletions pytensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,18 @@ def get_underlying_scalar_constant(v):
from pytensor.scan.views import foldl, foldr, map, reduce
from pytensor.compile.builders import OpFromGraph

try:
import pytensor.link.jax.ops
from pytensor.link.jax.ops import as_jax_op
except ImportError as e:
import_error_as_jax_op = e

def as_jax_op(jax_function=None, allow_eval=True):
raise ImportError(
"JAX and/or equinox are not installed. Install them"
" to use this function: pip install pytensor[jax]"
) from import_error_as_jax_op

# isort: on


Expand Down
6 changes: 6 additions & 0 deletions pytensor/link/jax/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pytensor.graph import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse
from pytensor.link.jax.ops import JAXOp
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise

Expand Down Expand Up @@ -142,3 +143,8 @@ def opfromgraph(*inputs):
return fgraph_fn(*inputs)

return opfromgraph


@jax_funcify.register(JAXOp)
def jax_op_funcify(op, **kwargs):
return op.perform_jax
Loading
Loading