Skip to content

Commit 2aa5844

Browse files
authored
fix: Update CUDA imports (#847)
1 parent 42c5e60 commit 2aa5844

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/python/library/tritonclient/utils/cuda_shared_memory/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
# Check for dependency before other import so other imports can assume
2929
# the module is available (drop "try ... except .."")
3030
try:
31-
from cuda import cudart
31+
import cuda.bindings.runtime as cudart
3232
except ModuleNotFoundError as error:
3333
raise RuntimeError(
3434
"CUDA shared memory utilities require Python package 'cuda-python'"

src/python/library/tritonclient/utils/cuda_shared_memory/_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626

2727
from typing import Any
2828

29-
from cuda import cuda as cuda_driver
30-
from cuda import cudart
29+
import cuda.bindings.driver as cuda_driver
30+
import cuda.bindings.runtime as cudart
3131

3232

3333
def call_cuda_function(function, *argv):

0 commit comments

Comments
 (0)