@@ -397,6 +397,43 @@ detect_and_load_backend(graph_encoding backend_hint,
397
397
return ret ;
398
398
}
399
399
400
+ static wasi_nn_error
401
+ ensure_backend (wasm_module_inst_t instance , graph_encoding encoding ,
402
+ WASINNContext * * wasi_nn_ctx_ptr )
403
+ {
404
+ wasi_nn_error res ;
405
+
406
+ graph_encoding loaded_backend = autodetect ;
407
+ if (!detect_and_load_backend (encoding , & loaded_backend )) {
408
+ res = invalid_encoding ;
409
+ NN_ERR_PRINTF ("load backend failed" );
410
+ goto fail ;
411
+ }
412
+
413
+ WASINNContext * wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx (instance );
414
+ if (wasi_nn_ctx -> is_backend_ctx_initialized ) {
415
+ if (wasi_nn_ctx -> backend != loaded_backend ) {
416
+ res = unsupported_operation ;
417
+ goto fail ;
418
+ }
419
+ }
420
+ else {
421
+ wasi_nn_ctx -> backend = loaded_backend ;
422
+
423
+ /* init() the backend */
424
+ call_wasi_nn_func (wasi_nn_ctx -> backend , init , res ,
425
+ & wasi_nn_ctx -> backend_ctx );
426
+ if (res != success )
427
+ goto fail ;
428
+
429
+ wasi_nn_ctx -> is_backend_ctx_initialized = true;
430
+ }
431
+ * wasi_nn_ctx_ptr = wasi_nn_ctx ;
432
+ return success ;
433
+ fail :
434
+ return res ;
435
+ }
436
+
400
437
/* WASI-NN implementation */
401
438
402
439
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
@@ -410,14 +447,15 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
410
447
graph_encoding encoding , execution_target target , graph * g )
411
448
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
412
449
{
450
+ wasi_nn_error res ;
451
+
413
452
NN_DBG_PRINTF ("[WASI NN] LOAD [encoding=%d, target=%d]..." , encoding ,
414
453
target );
415
454
416
455
wasm_module_inst_t instance = wasm_runtime_get_module_inst (exec_env );
417
456
if (!instance )
418
457
return runtime_error ;
419
458
420
- wasi_nn_error res ;
421
459
graph_builder_array builder_native = { 0 };
422
460
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
423
461
if (success
@@ -438,19 +476,8 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
438
476
goto fail ;
439
477
}
440
478
441
- graph_encoding loaded_backend = autodetect ;
442
- if (!detect_and_load_backend (encoding , & loaded_backend )) {
443
- res = invalid_encoding ;
444
- NN_ERR_PRINTF ("load backend failed" );
445
- goto fail ;
446
- }
447
-
448
- WASINNContext * wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx (instance );
449
- wasi_nn_ctx -> backend = loaded_backend ;
450
-
451
- /* init() the backend */
452
- call_wasi_nn_func (wasi_nn_ctx -> backend , init , res ,
453
- & wasi_nn_ctx -> backend_ctx );
479
+ WASINNContext * wasi_nn_ctx ;
480
+ res = ensure_backend (instance , encoding , & wasi_nn_ctx );
454
481
if (res != success )
455
482
goto fail ;
456
483
@@ -473,6 +500,8 @@ wasi_nn_error
473
500
wasi_nn_load_by_name (wasm_exec_env_t exec_env , char * name , uint32_t name_len ,
474
501
graph * g )
475
502
{
503
+ wasi_nn_error res ;
504
+
476
505
wasm_module_inst_t instance = wasm_runtime_get_module_inst (exec_env );
477
506
if (!instance ) {
478
507
return runtime_error ;
@@ -496,19 +525,8 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
496
525
497
526
NN_DBG_PRINTF ("[WASI NN] LOAD_BY_NAME %s..." , name );
498
527
499
- graph_encoding loaded_backend = autodetect ;
500
- if (!detect_and_load_backend (autodetect , & loaded_backend )) {
501
- NN_ERR_PRINTF ("load backend failed" );
502
- return invalid_encoding ;
503
- }
504
-
505
- WASINNContext * wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx (instance );
506
- wasi_nn_ctx -> backend = loaded_backend ;
507
-
508
- wasi_nn_error res ;
509
- /* init() the backend */
510
- call_wasi_nn_func (wasi_nn_ctx -> backend , init , res ,
511
- & wasi_nn_ctx -> backend_ctx );
528
+ WASINNContext * wasi_nn_ctx ;
529
+ res = ensure_backend (instance , autodetect , & wasi_nn_ctx );
512
530
if (res != success )
513
531
return res ;
514
532
@@ -526,6 +544,8 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
526
544
int32_t name_len , char * config ,
527
545
int32_t config_len , graph * g )
528
546
{
547
+ wasi_nn_error res ;
548
+
529
549
wasm_module_inst_t instance = wasm_runtime_get_module_inst (exec_env );
530
550
if (!instance ) {
531
551
return runtime_error ;
@@ -554,19 +574,8 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
554
574
555
575
NN_DBG_PRINTF ("[WASI NN] LOAD_BY_NAME_WITH_CONFIG %s %s..." , name , config );
556
576
557
- graph_encoding loaded_backend = autodetect ;
558
- if (!detect_and_load_backend (autodetect , & loaded_backend )) {
559
- NN_ERR_PRINTF ("load backend failed" );
560
- return invalid_encoding ;
561
- }
562
-
563
- WASINNContext * wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx (instance );
564
- wasi_nn_ctx -> backend = loaded_backend ;
565
-
566
- wasi_nn_error res ;
567
- /* init() the backend */
568
- call_wasi_nn_func (wasi_nn_ctx -> backend , init , res ,
569
- & wasi_nn_ctx -> backend_ctx );
577
+ WASINNContext * wasi_nn_ctx ;
578
+ res = ensure_backend (instance , autodetect , & wasi_nn_ctx );
570
579
if (res != success )
571
580
return res ;
572
581
0 commit comments