1
+ import logging
2
+ import urllib .request
3
+ from urllib .error import URLError
4
+
5
+ from core .testcontainers .core .container import DockerContainer
6
+ from core .testcontainers .core .waiting_utils import wait_container_is_ready
7
+ from core .testcontainers .core .config import testcontainers_config
8
+ from core .testcontainers .core .waiting_utils import wait_for_logs
9
+
10
+ class JAXContainer (DockerContainer ):
11
+ """
12
+ JAX container for GPU-accelerated numerical computing and machine learning.
13
+
14
+ Example:
15
+
16
+ .. doctest::
17
+
18
+ >>> import jax
19
+ >>> from testcontainers.jax import JAXContainer
20
+
21
+ >>> with JAXContainer("nvcr.io/nvidia/jax:23.08-py3") as jax_container:
22
+ ... # Connect to the container
23
+ ... jax_container.connect()
24
+ ...
25
+ ... # Run a simple JAX computation
26
+ ... result = jax.numpy.add(1, 1)
27
+ ... assert result == 2
28
+
29
+ .. auto-class:: JAXContainer
30
+ :members:
31
+ :undoc-members:
32
+ :show-inheritance:
33
+ """
34
+
35
+ def __init__ (self , image = "huggingface/transformers-jax-light:latest" , ** kwargs ):
36
+ super ().__init__ (image , ** kwargs )
37
+ self .with_exposed_ports (8888 ) # Expose Jupyter notebook port
38
+ self .with_env ("NVIDIA_VISIBLE_DEVICES" , "all" )
39
+ self .with_env ("CUDA_VISIBLE_DEVICES" , "all" )
40
+ self .with_kwargs (runtime = "nvidia" ) # Use NVIDIA runtime for GPU support
41
+ self .start_timeout = 600 # 10 minutes
42
+
43
+ @wait_container_is_ready (URLError )
44
+ def _connect (self ):
45
+ url = f"http://{ self .get_container_host_ip ()} :{ self .get_exposed_port (8888 )} "
46
+ res = urllib .request .urlopen (url , timeout = self .start_timeout )
47
+ if res .status != 200 :
48
+ raise Exception (f"Failed to connect to JAX container. Status: { res .status } " )
49
+
50
+ def connect (self ):
51
+ """
52
+ Connect to the JAX container and ensure it's ready.
53
+ """
54
+ self ._connect ()
55
+ logging .info ("Successfully connected to JAX container" )
56
+
57
+ def get_jupyter_url (self ):
58
+ """
59
+ Get the URL for accessing the Jupyter notebook server.
60
+ """
61
+ return f"http://{ self .get_container_host_ip ()} :{ self .get_exposed_port (8888 )} "
62
+
63
+ def run_jax_command (self , command ):
64
+ """
65
+ Run a JAX command inside the container.
66
+ """
67
+ exec_result = self .exec (f"python -c '{ command } '" )
68
+ return exec_result
69
+
70
+ def _wait_for_container_to_be_ready (self ):
71
+ wait_for_logs (self , "Jupyter Server" , timeout = self .start_timeout )
72
+
73
+ def start (self ):
74
+ """
75
+ Start the JAX container and wait for it to be ready.
76
+ """
77
+ super ().start ()
78
+ self ._wait_for_container_to_be_ready ()
79
+ logging .info (f"JAX container started. Jupyter URL: { self .get_jupyter_url ()} " )
80
+ return self
81
+
82
+ def stop (self , force = True ):
83
+ """
84
+ Stop the JAX container.
85
+ """
86
+ super ().stop (force )
87
+ logging .info ("JAX container stopped." )
88
+
89
+ @property
90
+ def timeout (self ):
91
+ """
92
+ Get the container start timeout.
93
+ """
94
+ return self .start_timeout
0 commit comments