@@ -139,6 +139,7 @@ typedef struct cu_ops_t {
139
139
CUresult (* cuGetErrorName )(CUresult error , const char * * pStr );
140
140
CUresult (* cuGetErrorString )(CUresult error , const char * * pStr );
141
141
CUresult (* cuCtxGetCurrent )(CUcontext * pctx );
142
+ CUresult (* cuCtxGetDevice )(CUdevice * device );
142
143
CUresult (* cuCtxSetCurrent )(CUcontext ctx );
143
144
CUresult (* cuIpcGetMemHandle )(CUipcMemHandle * pHandle , CUdeviceptr dptr );
144
145
CUresult (* cuIpcOpenMemHandle )(CUdeviceptr * pdptr , CUipcMemHandle handle ,
@@ -221,6 +222,8 @@ static void init_cu_global_state(void) {
221
222
utils_get_symbol_addr (lib_handle , "cuGetErrorString" , lib_name );
222
223
* (void * * )& g_cu_ops .cuCtxGetCurrent =
223
224
utils_get_symbol_addr (lib_handle , "cuCtxGetCurrent" , lib_name );
225
+ * (void * * )& g_cu_ops .cuCtxGetDevice =
226
+ utils_get_symbol_addr (lib_handle , "cuCtxGetDevice" , lib_name );
224
227
* (void * * )& g_cu_ops .cuCtxSetCurrent =
225
228
utils_get_symbol_addr (lib_handle , "cuCtxSetCurrent" , lib_name );
226
229
* (void * * )& g_cu_ops .cuIpcGetMemHandle =
@@ -234,9 +237,9 @@ static void init_cu_global_state(void) {
234
237
!g_cu_ops .cuMemHostAlloc || !g_cu_ops .cuMemAllocManaged ||
235
238
!g_cu_ops .cuMemFree || !g_cu_ops .cuMemFreeHost ||
236
239
!g_cu_ops .cuGetErrorName || !g_cu_ops .cuGetErrorString ||
237
- !g_cu_ops .cuCtxGetCurrent || !g_cu_ops .cuCtxSetCurrent ||
238
- !g_cu_ops .cuIpcGetMemHandle || !g_cu_ops .cuIpcOpenMemHandle ||
239
- !g_cu_ops .cuIpcCloseMemHandle ) {
240
+ !g_cu_ops .cuCtxGetCurrent || !g_cu_ops .cuCtxGetDevice ||
241
+ !g_cu_ops .cuCtxSetCurrent || !g_cu_ops .cuIpcGetMemHandle ||
242
+ !g_cu_ops .cuIpcOpenMemHandle || ! g_cu_ops . cuIpcCloseMemHandle ) {
240
243
LOG_FATAL ("Required CUDA symbols not found." );
241
244
Init_cu_global_state_failed = true;
242
245
utils_close_library (lib_handle );
@@ -260,8 +263,29 @@ umf_result_t umfCUDAMemoryProviderParamsCreate(
260
263
return UMF_RESULT_ERROR_OUT_OF_HOST_MEMORY ;
261
264
}
262
265
263
- params_data -> cuda_context_handle = NULL ;
264
- params_data -> cuda_device_handle = -1 ;
266
+ utils_init_once (& cu_is_initialized , init_cu_global_state );
267
+ if (Init_cu_global_state_failed ) {
268
+ LOG_FATAL ("Loading CUDA symbols failed" );
269
+ return UMF_RESULT_ERROR_DEPENDENCY_UNAVAILABLE ;
270
+ }
271
+
272
+ // initialize context and device to the current ones
273
+ CUcontext current_ctx = NULL ;
274
+ CUresult cu_result = g_cu_ops .cuCtxGetCurrent (& current_ctx );
275
+ if (cu_result == CUDA_SUCCESS ) {
276
+ params_data -> cuda_context_handle = current_ctx ;
277
+ } else {
278
+ params_data -> cuda_context_handle = NULL ;
279
+ }
280
+
281
+ CUdevice current_device = -1 ;
282
+ cu_result = g_cu_ops .cuCtxGetDevice (& current_device );
283
+ if (cu_result == CUDA_SUCCESS ) {
284
+ params_data -> cuda_device_handle = current_device ;
285
+ } else {
286
+ params_data -> cuda_device_handle = -1 ;
287
+ }
288
+
265
289
params_data -> memory_type = UMF_MEMORY_TYPE_UNKNOWN ;
266
290
params_data -> alloc_flags = 0 ;
267
291
@@ -342,6 +366,12 @@ static umf_result_t cu_memory_provider_initialize(void *params,
342
366
}
343
367
344
368
if (cu_params -> cuda_context_handle == NULL ) {
369
+ LOG_ERR ("Invalid context handle" );
370
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
371
+ }
372
+
373
+ if (cu_params -> cuda_device_handle < 0 ) {
374
+ LOG_ERR ("Invalid device handle" );
345
375
return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
346
376
}
347
377
0 commit comments