Skip to content

Commit 660604d

Browse files
committed
Fixes
1 parent 858782c commit 660604d

File tree

2 files changed

+69
-48
lines changed

2 files changed

+69
-48
lines changed

shinkai-libs/shinkai-sqlite/src/inbox_manager.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ impl SqliteManager {
2525
let smart_inbox_name = format!("New Inbox: {}", inbox_name);
2626
let conn = self.get_connection()?;
2727
conn.execute(
28-
"INSERT INTO inboxes (inbox_name, smart_inbox_name, last_modified, is_hidden) VALUES (?1, ?2, ?3, ?4)",
28+
"INSERT OR IGNORE INTO inboxes (inbox_name, smart_inbox_name, last_modified, is_hidden) VALUES (?1, ?2, ?3, ?4)",
2929
params![
3030
inbox_name,
3131
smart_inbox_name,
@@ -53,9 +53,8 @@ impl SqliteManager {
5353
return Err(SqliteManagerError::SomeError("Inbox name is empty".to_string()));
5454
}
5555

56-
if !self.does_inbox_exist(&inbox_name)? {
57-
self.create_empty_inbox(inbox_name.clone(), None)?;
58-
}
56+
// Create inbox if it doesn't exist (idempotent operation)
57+
self.create_empty_inbox(inbox_name.clone(), None)?;
5958

6059
// If this message has a parent, add this message as a child of the parent
6160
let parent_key = match maybe_parent_message_key {

shinkai-libs/shinkai-sqlite/src/shinkai_tool_manager.rs

Lines changed: 66 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,23 @@ impl SqliteManager {
1515
pub async fn add_tool(&self, tool: ShinkaiTool) -> Result<ShinkaiTool, SqliteManagerError> {
1616
let model_type = self.get_default_embedding_model()
1717
.unwrap_or_else(|_| EmbeddingModelType::default());
18-
let expected_dimensions = model_type.vector_dimensions().unwrap_or(768);
19-
18+
2019
let embedding = match tool.get_embedding() {
2120
Some(existing_embedding) => {
22-
// Check if existing embedding has correct dimensions
23-
if existing_embedding.len() == expected_dimensions {
24-
existing_embedding
25-
} else {
26-
// Dimension mismatch - regenerate with current model
27-
self.generate_embeddings(&tool.format_embedding_string()).await?
21+
// Check if existing embedding has correct dimensions (skip check for custom models)
22+
match model_type.vector_dimensions() {
23+
Ok(expected_dimensions) => {
24+
if existing_embedding.len() == expected_dimensions {
25+
existing_embedding
26+
} else {
27+
// Dimension mismatch - regenerate with current model
28+
self.generate_embeddings(&tool.format_embedding_string()).await?
29+
}
30+
}
31+
Err(_) => {
32+
// Unknown dimensions for custom models - use existing embedding
33+
existing_embedding
34+
}
2835
}
2936
}
3037
None => {
@@ -62,16 +69,18 @@ impl SqliteManager {
6269
let tool_type = tool.tool_type().to_string();
6370
let tool_header = serde_json::to_vec(&tool.to_header()).unwrap();
6471

65-
// Validate embedding dimensions before storing
72+
// Validate embedding dimensions before storing (skip check for custom models with unknown dimensions)
6673
let model_type = self.get_default_embedding_model()
6774
.unwrap_or_else(|_| EmbeddingModelType::default());
68-
let expected_dimensions = model_type.vector_dimensions().unwrap_or(768);
69-
if embedding.len() != expected_dimensions {
70-
return Err(SqliteManagerError::SomeError(format!(
71-
"Embedding dimension mismatch: expected {} dimensions but received {}",
72-
expected_dimensions, embedding.len()
73-
)));
75+
if let Ok(expected_dimensions) = model_type.vector_dimensions() {
76+
if embedding.len() != expected_dimensions {
77+
return Err(SqliteManagerError::SomeError(format!(
78+
"Embedding dimension mismatch: expected {} dimensions but received {}",
79+
expected_dimensions, embedding.len()
80+
)));
81+
}
7482
}
83+
// Skip dimension validation for custom models where dimensions are unknown
7584

7685
// Clone the tool to make it mutable
7786
let mut tool_clone = tool.clone();
@@ -171,16 +180,18 @@ impl SqliteManager {
171180
new_tool: ShinkaiTool,
172181
embedding: Vec<f32>,
173182
) -> Result<ShinkaiTool, SqliteManagerError> {
174-
// Validate embedding dimensions before upgrading
183+
// Validate embedding dimensions before upgrading (skip check for custom models with unknown dimensions)
175184
let model_type = self.get_default_embedding_model()
176185
.unwrap_or_else(|_| EmbeddingModelType::default());
177-
let expected_dimensions = model_type.vector_dimensions().unwrap_or(768);
178-
if embedding.len() != expected_dimensions {
179-
return Err(SqliteManagerError::SomeError(format!(
180-
"Embedding dimension mismatch: expected {} dimensions but received {}",
181-
expected_dimensions, embedding.len()
182-
)));
186+
if let Ok(expected_dimensions) = model_type.vector_dimensions() {
187+
if embedding.len() != expected_dimensions {
188+
return Err(SqliteManagerError::SomeError(format!(
189+
"Embedding dimension mismatch: expected {} dimensions but received {}",
190+
expected_dimensions, embedding.len()
191+
)));
192+
}
183193
}
194+
// Skip dimension validation for custom models where dimensions are unknown
184195

185196
// Use the tool_router_key (without version) to locate the old version
186197
let tool_key = new_tool.tool_router_key().to_string_without_version();
@@ -475,16 +486,18 @@ impl SqliteManager {
475486
tool: ShinkaiTool,
476487
embedding: Vec<f32>,
477488
) -> Result<ShinkaiTool, SqliteManagerError> {
478-
// Validate embedding dimensions before updating
489+
// Validate embedding dimensions before updating (skip check for custom models with unknown dimensions)
479490
let model_type = self.get_default_embedding_model()
480491
.unwrap_or_else(|_| EmbeddingModelType::default());
481-
let expected_dimensions = model_type.vector_dimensions().unwrap_or(768);
482-
if embedding.len() != expected_dimensions {
483-
return Err(SqliteManagerError::SomeError(format!(
484-
"Embedding dimension mismatch: expected {} dimensions but received {}",
485-
expected_dimensions, embedding.len()
486-
)));
492+
if let Ok(expected_dimensions) = model_type.vector_dimensions() {
493+
if embedding.len() != expected_dimensions {
494+
return Err(SqliteManagerError::SomeError(format!(
495+
"Embedding dimension mismatch: expected {} dimensions but received {}",
496+
expected_dimensions, embedding.len()
497+
)));
498+
}
487499
}
500+
// Skip dimension validation for custom models where dimensions are unknown
488501

489502
let mut conn = self.get_connection()?;
490503
let tx = conn.transaction()?;
@@ -582,16 +595,23 @@ impl SqliteManager {
582595
pub async fn update_tool(&self, tool: ShinkaiTool) -> Result<ShinkaiTool, SqliteManagerError> {
583596
let model_type = self.get_default_embedding_model()
584597
.unwrap_or_else(|_| EmbeddingModelType::default());
585-
let expected_dimensions = model_type.vector_dimensions().unwrap_or(768);
586-
598+
587599
let embedding = match tool.get_embedding() {
588600
Some(existing_embedding) => {
589-
// Check if existing embedding has correct dimensions
590-
if existing_embedding.len() == expected_dimensions {
591-
existing_embedding
592-
} else {
593-
// Dimension mismatch - regenerate with current model
594-
self.generate_embeddings(&tool.format_embedding_string()).await?
601+
// Check if existing embedding has correct dimensions (skip check for custom models)
602+
match model_type.vector_dimensions() {
603+
Ok(expected_dimensions) => {
604+
if existing_embedding.len() == expected_dimensions {
605+
existing_embedding
606+
} else {
607+
// Dimension mismatch - regenerate with current model
608+
self.generate_embeddings(&tool.format_embedding_string()).await?
609+
}
610+
}
611+
Err(_) => {
612+
// Unknown dimensions for custom models - use existing embedding
613+
existing_embedding
614+
}
595615
}
596616
}
597617
None => {
@@ -936,16 +956,18 @@ impl SqliteManager {
936956
tool_key: &str,
937957
embedding: Vec<f32>,
938958
) -> Result<(), SqliteManagerError> {
939-
// Validate embedding dimensions before updating vector
959+
// Validate embedding dimensions before updating vector (skip check for custom models with unknown dimensions)
940960
let model_type = self.get_default_embedding_model()
941961
.unwrap_or_else(|_| EmbeddingModelType::default());
942-
let expected_dimensions = model_type.vector_dimensions().unwrap_or(768);
943-
if embedding.len() != expected_dimensions {
944-
return Err(SqliteManagerError::SomeError(format!(
945-
"Embedding dimension mismatch: expected {} dimensions but received {}",
946-
expected_dimensions, embedding.len()
947-
)));
962+
if let Ok(expected_dimensions) = model_type.vector_dimensions() {
963+
if embedding.len() != expected_dimensions {
964+
return Err(SqliteManagerError::SomeError(format!(
965+
"Embedding dimension mismatch: expected {} dimensions but received {}",
966+
expected_dimensions, embedding.len()
967+
)));
968+
}
948969
}
970+
// Skip dimension validation for custom models where dimensions are unknown
949971

950972
// Get is_enabled and is_network from the main database
951973
let (is_enabled, is_network): (i32, i32) = tx.query_row(

0 commit comments

Comments
 (0)