@@ -35,32 +35,57 @@ def __init__(self, model_name: str = "openai/whisper-large-v2", hf_token: Option
3535 super ().__init__ ("nvcr.io/nvidia/jax:23.08-py3" , ** kwargs )
3636 self .model_name = model_name
3737 self .hf_token = hf_token
38- self .with_exposed_ports (8888 ) # Expose Jupyter notebook port
3938 self .with_env ("NVIDIA_VISIBLE_DEVICES" , "all" )
4039 self .with_env ("CUDA_VISIBLE_DEVICES" , "all" )
4140 self .with_kwargs (runtime = "nvidia" ) # Use NVIDIA runtime for GPU support
41+ self .start_timeout = 600 # 10 minutes
42+ self .connection_retries = 5
43+ self .connection_retry_delay = 10 # seconds
4244
4345 # Install required dependencies
4446 self .with_command ("sh -c '"
4547 "pip install --no-cache-dir git+https://github.com/sanchit-gandhi/whisper-jax.git && "
4648 "pip install --no-cache-dir numpy soundfile youtube_dl transformers datasets pyannote.audio && "
47- "python -m pip install --upgrade --no-cache-dir jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && "
48- "jupyter notebook --ip 0.0.0.0 --port 8888 --allow-root --NotebookApp.token='' --NotebookApp.password=''"
49+ "python -m pip install --upgrade --no-cache-dir jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
4950 "'" )
5051
5152 @wait_container_is_ready (URLError )
5253 def _connect (self ):
53- url = f"http://{ self .get_container_host_ip ()} :{ self .get_exposed_port (8888 )} "
54- res = urllib .request .urlopen (url )
55- if res .status != 200 :
56- raise Exception (f"Failed to connect to JAX-Whisper-Diarization container. Status: { res .status } " )
54+ for attempt in range (self .connection_retries ):
55+ try :
56+ # Check if JAX and other required libraries are properly installed and functioning
57+ result = self .run_command (
58+ "import jax; import whisper_jax; import pyannote.audio; "
59+ "print(f'JAX version: {jax.__version__}'); "
60+ "print(f'Whisper-JAX version: {whisper_jax.__version__}'); "
61+ "print(f'Pyannote Audio version: {pyannote.audio.__version__}'); "
62+ "print(f'Available devices: {jax.devices()}'); "
63+ "print(jax.numpy.add(1, 1))"
64+ )
65+
66+ if "JAX version" in result .output .decode () and "Available devices" in result .output .decode ():
67+ logging .info (f"JAX-Whisper-Diarization environment verified:\n { result .output .decode ()} " )
68+ return True
69+ else :
70+ raise Exception ("JAX-Whisper-Diarization environment check failed" )
71+
72+ except Exception as e :
73+ if attempt < self .connection_retries - 1 :
74+ logging .warning (f"Connection attempt { attempt + 1 } failed. Retrying in { self .connection_retry_delay } seconds..." )
75+ time .sleep (self .connection_retry_delay )
76+ else :
77+ raise Exception (f"Failed to connect to JAX-Whisper-Diarization container after { self .connection_retries } attempts: { str (e )} " )
78+
79+ return False
5780
5881 def connect (self ):
5982 """
6083 Connect to the JAX-Whisper-Diarization container and ensure it's ready.
84+ This method verifies that JAX, Whisper-JAX, and Pyannote Audio are properly installed and functioning.
85+ It also checks for available devices, including GPUs if applicable.
6186 """
6287 self ._connect ()
63- logging .info ("Successfully connected to JAX-Whisper-Diarization container" )
88+ logging .info ("Successfully connected to JAX-Whisper-Diarization container and verified the environment " )
6489
6590 def run_command (self , command : str ):
6691 """
@@ -242,8 +267,27 @@ def align(transcription, segments, group_by_speaker=True):
242267
243268 def start (self ):
244269 """
245- Start the JAX-Whisper-Diarization container.
270+ Start the JAX-Whisper-Diarization container and wait for it to be ready .
246271 """
247272 super ().start ()
248- logging .info (f"JAX-Whisper-Diarization container started. Jupyter URL: http://{ self .get_container_host_ip ()} :{ self .get_exposed_port (8888 )} " )
273+ self ._wait_for_container_to_be_ready ()
274+ logging .info ("JAX-Whisper-Diarization container started and ready." )
249275 return self
276+
277+ def _wait_for_container_to_be_ready (self ):
278+ # Wait for a specific log message that indicates the container is ready
279+ self .wait_for_logs ("Installation completed" )
280+
281+ def stop (self , force = True ):
282+ """
283+ Stop the JAX-Whisper-Diarization container.
284+ """
285+ super ().stop (force )
286+ logging .info ("JAX-Whisper-Diarization container stopped." )
287+
288+ @property
289+ def timeout (self ):
290+ """
291+ Get the container start timeout.
292+ """
293+ return self .start_timeout
0 commit comments