Skip to content

Commit 512952b

Browse files
committed
Embeddings migration endpoint
1 parent 9167095 commit 512952b

File tree

6 files changed

+323
-29
lines changed

6 files changed

+323
-29
lines changed

shinkai-bin/shinkai-node/src/network/handle_commands_list.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1169,8 +1169,24 @@ impl Node {
11691169
NodeCommand::V2ApiHealthCheck { res } => {
11701170
let db_clone = Arc::clone(&self.db);
11711171
let public_https_certificate_clone = self.public_https_certificate.clone();
1172+
let is_migration_in_progress_clone = Arc::clone(&self.is_migration_in_progress);
11721173
tokio::spawn(async move {
1173-
let _ = Node::v2_api_health_check(db_clone, public_https_certificate_clone, res).await;
1174+
let _ = Node::v2_api_health_check(db_clone, public_https_certificate_clone, is_migration_in_progress_clone, res).await;
1175+
});
1176+
}
1177+
NodeCommand::V2ApiTriggerEmbeddingMigration { bearer, payload, res } => {
1178+
let db_clone = Arc::clone(&self.db);
1179+
let embedding_generator_clone = self.embedding_generator.clone();
1180+
let is_migration_in_progress_clone = Arc::clone(&self.is_migration_in_progress);
1181+
tokio::spawn(async move {
1182+
let _ = Node::v2_api_trigger_embedding_migration(db_clone, embedding_generator_clone, is_migration_in_progress_clone, bearer, payload, res).await;
1183+
});
1184+
}
1185+
NodeCommand::V2ApiGetMigrationStatus { bearer, res } => {
1186+
let db_clone = Arc::clone(&self.db);
1187+
let is_migration_in_progress_clone = Arc::clone(&self.is_migration_in_progress);
1188+
tokio::spawn(async move {
1189+
let _ = Node::v2_api_get_migration_status(db_clone, is_migration_in_progress_clone, bearer, res).await;
11741190
});
11751191
}
11761192
NodeCommand::V2ApiScanOllamaModels { bearer, res } => {

shinkai-bin/shinkai-node/src/network/node.rs

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ use shinkai_sqlite::SqliteManager;
4646
use std::fs;
4747
use std::path::Path;
4848
use std::sync::Arc;
49-
use std::{io, net::SocketAddr, time::Duration};
49+
use std::{io, net::SocketAddr, time::Duration, sync::atomic::AtomicBool};
5050
use tokio::sync::Mutex;
5151
use tokio::time::Instant;
5252
use x25519_dalek::{PublicKey as EncryptionPublicKey, StaticSecret as EncryptionStaticKey};
@@ -116,6 +116,8 @@ pub struct Node {
116116
pub default_embedding_model: Arc<Mutex<EmbeddingModelType>>,
117117
// Supported embedding models for profiles
118118
pub supported_embedding_models: Arc<Mutex<Vec<EmbeddingModelType>>>,
119+
// Migration status tracking
120+
pub is_migration_in_progress: Arc<AtomicBool>,
119121
// API V2 Key
120122
#[allow(dead_code)]
121123
pub api_v2_key: String,
@@ -367,6 +369,7 @@ impl Node {
367369
tool_router: Some(tool_router),
368370
default_embedding_model,
369371
supported_embedding_models,
372+
is_migration_in_progress: Arc::new(AtomicBool::new(false)),
370373
api_v2_key,
371374
wallet_manager,
372375
my_agent_payments_manager,
@@ -443,29 +446,7 @@ impl Node {
443446
callback_manager.update_cron_manager(cron_manager.clone());
444447
}
445448

446-
// Perform embedding migration if needed
447-
{
448-
let current_model = {
449-
let default_model_guard = self.default_embedding_model.lock().await;
450-
default_model_guard.clone()
451-
};
452-
453-
let db_clone = Arc::clone(&self.db);
454-
let embedding_generator_clone = self.embedding_generator.clone();
455-
let current_model_clone = current_model.clone();
456-
457-
// Run migration in background without blocking node startup
458-
tokio::spawn(async move {
459-
if let Err(e) = db_clone.migrate_embeddings_to_new_model(&embedding_generator_clone, &current_model_clone).await {
460-
shinkai_log(
461-
ShinkaiLogOption::Node,
462-
ShinkaiLogLevel::Error,
463-
&format!("Embedding migration failed: {e:?}"),
464-
);
465-
}
466-
});
467-
}
468-
449+
// Initialize embedding models
469450
self.initialize_embedding_models().await?;
470451

471452
{

shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands.rs

Lines changed: 190 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ use shinkai_http_api::node_api_router::APIUseRegistrationCodeSuccessResponse;
2828
use shinkai_http_api::{
2929
api_v2::api_v2_handlers_general::InitialRegistrationRequest,
3030
node_api_router::{APIError, GetPublicKeysResponse},
31+
node_commands::EmbeddingMigrationRequest,
3132
};
3233
use shinkai_mcp::mcp_methods::{list_tools_via_command, list_tools_via_http, list_tools_via_sse};
3334
use shinkai_message_primitives::schemas::llm_providers::shinkai_backend::QuotaResponse;
@@ -51,6 +52,7 @@ use shinkai_message_primitives::{
5152
shinkai_utils::{
5253
encryption::{encryption_public_key_to_string, EncryptionMethod},
5354
shinkai_message_builder::ShinkaiMessageBuilder,
55+
shinkai_logging::{shinkai_log, ShinkaiLogLevel, ShinkaiLogOption},
5456
signatures::signature_public_key_to_string,
5557
},
5658
shinkai_utils::{job_scope::MinimalJobScope, shinkai_time::ShinkaiStringTime},
@@ -69,7 +71,7 @@ use shinkai_tools_primitives::tools::{
6971
use std::collections::HashMap;
7072
use std::process::Command;
7173
use std::time::Instant;
72-
use std::{env, sync::Arc};
74+
use std::{env, sync::{Arc, atomic::{AtomicBool, Ordering}}};
7375
use tokio::sync::Mutex;
7476
use tokio::time::Duration;
7577
use x25519_dalek::PublicKey as EncryptionPublicKey;
@@ -783,6 +785,7 @@ impl Node {
783785
pub async fn v2_api_health_check(
784786
db: Arc<SqliteManager>,
785787
public_https_certificate: Option<String>,
788+
is_migration_in_progress: Arc<AtomicBool>,
786789
res: Sender<Result<serde_json::Value, APIError>>,
787790
) -> Result<(), NodeError> {
788791
let public_https_certificate = match public_https_certificate {
@@ -806,18 +809,204 @@ impl Node {
806809
}
807810
};
808811

812+
let is_updating = is_migration_in_progress.load(Ordering::Relaxed);
809813
let _ = res
810814
.send(Ok(serde_json::json!({
811815
"is_pristine": !db.has_any_profile().unwrap_or(false),
812816
"public_https_certificate": public_https_certificate,
813817
"version": version,
814818
"update_requires_reset": needs_global_reset,
815819
"docker_status": "not-installed",
820+
"updating": is_updating,
821+
"ready": !is_updating,
816822
})))
817823
.await;
818824
Ok(())
819825
}
820826

827+
pub async fn v2_api_trigger_embedding_migration(
828+
db: Arc<SqliteManager>,
829+
embedding_generator: RemoteEmbeddingGenerator,
830+
is_migration_in_progress: Arc<AtomicBool>,
831+
bearer: String,
832+
payload: EmbeddingMigrationRequest,
833+
res: Sender<Result<serde_json::Value, APIError>>,
834+
) -> Result<(), NodeError> {
835+
// Validate the bearer token
836+
if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() {
837+
return Ok(());
838+
}
839+
840+
// Check if migration is already in progress (unless forced)
841+
if !payload.force && is_migration_in_progress.load(Ordering::Relaxed) {
842+
let _ = res
843+
.send(Err(APIError {
844+
code: StatusCode::CONFLICT.as_u16(),
845+
error: "Migration In Progress".to_string(),
846+
message: "Embedding migration is already in progress".to_string(),
847+
}))
848+
.await;
849+
return Ok(());
850+
}
851+
852+
// Parse the requested embedding model
853+
let requested_model = match EmbeddingModelType::from_string(&payload.embedding_model) {
854+
Ok(model) => model,
855+
Err(_) => {
856+
let _ = res
857+
.send(Err(APIError {
858+
code: StatusCode::BAD_REQUEST.as_u16(),
859+
error: "Invalid Embedding Model".to_string(),
860+
message: format!("Invalid embedding model: {}", payload.embedding_model),
861+
}))
862+
.await;
863+
return Ok(());
864+
}
865+
};
866+
867+
// Trigger the migration using the internal helper
868+
match Self::internal_trigger_embedding_migration(
869+
Arc::clone(&db),
870+
embedding_generator,
871+
requested_model,
872+
payload.force,
873+
Arc::clone(&is_migration_in_progress),
874+
true, // Check Ollama availability for API calls
875+
).await {
876+
Ok(_) => {
877+
// Migration started successfully
878+
}
879+
Err(err_msg) => {
880+
let (status_code, error_type) = if err_msg.contains("not available in Ollama") {
881+
(StatusCode::BAD_REQUEST, "Model Not Available")
882+
} else if err_msg.contains("Cannot connect to Ollama") {
883+
(StatusCode::SERVICE_UNAVAILABLE, "Ollama Unavailable")
884+
} else {
885+
(StatusCode::INTERNAL_SERVER_ERROR, "Migration Error")
886+
};
887+
888+
let _ = res
889+
.send(Err(APIError {
890+
code: status_code.as_u16(),
891+
error: error_type.to_string(),
892+
message: err_msg,
893+
}))
894+
.await;
895+
return Ok(());
896+
}
897+
}
898+
899+
// Send success response immediately (migration runs in background)
900+
let _ = res
901+
.send(Ok(serde_json::json!({
902+
"status": "success",
903+
"message": "Embedding migration has been triggered and is running in the background",
904+
"migration_in_progress": true,
905+
"target_model": payload.embedding_model,
906+
"force": payload.force
907+
})))
908+
.await;
909+
910+
Ok(())
911+
}
912+
913+
/// Internal helper function for triggering embedding migrations
914+
/// Used by both startup migration and API endpoint
915+
pub async fn internal_trigger_embedding_migration(
916+
db: Arc<SqliteManager>,
917+
embedding_generator: RemoteEmbeddingGenerator,
918+
target_model: EmbeddingModelType,
919+
force: bool,
920+
is_migration_in_progress: Arc<AtomicBool>,
921+
check_ollama_availability: bool,
922+
) -> Result<(), String> {
923+
// Check if model is available in Ollama (if requested)
924+
if check_ollama_availability {
925+
match Self::internal_scan_ollama_models().await {
926+
Ok(available_models) => {
927+
let model_name = target_model.to_string();
928+
let model_available = available_models.iter().any(|model| {
929+
model["name"].as_str()
930+
.map(|name| name == model_name)
931+
.unwrap_or(false)
932+
});
933+
934+
if !model_available {
935+
return Err(format!("Embedding model '{}' is not available in Ollama", model_name));
936+
}
937+
},
938+
Err(_) => {
939+
return Err("Cannot connect to Ollama to verify model availability".to_string());
940+
}
941+
}
942+
}
943+
944+
// Set migration status to in progress
945+
is_migration_in_progress.store(true, Ordering::Relaxed);
946+
947+
// Create a new embedding generator configured for the target model
948+
let mut target_embedding_generator = embedding_generator.clone();
949+
target_embedding_generator.set_model_type(target_model.clone());
950+
951+
// Clone necessary data for the migration task
952+
let db_clone = Arc::clone(&db);
953+
let target_model_clone = target_model.clone();
954+
let migration_status_clone = Arc::clone(&is_migration_in_progress);
955+
956+
// Spawn migration task
957+
tokio::spawn(async move {
958+
match db_clone.migrate_embeddings_to_new_model(&target_embedding_generator, &target_model_clone, force).await {
959+
Ok(_) => {
960+
shinkai_log(
961+
ShinkaiLogOption::Node,
962+
ShinkaiLogLevel::Info,
963+
&format!("Embedding migration to {} completed successfully", target_model_clone),
964+
);
965+
}
966+
Err(e) => {
967+
shinkai_log(
968+
ShinkaiLogOption::Node,
969+
ShinkaiLogLevel::Error,
970+
&format!("Embedding migration to {} failed: {e:?}", target_model_clone),
971+
);
972+
}
973+
}
974+
// Set migration status back to false when done
975+
migration_status_clone.store(false, Ordering::Relaxed);
976+
});
977+
978+
Ok(())
979+
}
980+
981+
pub async fn v2_api_get_migration_status(
982+
db: Arc<SqliteManager>,
983+
is_migration_in_progress: Arc<AtomicBool>,
984+
bearer: String,
985+
res: Sender<Result<serde_json::Value, APIError>>,
986+
) -> Result<(), NodeError> {
987+
// Validate the bearer token
988+
if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() {
989+
return Ok(());
990+
}
991+
992+
let is_migrating = is_migration_in_progress.load(Ordering::Relaxed);
993+
994+
// Get current embedding model for context
995+
let current_model = db.get_default_embedding_model()
996+
.unwrap_or_else(|_| EmbeddingModelType::default());
997+
998+
let _ = res
999+
.send(Ok(serde_json::json!({
1000+
"migration_in_progress": is_migrating,
1001+
"ready": !is_migrating,
1002+
"current_embedding_model": current_model.to_string(),
1003+
"status": if is_migrating { "migrating" } else { "ready" }
1004+
})))
1005+
.await;
1006+
1007+
Ok(())
1008+
}
1009+
8211010
pub async fn v2_api_scan_ollama_models(
8221011
db: Arc<SqliteManager>,
8231012
bearer: String,

0 commit comments

Comments
 (0)