Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
# Check for dependency before other import so other imports can assume
# the module is available (drop "try ... except .."")
try:
import cuda.bindings.driver as cuda_driver
import cuda.bindings.runtime as cudart
except ModuleNotFoundError as error:
raise RuntimeError(
Expand Down Expand Up @@ -73,15 +74,15 @@ def _get_or_create_global_cuda_stream(device_id):
def _support_uva(shm_device_id, ext_device_id):
try:
support_uva = call_cuda_function(
cudart.cudaDeviceGetAttribute,
cudart.cudaDeviceAttr.cudaDevAttrUnifiedAddressing,
shm_device_id,
cudart.cudaDeviceGetAttribute(
cudart.cudaDeviceAttr.cudaDevAttrUnifiedAddressing, shm_device_id
)
)
if (support_uva != 0) and (ext_device_id != -1):
support_uva = call_cuda_function(
cudart.cudaDeviceGetAttribute,
cudart.cudaDeviceAttr.cudaDevAttrUnifiedAddressing,
ext_device_id,
cudart.cudaDeviceGetAttribute(
cudart.cudaDeviceAttr.cudaDevAttrUnifiedAddressing, ext_device_id
)
)
if support_uva == 0:
raise CudaSharedMemoryException(
Expand Down Expand Up @@ -127,10 +128,11 @@ def create_shared_memory_region(triton_shm_name, byte_size, device_id):
"""
prev_device = None
try:
prev_device = call_cuda_function(cudart.cudaGetDevice)
call_cuda_function(cudart.cudaSetDevice, device_id)
device_ptr = call_cuda_function(cudart.cudaMalloc, byte_size)
cuda_shm_handle = call_cuda_function(cudart.cudaIpcGetMemHandle, device_ptr)
cuda_driver.cuInit(device_id)
prev_device = call_cuda_function(cudart.cudaGetDevice())
call_cuda_function(cudart.cudaSetDevice(device_id))
device_ptr = call_cuda_function(cudart.cudaMalloc(byte_size))
cuda_shm_handle = call_cuda_function(cudart.cudaIpcGetMemHandle(device_ptr))
triton_shm_handle = CudaSharedMemoryRegion(
triton_shm_name, cuda_shm_handle, device_ptr, byte_size, device_id
)
Expand Down Expand Up @@ -210,25 +212,27 @@ def set_shared_memory_region(cuda_shm_handle, input_values):
input_value = input_value.item()
byte_size = np.dtype(np.byte).itemsize * len(input_value)
call_cuda_function(
cudart.cudaMemcpyAsync,
cuda_shm_handle._base_addr + offset_current,
input_value,
byte_size,
cudart.cudaMemcpyKind.cudaMemcpyDefault,
stream,
cudart.cudaMemcpyAsync(
cuda_shm_handle._base_addr + offset_current,
input_value,
byte_size,
cudart.cudaMemcpyKind.cudaMemcpyDefault,
stream,
)
)
else:
byte_size = input_value.size * input_value.itemsize
call_cuda_function(
cudart.cudaMemcpyAsync,
cuda_shm_handle._base_addr + offset_current,
input_value.ctypes.data,
byte_size,
cudart.cudaMemcpyKind.cudaMemcpyDefault,
stream,
cudart.cudaMemcpyAsync(
cuda_shm_handle._base_addr + offset_current,
input_value.ctypes.data,
byte_size,
cudart.cudaMemcpyKind.cudaMemcpyDefault,
stream,
)
)
offset_current += byte_size
call_cuda_function(cudart.cudaStreamSynchronize, stream)
call_cuda_function(cudart.cudaStreamSynchronize(stream))
except Exception as ex:
if not isinstance(ex, CudaSharedMemoryException):
raise CudaSharedMemoryException(
Expand Down Expand Up @@ -265,15 +269,16 @@ def get_contents_as_numpy(cuda_shm_handle, datatype, shape):
# Numpy can only read from host buffer.
host_buffer = (ctypes.c_char * cuda_shm_handle._byte_size)()
call_cuda_function(
cudart.cudaMemcpyAsync,
host_buffer,
cuda_shm_handle._base_addr,
cuda_shm_handle._byte_size,
cudart.cudaMemcpyKind.cudaMemcpyDefault,
stream,
cudart.cudaMemcpyAsync(
host_buffer,
cuda_shm_handle._base_addr,
cuda_shm_handle._byte_size,
cudart.cudaMemcpyKind.cudaMemcpyDefault,
stream,
)
)
# Sync to ensure the host buffer is ready
call_cuda_function(cudart.cudaStreamSynchronize, stream)
call_cuda_function(cudart.cudaStreamSynchronize(stream))
except Exception as ex:
if not isinstance(ex, CudaSharedMemoryException):
raise CudaSharedMemoryException(
Expand Down Expand Up @@ -368,14 +373,15 @@ def set_shared_memory_region_from_dlpack(cuda_shm_handle, input_values):

try:
call_cuda_function(
cudart.cudaMemcpyAsync,
cuda_shm_handle._base_addr + offset_current,
data_ptr,
byte_size,
cudart.cudaMemcpyKind.cudaMemcpyDefault,
stream,
cudart.cudaMemcpyAsync(
cuda_shm_handle._base_addr + offset_current,
data_ptr,
byte_size,
cudart.cudaMemcpyKind.cudaMemcpyDefault,
stream,
)
)
call_cuda_function(cudart.cudaStreamSynchronize, stream)
call_cuda_function(cudart.cudaStreamSynchronize(stream))
except Exception as ex:
if not isinstance(ex, CudaSharedMemoryException):
raise CudaSharedMemoryException(
Expand Down
26 changes: 14 additions & 12 deletions src/python/library/tritonclient/utils/cuda_shared_memory/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
import cuda.bindings.runtime as cudart


def call_cuda_function(function, *argv):
res = function(*argv)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to change this place? Seems like the difference is where the CUDA function actually gets called (inside v.s. outside call_cuda_function)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only save my local changes, it's not real change.
I been testing locally and seeking for some specific output.
This is not a working change otherwise PR would be in open state.

def call_cuda_function(res):
err = res[0]
if isinstance(err, cudart.cudaError_t):
if err != cudart.cudaError_t.cudaSuccess:
Expand Down Expand Up @@ -92,9 +91,10 @@ def __del__(self):
return
prev_device = None
try:
prev_device = call_cuda_function(cudart.cudaGetDevice)
call_cuda_function(cudart.cudaSetDevice, self._device_id)
call_cuda_function(cudart.cudaFree, self._base_addr)
cuda_driver.cuInit(self._device_id)
prev_device = call_cuda_function(cudart.cudaGetDevice())
call_cuda_function(cudart.cudaSetDevice(self._device_id))
call_cuda_function(cudart.cudaFree(self._base_addr))
finally:
if prev_device is not None:
maybe_set_device(prev_device)
Expand All @@ -104,9 +104,10 @@ class CudaStream:
def __init__(self, device_id):
prev_device = None
try:
prev_device = call_cuda_function(cudart.cudaGetDevice)
call_cuda_function(cudart.cudaSetDevice, device_id)
self._stream = call_cuda_function(cudart.cudaStreamCreate)
cuda_driver.cuInit(device_id)
prev_device = call_cuda_function(cudart.cudaGetDevice())
call_cuda_function(cudart.cudaSetDevice(device_id))
self._stream = call_cuda_function(cudart.cudaStreamCreate())
finally:
if prev_device is not None:
maybe_set_device(prev_device)
Expand All @@ -117,12 +118,13 @@ def __del__(self):
if not hasattr(self, "_stream") or self._stream is None:
return
# [FIXME] __del__ is not the best place for releasing resources
call_cuda_function(cudart.cudaStreamDestroy, self._stream)
call_cuda_function(cudart.cudaStreamDestroy(self._stream))
self._stream = None


def maybe_set_device(device_id):
device = call_cuda_function(cuda_driver.cuDeviceGet, device_id)
_, active = call_cuda_function(cuda_driver.cuDevicePrimaryCtxGetState, device)
cuda_driver.cuInit(device_id)
call_cuda_function(cuda_driver.cuDeviceGet(device_id))
_, active = call_cuda_function(cuda_driver.cuDevicePrimaryCtxGetState(device_id))
if active:
call_cuda_function(cudart.cudaSetDevice, device_id)
call_cuda_function(cudart.cudaSetDevice(device_id))
Loading