@@ -139,6 +139,7 @@ typedef struct cu_ops_t {
139
139
CUresult (* cuGetErrorName )(CUresult error , const char * * pStr );
140
140
CUresult (* cuGetErrorString )(CUresult error , const char * * pStr );
141
141
CUresult (* cuCtxGetCurrent )(CUcontext * pctx );
142
+ CUresult (* cuCtxGetDevice )(CUdevice * device );
142
143
CUresult (* cuCtxSetCurrent )(CUcontext ctx );
143
144
CUresult (* cuIpcGetMemHandle )(CUipcMemHandle * pHandle , CUdeviceptr dptr );
144
145
CUresult (* cuIpcOpenMemHandle )(CUdeviceptr * pdptr , CUipcMemHandle handle ,
@@ -224,6 +225,8 @@ static void init_cu_global_state(void) {
224
225
utils_get_symbol_addr (lib_handle , "cuGetErrorString" , lib_name );
225
226
* (void * * )& g_cu_ops .cuCtxGetCurrent =
226
227
utils_get_symbol_addr (lib_handle , "cuCtxGetCurrent" , lib_name );
228
+ * (void * * )& g_cu_ops .cuCtxGetDevice =
229
+ utils_get_symbol_addr (lib_handle , "cuCtxGetDevice" , lib_name );
227
230
* (void * * )& g_cu_ops .cuCtxSetCurrent =
228
231
utils_get_symbol_addr (lib_handle , "cuCtxSetCurrent" , lib_name );
229
232
* (void * * )& g_cu_ops .cuIpcGetMemHandle =
@@ -237,9 +240,9 @@ static void init_cu_global_state(void) {
237
240
!g_cu_ops .cuMemHostAlloc || !g_cu_ops .cuMemAllocManaged ||
238
241
!g_cu_ops .cuMemFree || !g_cu_ops .cuMemFreeHost ||
239
242
!g_cu_ops .cuGetErrorName || !g_cu_ops .cuGetErrorString ||
240
- !g_cu_ops .cuCtxGetCurrent || !g_cu_ops .cuCtxSetCurrent ||
241
- !g_cu_ops .cuIpcGetMemHandle || !g_cu_ops .cuIpcOpenMemHandle ||
242
- !g_cu_ops .cuIpcCloseMemHandle ) {
243
+ !g_cu_ops .cuCtxGetCurrent || !g_cu_ops .cuCtxGetDevice ||
244
+ !g_cu_ops .cuCtxSetCurrent || !g_cu_ops .cuIpcGetMemHandle ||
245
+ !g_cu_ops .cuIpcOpenMemHandle || ! g_cu_ops . cuIpcCloseMemHandle ) {
243
246
LOG_FATAL ("Required CUDA symbols not found." );
244
247
Init_cu_global_state_failed = true;
245
248
utils_close_library (lib_handle );
@@ -263,8 +266,29 @@ umf_result_t umfCUDAMemoryProviderParamsCreate(
263
266
return UMF_RESULT_ERROR_OUT_OF_HOST_MEMORY ;
264
267
}
265
268
266
- params_data -> cuda_context_handle = NULL ;
267
- params_data -> cuda_device_handle = -1 ;
269
+ utils_init_once (& cu_is_initialized , init_cu_global_state );
270
+ if (Init_cu_global_state_failed ) {
271
+ LOG_FATAL ("Loading CUDA symbols failed" );
272
+ return UMF_RESULT_ERROR_DEPENDENCY_UNAVAILABLE ;
273
+ }
274
+
275
+ // initialize context and device to the current ones
276
+ CUcontext current_ctx = NULL ;
277
+ CUresult cu_result = g_cu_ops .cuCtxGetCurrent (& current_ctx );
278
+ if (cu_result == CUDA_SUCCESS ) {
279
+ params_data -> cuda_context_handle = current_ctx ;
280
+ } else {
281
+ params_data -> cuda_context_handle = NULL ;
282
+ }
283
+
284
+ CUdevice current_device = -1 ;
285
+ cu_result = g_cu_ops .cuCtxGetDevice (& current_device );
286
+ if (cu_result == CUDA_SUCCESS ) {
287
+ params_data -> cuda_device_handle = current_device ;
288
+ } else {
289
+ params_data -> cuda_device_handle = -1 ;
290
+ }
291
+
268
292
params_data -> memory_type = UMF_MEMORY_TYPE_UNKNOWN ;
269
293
params_data -> alloc_flags = 0 ;
270
294
@@ -345,6 +369,12 @@ static umf_result_t cu_memory_provider_initialize(void *params,
345
369
}
346
370
347
371
if (cu_params -> cuda_context_handle == NULL ) {
372
+ LOG_ERR ("Invalid context handle" );
373
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
374
+ }
375
+
376
+ if (cu_params -> cuda_device_handle < 0 ) {
377
+ LOG_ERR ("Invalid device handle" );
348
378
return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
349
379
}
350
380
0 commit comments