Skip to content

Commit eb93b21

Browse files
authored
FEATURE: Add BGE-M3 embeddings support (#569)
BAAI/bge-m3 is an interesting model, that is multilingual and with a context size of 8192. Even with a 16x larger context, it's only 4x slower to compute it's embeddings on the worst case scenario. Also includes a minor refactor of the rake task, including setting model and concurrency levels when running the backfill task.
1 parent 6de9c53 commit eb93b21

File tree

9 files changed

+1000347
-19
lines changed

9 files changed

+1000347
-19
lines changed

β€Žconfig/settings.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ discourse_ai:
264264
- multilingual-e5-large
265265
- bge-large-en
266266
- gemini
267+
- bge-m3
267268
validator: "DiscourseAi::Configuration::EmbeddingsModelValidator"
268269
ai_embeddings_per_post_enabled:
269270
default: false
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# frozen_string_literal: true
2+
3+
class AddEmbeddingsTablesforBgeM3 < ActiveRecord::Migration[7.0]
4+
def change
5+
create_table :ai_topic_embeddings_8_1, id: false do |t|
6+
t.integer :topic_id, null: false
7+
t.integer :model_version, null: false
8+
t.integer :strategy_version, null: false
9+
t.text :digest, null: false
10+
t.column :embeddings, "vector(1024)", null: false
11+
t.timestamps
12+
13+
t.index :topic_id, unique: true
14+
end
15+
create_table :ai_post_embeddings_8_1, id: false do |t|
16+
t.integer :post_id, null: false
17+
t.integer :model_version, null: false
18+
t.integer :strategy_version, null: false
19+
t.text :digest, null: false
20+
t.column :embeddings, "vector(1024)", null: false
21+
t.timestamps
22+
23+
t.index :post_id, unique: true
24+
end
25+
create_table :ai_document_fragment_embeddings_8_1, id: false do |t|
26+
t.integer :rag_document_fragment_id, null: false
27+
t.integer :model_version, null: false
28+
t.integer :strategy_version, null: false
29+
t.text :digest, null: false
30+
t.column :embeddings, "vector(1024)", null: false
31+
t.timestamps
32+
33+
t.index :rag_document_fragment_id,
34+
unique: true,
35+
name: "rag_document_fragment_id_embeddings_8_1"
36+
end
37+
end
38+
end

β€Žlib/embeddings/vector_representations/base.rb

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@ def find_representation(model_name)
1111
[
1212
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2,
1313
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn,
14+
DiscourseAi::Embeddings::VectorRepresentations::BgeM3,
1415
DiscourseAi::Embeddings::VectorRepresentations::Gemini,
1516
DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large,
16-
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002,
17-
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small,
1817
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large,
18+
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small,
19+
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002,
1920
].find { _1.name == model_name }
2021
end
2122

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# frozen_string_literal: true
2+
3+
module DiscourseAi
4+
module Embeddings
5+
module VectorRepresentations
6+
class BgeM3 < Base
7+
class << self
8+
def name
9+
"bge-m3"
10+
end
11+
12+
def correctly_configured?
13+
DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
14+
end
15+
16+
def dependant_setting_names
17+
%w[ai_hugging_face_tei_endpoint_srv ai_hugging_face_tei_endpoint]
18+
end
19+
end
20+
21+
def vector_from(text, asymetric: false)
22+
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
23+
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
24+
end
25+
26+
def dimensions
27+
1024
28+
end
29+
30+
def max_sequence_length
31+
8192
32+
end
33+
34+
def id
35+
8
36+
end
37+
38+
def version
39+
1
40+
end
41+
42+
def pg_function
43+
"<#>"
44+
end
45+
46+
def pg_index_type
47+
"vector_ip_ops"
48+
end
49+
50+
def tokenizer
51+
DiscourseAi::Tokenizer::BgeM3Tokenizer
52+
end
53+
end
54+
end
55+
end
56+
end

β€Žlib/tasks/modules/embeddings/database.rake

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,48 @@
11
# frozen_string_literal: true
22

33
desc "Backfill embeddings for all topics and posts"
4-
task "ai:embeddings:backfill" => [:environment] do
4+
task "ai:embeddings:backfill", %i[model concurrency] => [:environment] do |_, args|
55
public_categories = Category.where(read_restricted: false).pluck(:id)
66

77
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
8-
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
8+
if args[:model].present?
9+
vector_rep =
10+
DiscourseAi::Embeddings::VectorRepresentations::Base.find_representation(args[:model]).new(
11+
strategy,
12+
)
13+
else
14+
vector_rep =
15+
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
16+
end
917
table_name = vector_rep.topic_table_name
1018

11-
Topic
12-
.joins("LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id")
13-
.where("#{table_name}.topic_id IS NULL")
14-
.where("category_id IN (?)", public_categories)
15-
.where(deleted_at: nil)
16-
.order("topics.id DESC")
17-
.find_each do |t|
18-
print "."
19+
topics =
20+
Topic
21+
.joins("LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id")
22+
.where("#{table_name}.topic_id IS NULL")
23+
.where("category_id IN (?)", public_categories)
24+
.where(deleted_at: nil)
25+
.order("topics.id DESC")
26+
27+
Parallel.each(topics.all, in_processes: args[:concurrency].to_i, progress: "Topics") do |t|
28+
ActiveRecord::Base.connection_pool.with_connection do
1929
vector_rep.generate_representation_from(t)
2030
end
31+
end
2132

2233
table_name = vector_rep.post_table_name
23-
Post
24-
.joins("LEFT JOIN #{table_name} ON #{table_name}.post_id = posts.id")
25-
.where("#{table_name}.post_id IS NULL")
26-
.where(deleted_at: nil)
27-
.order("posts.id DESC")
28-
.find_each do |t|
29-
print "."
34+
posts =
35+
Post
36+
.joins("LEFT JOIN #{table_name} ON #{table_name}.post_id = posts.id")
37+
.where("#{table_name}.post_id IS NULL")
38+
.where(deleted_at: nil)
39+
.order("posts.id DESC")
40+
41+
Parallel.each(posts.all, in_processes: args[:concurrency].to_i, progress: "Posts") do |t|
42+
ActiveRecord::Base.connection_pool.with_connection do
3043
vector_rep.generate_representation_from(t)
3144
end
45+
end
3246
end
3347

3448
desc "Creates indexes for embeddings"

β€Žlib/tokenizer/bge_m3_tokenizer.rb

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# frozen_string_literal: true
2+
3+
module DiscourseAi
4+
module Tokenizer
5+
class BgeM3Tokenizer < BasicTokenizer
6+
def self.tokenizer
7+
@@tokenizer ||= Tokenizers.from_file("./plugins/discourse-ai/tokenizers/bge-m3.json")
8+
end
9+
end
10+
end
11+
end

β€Žspec/shared/tokenizer_spec.rb

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,32 @@
176176
end
177177
end
178178
end
179+
180+
describe DiscourseAi::Tokenizer::BgeM3Tokenizer do
181+
describe "#size" do
182+
describe "returns a token count" do
183+
it "for a sentence with punctuation and capitalization and numbers" do
184+
expect(described_class.size("Hello, World! 123")).to eq(7)
185+
end
186+
end
187+
end
188+
189+
describe "#truncate" do
190+
it "truncates a sentence" do
191+
sentence = "foo bar baz qux quux corge grault garply waldo fred plugh xyzzy thud"
192+
expect(described_class.truncate(sentence, 3)).to eq("foo")
193+
end
194+
195+
it "truncates a sentence successfully at a multibyte unicode character" do
196+
sentence = "foo bar πŸ‘¨πŸΏβ€πŸ‘©πŸΏβ€πŸ‘§πŸΏβ€πŸ‘§πŸΏ baz qux quux corge grault garply waldo fred plugh xyzzy thud"
197+
expect(described_class.truncate(sentence, 7)).to eq("foo bar πŸ‘¨πŸΏ")
198+
end
199+
200+
it "truncates unicode characters properly when they use more than one token per char" do
201+
sentence = "ζˆ‘ε–œζ¬’εƒζ―”θ¨"
202+
original_size = described_class.size(sentence)
203+
expect(described_class.size(described_class.truncate(sentence, original_size - 2))).to be <
204+
original_size
205+
end
206+
end
207+
end

β€Žtokenizers/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,7 @@ Licensed under MIT License
2525
## mixtral
2626

2727
Licensed under Apache 2.0 License
28+
29+
## bge-m3
30+
31+
Licensed under MIT License

0 commit comments

Comments
Β (0)