Skip to content

Commit 396c129

Browse files
committed
add huggingface amd jax
1 parent ba043ee commit 396c129

File tree

5 files changed

+96
-2
lines changed

5 files changed

+96
-2
lines changed
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import logging
2+
import urllib.request
3+
from urllib.error import URLError
4+
5+
from core.testcontainers.core.container import DockerContainer
6+
from core.testcontainers.core.waiting_utils import wait_container_is_ready
7+
from core.testcontainers.core.config import testcontainers_config
8+
from core.testcontainers.core.waiting_utils import wait_for_logs
9+
10+
class JAXContainer(DockerContainer):
11+
"""
12+
JAX container for GPU-accelerated numerical computing and machine learning.
13+
14+
Example:
15+
16+
.. doctest::
17+
18+
>>> import jax
19+
>>> from testcontainers.jax import JAXContainer
20+
21+
>>> with JAXContainer("nvcr.io/nvidia/jax:23.08-py3") as jax_container:
22+
... # Connect to the container
23+
... jax_container.connect()
24+
...
25+
... # Run a simple JAX computation
26+
... result = jax.numpy.add(1, 1)
27+
... assert result == 2
28+
29+
.. auto-class:: JAXContainer
30+
:members:
31+
:undoc-members:
32+
:show-inheritance:
33+
"""
34+
35+
def __init__(self, image="huggingface/transformers-jax-light:latest", **kwargs):
36+
super().__init__(image, **kwargs)
37+
self.with_exposed_ports(8888) # Expose Jupyter notebook port
38+
self.with_env("NVIDIA_VISIBLE_DEVICES", "all")
39+
self.with_env("CUDA_VISIBLE_DEVICES", "all")
40+
self.with_kwargs(runtime="nvidia") # Use NVIDIA runtime for GPU support
41+
self.start_timeout = 600 # 10 minutes
42+
43+
@wait_container_is_ready(URLError)
44+
def _connect(self):
45+
url = f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}"
46+
res = urllib.request.urlopen(url, timeout=self.start_timeout)
47+
if res.status != 200:
48+
raise Exception(f"Failed to connect to JAX container. Status: {res.status}")
49+
50+
def connect(self):
51+
"""
52+
Connect to the JAX container and ensure it's ready.
53+
"""
54+
self._connect()
55+
logging.info("Successfully connected to JAX container")
56+
57+
def get_jupyter_url(self):
58+
"""
59+
Get the URL for accessing the Jupyter notebook server.
60+
"""
61+
return f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}"
62+
63+
def run_jax_command(self, command):
64+
"""
65+
Run a JAX command inside the container.
66+
"""
67+
exec_result = self.exec(f"python -c '{command}'")
68+
return exec_result
69+
70+
def _wait_for_container_to_be_ready(self):
71+
wait_for_logs(self, "Jupyter Server", timeout=self.start_timeout)
72+
73+
def start(self):
74+
"""
75+
Start the JAX container and wait for it to be ready.
76+
"""
77+
super().start()
78+
self._wait_for_container_to_be_ready()
79+
logging.info(f"JAX container started. Jupyter URL: {self.get_jupyter_url()}")
80+
return self
81+
82+
def stop(self, force=True):
83+
"""
84+
Stop the JAX container.
85+
"""
86+
super().stop(force)
87+
logging.info("JAX container stopped.")
88+
89+
@property
90+
def timeout(self):
91+
"""
92+
Get the container start timeout.
93+
"""
94+
return self.start_timeout

modules/jax/tests/test_whisper_diarization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
from testcontainers.jax_whisper_diarization import JAXWhisperDiarizationContainer
2+
from modules.jax.testcontainers.whisper_cuda.whisper_diarization import JAXWhisperDiarizationContainer
33

44
@pytest.fixture(scope="module")
55
def hf_token():

modules/jax/tests/test_whisper_jax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
from testcontainers.whisper_jax import WhisperJAXContainer
2+
from modules.jax.testcontainers.whisper_cuda.whisper_transcription import WhisperJAXContainer
33

44
@pytest.mark.parametrize("model_name", ["openai/whisper-tiny", "openai/whisper-base"])
55
def test_whisper_jax_container(model_name):

0 commit comments

Comments
 (0)