Skip to content

Commit 6784773

Browse files
authored
Merge pull request #1286 from dcSpark/feature/change-embedding-model
Allows to change the embedding model. Recreate embeddings if model ch…
2 parents eceeaf1 + 223c2a6 commit 6784773

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+102654
-51249
lines changed

Cargo.lock

Lines changed: 12 additions & 12 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ members = [
1919
resolver = "2"
2020

2121
[workspace.package]
22-
version = "1.1.9"
22+
version = "1.1.10"
2323
edition = "2021"
2424
authors = ["Nico Arqueros <nico@shinkai.com>"]
2525

shinkai-bin/shinkai-node/src/llm_provider/providers/gemini.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -535,9 +535,7 @@ mod tests {
535535
let temp_file = NamedTempFile::new().unwrap();
536536
let db_path = std::path::PathBuf::from(temp_file.path());
537537
let api_url = String::new();
538-
let model_type = EmbeddingModelType::OllamaTextEmbeddingsInference(
539-
OllamaTextEmbeddingsInference::SnowflakeArcticEmbedM
540-
);
538+
let model_type = EmbeddingModelType::default();
541539

542540
SqliteManager::new(db_path, api_url, model_type).unwrap()
543541
}

shinkai-bin/shinkai-node/src/network/agent_payments_manager/my_agent_offerings_manager.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1002,7 +1002,7 @@ mod tests {
10021002
// let fs_db_path = format!("db_tests/{}", "vector_fs");
10031003
// let profile_list = vec![default_test_profile()];
10041004
// let supported_embedding_models = vec![EmbeddingModelType::OllamaTextEmbeddingsInference(
1005-
// OllamaTextEmbeddingsInference::SnowflakeArcticEmbed_M,
1005+
// OllamaTextEmbeddingsInference::EmbeddingGemma300M,
10061006
// )];
10071007

10081008
// VectorFS::new(

shinkai-bin/shinkai-node/src/network/node.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,25 @@ impl Node {
443443
callback_manager.update_cron_manager(cron_manager.clone());
444444
}
445445

446+
// Perform embedding migration if needed (BEFORE updating DB with current model)
447+
{
448+
let current_model = {
449+
let default_model_guard = self.default_embedding_model.lock().await;
450+
default_model_guard.clone()
451+
};
452+
453+
if let Err(e) = self.db.migrate_embeddings_to_new_model(&self.embedding_generator, &current_model).await {
454+
shinkai_log(
455+
ShinkaiLogOption::Node,
456+
ShinkaiLogLevel::Error,
457+
&format!("Embedding migration failed: {e:?}"),
458+
);
459+
// Note: We continue even if migration fails to allow the node to start
460+
}
461+
}
462+
446463
self.initialize_embedding_models().await?;
464+
447465
{
448466
// Starting the WebSocket server
449467
if let (Some(ws_manager), Some(ws_address)) = (&self.ws_manager, self.ws_address) {

shinkai-bin/shinkai-node/src/network/zip_export_import/zip_export_import.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1376,7 +1376,7 @@ mod tests {
13761376
let db_path = PathBuf::from(temp_file.path());
13771377
let api_url = String::new();
13781378
let model_type =
1379-
EmbeddingModelType::OllamaTextEmbeddingsInference(OllamaTextEmbeddingsInference::SnowflakeArcticEmbedM);
1379+
EmbeddingModelType::default();
13801380
println!("Creating test db at {:?}", db_path);
13811381
SqliteManager::new(db_path, api_url, model_type).unwrap()
13821382
}

shinkai-bin/shinkai-node/src/tools/tool_implementation/native_tools/config_setup.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ mod tests {
322322
let db_path = PathBuf::from(temp_file.path());
323323
let api_url = String::new();
324324
let model_type =
325-
EmbeddingModelType::OllamaTextEmbeddingsInference(OllamaTextEmbeddingsInference::SnowflakeArcticEmbedM);
325+
EmbeddingModelType::default();
326326

327327
SqliteManager::new(db_path, api_url, model_type).unwrap()
328328
}
@@ -388,7 +388,7 @@ mod tests {
388388
},
389389
true,
390390
);
391-
initial_tool.set_embedding(vec![0.0; 384]);
391+
initial_tool.set_embedding(vec![0.0; EmbeddingModelType::default().vector_dimensions().unwrap_or(768)]);
392392
initial_tool
393393
}
394394

shinkai-bin/shinkai-node/src/utils/environment.rs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::env;
22
use std::net::{IpAddr, SocketAddr};
33
use std::str::FromStr;
44

5-
use shinkai_embedding::model_type::{EmbeddingModelType, OllamaTextEmbeddingsInference};
5+
use shinkai_embedding::model_type::EmbeddingModelType;
66
use shinkai_message_primitives::schemas::llm_providers::serialized_llm_provider::{
77
LLMProviderInterface, SerializedLLMProvider
88
};
@@ -169,11 +169,7 @@ pub fn fetch_node_environment() -> NodeEnvironment {
169169
}
170170

171171
// Fetch the default embedding model
172-
let default_embedding_model: EmbeddingModelType = env::var("DEFAULT_EMBEDDING_MODEL")
173-
.map(|s| EmbeddingModelType::from_string(&s).expect("Failed to parse DEFAULT_EMBEDDING_MODEL"))
174-
.unwrap_or_else(|_| {
175-
EmbeddingModelType::OllamaTextEmbeddingsInference(OllamaTextEmbeddingsInference::SnowflakeArcticEmbedM)
176-
});
172+
let default_embedding_model: EmbeddingModelType = EmbeddingModelType::default();
177173

178174
// Fetch the supported embedding models
179175
let supported_embedding_models: Vec<EmbeddingModelType> = env::var("SUPPORTED_EMBEDDING_MODELS")
@@ -183,9 +179,7 @@ pub fn fetch_node_environment() -> NodeEnvironment {
183179
.collect()
184180
})
185181
.unwrap_or_else(|_| {
186-
vec![EmbeddingModelType::OllamaTextEmbeddingsInference(
187-
OllamaTextEmbeddingsInference::SnowflakeArcticEmbedM,
188-
)]
182+
vec![EmbeddingModelType::default()]
189183
});
190184

191185
// Fetch the API_V2_KEY environment variable

shinkai-bin/shinkai-node/tests/it/db_identity_tests.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ fn setup_test_db() -> SqliteManager {
2323
let db_path = PathBuf::from(temp_file.path());
2424
let api_url = String::new();
2525
let model_type =
26-
EmbeddingModelType::OllamaTextEmbeddingsInference(OllamaTextEmbeddingsInference::SnowflakeArcticEmbedM);
26+
EmbeddingModelType::default();
2727

2828
SqliteManager::new(db_path, api_url, model_type).unwrap()
2929
}

shinkai-bin/shinkai-node/tests/it/db_inbox_tests.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ fn setup_test_db() -> SqliteManager {
2323
let db_path = PathBuf::from(temp_file.path());
2424
let api_url = String::new();
2525
let model_type =
26-
EmbeddingModelType::OllamaTextEmbeddingsInference(OllamaTextEmbeddingsInference::SnowflakeArcticEmbedM);
26+
EmbeddingModelType::default();
2727

2828
SqliteManager::new(db_path, api_url, model_type).unwrap()
2929
}

0 commit comments

Comments
 (0)