Skip to content

Commit 1978c85

Browse files
authored
feat(mlx): add gemma3 example (second-state#189)
1 parent 29adf39 commit 1978c85

File tree

8 files changed

+257
-4
lines changed

8 files changed

+257
-4
lines changed
File renamed without changes.

wasmedge-mlx/README.md renamed to wasmedge-mlx/llama/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ wget https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0/resolve/main/toke
4141

4242
## Build wasm
4343

44-
Run the following command to build wasm, the output WASM file will be at `target/wasm32-wasi/release/`
44+
Run the following command to build wasm, the output WASM file will be at `target/wasm32-wasip1/release/`
4545

4646
```bash
47-
cargo build --target wasm32-wasi --release
47+
cargo build --target wasm32-wasip1 --release
4848
```
4949
## Execute
5050

@@ -53,7 +53,7 @@ Execute the WASM with the `wasmedge` using nn-preload to load model.
5353
``` bash
5454
wasmedge --dir .:. \
5555
--nn-preload default:mlx:AUTO:model.safetensors \
56-
./target/wasm32-wasi/release/wasmedge-mlx.wasm default
56+
./target/wasm32-wasip1/release/wasmedge-mlx.wasm default
5757

5858
```
5959

@@ -63,7 +63,7 @@ For example:
6363
``` bash
6464
wasmedge --dir .:. \
6565
--nn-preload default:mlx:AUTO:llama2-7b/model-00001-of-00002.safetensors:llama2-7b/model-00002-of-00002.safetensors \
66-
./target/wasm32-wasi/release/wasmedge-mlx.wasm default
66+
./target/wasm32-wasip1/release/wasmedge-mlx.wasm default
6767
```
6868

6969
## Other
File renamed without changes.

wasmedge-mlx/vlm/Cargo.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
[package]
2+
name = "wasmedge-vlm"
3+
version = "0.1.0"
4+
edition = "2024"
5+
6+
[dependencies]
7+
serde_json = "1.0"
8+
wasmedge-wasi-nn = "0.8.0"

wasmedge-mlx/vlm/README.md

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# VLM example with WasmEdge WASI-NN MLX plugin
2+
3+
This example demonstrates using WasmEdge WASI-NN MLX plugin to perform an inference task with VLM model.
4+
5+
## Supported Models
6+
7+
| Family | Models |
8+
|--------|--------|
9+
| Gemma 3 | gemma-3-4b-pt-bf16 |
10+
11+
## Install WasmEdge with WASI-NN MLX plugin
12+
13+
The MLX backend relies on [MLX](https://github.com/ml-explore/mlx), but we will auto-download MLX when you build WasmEdge. You do not need to install it yourself. If you want to custom MLX, install it yourself or set the `CMAKE_PREFIX_PATH` variable when configuring cmake.
14+
15+
Build and install WasmEdge from source:
16+
17+
``` bash
18+
cd <path/to/your/wasmedge/source/folder>
19+
20+
cmake -GNinja -Bbuild -DCMAKE_BUILD_TYPE=Release -DWASMEDGE_PLUGIN_WASI_NN_BACKEND="mlx"
21+
cmake --build build
22+
23+
# For the WASI-NN plugin, you should install this project.
24+
cmake --install build
25+
```
26+
27+
Then you will have an executable `wasmedge` runtime under `/usr/local/bin` and the WASI-NN with MLX backend plug-in under `/usr/local/lib/wasmedge/libwasmedgePluginWasiNN.so` after installation.
28+
29+
## Install dependencies
30+
31+
Currently, we use the Python transformer library to embed the prompt and image to input the token. You can use any other library instead of this step.
32+
33+
``` bash
34+
sudo apt install python3 python3-pip
35+
pip install transformers pillow mlx
36+
```
37+
38+
## Download the model and tokenizer
39+
40+
In this example, we will use `gemma-3-4b-pt-bf16`.
41+
42+
``` bash
43+
git clone https://huggingface.co/mlx-community/gemma-3-4b-pt-bf16
44+
```
45+
46+
## Build wasm
47+
48+
Run the following command to build wasm, the output WASM file will be at `target/wasm32-wasip1/release/`
49+
50+
```bash
51+
cargo build --target wasm32-wasip1 --release
52+
```
53+
## Execute
54+
55+
Execute the WASM with the `wasmedge` using nn-preload to load model.
56+
57+
``` bash
58+
# Download sample image
59+
wget https://github.com/WasmEdge/WasmEdge/raw/master/docs/wasmedge-runtime-logo.png
60+
61+
# python encode.py <model_path> <image_path> <prompt>
62+
python encode.py gemma-3-4b-it-bf16 wasmedge-runtime-logo.png "What is this icon?"
63+
64+
wasmedge --dir .:. \
65+
--nn-preload default:mlx:AUTO:model.safetensors \
66+
./target/wasm32-wasip1/release/wasmedge-vlm.wasm default
67+
68+
# python encode.py <model_path> <Output mlx array path>
69+
python decode.py gemma-3-4b-it-bf16 Answer.npy
70+
71+
```
72+
73+
If your model has multiple weight files, you need to provide all in the nn-preload.
74+
75+
For example:
76+
``` bash
77+
wasmedge --dir .:. \
78+
--nn-preload default:mlx:AUTO:gemma-3-4b-it-bf16/model-00001-of-00002.safetensors:gemma-3-4b-it-bf16/model-00002-of-00002.safetensors \
79+
./target/wasm32-wasip1/release/wasmedge-vlm.wasm default
80+
```
81+
82+
## Other
83+
84+
There are some metadata for MLX plugin you can set.
85+
86+
### Basic setting
87+
88+
- model_type (required): model type.
89+
- max_token (option): maximum generate token number, default is 1024.
90+
- enable_debug_log (option): if print debug log, default is false.

wasmedge-mlx/vlm/decode.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from transformers import AutoProcessor
2+
import mlx.core as mx
3+
import sys
4+
5+
6+
def _remove_space(x):
7+
if x and x[0] == " ":
8+
return x[1:]
9+
return x
10+
11+
12+
class Detokenizer():
13+
def __init__(self, tokenizer, trim_space=True):
14+
self.trim_space = trim_space
15+
self.tokenmap = [None] * len(tokenizer.vocab)
16+
for value, tokenid in tokenizer.vocab.items():
17+
self.tokenmap[tokenid] = value
18+
for i in range(len(self.tokenmap)):
19+
if self.tokenmap[i].startswith("<0x"):
20+
self.tokenmap[i] = chr(int(self.tokenmap[i][3:5], 16))
21+
22+
self.offset = 0
23+
self._unflushed = ""
24+
self.text = ""
25+
self.tokens = []
26+
27+
def add_token(self, token):
28+
v = self.tokenmap[token]
29+
if v[0] == "\u2581":
30+
if self.text or not self.trim_space:
31+
self.text += self._unflushed.replace("\u2581", " ")
32+
else:
33+
self.text = _remove_space(
34+
self._unflushed.replace("\u2581", " "))
35+
self._unflushed = v
36+
else:
37+
self._unflushed += v
38+
39+
40+
def decode(token: list, model_path: str, **kwargs):
41+
processor = AutoProcessor.from_pretrained(model_path, **kwargs)
42+
detokenizer = Detokenizer(processor.tokenizer)
43+
for (i, token) in enumerate(token):
44+
detokenizer.add_token(token)
45+
return detokenizer.text
46+
47+
48+
if __name__ == "__main__":
49+
model_path, output = sys.argv[1:]
50+
tokenList = mx.load(output)
51+
print(decode(tokenList.tolist(), model_path))

wasmedge-mlx/vlm/encode.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
2+
from transformers import AutoProcessor
3+
import mlx.core as mx
4+
from PIL import Image, ImageOps
5+
import sys
6+
7+
8+
def encode(processor, image, prompts):
9+
model_inputs = {}
10+
processor.tokenizer.pad_token = processor.tokenizer.eos_token
11+
12+
image = Image.open(image)
13+
image = ImageOps.exif_transpose(image)
14+
image = image.convert("RGB")
15+
images = [image]
16+
inputs = processor(
17+
text=prompts, images=images, padding=True, return_tensors="mlx"
18+
)
19+
if "images" in inputs:
20+
inputs["pixel_values"] = inputs["images"]
21+
inputs.pop("images")
22+
23+
if isinstance(inputs["pixel_values"], list):
24+
pixel_values = inputs["pixel_values"]
25+
else:
26+
pixel_values = mx.array(inputs["pixel_values"])
27+
28+
model_inputs["pixel_values"] = pixel_values
29+
model_inputs["attention_mask"] = (
30+
mx.array(inputs["attention_mask"]
31+
) if "attention_mask" in inputs else None
32+
)
33+
# Convert inputs to model_inputs with mx.array if present
34+
for key, value in inputs.items():
35+
if key not in model_inputs and not isinstance(value, (str, list)):
36+
model_inputs[key] = mx.array(value)
37+
mx.save("input_ids.npy", model_inputs["input_ids"])
38+
mx.save("pixel_values.npy", model_inputs["pixel_values"])
39+
mx.save("mask.npy", model_inputs["attention_mask"])
40+
41+
42+
if __name__ == "__main__":
43+
model_path, image, prompts = sys.argv[1:]
44+
processor = AutoProcessor.from_pretrained(model_path)
45+
formatted_prompt = f"<bos><start_of_turn>user\n\
46+
{prompts}<start_of_image><end_of_turn>\n\
47+
<start_of_turn>model"
48+
encode(processor, image, formatted_prompt)

wasmedge-mlx/vlm/src/main.rs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
use serde_json::json;
2+
use std::env;
3+
use wasmedge_wasi_nn::{
4+
self, ExecutionTarget, GraphBuilder, GraphEncoding, GraphExecutionContext, TensorType,
5+
};
6+
7+
fn get_data_from_context(context: &GraphExecutionContext, index: usize) -> String {
8+
// Preserve for 4096 tokens with average token length 6
9+
const MAX_OUTPUT_BUFFER_SIZE: usize = 4096 * 6;
10+
let mut output_buffer = vec![0u8; MAX_OUTPUT_BUFFER_SIZE];
11+
let mut output_size = context
12+
.get_output(index, &mut output_buffer)
13+
.expect("Failed to get output");
14+
output_size = std::cmp::min(MAX_OUTPUT_BUFFER_SIZE, output_size);
15+
16+
return String::from_utf8_lossy(&output_buffer[..output_size]).to_string();
17+
}
18+
19+
fn get_output_from_context(context: &GraphExecutionContext) -> String {
20+
get_data_from_context(context, 0)
21+
}
22+
23+
fn main() {
24+
// prompt: "What is this icon?";
25+
// image: "wasmedge-runtime-logo.png";
26+
let args: Vec<String> = env::args().collect();
27+
let model_name: &str = &args[1];
28+
let graph = GraphBuilder::new(GraphEncoding::Mlx, ExecutionTarget::AUTO)
29+
.config(
30+
serde_json::to_string(&json!({"model_type": "gemma3", "max_token":250}))
31+
.expect("Failed to serialize options"),
32+
)
33+
.build_from_cache(model_name)
34+
.expect("Failed to build graph");
35+
36+
let mut context = graph
37+
.init_execution_context()
38+
.expect("Failed to init context");
39+
40+
let tensor_data = "input_ids.npy".as_bytes().to_vec();
41+
context
42+
.set_input(0, TensorType::U8, &[1], &tensor_data)
43+
.expect("Failed to set input");
44+
let tensor_data = "pixel_values.npy".as_bytes().to_vec();
45+
context
46+
.set_input(1, TensorType::U8, &[1], &tensor_data)
47+
.expect("Failed to set input");
48+
let tensor_data = "mask.npy".as_bytes().to_vec();
49+
context
50+
.set_input(2, TensorType::U8, &[1], &tensor_data)
51+
.expect("Failed to set input");
52+
53+
context.compute().expect("Failed to compute");
54+
let output = get_output_from_context(&context);
55+
println!("{}", output.trim());
56+
}

0 commit comments

Comments
 (0)