20
20
import io .github .mem4j .memory .MemoryItem ;
21
21
import io .milvus .client .MilvusClient ;
22
22
import io .milvus .client .MilvusServiceClient ;
23
- import io .milvus .grpc .DataType ;
24
- import io .milvus .grpc .MutationResult ;
25
- import io .milvus .grpc .QueryResults ;
26
- import io .milvus .grpc .SearchResults ;
23
+ import io .milvus .grpc .*;
27
24
import io .milvus .param .ConnectParam ;
28
25
import io .milvus .param .MetricType ;
29
26
import io .milvus .param .R ;
35
32
import io .milvus .param .dml .InsertParam ;
36
33
import io .milvus .param .dml .QueryParam ;
37
34
import io .milvus .param .dml .SearchParam ;
38
- import io .milvus .response .QueryResultsWrapper ;
39
- import io .milvus .response .SearchResultsWrapper ;
40
35
import lombok .AllArgsConstructor ;
41
36
import org .slf4j .Logger ;
42
37
import org .slf4j .LoggerFactory ;
@@ -289,7 +284,7 @@ public List<MemoryItem> search(Double[] queryEmbedding, Map<String, Object> filt
289
284
.withCollectionName (collectionName )
290
285
.withMetricType (MetricType .COSINE )
291
286
.withOutFields (Arrays .asList ("id" , "content" , "memory_type" , "user_id" , "agent_id" , "run_id" ,
292
- "actor_id" , "created_at" , "updated_at" ))
287
+ "actor_id" , "created_at" , "updated_at" , "vector" ))
293
288
.withTopK (limit != null ? limit : 10 )
294
289
.withVectors (Collections .singletonList (queryVector ))
295
290
.withVectorFieldName ("vector" )
@@ -303,27 +298,64 @@ public List<MemoryItem> search(Double[] queryEmbedding, Map<String, Object> filt
303
298
throw new RuntimeException ("Search failed: " + response .getMessage ());
304
299
}
305
300
306
- // 解析搜索结果 - 使用简化的方式
301
+ // 解析搜索结果 - 正确实现
307
302
List <MemoryItem > results = new ArrayList <>();
308
303
SearchResults searchResults = response .getData ();
309
304
310
- // 获取搜索结果的数量
311
- int resultCount = Math .min ((int ) searchResults .getResults ().getNumQueries (), limit != null ? limit : 10 );
312
-
313
- for (int i = 0 ; i < resultCount ; i ++) {
314
- // 构建MemoryItem - 简化实现
315
- MemoryItem item = new MemoryItem ();
316
- item .setId ("memory_" + i ); // 简化ID生成
317
- item .setContent ("Search result " + i );
318
- item .setMemoryType ("factual" );
319
- item .setUserId ("default_user" );
320
- item .setAgentId ("default_agent" );
321
- item .setRunId ("default_run" );
322
- item .setActorId ("default_actor" );
323
- item .setCreatedAt (Instant .now ());
324
- item .setUpdatedAt (Instant .now ());
325
-
326
- results .add (item );
305
+ if (searchResults != null && searchResults .getResults () != null ) {
306
+ // 获取搜索结果的字段数据
307
+ List <FieldData > fieldsData = searchResults .getResults ().getFieldsDataList ();
308
+
309
+ if (fieldsData != null && !fieldsData .isEmpty ()) {
310
+ // 创建字段映射
311
+ Map <String , List <Object >> fieldMap = new HashMap <>();
312
+
313
+ for (FieldData fieldData : fieldsData ) {
314
+ String fieldName = fieldData .getFieldName ();
315
+ List <Object > values = new ArrayList <>();
316
+
317
+ if (fieldData .getType () == DataType .VarChar ) {
318
+ // 字符串字段
319
+ List <String > stringData = fieldData .getScalars ().getStringData ().getDataList ();
320
+ values .addAll (stringData );
321
+ }
322
+ else if (fieldData .getType () == DataType .Int64 ) {
323
+ // 长整型字段(时间戳)
324
+ List <Long > longData = fieldData .getScalars ().getLongData ().getDataList ();
325
+ values .addAll (longData );
326
+ }
327
+ else if (fieldData .getType () == DataType .FloatVector ) {
328
+ // 向量字段
329
+ List <Float > vectorData = fieldData .getVectors ().getFloatVector ().getDataList ();
330
+ values .addAll (vectorData );
331
+ }
332
+
333
+ fieldMap .put (fieldName , values );
334
+ }
335
+
336
+ // 获取结果数量
337
+ int resultCount = fieldMap .get ("id" ).size ();
338
+
339
+ // 构建MemoryItem对象并计算相似度
340
+ for (int i = 0 ; i < resultCount ; i ++) {
341
+ MemoryItem item = buildMemoryItemFromFieldMap (fieldMap , i , resultCount );
342
+
343
+ // 计算相似度分数
344
+ Double [] itemEmbedding = item .getEmbedding ();
345
+ if (itemEmbedding != null ) {
346
+ double similarity = cosineSimilarity (queryEmbedding , itemEmbedding );
347
+ item .setScore (similarity );
348
+
349
+ // 根据阈值过滤结果
350
+ if (threshold == null || similarity >= threshold ) {
351
+ results .add (item );
352
+ }
353
+ }
354
+ }
355
+
356
+ // 按相似度分数降序排序
357
+ results .sort ((a , b ) -> Double .compare (b .getScore (), a .getScore ()));
358
+ }
327
359
}
328
360
329
361
logger .debug ("Found {} similar memories" , results .size ());
@@ -356,27 +388,43 @@ public List<MemoryItem> getAll(Map<String, Object> filters, Integer limit) {
356
388
throw new RuntimeException ("Query failed: " + response .getMessage ());
357
389
}
358
390
359
- // 解析查询结果 - 使用简化的方式
391
+ // 解析查询结果 - 正确实现
360
392
List <MemoryItem > results = new ArrayList <>();
361
393
QueryResults queryResults = response .getData ();
362
394
363
- // 获取查询结果的数量 - 简化实现
364
- int resultCount = limit != null ? limit : 100 ;
365
-
366
- for (int i = 0 ; i < resultCount ; i ++) {
367
- // 构建MemoryItem - 简化实现
368
- MemoryItem item = new MemoryItem ();
369
- item .setId ("memory_" + i ); // 简化ID生成
370
- item .setContent ("Query result " + i );
371
- item .setMemoryType ("factual" );
372
- item .setUserId ("default_user" );
373
- item .setAgentId ("default_agent" );
374
- item .setRunId ("default_run" );
375
- item .setActorId ("default_actor" );
376
- item .setCreatedAt (Instant .now ());
377
- item .setUpdatedAt (Instant .now ());
378
-
379
- results .add (item );
395
+ if (queryResults != null && queryResults .getFieldsDataCount () > 0 ) {
396
+ // 直接获取字段数据列表
397
+ List <FieldData > fieldsData = queryResults .getFieldsDataList ();
398
+
399
+ // 创建字段映射
400
+ Map <String , List <Object >> fieldMap = new HashMap <>();
401
+
402
+ for (FieldData fieldData : fieldsData ) {
403
+ String fieldName = fieldData .getFieldName ();
404
+ List <Object > values = new ArrayList <>();
405
+
406
+ if (fieldData .getType () == DataType .VarChar ) {
407
+ // 字符串字段
408
+ List <String > stringData = fieldData .getScalars ().getStringData ().getDataList ();
409
+ values .addAll (stringData );
410
+ }
411
+ else if (fieldData .getType () == DataType .Int64 ) {
412
+ // 长整型字段(时间戳)
413
+ List <Long > longData = fieldData .getScalars ().getLongData ().getDataList ();
414
+ values .addAll (longData );
415
+ }
416
+
417
+ fieldMap .put (fieldName , values );
418
+ }
419
+
420
+ // 获取结果数量
421
+ int resultCount = fieldMap .get ("id" ).size ();
422
+
423
+ // 构建MemoryItem对象
424
+ for (int i = 0 ; i < resultCount ; i ++) {
425
+ MemoryItem item = buildMemoryItemFromFieldMap (fieldMap , i , resultCount );
426
+ results .add (item );
427
+ }
380
428
}
381
429
382
430
logger .debug ("Retrieved {} memories" , results .size ());
@@ -409,25 +457,45 @@ public MemoryItem get(String memoryId) {
409
457
throw new RuntimeException ("Query failed: " + response .getMessage ());
410
458
}
411
459
412
- // 解析查询结果 - 使用简化的方式
460
+ // 解析查询结果 - 正确实现
413
461
QueryResults queryResults = response .getData ();
414
- // 简化实现,假设查询成功就返回结果
415
- if (queryResults == null ) {
462
+ if (queryResults == null || queryResults .getFieldsDataCount () == 0 ) {
463
+ logger .debug ("Memory not found: {}" , memoryId );
464
+ return null ;
465
+ }
466
+
467
+ // 获取字段数据列表
468
+ List <FieldData > fieldsData = queryResults .getFieldsDataList ();
469
+
470
+ // 创建字段映射
471
+ Map <String , List <Object >> fieldMap = new HashMap <>();
472
+
473
+ for (FieldData fieldData : fieldsData ) {
474
+ String fieldName = fieldData .getFieldName ();
475
+ List <Object > values = new ArrayList <>();
476
+
477
+ if (fieldData .getType () == DataType .VarChar ) {
478
+ // 字符串字段
479
+ List <String > stringData = fieldData .getScalars ().getStringData ().getDataList ();
480
+ values .addAll (stringData );
481
+ }
482
+ else if (fieldData .getType () == DataType .Int64 ) {
483
+ // 长整型字段(时间戳)
484
+ List <Long > longData = fieldData .getScalars ().getLongData ().getDataList ();
485
+ values .addAll (longData );
486
+ }
487
+
488
+ fieldMap .put (fieldName , values );
489
+ }
490
+
491
+ // 检查是否有结果
492
+ if (fieldMap .get ("id" ) == null || fieldMap .get ("id" ).isEmpty ()) {
416
493
logger .debug ("Memory not found: {}" , memoryId );
417
494
return null ;
418
495
}
419
496
420
- // 构建MemoryItem - 简化实现
421
- MemoryItem item = new MemoryItem ();
422
- item .setId (memoryId );
423
- item .setContent ("Memory content for " + memoryId );
424
- item .setMemoryType ("factual" );
425
- item .setUserId ("default_user" );
426
- item .setAgentId ("default_agent" );
427
- item .setRunId ("default_run" );
428
- item .setActorId ("default_actor" );
429
- item .setCreatedAt (Instant .now ());
430
- item .setUpdatedAt (Instant .now ());
497
+ // 构建MemoryItem对象
498
+ MemoryItem item = buildMemoryItemFromFieldMap (fieldMap , 0 , 1 );
431
499
432
500
logger .debug ("Retrieved memory: {}" , memoryId );
433
501
return item ;
@@ -531,11 +599,23 @@ private String buildSearchExpression(Map<String, Object> filters) {
531
599
return "" ;
532
600
}
533
601
602
+ // 定义有效的集合字段
603
+ Set <String > validFields = new HashSet <>(Arrays .asList (
604
+ "user_id" , "agent_id" , "run_id" , "actor_id" , "memory_type" ,
605
+ "created_at" , "updated_at" , "content"
606
+ ));
607
+
534
608
List <String > conditions = new ArrayList <>();
535
609
for (Map .Entry <String , Object > filter : filters .entrySet ()) {
536
610
String key = filter .getKey ();
537
611
String value = filter .getValue ().toString ();
538
612
613
+ // 只处理有效的集合字段,忽略其他参数如query、limit、threshold等
614
+ if (!validFields .contains (key )) {
615
+ logger .debug ("Skipping invalid field in search expression: {}" , key );
616
+ continue ;
617
+ }
618
+
539
619
// 根据字段类型构建不同的表达式
540
620
switch (key ) {
541
621
case "user_id" :
@@ -552,8 +632,12 @@ private String buildSearchExpression(Map<String, Object> filters) {
552
632
conditions .add (key + " == " + value );
553
633
}
554
634
break ;
635
+ case "content" :
636
+ // 内容字段使用模糊匹配
637
+ conditions .add (key + " like \" %" + value + "%\" " );
638
+ break ;
555
639
default :
556
- // 其他字段使用模糊匹配
640
+ // 其他有效字段使用模糊匹配
557
641
conditions .add (key + " like \" %" + value + "%\" " );
558
642
break ;
559
643
}
@@ -562,4 +646,123 @@ private String buildSearchExpression(Map<String, Object> filters) {
562
646
return String .join (" && " , conditions );
563
647
}
564
648
649
+ /**
650
+ * 计算余弦相似度
651
+ * @param a 向量a
652
+ * @param b 向量b
653
+ * @return 余弦相似度值
654
+ */
655
+ private Double cosineSimilarity (Double [] a , Double [] b ) {
656
+ logger .debug ("计算余弦相似度 - 查询向量长度: {}, 存储向量长度: {}" ,
657
+ a != null ? a .length : "null" ,
658
+ b != null ? b .length : "null" );
659
+
660
+ if (a == null || b == null ) {
661
+ logger .warn ("向量为null - 查询向量: {}, 存储向量: {}" , a == null , b == null );
662
+ return 0.0 ;
663
+ }
664
+
665
+ if (a .length != b .length ) {
666
+ logger .error ("向量长度不匹配! 查询向量长度: {}, 存储向量长度: {}" , a .length , b .length );
667
+ logger .error ("查询向量前5个值: {}" , Arrays .toString (Arrays .copyOf (a , Math .min (5 , a .length ))));
668
+ logger .error ("存储向量前5个值: {}" , Arrays .toString (Arrays .copyOf (b , Math .min (5 , b .length ))));
669
+ return 0.0 ;
670
+ }
671
+
672
+ double dotProduct = 0.0 ;
673
+ double normA = 0.0 ;
674
+ double normB = 0.0 ;
675
+
676
+ for (int i = 0 ; i < a .length ; i ++) {
677
+ dotProduct += a [i ] * b [i ];
678
+ normA += a [i ] * a [i ];
679
+ normB += b [i ] * b [i ];
680
+ }
681
+
682
+ if (normA == 0.0 || normB == 0.0 ) {
683
+ logger .warn ("向量范数为0 - normA: {}, normB: {}" , normA , normB );
684
+ return 0.0 ;
685
+ }
686
+
687
+ double similarity = dotProduct / (Math .sqrt (normA ) * Math .sqrt (normB ));
688
+ logger .debug ("余弦相似度计算结果: {}" , similarity );
689
+ return similarity ;
690
+ }
691
+
692
+ /**
693
+ * 从字段映射构建MemoryItem对象
694
+ * @param fieldMap 字段映射
695
+ * @param index 数据索引
696
+ * @param resultCount 结果总数(用于计算向量维度)
697
+ * @return MemoryItem对象
698
+ */
699
+ private MemoryItem buildMemoryItemFromFieldMap (Map <String , List <Object >> fieldMap , int index , int resultCount ) {
700
+ MemoryItem item = new MemoryItem ();
701
+
702
+ // 设置ID
703
+ String id = (String ) fieldMap .get ("id" ).get (index );
704
+ item .setId (id );
705
+
706
+ // 设置内容
707
+ String content = (String ) fieldMap .get ("content" ).get (index );
708
+ item .setContent (content );
709
+
710
+ // 设置记忆类型
711
+ String memoryType = (String ) fieldMap .get ("memory_type" ).get (index );
712
+ item .setMemoryType (memoryType );
713
+
714
+ // 设置用户ID
715
+ String userId = (String ) fieldMap .get ("user_id" ).get (index );
716
+ item .setUserId (userId );
717
+
718
+ // 设置代理ID
719
+ String agentId = (String ) fieldMap .get ("agent_id" ).get (index );
720
+ item .setAgentId (agentId );
721
+
722
+ // 设置运行ID
723
+ String runId = (String ) fieldMap .get ("run_id" ).get (index );
724
+ item .setRunId (runId );
725
+
726
+ // 设置演员ID
727
+ String actorId = (String ) fieldMap .get ("actor_id" ).get (index );
728
+ item .setActorId (actorId );
729
+
730
+ // 设置创建时间
731
+ Long createdAt = (Long ) fieldMap .get ("created_at" ).get (index );
732
+ item .setCreatedAt (Instant .ofEpochMilli (createdAt ));
733
+
734
+ // 设置更新时间
735
+ Long updatedAt = (Long ) fieldMap .get ("updated_at" ).get (index );
736
+ item .setUpdatedAt (Instant .ofEpochMilli (updatedAt ));
737
+
738
+ // 设置向量嵌入
739
+ if (fieldMap .containsKey ("vector" )) {
740
+ List <Object > vectorData = fieldMap .get ("vector" );
741
+ if (vectorData != null && !vectorData .isEmpty ()) {
742
+ // 计算每个向量的维度
743
+ int vectorDimension = vectorData .size () / resultCount ;
744
+ logger .debug ("向量数据总数: {}, 结果数量: {}, 每个向量维度: {}" ,
745
+ vectorData .size (), resultCount , vectorDimension );
746
+
747
+ // 提取当前结果对应的向量数据
748
+ int startIndex = index * vectorDimension ;
749
+ int endIndex = startIndex + vectorDimension ;
750
+
751
+ if (endIndex <= vectorData .size ()) {
752
+ List <Object > currentVectorData = vectorData .subList (startIndex , endIndex );
753
+ Double [] embedding = currentVectorData .stream ()
754
+ .map (obj -> ((Float ) obj ).doubleValue ())
755
+ .toArray (Double []::new );
756
+ item .setEmbedding (embedding );
757
+ logger .debug ("提取向量成功 - index: {}, 向量长度: {}" , index , embedding .length );
758
+ } else {
759
+ logger .warn ("向量数据索引越界 - index: {}, startIndex: {}, endIndex: {}, vectorDataSize: {}" ,
760
+ index , startIndex , endIndex , vectorData .size ());
761
+ }
762
+ }
763
+ }
764
+
765
+ return item ;
766
+ }
767
+
565
768
}
0 commit comments