@@ -15,16 +15,23 @@ impl SqliteManager {
15
15
pub async fn add_tool ( & self , tool : ShinkaiTool ) -> Result < ShinkaiTool , SqliteManagerError > {
16
16
let model_type = self . get_default_embedding_model ( )
17
17
. unwrap_or_else ( |_| EmbeddingModelType :: default ( ) ) ;
18
- let expected_dimensions = model_type. vector_dimensions ( ) . unwrap_or ( 768 ) ;
19
-
18
+
20
19
let embedding = match tool. get_embedding ( ) {
21
20
Some ( existing_embedding) => {
22
- // Check if existing embedding has correct dimensions
23
- if existing_embedding. len ( ) == expected_dimensions {
24
- existing_embedding
25
- } else {
26
- // Dimension mismatch - regenerate with current model
27
- self . generate_embeddings ( & tool. format_embedding_string ( ) ) . await ?
21
+ // Check if existing embedding has correct dimensions (skip check for custom models)
22
+ match model_type. vector_dimensions ( ) {
23
+ Ok ( expected_dimensions) => {
24
+ if existing_embedding. len ( ) == expected_dimensions {
25
+ existing_embedding
26
+ } else {
27
+ // Dimension mismatch - regenerate with current model
28
+ self . generate_embeddings ( & tool. format_embedding_string ( ) ) . await ?
29
+ }
30
+ }
31
+ Err ( _) => {
32
+ // Unknown dimensions for custom models - use existing embedding
33
+ existing_embedding
34
+ }
28
35
}
29
36
}
30
37
None => {
@@ -62,16 +69,18 @@ impl SqliteManager {
62
69
let tool_type = tool. tool_type ( ) . to_string ( ) ;
63
70
let tool_header = serde_json:: to_vec ( & tool. to_header ( ) ) . unwrap ( ) ;
64
71
65
- // Validate embedding dimensions before storing
72
+ // Validate embedding dimensions before storing (skip check for custom models with unknown dimensions)
66
73
let model_type = self . get_default_embedding_model ( )
67
74
. unwrap_or_else ( |_| EmbeddingModelType :: default ( ) ) ;
68
- let expected_dimensions = model_type. vector_dimensions ( ) . unwrap_or ( 768 ) ;
69
- if embedding. len ( ) != expected_dimensions {
70
- return Err ( SqliteManagerError :: SomeError ( format ! (
71
- "Embedding dimension mismatch: expected {} dimensions but received {}" ,
72
- expected_dimensions, embedding. len( )
73
- ) ) ) ;
75
+ if let Ok ( expected_dimensions) = model_type. vector_dimensions ( ) {
76
+ if embedding. len ( ) != expected_dimensions {
77
+ return Err ( SqliteManagerError :: SomeError ( format ! (
78
+ "Embedding dimension mismatch: expected {} dimensions but received {}" ,
79
+ expected_dimensions, embedding. len( )
80
+ ) ) ) ;
81
+ }
74
82
}
83
+ // Skip dimension validation for custom models where dimensions are unknown
75
84
76
85
// Clone the tool to make it mutable
77
86
let mut tool_clone = tool. clone ( ) ;
@@ -171,16 +180,18 @@ impl SqliteManager {
171
180
new_tool : ShinkaiTool ,
172
181
embedding : Vec < f32 > ,
173
182
) -> Result < ShinkaiTool , SqliteManagerError > {
174
- // Validate embedding dimensions before upgrading
183
+ // Validate embedding dimensions before upgrading (skip check for custom models with unknown dimensions)
175
184
let model_type = self . get_default_embedding_model ( )
176
185
. unwrap_or_else ( |_| EmbeddingModelType :: default ( ) ) ;
177
- let expected_dimensions = model_type. vector_dimensions ( ) . unwrap_or ( 768 ) ;
178
- if embedding. len ( ) != expected_dimensions {
179
- return Err ( SqliteManagerError :: SomeError ( format ! (
180
- "Embedding dimension mismatch: expected {} dimensions but received {}" ,
181
- expected_dimensions, embedding. len( )
182
- ) ) ) ;
186
+ if let Ok ( expected_dimensions) = model_type. vector_dimensions ( ) {
187
+ if embedding. len ( ) != expected_dimensions {
188
+ return Err ( SqliteManagerError :: SomeError ( format ! (
189
+ "Embedding dimension mismatch: expected {} dimensions but received {}" ,
190
+ expected_dimensions, embedding. len( )
191
+ ) ) ) ;
192
+ }
183
193
}
194
+ // Skip dimension validation for custom models where dimensions are unknown
184
195
185
196
// Use the tool_router_key (without version) to locate the old version
186
197
let tool_key = new_tool. tool_router_key ( ) . to_string_without_version ( ) ;
@@ -475,16 +486,18 @@ impl SqliteManager {
475
486
tool : ShinkaiTool ,
476
487
embedding : Vec < f32 > ,
477
488
) -> Result < ShinkaiTool , SqliteManagerError > {
478
- // Validate embedding dimensions before updating
489
+ // Validate embedding dimensions before updating (skip check for custom models with unknown dimensions)
479
490
let model_type = self . get_default_embedding_model ( )
480
491
. unwrap_or_else ( |_| EmbeddingModelType :: default ( ) ) ;
481
- let expected_dimensions = model_type. vector_dimensions ( ) . unwrap_or ( 768 ) ;
482
- if embedding. len ( ) != expected_dimensions {
483
- return Err ( SqliteManagerError :: SomeError ( format ! (
484
- "Embedding dimension mismatch: expected {} dimensions but received {}" ,
485
- expected_dimensions, embedding. len( )
486
- ) ) ) ;
492
+ if let Ok ( expected_dimensions) = model_type. vector_dimensions ( ) {
493
+ if embedding. len ( ) != expected_dimensions {
494
+ return Err ( SqliteManagerError :: SomeError ( format ! (
495
+ "Embedding dimension mismatch: expected {} dimensions but received {}" ,
496
+ expected_dimensions, embedding. len( )
497
+ ) ) ) ;
498
+ }
487
499
}
500
+ // Skip dimension validation for custom models where dimensions are unknown
488
501
489
502
let mut conn = self . get_connection ( ) ?;
490
503
let tx = conn. transaction ( ) ?;
@@ -582,16 +595,23 @@ impl SqliteManager {
582
595
pub async fn update_tool ( & self , tool : ShinkaiTool ) -> Result < ShinkaiTool , SqliteManagerError > {
583
596
let model_type = self . get_default_embedding_model ( )
584
597
. unwrap_or_else ( |_| EmbeddingModelType :: default ( ) ) ;
585
- let expected_dimensions = model_type. vector_dimensions ( ) . unwrap_or ( 768 ) ;
586
-
598
+
587
599
let embedding = match tool. get_embedding ( ) {
588
600
Some ( existing_embedding) => {
589
- // Check if existing embedding has correct dimensions
590
- if existing_embedding. len ( ) == expected_dimensions {
591
- existing_embedding
592
- } else {
593
- // Dimension mismatch - regenerate with current model
594
- self . generate_embeddings ( & tool. format_embedding_string ( ) ) . await ?
601
+ // Check if existing embedding has correct dimensions (skip check for custom models)
602
+ match model_type. vector_dimensions ( ) {
603
+ Ok ( expected_dimensions) => {
604
+ if existing_embedding. len ( ) == expected_dimensions {
605
+ existing_embedding
606
+ } else {
607
+ // Dimension mismatch - regenerate with current model
608
+ self . generate_embeddings ( & tool. format_embedding_string ( ) ) . await ?
609
+ }
610
+ }
611
+ Err ( _) => {
612
+ // Unknown dimensions for custom models - use existing embedding
613
+ existing_embedding
614
+ }
595
615
}
596
616
}
597
617
None => {
@@ -936,16 +956,18 @@ impl SqliteManager {
936
956
tool_key : & str ,
937
957
embedding : Vec < f32 > ,
938
958
) -> Result < ( ) , SqliteManagerError > {
939
- // Validate embedding dimensions before updating vector
959
+ // Validate embedding dimensions before updating vector (skip check for custom models with unknown dimensions)
940
960
let model_type = self . get_default_embedding_model ( )
941
961
. unwrap_or_else ( |_| EmbeddingModelType :: default ( ) ) ;
942
- let expected_dimensions = model_type. vector_dimensions ( ) . unwrap_or ( 768 ) ;
943
- if embedding. len ( ) != expected_dimensions {
944
- return Err ( SqliteManagerError :: SomeError ( format ! (
945
- "Embedding dimension mismatch: expected {} dimensions but received {}" ,
946
- expected_dimensions, embedding. len( )
947
- ) ) ) ;
962
+ if let Ok ( expected_dimensions) = model_type. vector_dimensions ( ) {
963
+ if embedding. len ( ) != expected_dimensions {
964
+ return Err ( SqliteManagerError :: SomeError ( format ! (
965
+ "Embedding dimension mismatch: expected {} dimensions but received {}" ,
966
+ expected_dimensions, embedding. len( )
967
+ ) ) ) ;
968
+ }
948
969
}
970
+ // Skip dimension validation for custom models where dimensions are unknown
949
971
950
972
// Get is_enabled and is_network from the main database
951
973
let ( is_enabled, is_network) : ( i32 , i32 ) = tx. query_row (
0 commit comments