Skip to content

Commit 7b74711

Browse files
author
80025731
committed
refactor: update the implement of search and insert method.
1 parent cb00bcf commit 7b74711

File tree

1 file changed

+13
-17
lines changed

1 file changed

+13
-17
lines changed

mem4j-core/src/main/java/io/github/mem4j/vectorstores/MilvusVectorStoreService.java

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -336,9 +336,9 @@ else if (fieldData.getType() == DataType.FloatVector) {
336336
// 获取结果数量
337337
int resultCount = fieldMap.get("id").size();
338338

339-
// 构建MemoryItem对象并计算相似度
340-
for (int i = 0; i < resultCount; i++) {
341-
MemoryItem item = buildMemoryItemFromFieldMap(fieldMap, i, resultCount);
339+
// 构建MemoryItem对象并计算相似度
340+
for (int i = 0; i < resultCount; i++) {
341+
MemoryItem item = buildMemoryItemFromFieldMap(fieldMap, i, resultCount);
342342

343343
// 计算相似度分数
344344
Double[] itemEmbedding = item.getEmbedding();
@@ -600,10 +600,8 @@ private String buildSearchExpression(Map<String, Object> filters) {
600600
}
601601

602602
// 定义有效的集合字段
603-
Set<String> validFields = new HashSet<>(Arrays.asList(
604-
"user_id", "agent_id", "run_id", "actor_id", "memory_type",
605-
"created_at", "updated_at", "content"
606-
));
603+
Set<String> validFields = new HashSet<>(Arrays.asList("user_id", "agent_id", "run_id", "actor_id",
604+
"memory_type", "created_at", "updated_at", "content"));
607605

608606
List<String> conditions = new ArrayList<>();
609607
for (Map.Entry<String, Object> filter : filters.entrySet()) {
@@ -653,9 +651,7 @@ private String buildSearchExpression(Map<String, Object> filters) {
653651
* @return 余弦相似度值
654652
*/
655653
private Double cosineSimilarity(Double[] a, Double[] b) {
656-
logger.debug("计算余弦相似度 - 查询向量长度: {}, 存储向量长度: {}",
657-
a != null ? a.length : "null",
658-
b != null ? b.length : "null");
654+
logger.debug("计算余弦相似度 - 查询向量长度: {}, 存储向量长度: {}", a != null ? a.length : "null", b != null ? b.length : "null");
659655

660656
if (a == null || b == null) {
661657
logger.warn("向量为null - 查询向量: {}, 存储向量: {}", a == null, b == null);
@@ -741,23 +737,23 @@ private MemoryItem buildMemoryItemFromFieldMap(Map<String, List<Object>> fieldMa
741737
if (vectorData != null && !vectorData.isEmpty()) {
742738
// 计算每个向量的维度
743739
int vectorDimension = vectorData.size() / resultCount;
744-
logger.debug("向量数据总数: {}, 结果数量: {}, 每个向量维度: {}",
745-
vectorData.size(), resultCount, vectorDimension);
746-
740+
logger.debug("向量数据总数: {}, 结果数量: {}, 每个向量维度: {}", vectorData.size(), resultCount, vectorDimension);
741+
747742
// 提取当前结果对应的向量数据
748743
int startIndex = index * vectorDimension;
749744
int endIndex = startIndex + vectorDimension;
750-
745+
751746
if (endIndex <= vectorData.size()) {
752747
List<Object> currentVectorData = vectorData.subList(startIndex, endIndex);
753748
Double[] embedding = currentVectorData.stream()
754749
.map(obj -> ((Float) obj).doubleValue())
755750
.toArray(Double[]::new);
756751
item.setEmbedding(embedding);
757752
logger.debug("提取向量成功 - index: {}, 向量长度: {}", index, embedding.length);
758-
} else {
759-
logger.warn("向量数据索引越界 - index: {}, startIndex: {}, endIndex: {}, vectorDataSize: {}",
760-
index, startIndex, endIndex, vectorData.size());
753+
}
754+
else {
755+
logger.warn("向量数据索引越界 - index: {}, startIndex: {}, endIndex: {}, vectorDataSize: {}", index,
756+
startIndex, endIndex, vectorData.size());
761757
}
762758
}
763759
}

0 commit comments

Comments
 (0)