Skip to content

Commit 0d001c4

Browse files
authored
wasi-nn: fix backend leak on multiple loads (#4366)
cf. #4340
1 parent 8e60feb commit 0d001c4

File tree

2 files changed

+50
-40
lines changed

2 files changed

+50
-40
lines changed

core/iwasm/libraries/wasi-nn/src/wasi_nn.c

Lines changed: 49 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,43 @@ detect_and_load_backend(graph_encoding backend_hint,
397397
return ret;
398398
}
399399

400+
static wasi_nn_error
401+
ensure_backend(wasm_module_inst_t instance, graph_encoding encoding,
402+
WASINNContext **wasi_nn_ctx_ptr)
403+
{
404+
wasi_nn_error res;
405+
406+
graph_encoding loaded_backend = autodetect;
407+
if (!detect_and_load_backend(encoding, &loaded_backend)) {
408+
res = invalid_encoding;
409+
NN_ERR_PRINTF("load backend failed");
410+
goto fail;
411+
}
412+
413+
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
414+
if (wasi_nn_ctx->is_backend_ctx_initialized) {
415+
if (wasi_nn_ctx->backend != loaded_backend) {
416+
res = unsupported_operation;
417+
goto fail;
418+
}
419+
}
420+
else {
421+
wasi_nn_ctx->backend = loaded_backend;
422+
423+
/* init() the backend */
424+
call_wasi_nn_func(wasi_nn_ctx->backend, init, res,
425+
&wasi_nn_ctx->backend_ctx);
426+
if (res != success)
427+
goto fail;
428+
429+
wasi_nn_ctx->is_backend_ctx_initialized = true;
430+
}
431+
*wasi_nn_ctx_ptr = wasi_nn_ctx;
432+
return success;
433+
fail:
434+
return res;
435+
}
436+
400437
/* WASI-NN implementation */
401438

402439
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
@@ -410,14 +447,15 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
410447
graph_encoding encoding, execution_target target, graph *g)
411448
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
412449
{
450+
wasi_nn_error res;
451+
413452
NN_DBG_PRINTF("[WASI NN] LOAD [encoding=%d, target=%d]...", encoding,
414453
target);
415454

416455
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
417456
if (!instance)
418457
return runtime_error;
419458

420-
wasi_nn_error res;
421459
graph_builder_array builder_native = { 0 };
422460
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
423461
if (success
@@ -438,19 +476,8 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
438476
goto fail;
439477
}
440478

441-
graph_encoding loaded_backend = autodetect;
442-
if (!detect_and_load_backend(encoding, &loaded_backend)) {
443-
res = invalid_encoding;
444-
NN_ERR_PRINTF("load backend failed");
445-
goto fail;
446-
}
447-
448-
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
449-
wasi_nn_ctx->backend = loaded_backend;
450-
451-
/* init() the backend */
452-
call_wasi_nn_func(wasi_nn_ctx->backend, init, res,
453-
&wasi_nn_ctx->backend_ctx);
479+
WASINNContext *wasi_nn_ctx;
480+
res = ensure_backend(instance, encoding, &wasi_nn_ctx);
454481
if (res != success)
455482
goto fail;
456483

@@ -473,6 +500,8 @@ wasi_nn_error
473500
wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
474501
graph *g)
475502
{
503+
wasi_nn_error res;
504+
476505
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
477506
if (!instance) {
478507
return runtime_error;
@@ -496,19 +525,8 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
496525

497526
NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME %s...", name);
498527

499-
graph_encoding loaded_backend = autodetect;
500-
if (!detect_and_load_backend(autodetect, &loaded_backend)) {
501-
NN_ERR_PRINTF("load backend failed");
502-
return invalid_encoding;
503-
}
504-
505-
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
506-
wasi_nn_ctx->backend = loaded_backend;
507-
508-
wasi_nn_error res;
509-
/* init() the backend */
510-
call_wasi_nn_func(wasi_nn_ctx->backend, init, res,
511-
&wasi_nn_ctx->backend_ctx);
528+
WASINNContext *wasi_nn_ctx;
529+
res = ensure_backend(instance, autodetect, &wasi_nn_ctx);
512530
if (res != success)
513531
return res;
514532

@@ -526,6 +544,8 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
526544
int32_t name_len, char *config,
527545
int32_t config_len, graph *g)
528546
{
547+
wasi_nn_error res;
548+
529549
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
530550
if (!instance) {
531551
return runtime_error;
@@ -554,19 +574,8 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
554574

555575
NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME_WITH_CONFIG %s %s...", name, config);
556576

557-
graph_encoding loaded_backend = autodetect;
558-
if (!detect_and_load_backend(autodetect, &loaded_backend)) {
559-
NN_ERR_PRINTF("load backend failed");
560-
return invalid_encoding;
561-
}
562-
563-
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
564-
wasi_nn_ctx->backend = loaded_backend;
565-
566-
wasi_nn_error res;
567-
/* init() the backend */
568-
call_wasi_nn_func(wasi_nn_ctx->backend, init, res,
569-
&wasi_nn_ctx->backend_ctx);
577+
WASINNContext *wasi_nn_ctx;
578+
res = ensure_backend(instance, autodetect, &wasi_nn_ctx);
570579
if (res != success)
571580
return res;
572581

core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "wasm_export.h"
1111

1212
typedef struct {
13+
bool is_backend_ctx_initialized;
1314
bool is_model_loaded;
1415
graph_encoding backend;
1516
void *backend_ctx;

0 commit comments

Comments
 (0)