Skip to content

Commit 284a688

Browse files
authored
Merge pull request #1290 from dcSpark/feature/embedding-migration-endpoint
Embeddings migration endpoint
2 parents 9167095 + 076420b commit 284a688

File tree

6 files changed

+436
-62
lines changed

6 files changed

+436
-62
lines changed

shinkai-bin/shinkai-node/src/network/handle_commands_list.rs

Lines changed: 66 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,11 @@ impl Node {
140140
let full_identity = self.node_name.clone();
141141
let signing_secret_key = self.identity_secret_key.clone();
142142
let node_env = fetch_node_environment();
143-
let embedding_generator = Arc::new(self.embedding_generator.clone());
143+
let embedding_generator_ref = Arc::clone(&self.embedding_generator);
144+
let embedding_generator = Arc::new({
145+
let generator_guard = embedding_generator_ref.lock().await;
146+
generator_guard.clone()
147+
});
144148
tokio::spawn(async move {
145149
let _ = Node::v2_api_import_agent_url(
146150
db_clone,
@@ -159,7 +163,11 @@ impl Node {
159163
let db_clone = Arc::clone(&self.db);
160164
let node_env = fetch_node_environment();
161165
let full_identity = self.node_name.clone();
162-
let embedding_generator = Arc::new(self.embedding_generator.clone());
166+
let embedding_generator_ref = Arc::clone(&self.embedding_generator);
167+
let embedding_generator = Arc::new({
168+
let generator_guard = embedding_generator_ref.lock().await;
169+
generator_guard.clone()
170+
});
163171
tokio::spawn(async move {
164172
let _ = Node::v2_api_import_agent_zip(
165173
db_clone,
@@ -233,7 +241,11 @@ impl Node {
233241
let node_name_clone = self.node_name.clone();
234242
let encryption_secret_key_clone = self.encryption_secret_key.clone();
235243
let first_device_needs_registration_code = self.first_device_needs_registration_code;
236-
let embedding_generator_clone = Arc::new(self.embedding_generator.clone());
244+
let embedding_generator_ref = Arc::clone(&self.embedding_generator);
245+
let embedding_generator_clone = Arc::new({
246+
let generator_guard = embedding_generator_ref.lock().await;
247+
generator_guard.clone()
248+
});
237249
let encryption_public_key_clone = self.encryption_public_key;
238250
let identity_public_key_clone = self.identity_public_key;
239251
let identity_secret_key_clone = self.identity_secret_key.clone();
@@ -557,7 +569,11 @@ impl Node {
557569
let identity_manager_clone = self.identity_manager.clone();
558570
let node_name_clone = self.node_name.clone();
559571
let first_device_needs_registration_code = self.first_device_needs_registration_code;
560-
let embedding_generator_clone = Arc::new(self.embedding_generator.clone());
572+
let embedding_generator_ref = Arc::clone(&self.embedding_generator);
573+
let embedding_generator_clone = Arc::new({
574+
let generator_guard = embedding_generator_ref.lock().await;
575+
generator_guard.clone()
576+
});
561577
let encryption_public_key_clone = self.encryption_public_key;
562578
let identity_public_key_clone = self.identity_public_key;
563579
let identity_secret_key_clone = self.identity_secret_key.clone();
@@ -872,15 +888,19 @@ impl Node {
872888

873889
NodeCommand::V2ApiSearchItems { bearer, payload, res } => {
874890
let db_clone = Arc::clone(&self.db);
875-
let embedding_generator_clone = self.embedding_generator.clone();
891+
let embedding_generator_ref = Arc::clone(&self.embedding_generator);
876892

877893
let identity_manager_clone = self.identity_manager.clone();
878894
tokio::spawn(async move {
895+
let embedding_generator = {
896+
let generator_guard = embedding_generator_ref.lock().await;
897+
generator_guard.clone()
898+
};
879899
let _ = Node::v2_search_items(
880900
db_clone,
881901
identity_manager_clone,
882902
payload,
883-
Arc::new(embedding_generator_clone),
903+
Arc::new(embedding_generator),
884904
bearer,
885905
res,
886906
)
@@ -954,12 +974,16 @@ impl Node {
954974
let db_clone = Arc::clone(&self.db);
955975

956976
let identity_manager_clone = self.identity_manager.clone();
957-
let embedding_generator_clone = self.embedding_generator.clone();
977+
let embedding_generator_ref = Arc::clone(&self.embedding_generator);
958978
tokio::spawn(async move {
979+
let embedding_generator = {
980+
let generator_guard = embedding_generator_ref.lock().await;
981+
generator_guard.clone()
982+
};
959983
let _ = Node::v2_upload_file_to_folder(
960984
db_clone,
961985
identity_manager_clone,
962-
Arc::new(embedding_generator_clone),
986+
Arc::new(embedding_generator),
963987
bearer,
964988
filename,
965989
file,
@@ -980,12 +1004,16 @@ impl Node {
9801004
} => {
9811005
let db_clone = Arc::clone(&self.db);
9821006
let identity_manager_clone = self.identity_manager.clone();
983-
let embedding_generator_clone = self.embedding_generator.clone();
1007+
let embedding_generator_ref = Arc::clone(&self.embedding_generator);
9841008
tokio::spawn(async move {
1009+
let embedding_generator = {
1010+
let generator_guard = embedding_generator_ref.lock().await;
1011+
generator_guard.clone()
1012+
};
9851013
let _ = Node::v2_upload_file_to_job(
9861014
db_clone,
9871015
identity_manager_clone,
988-
Arc::new(embedding_generator_clone),
1016+
Arc::new(embedding_generator),
9891017
bearer,
9901018
job_id,
9911019
filename,
@@ -1169,8 +1197,25 @@ impl Node {
11691197
NodeCommand::V2ApiHealthCheck { res } => {
11701198
let db_clone = Arc::clone(&self.db);
11711199
let public_https_certificate_clone = self.public_https_certificate.clone();
1200+
let is_migration_in_progress_clone = Arc::clone(&self.is_migration_in_progress);
1201+
tokio::spawn(async move {
1202+
let _ = Node::v2_api_health_check(db_clone, public_https_certificate_clone, is_migration_in_progress_clone, res).await;
1203+
});
1204+
}
1205+
NodeCommand::V2ApiTriggerEmbeddingMigration { bearer, payload, res } => {
1206+
let db_clone = Arc::clone(&self.db);
1207+
let embedding_generator_clone = Arc::clone(&self.embedding_generator);
1208+
let default_embedding_model_clone = Arc::clone(&self.default_embedding_model);
1209+
let is_migration_in_progress_clone = Arc::clone(&self.is_migration_in_progress);
1210+
tokio::spawn(async move {
1211+
let _ = Node::v2_api_trigger_embedding_migration(db_clone, embedding_generator_clone, default_embedding_model_clone, is_migration_in_progress_clone, bearer, payload, res).await;
1212+
});
1213+
}
1214+
NodeCommand::V2ApiGetMigrationStatus { bearer, res } => {
1215+
let db_clone = Arc::clone(&self.db);
1216+
let is_migration_in_progress_clone = Arc::clone(&self.is_migration_in_progress);
11721217
tokio::spawn(async move {
1173-
let _ = Node::v2_api_health_check(db_clone, public_https_certificate_clone, res).await;
1218+
let _ = Node::v2_api_get_migration_status(db_clone, is_migration_in_progress_clone, bearer, res).await;
11741219
});
11751220
}
11761221
NodeCommand::V2ApiScanOllamaModels { bearer, res } => {
@@ -2040,9 +2085,13 @@ impl Node {
20402085
let db_clone = Arc::clone(&self.db);
20412086
let node_env = fetch_node_environment();
20422087
let signing_secret_key = self.identity_secret_key.clone();
2043-
let embedding_generator = self.embedding_generator.clone();
2088+
let embedding_generator_ref = Arc::clone(&self.embedding_generator);
20442089
let full_identity = self.node_name.clone();
20452090
tokio::spawn(async move {
2091+
let embedding_generator = {
2092+
let generator_guard = embedding_generator_ref.lock().await;
2093+
generator_guard.clone()
2094+
};
20462095
let _ = Node::v2_api_import_tool_url(
20472096
db_clone,
20482097
bearer,
@@ -2059,9 +2108,13 @@ impl Node {
20592108
NodeCommand::V2ApiImportToolZip { bearer, file_data, res } => {
20602109
let db_clone = Arc::clone(&self.db);
20612110
let node_env = fetch_node_environment();
2062-
let embedding_generator = self.embedding_generator.clone();
2111+
let embedding_generator_ref = Arc::clone(&self.embedding_generator);
20632112
let full_identity = self.node_name.clone();
20642113
tokio::spawn(async move {
2114+
let embedding_generator = {
2115+
let generator_guard = embedding_generator_ref.lock().await;
2116+
generator_guard.clone()
2117+
};
20652118
let _ = Node::v2_api_import_tool_zip(
20662119
db_clone,
20672120
bearer,

shinkai-bin/shinkai-node/src/network/node.rs

Lines changed: 39 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ use shinkai_sqlite::SqliteManager;
4646
use std::fs;
4747
use std::path::Path;
4848
use std::sync::Arc;
49-
use std::{io, net::SocketAddr, time::Duration};
49+
use std::{io, net::SocketAddr, time::Duration, sync::atomic::AtomicBool};
5050
use tokio::sync::Mutex;
5151
use tokio::time::Instant;
5252
use x25519_dalek::{PublicKey as EncryptionPublicKey, StaticSecret as EncryptionStaticKey};
@@ -97,7 +97,7 @@ pub struct Node {
9797
// Cron Manager
9898
pub cron_manager: Option<Arc<Mutex<CronManager>>>,
9999
// An EmbeddingGenerator initialized with the Node's default embedding model + server info
100-
pub embedding_generator: RemoteEmbeddingGenerator,
100+
pub embedding_generator: Arc<Mutex<RemoteEmbeddingGenerator>>,
101101
// Proxy Address
102102
pub proxy_connection_info: Arc<Mutex<Option<ProxyConnectionInfo>>>,
103103
// Websocket Manager
@@ -116,6 +116,8 @@ pub struct Node {
116116
pub default_embedding_model: Arc<Mutex<EmbeddingModelType>>,
117117
// Supported embedding models for profiles
118118
pub supported_embedding_models: Arc<Mutex<Vec<EmbeddingModelType>>>,
119+
// Migration status tracking
120+
pub is_migration_in_progress: Arc<AtomicBool>,
119121
// API V2 Key
120122
#[allow(dead_code)]
121123
pub api_v2_key: String,
@@ -357,7 +359,7 @@ impl Node {
357359
cron_manager: None,
358360
first_device_needs_registration_code,
359361
initial_llm_providers,
360-
embedding_generator,
362+
embedding_generator: Arc::new(Mutex::new(embedding_generator)),
361363
proxy_connection_info,
362364
ws_manager,
363365
ws_manager_trait,
@@ -367,6 +369,7 @@ impl Node {
367369
tool_router: Some(tool_router),
368370
default_embedding_model,
369371
supported_embedding_models,
372+
is_migration_in_progress: Arc::new(AtomicBool::new(false)),
370373
api_v2_key,
371374
wallet_manager,
372375
my_agent_payments_manager,
@@ -397,22 +400,28 @@ impl Node {
397400
}
398401
}
399402

400-
let job_manager = Arc::new(Mutex::new(
401-
JobManager::new(
402-
db_weak,
403-
Arc::clone(&self.identity_manager),
404-
clone_signature_secret_key(&self.identity_secret_key),
405-
self.node_name.clone(),
406-
self.embedding_generator.clone(),
407-
self.ws_manager_trait.clone(),
408-
self.tool_router.clone(),
409-
self.callback_manager.clone(),
410-
self.my_agent_payments_manager.clone(),
411-
self.ext_agent_payments_manager.clone(),
412-
self.llm_stopper.clone(),
413-
)
414-
.await,
415-
));
403+
let job_manager = {
404+
let embedding_generator = {
405+
let generator_guard = self.embedding_generator.lock().await;
406+
generator_guard.clone()
407+
};
408+
Arc::new(Mutex::new(
409+
JobManager::new(
410+
db_weak,
411+
Arc::clone(&self.identity_manager),
412+
clone_signature_secret_key(&self.identity_secret_key),
413+
self.node_name.clone(),
414+
embedding_generator,
415+
self.ws_manager_trait.clone(),
416+
self.tool_router.clone(),
417+
self.callback_manager.clone(),
418+
self.my_agent_payments_manager.clone(),
419+
self.ext_agent_payments_manager.clone(),
420+
self.llm_stopper.clone(),
421+
)
422+
.await,
423+
))
424+
};
416425
self.job_manager = Some(job_manager.clone());
417426

418427
shinkai_log(
@@ -443,29 +452,7 @@ impl Node {
443452
callback_manager.update_cron_manager(cron_manager.clone());
444453
}
445454

446-
// Perform embedding migration if needed
447-
{
448-
let current_model = {
449-
let default_model_guard = self.default_embedding_model.lock().await;
450-
default_model_guard.clone()
451-
};
452-
453-
let db_clone = Arc::clone(&self.db);
454-
let embedding_generator_clone = self.embedding_generator.clone();
455-
let current_model_clone = current_model.clone();
456-
457-
// Run migration in background without blocking node startup
458-
tokio::spawn(async move {
459-
if let Err(e) = db_clone.migrate_embeddings_to_new_model(&embedding_generator_clone, &current_model_clone).await {
460-
shinkai_log(
461-
ShinkaiLogOption::Node,
462-
ShinkaiLogLevel::Error,
463-
&format!("Embedding migration failed: {e:?}"),
464-
);
465-
}
466-
});
467-
}
468-
455+
// Initialize embedding models
469456
self.initialize_embedding_models().await?;
470457

471458
{
@@ -491,7 +478,10 @@ impl Node {
491478
// Call ToolRouter initialization in a new task
492479
if let Some(tool_router) = &self.tool_router {
493480
let tool_router = tool_router.clone();
494-
let generator = self.embedding_generator.clone();
481+
let generator = {
482+
let generator_guard = self.embedding_generator.lock().await;
483+
generator_guard.clone()
484+
};
495485
let reinstall_tools = std::env::var("REINSTALL_TOOLS").unwrap_or_else(|_| "false".to_string()) == "true";
496486

497487
tokio::spawn(async move {
@@ -699,15 +689,19 @@ impl Node {
699689
let node_name_clone = self.node_name.clone();
700690
let identity_manager_clone = self.identity_manager.clone();
701691
let tool_router_clone = self.tool_router.clone();
702-
let embedding_generator_clone = self.embedding_generator.clone();
692+
let embedding_generator_ref = Arc::clone(&self.embedding_generator);
703693
// Spawn a new task to handle periodic maintenance
704694
tokio::spawn(async move {
695+
let embedding_generator = {
696+
let generator_guard = embedding_generator_ref.lock().await;
697+
generator_guard.clone()
698+
};
705699
let _ = Self::handle_periodic_maintenance(
706700
db_clone,
707701
node_name_clone,
708702
identity_manager_clone,
709703
tool_router_clone,
710-
Arc::new(embedding_generator_clone),
704+
Arc::new(embedding_generator),
711705
).await;
712706
});
713707
},

0 commit comments

Comments
 (0)