Skip to content

Commit ba043ee

Browse files
committed
increase timeout and use get_logs, refactor whisper-cuda into folder
1 parent 298e0e7 commit ba043ee

File tree

5 files changed

+150
-4
lines changed

5 files changed

+150
-4
lines changed

modules/jax/testcontainers/jax/__init__.py renamed to modules/jax/testcontainers/jax_cuda/__init__.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
from core.testcontainers.core.container import DockerContainer
66
from core.testcontainers.core.waiting_utils import wait_container_is_ready
7+
from core.testcontainers.core.config import testcontainers_config
8+
from core.testcontainers.core.waiting_utils import wait_for_logs
79

810
class JAXContainer(DockerContainer):
911
"""
@@ -36,11 +38,12 @@ def __init__(self, image="nvcr.io/nvidia/jax:23.08-py3", **kwargs):
3638
self.with_env("NVIDIA_VISIBLE_DEVICES", "all")
3739
self.with_env("CUDA_VISIBLE_DEVICES", "all")
3840
self.with_kwargs(runtime="nvidia") # Use NVIDIA runtime for GPU support
41+
self.start_timeout = 600 # 10 minutes
3942

4043
@wait_container_is_ready(URLError)
4144
def _connect(self):
4245
url = f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}"
43-
res = urllib.request.urlopen(url)
46+
res = urllib.request.urlopen(url, timeout=self.start_timeout)
4447
if res.status != 200:
4548
raise Exception(f"Failed to connect to JAX container. Status: {res.status}")
4649

@@ -64,10 +67,28 @@ def run_jax_command(self, command):
6467
exec_result = self.exec(f"python -c '{command}'")
6568
return exec_result
6669

70+
def _wait_for_container_to_be_ready(self):
71+
wait_for_logs(self, "Jupyter Server", timeout=self.start_timeout)
72+
6773
def start(self):
6874
"""
69-
Start the JAX container.
75+
Start the JAX container and wait for it to be ready.
7076
"""
7177
super().start()
78+
self._wait_for_container_to_be_ready()
7279
logging.info(f"JAX container started. Jupyter URL: {self.get_jupyter_url()}")
7380
return self
81+
82+
def stop(self, force=True):
83+
"""
84+
Stop the JAX container.
85+
"""
86+
super().stop(force)
87+
logging.info("JAX container stopped.")
88+
89+
@property
90+
def timeout(self):
91+
"""
92+
Get the container start timeout.
93+
"""
94+
return self.start_timeout
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import logging
2+
import tempfile
3+
import time
4+
from typing import Optional
5+
6+
from core.testcontainers.core.container import DockerContainer
7+
from core.testcontainers.core.waiting_utils import wait_container_is_ready
8+
from urllib.error import URLError
9+
10+
class WhisperJAXContainer(DockerContainer):
11+
"""
12+
Whisper-JAX container for fast speech recognition and transcription.
13+
14+
Example:
15+
16+
.. doctest::
17+
18+
>>> from testcontainers.whisper_jax import WhisperJAXContainer
19+
20+
>>> with WhisperJAXContainer("openai/whisper-large-v2") as whisper:
21+
... # Connect to the container
22+
... whisper.connect()
23+
...
24+
... # Transcribe an audio file
25+
... result = whisper.transcribe_file("path/to/audio/file.wav")
26+
... print(result['text'])
27+
...
28+
... # Transcribe a YouTube video
29+
... result = whisper.transcribe_youtube("https://www.youtube.com/watch?v=dQw4w9WgXcQ")
30+
... print(result['text'])
31+
"""
32+
33+
def __init__(self, model_name: str = "openai/whisper-large-v2", **kwargs):
34+
super().__init__("nvcr.io/nvidia/jax:23.08-py3", **kwargs)
35+
self.model_name = model_name
36+
self.with_exposed_ports(8888) # Expose Jupyter notebook port
37+
self.with_env("NVIDIA_VISIBLE_DEVICES", "all")
38+
self.with_env("CUDA_VISIBLE_DEVICES", "all")
39+
self.with_kwargs(runtime="nvidia") # Use NVIDIA runtime for GPU support
40+
41+
# Install required dependencies
42+
self.with_command("sh -c '"
43+
"pip install --no-cache-dir git+https://github.com/sanchit-gandhi/whisper-jax.git && "
44+
"pip install --no-cache-dir numpy soundfile youtube_dl transformers datasets && "
45+
"python -m pip install --upgrade --no-cache-dir jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && "
46+
"jupyter notebook --ip 0.0.0.0 --port 8888 --allow-root --NotebookApp.token='' --NotebookApp.password=''"
47+
"'")
48+
49+
@wait_container_is_ready(URLError)
50+
def _connect(self):
51+
url = f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}"
52+
res = urllib.request.urlopen(url)
53+
if res.status != 200:
54+
raise Exception(f"Failed to connect to Whisper-JAX container. Status: {res.status}")
55+
56+
def connect(self):
57+
"""
58+
Connect to the Whisper-JAX container and ensure it's ready.
59+
"""
60+
self._connect()
61+
logging.info("Successfully connected to Whisper-JAX container")
62+
63+
def run_command(self, command: str):
64+
"""
65+
Run a Python command inside the container.
66+
"""
67+
exec_result = self.exec(f"python -c '{command}'")
68+
return exec_result
69+
70+
def transcribe_file(self, file_path: str, task: str = "transcribe", return_timestamps: bool = False):
71+
"""
72+
Transcribe an audio file using Whisper-JAX.
73+
"""
74+
command = f"""
75+
import soundfile as sf
76+
from whisper_jax import FlaxWhisperPipline
77+
import jax.numpy as jnp
78+
79+
pipeline = FlaxWhisperPipline("{self.model_name}", dtype=jnp.bfloat16, batch_size=16)
80+
audio, sr = sf.read("{file_path}")
81+
result = pipeline({{"array": audio, "sampling_rate": sr}}, task="{task}", return_timestamps={return_timestamps})
82+
print(result)
83+
"""
84+
return self.run_command(command)
85+
86+
def transcribe_youtube(self, youtube_url: str, task: str = "transcribe", return_timestamps: bool = False):
87+
"""
88+
Transcribe a YouTube video using Whisper-JAX.
89+
"""
90+
command = f"""
91+
import tempfile
92+
import youtube_dl
93+
import soundfile as sf
94+
from whisper_jax import FlaxWhisperPipline
95+
import jax.numpy as jnp
96+
97+
def download_youtube_audio(youtube_url, output_file):
98+
ydl_opts = {{
99+
'format': 'bestaudio/best',
100+
'postprocessors': [{{
101+
'key': 'FFmpegExtractAudio',
102+
'preferredcodec': 'wav',
103+
'preferredquality': '192',
104+
}}],
105+
'outtmpl': output_file,
106+
}}
107+
with youtube_dl.YoutubeDL(ydl_opts) as ydl:
108+
ydl.download([youtube_url])
109+
110+
pipeline = FlaxWhisperPipline("{self.model_name}", dtype=jnp.bfloat16, batch_size=16)
111+
112+
with tempfile.NamedTemporaryFile(suffix=".wav") as temp_file:
113+
download_youtube_audio("{youtube_url}", temp_file.name)
114+
audio, sr = sf.read(temp_file.name)
115+
result = pipeline({{"array": audio, "sampling_rate": sr}}, task="{task}", return_timestamps={return_timestamps})
116+
print(result)
117+
"""
118+
return self.run_command(command)
119+
120+
def start(self):
121+
"""
122+
Start the Whisper-JAX container.
123+
"""
124+
super().start()
125+
logging.info(f"Whisper-JAX container started. Jupyter URL: http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}")
126+
return self

modules/jax/testcontainers/whisper-jax/__init__.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

modules/jax/tests/test_jax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
from testcontainers.jax import JAXContainer
2+
from modules.jax.testcontainers.jax_cuda import JAXContainer
33

44
def test_jax_container():
55
with JAXContainer() as jax_container:

0 commit comments

Comments
 (0)