Skip to content

Commit cb00bcf

Browse files
author
80025731
committed
refactor: update the implement of search and insert method.
1 parent 24934bd commit cb00bcf

File tree

1 file changed

+261
-58
lines changed

1 file changed

+261
-58
lines changed

mem4j-core/src/main/java/io/github/mem4j/vectorstores/MilvusVectorStoreService.java

Lines changed: 261 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,7 @@
2020
import io.github.mem4j.memory.MemoryItem;
2121
import io.milvus.client.MilvusClient;
2222
import io.milvus.client.MilvusServiceClient;
23-
import io.milvus.grpc.DataType;
24-
import io.milvus.grpc.MutationResult;
25-
import io.milvus.grpc.QueryResults;
26-
import io.milvus.grpc.SearchResults;
23+
import io.milvus.grpc.*;
2724
import io.milvus.param.ConnectParam;
2825
import io.milvus.param.MetricType;
2926
import io.milvus.param.R;
@@ -35,8 +32,6 @@
3532
import io.milvus.param.dml.InsertParam;
3633
import io.milvus.param.dml.QueryParam;
3734
import io.milvus.param.dml.SearchParam;
38-
import io.milvus.response.QueryResultsWrapper;
39-
import io.milvus.response.SearchResultsWrapper;
4035
import lombok.AllArgsConstructor;
4136
import org.slf4j.Logger;
4237
import org.slf4j.LoggerFactory;
@@ -289,7 +284,7 @@ public List<MemoryItem> search(Double[] queryEmbedding, Map<String, Object> filt
289284
.withCollectionName(collectionName)
290285
.withMetricType(MetricType.COSINE)
291286
.withOutFields(Arrays.asList("id", "content", "memory_type", "user_id", "agent_id", "run_id",
292-
"actor_id", "created_at", "updated_at"))
287+
"actor_id", "created_at", "updated_at", "vector"))
293288
.withTopK(limit != null ? limit : 10)
294289
.withVectors(Collections.singletonList(queryVector))
295290
.withVectorFieldName("vector")
@@ -303,27 +298,64 @@ public List<MemoryItem> search(Double[] queryEmbedding, Map<String, Object> filt
303298
throw new RuntimeException("Search failed: " + response.getMessage());
304299
}
305300

306-
// 解析搜索结果 - 使用简化的方式
301+
// 解析搜索结果 - 正确实现
307302
List<MemoryItem> results = new ArrayList<>();
308303
SearchResults searchResults = response.getData();
309304

310-
// 获取搜索结果的数量
311-
int resultCount = Math.min((int) searchResults.getResults().getNumQueries(), limit != null ? limit : 10);
312-
313-
for (int i = 0; i < resultCount; i++) {
314-
// 构建MemoryItem - 简化实现
315-
MemoryItem item = new MemoryItem();
316-
item.setId("memory_" + i); // 简化ID生成
317-
item.setContent("Search result " + i);
318-
item.setMemoryType("factual");
319-
item.setUserId("default_user");
320-
item.setAgentId("default_agent");
321-
item.setRunId("default_run");
322-
item.setActorId("default_actor");
323-
item.setCreatedAt(Instant.now());
324-
item.setUpdatedAt(Instant.now());
325-
326-
results.add(item);
305+
if (searchResults != null && searchResults.getResults() != null) {
306+
// 获取搜索结果的字段数据
307+
List<FieldData> fieldsData = searchResults.getResults().getFieldsDataList();
308+
309+
if (fieldsData != null && !fieldsData.isEmpty()) {
310+
// 创建字段映射
311+
Map<String, List<Object>> fieldMap = new HashMap<>();
312+
313+
for (FieldData fieldData : fieldsData) {
314+
String fieldName = fieldData.getFieldName();
315+
List<Object> values = new ArrayList<>();
316+
317+
if (fieldData.getType() == DataType.VarChar) {
318+
// 字符串字段
319+
List<String> stringData = fieldData.getScalars().getStringData().getDataList();
320+
values.addAll(stringData);
321+
}
322+
else if (fieldData.getType() == DataType.Int64) {
323+
// 长整型字段(时间戳)
324+
List<Long> longData = fieldData.getScalars().getLongData().getDataList();
325+
values.addAll(longData);
326+
}
327+
else if (fieldData.getType() == DataType.FloatVector) {
328+
// 向量字段
329+
List<Float> vectorData = fieldData.getVectors().getFloatVector().getDataList();
330+
values.addAll(vectorData);
331+
}
332+
333+
fieldMap.put(fieldName, values);
334+
}
335+
336+
// 获取结果数量
337+
int resultCount = fieldMap.get("id").size();
338+
339+
// 构建MemoryItem对象并计算相似度
340+
for (int i = 0; i < resultCount; i++) {
341+
MemoryItem item = buildMemoryItemFromFieldMap(fieldMap, i, resultCount);
342+
343+
// 计算相似度分数
344+
Double[] itemEmbedding = item.getEmbedding();
345+
if (itemEmbedding != null) {
346+
double similarity = cosineSimilarity(queryEmbedding, itemEmbedding);
347+
item.setScore(similarity);
348+
349+
// 根据阈值过滤结果
350+
if (threshold == null || similarity >= threshold) {
351+
results.add(item);
352+
}
353+
}
354+
}
355+
356+
// 按相似度分数降序排序
357+
results.sort((a, b) -> Double.compare(b.getScore(), a.getScore()));
358+
}
327359
}
328360

329361
logger.debug("Found {} similar memories", results.size());
@@ -356,27 +388,43 @@ public List<MemoryItem> getAll(Map<String, Object> filters, Integer limit) {
356388
throw new RuntimeException("Query failed: " + response.getMessage());
357389
}
358390

359-
// 解析查询结果 - 使用简化的方式
391+
// 解析查询结果 - 正确实现
360392
List<MemoryItem> results = new ArrayList<>();
361393
QueryResults queryResults = response.getData();
362394

363-
// 获取查询结果的数量 - 简化实现
364-
int resultCount = limit != null ? limit : 100;
365-
366-
for (int i = 0; i < resultCount; i++) {
367-
// 构建MemoryItem - 简化实现
368-
MemoryItem item = new MemoryItem();
369-
item.setId("memory_" + i); // 简化ID生成
370-
item.setContent("Query result " + i);
371-
item.setMemoryType("factual");
372-
item.setUserId("default_user");
373-
item.setAgentId("default_agent");
374-
item.setRunId("default_run");
375-
item.setActorId("default_actor");
376-
item.setCreatedAt(Instant.now());
377-
item.setUpdatedAt(Instant.now());
378-
379-
results.add(item);
395+
if (queryResults != null && queryResults.getFieldsDataCount() > 0) {
396+
// 直接获取字段数据列表
397+
List<FieldData> fieldsData = queryResults.getFieldsDataList();
398+
399+
// 创建字段映射
400+
Map<String, List<Object>> fieldMap = new HashMap<>();
401+
402+
for (FieldData fieldData : fieldsData) {
403+
String fieldName = fieldData.getFieldName();
404+
List<Object> values = new ArrayList<>();
405+
406+
if (fieldData.getType() == DataType.VarChar) {
407+
// 字符串字段
408+
List<String> stringData = fieldData.getScalars().getStringData().getDataList();
409+
values.addAll(stringData);
410+
}
411+
else if (fieldData.getType() == DataType.Int64) {
412+
// 长整型字段(时间戳)
413+
List<Long> longData = fieldData.getScalars().getLongData().getDataList();
414+
values.addAll(longData);
415+
}
416+
417+
fieldMap.put(fieldName, values);
418+
}
419+
420+
// 获取结果数量
421+
int resultCount = fieldMap.get("id").size();
422+
423+
// 构建MemoryItem对象
424+
for (int i = 0; i < resultCount; i++) {
425+
MemoryItem item = buildMemoryItemFromFieldMap(fieldMap, i, resultCount);
426+
results.add(item);
427+
}
380428
}
381429

382430
logger.debug("Retrieved {} memories", results.size());
@@ -409,25 +457,45 @@ public MemoryItem get(String memoryId) {
409457
throw new RuntimeException("Query failed: " + response.getMessage());
410458
}
411459

412-
// 解析查询结果 - 使用简化的方式
460+
// 解析查询结果 - 正确实现
413461
QueryResults queryResults = response.getData();
414-
// 简化实现,假设查询成功就返回结果
415-
if (queryResults == null) {
462+
if (queryResults == null || queryResults.getFieldsDataCount() == 0) {
463+
logger.debug("Memory not found: {}", memoryId);
464+
return null;
465+
}
466+
467+
// 获取字段数据列表
468+
List<FieldData> fieldsData = queryResults.getFieldsDataList();
469+
470+
// 创建字段映射
471+
Map<String, List<Object>> fieldMap = new HashMap<>();
472+
473+
for (FieldData fieldData : fieldsData) {
474+
String fieldName = fieldData.getFieldName();
475+
List<Object> values = new ArrayList<>();
476+
477+
if (fieldData.getType() == DataType.VarChar) {
478+
// 字符串字段
479+
List<String> stringData = fieldData.getScalars().getStringData().getDataList();
480+
values.addAll(stringData);
481+
}
482+
else if (fieldData.getType() == DataType.Int64) {
483+
// 长整型字段(时间戳)
484+
List<Long> longData = fieldData.getScalars().getLongData().getDataList();
485+
values.addAll(longData);
486+
}
487+
488+
fieldMap.put(fieldName, values);
489+
}
490+
491+
// 检查是否有结果
492+
if (fieldMap.get("id") == null || fieldMap.get("id").isEmpty()) {
416493
logger.debug("Memory not found: {}", memoryId);
417494
return null;
418495
}
419496

420-
// 构建MemoryItem - 简化实现
421-
MemoryItem item = new MemoryItem();
422-
item.setId(memoryId);
423-
item.setContent("Memory content for " + memoryId);
424-
item.setMemoryType("factual");
425-
item.setUserId("default_user");
426-
item.setAgentId("default_agent");
427-
item.setRunId("default_run");
428-
item.setActorId("default_actor");
429-
item.setCreatedAt(Instant.now());
430-
item.setUpdatedAt(Instant.now());
497+
// 构建MemoryItem对象
498+
MemoryItem item = buildMemoryItemFromFieldMap(fieldMap, 0, 1);
431499

432500
logger.debug("Retrieved memory: {}", memoryId);
433501
return item;
@@ -531,11 +599,23 @@ private String buildSearchExpression(Map<String, Object> filters) {
531599
return "";
532600
}
533601

602+
// 定义有效的集合字段
603+
Set<String> validFields = new HashSet<>(Arrays.asList(
604+
"user_id", "agent_id", "run_id", "actor_id", "memory_type",
605+
"created_at", "updated_at", "content"
606+
));
607+
534608
List<String> conditions = new ArrayList<>();
535609
for (Map.Entry<String, Object> filter : filters.entrySet()) {
536610
String key = filter.getKey();
537611
String value = filter.getValue().toString();
538612

613+
// 只处理有效的集合字段,忽略其他参数如query、limit、threshold等
614+
if (!validFields.contains(key)) {
615+
logger.debug("Skipping invalid field in search expression: {}", key);
616+
continue;
617+
}
618+
539619
// 根据字段类型构建不同的表达式
540620
switch (key) {
541621
case "user_id":
@@ -552,8 +632,12 @@ private String buildSearchExpression(Map<String, Object> filters) {
552632
conditions.add(key + " == " + value);
553633
}
554634
break;
635+
case "content":
636+
// 内容字段使用模糊匹配
637+
conditions.add(key + " like \"%" + value + "%\"");
638+
break;
555639
default:
556-
// 其他字段使用模糊匹配
640+
// 其他有效字段使用模糊匹配
557641
conditions.add(key + " like \"%" + value + "%\"");
558642
break;
559643
}
@@ -562,4 +646,123 @@ private String buildSearchExpression(Map<String, Object> filters) {
562646
return String.join(" && ", conditions);
563647
}
564648

649+
/**
650+
* 计算余弦相似度
651+
* @param a 向量a
652+
* @param b 向量b
653+
* @return 余弦相似度值
654+
*/
655+
private Double cosineSimilarity(Double[] a, Double[] b) {
656+
logger.debug("计算余弦相似度 - 查询向量长度: {}, 存储向量长度: {}",
657+
a != null ? a.length : "null",
658+
b != null ? b.length : "null");
659+
660+
if (a == null || b == null) {
661+
logger.warn("向量为null - 查询向量: {}, 存储向量: {}", a == null, b == null);
662+
return 0.0;
663+
}
664+
665+
if (a.length != b.length) {
666+
logger.error("向量长度不匹配! 查询向量长度: {}, 存储向量长度: {}", a.length, b.length);
667+
logger.error("查询向量前5个值: {}", Arrays.toString(Arrays.copyOf(a, Math.min(5, a.length))));
668+
logger.error("存储向量前5个值: {}", Arrays.toString(Arrays.copyOf(b, Math.min(5, b.length))));
669+
return 0.0;
670+
}
671+
672+
double dotProduct = 0.0;
673+
double normA = 0.0;
674+
double normB = 0.0;
675+
676+
for (int i = 0; i < a.length; i++) {
677+
dotProduct += a[i] * b[i];
678+
normA += a[i] * a[i];
679+
normB += b[i] * b[i];
680+
}
681+
682+
if (normA == 0.0 || normB == 0.0) {
683+
logger.warn("向量范数为0 - normA: {}, normB: {}", normA, normB);
684+
return 0.0;
685+
}
686+
687+
double similarity = dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
688+
logger.debug("余弦相似度计算结果: {}", similarity);
689+
return similarity;
690+
}
691+
692+
/**
693+
* 从字段映射构建MemoryItem对象
694+
* @param fieldMap 字段映射
695+
* @param index 数据索引
696+
* @param resultCount 结果总数(用于计算向量维度)
697+
* @return MemoryItem对象
698+
*/
699+
private MemoryItem buildMemoryItemFromFieldMap(Map<String, List<Object>> fieldMap, int index, int resultCount) {
700+
MemoryItem item = new MemoryItem();
701+
702+
// 设置ID
703+
String id = (String) fieldMap.get("id").get(index);
704+
item.setId(id);
705+
706+
// 设置内容
707+
String content = (String) fieldMap.get("content").get(index);
708+
item.setContent(content);
709+
710+
// 设置记忆类型
711+
String memoryType = (String) fieldMap.get("memory_type").get(index);
712+
item.setMemoryType(memoryType);
713+
714+
// 设置用户ID
715+
String userId = (String) fieldMap.get("user_id").get(index);
716+
item.setUserId(userId);
717+
718+
// 设置代理ID
719+
String agentId = (String) fieldMap.get("agent_id").get(index);
720+
item.setAgentId(agentId);
721+
722+
// 设置运行ID
723+
String runId = (String) fieldMap.get("run_id").get(index);
724+
item.setRunId(runId);
725+
726+
// 设置演员ID
727+
String actorId = (String) fieldMap.get("actor_id").get(index);
728+
item.setActorId(actorId);
729+
730+
// 设置创建时间
731+
Long createdAt = (Long) fieldMap.get("created_at").get(index);
732+
item.setCreatedAt(Instant.ofEpochMilli(createdAt));
733+
734+
// 设置更新时间
735+
Long updatedAt = (Long) fieldMap.get("updated_at").get(index);
736+
item.setUpdatedAt(Instant.ofEpochMilli(updatedAt));
737+
738+
// 设置向量嵌入
739+
if (fieldMap.containsKey("vector")) {
740+
List<Object> vectorData = fieldMap.get("vector");
741+
if (vectorData != null && !vectorData.isEmpty()) {
742+
// 计算每个向量的维度
743+
int vectorDimension = vectorData.size() / resultCount;
744+
logger.debug("向量数据总数: {}, 结果数量: {}, 每个向量维度: {}",
745+
vectorData.size(), resultCount, vectorDimension);
746+
747+
// 提取当前结果对应的向量数据
748+
int startIndex = index * vectorDimension;
749+
int endIndex = startIndex + vectorDimension;
750+
751+
if (endIndex <= vectorData.size()) {
752+
List<Object> currentVectorData = vectorData.subList(startIndex, endIndex);
753+
Double[] embedding = currentVectorData.stream()
754+
.map(obj -> ((Float) obj).doubleValue())
755+
.toArray(Double[]::new);
756+
item.setEmbedding(embedding);
757+
logger.debug("提取向量成功 - index: {}, 向量长度: {}", index, embedding.length);
758+
} else {
759+
logger.warn("向量数据索引越界 - index: {}, startIndex: {}, endIndex: {}, vectorDataSize: {}",
760+
index, startIndex, endIndex, vectorData.size());
761+
}
762+
}
763+
}
764+
765+
return item;
766+
}
767+
565768
}

0 commit comments

Comments
 (0)