Skip to content

Commit 4d6b8dc

Browse files
authored
wasi_nn.h: make this compatible with wasi_ephemeral_nn (#4330)
- wasi_nn.h: make this compatible with wasi_ephemeral_nn cf. #4323 - fix WASM_ENABLE_WASI_EPHEMERAL_NN build this structure is used by host logic as well. ideally definitions for wasm and host should be separated. until it happens, check __wasm__ to avoid the breakage.
1 parent 99c75b5 commit 4d6b8dc

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

core/iwasm/libraries/wasi-nn/include/wasi_nn.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,33 @@
1515
#include <stdint.h>
1616
#include "wasi_nn_types.h"
1717

18+
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
19+
#define WASI_NN_IMPORT(name) \
20+
__attribute__((import_module("wasi_ephemeral_nn"), import_name(name)))
21+
#else
1822
#define WASI_NN_IMPORT(name) \
1923
__attribute__((import_module("wasi_nn"), import_name(name)))
24+
#endif
2025

2126
/**
2227
* @brief Load an opaque sequence of bytes to use for inference.
2328
*
2429
* @param builder Model builder.
30+
* @param builder_len The size of model builder.
2531
* @param encoding Model encoding.
2632
* @param target Execution target.
2733
* @param g Graph.
2834
* @return wasi_nn_error Execution status.
2935
*/
36+
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
37+
wasi_nn_error
38+
load(graph_builder *builder, uint32_t builder_len, graph_encoding encoding,
39+
execution_target target, graph *g) WASI_NN_IMPORT("load");
40+
#else
3041
wasi_nn_error
3142
load(graph_builder_array *builder, graph_encoding encoding,
3243
execution_target target, graph *g) WASI_NN_IMPORT("load");
44+
#endif
3345

3446
wasi_nn_error
3547
load_by_name(const char *name, uint32_t name_len, graph *g)
@@ -84,9 +96,16 @@ compute(graph_execution_context ctx) WASI_NN_IMPORT("compute");
8496
* copied number of bytes.
8597
* @return wasi_nn_error Execution status.
8698
*/
99+
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
100+
wasi_nn_error
101+
get_output(graph_execution_context ctx, uint32_t index,
102+
tensor_data output_tensor, uint32_t output_tensor_max_size,
103+
uint32_t *output_tensor_size) WASI_NN_IMPORT("get_output");
104+
#else
87105
wasi_nn_error
88106
get_output(graph_execution_context ctx, uint32_t index,
89107
tensor_data output_tensor, uint32_t *output_tensor_size)
90108
WASI_NN_IMPORT("get_output");
109+
#endif
91110

92111
#endif

core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,11 @@ typedef struct {
7777
// Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To
7878
// represent a tensor containing a single value, use `[1]` for the tensor
7979
// dimensions.
80+
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 && defined(__wasm__)
81+
tensor_dimensions dimensions;
82+
#else
8083
tensor_dimensions *dimensions;
84+
#endif
8185
// Describe the type of element in the tensor (e.g., f32).
8286
uint8_t type;
8387
uint8_t _pad[3];

0 commit comments

Comments
 (0)