Skip to content

Commit 298e0e7

Browse files
committed
add jax to pyproject.toml
1 parent 421ddb0 commit 298e0e7

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

modules/jax/testcontainers/jax/__init__.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,23 @@ class JAXContainer(DockerContainer):
1111
1212
Example:
1313
14-
.. doctest::
14+
.. doctest::
1515
16-
>>> import jax
17-
>>> from testcontainers.jax import JAXContainer
16+
>>> import jax
17+
>>> from testcontainers.jax import JAXContainer
1818
19-
>>> with JAXContainer("nvcr.io/nvidia/jax:23.08-py3") as jax_container:
20-
... # Connect to the container
21-
... jax_container.connect()
22-
...
23-
... # Run a simple JAX computation
24-
... result = jax.numpy.add(1, 1)
25-
... assert result == 2
19+
>>> with JAXContainer("nvcr.io/nvidia/jax:23.08-py3") as jax_container:
20+
... # Connect to the container
21+
... jax_container.connect()
22+
...
23+
... # Run a simple JAX computation
24+
... result = jax.numpy.add(1, 1)
25+
... assert result == 2
26+
27+
.. auto-class:: JAXContainer
28+
:members:
29+
:undoc-members:
30+
:show-inheritance:
2631
"""
2732

2833
def __init__(self, image="nvcr.io/nvidia/jax:23.08-py3", **kwargs):

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ neo4j = ["neo4j"]
146146
nginx = []
147147
opensearch = ["opensearch-py"]
148148
ollama = []
149+
jax = ["jax"]
149150
oracle = ["sqlalchemy", "oracledb"]
150151
oracle-free = ["sqlalchemy", "oracledb"]
151152
postgres = []

0 commit comments

Comments
 (0)