diff --git a/pom.xml b/pom.xml
index c5d20c1..a63a015 100644
--- a/pom.xml
+++ b/pom.xml
@@ -49,7 +49,8 @@
3.11.0
-
+
+ 2.44.5
3.5.3
3.0.0
1.2.22
@@ -249,6 +250,24 @@
+
+ com.diffplug.spotless
+ spotless-maven-plugin
+ ${spotless-maven-plugin.version}
+
+
+
+
+
+
+
+ compile
+
+ apply
+
+
+
+
org.apache.maven.plugins
maven-compiler-plugin
diff --git a/spring-ai-alibaba-data-agent-chat/README.md b/spring-ai-alibaba-data-agent-chat/README.md
index 62c7c24..8713d60 100644
--- a/spring-ai-alibaba-data-agent-chat/README.md
+++ b/spring-ai-alibaba-data-agent-chat/README.md
@@ -465,7 +465,7 @@ import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.request.SchemaInitRequest;
-import com.alibaba.cloud.ai.service.vectorstore.impls.SimpleVectorStoreService;
+import com.alibaba.cloud.ai.service.vectorstore.impls.SimpleAgentVectorStoreService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
@@ -491,7 +491,7 @@ public class Nl2sqlController {
private final CompiledGraph compiledGraph;
@Autowired
- private SimpleVectorStoreService simpleVectorStoreService;
+ private SimpleAgentVectorStoreService simpleVectorStoreService;
@Autowired
private DbConfig dbConfig;
diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/constant/Constant.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/constant/Constant.java
index 42a899f..35bd698 100644
--- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/constant/Constant.java
+++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/constant/Constant.java
@@ -144,4 +144,16 @@ private Constant() {
// 人类复核相关
public static final String HUMAN_REVIEW_ENABLED = "HUMAN_REVIEW_ENABLED";
+ // column
+ public static final String COLUMN = "column";
+
+ // table
+ public static final String TABLE = "table";
+
+ // vectorType
+ public static final String VECTOR_TYPE = "vectorType";
+
+ // knowledgeId
+ public static final String KNOWLEDGE_ID = "knowledgeId";
+
}
diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/node/TableRelationNode.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/node/TableRelationNode.java
index b9f2afc..0ac91a2 100644
--- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/node/TableRelationNode.java
+++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/node/TableRelationNode.java
@@ -20,7 +20,6 @@
import com.alibaba.cloud.ai.constant.Constant;
import com.alibaba.cloud.ai.dto.BusinessKnowledgeDTO;
import com.alibaba.cloud.ai.dto.SemanticModelDTO;
-import com.alibaba.cloud.ai.entity.AgentDatasource;
import com.alibaba.cloud.ai.entity.Datasource;
import com.alibaba.cloud.ai.enums.StreamResponseType;
import com.alibaba.cloud.ai.graph.OverAllState;
@@ -32,6 +31,7 @@
import com.alibaba.cloud.ai.service.schema.SchemaService;
import com.alibaba.cloud.ai.service.semantic.SemanticModelRecallService;
import com.alibaba.cloud.ai.util.ChatResponseUtil;
+import com.alibaba.cloud.ai.util.SchemaProcessorUtil;
import com.alibaba.cloud.ai.util.StateUtil;
import com.alibaba.cloud.ai.util.StreamingChatGeneratorUtil;
import org.slf4j.Logger;
@@ -40,6 +40,7 @@
import org.springframework.ai.document.Document;
import org.springframework.dao.DataAccessException;
import org.springframework.stereotype.Component;
+import org.springframework.util.StringUtils;
import reactor.core.publisher.Flux;
import java.util.List;
@@ -98,21 +99,20 @@ public Map apply(OverAllState state) throws Exception {
COLUMN_DOCUMENTS_BY_KEYWORDS_OUTPUT);
String dataSetId = StateUtil.getStringValue(state, Constant.AGENT_ID);
String agentIdStr = StateUtil.getStringValue(state, AGENT_ID);
- long agentId = -1L;
- if (!agentIdStr.isEmpty()) {
- agentId = Long.parseLong(agentIdStr);
- }
+ if (!StringUtils.hasText(agentIdStr))
+ throw new RuntimeException("Agent ID is empty.");
// Execute business logic first - get final result immediately
- SchemaDTO schemaDTO = buildInitialSchema(columnDocumentsByKeywords, tableDocuments);
- SchemaDTO result = processSchemaSelection(schemaDTO, input, evidenceList, state);
+ DbConfig agentDbConfig = getAgentDbConfig(Integer.valueOf(agentIdStr));
+ SchemaDTO schemaDTO = buildInitialSchema(agentIdStr, columnDocumentsByKeywords, tableDocuments, agentDbConfig);
+ SchemaDTO result = processSchemaSelection(schemaDTO, input, evidenceList, state, agentDbConfig);
List businessKnowledges;
List semanticModel;
try {
// Extract business knowledge and semantic model
businessKnowledges = businessKnowledgeRecallService.getFieldByDataSetId(dataSetId);
- semanticModel = semanticModelRecallService.getFieldByDataSetId(String.valueOf(agentId));
+ semanticModel = semanticModelRecallService.getFieldByDataSetId(dataSetId);
}
catch (DataAccessException e) {
logger.warn("Database query failed (attempt {}): {}", retryCount + 1, e.getMessage());
@@ -164,59 +164,27 @@ private String classifyDatabaseError(DataAccessException e) {
/**
* Builds initial schema from column and table documents.
*/
- private SchemaDTO buildInitialSchema(List> columnDocumentsByKeywords,
- List tableDocuments) {
+ private SchemaDTO buildInitialSchema(String agentId, List> columnDocumentsByKeywords,
+ List tableDocuments, DbConfig agentDbConfig) {
SchemaDTO schemaDTO = new SchemaDTO();
- schemaService.extractDatabaseName(schemaDTO);
- schemaService.buildSchemaFromDocuments(columnDocumentsByKeywords, tableDocuments, schemaDTO);
+
+ schemaService.extractDatabaseName(schemaDTO, agentDbConfig);
+ schemaService.buildSchemaFromDocuments(agentId, columnDocumentsByKeywords, tableDocuments, schemaDTO);
return schemaDTO;
}
- /**
- * Dynamically get the data source configuration for an agent
- * @param state The state object containing the agent ID
- * @return The database configuration corresponding to the agent, or null if not found
- */
- private DbConfig getAgentDbConfig(OverAllState state) {
- try {
- // Get the agent ID from the state
- String agentIdStr = StateUtil.getStringValue(state, Constant.AGENT_ID, null);
- if (agentIdStr == null || agentIdStr.trim().isEmpty()) {
- logger.debug("AgentId is null or empty, will use default dbConfig");
- return null;
- }
-
- Integer agentId = Integer.valueOf(agentIdStr);
- logger.debug("Getting datasource config for agent: {}", agentId);
-
- // Get the enabled data source for the agent
- List agentDatasources = datasourceService.getAgentDatasources(agentId);
- if (agentDatasources.isEmpty()) {
- // TODO 调试AgentID不一致,暂时手动处理
- agentDatasources = datasourceService.getAgentDatasources(agentId - 999999);
- }
-
- AgentDatasource activeDatasource = agentDatasources.stream()
- .filter(ad -> ad.getIsActive() == 1)
- .findFirst()
- .orElse(null);
-
- if (activeDatasource == null) {
- logger.debug("Agent {} has no active datasource, will use default dbConfig", agentId);
- return null;
- }
+ private DbConfig getAgentDbConfig(Integer agentId) {
+ // Get the enabled data source for the agent
+ Datasource agentDatasource = datasourceService.getActiveDatasourceByAgentId(agentId);
+ if (agentDatasource == null)
+ throw new RuntimeException("No active datasource found for agent " + agentId);
- // Convert to DbConfig
- DbConfig dbConfig = createDbConfigFromDatasource(activeDatasource.getDatasource());
- logger.debug("Successfully created DbConfig for agent {}: url={}, schema={}, type={}", agentId,
- dbConfig.getUrl(), dbConfig.getSchema(), dbConfig.getDialectType());
+ // Convert to DbConfig
+ DbConfig dbConfig = SchemaProcessorUtil.createDbConfigFromDatasource(agentDatasource);
+ logger.debug("Successfully created DbConfig for agent {}: url={}, schema={}, type={}", agentId,
+ dbConfig.getUrl(), dbConfig.getSchema(), dbConfig.getDialectType());
- return dbConfig;
- }
- catch (Exception e) {
- logger.warn("Failed to get agent datasource config, will use default dbConfig: {}", e.getMessage());
- return null;
- }
+ return dbConfig;
}
/**
@@ -255,13 +223,9 @@ else if ("postgresql".equalsIgnoreCase(datasource.getType())) {
* Processes schema selection based on input, evidence, and optional advice.
*/
private SchemaDTO processSchemaSelection(SchemaDTO schemaDTO, String input, List evidenceList,
- OverAllState state) {
+ OverAllState state, DbConfig agentDbConfig) {
String schemaAdvice = StateUtil.getStringValue(state, SQL_GENERATE_SCHEMA_MISSING_ADVICE, null);
- // 动态获取Agent对应的数据库配置
- DbConfig agentDbConfig = getAgentDbConfig(state);
- logger.debug("Using agent-specific dbConfig: {}", agentDbConfig != null ? agentDbConfig.getUrl() : "default");
-
if (schemaAdvice != null) {
logger.info("[{}] Processing with schema supplement advice: {}", this.getClass().getSimpleName(),
schemaAdvice);
diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/DatasourceService.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/DatasourceService.java
index 9357be3..ee74914 100644
--- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/DatasourceService.java
+++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/DatasourceService.java
@@ -299,4 +299,15 @@ public Map getDatasourceStats() {
return stats;
}
+ public Datasource getActiveDatasourceByAgentId(Integer agentId) {
+ AgentDatasource agentDatasource = getAgentDatasources(agentId).stream()
+ .filter(a -> a.getIsActive() == 1)
+ .findFirst()
+ .orElse(null);
+ if (agentDatasource == null) {
+ return null;
+ }
+ return agentDatasource.getDatasource();
+ }
+
}
diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/processing/AbstractQueryProcessingService.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/processing/AbstractQueryProcessingService.java
index e4e684f..c59276e 100644
--- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/processing/AbstractQueryProcessingService.java
+++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/processing/AbstractQueryProcessingService.java
@@ -19,17 +19,20 @@
import com.alibaba.cloud.ai.dto.schema.SchemaDTO;
import com.alibaba.cloud.ai.prompt.PromptConstant;
import com.alibaba.cloud.ai.prompt.PromptHelper;
+import com.alibaba.cloud.ai.service.DatasourceService;
import com.alibaba.cloud.ai.service.LlmService;
import com.alibaba.cloud.ai.service.nl2sql.Nl2SqlService;
import com.alibaba.cloud.ai.service.schema.SchemaService;
-import com.alibaba.cloud.ai.service.vectorstore.VectorStoreService;
+import com.alibaba.cloud.ai.service.vectorstore.AgentVectorStoreService;
import com.alibaba.cloud.ai.util.JsonUtil;
+import com.alibaba.cloud.ai.util.SchemaProcessorUtil;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.document.Document;
+import org.springframework.util.Assert;
import reactor.core.publisher.Flux;
import java.util.Collections;
@@ -44,11 +47,14 @@ public abstract class AbstractQueryProcessingService implements QueryProcessingS
private final LlmService aiService;
- public AbstractQueryProcessingService(LlmService aiService) {
+ private final DatasourceService datasourceService;
+
+ public AbstractQueryProcessingService(LlmService aiService, DatasourceService datasourceService) {
this.aiService = aiService;
+ this.datasourceService = datasourceService;
}
- protected abstract VectorStoreService getVectorStoreService();
+ protected abstract AgentVectorStoreService getVectorStoreService();
protected abstract SchemaService getSchemaService();
@@ -57,13 +63,9 @@ public AbstractQueryProcessingService(LlmService aiService) {
@Override
public List extractEvidences(String query, String agentId) {
logger.debug("Extracting evidences for query: {} with agentId: {}", query, agentId);
- List evidenceDocuments;
- if (agentId != null && !agentId.trim().isEmpty()) {
- evidenceDocuments = getVectorStoreService().getDocumentsForAgent(agentId, query, "evidence");
- }
- else {
- evidenceDocuments = getVectorStoreService().getDocuments(query, "evidence");
- }
+ Assert.notNull(agentId, "AgentId cannot be null");
+ List evidenceDocuments = getVectorStoreService().getDocumentsForAgent(agentId, query, "evidence");
+
List evidences = evidenceDocuments.stream().map(Document::getText).collect(Collectors.toList());
logger.debug("Extracted {} evidences: {}", evidences.size(), evidences);
return evidences;
@@ -176,17 +178,20 @@ private String processTimeExpressions(String query) {
}
private SchemaDTO select(String query, List evidenceList, String agentId) throws Exception {
+ Assert.notNull(agentId, "AgentId cannot be null");
logger.debug("Starting schema selection for query: {} with {} evidences and agentId: {}", query,
evidenceList.size(), agentId);
List keywords = extractKeywords(query, evidenceList);
logger.debug("Using {} keywords for schema selection", keywords != null ? keywords.size() : 0);
- SchemaDTO schemaDTO;
- if (agentId != null) {
- schemaDTO = getSchemaService().mixRagForAgent(agentId, query, keywords);
- }
- else {
- schemaDTO = getSchemaService().mixRag(query, keywords);
+
+ com.alibaba.cloud.ai.entity.Datasource datasource = datasourceService
+ .getActiveDatasourceByAgentId(Integer.valueOf(agentId));
+ if (datasource == null) {
+ throw new RuntimeException("No active datasource found for agentId: " + agentId);
}
+ SchemaDTO schemaDTO = getSchemaService().mixRagForAgent(agentId, query, keywords,
+ SchemaProcessorUtil.createDbConfigFromDatasource(datasource));
+
logger.debug("Retrieved schema with {} tables", schemaDTO.getTable() != null ? schemaDTO.getTable().size() : 0);
SchemaDTO result = fineSelect(schemaDTO, query, evidenceList);
logger.debug("Fine selection completed, final schema has {} tables",
diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/processing/impls/QueryProcessingServiceImpl.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/processing/impls/QueryProcessingServiceImpl.java
index d97579e..83efe51 100644
--- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/processing/impls/QueryProcessingServiceImpl.java
+++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/processing/impls/QueryProcessingServiceImpl.java
@@ -16,32 +16,33 @@
package com.alibaba.cloud.ai.service.processing.impls;
+import com.alibaba.cloud.ai.service.DatasourceService;
import com.alibaba.cloud.ai.service.LlmService;
import com.alibaba.cloud.ai.service.nl2sql.Nl2SqlService;
import com.alibaba.cloud.ai.service.schema.SchemaService;
import com.alibaba.cloud.ai.service.processing.AbstractQueryProcessingService;
-import com.alibaba.cloud.ai.service.vectorstore.VectorStoreService;
+import com.alibaba.cloud.ai.service.vectorstore.AgentVectorStoreService;
import org.springframework.stereotype.Service;
@Service
public class QueryProcessingServiceImpl extends AbstractQueryProcessingService {
- private final VectorStoreService vectorStoreService;
+ private final AgentVectorStoreService vectorStoreService;
private final SchemaService schemaService;
private final Nl2SqlService nl2SqlService;
- public QueryProcessingServiceImpl(LlmService aiService, VectorStoreService vectorStoreService,
- SchemaService schemaService, Nl2SqlService nl2SqlService) {
- super(aiService);
+ public QueryProcessingServiceImpl(LlmService aiService, AgentVectorStoreService vectorStoreService,
+ SchemaService schemaService, Nl2SqlService nl2SqlService, DatasourceService datasourceService) {
+ super(aiService, datasourceService);
this.vectorStoreService = vectorStoreService;
this.schemaService = schemaService;
this.nl2SqlService = nl2SqlService;
}
@Override
- protected VectorStoreService getVectorStoreService() {
+ protected AgentVectorStoreService getVectorStoreService() {
return this.vectorStoreService;
}
diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/AbstractSchemaService.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/AbstractSchemaService.java
index ee0eae5..e90730a 100644
--- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/AbstractSchemaService.java
+++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/AbstractSchemaService.java
@@ -15,29 +15,28 @@
*/
package com.alibaba.cloud.ai.service.schema;
+import com.alibaba.cloud.ai.constant.Constant;
import com.alibaba.cloud.ai.enums.BizDataSourceTypeEnum;
import com.alibaba.cloud.ai.connector.config.DbConfig;
-import com.alibaba.cloud.ai.request.SearchRequest;
import com.alibaba.cloud.ai.dto.schema.ColumnDTO;
import com.alibaba.cloud.ai.dto.schema.SchemaDTO;
import com.alibaba.cloud.ai.dto.schema.TableDTO;
-import com.alibaba.cloud.ai.service.vectorstore.VectorStoreService;
+import com.alibaba.cloud.ai.service.vectorstore.AgentVectorStoreService;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
+import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.ai.document.Document;
+import org.springframework.util.Assert;
import java.util.ArrayList;
import java.util.Arrays;
-import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
-import java.util.Optional;
import java.util.Set;
-import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
@@ -45,19 +44,17 @@
/**
* Schema service base class, providing common method implementations
*/
+@Slf4j
public abstract class AbstractSchemaService implements SchemaService {
- protected final DbConfig dbConfig;
-
protected final ObjectMapper objectMapper;
/**
* Vector storage service
*/
- protected final VectorStoreService vectorStoreService;
+ protected final AgentVectorStoreService vectorStoreService;
- public AbstractSchemaService(DbConfig dbConfig, ObjectMapper objectMapper, VectorStoreService vectorStoreService) {
- this.dbConfig = dbConfig;
+ public AbstractSchemaService(ObjectMapper objectMapper, AgentVectorStoreService vectorStoreService) {
this.objectMapper = objectMapper;
this.vectorStoreService = vectorStoreService;
}
@@ -67,42 +64,44 @@ public AbstractSchemaService(DbConfig dbConfig, ObjectMapper objectMapper, Vecto
* @param agentId agent ID
* @param query query
* @param keywords keyword list
+ * @param dbConfig Database configuration
* @return SchemaDTO
*/
@Override
- public SchemaDTO mixRagForAgent(String agentId, String query, List keywords) {
+ public SchemaDTO mixRagForAgent(String agentId, String query, List keywords, DbConfig dbConfig) {
SchemaDTO schemaDTO = new SchemaDTO();
- extractDatabaseName(schemaDTO); // Set database name or schema name
+ extractDatabaseName(schemaDTO, dbConfig); // Set database name or schema name
+
+ // Get table documents
+ List tableDocuments = getTableDocumentsForAgent(agentId, query);
- List tableDocuments = getTableDocuments(query, agentId); // Get table
- // documents
- List> columnDocumentList = getColumnDocumentsByKeywords(keywords, agentId); // Get
- // column
- // document
- // list
+ // Get column documents
+ List> columnDocumentList = getColumnDocumentsByKeywordsForAgent(agentId, keywords);
- buildSchemaFromDocuments(columnDocumentList, tableDocuments, schemaDTO);
+ buildSchemaFromDocuments(agentId, columnDocumentList, tableDocuments, schemaDTO);
return schemaDTO;
}
@Override
- public void buildSchemaFromDocuments(List> columnDocumentList, List tableDocuments,
- SchemaDTO schemaDTO) {
+ public void buildSchemaFromDocuments(String agentId, List> columnDocumentList,
+ List tableDocuments, SchemaDTO schemaDTO) {
// Process column weights and sort by table association
- processColumnWeights(columnDocumentList, tableDocuments);
+ updateAndSortColumnScoresByTableWeights(columnDocumentList, tableDocuments);
// Initialize column selector, TODO upper limit 100 has issues
- Map weightedColumns = selectWeightedColumns(columnDocumentList, 100);
+ Map weightedColumns = selectColumnsByRoundRobin(columnDocumentList, 100);
- Set foreignKeySet = extractForeignKeyRelations(tableDocuments);
+ // 如果外键关系是"订单表.订单ID=订单详情表.订单ID",那么 relatedNamesFromForeignKeys
+ // 将包含"订单表.订单ID"和"订单详情表.订单ID"
+ Set relatedNamesFromForeignKeys = extractRelatedNamesFromForeignKeys(tableDocuments);
// Build table list
List tableList = buildTableListFromDocuments(tableDocuments);
// Supplement missing foreign key corresponding tables
- expandTableDocumentsWithForeignKeys(tableDocuments, foreignKeySet, "table");
- expandColumnDocumentsWithForeignKeys(weightedColumns, foreignKeySet, "column");
+ expandTableDocumentsWithForeignKeys(agentId, tableDocuments, relatedNamesFromForeignKeys);
+ expandColumnDocumentsWithForeignKeys(agentId, weightedColumns, relatedNamesFromForeignKeys);
// Attach weighted columns to corresponding tables
attachColumnsToTables(weightedColumns, tableList);
@@ -118,38 +117,13 @@ public void buildSchemaFromDocuments(List> columnDocumentList, Li
schemaDTO.setForeignKeys(List.of(new ArrayList<>(foreignKeys)));
}
- /**
- * Get all table documents by keywords - supports agent isolation
- */
- public List getTableDocuments(String query, String agentId) {
- if (agentId != null && !agentId.trim().isEmpty()) {
- return vectorStoreService.getDocumentsForAgent(agentId, query, "table");
- }
- else {
- return vectorStoreService.getDocuments(query, "table");
- }
- }
-
/**
* Get all table documents by keywords for specified agent
*/
@Override
public List getTableDocumentsForAgent(String agentId, String query) {
- return vectorStoreService.getDocumentsForAgent(agentId, query, "table");
- }
-
- /**
- * Get all column documents by keywords - supports agent isolation
- */
- public List> getColumnDocumentsByKeywords(List keywords, String agentId) {
- if (agentId != null) {
- return getColumnDocumentsByKeywordsForAgent(agentId, keywords);
- }
- else {
- return keywords.stream()
- .map(kw -> vectorStoreService.getDocuments(kw, "column"))
- .collect(Collectors.toList());
- }
+ Assert.notNull(agentId, "agentId cannot be null");
+ return vectorStoreService.getDocumentsForAgent(agentId, query, Constant.TABLE);
}
/**
@@ -157,21 +131,19 @@ public List> getColumnDocumentsByKeywords(List keywords,
*/
@Override
public List> getColumnDocumentsByKeywordsForAgent(String agentId, List keywords) {
- if (agentId == null) {
- return keywords.stream()
- .map(kw -> vectorStoreService.getDocuments(kw, "column"))
- .collect(Collectors.toList());
- }
+
+ Assert.notNull(agentId, "agentId cannot be null");
+
return keywords.stream()
- .map(kw -> vectorStoreService.getDocumentsForAgent(agentId, kw, "column"))
+ .map(kw -> vectorStoreService.getDocumentsForAgent(agentId, kw, Constant.COLUMN))
.collect(Collectors.toList());
}
/**
* Expand column documents (supplement missing columns through foreign keys)
*/
- private void expandColumnDocumentsWithForeignKeys(Map weightedColumns, Set foreignKeySet,
- String vectorType) {
+ private void expandColumnDocumentsWithForeignKeys(String agentId, Map weightedColumns,
+ Set foreignKeySet) {
Set existingColumnNames = weightedColumns.keySet();
Set missingColumns = new HashSet<>();
@@ -182,7 +154,7 @@ private void expandColumnDocumentsWithForeignKeys(Map weighted
}
for (String columnName : missingColumns) {
- addColumnsDocument(weightedColumns, columnName, vectorType);
+ addColumnsDocument(agentId, weightedColumns, columnName);
}
}
@@ -190,8 +162,8 @@ private void expandColumnDocumentsWithForeignKeys(Map weighted
/**
* Expand table documents (supplement missing tables through foreign keys)
*/
- private void expandTableDocumentsWithForeignKeys(List tableDocuments, Set foreignKeySet,
- String vectorType) {
+ private void expandTableDocumentsWithForeignKeys(String agentId, List tableDocuments,
+ Set foreignKeySet) {
Set uniqueTableNames = tableDocuments.stream()
.map(doc -> (String) doc.getMetadata().get("name"))
.collect(Collectors.toSet());
@@ -208,46 +180,67 @@ private void expandTableDocumentsWithForeignKeys(List tableDocuments,
}
for (String tableName : missingTables) {
- addTableDocument(tableDocuments, tableName, vectorType);
+ addTableDocument(agentId, tableDocuments, tableName);
}
}
- /**
- * Add missing table documents
- * @param tableDocuments
- * @param tableName
- * @param vectorType
- */
- protected abstract void addTableDocument(List tableDocuments, String tableName, String vectorType);
+ protected void addTableDocument(String agentId, List tableDocuments, String tableName) {
+ List documentsForAgent = vectorStoreService.getDocumentsForAgent(agentId, tableName, Constant.TABLE);
+ if (documentsForAgent != null && !documentsForAgent.isEmpty())
+ tableDocuments.addAll(documentsForAgent);
+ }
- protected abstract void addColumnsDocument(Map weightedColumns, String columnName,
- String vectorType);
+ protected void addColumnsDocument(String agentId, Map weightedColumns, String columnName) {
+ List documentsForAgent = vectorStoreService.getDocumentsForAgent(agentId, columnName,
+ Constant.COLUMN);
+ if (documentsForAgent != null && !documentsForAgent.isEmpty()) {
+ for (Document document : documentsForAgent)
+ weightedColumns.putIfAbsent(document.getId(), document);
+ }
+ }
/**
- * Select up to maxCount columns by weight
+ * Select up to maxCount columns by weight using a round-robin approach to ensure
+ * balanced selection across different tables
*/
- protected Map selectWeightedColumns(List> columnDocumentList, int maxCount) {
- Map result = new HashMap<>();
- int index = 0;
-
- while (result.size() < maxCount) {
- boolean completed = true;
- for (List docs : columnDocumentList) {
- if (index < docs.size()) {
- Document doc = docs.get(index);
- String id = doc.getId();
- if (!result.containsKey(id)) {
- result.put(id, doc);
+ protected Map selectColumnsByRoundRobin(List> columnDocumentList, int maxCount) {
+ Map selectedColumns = new HashMap<>();
+ int currentRound = 0;
+
+ // Continue selecting columns until we reach maxCount or exhaust all columns
+ while (selectedColumns.size() < maxCount) {
+ boolean hasMoreColumnsInAnyList = false;
+
+ // Process each table's column list in the current round
+ for (List tableColumns : columnDocumentList) {
+ if (currentRound < tableColumns.size()) {
+ // Get the column at current position (already sorted by weight)
+ Document column = tableColumns.get(currentRound);
+ String columnId = column.getId();
+
+ // Add to selection if not already selected
+ if (!selectedColumns.containsKey(columnId)) {
+ selectedColumns.put(columnId, column);
+
+ // Stop if we've reached the maximum count
+ if (selectedColumns.size() >= maxCount) {
+ break;
+ }
}
- completed = false;
+
+ hasMoreColumnsInAnyList = true;
}
}
- index++;
- if (completed) {
+
+ // If no more columns in any list, exit the loop
+ if (!hasMoreColumnsInAnyList) {
break;
}
+
+ currentRound++;
}
- return result;
+
+ return selectedColumns;
}
/**
@@ -284,34 +277,147 @@ else if (primaryKeyObj instanceof String) {
/**
* Score each column (combining with its table's score)
*/
- public void processColumnWeights(List> columnDocuments, List tableDocuments) {
- columnDocuments.replaceAll(docs -> docs.stream()
- .filter(column -> tableDocuments.stream()
- .anyMatch(table -> table.getMetadata().get("name").equals(column.getMetadata().get("tableName"))))
- .peek(column -> {
- Optional matchingTable = tableDocuments.stream()
- .filter(table -> table.getMetadata().get("name").equals(column.getMetadata().get("tableName")))
- .findFirst();
- matchingTable.ifPresent(tableDoc -> {
- Double tableScore = Optional.ofNullable((Double) tableDoc.getMetadata().get("score"))
- .orElse(tableDoc.getScore());
- Double columnScore = Optional.ofNullable((Double) column.getMetadata().get("score"))
- .orElse(column.getScore());
- if (tableScore != null && columnScore != null) {
- column.getMetadata().put("score", columnScore * tableScore);
- }
- });
- })
- .sorted(Comparator.comparing((Document d) -> (Double) d.getMetadata().get("score")).reversed())
- .collect(Collectors.toList()));
+ public void updateAndSortColumnScoresByTableWeights(List> columnDocuments,
+ List tableDocuments) {
+ for (int i = 0; i < columnDocuments.size(); i++) {
+ List processedColumns = processSingleTableColumns(columnDocuments.get(i), tableDocuments);
+ columnDocuments.set(i, processedColumns);
+ }
+ }
+
+ /**
+ * Process columns for a single table, filtering and updating scores
+ */
+ private List processSingleTableColumns(List columns, List tableDocuments) {
+ // Step 1: Filter columns to only include those that have a matching table
+ List filteredColumns = filterColumnsWithMatchingTables(columns, tableDocuments);
+
+ // Step 2: Update column scores by multiplying with their table scores
+ updateColumnScoresWithTableScores(filteredColumns, tableDocuments);
+
+ // Step 3: Sort columns by their new scores in descending order
+ return sortColumnsByScoreDescending(filteredColumns);
+ }
+
+ /**
+ * Filter columns to only include those that have a matching table
+ */
+ private List filterColumnsWithMatchingTables(List columns, List tableDocuments) {
+ List result = new ArrayList<>();
+
+ for (Document column : columns) {
+ String columnTableName = (String) column.getMetadata().get("tableName");
+ if (hasMatchingTable(tableDocuments, columnTableName)) {
+ result.add(column);
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Check if there's a table with the given name in the table documents
+ */
+ private boolean hasMatchingTable(List tableDocuments, String tableName) {
+ if (StringUtils.isBlank(tableName)) {
+ return false;
+ }
+
+ for (Document table : tableDocuments) {
+ String table_name = (String) table.getMetadata().get("name");
+ if (tableName.equals(table_name)) {
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ /**
+ * Update column scores by multiplying with their table scores
+ */
+ private void updateColumnScoresWithTableScores(List columns, List tableDocuments) {
+ for (Document column : columns) {
+ String columnTableName = (String) column.getMetadata().get("tableName");
+ Document matchingTable = findTableByName(tableDocuments, columnTableName);
+
+ if (matchingTable != null) {
+ Double tableScore = getTableScore(matchingTable);
+ Double columnScore = getColumnScore(column);
+
+ if (tableScore != null && columnScore != null) {
+ Double newScore = columnScore * tableScore;
+ column.getMetadata().put("score", newScore);
+ }
+ }
+ }
+ }
+
+ /**
+ * Find a table document by its name
+ */
+ private Document findTableByName(List tableDocuments, String tableName) {
+ if (StringUtils.isBlank(tableName)) {
+ return null;
+ }
+
+ for (Document table : tableDocuments) {
+ String table_name = (String) table.getMetadata().get("name");
+ if (tableName.equals(table_name)) {
+ return table;
+ }
+ }
+
+ return null;
+ }
+
+ /**
+ * Get the score from a table document
+ */
+ private Double getTableScore(Document tableDoc) {
+ Double scoreFromMetadata = (Double) tableDoc.getMetadata().get("score");
+ return scoreFromMetadata != null ? scoreFromMetadata : tableDoc.getScore();
}
/**
- * Extract foreign key relationships
+ * Get the score from a column document
+ */
+ private Double getColumnScore(Document columnDoc) {
+ Double scoreFromMetadata = (Double) columnDoc.getMetadata().get("score");
+ return scoreFromMetadata != null ? scoreFromMetadata : columnDoc.getScore();
+ }
+
+ /**
+ * Sort columns by their scores in descending order
+ */
+ private List sortColumnsByScoreDescending(List columns) {
+ List sortedColumns = new ArrayList<>(columns);
+
+ sortedColumns.sort((doc1, doc2) -> {
+ Double score1 = (Double) doc1.getMetadata().get("score");
+ Double score2 = (Double) doc2.getMetadata().get("score");
+
+ // Handle null scores
+ if (score1 == null && score2 == null)
+ return 0;
+ if (score1 == null)
+ return 1;
+ if (score2 == null)
+ return -1;
+
+ // Sort in descending order
+ return score2.compareTo(score1);
+ });
+
+ return sortedColumns;
+ }
+
+ /**
+ * Extract related table and column names from foreign key relationships
* @param tableDocuments table document list
- * @return foreign key relationship set
+ * @return set of related names in format "tableName.columnName"
*/
- protected Set extractForeignKeyRelations(List tableDocuments) {
+ protected Set extractRelatedNamesFromForeignKeys(List tableDocuments) {
Set result = new HashSet<>();
for (Document doc : tableDocuments) {
@@ -352,7 +458,8 @@ protected void attachColumnsToTables(Map weightedColumns, List
});
columnDTO.setData(samples);
}
- catch (Exception ignore) {
+ catch (Exception e) {
+ log.error("Failed to parse samples: {}", samplesStr, e);
}
}
@@ -364,28 +471,13 @@ protected void attachColumnsToTables(Map weightedColumns, List
}
}
- /**
- * Get table metadata
- * @param tableName table name
- * @return table metadata
- */
- protected Map getTableMetadata(String tableName) {
- List tableDocuments = getTableDocuments(tableName);
- for (Document doc : tableDocuments) {
- Map metadata = doc.getMetadata();
- if (tableName.equals(metadata.get("name"))) {
- return metadata;
- }
- }
- return null;
- }
-
/**
* Extract database name
* @param schemaDTO SchemaDTO
+ * @param dbConfig Database configuration
*/
@Override
- public void extractDatabaseName(SchemaDTO schemaDTO) {
+ public void extractDatabaseName(SchemaDTO schemaDTO, DbConfig dbConfig) {
String pattern = ":\\d+/([^/?&]+)";
if (BizDataSourceTypeEnum.isMysqlDialect(dbConfig.getDialectType())) {
Pattern regex = Pattern.compile(pattern);
@@ -399,31 +491,4 @@ else if (BizDataSourceTypeEnum.isPgDialect(dbConfig.getDialectType())) {
}
}
- /**
- * Common document query processing template to reduce subclass redundant code.
- */
- protected void handleDocumentQuery(List targetList, String key, String vectorType,
- Function requestBuilder, Function> searchFunc) {
- SearchRequest request = requestBuilder.apply(key);
- request.setVectorType(vectorType);
- request.setTopK(10);
- List docs = searchFunc.apply(request);
- if (CollectionUtils.isNotEmpty(docs)) {
- targetList.addAll(docs);
- }
- }
-
- protected void handleDocumentQuery(Map targetMap, String key, String vectorType,
- Function requestBuilder, Function> searchFunc) {
- SearchRequest request = requestBuilder.apply(key);
- request.setVectorType(vectorType);
- request.setTopK(10);
- List docs = searchFunc.apply(request);
- if (CollectionUtils.isNotEmpty(docs)) {
- for (Document doc : docs) {
- targetMap.putIfAbsent(doc.getId(), doc);
- }
- }
- }
-
}
diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/SchemaService.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/SchemaService.java
index 47c93a0..7c64b9b 100644
--- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/SchemaService.java
+++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/SchemaService.java
@@ -16,6 +16,7 @@
package com.alibaba.cloud.ai.service.schema;
+import com.alibaba.cloud.ai.connector.config.DbConfig;
import com.alibaba.cloud.ai.dto.schema.SchemaDTO;
import org.springframework.ai.document.Document;
@@ -27,12 +28,12 @@ public interface SchemaService {
List> getColumnDocumentsByKeywordsForAgent(String agentId, List keywords);
- void extractDatabaseName(SchemaDTO schemaDTO);
+ void extractDatabaseName(SchemaDTO schemaDTO, DbConfig dbConfig);
- void buildSchemaFromDocuments(List> columnDocumentList, List tableDocuments,
- SchemaDTO schemaDTO);
+ void buildSchemaFromDocuments(String agentId, List> columnDocumentList,
+ List tableDocuments, SchemaDTO schemaDTO);
- SchemaDTO mixRagForAgent(String agentId, String query, List keywords);
+ SchemaDTO mixRagForAgent(String agentId, String query, List keywords, DbConfig dbConfig);
default List getTableDocuments(String query) {
return getTableDocumentsForAgent(null, query);
@@ -42,8 +43,8 @@ default List> getColumnDocumentsByKeywords(List keywords)
return getColumnDocumentsByKeywordsForAgent(null, keywords);
}
- default SchemaDTO mixRag(String query, List keywords) {
- return mixRagForAgent(null, query, keywords);
+ default SchemaDTO mixRag(String query, List keywords, DbConfig dbConfig) {
+ return mixRagForAgent(null, query, keywords, dbConfig);
}
}
diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/SchemaServiceFactory.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/SchemaServiceFactory.java
index bf0c858..250c99f 100644
--- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/SchemaServiceFactory.java
+++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/SchemaServiceFactory.java
@@ -16,37 +16,65 @@
package com.alibaba.cloud.ai.service.schema;
-import com.alibaba.cloud.ai.connector.config.DbConfig;
import com.alibaba.cloud.ai.service.schema.impls.AnalyticSchemaService;
import com.alibaba.cloud.ai.service.schema.impls.SimpleSchemaService;
-import com.alibaba.cloud.ai.service.vectorstore.VectorStoreService;
+import com.alibaba.cloud.ai.service.vectorstore.AgentVectorStoreService;
+import com.alibaba.cloud.ai.service.vectorstore.VectorStoreType;
+import com.alibaba.cloud.ai.service.vectorstore.impls.AnalyticAgentVectorStoreService;
+import com.alibaba.cloud.ai.service.vectorstore.impls.SimpleAgentVectorStoreService;
import com.alibaba.cloud.ai.util.JsonUtil;
+import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.FactoryBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
+import org.springframework.context.annotation.DependsOn;
import org.springframework.stereotype.Component;
+import java.util.HashMap;
+import java.util.Map;
+
+@Slf4j
@Component
+@DependsOn("agentVectorStoreServiceFactory")
public class SchemaServiceFactory implements FactoryBean {
- // todo: 改为枚举,由用户配置决定实现类
- @Value("${spring.ai.vectorstore.analytic.enabled:false}")
- private Boolean analyticEnabled;
+ @Value("${spring.ai.vectorstore.type:SIMPLE}")
+ private VectorStoreType vectorStoreType;
+
+ @Autowired(required = false)
+ private AgentVectorStoreService agentVectorStoreService;
+
+ @FunctionalInterface
+ private interface SchemaServiceCreator {
+
+ SchemaService create();
- @Autowired
- private DbConfig dbConfig;
+ }
+
+ private final Map serviceCreators = new HashMap<>();
- @Autowired
- private VectorStoreService vectorStoreService;
+ public SchemaServiceFactory() {
+ // 初始化各种向量存储类型的创建策略
+ serviceCreators.put(VectorStoreType.ANALYTIC_DB, this::createAnalyticSchemaService);
+ serviceCreators.put(VectorStoreType.SIMPLE, this::createSimpleSchemaService);
+
+ }
@Override
public SchemaService getObject() {
- if (Boolean.TRUE.equals(analyticEnabled)) {
- return new AnalyticSchemaService(dbConfig, JsonUtil.getObjectMapper(), vectorStoreService);
+ if (agentVectorStoreService == null) {
+ throw new IllegalStateException("AgentVectorStoreService is not initialized.");
}
- else {
- return new SimpleSchemaService(dbConfig, JsonUtil.getObjectMapper(), vectorStoreService);
+
+ // 根据配置的向量存储类型获取对应的创建策略
+ SchemaServiceCreator creator = serviceCreators.get(vectorStoreType);
+ if (creator == null) {
+ log.warn("Unsupported vector store type: {}, falling back to SIMPLE", vectorStoreType);
+ creator = serviceCreators.get(VectorStoreType.SIMPLE);
}
+
+ // 使用选定的策略创建SchemaService实例
+ return creator.create();
}
@Override
@@ -54,4 +82,32 @@ public Class> getObjectType() {
return SchemaService.class;
}
+ /**
+ * 创建分析型数据库的SchemaService
+ * @return AnalyticSchemaService实例
+ */
+ private SchemaService createAnalyticSchemaService() {
+ log.info("Using AnalyticSchemaService");
+ if (!(agentVectorStoreService instanceof AnalyticAgentVectorStoreService)) {
+ throw new IllegalStateException(
+ "AgentVectorStoreService is not an instance of AnalyticAgentVectorStoreService");
+ }
+ return new AnalyticSchemaService(JsonUtil.getObjectMapper(),
+ (AnalyticAgentVectorStoreService) agentVectorStoreService);
+ }
+
+ /**
+ * 创建简单内存存储的SchemaService
+ * @return SimpleSchemaService实例
+ */
+ private SchemaService createSimpleSchemaService() {
+ log.info("Using SimpleSchemaService");
+ if (!(agentVectorStoreService instanceof SimpleAgentVectorStoreService)) {
+ throw new IllegalStateException(
+ "AgentVectorStoreService is not an instance of SimpleAgentVectorStoreService");
+ }
+ return new SimpleSchemaService(JsonUtil.getObjectMapper(),
+ (SimpleAgentVectorStoreService) agentVectorStoreService);
+ }
+
}
diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/impls/AnalyticSchemaService.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/impls/AnalyticSchemaService.java
index bc20ad5..599c940 100644
--- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/impls/AnalyticSchemaService.java
+++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/impls/AnalyticSchemaService.java
@@ -15,45 +15,17 @@
*/
package com.alibaba.cloud.ai.service.schema.impls;
-import com.alibaba.cloud.ai.connector.config.DbConfig;
-import com.alibaba.cloud.ai.request.SearchRequest;
import com.alibaba.cloud.ai.service.schema.AbstractSchemaService;
-import com.alibaba.cloud.ai.service.vectorstore.VectorStoreService;
+import com.alibaba.cloud.ai.service.vectorstore.impls.AnalyticAgentVectorStoreService;
import com.fasterxml.jackson.databind.ObjectMapper;
-import org.springframework.ai.document.Document;
-
-import java.util.List;
-import java.util.Map;
/**
* Schema building service, supports RAG-based hybrid queries.
*/
public class AnalyticSchemaService extends AbstractSchemaService {
- public AnalyticSchemaService(DbConfig dbConfig, ObjectMapper objectMapper, VectorStoreService vectorStoreService) {
- super(dbConfig, objectMapper, vectorStoreService);
- }
-
- @Override
- protected void addTableDocument(List tableDocuments, String tableName, String vectorType) {
- handleDocumentQuery(tableDocuments, tableName, vectorType, name -> {
- SearchRequest req = new SearchRequest();
- req.setQuery(null);
- req.setFilterFormatted("jsonb_extract_path_text(metadata, 'vectorType') = '" + vectorType
- + "' and refdocid = '" + name + "'");
- return req;
- }, vectorStoreService::searchWithFilter);
- }
-
- @Override
- protected void addColumnsDocument(Map weightedColumns, String columnName, String vectorType) {
- handleDocumentQuery(weightedColumns, columnName, vectorType, name -> {
- SearchRequest req = new SearchRequest();
- req.setQuery(null);
- req.setFilterFormatted("jsonb_extract_path_text(metadata, 'vectorType') = '" + vectorType
- + "' and refdocid = '" + name + "'");
- return req;
- }, vectorStoreService::searchWithFilter);
+ public AnalyticSchemaService(ObjectMapper objectMapper, AnalyticAgentVectorStoreService vectorStoreService) {
+ super(objectMapper, vectorStoreService);
}
}
diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/impls/SimpleSchemaService.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/impls/SimpleSchemaService.java
index ebc356d..03c07bc 100644
--- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/impls/SimpleSchemaService.java
+++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/impls/SimpleSchemaService.java
@@ -15,38 +15,14 @@
*/
package com.alibaba.cloud.ai.service.schema.impls;
-import com.alibaba.cloud.ai.connector.config.DbConfig;
-import com.alibaba.cloud.ai.request.SearchRequest;
import com.alibaba.cloud.ai.service.schema.AbstractSchemaService;
-import com.alibaba.cloud.ai.service.vectorstore.VectorStoreService;
+import com.alibaba.cloud.ai.service.vectorstore.impls.SimpleAgentVectorStoreService;
import com.fasterxml.jackson.databind.ObjectMapper;
-import org.springframework.ai.document.Document;
-
-import java.util.List;
-import java.util.Map;
public class SimpleSchemaService extends AbstractSchemaService {
- public SimpleSchemaService(DbConfig dbConfig, ObjectMapper objectMapper, VectorStoreService vectorStoreService) {
- super(dbConfig, objectMapper, vectorStoreService);
- }
-
- @Override
- protected void addTableDocument(List tableDocuments, String tableName, String vectorType) {
- handleDocumentQuery(tableDocuments, tableName, vectorType, name -> {
- SearchRequest req = new SearchRequest();
- req.setName(name);
- return req;
- }, vectorStoreService::searchTableByNameAndVectorType);
- }
-
- @Override
- protected void addColumnsDocument(Map weightedColumns, String tableName, String vectorType) {
- handleDocumentQuery(weightedColumns, tableName, vectorType, name -> {
- SearchRequest req = new SearchRequest();
- req.setName(name);
- return req;
- }, vectorStoreService::searchTableByNameAndVectorType);
+ public SimpleSchemaService(ObjectMapper objectMapper, SimpleAgentVectorStoreService vectorStoreService) {
+ super(objectMapper, vectorStoreService);
}
}
diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/AbstractAgentVectorStoreService.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/AbstractAgentVectorStoreService.java
new file mode 100644
index 0000000..6e59d56
--- /dev/null
+++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/AbstractAgentVectorStoreService.java
@@ -0,0 +1,318 @@
+/*
+ * Copyright 2024-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.alibaba.cloud.ai.service.vectorstore;
+
+import static com.alibaba.cloud.ai.util.DocumentConverterUtil.convertColumnsToDocuments;
+import static com.alibaba.cloud.ai.util.DocumentConverterUtil.convertTablesToDocuments;
+
+import java.util.*;
+import java.util.stream.Collectors;
+
+import org.springframework.ai.document.Document;
+import org.springframework.ai.embedding.EmbeddingModel;
+import org.springframework.ai.vectorstore.VectorStore;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.util.Assert;
+import org.springframework.util.StringUtils;
+
+import com.alibaba.cloud.ai.connector.accessor.Accessor;
+import com.alibaba.cloud.ai.connector.accessor.AccessorFactory;
+import com.alibaba.cloud.ai.connector.bo.DbQueryParameter;
+import com.alibaba.cloud.ai.connector.bo.ForeignKeyInfoBO;
+import com.alibaba.cloud.ai.connector.bo.TableInfoBO;
+import com.alibaba.cloud.ai.connector.config.DbConfig;
+import com.alibaba.cloud.ai.constant.Constant;
+import com.alibaba.cloud.ai.request.AgentSearchRequest;
+import com.alibaba.cloud.ai.request.SchemaInitRequest;
+import com.alibaba.cloud.ai.util.JsonUtil;
+import com.alibaba.cloud.ai.util.SchemaProcessorUtil;
+
+import lombok.extern.slf4j.Slf4j;
+
+@Slf4j
+public abstract class AbstractAgentVectorStoreService implements AgentVectorStoreService {
+
+ /**
+ * 相似度阈值配置,用于过滤相似度分数大于等于此阈值的文档
+ */
+ @Value("${spring.ai.vectorstore.similarityThreshold:0.6}")
+ protected double similarityThreshold;
+
+ /**
+ * Get embedding model
+ */
+ protected abstract EmbeddingModel getEmbeddingModel();
+
+ protected final AccessorFactory accessorFactory;
+
+ public AbstractAgentVectorStoreService(AccessorFactory accessorFactory) {
+ this.accessorFactory = accessorFactory;
+ }
+
+ @Override
+ public List search(AgentSearchRequest searchRequest) {
+ Assert.notNull(searchRequest, "SearchRequest cannot be null");
+ Assert.notNull(searchRequest.getAgentId(), "agentId cannot be null");
+ org.springframework.ai.vectorstore.SearchRequest.Builder builder = org.springframework.ai.vectorstore.SearchRequest
+ .builder();
+
+ if (StringUtils.hasText(searchRequest.getQuery()))
+ builder.query(searchRequest.getQuery());
+
+ if (Objects.nonNull(searchRequest.getTopK()))
+ builder.topK(searchRequest.getTopK());
+
+ String filterFormatted = buildFilterExpressionString(searchRequest.getMetadataFilter());
+ if (StringUtils.hasText(filterFormatted))
+ builder.filterExpression(filterFormatted);
+ builder.similarityThreshold(similarityThreshold);
+ List results = getVectorStore().similaritySearch(builder.build());
+ log.info("Search completed. Found {} documents for SearchRequest: {}", results.size(), searchRequest);
+ return results;
+
+ }
+
+ // 模板方法 - 通用schema处理流程
+ @Override
+ public final Boolean schema(String agentId, SchemaInitRequest schemaInitRequest) throws Exception {
+ try {
+
+ DbConfig config = schemaInitRequest.getDbConfig();
+ DbQueryParameter dqp = DbQueryParameter.from(config)
+ .setSchema(config.getSchema())
+ .setTables(schemaInitRequest.getTables());
+
+ // 根据当前DbConfig获取Accessor
+ Accessor dbAccessor = accessorFactory.getAccessorByDbConfig(config);
+
+ // 清理旧数据
+ clearSchemaDataForAgent(agentId);
+
+ // 处理外键
+ List foreignKeys = dbAccessor.showForeignKeys(config, dqp);
+ Map> foreignKeyMap = buildForeignKeyMap(foreignKeys);
+
+ // 处理表和列
+ List tables = dbAccessor.fetchTables(config, dqp);
+ for (TableInfoBO table : tables) {
+ SchemaProcessorUtil.enrichTableMetadata(table, dqp, config, dbAccessor, JsonUtil.getObjectMapper(),
+ foreignKeyMap);
+ }
+
+ // 转换为文档
+ List columnDocs = convertColumnsToDocuments(agentId, tables);
+ List tableDocs = convertTablesToDocuments(agentId, tables);
+
+ // 存储文档
+ return storeSchemaDocuments(columnDocs, tableDocs);
+ }
+ catch (Exception e) {
+ log.error("Failed to process schema ", e);
+ return false;
+ }
+ }
+
+ protected Boolean storeSchemaDocuments(List columns, List tables) {
+ try {
+ getVectorStore().add(columns);
+ getVectorStore().add(tables);
+ return true;
+ }
+ catch (Exception e) {
+ log.error("add document to vectorstore error", e);
+ return false;
+ }
+
+ }
+
+ protected Map> buildForeignKeyMap(List foreignKeys) {
+ Map> map = new HashMap<>();
+ for (ForeignKeyInfoBO fk : foreignKeys) {
+ String key = fk.getTable() + "." + fk.getColumn() + "=" + fk.getReferencedTable() + "."
+ + fk.getReferencedColumn();
+
+ map.computeIfAbsent(fk.getTable(), k -> new ArrayList<>()).add(key);
+ map.computeIfAbsent(fk.getReferencedTable(), k -> new ArrayList<>()).add(key);
+ }
+ return map;
+ }
+
+ protected abstract VectorStore getVectorStore();
+
+ protected void clearSchemaDataForAgent(String agentId) throws Exception {
+ deleteDocumentsByVectorType(agentId, Constant.COLUMN);
+ deleteDocumentsByVectorType(agentId, Constant.TABLE);
+ }
+
+ @Override
+ public Boolean deleteDocumentsByVectorType(String agentId, String vectorType) throws Exception {
+ Assert.notNull(agentId, "AgentId cannot be null.");
+ Assert.notNull(vectorType, "VectorType cannot be null.");
+
+ Map metadata = new HashMap<>(
+ Map.ofEntries(Map.entry(Constant.AGENT_ID, agentId), Map.entry(Constant.VECTOR_TYPE, vectorType)));
+
+ return this.deleteDocumentsByMetedata(agentId, metadata);
+ }
+
+ @Override
+ public void addDocuments(String agentId, List documents) {
+ Assert.notNull(agentId, "AgentId cannot be null.");
+ Assert.notEmpty(documents, "Documents cannot be empty.");
+ getVectorStore().add(documents);
+ }
+
+ @Override
+ public int estimateDocuments(String agentId) {
+ // 初略估算文档数目
+ List docs = getVectorStore()
+ .similaritySearch(org.springframework.ai.vectorstore.SearchRequest.builder()
+ .query("")
+ .filterExpression(buildFilterExpressionString(Map.of(Constant.AGENT_ID, agentId)))
+ .topK(Integer.MAX_VALUE) // 获取所有匹配的文档
+ .build());
+ return docs.size();
+ }
+
+ @Override
+ public Boolean deleteDocumentsByMetedata(String agentId, Map metadata) throws Exception {
+ Assert.notNull(agentId, "AgentId cannot be null.");
+ Assert.notNull(metadata, "Metadata cannot be null.");
+ // 添加agentId元数据过滤条件, 用于删除指定agentId下的所有数据,因为metadata中用户调用可能忘记添加agentId
+ metadata.put(Constant.AGENT_ID, agentId);
+ String filterExpression = buildFilterExpressionString(metadata);
+
+ // TODO 后续改成getVectorStore().delete(filterExpression);
+ // TODO 目前不支持通过元数据删除,使用会抛出UnsupportedOperationException,后续spring
+ // TODO ai发布1.1.0正式版本后再修改,现在是通过id删除
+
+ // 先搜索要删除的文档
+ List documentsToDelete = getVectorStore()
+ .similaritySearch(org.springframework.ai.vectorstore.SearchRequest.builder()
+ .query("")
+ .filterExpression(filterExpression)
+ .topK(Integer.MAX_VALUE)
+ .build());
+
+ // 提取文档ID并删除
+ if (!documentsToDelete.isEmpty()) {
+ List idsToDelete = documentsToDelete.stream().map(Document::getId).collect(Collectors.toList());
+ getVectorStore().delete(idsToDelete);
+ }
+
+ return true;
+ }
+
+ /**
+ * 构建过滤表达式字符串,目前FilterExpressionBuilder 不支持链式拼接元数据过滤,所以只能使用字符串拼接
+ * @param filterMap
+ * @return
+ */
+ protected final String buildFilterExpressionString(Map filterMap) {
+ if (filterMap == null || filterMap.isEmpty()) {
+ return null;
+ }
+
+ // 验证键名是否合法(只包含字母、数字和下划线)
+ for (String key : filterMap.keySet()) {
+ if (!key.matches("[a-zA-Z_][a-zA-Z0-9_]*")) {
+ throw new IllegalArgumentException("Invalid key name: " + key
+ + ". Keys must start with a letter or underscore and contain only alphanumeric characters and underscores.");
+ }
+ }
+
+ return filterMap.entrySet().stream().map(entry -> {
+ String key = entry.getKey();
+ Object value = entry.getValue();
+
+ // 处理空值
+ if (value == null) {
+ return key + " == null";
+ }
+
+ // 根据值的类型决定如何格式化
+ if (value instanceof String) {
+ // 转义字符串中的特殊字符
+ String escapedValue = escapeStringLiteral((String) value);
+ return key + " == '" + escapedValue + "'";
+ }
+ else if (value instanceof Number) {
+ // 数字类型直接使用
+ return key + " == " + value;
+ }
+ else if (value instanceof Boolean) {
+ // 布尔值使用小写形式
+ return key + " == " + ((Boolean) value).toString().toLowerCase();
+ }
+ else if (value instanceof Enum) {
+ // 枚举类型,转换为字符串并转义
+ String enumValue = ((Enum>) value).name();
+ String escapedValue = escapeStringLiteral(enumValue);
+ return key + " == '" + escapedValue + "'";
+ }
+ else {
+ // 其他类型尝试转换为字符串并转义
+ String stringValue = value.toString();
+ String escapedValue = escapeStringLiteral(stringValue);
+ return key + " == '" + escapedValue + "'";
+ }
+ }).collect(Collectors.joining(" && "));
+ }
+
+ /**
+ * 转义字符串字面量中的特殊字符
+ */
+ private String escapeStringLiteral(String input) {
+ if (input == null) {
+ return "";
+ }
+
+ // 转义反斜杠和单引号
+ String escaped = input.replace("\\", "\\\\").replace("'", "\\'");
+
+ // 转义其他特殊字符
+ escaped = escaped.replace("\n", "\\n")
+ .replace("\r", "\\r")
+ .replace("\t", "\\t")
+ .replace("\b", "\\b")
+ .replace("\f", "\\f");
+
+ return escaped;
+ }
+
+ @Override
+ public List getDocumentsForAgent(String agentId, String query, String vectorType) {
+ AgentSearchRequest searchRequest = AgentSearchRequest.getInstance(agentId);
+ searchRequest.setQuery(query);
+ searchRequest.setTopK(20);
+ searchRequest.setMetadataFilter(Map.of(Constant.VECTOR_TYPE, vectorType));
+
+ return search(searchRequest);
+ }
+
+ @Override
+ public boolean hasDocuments(String agentId) {
+ // 类似 MySQL 的 LIMIT 1,只检查是否存在文档
+ List docs = getVectorStore()
+ .similaritySearch(org.springframework.ai.vectorstore.SearchRequest.builder()
+ .query("")
+ .filterExpression(buildFilterExpressionString(Map.of(Constant.AGENT_ID, agentId)))
+ .topK(1) // 只获取1个文档
+ .build());
+ return !docs.isEmpty();
+ }
+
+}
diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/AbstractVectorStoreService.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/AbstractVectorStoreService.java
deleted file mode 100644
index 7d3d8ed..0000000
--- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/AbstractVectorStoreService.java
+++ /dev/null
@@ -1,77 +0,0 @@
-/*
- * Copyright 2024-2025 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package com.alibaba.cloud.ai.service.vectorstore;
-
-import com.alibaba.cloud.ai.request.SearchRequest;
-import org.springframework.ai.document.Document;
-import org.springframework.ai.embedding.EmbeddingModel;
-
-import java.util.ArrayList;
-import java.util.List;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-
-public abstract class AbstractVectorStoreService implements VectorStoreService {
-
- /**
- * Get embedding model
- */
- protected abstract EmbeddingModel getEmbeddingModel();
-
- /**
- * Convert text to Double type vector
- */
- public List embedDouble(String text) {
- return convertToDoubleList(getEmbeddingModel().embed(text));
- }
-
- /**
- * Convert text to Float type vector
- */
- public List embedFloat(String text) {
- return convertToFloatList(getEmbeddingModel().embed(text));
- }
-
- /**
- * Get documents from vector store
- */
- @Override
- public List getDocuments(String query, String vectorType) {
- SearchRequest request = new SearchRequest();
- request.setQuery(query);
- request.setVectorType(vectorType);
- request.setTopK(100);
- return new ArrayList<>(searchWithVectorType(request));
- }
-
- /**
- * Convert float[] to Double List
- */
- protected List convertToDoubleList(float[] array) {
- return IntStream.range(0, array.length)
- .mapToDouble(i -> (double) array[i])
- .boxed()
- .collect(Collectors.toList());
- }
-
- /**
- * Convert float[] to Float List
- */
- protected List convertToFloatList(float[] array) {
- return IntStream.range(0, array.length).mapToObj(i -> array[i]).collect(Collectors.toList());
- }
-
-}
diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/AgentVectorStoreManager.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/AgentVectorStoreManager.java
deleted file mode 100644
index 3b1aea4..0000000
--- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/AgentVectorStoreManager.java
+++ /dev/null
@@ -1,262 +0,0 @@
-/*
- * Copyright 2024-2025 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package com.alibaba.cloud.ai.service.vectorstore;
-
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import org.springframework.ai.document.Document;
-import org.springframework.ai.embedding.EmbeddingModel;
-import org.springframework.ai.vectorstore.SimpleVectorStore;
-import org.springframework.ai.vectorstore.filter.Filter;
-import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder;
-import org.springframework.stereotype.Service;
-
-import java.util.Collections;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import java.util.concurrent.ConcurrentHashMap;
-
-/**
- * Agent Vector Storage Manager Provides independent vector storage instances for each
- * agent, ensuring data isolation
- */
-@Service
-public class AgentVectorStoreManager {
-
- private static final Logger log = LoggerFactory.getLogger(AgentVectorStoreManager.class);
-
- private final Map agentStores = new ConcurrentHashMap<>();
-
- private final EmbeddingModel embeddingModel;
-
- public AgentVectorStoreManager(EmbeddingModel embeddingModel) {
- this.embeddingModel = embeddingModel;
- log.info("AgentVectorStoreManager initialized with EmbeddingModel: {}",
- embeddingModel.getClass().getSimpleName());
- }
-
- /**
- * Get or create agent-specific vector storage
- * @param agentId agent ID
- * @return agent-specific SimpleVectorStore instance
- */
- public SimpleVectorStore getOrCreateVectorStore(String agentId) {
- if (agentId == null || agentId.trim().isEmpty()) {
- throw new IllegalArgumentException("Agent ID cannot be null or empty");
- }
-
- return agentStores.computeIfAbsent(agentId, id -> {
- log.info("Creating new vector store for agent: {}", id);
- return SimpleVectorStore.builder(embeddingModel).build();
- });
- }
-
- /**
- * Add documents for specified agent
- * @param agentId agent ID
- * @param documents list of documents to add
- */
- public void addDocuments(String agentId, List documents) {
- if (documents == null || documents.isEmpty()) {
- log.warn("No documents to add for agent: {}", agentId);
- return;
- }
-
- SimpleVectorStore store = getOrCreateVectorStore(agentId);
- store.add(documents);
- log.info("Added {} documents to vector store for agent: {}", documents.size(), agentId);
- }
-
- /**
- * Search similar documents for specified agent
- * @param agentId agent ID
- * @param query query text
- * @param topK number of results to return
- * @return list of similar documents
- */
- public List similaritySearch(String agentId, String query, int topK) {
- SimpleVectorStore store = agentStores.get(agentId);
- if (store == null) {
- log.warn("No vector store found for agent: {}", agentId);
- return Collections.emptyList();
- }
-
- List results = store.similaritySearch(
- org.springframework.ai.vectorstore.SearchRequest.builder().query(query).topK(topK).build());
- log.debug("Found {} similar documents for agent: {} with query: {}", results.size(), agentId, query);
- return results;
- }
-
- /**
- * Search similar documents for specified agent (with filter conditions)
- * @param agentId agent ID
- * @param query query text
- * @param topK number of results to return
- * @param vectorType vector type filter
- * @return list of similar documents
- */
- public List similaritySearchWithFilter(String agentId, String query, int topK, String vectorType) {
- SimpleVectorStore store = agentStores.get(agentId);
- if (store == null) {
- log.warn("No vector store found for agent: {}", agentId);
- return Collections.emptyList();
- }
-
- FilterExpressionBuilder builder = new FilterExpressionBuilder();
- Filter.Expression expression = builder.eq("vectorType", vectorType).build();
-
- List results = store.similaritySearch(org.springframework.ai.vectorstore.SearchRequest.builder()
- .query(query)
- .topK(topK)
- .filterExpression(expression)
- .build());
-
- log.debug("Found {} filtered documents for agent: {} with query: {} and vectorType: {}", results.size(),
- agentId, query, vectorType);
- return results;
- }
-
- /**
- * Delete all data of specified agent
- * @param agentId agent ID
- */
- public void deleteAgentData(String agentId) {
- SimpleVectorStore removed = agentStores.remove(agentId);
- if (removed != null) {
- log.info("Deleted all vector data for agent: {}", agentId);
- }
- else {
- log.warn("No vector store found to delete for agent: {}", agentId);
- }
- }
-
- /**
- * Delete specific documents of specified agent
- * @param agentId agent ID
- * @param documentIds list of document IDs to delete
- */
- public void deleteDocuments(String agentId, List documentIds) {
- SimpleVectorStore store = agentStores.get(agentId);
- if (store == null) {
- log.warn("No vector store found for agent: {}", agentId);
- return;
- }
-
- if (documentIds != null && !documentIds.isEmpty()) {
- store.delete(documentIds);
- log.info("Deleted {} documents from vector store for agent: {}", documentIds.size(), agentId);
- }
- }
-
- /**
- * Delete specific type documents of specified agent
- * @param agentId agent ID
- * @param vectorType vector type
- */
- public void deleteDocumentsByType(String agentId, String vectorType) {
- SimpleVectorStore store = agentStores.get(agentId);
- if (store == null) {
- log.warn("No vector store found for agent: {}", agentId);
- return;
- }
-
- try {
- FilterExpressionBuilder builder = new FilterExpressionBuilder();
- Filter.Expression expression = builder.eq("vectorType", vectorType).build();
-
- List documents = store.similaritySearch(org.springframework.ai.vectorstore.SearchRequest.builder()
- .query("")
- .topK(Integer.MAX_VALUE)
- .filterExpression(expression)
- .build());
-
- if (!documents.isEmpty()) {
- List documentIds = documents.stream().map(Document::getId).toList();
- store.delete(documentIds);
- log.info("Deleted {} documents of type '{}' for agent: {}", documents.size(), vectorType, agentId);
- }
- else {
- log.info("No documents of type '{}' found for agent: {}", vectorType, agentId);
- }
- }
- catch (Exception e) {
- log.error("Failed to delete documents by type for agent: {}", agentId, e);
- throw new RuntimeException("Failed to delete documents by type: " + e.getMessage(), e);
- }
- }
-
- /**
- * Check if agent has vector data
- * @param agentId agent ID
- * @return whether has data
- */
- public boolean hasAgentData(String agentId) {
- return agentStores.containsKey(agentId);
- }
-
- /**
- * Get document count of agent (estimated)
- * @param agentId agent ID
- * @return document count
- */
- public int getDocumentCount(String agentId) {
- SimpleVectorStore store = agentStores.get(agentId);
- if (store == null) {
- return 0;
- }
-
- try {
- // Estimate quantity by searching all documents
- List allDocs = store.similaritySearch(org.springframework.ai.vectorstore.SearchRequest.builder()
- .query("")
- .topK(Integer.MAX_VALUE)
- .build());
- return allDocs.size();
- }
- catch (Exception e) {
- log.warn("Failed to get document count for agent: {}", agentId, e);
- return 0;
- }
- }
-
- /**
- * Get all agent IDs with data
- * @return set of agent IDs
- */
- public Set getAllAgentIds() {
- return Set.copyOf(agentStores.keySet());
- }
-
- /**
- * Get vector storage statistics
- * @return statistics
- */
- public Map getStatistics() {
- Map stats = new ConcurrentHashMap<>();
- stats.put("totalAgents", agentStores.size());
- stats.put("agentIds", getAllAgentIds());
-
- Map agentDocCounts = new ConcurrentHashMap<>();
- agentStores.forEach((agentId, store) -> {
- agentDocCounts.put(agentId, getDocumentCount(agentId));
- });
- stats.put("documentCounts", agentDocCounts);
-
- return stats;
- }
-
-}
diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/AgentVectorStoreService.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/AgentVectorStoreService.java
new file mode 100644
index 0000000..5485890
--- /dev/null
+++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/AgentVectorStoreService.java
@@ -0,0 +1,51 @@
+/*
+ * Copyright 2024-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.alibaba.cloud.ai.service.vectorstore;
+
+import java.util.List;
+import java.util.Map;
+
+import org.springframework.ai.document.Document;
+
+import com.alibaba.cloud.ai.request.AgentSearchRequest;
+import com.alibaba.cloud.ai.request.SchemaInitRequest;
+
+public interface AgentVectorStoreService {
+
+ /**
+ * 查询某个Agent的文档 总入口
+ */
+ List search(AgentSearchRequest searchRequest);
+
+ Boolean schema(String agentId, SchemaInitRequest schemaInitRequest) throws Exception;
+
+ Boolean deleteDocumentsByVectorType(String agentId, String vectorType) throws Exception;
+
+ Boolean deleteDocumentsByMetedata(String agentId, Map metadata) throws Exception;
+
+ /**
+ * Get documents for specified agent
+ */
+ List getDocumentsForAgent(String agentId, String query, String vectorType);
+
+ boolean hasDocuments(String agentId);
+
+ void addDocuments(String agentId, List documents);
+
+ int estimateDocuments(String agentId);
+
+}
diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/AgentVectorStoreServiceFactory.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/AgentVectorStoreServiceFactory.java
new file mode 100644
index 0000000..58756bb
--- /dev/null
+++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/AgentVectorStoreServiceFactory.java
@@ -0,0 +1,115 @@
+/*
+ * Copyright 2024-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.alibaba.cloud.ai.service.vectorstore;
+
+import com.alibaba.cloud.ai.connector.accessor.AccessorFactory;
+import com.alibaba.cloud.ai.service.vectorstore.impls.AnalyticAgentVectorStoreService;
+import com.alibaba.cloud.ai.service.vectorstore.impls.SimpleAgentVectorStoreService;
+import com.alibaba.cloud.ai.vectorstore.analyticdb.AnalyticDbVectorStore;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.ai.embedding.EmbeddingModel;
+import org.springframework.ai.vectorstore.VectorStore;
+import org.springframework.beans.factory.FactoryBean;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.stereotype.Component;
+
+import java.util.HashMap;
+import java.util.Map;
+
+@Slf4j
+@Component
+public class AgentVectorStoreServiceFactory implements FactoryBean {
+
+ // 使用枚举配置向量存储类型,由用户配置决定实现类
+ @Value("${spring.ai.vectorstore.type:SIMPLE}")
+ private VectorStoreType vectorStoreType;
+
+ @Autowired
+ private EmbeddingModel embeddingModel;
+
+ @Autowired
+ private AccessorFactory accessorFactory;
+
+ // 通用的向量存储bean,由Spring AI自动配置
+ @Autowired(required = false)
+ private VectorStore vectorStore;
+
+ @FunctionalInterface
+ private interface AgentVectorStoreServiceCreator {
+
+ AgentVectorStoreService create();
+
+ }
+
+ /**
+ * 存储不同向量存储类型对应的创建策略 使用Map可以方便地扩展新的向量存储类型
+ */
+ private final Map serviceCreators = new HashMap<>();
+
+ public AgentVectorStoreServiceFactory() {
+ // 初始化各种向量存储类型的创建策略
+ // 这里使用显式的匿名类实现而不是方法引用,以提高代码可读性
+ serviceCreators.put(VectorStoreType.ANALYTIC_DB, this::createAnalyticAgentVectorStoreService);
+
+ serviceCreators.put(VectorStoreType.SIMPLE, this::createSimpleAgentVectorStoreService);
+
+ // TODO 后续其他向量存储类型扩展处
+
+ }
+
+ @Override
+ public AgentVectorStoreService getObject() {
+ // 根据配置的向量存储类型获取对应的创建策略
+ AgentVectorStoreServiceCreator creator = serviceCreators.get(vectorStoreType);
+ if (creator == null) {
+ log.warn("Unsupported vector store type: {}, falling back to SIMPLE", vectorStoreType);
+ creator = serviceCreators.get(VectorStoreType.SIMPLE);
+ }
+
+ // 使用选定的策略创建AgentVectorStoreService实例
+ return creator.create();
+ }
+
+ @Override
+ public Class> getObjectType() {
+ return AgentVectorStoreService.class;
+ }
+
+ private AgentVectorStoreService createAnalyticAgentVectorStoreService() {
+ if (vectorStore == null) {
+ throw new IllegalStateException(
+ "AnalyticDbVectorStore is not configured. Please check your configuration.");
+ }
+ log.info("Using AnalyticDbVectorStoreService");
+ if (!(vectorStore instanceof AnalyticDbVectorStore)) {
+ throw new IllegalStateException("VectorStore is not an instance of AnalyticDbVectorStore");
+ }
+ return new AnalyticAgentVectorStoreService(embeddingModel, (AnalyticDbVectorStore) vectorStore,
+ accessorFactory);
+ }
+
+ /**
+ * 创建简单内存存储的AgentVectorStoreService
+ * @return SimpleAgentVectorStoreService实例
+ */
+ private AgentVectorStoreService createSimpleAgentVectorStoreService() {
+ log.info("Using SimpleVectorStoreService");
+ return new SimpleAgentVectorStoreService(embeddingModel, accessorFactory);
+ }
+
+}
diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/VectorStoreService.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/VectorStoreService.java
deleted file mode 100644
index 422641f..0000000
--- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/VectorStoreService.java
+++ /dev/null
@@ -1,66 +0,0 @@
-/*
- * Copyright 2024-2025 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.alibaba.cloud.ai.service.vectorstore;
-
-import com.alibaba.cloud.ai.request.SchemaInitRequest;
-import com.alibaba.cloud.ai.request.SearchRequest;
-import org.springframework.ai.document.Document;
-
-import java.util.List;
-
-public interface VectorStoreService {
-
- /**
- * Search interface with default filter
- */
- List searchWithVectorType(SearchRequest searchRequestDTO);
-
- /**
- * Search interface with custom filter
- */
- List searchWithFilter(SearchRequest searchRequestDTO);
-
- List getDocuments(String query, String vectorType);
-
- /**
- * Get documents for tables
- */
- default List searchTableByNameAndVectorType(SearchRequest searchRequestDTO) {
- throw new UnsupportedOperationException("Not implemented.");
- }
-
- /**
- * Get documents from vector store for specified agent
- */
- default List getDocumentsForAgent(String agentId, String query, String vectorType) {
- // Default implementation: if subclass doesn't override, use global search
- return getDocuments(query, vectorType);
- }
-
- default AgentVectorStoreManager getAgentVectorStoreManager() {
- throw new UnsupportedOperationException("Not implemented.");
- }
-
- default Boolean schemaForAgent(String agentId, SchemaInitRequest schemaInitRequest) throws Exception {
- throw new UnsupportedOperationException("Not implemented.");
- }
-
- default Boolean schema(SchemaInitRequest schemaInitRequest) throws Exception {
- throw new UnsupportedOperationException("Not implemented.");
- }
-
-}
diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/VectorStoreServiceFactory.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/VectorStoreServiceFactory.java
deleted file mode 100644
index d501219..0000000
--- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/VectorStoreServiceFactory.java
+++ /dev/null
@@ -1,67 +0,0 @@
-/*
- * Copyright 2024-2025 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.alibaba.cloud.ai.service.vectorstore;
-
-import com.alibaba.cloud.ai.connector.accessor.AccessorFactory;
-import com.alibaba.cloud.ai.service.vectorstore.impls.AnalyticVectorStoreService;
-import com.alibaba.cloud.ai.service.vectorstore.impls.SimpleVectorStoreService;
-import com.alibaba.cloud.ai.vectorstore.analyticdb.AnalyticDbVectorStoreProperties;
-import com.aliyun.gpdb20160503.Client;
-import org.springframework.ai.embedding.EmbeddingModel;
-import org.springframework.beans.factory.FactoryBean;
-import org.springframework.beans.factory.annotation.Autowired;
-import org.springframework.beans.factory.annotation.Value;
-import org.springframework.stereotype.Component;
-
-@Component
-public class VectorStoreServiceFactory implements FactoryBean {
-
- // todo: 改为枚举,由用户配置决定实现类
- @Value("${spring.ai.vectorstore.analytic.enabled:false}")
- private Boolean analyticEnabled;
-
- @Autowired
- private EmbeddingModel embeddingModel;
-
- @Autowired(required = false)
- private AnalyticDbVectorStoreProperties analyticDbVectorStoreProperties;
-
- @Autowired(required = false)
- private Client client;
-
- @Autowired
- private AccessorFactory accessorFactory;
-
- @Autowired
- private AgentVectorStoreManager agentVectorStoreManager;
-
- @Override
- public VectorStoreService getObject() {
- if (Boolean.TRUE.equals(analyticEnabled)) {
- return new AnalyticVectorStoreService(analyticDbVectorStoreProperties, embeddingModel, client);
- }
- else {
- return new SimpleVectorStoreService(embeddingModel, accessorFactory, agentVectorStoreManager);
- }
- }
-
- @Override
- public Class> getObjectType() {
- return VectorStoreService.class;
- }
-
-}
diff --git a/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/service/VectorStoreManagementService.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/VectorStoreType.java
similarity index 57%
rename from spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/service/VectorStoreManagementService.java
rename to spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/VectorStoreType.java
index f182e5e..9881152 100644
--- a/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/service/VectorStoreManagementService.java
+++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/VectorStoreType.java
@@ -13,20 +13,32 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package com.alibaba.cloud.ai.service;
-import com.alibaba.cloud.ai.request.DeleteRequest;
-import com.alibaba.cloud.ai.request.EvidenceRequest;
-import com.alibaba.cloud.ai.request.SchemaInitRequest;
+package com.alibaba.cloud.ai.service.vectorstore;
-import java.util.List;
+/**
+ * 向量存储类型枚举 用于配置使用哪种向量存储实现
+ */
+public enum VectorStoreType {
-public interface VectorStoreManagementService {
+ /**
+ * 简单向量存储(内存存储)
+ */
+ SIMPLE,
- Boolean addEvidence(List evidenceRequests);
+ /**
+ * 分析型数据库向量存储
+ */
+ ANALYTIC_DB,
- Boolean deleteDocuments(DeleteRequest deleteRequest) throws Exception;
+ /**
+ * Milvus 向量存储
+ */
+ MILVUS,
- Boolean schema(SchemaInitRequest schemaInitRequest) throws Exception;
+ /**
+ * PostgreSQL PGVector 向量存储
+ */
+ PGVECTOR
}
diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/impls/AnalyticAgentVectorStoreService.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/impls/AnalyticAgentVectorStoreService.java
new file mode 100644
index 0000000..fcf6051
--- /dev/null
+++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/impls/AnalyticAgentVectorStoreService.java
@@ -0,0 +1,49 @@
+/*
+ * Copyright 2024-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.alibaba.cloud.ai.service.vectorstore.impls;
+
+import com.alibaba.cloud.ai.service.vectorstore.AbstractAgentVectorStoreService;
+import org.springframework.ai.embedding.EmbeddingModel;
+import org.springframework.ai.vectorstore.VectorStore;
+
+import com.alibaba.cloud.ai.connector.accessor.AccessorFactory;
+import com.alibaba.cloud.ai.vectorstore.analyticdb.AnalyticDbVectorStore;
+
+public class AnalyticAgentVectorStoreService extends AbstractAgentVectorStoreService {
+
+ private final EmbeddingModel embeddingModel;
+
+ private final AnalyticDbVectorStore analyticDbVectorStore;
+
+ public AnalyticAgentVectorStoreService(EmbeddingModel embeddingModel, AnalyticDbVectorStore vectorStore,
+ AccessorFactory accessorFactory) {
+ super(accessorFactory);
+ this.embeddingModel = embeddingModel;
+ this.analyticDbVectorStore = vectorStore;
+ }
+
+ @Override
+ protected EmbeddingModel getEmbeddingModel() {
+ return embeddingModel;
+ }
+
+ @Override
+ protected VectorStore getVectorStore() {
+ return analyticDbVectorStore;
+ }
+
+}
diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/impls/AnalyticVectorStoreService.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/impls/AnalyticVectorStoreService.java
deleted file mode 100644
index 70f282a..0000000
--- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/impls/AnalyticVectorStoreService.java
+++ /dev/null
@@ -1,143 +0,0 @@
-/*
- * Copyright 2024-2025 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.alibaba.cloud.ai.service.vectorstore.impls;
-
-import com.alibaba.cloud.ai.request.SearchRequest;
-import com.alibaba.cloud.ai.service.vectorstore.AbstractVectorStoreService;
-import com.alibaba.cloud.ai.vectorstore.analyticdb.AnalyticDbVectorStoreProperties;
-import com.aliyun.gpdb20160503.Client;
-import com.aliyun.gpdb20160503.models.QueryCollectionDataRequest;
-import com.aliyun.gpdb20160503.models.QueryCollectionDataResponse;
-import com.fasterxml.jackson.core.type.TypeReference;
-import com.fasterxml.jackson.databind.ObjectMapper;
-import org.springframework.ai.document.Document;
-import org.springframework.ai.embedding.EmbeddingModel;
-
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.stream.Collectors;
-
-public class AnalyticVectorStoreService extends AbstractVectorStoreService {
-
- private static final String CONTENT_FIELD_NAME = "content";
-
- private static final String METADATA_FIELD_NAME = "metadata";
-
- private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
-
- private final EmbeddingModel embeddingModel;
-
- private final AnalyticDbVectorStoreProperties analyticDbVectorStoreProperties;
-
- private final Client client;
-
- public AnalyticVectorStoreService(AnalyticDbVectorStoreProperties analyticDbVectorStoreProperties,
- EmbeddingModel embeddingModel, Client client) {
- this.analyticDbVectorStoreProperties = analyticDbVectorStoreProperties;
- this.embeddingModel = embeddingModel;
- this.client = client;
- }
-
- @Override
- protected EmbeddingModel getEmbeddingModel() {
- return embeddingModel;
- }
-
- /**
- * Search interface with default filter
- */
- @Override
- public List searchWithVectorType(SearchRequest searchRequestDTO) {
- String filter = String.format("jsonb_extract_path_text(metadata, 'vectorType') = '%s'",
- searchRequestDTO.getVectorType());
-
- QueryCollectionDataRequest request = buildBaseRequest(searchRequestDTO).setFilter(filter);
-
- return executeQuery(request);
- }
-
- /**
- * Search interface with custom filter
- */
- @Override
- public List searchWithFilter(SearchRequest searchRequestDTO) {
- QueryCollectionDataRequest request = buildBaseRequest(searchRequestDTO)
- .setFilter(searchRequestDTO.getFilterFormatted());
- return executeQuery(request);
- }
-
- /**
- * Build basic query request object
- */
- private QueryCollectionDataRequest buildBaseRequest(SearchRequest searchRequestDTO) {
- QueryCollectionDataRequest queryCollectionDataRequest = new QueryCollectionDataRequest()
- .setDBInstanceId(analyticDbVectorStoreProperties.getDbInstanceId())
- .setRegionId(analyticDbVectorStoreProperties.getRegionId())
- .setNamespace(analyticDbVectorStoreProperties.getNamespace())
- .setNamespacePassword(analyticDbVectorStoreProperties.getNamespacePassword())
- .setCollection(analyticDbVectorStoreProperties.getCollectName())
- .setIncludeValues(false)
- .setMetrics(analyticDbVectorStoreProperties.getMetrics())
- .setTopK((long) searchRequestDTO.getTopK());
- if (searchRequestDTO.getQuery() != null) {
- queryCollectionDataRequest.setVector(embedDouble(searchRequestDTO.getQuery()));
- queryCollectionDataRequest.setContent(searchRequestDTO.getQuery());
- }
- return queryCollectionDataRequest;
- }
-
- /**
- * Execute actual query and parse results
- */
- private List executeQuery(QueryCollectionDataRequest request) {
- try {
- QueryCollectionDataResponse response = client.queryCollectionData(request);
- return parseDocuments(response);
- }
- catch (Exception e) {
- throw new RuntimeException("向量数据库查询失败: " + e.getMessage(), e);
- }
- }
-
- /**
- * Parse response data into Document list
- */
- private List parseDocuments(QueryCollectionDataResponse response) throws Exception {
- return response.getBody()
- .getMatches()
- .getMatch()
- .stream()
- .filter(match -> match.getScore() == null || match.getScore() > 0.1 || match.getScore() == 0.0)
- .map(match -> {
- Map metadata = match.getMetadata();
- try {
- Map metadataJson = OBJECT_MAPPER.readValue(metadata.get(METADATA_FIELD_NAME),
- new TypeReference>() {
- });
- metadataJson.put("score", match.getScore());
-
- return new Document(match.getId(), metadata.get(CONTENT_FIELD_NAME), metadataJson);
- }
- catch (Exception e) {
- throw new RuntimeException("解析元数据失败: " + e.getMessage(), e);
- }
- })
- .collect(Collectors.toList());
- }
-
-}
diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/impls/SimpleAgentVectorStoreService.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/impls/SimpleAgentVectorStoreService.java
new file mode 100644
index 0000000..da2d758
--- /dev/null
+++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/impls/SimpleAgentVectorStoreService.java
@@ -0,0 +1,47 @@
+/*
+ * Copyright 2024-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.alibaba.cloud.ai.service.vectorstore.impls;
+
+import org.springframework.ai.embedding.EmbeddingModel;
+import org.springframework.ai.vectorstore.SimpleVectorStore;
+import org.springframework.ai.vectorstore.VectorStore;
+
+import com.alibaba.cloud.ai.connector.accessor.AccessorFactory;
+import com.alibaba.cloud.ai.service.vectorstore.AbstractAgentVectorStoreService;
+
+public class SimpleAgentVectorStoreService extends AbstractAgentVectorStoreService {
+
+ private final SimpleVectorStore vectorStore;
+
+ private final EmbeddingModel embeddingModel;
+
+ public SimpleAgentVectorStoreService(EmbeddingModel embeddingModel, AccessorFactory accessorFactory) {
+ super(accessorFactory);
+ this.embeddingModel = embeddingModel;
+ this.vectorStore = SimpleVectorStore.builder(embeddingModel).build();
+ }
+
+ @Override
+ protected EmbeddingModel getEmbeddingModel() {
+ return embeddingModel;
+ }
+
+ @Override
+ protected VectorStore getVectorStore() {
+ return this.vectorStore;
+ }
+
+}
diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/impls/SimpleVectorStoreService.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/impls/SimpleVectorStoreService.java
deleted file mode 100644
index 4c5e59f..0000000
--- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/impls/SimpleVectorStoreService.java
+++ /dev/null
@@ -1,562 +0,0 @@
-/*
- * Copyright 2024-2025 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package com.alibaba.cloud.ai.service.vectorstore.impls;
-
-import com.alibaba.cloud.ai.connector.accessor.Accessor;
-import com.alibaba.cloud.ai.connector.accessor.AccessorFactory;
-import com.alibaba.cloud.ai.connector.bo.ColumnInfoBO;
-import com.alibaba.cloud.ai.connector.bo.DbQueryParameter;
-import com.alibaba.cloud.ai.connector.bo.ForeignKeyInfoBO;
-import com.alibaba.cloud.ai.connector.bo.TableInfoBO;
-import com.alibaba.cloud.ai.connector.config.DbConfig;
-import com.alibaba.cloud.ai.request.DeleteRequest;
-import com.alibaba.cloud.ai.request.SchemaInitRequest;
-import com.alibaba.cloud.ai.request.SearchRequest;
-import com.alibaba.cloud.ai.service.vectorstore.AgentVectorStoreManager;
-import com.alibaba.cloud.ai.service.vectorstore.AbstractVectorStoreService;
-import com.alibaba.cloud.ai.util.JsonUtil;
-import com.fasterxml.jackson.core.JsonProcessingException;
-import com.fasterxml.jackson.databind.ObjectMapper;
-import org.apache.commons.collections.CollectionUtils;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import org.springframework.ai.document.Document;
-import org.springframework.ai.embedding.EmbeddingModel;
-import org.springframework.ai.vectorstore.SimpleVectorStore;
-import org.springframework.ai.vectorstore.filter.Filter;
-import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder;
-
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Objects;
-import java.util.Optional;
-import java.util.stream.Collectors;
-
-public class SimpleVectorStoreService extends AbstractVectorStoreService {
-
- private static final Logger log = LoggerFactory.getLogger(SimpleVectorStoreService.class);
-
- private final SimpleVectorStore vectorStore; // Keep original global storage for
- // backward compatibility
-
- private final AgentVectorStoreManager agentVectorStoreManager; // New agent vector
- // storage manager
-
- private final ObjectMapper objectMapper;
-
- // todo: 根据数据源动态获取Accessor
- private final Accessor dbAccessor;
-
- private final EmbeddingModel embeddingModel;
-
- public SimpleVectorStoreService(EmbeddingModel embeddingModel, AccessorFactory accessorFactory,
- AgentVectorStoreManager agentVectorStoreManager) {
- this.objectMapper = JsonUtil.getObjectMapper();
- this.dbAccessor = accessorFactory.getAccessorByDbConfig(null);
- this.embeddingModel = embeddingModel;
- this.agentVectorStoreManager = agentVectorStoreManager;
- this.vectorStore = SimpleVectorStore.builder(embeddingModel).build();
- }
-
- @Override
- protected EmbeddingModel getEmbeddingModel() {
- return embeddingModel;
- }
-
- /**
- * Initialize database schema to vector store
- * @param schemaInitRequest schema initialization request
- * @throws Exception if an error occurs
- */
- @Override
- public Boolean schema(SchemaInitRequest schemaInitRequest) throws Exception {
- log.info("Starting schema initialization for database: {}, schema: {}, tables: {}",
- schemaInitRequest.getDbConfig().getUrl(), schemaInitRequest.getDbConfig().getSchema(),
- schemaInitRequest.getTables());
-
- DbConfig dbConfig = schemaInitRequest.getDbConfig();
- DbQueryParameter dqp = DbQueryParameter.from(dbConfig)
- .setSchema(dbConfig.getSchema())
- .setTables(schemaInitRequest.getTables());
-
- // Clean up old schema data
- DeleteRequest deleteRequest = new DeleteRequest();
- deleteRequest.setVectorType("column");
- deleteDocuments(deleteRequest);
- deleteRequest.setVectorType("table");
- deleteDocuments(deleteRequest);
-
- log.debug("Fetching foreign keys from database");
- List foreignKeyInfoBOS = dbAccessor.showForeignKeys(dbConfig, dqp);
- log.debug("Found {} foreign keys", foreignKeyInfoBOS.size());
- Map> foreignKeyMap = buildForeignKeyMap(foreignKeyInfoBOS);
-
- log.debug("Fetching tables from database");
- List tableInfoBOS = dbAccessor.fetchTables(dbConfig, dqp);
- log.info("Found {} tables to process", tableInfoBOS.size());
-
- for (TableInfoBO tableInfoBO : tableInfoBOS) {
- log.debug("Processing table: {}", tableInfoBO.getName());
- processTable(tableInfoBO, dqp, dbConfig, foreignKeyMap);
- }
-
- log.debug("Converting columns to documents");
- List columnDocuments = tableInfoBOS.stream().flatMap(table -> {
- try {
- dqp.setTable(table.getName());
- return dbAccessor.showColumns(dbConfig, dqp).stream().map(column -> convertToDocument(table, column));
- }
- catch (Exception e) {
- log.error("Error processing columns for table: {}", table.getName(), e);
- throw new RuntimeException(e);
- }
- }).collect(Collectors.toList());
-
- log.info("Adding {} column documents to vector store", columnDocuments.size());
- vectorStore.add(columnDocuments);
-
- log.debug("Converting tables to documents");
- List tableDocuments = tableInfoBOS.stream()
- .map(this::convertTableToDocument)
- .collect(Collectors.toList());
-
- log.info("Adding {} table documents to vector store", tableDocuments.size());
- vectorStore.add(tableDocuments);
-
- log.info("Schema initialization completed successfully. Total documents added: {}",
- columnDocuments.size() + tableDocuments.size());
- return true;
- }
-
- private void processTable(TableInfoBO tableInfoBO, DbQueryParameter dqp, DbConfig dbConfig,
- Map> foreignKeyMap) throws Exception {
- dqp.setTable(tableInfoBO.getName());
- List columnInfoBOS = dbAccessor.showColumns(dbConfig, dqp);
- for (ColumnInfoBO columnInfoBO : columnInfoBOS) {
- dqp.setColumn(columnInfoBO.getName());
- List sampleColumn = dbAccessor.sampleColumn(dbConfig, dqp);
- sampleColumn = Optional.ofNullable(sampleColumn)
- .orElse(new ArrayList<>())
- .stream()
- .filter(Objects::nonNull)
- .distinct()
- .limit(3)
- .filter(s -> s.length() <= 100)
- .toList();
-
- columnInfoBO.setTableName(tableInfoBO.getName());
- try {
- columnInfoBO.setSamples(objectMapper.writeValueAsString(sampleColumn));
- }
- catch (JsonProcessingException e) {
- columnInfoBO.setSamples("[]");
- }
- }
-
- List targetPrimaryList = columnInfoBOS.stream()
- .filter(ColumnInfoBO::isPrimary)
- .collect(Collectors.toList());
- if (CollectionUtils.isNotEmpty(targetPrimaryList)) {
- List columnNames = targetPrimaryList.stream()
- .map(ColumnInfoBO::getName)
- .collect(Collectors.toList());
- tableInfoBO.setPrimaryKeys(columnNames);
- }
- else {
- tableInfoBO.setPrimaryKeys(new ArrayList<>());
- }
- tableInfoBO
- .setForeignKey(String.join("、", foreignKeyMap.getOrDefault(tableInfoBO.getName(), new ArrayList<>())));
- }
-
- public Document convertToDocument(TableInfoBO tableInfoBO, ColumnInfoBO columnInfoBO) {
- log.debug("Converting column to document: table={}, column={}", tableInfoBO.getName(), columnInfoBO.getName());
-
- String text = Optional.ofNullable(columnInfoBO.getDescription()).orElse(columnInfoBO.getName());
- String id = tableInfoBO.getName() + "." + columnInfoBO.getName();
- Map metadata = new HashMap<>();
- metadata.put("id", id);
- metadata.put("name", columnInfoBO.getName());
- metadata.put("tableName", tableInfoBO.getName());
- metadata.put("description", Optional.ofNullable(columnInfoBO.getDescription()).orElse(""));
- metadata.put("type", columnInfoBO.getType());
- metadata.put("primary", columnInfoBO.isPrimary());
- metadata.put("notnull", columnInfoBO.isNotnull());
- metadata.put("vectorType", "column");
- if (columnInfoBO.getSamples() != null) {
- metadata.put("samples", columnInfoBO.getSamples());
- }
- // Multi-table duplicate field data will be deduplicated, using table name + field
- // name as unique identifier
- Document document = new Document(id, text, metadata);
- log.debug("Created column document with ID: {}", id);
- return document;
- }
-
- public Document convertTableToDocument(TableInfoBO tableInfoBO) {
- log.debug("Converting table to document: {}", tableInfoBO.getName());
-
- String text = Optional.ofNullable(tableInfoBO.getDescription()).orElse(tableInfoBO.getName());
- Map metadata = new HashMap<>();
- metadata.put("schema", Optional.ofNullable(tableInfoBO.getSchema()).orElse(""));
- metadata.put("name", tableInfoBO.getName());
- metadata.put("description", Optional.ofNullable(tableInfoBO.getDescription()).orElse(""));
- metadata.put("foreignKey", Optional.ofNullable(tableInfoBO.getForeignKey()).orElse(""));
- metadata.put("primaryKey", Optional.ofNullable(tableInfoBO.getPrimaryKeys()).orElse(new ArrayList<>()));
- metadata.put("vectorType", "table");
- Document document = new Document(tableInfoBO.getName(), text, metadata);
- log.debug("Created table document with ID: {}", tableInfoBO.getName());
- return document;
- }
-
- private Map> buildForeignKeyMap(List foreignKeyInfoBOS) {
- Map> foreignKeyMap = new HashMap<>();
- for (ForeignKeyInfoBO fk : foreignKeyInfoBOS) {
- String key = fk.getTable() + "." + fk.getColumn() + "=" + fk.getReferencedTable() + "."
- + fk.getReferencedColumn();
-
- foreignKeyMap.computeIfAbsent(fk.getTable(), k -> new ArrayList<>()).add(key);
- foreignKeyMap.computeIfAbsent(fk.getReferencedTable(), k -> new ArrayList<>()).add(key);
- }
- return foreignKeyMap;
- }
-
- /**
- * Delete vector data with specified conditions
- * @param deleteRequest delete request
- * @return whether deletion succeeded
- */
- public Boolean deleteDocuments(DeleteRequest deleteRequest) throws Exception {
- log.info("Starting delete operation with request: id={}, vectorType={}", deleteRequest.getId(),
- deleteRequest.getVectorType());
-
- try {
- if (deleteRequest.getId() != null && !deleteRequest.getId().isEmpty()) {
- log.debug("Deleting documents by ID: {}", deleteRequest.getId());
- vectorStore.delete(Arrays.asList(deleteRequest.getId()));
- log.info("Successfully deleted documents by ID");
- }
- else if (deleteRequest.getVectorType() != null && !deleteRequest.getVectorType().isEmpty()) {
- log.debug("Deleting documents by vectorType: {}", deleteRequest.getVectorType());
- FilterExpressionBuilder b = new FilterExpressionBuilder();
- Filter.Expression expression = b.eq("vectorType", deleteRequest.getVectorType()).build();
- List documents = vectorStore
- .similaritySearch(org.springframework.ai.vectorstore.SearchRequest.builder()
- .topK(Integer.MAX_VALUE)
- .filterExpression(expression)
- .build());
- if (documents != null && !documents.isEmpty()) {
- log.info("Found {} documents to delete with vectorType: {}", documents.size(),
- deleteRequest.getVectorType());
- vectorStore.delete(documents.stream().map(Document::getId).toList());
- log.info("Successfully deleted {} documents", documents.size());
- }
- else {
- log.info("No documents found to delete with vectorType: {}", deleteRequest.getVectorType());
- }
- }
- else {
- log.warn("Invalid delete request: either id or vectorType must be specified");
- throw new IllegalArgumentException("Either id or vectorType must be specified.");
- }
- return true;
- }
- catch (Exception e) {
- log.error("Failed to delete documents: {}", e.getMessage(), e);
- throw new Exception("Failed to delete collection data by filterExpression: " + e.getMessage(), e);
- }
- }
-
- /**
- * Search interface with default filter
- */
- @Override
- public List searchWithVectorType(SearchRequest searchRequestDTO) {
- log.debug("Searching with vectorType: {}, query: {}, topK: {}", searchRequestDTO.getVectorType(),
- searchRequestDTO.getQuery(), searchRequestDTO.getTopK());
-
- FilterExpressionBuilder b = new FilterExpressionBuilder();
- Filter.Expression expression = b.eq("vectorType", searchRequestDTO.getVectorType()).build();
-
- List results = vectorStore.similaritySearch(org.springframework.ai.vectorstore.SearchRequest.builder()
- .query(searchRequestDTO.getQuery())
- .topK(searchRequestDTO.getTopK())
- .filterExpression(expression)
- .build());
-
- if (results == null) {
- results = new ArrayList<>();
- }
-
- log.info("Search completed. Found {} documents for vectorType: {}", results.size(),
- searchRequestDTO.getVectorType());
- return results;
- }
-
- /**
- * Search interface with custom filter
- */
- @Override
- public List searchWithFilter(SearchRequest searchRequestDTO) {
- log.debug("Searching with custom filter: vectorType={}, query={}, topK={}", searchRequestDTO.getVectorType(),
- searchRequestDTO.getQuery(), searchRequestDTO.getTopK());
-
- // Need to parse filterFormatted field according to actual situation here, convert
- // to FilterExpressionBuilder expression
- // Simplified implementation, for demonstration only
- FilterExpressionBuilder b = new FilterExpressionBuilder();
- Filter.Expression expression = b.eq("vectorType", searchRequestDTO.getVectorType()).build();
-
- List results = vectorStore.similaritySearch(org.springframework.ai.vectorstore.SearchRequest.builder()
- .query(searchRequestDTO.getQuery())
- .topK(searchRequestDTO.getTopK())
- .filterExpression(expression)
- .build());
-
- if (results == null) {
- results = new ArrayList<>();
- }
-
- log.info("Search with filter completed. Found {} documents", results.size());
- return results;
- }
-
- @Override
- public List searchTableByNameAndVectorType(SearchRequest searchRequestDTO) {
- log.debug("Searching table by name and vectorType: name={}, vectorType={}, topK={}", searchRequestDTO.getName(),
- searchRequestDTO.getVectorType(), searchRequestDTO.getTopK());
-
- FilterExpressionBuilder b = new FilterExpressionBuilder();
- Filter.Expression expression = b
- .and(b.eq("vectorType", searchRequestDTO.getVectorType()), b.eq("id", searchRequestDTO.getName()))
- .build();
-
- List results = vectorStore.similaritySearch(org.springframework.ai.vectorstore.SearchRequest.builder()
- .topK(searchRequestDTO.getTopK())
- .filterExpression(expression)
- .build());
-
- if (results == null) {
- results = new ArrayList<>();
- }
-
- log.info("Search by name completed. Found {} documents for name: {}", results.size(),
- searchRequestDTO.getName());
- return results;
- }
-
- // ==================== 智能体相关的新方法 ====================
-
- /**
- * Initialize database schema to vector store for specified agent
- * @param agentId agent ID
- * @param schemaInitRequest schema initialization request
- * @throws Exception if an error occurs
- */
- @Override
- public Boolean schemaForAgent(String agentId, SchemaInitRequest schemaInitRequest) throws Exception {
- log.info("Starting schema initialization for agent: {}, database: {}, schema: {}, tables: {}", agentId,
- schemaInitRequest.getDbConfig().getUrl(), schemaInitRequest.getDbConfig().getSchema(),
- schemaInitRequest.getTables());
-
- DbConfig dbConfig = schemaInitRequest.getDbConfig();
- DbQueryParameter dqp = DbQueryParameter.from(dbConfig)
- .setSchema(dbConfig.getSchema())
- .setTables(schemaInitRequest.getTables());
-
- // Clean up agent's old data
- agentVectorStoreManager.deleteDocumentsByType(agentId, "column");
- agentVectorStoreManager.deleteDocumentsByType(agentId, "table");
-
- log.debug("Fetching foreign keys from database for agent: {}", agentId);
- List foreignKeyInfoBOS = dbAccessor.showForeignKeys(dbConfig, dqp);
- log.debug("Found {} foreign keys for agent: {}", foreignKeyInfoBOS.size(), agentId);
- Map> foreignKeyMap = buildForeignKeyMap(foreignKeyInfoBOS);
-
- log.debug("Fetching tables from database for agent: {}", agentId);
- List tableInfoBOS = dbAccessor.fetchTables(dbConfig, dqp);
- log.info("Found {} tables to process for agent: {}", tableInfoBOS.size(), agentId);
-
- for (TableInfoBO tableInfoBO : tableInfoBOS) {
- log.debug("Processing table: {} for agent: {}", tableInfoBO.getName(), agentId);
- processTable(tableInfoBO, dqp, dbConfig, foreignKeyMap);
- }
-
- log.debug("Converting columns to documents for agent: {}", agentId);
- List columnDocuments = tableInfoBOS.stream().flatMap(table -> {
- try {
- dqp.setTable(table.getName());
- return dbAccessor.showColumns(dbConfig, dqp)
- .stream()
- .map(column -> convertToDocumentForAgent(agentId, table, column));
- }
- catch (Exception e) {
- log.error("Error processing columns for table: {} and agent: {}", table.getName(), agentId, e);
- throw new RuntimeException(e);
- }
- }).collect(Collectors.toList());
-
- log.info("Adding {} column documents to vector store for agent: {}", columnDocuments.size(), agentId);
- agentVectorStoreManager.addDocuments(agentId, columnDocuments);
-
- log.debug("Converting tables to documents for agent: {}", agentId);
- List tableDocuments = tableInfoBOS.stream()
- .map(table -> convertTableToDocumentForAgent(agentId, table))
- .collect(Collectors.toList());
-
- log.info("Adding {} table documents to vector store for agent: {}", tableDocuments.size(), agentId);
- agentVectorStoreManager.addDocuments(agentId, tableDocuments);
-
- log.info("Schema initialization completed successfully for agent: {}. Total documents added: {}", agentId,
- columnDocuments.size() + tableDocuments.size());
- return true;
- }
-
- /**
- * Convert column information to documents for agent
- */
- private Document convertToDocumentForAgent(String agentId, TableInfoBO tableInfoBO, ColumnInfoBO columnInfoBO) {
- log.debug("Converting column to document for agent: {}, table={}, column={}", agentId, tableInfoBO.getName(),
- columnInfoBO.getName());
-
- String text = Optional.ofNullable(columnInfoBO.getDescription()).orElse(columnInfoBO.getName());
- String id = agentId + ":" + tableInfoBO.getName() + "." + columnInfoBO.getName();
- Map metadata = new HashMap<>();
- metadata.put("id", id);
- metadata.put("agentId", agentId);
- metadata.put("name", columnInfoBO.getName());
- metadata.put("tableName", tableInfoBO.getName());
- metadata.put("description", Optional.ofNullable(columnInfoBO.getDescription()).orElse(""));
- metadata.put("type", columnInfoBO.getType());
- metadata.put("primary", columnInfoBO.isPrimary());
- metadata.put("notnull", columnInfoBO.isNotnull());
- metadata.put("vectorType", "column");
- if (columnInfoBO.getSamples() != null) {
- metadata.put("samples", columnInfoBO.getSamples());
- }
-
- Document document = new Document(id, text, metadata);
- log.debug("Created column document with ID: {} for agent: {}", id, agentId);
- return document;
- }
-
- /**
- * Convert table information to documents for agent
- */
- private Document convertTableToDocumentForAgent(String agentId, TableInfoBO tableInfoBO) {
- log.debug("Converting table to document for agent: {}, table: {}", agentId, tableInfoBO.getName());
-
- String text = Optional.ofNullable(tableInfoBO.getDescription()).orElse(tableInfoBO.getName());
- String id = agentId + ":" + tableInfoBO.getName();
- Map metadata = new HashMap<>();
- metadata.put("agentId", agentId);
- metadata.put("schema", Optional.ofNullable(tableInfoBO.getSchema()).orElse(""));
- metadata.put("name", tableInfoBO.getName());
- metadata.put("description", Optional.ofNullable(tableInfoBO.getDescription()).orElse(""));
- metadata.put("foreignKey", Optional.ofNullable(tableInfoBO.getForeignKey()).orElse(""));
- metadata.put("primaryKey", Optional.ofNullable(tableInfoBO.getPrimaryKeys()).orElse(new ArrayList<>()));
- metadata.put("vectorType", "table");
-
- Document document = new Document(id, text, metadata);
- log.debug("Created table document with ID: {} for agent: {}", id, agentId);
- return document;
- }
-
- /**
- * Search vector data for specified agent
- */
- public List searchWithVectorTypeForAgent(String agentId, SearchRequest searchRequestDTO) {
- log.debug("Searching for agent: {}, vectorType: {}, query: {}, topK: {}", agentId,
- searchRequestDTO.getVectorType(), searchRequestDTO.getQuery(), searchRequestDTO.getTopK());
-
- List results = agentVectorStoreManager.similaritySearchWithFilter(agentId,
- searchRequestDTO.getQuery(), searchRequestDTO.getTopK(), searchRequestDTO.getVectorType());
-
- log.info("Search completed for agent: {}. Found {} documents for vectorType: {}", agentId, results.size(),
- searchRequestDTO.getVectorType());
- return results;
- }
-
- /**
- * Delete vector data for specified agent
- */
- public Boolean deleteDocumentsForAgent(String agentId, DeleteRequest deleteRequest) throws Exception {
- log.info("Starting delete operation for agent: {}, id={}, vectorType={}", agentId, deleteRequest.getId(),
- deleteRequest.getVectorType());
-
- try {
- if (deleteRequest.getId() != null && !deleteRequest.getId().isEmpty()) {
- log.debug("Deleting documents by ID for agent: {}, ID: {}", agentId, deleteRequest.getId());
- agentVectorStoreManager.deleteDocuments(agentId, Arrays.asList(deleteRequest.getId()));
- log.info("Successfully deleted documents by ID for agent: {}", agentId);
- }
- else if (deleteRequest.getVectorType() != null && !deleteRequest.getVectorType().isEmpty()) {
- log.debug("Deleting documents by vectorType for agent: {}, vectorType: {}", agentId,
- deleteRequest.getVectorType());
- agentVectorStoreManager.deleteDocumentsByType(agentId, deleteRequest.getVectorType());
- log.info("Successfully deleted documents by vectorType for agent: {}", agentId);
- }
- else {
- log.warn("Invalid delete request for agent: {}: either id or vectorType must be specified", agentId);
- throw new IllegalArgumentException("Either id or vectorType must be specified.");
- }
- return true;
- }
- catch (Exception e) {
- log.error("Failed to delete documents for agent: {}: {}", agentId, e.getMessage(), e);
- throw new Exception("Failed to delete collection data for agent " + agentId + ": " + e.getMessage(), e);
- }
- }
-
- /**
- * Get agent vector storage manager (for other services to use)
- */
- @Override
- public AgentVectorStoreManager getAgentVectorStoreManager() {
- return agentVectorStoreManager;
- }
-
- /**
- * Get documents from vector store for specified agent Override parent method, use
- * agent-specific vector storage
- */
- @Override
- public List getDocumentsForAgent(String agentId, String query, String vectorType) {
- log.debug("Getting documents for agent: {}, query: {}, vectorType: {}", agentId, query, vectorType);
-
- if (agentId == null || agentId.trim().isEmpty()) {
- log.warn("AgentId is null or empty, falling back to global search");
- return getDocuments(query, vectorType);
- }
-
- try {
- // Use agent vector storage manager for search
- List results = agentVectorStoreManager.similaritySearchWithFilter(agentId, query, 100, // topK
- vectorType);
-
- log.info("Found {} documents for agent: {}, vectorType: {}", results.size(), agentId, vectorType);
- return results;
- }
- catch (Exception e) {
- log.error("Error getting documents for agent: {}, falling back to global search", agentId, e);
- return getDocuments(query, vectorType);
- }
- }
-
-}
diff --git a/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/util/DocumentConverterUtil.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/util/DocumentConverterUtil.java
similarity index 74%
rename from spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/util/DocumentConverterUtil.java
rename to spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/util/DocumentConverterUtil.java
index d16d8db..fc1e781 100644
--- a/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/util/DocumentConverterUtil.java
+++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/util/DocumentConverterUtil.java
@@ -15,12 +15,14 @@
*/
package com.alibaba.cloud.ai.util;
+import java.util.*;
+import java.util.stream.Collectors;
+
+import org.springframework.ai.document.Document;
+
import com.alibaba.cloud.ai.connector.bo.ColumnInfoBO;
import com.alibaba.cloud.ai.connector.bo.TableInfoBO;
import com.alibaba.cloud.ai.request.EvidenceRequest;
-import org.springframework.ai.document.Document;
-
-import java.util.*;
/**
* Utility class for converting business objects to Document objects. Provides common
@@ -28,13 +30,28 @@
*/
public class DocumentConverterUtil {
+ public static List convertColumnsToDocuments(String agentId, List tables) {
+ List documents = new ArrayList<>();
+ for (TableInfoBO table : tables) {
+ // 使用已经处理过的列数据,避免重复查询
+ List columns = table.getColumns();
+ if (columns != null) {
+ for (ColumnInfoBO column : columns) {
+ documents.add(DocumentConverterUtil.convertColumnToDocumentForAgent(agentId, table, column));
+ }
+ }
+ }
+ return documents;
+ }
+
/**
* Converts a column info object to a Document for vector storage.
* @param tableInfoBO the table information containing schema details
* @param columnInfoBO the column information to convert
* @return Document object with column metadata
*/
- public static Document convertColumnToDocument(TableInfoBO tableInfoBO, ColumnInfoBO columnInfoBO) {
+ public static Document convertColumnToDocumentForAgent(String agentId, TableInfoBO tableInfoBO,
+ ColumnInfoBO columnInfoBO) {
String text = Optional.ofNullable(columnInfoBO.getDescription()).orElse(columnInfoBO.getName());
Map metadata = new HashMap<>();
metadata.put("name", columnInfoBO.getName());
@@ -44,6 +61,7 @@ public static Document convertColumnToDocument(TableInfoBO tableInfoBO, ColumnIn
metadata.put("primary", columnInfoBO.isPrimary());
metadata.put("notnull", columnInfoBO.isNotnull());
metadata.put("vectorType", "column");
+ metadata.put("agentId", agentId);
if (columnInfoBO.getSamples() != null) {
metadata.put("samples", columnInfoBO.getSamples());
@@ -57,7 +75,7 @@ public static Document convertColumnToDocument(TableInfoBO tableInfoBO, ColumnIn
* @param tableInfoBO the table information to convert
* @return Document object with table metadata
*/
- public static Document convertTableToDocument(TableInfoBO tableInfoBO) {
+ public static Document convertTableToDocumentForAgent(String agentId, TableInfoBO tableInfoBO) {
String text = Optional.ofNullable(tableInfoBO.getDescription()).orElse(tableInfoBO.getName());
Map metadata = new HashMap<>();
metadata.put("schema", Optional.ofNullable(tableInfoBO.getSchema()).orElse(""));
@@ -66,21 +84,29 @@ public static Document convertTableToDocument(TableInfoBO tableInfoBO) {
metadata.put("foreignKey", Optional.ofNullable(tableInfoBO.getForeignKey()).orElse(""));
metadata.put("primaryKey", Optional.ofNullable(tableInfoBO.getPrimaryKeys()).orElse(new ArrayList<>()));
metadata.put("vectorType", "table");
-
+ metadata.put("agentId", agentId);
return new Document(tableInfoBO.getName(), text, metadata);
}
+ public static List convertTablesToDocuments(String agentId, List tables) {
+ return tables.stream()
+ .map(table -> DocumentConverterUtil.convertTableToDocumentForAgent(agentId, table))
+ .collect(Collectors.toList());
+ }
+
/**
* Converts evidence requests to Documents for vector storage.
* @param evidenceRequests list of evidence requests
* @return list of Document objects
*/
- public static List convertEvidenceToDocuments(List evidenceRequests) {
+ public static List convertEvidenceToDocumentsForAgent(String agentId,
+ List evidenceRequests) {
return evidenceRequests.stream().map(evidenceRequest -> {
Map metadata = new HashMap<>();
metadata.put("evidenceType", evidenceRequest.getType());
metadata.put("vectorType", "evidence");
+ metadata.put("agentId", agentId);
return new Document(UUID.randomUUID().toString(), evidenceRequest.getContent(), metadata);
}).toList();
}
diff --git a/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/util/SchemaProcessorUtil.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/util/SchemaProcessorUtil.java
similarity index 75%
rename from spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/util/SchemaProcessorUtil.java
rename to spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/util/SchemaProcessorUtil.java
index dd8a1f8..358e038 100644
--- a/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/util/SchemaProcessorUtil.java
+++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/util/SchemaProcessorUtil.java
@@ -15,6 +15,9 @@
*/
package com.alibaba.cloud.ai.util;
+import java.util.*;
+import java.util.stream.Collectors;
+
import com.alibaba.cloud.ai.connector.accessor.Accessor;
import com.alibaba.cloud.ai.connector.bo.ColumnInfoBO;
import com.alibaba.cloud.ai.connector.bo.DbQueryParameter;
@@ -22,10 +25,6 @@
import com.alibaba.cloud.ai.connector.config.DbConfig;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
-import org.apache.commons.collections.CollectionUtils;
-
-import java.util.*;
-import java.util.stream.Collectors;
/**
* Utility class for processing database schema information. Provides common schema
@@ -71,12 +70,10 @@ public static void enrichTableMetadata(TableInfoBO tableInfoBO, DbQueryParameter
// 保存处理过的列数据到TableInfoBO,供后续使用
tableInfoBO.setColumns(columnInfoBOS);
- List targetPrimaryList = columnInfoBOS.stream()
- .filter(ColumnInfoBO::isPrimary)
- .collect(Collectors.toList());
+ List primaryKeyColumns = columnInfoBOS.stream().filter(ColumnInfoBO::isPrimary).toList();
- if (CollectionUtils.isNotEmpty(targetPrimaryList)) {
- List columnNames = targetPrimaryList.stream()
+ if (!primaryKeyColumns.isEmpty()) {
+ List columnNames = primaryKeyColumns.stream()
.map(ColumnInfoBO::getName)
.collect(Collectors.toList());
tableInfoBO.setPrimaryKeys(columnNames);
@@ -89,6 +86,31 @@ public static void enrichTableMetadata(TableInfoBO tableInfoBO, DbQueryParameter
.setForeignKey(String.join("、", foreignKeyMap.getOrDefault(tableInfoBO.getName(), new ArrayList<>())));
}
+ public static DbConfig createDbConfigFromDatasource(com.alibaba.cloud.ai.entity.Datasource datasource) {
+ DbConfig dbConfig = new DbConfig();
+
+ // Set basic connection information
+ dbConfig.setUrl(datasource.getConnectionUrl());
+ dbConfig.setUsername(datasource.getUsername());
+ dbConfig.setPassword(datasource.getPassword());
+
+ // TODO Set database type need to be optimized
+ if ("mysql".equalsIgnoreCase(datasource.getType())) {
+ dbConfig.setConnectionType("jdbc");
+ dbConfig.setDialectType("mysql");
+ }
+ else if ("h2".equalsIgnoreCase(datasource.getType())) {
+ dbConfig.setConnectionType("jdbc");
+ dbConfig.setDialectType("h2");
+ }
+ // Support for other database types can be extended here
+
+ // Set Schema as the database name of the data source
+ dbConfig.setSchema(datasource.getDatabaseName());
+
+ return dbConfig;
+ }
+
/**
* Private constructor to prevent instantiation.
*/
diff --git a/spring-ai-alibaba-data-agent-common/src/main/java/com/alibaba/cloud/ai/request/AgentSearchRequest.java b/spring-ai-alibaba-data-agent-common/src/main/java/com/alibaba/cloud/ai/request/AgentSearchRequest.java
new file mode 100644
index 0000000..1e9066e
--- /dev/null
+++ b/spring-ai-alibaba-data-agent-common/src/main/java/com/alibaba/cloud/ai/request/AgentSearchRequest.java
@@ -0,0 +1,128 @@
+/*
+ * Copyright 2024-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.alibaba.cloud.ai.request;
+
+import java.io.Serial;
+import java.io.Serializable;
+import java.util.Map;
+import java.util.Objects;
+
+public class AgentSearchRequest implements Serializable {
+
+ @Serial
+ private static final long serialVersionUID = 1L;
+
+ private final String agentId;
+
+ private String query;
+
+ private Integer topK;
+
+ private Map metadataFilter;
+
+ /**
+ * 私有构造函数,禁止直接实例化 此构造函数仅内部使用,外部代码必须通过getInstance()方法创建实例
+ */
+ private AgentSearchRequest(String agentId) {
+ Objects.requireNonNull(agentId, "Agent ID cannot be null");
+ this.agentId = agentId;
+ // 初始化metadataFilter,确保始终包含agentId
+ this.metadataFilter = Map.of("agentId", agentId);
+ }
+
+ /**
+ * 创建AgentSearchRequest实例的工厂方法
+ * @param agentId 代理ID,不能为空
+ * @return AgentSearchRequest实例
+ * @throws IllegalArgumentException 如果agentId为空
+ */
+ public static AgentSearchRequest getInstance(String agentId) {
+ return new AgentSearchRequest(agentId);
+ }
+
+ /**
+ * 创建AgentSearchRequest实例的工厂方法
+ * @param agentId 代理ID,不能为空
+ * @param query 查询内容
+ * @return AgentSearchRequest实例
+ * @throws IllegalArgumentException 如果agentId为空
+ */
+ public static AgentSearchRequest getInstance(String agentId, String query) {
+ AgentSearchRequest request = new AgentSearchRequest(agentId);
+ request.setQuery(query);
+ return request;
+ }
+
+ /**
+ * 创建AgentSearchRequest实例的工厂方法
+ * @param agentId 代理ID,不能为空
+ * @param query 查询内容
+ * @param topK 返回结果数量
+ * @return AgentSearchRequest实例
+ * @throws IllegalArgumentException 如果agentId为空
+ */
+ public static AgentSearchRequest getInstance(String agentId, String query, Integer topK) {
+ AgentSearchRequest request = new AgentSearchRequest(agentId);
+ request.setQuery(query);
+ request.setTopK(topK);
+ return request;
+ }
+
+ public String getAgentId() {
+ return agentId;
+ }
+
+ public Map getMetadataFilter() {
+ return metadataFilter;
+ }
+
+ public void setMetadataFilter(Map metadataFilter) {
+ if (metadataFilter == null) {
+ // 如果传入null,则创建只包含agentId的map
+ this.metadataFilter = Map.of("agentId", agentId);
+ }
+ else {
+ // 创建新的map,包含传入的所有参数和agentId
+ Map newFilter = new java.util.HashMap<>(metadataFilter);
+ newFilter.put("agentId", agentId);
+ this.metadataFilter = Map.copyOf(newFilter); // 创建不可变副本
+ }
+ }
+
+ public String getQuery() {
+ return query;
+ }
+
+ public void setQuery(String query) {
+ this.query = query;
+ }
+
+ public Integer getTopK() {
+ return topK;
+ }
+
+ public void setTopK(Integer topK) {
+ this.topK = topK;
+ }
+
+ @Override
+ public String toString() {
+ return "AgentSearchRequest{" + "agentId='" + agentId + '\'' + ", query='" + query + '\'' + ", topK=" + topK
+ + ", metadataFilter=" + metadataFilter + '}';
+ }
+
+}
diff --git a/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/controller/Nl2sqlForGraphController.java b/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/controller/Nl2sqlForGraphController.java
index f4cd255..56b72da 100644
--- a/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/controller/Nl2sqlForGraphController.java
+++ b/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/controller/Nl2sqlForGraphController.java
@@ -16,17 +16,14 @@
package com.alibaba.cloud.ai.controller;
-import com.alibaba.cloud.ai.connector.config.DbConfig;
import com.alibaba.cloud.ai.constant.Constant;
import com.alibaba.cloud.ai.graph.*;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.streaming.StreamingOutput;
import com.alibaba.cloud.ai.graph.NodeOutput;
import com.alibaba.cloud.ai.graph.state.StateSnapshot;
-import com.alibaba.cloud.ai.request.SchemaInitRequest;
-import com.alibaba.cloud.ai.service.vectorstore.VectorStoreService;
+import com.alibaba.cloud.ai.service.vectorstore.AgentVectorStoreService;
import com.alibaba.cloud.ai.service.DatasourceService;
-import com.alibaba.cloud.ai.entity.Datasource;
import com.alibaba.cloud.ai.service.AgentService;
import com.alibaba.cloud.ai.util.JsonUtil;
import com.fasterxml.jackson.databind.ObjectMapper;
@@ -43,12 +40,10 @@
import reactor.core.publisher.Flux;
import reactor.core.publisher.Sinks;
-import java.util.Arrays;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
-import static com.alibaba.cloud.ai.constant.Constant.AGENT_ID;
import static com.alibaba.cloud.ai.constant.Constant.HUMAN_FEEDBACK_NODE;
import static com.alibaba.cloud.ai.constant.Constant.INPUT_KEY;
import static com.alibaba.cloud.ai.constant.Constant.RESULT;
@@ -66,7 +61,7 @@ public class Nl2sqlForGraphController {
private final CompiledGraph compiledGraph;
- private final VectorStoreService vectorStoreService;
+ private final AgentVectorStoreService vectorStoreService;
private final DatasourceService datasourceService;
@@ -75,7 +70,7 @@ public class Nl2sqlForGraphController {
private final ObjectMapper objectMapper = JsonUtil.getObjectMapper();
public Nl2sqlForGraphController(@Qualifier("nl2sqlGraph") StateGraph stateGraph,
- VectorStoreService vectorStoreService, DatasourceService datasourceService, AgentService agentService)
+ AgentVectorStoreService vectorStoreService, DatasourceService datasourceService, AgentService agentService)
throws GraphStateException {
this.compiledGraph = stateGraph.compile(CompileConfig.builder().interruptBefore(HUMAN_FEEDBACK_NODE).build());
this.compiledGraph.setMaxIterations(100);
@@ -84,106 +79,6 @@ public Nl2sqlForGraphController(@Qualifier("nl2sqlGraph") StateGraph stateGraph,
this.agentService = agentService;
}
- @GetMapping("/search")
- public String search(
- @RequestParam(value = "query", required = false,
- defaultValue = "查询每个分类下已经成交且销量最高的商品及其销售总量,每个分类只返回销量最高的商品。") String query,
- @RequestParam(value = "dataSetId", required = false, defaultValue = "1") String dataSetId,
- @RequestParam(value = "agentId", required = false, defaultValue = "1") String agentId) throws Exception {
- // Get the data source configuration for an agent for vector initialization
- DbConfig dbConfig = getDbConfigForAgent(Integer.valueOf(agentId));
-
- SchemaInitRequest schemaInitRequest = new SchemaInitRequest();
- schemaInitRequest.setDbConfig(dbConfig);
- schemaInitRequest
- .setTables(Arrays.asList("categories", "order_items", "orders", "products", "users", "product_categories"));
- vectorStoreService.schema(schemaInitRequest);
-
- boolean humanReviewEnabled = false;
- try {
- var agent = agentService.findById(Long.valueOf(agentId));
- humanReviewEnabled = agent != null && agent.getHumanReviewEnabled() != null
- && agent.getHumanReviewEnabled() == 1;
- }
- catch (Exception ignore) {
- }
-
- Optional invoke = compiledGraph
- .call(Map.of(INPUT_KEY, query, AGENT_ID, agentId, HUMAN_REVIEW_ENABLED, humanReviewEnabled));
- OverAllState overAllState = invoke.get();
- // 注意:在新的人类反馈实现中,计划内容通过流式处理发送给前端
- // 这里不再需要单独获取计划内容
- return overAllState.value(RESULT).map(Object::toString).orElse("");
- }
-
- @GetMapping("/init")
- public void init(@RequestParam(value = "agentId", required = false, defaultValue = "1") Integer agentId)
- throws Exception {
- // Get the data source configuration for an agent for vector initialization
- DbConfig dbConfig = getDbConfigForAgent(agentId);
-
- SchemaInitRequest schemaInitRequest = new SchemaInitRequest();
- schemaInitRequest.setDbConfig(dbConfig);
- schemaInitRequest
- .setTables(Arrays.asList("categories", "order_items", "orders", "products", "users", "product_categories"));
- vectorStoreService.schema(schemaInitRequest);
- }
-
- /**
- * Get database configuration by agent ID
- */
- private DbConfig getDbConfigForAgent(Integer agentId) {
- try {
- // Get the enabled data source for an agent
- var agentDatasources = datasourceService.getAgentDatasources(agentId);
- var activeDatasource = agentDatasources.stream()
- .filter(ad -> ad.getIsActive() == 1)
- .findFirst()
- .orElseThrow(() -> new RuntimeException("智能体 " + agentId + " 未配置启用的数据源"));
-
- // Convert to DbConfig
- return createDbConfigFromDatasource(activeDatasource.getDatasource());
- }
- catch (Exception e) {
- logger.error("Failed to get agent datasource config for agent: {}", agentId, e);
- throw new RuntimeException("获取智能体数据源配置失败: " + e.getMessage(), e);
- }
- }
-
- /**
- * Create database configuration from data source entity
- */
- private DbConfig createDbConfigFromDatasource(Datasource datasource) {
- DbConfig dbConfig = new DbConfig();
-
- // Set basic connection information
- dbConfig.setUrl(datasource.getConnectionUrl());
- dbConfig.setUsername(datasource.getUsername());
- dbConfig.setPassword(datasource.getPassword());
-
- // Set database type
- if ("mysql".equalsIgnoreCase(datasource.getType())) {
- dbConfig.setConnectionType("jdbc");
- dbConfig.setDialectType("mysql");
- }
- else if ("postgresql".equalsIgnoreCase(datasource.getType())) {
- dbConfig.setConnectionType("jdbc");
- dbConfig.setDialectType("postgresql");
- }
- else if ("h2".equalsIgnoreCase(datasource.getType())) {
- dbConfig.setConnectionType("jdbc");
- dbConfig.setDialectType("h2");
- }
- else {
- throw new RuntimeException("不支持的数据库类型: " + datasource.getType());
- }
-
- // Set Schema to the database name of the data source
- dbConfig.setSchema(datasource.getDatabaseName());
-
- return dbConfig;
- }
-
@GetMapping(value = "/stream/search", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux> streamSearch(@RequestParam(value = "query") String query,
@RequestParam(value = "agentId") String agentId,
diff --git a/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/service/AbstractVectorStoreManagementService.java b/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/service/AbstractVectorStoreManagementService.java
deleted file mode 100644
index 92ea915..0000000
--- a/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/service/AbstractVectorStoreManagementService.java
+++ /dev/null
@@ -1,177 +0,0 @@
-/*
- * Copyright 2024-2025 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package com.alibaba.cloud.ai.service;
-
-import com.alibaba.cloud.ai.connector.accessor.Accessor;
-import com.alibaba.cloud.ai.connector.accessor.AccessorFactory;
-import com.alibaba.cloud.ai.connector.bo.ColumnInfoBO;
-import com.alibaba.cloud.ai.connector.bo.DbQueryParameter;
-import com.alibaba.cloud.ai.connector.bo.ForeignKeyInfoBO;
-import com.alibaba.cloud.ai.connector.bo.TableInfoBO;
-import com.alibaba.cloud.ai.connector.config.DbConfig;
-import com.alibaba.cloud.ai.request.DeleteRequest;
-import com.alibaba.cloud.ai.request.EvidenceRequest;
-import com.alibaba.cloud.ai.request.SchemaInitRequest;
-import com.alibaba.cloud.ai.util.DocumentConverterUtil;
-import com.alibaba.cloud.ai.util.JsonUtil;
-import com.alibaba.cloud.ai.util.SchemaProcessorUtil;
-import jakarta.annotation.PostConstruct;
-import lombok.extern.slf4j.Slf4j;
-import org.springframework.ai.document.Document;
-import org.springframework.ai.vectorstore.VectorStore;
-import org.springframework.ai.vectorstore.filter.Filter;
-import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder;
-import org.springframework.beans.factory.annotation.Autowired;
-
-import java.util.*;
-import java.util.stream.Collectors;
-
-@Slf4j
-public abstract class AbstractVectorStoreManagementService implements VectorStoreManagementService {
-
- @Autowired
- protected AccessorFactory accessorFactory;
-
- protected Accessor dbAccessor;
-
- protected DbConfig dbConfig;
-
- @PostConstruct
- public void init() {
- this.dbAccessor = accessorFactory.getAccessorByDbConfig(dbConfig);
- }
-
- @Override
- public Boolean addEvidence(List evidenceRequests) {
- List evidences = DocumentConverterUtil.convertEvidenceToDocuments(evidenceRequests);
- getVectorStore().add(evidences);
- return true;
- }
-
- @Override
- public Boolean deleteDocuments(DeleteRequest deleteRequest) throws Exception {
- try {
- if (deleteRequest.getId() != null && !deleteRequest.getId().isEmpty()) {
- getVectorStore().delete(List.of(deleteRequest.getId()));
- }
- else if (deleteRequest.getVectorType() != null && !deleteRequest.getVectorType().isEmpty()) {
- FilterExpressionBuilder builder = new FilterExpressionBuilder();
- Filter.Expression filterExpression = builder.eq("vectorType", deleteRequest.getVectorType()).build();
-
- getVectorStore().delete(filterExpression);
- }
- else {
- throw new IllegalArgumentException("Either id or vectorType must be specified.");
- }
- return true;
- }
- catch (Exception e) {
- throw new Exception("Failed to delete collection data by filterExpression: " + e.getMessage(), e);
- }
- }
-
- protected abstract VectorStore getVectorStore();
-
- // 模板方法 - 通用schema处理流程
- @Override
- public Boolean schema(SchemaInitRequest schemaInitRequest) throws Exception {
- try {
-
- DbConfig config = schemaInitRequest.getDbConfig();
- DbQueryParameter dqp = DbQueryParameter.from(config)
- .setSchema(config.getSchema())
- .setTables(schemaInitRequest.getTables());
-
- // 清理旧数据
- clearSchemaData();
-
- // 处理外键
- List foreignKeys = dbAccessor.showForeignKeys(config, dqp);
- Map> foreignKeyMap = buildForeignKeyMap(foreignKeys);
-
- // 处理表和列
- List tables = dbAccessor.fetchTables(config, dqp);
- for (TableInfoBO table : tables) {
- SchemaProcessorUtil.enrichTableMetadata(table, dqp, config, dbAccessor, JsonUtil.getObjectMapper(),
- foreignKeyMap);
- }
-
- // 转换为文档
- List columnDocs = convertColumnsToDocuments(tables);
- List tableDocs = convertTablesToDocuments(tables);
-
- // 存储文档
- return storeSchemaDocuments(columnDocs, tableDocs);
- }
- catch (Exception e) {
- log.error("Failed to process schema ", e);
- return false;
- }
- }
-
- // 通用辅助方法
- protected Map> buildForeignKeyMap(List foreignKeys) {
- Map> map = new HashMap<>();
- for (ForeignKeyInfoBO fk : foreignKeys) {
- String key = fk.getTable() + "." + fk.getColumn() + "=" + fk.getReferencedTable() + "."
- + fk.getReferencedColumn();
-
- map.computeIfAbsent(fk.getTable(), k -> new ArrayList<>()).add(key);
- map.computeIfAbsent(fk.getReferencedTable(), k -> new ArrayList<>()).add(key);
- }
- return map;
- }
-
- private List convertColumnsToDocuments(List