Skip to content

Commit 3d5e5e8

Browse files
committed
Updated Cohere example for binary embeddings [skip ci]
1 parent 0854820 commit 3d5e5e8

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Follow the instructions for your database library:
1616
Or check out some examples:
1717

1818
- [Embeddings](examples/openai.zig) with OpenAI
19-
- [Embeddings](examples/cohere.zig) with Cohere
19+
- [Binary embeddings](examples/cohere.zig) with Cohere
2020
- [Hybrid search](examples/hybrid.zig) with Ollama (Reciprocal Rank Fusion)
2121
- [Sparse search](examples/sparse.zig) with Text Embeddings Inference
2222

examples/cohere.zig

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ const Embeddings = struct {
55
parsed: std.json.Parsed(EmbedResponse),
66

77
const EmbeddingsObject = struct {
8-
float: []const []const f32,
8+
ubinary: []const []const u8,
99
};
1010

1111
const EmbedResponse = struct {
@@ -16,8 +16,8 @@ const Embeddings = struct {
1616
self.parsed.deinit();
1717
}
1818

19-
pub fn get(self: *Embeddings, index: usize) ?[]const f32 {
20-
const data = self.parsed.value.embeddings.float;
19+
pub fn get(self: *Embeddings, index: usize) ?[]const u8 {
20+
const data = self.parsed.value.embeddings.ubinary;
2121
return if (index < data.len) data[index] else null;
2222
}
2323
};
@@ -31,7 +31,7 @@ fn embed(allocator: std.mem.Allocator, texts: []const []const u8, inputType: []c
3131
.texts = texts,
3232
.model = "embed-english-v3.0",
3333
.input_type = inputType,
34-
.embedding_types = [_][]const u8{"float"},
34+
.embedding_types = [_][]const u8{"ubinary"},
3535
};
3636

3737
var authorization = std.ArrayList(u8).init(allocator);
@@ -59,6 +59,14 @@ fn embed(allocator: std.mem.Allocator, texts: []const []const u8, inputType: []c
5959
return Embeddings{ .parsed = parsed };
6060
}
6161

62+
fn bitString(allocator: std.mem.Allocator, data: []const u8) !std.ArrayList(u8) {
63+
var buf = std.ArrayList(u8).init(allocator);
64+
for (data) |v| {
65+
try buf.writer().print("{b:08}", .{v});
66+
}
67+
return buf;
68+
}
69+
6270
pub fn main() !void {
6371
const apiKey = std.posix.getenv("CO_API_KEY") orelse {
6472
std.debug.print("Set CO_API_KEY\n", .{});
@@ -80,20 +88,24 @@ pub fn main() !void {
8088

8189
_ = try conn.exec("CREATE EXTENSION IF NOT EXISTS vector", .{});
8290
_ = try conn.exec("DROP TABLE IF EXISTS documents", .{});
83-
_ = try conn.exec("CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding vector(1024))", .{});
91+
_ = try conn.exec("CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding bit(1024))", .{});
8492

8593
const documents = [_][]const u8{ "The dog is barking", "The cat is purring", "The bear is growling" };
8694
var documentEmbeddings = try embed(allocator, &documents, "search_document", apiKey);
8795
defer documentEmbeddings.deinit();
8896
for (&documents, 0..) |content, i| {
89-
const params = .{ content, documentEmbeddings.get(i) };
90-
_ = try conn.exec("INSERT INTO documents (content, embedding) VALUES ($1, $2::float4[])", params);
97+
var bit = try bitString(allocator, documentEmbeddings.get(i).?);
98+
defer bit.deinit();
99+
const params = .{ content, bit.items };
100+
_ = try conn.exec("INSERT INTO documents (content, embedding) VALUES ($1, $2)", params);
91101
}
92102

93103
const query = "forest";
94104
var queryEmbeddings = try embed(allocator, &[_][]const u8{query}, "search_query", apiKey);
95105
defer queryEmbeddings.deinit();
96-
var result = try conn.query("SELECT content FROM documents ORDER BY embedding <=> $1::float4[]::vector LIMIT 5", .{queryEmbeddings.get(0)});
106+
var queryBit = try bitString(allocator, queryEmbeddings.get(0).?);
107+
defer queryBit.deinit();
108+
var result = try conn.query("SELECT content FROM documents ORDER BY embedding <~> $1 LIMIT 5", .{queryBit.items});
97109
defer result.deinit();
98110
while (try result.next()) |row| {
99111
const content = row.get([]const u8, 0);

0 commit comments

Comments
 (0)