Skip to content

Commit 43cb47e

Browse files
committed
feat: add text embedding model and pipeline.
1 parent 1108a9a commit 43cb47e

File tree

5 files changed

+263
-0
lines changed

5 files changed

+263
-0
lines changed
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import { TextEmbedding } from "../models/text-embedding";
2+
import { InferenceSession } from "onnxruntime-react-native";
3+
4+
describe("TextEmbedding Model", () => {
5+
let model: TextEmbedding;
6+
7+
beforeEach(() => {
8+
model = new TextEmbedding();
9+
});
10+
11+
afterEach(async () => {
12+
await model.release();
13+
});
14+
15+
it("should initialize properly", () => {
16+
expect(model).toBeInstanceOf(TextEmbedding);
17+
});
18+
19+
it("should throw error when session is undefined", async () => {
20+
await expect(model.embed([1n, 2n, 3n])).rejects.toThrow(
21+
"Session is undefined",
22+
);
23+
});
24+
25+
it("should throw error when no embedding output is found", async () => {
26+
// Mock session run to return empty outputs
27+
const mockRun = jest.fn().mockResolvedValue({});
28+
(model as any).sess = {
29+
run: mockRun,
30+
release: jest.fn().mockResolvedValue(undefined),
31+
} as Partial<InferenceSession>;
32+
33+
await expect(model.embed([1n, 2n, 3n])).rejects.toThrow(
34+
"No embedding output found in model outputs",
35+
);
36+
});
37+
38+
it("should properly calculate mean embeddings", async () => {
39+
// Mock session run to return sample embeddings
40+
const mockEmbeddings = new Float32Array([1, 2, 3, 4, 5, 6]); // 2 tokens, 3 dimensions
41+
const mockRun = jest.fn().mockResolvedValue({
42+
last_hidden_state: {
43+
data: mockEmbeddings,
44+
dims: [1, 2, 3], // [batch_size, sequence_length, hidden_size]
45+
},
46+
});
47+
(model as any).sess = {
48+
run: mockRun,
49+
release: jest.fn().mockResolvedValue(undefined),
50+
} as Partial<InferenceSession>;
51+
52+
const result = await model.embed([1n, 2n]);
53+
54+
// Expected mean values: [2.5, 3.5, 4.5]
55+
expect(Array.from(result)).toEqual([2.5, 3.5, 4.5]);
56+
expect(mockRun).toHaveBeenCalled();
57+
});
58+
});
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import TextEmbeddingPipeline from "../pipelines/text-embedding";
2+
3+
// Mock the TextEmbedding model
4+
jest.mock("../models/text-embedding", () => {
5+
return {
6+
TextEmbedding: jest.fn().mockImplementation(() => ({
7+
load: jest.fn().mockResolvedValue(undefined),
8+
embed: jest.fn().mockResolvedValue(new Float32Array([0.1, 0.2, 0.3])),
9+
release: jest.fn().mockResolvedValue(undefined),
10+
})),
11+
};
12+
});
13+
14+
// Create a callable tokenizer mock
15+
const createCallableTokenizer = () => {
16+
const tokenizer = jest.fn().mockResolvedValue({
17+
input_ids: [1n, 2n, 3n],
18+
});
19+
return tokenizer;
20+
};
21+
22+
jest.mock("@xenova/transformers", () => ({
23+
env: {
24+
allowRemoteModels: true,
25+
allowLocalModels: false,
26+
},
27+
AutoTokenizer: {
28+
from_pretrained: jest.fn().mockResolvedValue(createCallableTokenizer()),
29+
},
30+
}));
31+
32+
describe("TextEmbedding Pipeline", () => {
33+
beforeEach(() => {
34+
jest.clearAllMocks();
35+
});
36+
37+
afterEach(async () => {
38+
await TextEmbeddingPipeline.release();
39+
});
40+
41+
it("should throw error when not initialized", async () => {
42+
await expect(TextEmbeddingPipeline.embed("test text")).rejects.toThrow(
43+
"Tokenizer undefined, please initialize first",
44+
);
45+
});
46+
47+
it("should initialize properly", async () => {
48+
await expect(
49+
TextEmbeddingPipeline.init("test-model", "model.onnx"),
50+
).resolves.not.toThrow();
51+
});
52+
53+
it("should generate embeddings", async () => {
54+
await TextEmbeddingPipeline.init("test-model", "model.onnx");
55+
const embeddings = await TextEmbeddingPipeline.embed("test text");
56+
expect(embeddings).toBeInstanceOf(Float32Array);
57+
expect(embeddings.length).toBe(3);
58+
});
59+
});

src/index.tsx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import { TextGeneration } from "./models/text-generation";
2+
import { TextEmbedding } from "./models/text-embedding";
23
import TextGenerationPipeline from "./pipelines/text-generation";
4+
import TextEmbeddingPipeline from "./pipelines/text-embedding";
35

46
export const Pipeline = {
57
TextGeneration: TextGenerationPipeline,
8+
TextEmbedding: TextEmbeddingPipeline,
69
};
710

811
export const Model = {
912
TextGeneration,
13+
TextEmbedding,
1014
};
1115

1216
export default {
@@ -16,4 +20,6 @@ export default {
1620

1721
export type * from "./models/base";
1822
export type * from "./models/text-generation";
23+
export type * from "./models/text-embedding";
1924
export type * from "./pipelines/text-generation";
25+
export type * from "./pipelines/text-embedding";

src/models/text-embedding.tsx

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import "text-encoding-polyfill";
2+
import { Tensor } from "onnxruntime-react-native";
3+
import { Base } from "./base";
4+
5+
/**
6+
* Class to handle text embedding model on top of onnxruntime
7+
*/
8+
export class TextEmbedding extends Base {
9+
/**
10+
* Generate embeddings from input tokens
11+
*
12+
* @param tokens Input tokens to generate embeddings from
13+
* @returns Float32Array containing the embedding vector
14+
*/
15+
public async embed(tokens: bigint[]): Promise<Float32Array> {
16+
const feed = this.feed;
17+
const inputIdsTensor = new Tensor(
18+
"int64",
19+
BigInt64Array.from(tokens.map(BigInt)),
20+
[1, tokens.length],
21+
);
22+
feed.input_ids = inputIdsTensor;
23+
24+
// Create attention mask (1 for all tokens)
25+
feed.attention_mask = new Tensor(
26+
"int64",
27+
BigInt64Array.from({ length: tokens.length }, () => 1n),
28+
[1, tokens.length],
29+
);
30+
31+
if (!this.sess) {
32+
throw new Error("Session is undefined");
33+
}
34+
35+
// Run inference to get embeddings
36+
const outputs = await this.sess.run(feed);
37+
38+
// The model typically outputs the embeddings as 'last_hidden_state' or 'embeddings'
39+
// We take the mean of the token embeddings to get a single vector
40+
const embeddings = outputs.last_hidden_state || outputs.embeddings;
41+
42+
if (!embeddings) {
43+
throw new Error("No embedding output found in model outputs");
44+
}
45+
46+
// Calculate mean across token dimension (dim 1) to get a single embedding vector
47+
const data = embeddings.data as Float32Array;
48+
const [, seqLen, hiddenSize] = embeddings.dims;
49+
const result = new Float32Array(hiddenSize);
50+
51+
for (let h = 0; h < hiddenSize; h++) {
52+
let sum = 0;
53+
for (let s = 0; s < seqLen; s++) {
54+
sum += data[s * hiddenSize + h];
55+
}
56+
result[h] = sum / seqLen;
57+
}
58+
59+
return result;
60+
}
61+
}

src/pipelines/text-embedding.tsx

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import { env, AutoTokenizer, PreTrainedTokenizer } from "@xenova/transformers";
2+
import { TextEmbedding as Model } from "../models/text-embedding";
3+
import { LoadOptions } from "../models/base";
4+
5+
/** Initialization Options for Text Embedding */
6+
export interface TextEmbeddingOptions extends LoadOptions {
7+
/** Shows special tokens in the output. */
8+
show_special: boolean;
9+
}
10+
11+
// Set up environment for transformers.js tokenizer
12+
env.allowRemoteModels = true;
13+
env.allowLocalModels = false;
14+
15+
// Declare tokenizer and model
16+
let tokenizer: PreTrainedTokenizer;
17+
const model: Model = new Model();
18+
19+
// Initialize options with default values
20+
let _options: TextEmbeddingOptions = {
21+
show_special: false,
22+
max_tokens: 512, // typical max length for embedding models
23+
fetch: async (url) => url,
24+
verbose: false,
25+
externalData: false,
26+
executionProviders: ["cpu"],
27+
};
28+
29+
/**
30+
* Generates embeddings from the input text.
31+
*
32+
* @param text - The input text to generate embeddings from.
33+
* @returns Float32Array containing the embedding vector.
34+
*/
35+
async function embed(text: string): Promise<Float32Array> {
36+
if (!tokenizer) {
37+
throw new Error("Tokenizer undefined, please initialize first.");
38+
}
39+
40+
const { input_ids } = await tokenizer(text, {
41+
return_tensor: false,
42+
padding: true,
43+
truncation: true,
44+
max_length: _options.max_tokens,
45+
});
46+
47+
return await model.embed(input_ids);
48+
}
49+
50+
/**
51+
* Loads the model and tokenizer with the specified options.
52+
*
53+
* @param model_name - The name of the model to load.
54+
* @param onnx_path - The path to the ONNX model.
55+
* @param options - Optional initialization options.
56+
*/
57+
async function init(
58+
model_name: string,
59+
onnx_path: string,
60+
options?: Partial<TextEmbeddingOptions>,
61+
): Promise<void> {
62+
_options = { ..._options, ...options };
63+
tokenizer = await AutoTokenizer.from_pretrained(model_name);
64+
await model.load(model_name, onnx_path, _options);
65+
}
66+
67+
/**
68+
* Releases the resources used by the model.
69+
*/
70+
async function release(): Promise<void> {
71+
await model.release();
72+
}
73+
74+
// Export functions for external use
75+
export default {
76+
init,
77+
embed,
78+
release,
79+
};

0 commit comments

Comments
 (0)