Releases: huggingface/transformers.js
3.7.0
🚀 Transformers.js v3.7 — Voxtral, LFM2, ModernBERT Decoder
🤖 New models
This update adds support for 3 new architectures:
Voxtral
Voxtral Mini is an enhancement of Ministral 3B, incorporating state-of-the-art audio input capabilities while retaining best-in-class text performance. It excels at speech transcription, translation and audio understanding. ONNX weights for Voxtral-Mini-3B-2507 can be found here. Learn more about Voxtral in the release blog post.
Try it out with our online demo:
Voxtral.WebGPU.demo.mp4
Example: Audio transcription
import { VoxtralForConditionalGeneration, VoxtralProcessor, TextStreamer, read_audio } from "@huggingface/transformers";
// Load the processor and model
const model_id = "onnx-community/Voxtral-Mini-3B-2507-ONNX";
const processor = await VoxtralProcessor.from_pretrained(model_id);
const model = await VoxtralForConditionalGeneration.from_pretrained(
model_id,
{
dtype: {
embed_tokens: "fp16", // "fp32", "fp16", "q8", "q4"
audio_encoder: "q4", // "fp32", "fp16", "q8", "q4", "q4f16"
decoder_model_merged: "q4", // "q4", "q4f16"
},
device: "webgpu",
},
);
// Prepare the conversation
const conversation = [
{
"role": "user",
"content": [
{ "type": "audio" },
{ "type": "text", "text": "lang:en [TRANSCRIBE]" },
],
}
];
const text = processor.apply_chat_template(conversation, { tokenize: false });
const audio = await read_audio("http://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/mlk.wav", 16000);
const inputs = await processor(text, audio);
// Generate the response
const generated_ids = await model.generate({
...inputs,
max_new_tokens: 256,
streamer: new TextStreamer(processor.tokenizer, { skip_special_tokens: true, skip_prompt: true }),
});
// Decode the generated tokens
const new_tokens = generated_ids.slice(null, [inputs.input_ids.dims.at(-1), null]);
const generated_texts = processor.batch_decode(
new_tokens,
{ skip_special_tokens: true },
);
console.log(generated_texts[0]);
// I have a dream that one day this nation will rise up and live out the true meaning of its creed.
LFM2
LFM2 is a new generation of hybrid models developed by Liquid AI, specifically designed for edge AI and on-device deployment. It sets a new standard in terms of quality, speed, and memory efficiency.
The models, which we have converted to ONNX, come in three different sizes: 350M, 700M, and 1.2B parameters.
Example: Text-generation with LFM2-350M:
import { pipeline, TextStreamer } from "@huggingface/transformers";
// Create a text generation pipeline
const generator = await pipeline(
"text-generation",
"onnx-community/LFM2-350M-ONNX",
{ dtype: "q4" },
);
// Define the list of messages
const messages = [
{ role: "system", content: "You are a helpful assistant." },
{ role: "user", content: "What is the capital of France?" },
];
// Generate a response
const output = await generator(messages, {
max_new_tokens: 512,
do_sample: false,
streamer: new TextStreamer(generator.tokenizer, { skip_prompt: true, skip_special_tokens: true }),
});
console.log(output[0].generated_text.at(-1).content);
// The capital of France is Paris. It is a vibrant city known for its historical landmarks, art, fashion, and gastronomy.
ModernBERT Decoder
These models form part of the Ettin suite: the first collection of paired encoder-only and decoder-only models trained with identical data, architecture, and training recipes. Ettin enables fair comparisons between encoder and decoder architectures across multiple scales, providing state-of-the-art performance for open-data models in their respective size categories.
The list of supported models can be found here.
import { pipeline, TextStreamer } from "@huggingface/transformers";
// Create a text generation pipeline
const generator = await pipeline(
"text-generation",
"onnx-community/ettin-decoder-150m-ONNX",
{ dtype: "fp32" },
);
// Generate a response
const text = "Q: What is the capital of France?\nA:";
const output = await generator(text, {
max_new_tokens: 128,
streamer: new TextStreamer(generator.tokenizer, { skip_prompt: true, skip_special_tokens: true }),
});
console.log(output[0].generated_text);
Added in #1371.
🛠️ Other improvements
- Add special tokens in text-generation pipeline if tokenizer requires in #1370
Full Changelog: 3.6.3...3.7.0
3.6.3
3.6.2
What's new?
-
Add support for SmolLM3 in #1359
SmolLM3 is a 3B parameter language model designed to push the boundaries of small models. It supports 6 languages, advanced reasoning and long context. SmolLM3 is a fully open model that offers strong performance at the 3B–4B scale.
Example:
import { pipeline, TextStreamer } from "@huggingface/transformers"; // Create a text generation pipeline const generator = await pipeline( "text-generation", "HuggingFaceTB/SmolLM3-3B-ONNX", { dtype: "q4f16" }, ); // Define the list of messages const messages = [ { role: "system", content: "You are SmolLM, a language model created by Hugging Face. If asked by the user, here is some information about you: SmolLM has 3 billion parameters and can converse in 6 languages: English, Spanish, German, French, Italian, and Portuguese. SmolLM is a fully open model and was trained on a diverse mix of public datasets./think" }, { role: "user", content: "Solve the equation x^2 - 3x + 2 = 0" }, ]; // Generate a response const output = await generator(messages, { max_new_tokens: 1024, do_sample: false, streamer: new TextStreamer(generator.tokenizer, { skip_prompt: true, skip_special_tokens: true }), }); console.log(output[0].generated_text.at(-1).content);
-
Add support for ERNIE-4.5 in #1354
Example:import { pipeline, TextStreamer } from "@huggingface/transformers"; // Create a text generation pipeline const generator = await pipeline( "text-generation", "onnx-community/ERNIE-4.5-0.3B-ONNX", { dtype: "fp32" }, // Options: "fp32", "fp16", "q8", "q4", "q4f16" ); // Define the list of messages const messages = [ { role: "system", content: "You are a helpful assistant." }, { role: "user", content: "What is the capital of France?" }, ]; // Generate a response const output = await generator(messages, { max_new_tokens: 512, do_sample: false, streamer: new TextStreamer(generator.tokenizer, { skip_prompt: true, skip_special_tokens: true }), }); console.log(output[0].generated_text.at(-1).content); // The capital of France is Paris.
Full Changelog: 3.6.1...3.6.2
3.6.1
What's new?
-
Add support for NeoBERT in #1350
import { pipeline } from "@huggingface/transformers"; // Create feature extraction pipeline const extractor = await pipeline("feature-extraction", "onnx-community/NeoBERT-ONNX"); // Compute embeddings const text = "NeoBERT is the most efficient model of its kind!"; const embedding = await extractor(text, { pooling: "cls" }); console.log(embedding.dims); // [1, 768]
-
Improve webworker detection to support ServiceWorker and SharedWorker by @aungKhantPaing in #1346
-
Fix optional
from_pretrained
types in #1352
New Contributors
- @aungKhantPaing made their first contribution in #1346
- @fidoriel made their first contribution in #1351
Full Changelog: 3.6.0...3.6.1
3.6.0
🚀 Transformers.js v3.6 — Gemma 3n, Qwen3-Embedding, Llava-Qwen2
🤖 New models
Gemma 3n
Gemma 3n, which was announced as a preview during Google I/O, is a model that is designed from the ground up to run locally on your hardware. On top of that, it's natively multimodal, supporting image, text, audio, and video inputs 🤯
Gemma 3n models have multiple architecture innovations:
- They are available in two sizes based on effective parameters. While the raw parameter count of this model is 6B, the architecture design allows the model to be run with a memory footprint comparable to a traditional 2B model by offloading low-utilization matrices from the accelerator.
- They use a MatFormer architecture that allows nesting sub-models within the E4B model. We provide one sub-model (this model repository), or you can access a spectrum of custom-sized models using the Mix-and-Match method.
Learn more about these techniques in the technical blog post and the Gemma documentation.
As part of the release, we are releasing ONNX weights for the gemma-3n-E2B-it
variant (link), making it compatible with Transformers.js:
Warning
Due to the model's large size, we currently only support Node.js, Deno, and Bun execution.
In-browser WebGPU support is actively being worked on, so stay tuned for an update!
Example: Caption an image
import {
AutoProcessor,
AutoModelForImageTextToText,
load_image,
TextStreamer,
} from "@huggingface/transformers";
// Load processor and model
const model_id = "onnx-community/gemma-3n-E2B-it-ONNX";
const processor = await AutoProcessor.from_pretrained(model_id);
const model = await AutoModelForImageTextToText.from_pretrained(model_id, {
dtype: {
embed_tokens: "q8",
audio_encoder: "q8",
vision_encoder: "fp16",
decoder_model_merged: "q4",
},
device: "cpu", // NOTE: WebGPU support coming soon!
});
// Prepare prompt
const messages = [
{
role: "user",
content: [
{ type: "image" },
{ type: "text", text: "Describe this image in detail." },
],
},
];
const prompt = processor.apply_chat_template(messages, {
add_generation_prompt: true,
});
// Prepare inputs
const url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg";
const image = await load_image(url);
const audio = null;
const inputs = await processor(prompt, image, audio, {
add_special_tokens: false,
});
// Generate output
const outputs = await model.generate({
...inputs,
max_new_tokens: 512,
do_sample: false,
streamer: new TextStreamer(processor.tokenizer, {
skip_prompt: true,
skip_special_tokens: false,
// callback_function: (text) => { /* Do something with the streamed output */ },
}),
});
// Decode output
const decoded = processor.batch_decode(
outputs.slice(null, [inputs.input_ids.dims.at(-1), null]),
{ skip_special_tokens: true },
);
console.log(decoded[0]);
See example output
The image is a close-up, slightly macro shot of a cluster of vibrant pink cosmos flowers in full bloom. The flowers are the focal point, with their delicate, slightly ruffled petals radiating outwards. They have a soft, almost pastel pink hue, and their edges are subtly veined.
A small, dark-colored bee is actively visiting one of the pink flowers, its body positioned near the center of the bloom. The bee appears to be collecting pollen or nectar.
The flowers are attached to slender, brownish-green stems, and some of the surrounding foliage is visible in a blurred background, suggesting a natural outdoor setting. There are also hints of other flowers in the background, including some red ones, adding a touch of contrast to the pink.
The lighting in the image seems to be natural daylight, casting soft shadows and highlighting the textures of the petals and the bee. The overall impression is one of delicate beauty and the gentle activity of nature.
Example: Transcribe audio
import {
AutoProcessor,
AutoModelForImageTextToText,
TextStreamer,
} from "@huggingface/transformers";
import wavefile from "wavefile";
// Load processor and model
const model_id = "onnx-community/gemma-3n-E2B-it-ONNX";
const processor = await AutoProcessor.from_pretrained(model_id);
const model = await AutoModelForImageTextToText.from_pretrained(model_id, {
dtype: {
embed_tokens: "q8",
audio_encoder: "q4",
vision_encoder: "fp16",
decoder_model_merged: "q4",
},
device: "cpu", // NOTE: WebGPU support coming soon!
});
// Prepare prompt
const messages = [
{
role: "user",
content: [
{ type: "audio" },
{ type: "text", text: "Transcribe this audio verbatim." },
],
},
];
const prompt = processor.apply_chat_template(messages, {
add_generation_prompt: true,
});
// Prepare inputs
const url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav";
const buffer = Buffer.from(await fetch(url).then((x) => x.arrayBuffer()));
const wav = new wavefile.WaveFile(buffer);
wav.toBitDepth("32f"); // Pipeline expects input as a Float32Array
wav.toSampleRate(processor.feature_extractor.config.sampling_rate);
let audioData = wav.getSamples();
if (Array.isArray(audioData)) {
if (audioData.length > 1) {
for (let i = 0; i < audioData[0].length; ++i) {
audioData[0][i] = (Math.sqrt(2) * (audioData[0][i] + audioData[1][i])) / 2;
}
}
audioData = audioData[0];
}
const image = null;
const audio = audioData;
const inputs = await processor(prompt, image, audio, {
add_special_tokens: false,
});
// Generate output
const outputs = await model.generate({
...inputs,
max_new_tokens: 512,
do_sample: false,
streamer: new TextStreamer(processor.tokenizer, {
skip_prompt: true,
skip_special_tokens: false,
// callback_function: (text) => { /* Do something with the streamed output */ },
}),
});
// Decode output
const decoded = processor.batch_decode(
outputs.slice(null, [inputs.input_ids.dims.at(-1), null]),
{ skip_special_tokens: true },
);
console.log(decoded[0]);
See example output
And so, my fellow Americans, ask not what your country can do for you. Ask what you can do for your country.
Qwen3-Embedding
The Qwen3 Embedding model series is the latest proprietary model of the Qwen family, specifically designed for text embedding and ranking tasks. Building upon the dense foundational models of the Qwen3 series, it provides a comprehensive range of text embeddings and reranking models in various sizes (0.6B, 4B, and 8B). This series inherits the exceptional multilingual capabilities, long-text understanding, and reasoning skills of its foundational model.
You can run it with Transformers.js as follows:
import { pipeline, matmul } from "@huggingface/transformers";
// Create a feature extraction pipeline
const extractor = await pipeline(
"feature-extraction",
"onnx-community/Qwen3-Embedding-0.6B-ONNX",
{
dtype: "fp32", // Options: "fp32", "fp16", "q8"
// device: "webgpu",
},
);
function get_detailed_instruct(task_description, query) {
return `Instruct: ${task_description}\nQuery:${query}`;
}
// Each query must come with a one-sentence instruction that describes the task
const task = "Given a web search query, retrieve relevant passages that answer the query";
const queries = [
get_detailed_instruct(task, "What is the capital of China?"),
get_detailed_instruct(task, "Explain gravity"),
];
// No need to add instruction for retrieval documents
const documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
];
const input_texts = [...queries, ...documents];
// Extract embeddings for queries and documents
const output = await extractor(input_texts, {
pooling: "last_token",
normalize: true,
});
const scores = await matmul(
output.slice([0, queries.length]), // Query embeddings
output.slice([queries.length, null]).transpose(1, 0), // Document embeddings
);
console.log(scores.tolist());
// [
// [ 0.7645590305328369, 0.14142560958862305 ],
// [ 0.13549776375293732, 0.599955141544342 ]
// ]
Llava-Qwen2
Finally, we also added support for Llava models with a Qwen2 text backbone:
import {
AutoProcessor,
AutoModelForImageTextToText,
load_image,
TextStreamer,
} from "@huggingface/transformers";
// Load processor and model
const model_id = "onnx-community/FastVLM-0.5B-ONNX";
const processor = await AutoProcessor.from_pretrained(model_id);
const model = await AutoModelForImageTextToText.from_pretrained(model_id, {
dtype: {
embed_tokens: "fp16",
vision_encoder: "q4",
decoder_model_merged: "q4",
},
});
// Prepare prompt
const messages = [
{
role: "user",
content: "<image>Describe this image in detail.",
},
];
const prompt = processor.apply_cha...
3.5.2
What's new?
- Update paper links to HF papers by @qgallouedec in #1318
- Allow older (legacy) BPE models to be detected even when the type is not specified in #1314
- Fix WhisperTextStreamer when
return_timestamps
is true (correctly ignore printing of timestamp tokens) in #1327 - Improve typescript exports and expose common types in #1325
New Contributors
- @qgallouedec made their first contribution in #1318
Full Changelog: 3.5.1...3.5.2
3.5.1
What's new?
-
Add support for Qwen3 in #1300.
Example usage:
import { pipeline, TextStreamer } from "@huggingface/transformers"; // Create a text generation pipeline const generator = await pipeline( "text-generation", "onnx-community/Qwen3-0.6B-ONNX", { dtype: "q4f16", device: "webgpu" }, ); // Define the list of messages const messages = [ { role: "user", content: "If 5 brog 5 is 1, and 4 brog 2 is 2, what is 3 brog 1?" }, ]; // Generate a response const output = await generator(messages, { max_new_tokens: 1024, do_sample: true, top_k: 20, temperature: 0.7, streamer: new TextStreamer(generator.tokenizer, { skip_prompt: true, skip_special_tokens: true}), }); console.log(output[0].generated_text.at(-1).content);
Try out the online demo:
qwen3-webgpu.mp4
-
Add support for D-FINE in #1303
Example usage:
import { pipeline } from "@huggingface/transformers"; const detector = await pipeline("object-detection", "onnx-community/dfine_s_coco-ONNX"); const image = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cats.jpg"; const output = await detector(image, { threshold: 0.5 }); console.log(output);
See list of supported models: https://huggingface.co/models?library=transformers.js&other=d_fine&sort=trending
-
Introduce global inference chain (+ other WebGPU fixes) in #1293
-
fix:
RawImage.fromURL
error when input file url by @himself65 in #1288 -
[bugfix] tokenizers respect padding: true with non-null max_length by @dwisdom0 in #1284
New Contributors
- @himself65 made their first contribution in #1288
- @dwisdom0 made their first contribution in #1284
Full Changelog: 3.5.0...3.5.1
3.5.0
🔥 Transformers.js v3.5
🛠️ Improvements
- Fix error when dtype in config is unset by @hans00 in #1271
- [audio utils] fix fft_bin_width computation in #1274
- Fix bad words logits processor in #1278
- Implement LRU cache for BPE tokenizer in #1283
- Return buffer instead of file_path if cache unavailable for model loading by @PrafulB in #1280
- Use custom cache over FSCache if specified by @PrafulB in #1285
- Support device-level configuration across all devices by @ibelem in #1276
🤗 New contributors
Full Changelog: 3.4.2...3.5.0
3.4.2
3.4.1
What's new?
- Add support for SNAC (Multi-Scale Neural Audio Codec) in #1251
- Add support for Metric3D (v1 & v2) in #1254
- Add support for Gemma 3 text in #1229. Note: Only Node.js execution is supported for now.
- Safeguard against background removal pipeline precision issues in #1255. Thanks to @LuSrodri for reporting the issue!
- Allow RawImage to read from all types of supported sources by @BritishWerewolf in #1244
- Update pipelines.md api docs in #1256
- Update extension example to use latest version by @fs-eire in #1213
Full Changelog: 3.4.0...3.4.1