diff --git a/pom.xml b/pom.xml index c5d20c1..a63a015 100644 --- a/pom.xml +++ b/pom.xml @@ -49,7 +49,8 @@ 3.11.0 - + + 2.44.5 3.5.3 3.0.0 1.2.22 @@ -249,6 +250,24 @@ + + com.diffplug.spotless + spotless-maven-plugin + ${spotless-maven-plugin.version} + + + + + + + + compile + + apply + + + + org.apache.maven.plugins maven-compiler-plugin diff --git a/spring-ai-alibaba-data-agent-chat/README.md b/spring-ai-alibaba-data-agent-chat/README.md index 62c7c24..8713d60 100644 --- a/spring-ai-alibaba-data-agent-chat/README.md +++ b/spring-ai-alibaba-data-agent-chat/README.md @@ -465,7 +465,7 @@ import com.alibaba.cloud.ai.graph.OverAllState; import com.alibaba.cloud.ai.graph.StateGraph; import com.alibaba.cloud.ai.graph.exception.GraphStateException; import com.alibaba.cloud.ai.request.SchemaInitRequest; -import com.alibaba.cloud.ai.service.vectorstore.impls.SimpleVectorStoreService; +import com.alibaba.cloud.ai.service.vectorstore.impls.SimpleAgentVectorStoreService; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; @@ -491,7 +491,7 @@ public class Nl2sqlController { private final CompiledGraph compiledGraph; @Autowired - private SimpleVectorStoreService simpleVectorStoreService; + private SimpleAgentVectorStoreService simpleVectorStoreService; @Autowired private DbConfig dbConfig; diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/constant/Constant.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/constant/Constant.java index 42a899f..35bd698 100644 --- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/constant/Constant.java +++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/constant/Constant.java @@ -144,4 +144,16 @@ private Constant() { // 人类复核相关 public static final String HUMAN_REVIEW_ENABLED = "HUMAN_REVIEW_ENABLED"; + // column + public static final String COLUMN = "column"; + + // table + public static final String TABLE = "table"; + + // vectorType + public static final String VECTOR_TYPE = "vectorType"; + + // knowledgeId + public static final String KNOWLEDGE_ID = "knowledgeId"; + } diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/node/TableRelationNode.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/node/TableRelationNode.java index b9f2afc..0ac91a2 100644 --- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/node/TableRelationNode.java +++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/node/TableRelationNode.java @@ -20,7 +20,6 @@ import com.alibaba.cloud.ai.constant.Constant; import com.alibaba.cloud.ai.dto.BusinessKnowledgeDTO; import com.alibaba.cloud.ai.dto.SemanticModelDTO; -import com.alibaba.cloud.ai.entity.AgentDatasource; import com.alibaba.cloud.ai.entity.Datasource; import com.alibaba.cloud.ai.enums.StreamResponseType; import com.alibaba.cloud.ai.graph.OverAllState; @@ -32,6 +31,7 @@ import com.alibaba.cloud.ai.service.schema.SchemaService; import com.alibaba.cloud.ai.service.semantic.SemanticModelRecallService; import com.alibaba.cloud.ai.util.ChatResponseUtil; +import com.alibaba.cloud.ai.util.SchemaProcessorUtil; import com.alibaba.cloud.ai.util.StateUtil; import com.alibaba.cloud.ai.util.StreamingChatGeneratorUtil; import org.slf4j.Logger; @@ -40,6 +40,7 @@ import org.springframework.ai.document.Document; import org.springframework.dao.DataAccessException; import org.springframework.stereotype.Component; +import org.springframework.util.StringUtils; import reactor.core.publisher.Flux; import java.util.List; @@ -98,21 +99,20 @@ public Map apply(OverAllState state) throws Exception { COLUMN_DOCUMENTS_BY_KEYWORDS_OUTPUT); String dataSetId = StateUtil.getStringValue(state, Constant.AGENT_ID); String agentIdStr = StateUtil.getStringValue(state, AGENT_ID); - long agentId = -1L; - if (!agentIdStr.isEmpty()) { - agentId = Long.parseLong(agentIdStr); - } + if (!StringUtils.hasText(agentIdStr)) + throw new RuntimeException("Agent ID is empty."); // Execute business logic first - get final result immediately - SchemaDTO schemaDTO = buildInitialSchema(columnDocumentsByKeywords, tableDocuments); - SchemaDTO result = processSchemaSelection(schemaDTO, input, evidenceList, state); + DbConfig agentDbConfig = getAgentDbConfig(Integer.valueOf(agentIdStr)); + SchemaDTO schemaDTO = buildInitialSchema(agentIdStr, columnDocumentsByKeywords, tableDocuments, agentDbConfig); + SchemaDTO result = processSchemaSelection(schemaDTO, input, evidenceList, state, agentDbConfig); List businessKnowledges; List semanticModel; try { // Extract business knowledge and semantic model businessKnowledges = businessKnowledgeRecallService.getFieldByDataSetId(dataSetId); - semanticModel = semanticModelRecallService.getFieldByDataSetId(String.valueOf(agentId)); + semanticModel = semanticModelRecallService.getFieldByDataSetId(dataSetId); } catch (DataAccessException e) { logger.warn("Database query failed (attempt {}): {}", retryCount + 1, e.getMessage()); @@ -164,59 +164,27 @@ private String classifyDatabaseError(DataAccessException e) { /** * Builds initial schema from column and table documents. */ - private SchemaDTO buildInitialSchema(List> columnDocumentsByKeywords, - List tableDocuments) { + private SchemaDTO buildInitialSchema(String agentId, List> columnDocumentsByKeywords, + List tableDocuments, DbConfig agentDbConfig) { SchemaDTO schemaDTO = new SchemaDTO(); - schemaService.extractDatabaseName(schemaDTO); - schemaService.buildSchemaFromDocuments(columnDocumentsByKeywords, tableDocuments, schemaDTO); + + schemaService.extractDatabaseName(schemaDTO, agentDbConfig); + schemaService.buildSchemaFromDocuments(agentId, columnDocumentsByKeywords, tableDocuments, schemaDTO); return schemaDTO; } - /** - * Dynamically get the data source configuration for an agent - * @param state The state object containing the agent ID - * @return The database configuration corresponding to the agent, or null if not found - */ - private DbConfig getAgentDbConfig(OverAllState state) { - try { - // Get the agent ID from the state - String agentIdStr = StateUtil.getStringValue(state, Constant.AGENT_ID, null); - if (agentIdStr == null || agentIdStr.trim().isEmpty()) { - logger.debug("AgentId is null or empty, will use default dbConfig"); - return null; - } - - Integer agentId = Integer.valueOf(agentIdStr); - logger.debug("Getting datasource config for agent: {}", agentId); - - // Get the enabled data source for the agent - List agentDatasources = datasourceService.getAgentDatasources(agentId); - if (agentDatasources.isEmpty()) { - // TODO 调试AgentID不一致,暂时手动处理 - agentDatasources = datasourceService.getAgentDatasources(agentId - 999999); - } - - AgentDatasource activeDatasource = agentDatasources.stream() - .filter(ad -> ad.getIsActive() == 1) - .findFirst() - .orElse(null); - - if (activeDatasource == null) { - logger.debug("Agent {} has no active datasource, will use default dbConfig", agentId); - return null; - } + private DbConfig getAgentDbConfig(Integer agentId) { + // Get the enabled data source for the agent + Datasource agentDatasource = datasourceService.getActiveDatasourceByAgentId(agentId); + if (agentDatasource == null) + throw new RuntimeException("No active datasource found for agent " + agentId); - // Convert to DbConfig - DbConfig dbConfig = createDbConfigFromDatasource(activeDatasource.getDatasource()); - logger.debug("Successfully created DbConfig for agent {}: url={}, schema={}, type={}", agentId, - dbConfig.getUrl(), dbConfig.getSchema(), dbConfig.getDialectType()); + // Convert to DbConfig + DbConfig dbConfig = SchemaProcessorUtil.createDbConfigFromDatasource(agentDatasource); + logger.debug("Successfully created DbConfig for agent {}: url={}, schema={}, type={}", agentId, + dbConfig.getUrl(), dbConfig.getSchema(), dbConfig.getDialectType()); - return dbConfig; - } - catch (Exception e) { - logger.warn("Failed to get agent datasource config, will use default dbConfig: {}", e.getMessage()); - return null; - } + return dbConfig; } /** @@ -255,13 +223,9 @@ else if ("postgresql".equalsIgnoreCase(datasource.getType())) { * Processes schema selection based on input, evidence, and optional advice. */ private SchemaDTO processSchemaSelection(SchemaDTO schemaDTO, String input, List evidenceList, - OverAllState state) { + OverAllState state, DbConfig agentDbConfig) { String schemaAdvice = StateUtil.getStringValue(state, SQL_GENERATE_SCHEMA_MISSING_ADVICE, null); - // 动态获取Agent对应的数据库配置 - DbConfig agentDbConfig = getAgentDbConfig(state); - logger.debug("Using agent-specific dbConfig: {}", agentDbConfig != null ? agentDbConfig.getUrl() : "default"); - if (schemaAdvice != null) { logger.info("[{}] Processing with schema supplement advice: {}", this.getClass().getSimpleName(), schemaAdvice); diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/DatasourceService.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/DatasourceService.java index 9357be3..ee74914 100644 --- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/DatasourceService.java +++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/DatasourceService.java @@ -299,4 +299,15 @@ public Map getDatasourceStats() { return stats; } + public Datasource getActiveDatasourceByAgentId(Integer agentId) { + AgentDatasource agentDatasource = getAgentDatasources(agentId).stream() + .filter(a -> a.getIsActive() == 1) + .findFirst() + .orElse(null); + if (agentDatasource == null) { + return null; + } + return agentDatasource.getDatasource(); + } + } diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/processing/AbstractQueryProcessingService.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/processing/AbstractQueryProcessingService.java index e4e684f..c59276e 100644 --- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/processing/AbstractQueryProcessingService.java +++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/processing/AbstractQueryProcessingService.java @@ -19,17 +19,20 @@ import com.alibaba.cloud.ai.dto.schema.SchemaDTO; import com.alibaba.cloud.ai.prompt.PromptConstant; import com.alibaba.cloud.ai.prompt.PromptHelper; +import com.alibaba.cloud.ai.service.DatasourceService; import com.alibaba.cloud.ai.service.LlmService; import com.alibaba.cloud.ai.service.nl2sql.Nl2SqlService; import com.alibaba.cloud.ai.service.schema.SchemaService; -import com.alibaba.cloud.ai.service.vectorstore.VectorStoreService; +import com.alibaba.cloud.ai.service.vectorstore.AgentVectorStoreService; import com.alibaba.cloud.ai.util.JsonUtil; +import com.alibaba.cloud.ai.util.SchemaProcessorUtil; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.document.Document; +import org.springframework.util.Assert; import reactor.core.publisher.Flux; import java.util.Collections; @@ -44,11 +47,14 @@ public abstract class AbstractQueryProcessingService implements QueryProcessingS private final LlmService aiService; - public AbstractQueryProcessingService(LlmService aiService) { + private final DatasourceService datasourceService; + + public AbstractQueryProcessingService(LlmService aiService, DatasourceService datasourceService) { this.aiService = aiService; + this.datasourceService = datasourceService; } - protected abstract VectorStoreService getVectorStoreService(); + protected abstract AgentVectorStoreService getVectorStoreService(); protected abstract SchemaService getSchemaService(); @@ -57,13 +63,9 @@ public AbstractQueryProcessingService(LlmService aiService) { @Override public List extractEvidences(String query, String agentId) { logger.debug("Extracting evidences for query: {} with agentId: {}", query, agentId); - List evidenceDocuments; - if (agentId != null && !agentId.trim().isEmpty()) { - evidenceDocuments = getVectorStoreService().getDocumentsForAgent(agentId, query, "evidence"); - } - else { - evidenceDocuments = getVectorStoreService().getDocuments(query, "evidence"); - } + Assert.notNull(agentId, "AgentId cannot be null"); + List evidenceDocuments = getVectorStoreService().getDocumentsForAgent(agentId, query, "evidence"); + List evidences = evidenceDocuments.stream().map(Document::getText).collect(Collectors.toList()); logger.debug("Extracted {} evidences: {}", evidences.size(), evidences); return evidences; @@ -176,17 +178,20 @@ private String processTimeExpressions(String query) { } private SchemaDTO select(String query, List evidenceList, String agentId) throws Exception { + Assert.notNull(agentId, "AgentId cannot be null"); logger.debug("Starting schema selection for query: {} with {} evidences and agentId: {}", query, evidenceList.size(), agentId); List keywords = extractKeywords(query, evidenceList); logger.debug("Using {} keywords for schema selection", keywords != null ? keywords.size() : 0); - SchemaDTO schemaDTO; - if (agentId != null) { - schemaDTO = getSchemaService().mixRagForAgent(agentId, query, keywords); - } - else { - schemaDTO = getSchemaService().mixRag(query, keywords); + + com.alibaba.cloud.ai.entity.Datasource datasource = datasourceService + .getActiveDatasourceByAgentId(Integer.valueOf(agentId)); + if (datasource == null) { + throw new RuntimeException("No active datasource found for agentId: " + agentId); } + SchemaDTO schemaDTO = getSchemaService().mixRagForAgent(agentId, query, keywords, + SchemaProcessorUtil.createDbConfigFromDatasource(datasource)); + logger.debug("Retrieved schema with {} tables", schemaDTO.getTable() != null ? schemaDTO.getTable().size() : 0); SchemaDTO result = fineSelect(schemaDTO, query, evidenceList); logger.debug("Fine selection completed, final schema has {} tables", diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/processing/impls/QueryProcessingServiceImpl.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/processing/impls/QueryProcessingServiceImpl.java index d97579e..83efe51 100644 --- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/processing/impls/QueryProcessingServiceImpl.java +++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/processing/impls/QueryProcessingServiceImpl.java @@ -16,32 +16,33 @@ package com.alibaba.cloud.ai.service.processing.impls; +import com.alibaba.cloud.ai.service.DatasourceService; import com.alibaba.cloud.ai.service.LlmService; import com.alibaba.cloud.ai.service.nl2sql.Nl2SqlService; import com.alibaba.cloud.ai.service.schema.SchemaService; import com.alibaba.cloud.ai.service.processing.AbstractQueryProcessingService; -import com.alibaba.cloud.ai.service.vectorstore.VectorStoreService; +import com.alibaba.cloud.ai.service.vectorstore.AgentVectorStoreService; import org.springframework.stereotype.Service; @Service public class QueryProcessingServiceImpl extends AbstractQueryProcessingService { - private final VectorStoreService vectorStoreService; + private final AgentVectorStoreService vectorStoreService; private final SchemaService schemaService; private final Nl2SqlService nl2SqlService; - public QueryProcessingServiceImpl(LlmService aiService, VectorStoreService vectorStoreService, - SchemaService schemaService, Nl2SqlService nl2SqlService) { - super(aiService); + public QueryProcessingServiceImpl(LlmService aiService, AgentVectorStoreService vectorStoreService, + SchemaService schemaService, Nl2SqlService nl2SqlService, DatasourceService datasourceService) { + super(aiService, datasourceService); this.vectorStoreService = vectorStoreService; this.schemaService = schemaService; this.nl2SqlService = nl2SqlService; } @Override - protected VectorStoreService getVectorStoreService() { + protected AgentVectorStoreService getVectorStoreService() { return this.vectorStoreService; } diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/AbstractSchemaService.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/AbstractSchemaService.java index ee0eae5..e90730a 100644 --- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/AbstractSchemaService.java +++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/AbstractSchemaService.java @@ -15,29 +15,28 @@ */ package com.alibaba.cloud.ai.service.schema; +import com.alibaba.cloud.ai.constant.Constant; import com.alibaba.cloud.ai.enums.BizDataSourceTypeEnum; import com.alibaba.cloud.ai.connector.config.DbConfig; -import com.alibaba.cloud.ai.request.SearchRequest; import com.alibaba.cloud.ai.dto.schema.ColumnDTO; import com.alibaba.cloud.ai.dto.schema.SchemaDTO; import com.alibaba.cloud.ai.dto.schema.TableDTO; -import com.alibaba.cloud.ai.service.vectorstore.VectorStoreService; +import com.alibaba.cloud.ai.service.vectorstore.AgentVectorStoreService; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.StringUtils; import org.springframework.ai.document.Document; +import org.springframework.util.Assert; import java.util.ArrayList; import java.util.Arrays; -import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; -import java.util.function.Function; import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.stream.Collectors; @@ -45,19 +44,17 @@ /** * Schema service base class, providing common method implementations */ +@Slf4j public abstract class AbstractSchemaService implements SchemaService { - protected final DbConfig dbConfig; - protected final ObjectMapper objectMapper; /** * Vector storage service */ - protected final VectorStoreService vectorStoreService; + protected final AgentVectorStoreService vectorStoreService; - public AbstractSchemaService(DbConfig dbConfig, ObjectMapper objectMapper, VectorStoreService vectorStoreService) { - this.dbConfig = dbConfig; + public AbstractSchemaService(ObjectMapper objectMapper, AgentVectorStoreService vectorStoreService) { this.objectMapper = objectMapper; this.vectorStoreService = vectorStoreService; } @@ -67,42 +64,44 @@ public AbstractSchemaService(DbConfig dbConfig, ObjectMapper objectMapper, Vecto * @param agentId agent ID * @param query query * @param keywords keyword list + * @param dbConfig Database configuration * @return SchemaDTO */ @Override - public SchemaDTO mixRagForAgent(String agentId, String query, List keywords) { + public SchemaDTO mixRagForAgent(String agentId, String query, List keywords, DbConfig dbConfig) { SchemaDTO schemaDTO = new SchemaDTO(); - extractDatabaseName(schemaDTO); // Set database name or schema name + extractDatabaseName(schemaDTO, dbConfig); // Set database name or schema name + + // Get table documents + List tableDocuments = getTableDocumentsForAgent(agentId, query); - List tableDocuments = getTableDocuments(query, agentId); // Get table - // documents - List> columnDocumentList = getColumnDocumentsByKeywords(keywords, agentId); // Get - // column - // document - // list + // Get column documents + List> columnDocumentList = getColumnDocumentsByKeywordsForAgent(agentId, keywords); - buildSchemaFromDocuments(columnDocumentList, tableDocuments, schemaDTO); + buildSchemaFromDocuments(agentId, columnDocumentList, tableDocuments, schemaDTO); return schemaDTO; } @Override - public void buildSchemaFromDocuments(List> columnDocumentList, List tableDocuments, - SchemaDTO schemaDTO) { + public void buildSchemaFromDocuments(String agentId, List> columnDocumentList, + List tableDocuments, SchemaDTO schemaDTO) { // Process column weights and sort by table association - processColumnWeights(columnDocumentList, tableDocuments); + updateAndSortColumnScoresByTableWeights(columnDocumentList, tableDocuments); // Initialize column selector, TODO upper limit 100 has issues - Map weightedColumns = selectWeightedColumns(columnDocumentList, 100); + Map weightedColumns = selectColumnsByRoundRobin(columnDocumentList, 100); - Set foreignKeySet = extractForeignKeyRelations(tableDocuments); + // 如果外键关系是"订单表.订单ID=订单详情表.订单ID",那么 relatedNamesFromForeignKeys + // 将包含"订单表.订单ID"和"订单详情表.订单ID" + Set relatedNamesFromForeignKeys = extractRelatedNamesFromForeignKeys(tableDocuments); // Build table list List tableList = buildTableListFromDocuments(tableDocuments); // Supplement missing foreign key corresponding tables - expandTableDocumentsWithForeignKeys(tableDocuments, foreignKeySet, "table"); - expandColumnDocumentsWithForeignKeys(weightedColumns, foreignKeySet, "column"); + expandTableDocumentsWithForeignKeys(agentId, tableDocuments, relatedNamesFromForeignKeys); + expandColumnDocumentsWithForeignKeys(agentId, weightedColumns, relatedNamesFromForeignKeys); // Attach weighted columns to corresponding tables attachColumnsToTables(weightedColumns, tableList); @@ -118,38 +117,13 @@ public void buildSchemaFromDocuments(List> columnDocumentList, Li schemaDTO.setForeignKeys(List.of(new ArrayList<>(foreignKeys))); } - /** - * Get all table documents by keywords - supports agent isolation - */ - public List getTableDocuments(String query, String agentId) { - if (agentId != null && !agentId.trim().isEmpty()) { - return vectorStoreService.getDocumentsForAgent(agentId, query, "table"); - } - else { - return vectorStoreService.getDocuments(query, "table"); - } - } - /** * Get all table documents by keywords for specified agent */ @Override public List getTableDocumentsForAgent(String agentId, String query) { - return vectorStoreService.getDocumentsForAgent(agentId, query, "table"); - } - - /** - * Get all column documents by keywords - supports agent isolation - */ - public List> getColumnDocumentsByKeywords(List keywords, String agentId) { - if (agentId != null) { - return getColumnDocumentsByKeywordsForAgent(agentId, keywords); - } - else { - return keywords.stream() - .map(kw -> vectorStoreService.getDocuments(kw, "column")) - .collect(Collectors.toList()); - } + Assert.notNull(agentId, "agentId cannot be null"); + return vectorStoreService.getDocumentsForAgent(agentId, query, Constant.TABLE); } /** @@ -157,21 +131,19 @@ public List> getColumnDocumentsByKeywords(List keywords, */ @Override public List> getColumnDocumentsByKeywordsForAgent(String agentId, List keywords) { - if (agentId == null) { - return keywords.stream() - .map(kw -> vectorStoreService.getDocuments(kw, "column")) - .collect(Collectors.toList()); - } + + Assert.notNull(agentId, "agentId cannot be null"); + return keywords.stream() - .map(kw -> vectorStoreService.getDocumentsForAgent(agentId, kw, "column")) + .map(kw -> vectorStoreService.getDocumentsForAgent(agentId, kw, Constant.COLUMN)) .collect(Collectors.toList()); } /** * Expand column documents (supplement missing columns through foreign keys) */ - private void expandColumnDocumentsWithForeignKeys(Map weightedColumns, Set foreignKeySet, - String vectorType) { + private void expandColumnDocumentsWithForeignKeys(String agentId, Map weightedColumns, + Set foreignKeySet) { Set existingColumnNames = weightedColumns.keySet(); Set missingColumns = new HashSet<>(); @@ -182,7 +154,7 @@ private void expandColumnDocumentsWithForeignKeys(Map weighted } for (String columnName : missingColumns) { - addColumnsDocument(weightedColumns, columnName, vectorType); + addColumnsDocument(agentId, weightedColumns, columnName); } } @@ -190,8 +162,8 @@ private void expandColumnDocumentsWithForeignKeys(Map weighted /** * Expand table documents (supplement missing tables through foreign keys) */ - private void expandTableDocumentsWithForeignKeys(List tableDocuments, Set foreignKeySet, - String vectorType) { + private void expandTableDocumentsWithForeignKeys(String agentId, List tableDocuments, + Set foreignKeySet) { Set uniqueTableNames = tableDocuments.stream() .map(doc -> (String) doc.getMetadata().get("name")) .collect(Collectors.toSet()); @@ -208,46 +180,67 @@ private void expandTableDocumentsWithForeignKeys(List tableDocuments, } for (String tableName : missingTables) { - addTableDocument(tableDocuments, tableName, vectorType); + addTableDocument(agentId, tableDocuments, tableName); } } - /** - * Add missing table documents - * @param tableDocuments - * @param tableName - * @param vectorType - */ - protected abstract void addTableDocument(List tableDocuments, String tableName, String vectorType); + protected void addTableDocument(String agentId, List tableDocuments, String tableName) { + List documentsForAgent = vectorStoreService.getDocumentsForAgent(agentId, tableName, Constant.TABLE); + if (documentsForAgent != null && !documentsForAgent.isEmpty()) + tableDocuments.addAll(documentsForAgent); + } - protected abstract void addColumnsDocument(Map weightedColumns, String columnName, - String vectorType); + protected void addColumnsDocument(String agentId, Map weightedColumns, String columnName) { + List documentsForAgent = vectorStoreService.getDocumentsForAgent(agentId, columnName, + Constant.COLUMN); + if (documentsForAgent != null && !documentsForAgent.isEmpty()) { + for (Document document : documentsForAgent) + weightedColumns.putIfAbsent(document.getId(), document); + } + } /** - * Select up to maxCount columns by weight + * Select up to maxCount columns by weight using a round-robin approach to ensure + * balanced selection across different tables */ - protected Map selectWeightedColumns(List> columnDocumentList, int maxCount) { - Map result = new HashMap<>(); - int index = 0; - - while (result.size() < maxCount) { - boolean completed = true; - for (List docs : columnDocumentList) { - if (index < docs.size()) { - Document doc = docs.get(index); - String id = doc.getId(); - if (!result.containsKey(id)) { - result.put(id, doc); + protected Map selectColumnsByRoundRobin(List> columnDocumentList, int maxCount) { + Map selectedColumns = new HashMap<>(); + int currentRound = 0; + + // Continue selecting columns until we reach maxCount or exhaust all columns + while (selectedColumns.size() < maxCount) { + boolean hasMoreColumnsInAnyList = false; + + // Process each table's column list in the current round + for (List tableColumns : columnDocumentList) { + if (currentRound < tableColumns.size()) { + // Get the column at current position (already sorted by weight) + Document column = tableColumns.get(currentRound); + String columnId = column.getId(); + + // Add to selection if not already selected + if (!selectedColumns.containsKey(columnId)) { + selectedColumns.put(columnId, column); + + // Stop if we've reached the maximum count + if (selectedColumns.size() >= maxCount) { + break; + } } - completed = false; + + hasMoreColumnsInAnyList = true; } } - index++; - if (completed) { + + // If no more columns in any list, exit the loop + if (!hasMoreColumnsInAnyList) { break; } + + currentRound++; } - return result; + + return selectedColumns; } /** @@ -284,34 +277,147 @@ else if (primaryKeyObj instanceof String) { /** * Score each column (combining with its table's score) */ - public void processColumnWeights(List> columnDocuments, List tableDocuments) { - columnDocuments.replaceAll(docs -> docs.stream() - .filter(column -> tableDocuments.stream() - .anyMatch(table -> table.getMetadata().get("name").equals(column.getMetadata().get("tableName")))) - .peek(column -> { - Optional matchingTable = tableDocuments.stream() - .filter(table -> table.getMetadata().get("name").equals(column.getMetadata().get("tableName"))) - .findFirst(); - matchingTable.ifPresent(tableDoc -> { - Double tableScore = Optional.ofNullable((Double) tableDoc.getMetadata().get("score")) - .orElse(tableDoc.getScore()); - Double columnScore = Optional.ofNullable((Double) column.getMetadata().get("score")) - .orElse(column.getScore()); - if (tableScore != null && columnScore != null) { - column.getMetadata().put("score", columnScore * tableScore); - } - }); - }) - .sorted(Comparator.comparing((Document d) -> (Double) d.getMetadata().get("score")).reversed()) - .collect(Collectors.toList())); + public void updateAndSortColumnScoresByTableWeights(List> columnDocuments, + List tableDocuments) { + for (int i = 0; i < columnDocuments.size(); i++) { + List processedColumns = processSingleTableColumns(columnDocuments.get(i), tableDocuments); + columnDocuments.set(i, processedColumns); + } + } + + /** + * Process columns for a single table, filtering and updating scores + */ + private List processSingleTableColumns(List columns, List tableDocuments) { + // Step 1: Filter columns to only include those that have a matching table + List filteredColumns = filterColumnsWithMatchingTables(columns, tableDocuments); + + // Step 2: Update column scores by multiplying with their table scores + updateColumnScoresWithTableScores(filteredColumns, tableDocuments); + + // Step 3: Sort columns by their new scores in descending order + return sortColumnsByScoreDescending(filteredColumns); + } + + /** + * Filter columns to only include those that have a matching table + */ + private List filterColumnsWithMatchingTables(List columns, List tableDocuments) { + List result = new ArrayList<>(); + + for (Document column : columns) { + String columnTableName = (String) column.getMetadata().get("tableName"); + if (hasMatchingTable(tableDocuments, columnTableName)) { + result.add(column); + } + } + + return result; + } + + /** + * Check if there's a table with the given name in the table documents + */ + private boolean hasMatchingTable(List tableDocuments, String tableName) { + if (StringUtils.isBlank(tableName)) { + return false; + } + + for (Document table : tableDocuments) { + String table_name = (String) table.getMetadata().get("name"); + if (tableName.equals(table_name)) { + return true; + } + } + + return false; + } + + /** + * Update column scores by multiplying with their table scores + */ + private void updateColumnScoresWithTableScores(List columns, List tableDocuments) { + for (Document column : columns) { + String columnTableName = (String) column.getMetadata().get("tableName"); + Document matchingTable = findTableByName(tableDocuments, columnTableName); + + if (matchingTable != null) { + Double tableScore = getTableScore(matchingTable); + Double columnScore = getColumnScore(column); + + if (tableScore != null && columnScore != null) { + Double newScore = columnScore * tableScore; + column.getMetadata().put("score", newScore); + } + } + } + } + + /** + * Find a table document by its name + */ + private Document findTableByName(List tableDocuments, String tableName) { + if (StringUtils.isBlank(tableName)) { + return null; + } + + for (Document table : tableDocuments) { + String table_name = (String) table.getMetadata().get("name"); + if (tableName.equals(table_name)) { + return table; + } + } + + return null; + } + + /** + * Get the score from a table document + */ + private Double getTableScore(Document tableDoc) { + Double scoreFromMetadata = (Double) tableDoc.getMetadata().get("score"); + return scoreFromMetadata != null ? scoreFromMetadata : tableDoc.getScore(); } /** - * Extract foreign key relationships + * Get the score from a column document + */ + private Double getColumnScore(Document columnDoc) { + Double scoreFromMetadata = (Double) columnDoc.getMetadata().get("score"); + return scoreFromMetadata != null ? scoreFromMetadata : columnDoc.getScore(); + } + + /** + * Sort columns by their scores in descending order + */ + private List sortColumnsByScoreDescending(List columns) { + List sortedColumns = new ArrayList<>(columns); + + sortedColumns.sort((doc1, doc2) -> { + Double score1 = (Double) doc1.getMetadata().get("score"); + Double score2 = (Double) doc2.getMetadata().get("score"); + + // Handle null scores + if (score1 == null && score2 == null) + return 0; + if (score1 == null) + return 1; + if (score2 == null) + return -1; + + // Sort in descending order + return score2.compareTo(score1); + }); + + return sortedColumns; + } + + /** + * Extract related table and column names from foreign key relationships * @param tableDocuments table document list - * @return foreign key relationship set + * @return set of related names in format "tableName.columnName" */ - protected Set extractForeignKeyRelations(List tableDocuments) { + protected Set extractRelatedNamesFromForeignKeys(List tableDocuments) { Set result = new HashSet<>(); for (Document doc : tableDocuments) { @@ -352,7 +458,8 @@ protected void attachColumnsToTables(Map weightedColumns, List }); columnDTO.setData(samples); } - catch (Exception ignore) { + catch (Exception e) { + log.error("Failed to parse samples: {}", samplesStr, e); } } @@ -364,28 +471,13 @@ protected void attachColumnsToTables(Map weightedColumns, List } } - /** - * Get table metadata - * @param tableName table name - * @return table metadata - */ - protected Map getTableMetadata(String tableName) { - List tableDocuments = getTableDocuments(tableName); - for (Document doc : tableDocuments) { - Map metadata = doc.getMetadata(); - if (tableName.equals(metadata.get("name"))) { - return metadata; - } - } - return null; - } - /** * Extract database name * @param schemaDTO SchemaDTO + * @param dbConfig Database configuration */ @Override - public void extractDatabaseName(SchemaDTO schemaDTO) { + public void extractDatabaseName(SchemaDTO schemaDTO, DbConfig dbConfig) { String pattern = ":\\d+/([^/?&]+)"; if (BizDataSourceTypeEnum.isMysqlDialect(dbConfig.getDialectType())) { Pattern regex = Pattern.compile(pattern); @@ -399,31 +491,4 @@ else if (BizDataSourceTypeEnum.isPgDialect(dbConfig.getDialectType())) { } } - /** - * Common document query processing template to reduce subclass redundant code. - */ - protected void handleDocumentQuery(List targetList, String key, String vectorType, - Function requestBuilder, Function> searchFunc) { - SearchRequest request = requestBuilder.apply(key); - request.setVectorType(vectorType); - request.setTopK(10); - List docs = searchFunc.apply(request); - if (CollectionUtils.isNotEmpty(docs)) { - targetList.addAll(docs); - } - } - - protected void handleDocumentQuery(Map targetMap, String key, String vectorType, - Function requestBuilder, Function> searchFunc) { - SearchRequest request = requestBuilder.apply(key); - request.setVectorType(vectorType); - request.setTopK(10); - List docs = searchFunc.apply(request); - if (CollectionUtils.isNotEmpty(docs)) { - for (Document doc : docs) { - targetMap.putIfAbsent(doc.getId(), doc); - } - } - } - } diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/SchemaService.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/SchemaService.java index 47c93a0..7c64b9b 100644 --- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/SchemaService.java +++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/SchemaService.java @@ -16,6 +16,7 @@ package com.alibaba.cloud.ai.service.schema; +import com.alibaba.cloud.ai.connector.config.DbConfig; import com.alibaba.cloud.ai.dto.schema.SchemaDTO; import org.springframework.ai.document.Document; @@ -27,12 +28,12 @@ public interface SchemaService { List> getColumnDocumentsByKeywordsForAgent(String agentId, List keywords); - void extractDatabaseName(SchemaDTO schemaDTO); + void extractDatabaseName(SchemaDTO schemaDTO, DbConfig dbConfig); - void buildSchemaFromDocuments(List> columnDocumentList, List tableDocuments, - SchemaDTO schemaDTO); + void buildSchemaFromDocuments(String agentId, List> columnDocumentList, + List tableDocuments, SchemaDTO schemaDTO); - SchemaDTO mixRagForAgent(String agentId, String query, List keywords); + SchemaDTO mixRagForAgent(String agentId, String query, List keywords, DbConfig dbConfig); default List getTableDocuments(String query) { return getTableDocumentsForAgent(null, query); @@ -42,8 +43,8 @@ default List> getColumnDocumentsByKeywords(List keywords) return getColumnDocumentsByKeywordsForAgent(null, keywords); } - default SchemaDTO mixRag(String query, List keywords) { - return mixRagForAgent(null, query, keywords); + default SchemaDTO mixRag(String query, List keywords, DbConfig dbConfig) { + return mixRagForAgent(null, query, keywords, dbConfig); } } diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/SchemaServiceFactory.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/SchemaServiceFactory.java index bf0c858..250c99f 100644 --- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/SchemaServiceFactory.java +++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/SchemaServiceFactory.java @@ -16,37 +16,65 @@ package com.alibaba.cloud.ai.service.schema; -import com.alibaba.cloud.ai.connector.config.DbConfig; import com.alibaba.cloud.ai.service.schema.impls.AnalyticSchemaService; import com.alibaba.cloud.ai.service.schema.impls.SimpleSchemaService; -import com.alibaba.cloud.ai.service.vectorstore.VectorStoreService; +import com.alibaba.cloud.ai.service.vectorstore.AgentVectorStoreService; +import com.alibaba.cloud.ai.service.vectorstore.VectorStoreType; +import com.alibaba.cloud.ai.service.vectorstore.impls.AnalyticAgentVectorStoreService; +import com.alibaba.cloud.ai.service.vectorstore.impls.SimpleAgentVectorStoreService; import com.alibaba.cloud.ai.util.JsonUtil; +import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.FactoryBean; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; +import org.springframework.context.annotation.DependsOn; import org.springframework.stereotype.Component; +import java.util.HashMap; +import java.util.Map; + +@Slf4j @Component +@DependsOn("agentVectorStoreServiceFactory") public class SchemaServiceFactory implements FactoryBean { - // todo: 改为枚举,由用户配置决定实现类 - @Value("${spring.ai.vectorstore.analytic.enabled:false}") - private Boolean analyticEnabled; + @Value("${spring.ai.vectorstore.type:SIMPLE}") + private VectorStoreType vectorStoreType; + + @Autowired(required = false) + private AgentVectorStoreService agentVectorStoreService; + + @FunctionalInterface + private interface SchemaServiceCreator { + + SchemaService create(); - @Autowired - private DbConfig dbConfig; + } + + private final Map serviceCreators = new HashMap<>(); - @Autowired - private VectorStoreService vectorStoreService; + public SchemaServiceFactory() { + // 初始化各种向量存储类型的创建策略 + serviceCreators.put(VectorStoreType.ANALYTIC_DB, this::createAnalyticSchemaService); + serviceCreators.put(VectorStoreType.SIMPLE, this::createSimpleSchemaService); + + } @Override public SchemaService getObject() { - if (Boolean.TRUE.equals(analyticEnabled)) { - return new AnalyticSchemaService(dbConfig, JsonUtil.getObjectMapper(), vectorStoreService); + if (agentVectorStoreService == null) { + throw new IllegalStateException("AgentVectorStoreService is not initialized."); } - else { - return new SimpleSchemaService(dbConfig, JsonUtil.getObjectMapper(), vectorStoreService); + + // 根据配置的向量存储类型获取对应的创建策略 + SchemaServiceCreator creator = serviceCreators.get(vectorStoreType); + if (creator == null) { + log.warn("Unsupported vector store type: {}, falling back to SIMPLE", vectorStoreType); + creator = serviceCreators.get(VectorStoreType.SIMPLE); } + + // 使用选定的策略创建SchemaService实例 + return creator.create(); } @Override @@ -54,4 +82,32 @@ public Class getObjectType() { return SchemaService.class; } + /** + * 创建分析型数据库的SchemaService + * @return AnalyticSchemaService实例 + */ + private SchemaService createAnalyticSchemaService() { + log.info("Using AnalyticSchemaService"); + if (!(agentVectorStoreService instanceof AnalyticAgentVectorStoreService)) { + throw new IllegalStateException( + "AgentVectorStoreService is not an instance of AnalyticAgentVectorStoreService"); + } + return new AnalyticSchemaService(JsonUtil.getObjectMapper(), + (AnalyticAgentVectorStoreService) agentVectorStoreService); + } + + /** + * 创建简单内存存储的SchemaService + * @return SimpleSchemaService实例 + */ + private SchemaService createSimpleSchemaService() { + log.info("Using SimpleSchemaService"); + if (!(agentVectorStoreService instanceof SimpleAgentVectorStoreService)) { + throw new IllegalStateException( + "AgentVectorStoreService is not an instance of SimpleAgentVectorStoreService"); + } + return new SimpleSchemaService(JsonUtil.getObjectMapper(), + (SimpleAgentVectorStoreService) agentVectorStoreService); + } + } diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/impls/AnalyticSchemaService.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/impls/AnalyticSchemaService.java index bc20ad5..599c940 100644 --- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/impls/AnalyticSchemaService.java +++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/impls/AnalyticSchemaService.java @@ -15,45 +15,17 @@ */ package com.alibaba.cloud.ai.service.schema.impls; -import com.alibaba.cloud.ai.connector.config.DbConfig; -import com.alibaba.cloud.ai.request.SearchRequest; import com.alibaba.cloud.ai.service.schema.AbstractSchemaService; -import com.alibaba.cloud.ai.service.vectorstore.VectorStoreService; +import com.alibaba.cloud.ai.service.vectorstore.impls.AnalyticAgentVectorStoreService; import com.fasterxml.jackson.databind.ObjectMapper; -import org.springframework.ai.document.Document; - -import java.util.List; -import java.util.Map; /** * Schema building service, supports RAG-based hybrid queries. */ public class AnalyticSchemaService extends AbstractSchemaService { - public AnalyticSchemaService(DbConfig dbConfig, ObjectMapper objectMapper, VectorStoreService vectorStoreService) { - super(dbConfig, objectMapper, vectorStoreService); - } - - @Override - protected void addTableDocument(List tableDocuments, String tableName, String vectorType) { - handleDocumentQuery(tableDocuments, tableName, vectorType, name -> { - SearchRequest req = new SearchRequest(); - req.setQuery(null); - req.setFilterFormatted("jsonb_extract_path_text(metadata, 'vectorType') = '" + vectorType - + "' and refdocid = '" + name + "'"); - return req; - }, vectorStoreService::searchWithFilter); - } - - @Override - protected void addColumnsDocument(Map weightedColumns, String columnName, String vectorType) { - handleDocumentQuery(weightedColumns, columnName, vectorType, name -> { - SearchRequest req = new SearchRequest(); - req.setQuery(null); - req.setFilterFormatted("jsonb_extract_path_text(metadata, 'vectorType') = '" + vectorType - + "' and refdocid = '" + name + "'"); - return req; - }, vectorStoreService::searchWithFilter); + public AnalyticSchemaService(ObjectMapper objectMapper, AnalyticAgentVectorStoreService vectorStoreService) { + super(objectMapper, vectorStoreService); } } diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/impls/SimpleSchemaService.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/impls/SimpleSchemaService.java index ebc356d..03c07bc 100644 --- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/impls/SimpleSchemaService.java +++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/schema/impls/SimpleSchemaService.java @@ -15,38 +15,14 @@ */ package com.alibaba.cloud.ai.service.schema.impls; -import com.alibaba.cloud.ai.connector.config.DbConfig; -import com.alibaba.cloud.ai.request.SearchRequest; import com.alibaba.cloud.ai.service.schema.AbstractSchemaService; -import com.alibaba.cloud.ai.service.vectorstore.VectorStoreService; +import com.alibaba.cloud.ai.service.vectorstore.impls.SimpleAgentVectorStoreService; import com.fasterxml.jackson.databind.ObjectMapper; -import org.springframework.ai.document.Document; - -import java.util.List; -import java.util.Map; public class SimpleSchemaService extends AbstractSchemaService { - public SimpleSchemaService(DbConfig dbConfig, ObjectMapper objectMapper, VectorStoreService vectorStoreService) { - super(dbConfig, objectMapper, vectorStoreService); - } - - @Override - protected void addTableDocument(List tableDocuments, String tableName, String vectorType) { - handleDocumentQuery(tableDocuments, tableName, vectorType, name -> { - SearchRequest req = new SearchRequest(); - req.setName(name); - return req; - }, vectorStoreService::searchTableByNameAndVectorType); - } - - @Override - protected void addColumnsDocument(Map weightedColumns, String tableName, String vectorType) { - handleDocumentQuery(weightedColumns, tableName, vectorType, name -> { - SearchRequest req = new SearchRequest(); - req.setName(name); - return req; - }, vectorStoreService::searchTableByNameAndVectorType); + public SimpleSchemaService(ObjectMapper objectMapper, SimpleAgentVectorStoreService vectorStoreService) { + super(objectMapper, vectorStoreService); } } diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/AbstractAgentVectorStoreService.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/AbstractAgentVectorStoreService.java new file mode 100644 index 0000000..6e59d56 --- /dev/null +++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/AbstractAgentVectorStoreService.java @@ -0,0 +1,318 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.service.vectorstore; + +import static com.alibaba.cloud.ai.util.DocumentConverterUtil.convertColumnsToDocuments; +import static com.alibaba.cloud.ai.util.DocumentConverterUtil.convertTablesToDocuments; + +import java.util.*; +import java.util.stream.Collectors; + +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +import com.alibaba.cloud.ai.connector.accessor.Accessor; +import com.alibaba.cloud.ai.connector.accessor.AccessorFactory; +import com.alibaba.cloud.ai.connector.bo.DbQueryParameter; +import com.alibaba.cloud.ai.connector.bo.ForeignKeyInfoBO; +import com.alibaba.cloud.ai.connector.bo.TableInfoBO; +import com.alibaba.cloud.ai.connector.config.DbConfig; +import com.alibaba.cloud.ai.constant.Constant; +import com.alibaba.cloud.ai.request.AgentSearchRequest; +import com.alibaba.cloud.ai.request.SchemaInitRequest; +import com.alibaba.cloud.ai.util.JsonUtil; +import com.alibaba.cloud.ai.util.SchemaProcessorUtil; + +import lombok.extern.slf4j.Slf4j; + +@Slf4j +public abstract class AbstractAgentVectorStoreService implements AgentVectorStoreService { + + /** + * 相似度阈值配置,用于过滤相似度分数大于等于此阈值的文档 + */ + @Value("${spring.ai.vectorstore.similarityThreshold:0.6}") + protected double similarityThreshold; + + /** + * Get embedding model + */ + protected abstract EmbeddingModel getEmbeddingModel(); + + protected final AccessorFactory accessorFactory; + + public AbstractAgentVectorStoreService(AccessorFactory accessorFactory) { + this.accessorFactory = accessorFactory; + } + + @Override + public List search(AgentSearchRequest searchRequest) { + Assert.notNull(searchRequest, "SearchRequest cannot be null"); + Assert.notNull(searchRequest.getAgentId(), "agentId cannot be null"); + org.springframework.ai.vectorstore.SearchRequest.Builder builder = org.springframework.ai.vectorstore.SearchRequest + .builder(); + + if (StringUtils.hasText(searchRequest.getQuery())) + builder.query(searchRequest.getQuery()); + + if (Objects.nonNull(searchRequest.getTopK())) + builder.topK(searchRequest.getTopK()); + + String filterFormatted = buildFilterExpressionString(searchRequest.getMetadataFilter()); + if (StringUtils.hasText(filterFormatted)) + builder.filterExpression(filterFormatted); + builder.similarityThreshold(similarityThreshold); + List results = getVectorStore().similaritySearch(builder.build()); + log.info("Search completed. Found {} documents for SearchRequest: {}", results.size(), searchRequest); + return results; + + } + + // 模板方法 - 通用schema处理流程 + @Override + public final Boolean schema(String agentId, SchemaInitRequest schemaInitRequest) throws Exception { + try { + + DbConfig config = schemaInitRequest.getDbConfig(); + DbQueryParameter dqp = DbQueryParameter.from(config) + .setSchema(config.getSchema()) + .setTables(schemaInitRequest.getTables()); + + // 根据当前DbConfig获取Accessor + Accessor dbAccessor = accessorFactory.getAccessorByDbConfig(config); + + // 清理旧数据 + clearSchemaDataForAgent(agentId); + + // 处理外键 + List foreignKeys = dbAccessor.showForeignKeys(config, dqp); + Map> foreignKeyMap = buildForeignKeyMap(foreignKeys); + + // 处理表和列 + List tables = dbAccessor.fetchTables(config, dqp); + for (TableInfoBO table : tables) { + SchemaProcessorUtil.enrichTableMetadata(table, dqp, config, dbAccessor, JsonUtil.getObjectMapper(), + foreignKeyMap); + } + + // 转换为文档 + List columnDocs = convertColumnsToDocuments(agentId, tables); + List tableDocs = convertTablesToDocuments(agentId, tables); + + // 存储文档 + return storeSchemaDocuments(columnDocs, tableDocs); + } + catch (Exception e) { + log.error("Failed to process schema ", e); + return false; + } + } + + protected Boolean storeSchemaDocuments(List columns, List tables) { + try { + getVectorStore().add(columns); + getVectorStore().add(tables); + return true; + } + catch (Exception e) { + log.error("add document to vectorstore error", e); + return false; + } + + } + + protected Map> buildForeignKeyMap(List foreignKeys) { + Map> map = new HashMap<>(); + for (ForeignKeyInfoBO fk : foreignKeys) { + String key = fk.getTable() + "." + fk.getColumn() + "=" + fk.getReferencedTable() + "." + + fk.getReferencedColumn(); + + map.computeIfAbsent(fk.getTable(), k -> new ArrayList<>()).add(key); + map.computeIfAbsent(fk.getReferencedTable(), k -> new ArrayList<>()).add(key); + } + return map; + } + + protected abstract VectorStore getVectorStore(); + + protected void clearSchemaDataForAgent(String agentId) throws Exception { + deleteDocumentsByVectorType(agentId, Constant.COLUMN); + deleteDocumentsByVectorType(agentId, Constant.TABLE); + } + + @Override + public Boolean deleteDocumentsByVectorType(String agentId, String vectorType) throws Exception { + Assert.notNull(agentId, "AgentId cannot be null."); + Assert.notNull(vectorType, "VectorType cannot be null."); + + Map metadata = new HashMap<>( + Map.ofEntries(Map.entry(Constant.AGENT_ID, agentId), Map.entry(Constant.VECTOR_TYPE, vectorType))); + + return this.deleteDocumentsByMetedata(agentId, metadata); + } + + @Override + public void addDocuments(String agentId, List documents) { + Assert.notNull(agentId, "AgentId cannot be null."); + Assert.notEmpty(documents, "Documents cannot be empty."); + getVectorStore().add(documents); + } + + @Override + public int estimateDocuments(String agentId) { + // 初略估算文档数目 + List docs = getVectorStore() + .similaritySearch(org.springframework.ai.vectorstore.SearchRequest.builder() + .query("") + .filterExpression(buildFilterExpressionString(Map.of(Constant.AGENT_ID, agentId))) + .topK(Integer.MAX_VALUE) // 获取所有匹配的文档 + .build()); + return docs.size(); + } + + @Override + public Boolean deleteDocumentsByMetedata(String agentId, Map metadata) throws Exception { + Assert.notNull(agentId, "AgentId cannot be null."); + Assert.notNull(metadata, "Metadata cannot be null."); + // 添加agentId元数据过滤条件, 用于删除指定agentId下的所有数据,因为metadata中用户调用可能忘记添加agentId + metadata.put(Constant.AGENT_ID, agentId); + String filterExpression = buildFilterExpressionString(metadata); + + // TODO 后续改成getVectorStore().delete(filterExpression); + // TODO 目前不支持通过元数据删除,使用会抛出UnsupportedOperationException,后续spring + // TODO ai发布1.1.0正式版本后再修改,现在是通过id删除 + + // 先搜索要删除的文档 + List documentsToDelete = getVectorStore() + .similaritySearch(org.springframework.ai.vectorstore.SearchRequest.builder() + .query("") + .filterExpression(filterExpression) + .topK(Integer.MAX_VALUE) + .build()); + + // 提取文档ID并删除 + if (!documentsToDelete.isEmpty()) { + List idsToDelete = documentsToDelete.stream().map(Document::getId).collect(Collectors.toList()); + getVectorStore().delete(idsToDelete); + } + + return true; + } + + /** + * 构建过滤表达式字符串,目前FilterExpressionBuilder 不支持链式拼接元数据过滤,所以只能使用字符串拼接 + * @param filterMap + * @return + */ + protected final String buildFilterExpressionString(Map filterMap) { + if (filterMap == null || filterMap.isEmpty()) { + return null; + } + + // 验证键名是否合法(只包含字母、数字和下划线) + for (String key : filterMap.keySet()) { + if (!key.matches("[a-zA-Z_][a-zA-Z0-9_]*")) { + throw new IllegalArgumentException("Invalid key name: " + key + + ". Keys must start with a letter or underscore and contain only alphanumeric characters and underscores."); + } + } + + return filterMap.entrySet().stream().map(entry -> { + String key = entry.getKey(); + Object value = entry.getValue(); + + // 处理空值 + if (value == null) { + return key + " == null"; + } + + // 根据值的类型决定如何格式化 + if (value instanceof String) { + // 转义字符串中的特殊字符 + String escapedValue = escapeStringLiteral((String) value); + return key + " == '" + escapedValue + "'"; + } + else if (value instanceof Number) { + // 数字类型直接使用 + return key + " == " + value; + } + else if (value instanceof Boolean) { + // 布尔值使用小写形式 + return key + " == " + ((Boolean) value).toString().toLowerCase(); + } + else if (value instanceof Enum) { + // 枚举类型,转换为字符串并转义 + String enumValue = ((Enum) value).name(); + String escapedValue = escapeStringLiteral(enumValue); + return key + " == '" + escapedValue + "'"; + } + else { + // 其他类型尝试转换为字符串并转义 + String stringValue = value.toString(); + String escapedValue = escapeStringLiteral(stringValue); + return key + " == '" + escapedValue + "'"; + } + }).collect(Collectors.joining(" && ")); + } + + /** + * 转义字符串字面量中的特殊字符 + */ + private String escapeStringLiteral(String input) { + if (input == null) { + return ""; + } + + // 转义反斜杠和单引号 + String escaped = input.replace("\\", "\\\\").replace("'", "\\'"); + + // 转义其他特殊字符 + escaped = escaped.replace("\n", "\\n") + .replace("\r", "\\r") + .replace("\t", "\\t") + .replace("\b", "\\b") + .replace("\f", "\\f"); + + return escaped; + } + + @Override + public List getDocumentsForAgent(String agentId, String query, String vectorType) { + AgentSearchRequest searchRequest = AgentSearchRequest.getInstance(agentId); + searchRequest.setQuery(query); + searchRequest.setTopK(20); + searchRequest.setMetadataFilter(Map.of(Constant.VECTOR_TYPE, vectorType)); + + return search(searchRequest); + } + + @Override + public boolean hasDocuments(String agentId) { + // 类似 MySQL 的 LIMIT 1,只检查是否存在文档 + List docs = getVectorStore() + .similaritySearch(org.springframework.ai.vectorstore.SearchRequest.builder() + .query("") + .filterExpression(buildFilterExpressionString(Map.of(Constant.AGENT_ID, agentId))) + .topK(1) // 只获取1个文档 + .build()); + return !docs.isEmpty(); + } + +} diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/AbstractVectorStoreService.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/AbstractVectorStoreService.java deleted file mode 100644 index 7d3d8ed..0000000 --- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/AbstractVectorStoreService.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright 2024-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.alibaba.cloud.ai.service.vectorstore; - -import com.alibaba.cloud.ai.request.SearchRequest; -import org.springframework.ai.document.Document; -import org.springframework.ai.embedding.EmbeddingModel; - -import java.util.ArrayList; -import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.IntStream; - -public abstract class AbstractVectorStoreService implements VectorStoreService { - - /** - * Get embedding model - */ - protected abstract EmbeddingModel getEmbeddingModel(); - - /** - * Convert text to Double type vector - */ - public List embedDouble(String text) { - return convertToDoubleList(getEmbeddingModel().embed(text)); - } - - /** - * Convert text to Float type vector - */ - public List embedFloat(String text) { - return convertToFloatList(getEmbeddingModel().embed(text)); - } - - /** - * Get documents from vector store - */ - @Override - public List getDocuments(String query, String vectorType) { - SearchRequest request = new SearchRequest(); - request.setQuery(query); - request.setVectorType(vectorType); - request.setTopK(100); - return new ArrayList<>(searchWithVectorType(request)); - } - - /** - * Convert float[] to Double List - */ - protected List convertToDoubleList(float[] array) { - return IntStream.range(0, array.length) - .mapToDouble(i -> (double) array[i]) - .boxed() - .collect(Collectors.toList()); - } - - /** - * Convert float[] to Float List - */ - protected List convertToFloatList(float[] array) { - return IntStream.range(0, array.length).mapToObj(i -> array[i]).collect(Collectors.toList()); - } - -} diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/AgentVectorStoreManager.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/AgentVectorStoreManager.java deleted file mode 100644 index 3b1aea4..0000000 --- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/AgentVectorStoreManager.java +++ /dev/null @@ -1,262 +0,0 @@ -/* - * Copyright 2024-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.alibaba.cloud.ai.service.vectorstore; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.ai.document.Document; -import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.vectorstore.SimpleVectorStore; -import org.springframework.ai.vectorstore.filter.Filter; -import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder; -import org.springframework.stereotype.Service; - -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; - -/** - * Agent Vector Storage Manager Provides independent vector storage instances for each - * agent, ensuring data isolation - */ -@Service -public class AgentVectorStoreManager { - - private static final Logger log = LoggerFactory.getLogger(AgentVectorStoreManager.class); - - private final Map agentStores = new ConcurrentHashMap<>(); - - private final EmbeddingModel embeddingModel; - - public AgentVectorStoreManager(EmbeddingModel embeddingModel) { - this.embeddingModel = embeddingModel; - log.info("AgentVectorStoreManager initialized with EmbeddingModel: {}", - embeddingModel.getClass().getSimpleName()); - } - - /** - * Get or create agent-specific vector storage - * @param agentId agent ID - * @return agent-specific SimpleVectorStore instance - */ - public SimpleVectorStore getOrCreateVectorStore(String agentId) { - if (agentId == null || agentId.trim().isEmpty()) { - throw new IllegalArgumentException("Agent ID cannot be null or empty"); - } - - return agentStores.computeIfAbsent(agentId, id -> { - log.info("Creating new vector store for agent: {}", id); - return SimpleVectorStore.builder(embeddingModel).build(); - }); - } - - /** - * Add documents for specified agent - * @param agentId agent ID - * @param documents list of documents to add - */ - public void addDocuments(String agentId, List documents) { - if (documents == null || documents.isEmpty()) { - log.warn("No documents to add for agent: {}", agentId); - return; - } - - SimpleVectorStore store = getOrCreateVectorStore(agentId); - store.add(documents); - log.info("Added {} documents to vector store for agent: {}", documents.size(), agentId); - } - - /** - * Search similar documents for specified agent - * @param agentId agent ID - * @param query query text - * @param topK number of results to return - * @return list of similar documents - */ - public List similaritySearch(String agentId, String query, int topK) { - SimpleVectorStore store = agentStores.get(agentId); - if (store == null) { - log.warn("No vector store found for agent: {}", agentId); - return Collections.emptyList(); - } - - List results = store.similaritySearch( - org.springframework.ai.vectorstore.SearchRequest.builder().query(query).topK(topK).build()); - log.debug("Found {} similar documents for agent: {} with query: {}", results.size(), agentId, query); - return results; - } - - /** - * Search similar documents for specified agent (with filter conditions) - * @param agentId agent ID - * @param query query text - * @param topK number of results to return - * @param vectorType vector type filter - * @return list of similar documents - */ - public List similaritySearchWithFilter(String agentId, String query, int topK, String vectorType) { - SimpleVectorStore store = agentStores.get(agentId); - if (store == null) { - log.warn("No vector store found for agent: {}", agentId); - return Collections.emptyList(); - } - - FilterExpressionBuilder builder = new FilterExpressionBuilder(); - Filter.Expression expression = builder.eq("vectorType", vectorType).build(); - - List results = store.similaritySearch(org.springframework.ai.vectorstore.SearchRequest.builder() - .query(query) - .topK(topK) - .filterExpression(expression) - .build()); - - log.debug("Found {} filtered documents for agent: {} with query: {} and vectorType: {}", results.size(), - agentId, query, vectorType); - return results; - } - - /** - * Delete all data of specified agent - * @param agentId agent ID - */ - public void deleteAgentData(String agentId) { - SimpleVectorStore removed = agentStores.remove(agentId); - if (removed != null) { - log.info("Deleted all vector data for agent: {}", agentId); - } - else { - log.warn("No vector store found to delete for agent: {}", agentId); - } - } - - /** - * Delete specific documents of specified agent - * @param agentId agent ID - * @param documentIds list of document IDs to delete - */ - public void deleteDocuments(String agentId, List documentIds) { - SimpleVectorStore store = agentStores.get(agentId); - if (store == null) { - log.warn("No vector store found for agent: {}", agentId); - return; - } - - if (documentIds != null && !documentIds.isEmpty()) { - store.delete(documentIds); - log.info("Deleted {} documents from vector store for agent: {}", documentIds.size(), agentId); - } - } - - /** - * Delete specific type documents of specified agent - * @param agentId agent ID - * @param vectorType vector type - */ - public void deleteDocumentsByType(String agentId, String vectorType) { - SimpleVectorStore store = agentStores.get(agentId); - if (store == null) { - log.warn("No vector store found for agent: {}", agentId); - return; - } - - try { - FilterExpressionBuilder builder = new FilterExpressionBuilder(); - Filter.Expression expression = builder.eq("vectorType", vectorType).build(); - - List documents = store.similaritySearch(org.springframework.ai.vectorstore.SearchRequest.builder() - .query("") - .topK(Integer.MAX_VALUE) - .filterExpression(expression) - .build()); - - if (!documents.isEmpty()) { - List documentIds = documents.stream().map(Document::getId).toList(); - store.delete(documentIds); - log.info("Deleted {} documents of type '{}' for agent: {}", documents.size(), vectorType, agentId); - } - else { - log.info("No documents of type '{}' found for agent: {}", vectorType, agentId); - } - } - catch (Exception e) { - log.error("Failed to delete documents by type for agent: {}", agentId, e); - throw new RuntimeException("Failed to delete documents by type: " + e.getMessage(), e); - } - } - - /** - * Check if agent has vector data - * @param agentId agent ID - * @return whether has data - */ - public boolean hasAgentData(String agentId) { - return agentStores.containsKey(agentId); - } - - /** - * Get document count of agent (estimated) - * @param agentId agent ID - * @return document count - */ - public int getDocumentCount(String agentId) { - SimpleVectorStore store = agentStores.get(agentId); - if (store == null) { - return 0; - } - - try { - // Estimate quantity by searching all documents - List allDocs = store.similaritySearch(org.springframework.ai.vectorstore.SearchRequest.builder() - .query("") - .topK(Integer.MAX_VALUE) - .build()); - return allDocs.size(); - } - catch (Exception e) { - log.warn("Failed to get document count for agent: {}", agentId, e); - return 0; - } - } - - /** - * Get all agent IDs with data - * @return set of agent IDs - */ - public Set getAllAgentIds() { - return Set.copyOf(agentStores.keySet()); - } - - /** - * Get vector storage statistics - * @return statistics - */ - public Map getStatistics() { - Map stats = new ConcurrentHashMap<>(); - stats.put("totalAgents", agentStores.size()); - stats.put("agentIds", getAllAgentIds()); - - Map agentDocCounts = new ConcurrentHashMap<>(); - agentStores.forEach((agentId, store) -> { - agentDocCounts.put(agentId, getDocumentCount(agentId)); - }); - stats.put("documentCounts", agentDocCounts); - - return stats; - } - -} diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/AgentVectorStoreService.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/AgentVectorStoreService.java new file mode 100644 index 0000000..5485890 --- /dev/null +++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/AgentVectorStoreService.java @@ -0,0 +1,51 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.cloud.ai.service.vectorstore; + +import java.util.List; +import java.util.Map; + +import org.springframework.ai.document.Document; + +import com.alibaba.cloud.ai.request.AgentSearchRequest; +import com.alibaba.cloud.ai.request.SchemaInitRequest; + +public interface AgentVectorStoreService { + + /** + * 查询某个Agent的文档 总入口 + */ + List search(AgentSearchRequest searchRequest); + + Boolean schema(String agentId, SchemaInitRequest schemaInitRequest) throws Exception; + + Boolean deleteDocumentsByVectorType(String agentId, String vectorType) throws Exception; + + Boolean deleteDocumentsByMetedata(String agentId, Map metadata) throws Exception; + + /** + * Get documents for specified agent + */ + List getDocumentsForAgent(String agentId, String query, String vectorType); + + boolean hasDocuments(String agentId); + + void addDocuments(String agentId, List documents); + + int estimateDocuments(String agentId); + +} diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/AgentVectorStoreServiceFactory.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/AgentVectorStoreServiceFactory.java new file mode 100644 index 0000000..58756bb --- /dev/null +++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/AgentVectorStoreServiceFactory.java @@ -0,0 +1,115 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.cloud.ai.service.vectorstore; + +import com.alibaba.cloud.ai.connector.accessor.AccessorFactory; +import com.alibaba.cloud.ai.service.vectorstore.impls.AnalyticAgentVectorStoreService; +import com.alibaba.cloud.ai.service.vectorstore.impls.SimpleAgentVectorStoreService; +import com.alibaba.cloud.ai.vectorstore.analyticdb.AnalyticDbVectorStore; +import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.beans.factory.FactoryBean; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Component; + +import java.util.HashMap; +import java.util.Map; + +@Slf4j +@Component +public class AgentVectorStoreServiceFactory implements FactoryBean { + + // 使用枚举配置向量存储类型,由用户配置决定实现类 + @Value("${spring.ai.vectorstore.type:SIMPLE}") + private VectorStoreType vectorStoreType; + + @Autowired + private EmbeddingModel embeddingModel; + + @Autowired + private AccessorFactory accessorFactory; + + // 通用的向量存储bean,由Spring AI自动配置 + @Autowired(required = false) + private VectorStore vectorStore; + + @FunctionalInterface + private interface AgentVectorStoreServiceCreator { + + AgentVectorStoreService create(); + + } + + /** + * 存储不同向量存储类型对应的创建策略 使用Map可以方便地扩展新的向量存储类型 + */ + private final Map serviceCreators = new HashMap<>(); + + public AgentVectorStoreServiceFactory() { + // 初始化各种向量存储类型的创建策略 + // 这里使用显式的匿名类实现而不是方法引用,以提高代码可读性 + serviceCreators.put(VectorStoreType.ANALYTIC_DB, this::createAnalyticAgentVectorStoreService); + + serviceCreators.put(VectorStoreType.SIMPLE, this::createSimpleAgentVectorStoreService); + + // TODO 后续其他向量存储类型扩展处 + + } + + @Override + public AgentVectorStoreService getObject() { + // 根据配置的向量存储类型获取对应的创建策略 + AgentVectorStoreServiceCreator creator = serviceCreators.get(vectorStoreType); + if (creator == null) { + log.warn("Unsupported vector store type: {}, falling back to SIMPLE", vectorStoreType); + creator = serviceCreators.get(VectorStoreType.SIMPLE); + } + + // 使用选定的策略创建AgentVectorStoreService实例 + return creator.create(); + } + + @Override + public Class getObjectType() { + return AgentVectorStoreService.class; + } + + private AgentVectorStoreService createAnalyticAgentVectorStoreService() { + if (vectorStore == null) { + throw new IllegalStateException( + "AnalyticDbVectorStore is not configured. Please check your configuration."); + } + log.info("Using AnalyticDbVectorStoreService"); + if (!(vectorStore instanceof AnalyticDbVectorStore)) { + throw new IllegalStateException("VectorStore is not an instance of AnalyticDbVectorStore"); + } + return new AnalyticAgentVectorStoreService(embeddingModel, (AnalyticDbVectorStore) vectorStore, + accessorFactory); + } + + /** + * 创建简单内存存储的AgentVectorStoreService + * @return SimpleAgentVectorStoreService实例 + */ + private AgentVectorStoreService createSimpleAgentVectorStoreService() { + log.info("Using SimpleVectorStoreService"); + return new SimpleAgentVectorStoreService(embeddingModel, accessorFactory); + } + +} diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/VectorStoreService.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/VectorStoreService.java deleted file mode 100644 index 422641f..0000000 --- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/VectorStoreService.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright 2024-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.alibaba.cloud.ai.service.vectorstore; - -import com.alibaba.cloud.ai.request.SchemaInitRequest; -import com.alibaba.cloud.ai.request.SearchRequest; -import org.springframework.ai.document.Document; - -import java.util.List; - -public interface VectorStoreService { - - /** - * Search interface with default filter - */ - List searchWithVectorType(SearchRequest searchRequestDTO); - - /** - * Search interface with custom filter - */ - List searchWithFilter(SearchRequest searchRequestDTO); - - List getDocuments(String query, String vectorType); - - /** - * Get documents for tables - */ - default List searchTableByNameAndVectorType(SearchRequest searchRequestDTO) { - throw new UnsupportedOperationException("Not implemented."); - } - - /** - * Get documents from vector store for specified agent - */ - default List getDocumentsForAgent(String agentId, String query, String vectorType) { - // Default implementation: if subclass doesn't override, use global search - return getDocuments(query, vectorType); - } - - default AgentVectorStoreManager getAgentVectorStoreManager() { - throw new UnsupportedOperationException("Not implemented."); - } - - default Boolean schemaForAgent(String agentId, SchemaInitRequest schemaInitRequest) throws Exception { - throw new UnsupportedOperationException("Not implemented."); - } - - default Boolean schema(SchemaInitRequest schemaInitRequest) throws Exception { - throw new UnsupportedOperationException("Not implemented."); - } - -} diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/VectorStoreServiceFactory.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/VectorStoreServiceFactory.java deleted file mode 100644 index d501219..0000000 --- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/VectorStoreServiceFactory.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright 2024-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.alibaba.cloud.ai.service.vectorstore; - -import com.alibaba.cloud.ai.connector.accessor.AccessorFactory; -import com.alibaba.cloud.ai.service.vectorstore.impls.AnalyticVectorStoreService; -import com.alibaba.cloud.ai.service.vectorstore.impls.SimpleVectorStoreService; -import com.alibaba.cloud.ai.vectorstore.analyticdb.AnalyticDbVectorStoreProperties; -import com.aliyun.gpdb20160503.Client; -import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.beans.factory.FactoryBean; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.stereotype.Component; - -@Component -public class VectorStoreServiceFactory implements FactoryBean { - - // todo: 改为枚举,由用户配置决定实现类 - @Value("${spring.ai.vectorstore.analytic.enabled:false}") - private Boolean analyticEnabled; - - @Autowired - private EmbeddingModel embeddingModel; - - @Autowired(required = false) - private AnalyticDbVectorStoreProperties analyticDbVectorStoreProperties; - - @Autowired(required = false) - private Client client; - - @Autowired - private AccessorFactory accessorFactory; - - @Autowired - private AgentVectorStoreManager agentVectorStoreManager; - - @Override - public VectorStoreService getObject() { - if (Boolean.TRUE.equals(analyticEnabled)) { - return new AnalyticVectorStoreService(analyticDbVectorStoreProperties, embeddingModel, client); - } - else { - return new SimpleVectorStoreService(embeddingModel, accessorFactory, agentVectorStoreManager); - } - } - - @Override - public Class getObjectType() { - return VectorStoreService.class; - } - -} diff --git a/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/service/VectorStoreManagementService.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/VectorStoreType.java similarity index 57% rename from spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/service/VectorStoreManagementService.java rename to spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/VectorStoreType.java index f182e5e..9881152 100644 --- a/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/service/VectorStoreManagementService.java +++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/VectorStoreType.java @@ -13,20 +13,32 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.alibaba.cloud.ai.service; -import com.alibaba.cloud.ai.request.DeleteRequest; -import com.alibaba.cloud.ai.request.EvidenceRequest; -import com.alibaba.cloud.ai.request.SchemaInitRequest; +package com.alibaba.cloud.ai.service.vectorstore; -import java.util.List; +/** + * 向量存储类型枚举 用于配置使用哪种向量存储实现 + */ +public enum VectorStoreType { -public interface VectorStoreManagementService { + /** + * 简单向量存储(内存存储) + */ + SIMPLE, - Boolean addEvidence(List evidenceRequests); + /** + * 分析型数据库向量存储 + */ + ANALYTIC_DB, - Boolean deleteDocuments(DeleteRequest deleteRequest) throws Exception; + /** + * Milvus 向量存储 + */ + MILVUS, - Boolean schema(SchemaInitRequest schemaInitRequest) throws Exception; + /** + * PostgreSQL PGVector 向量存储 + */ + PGVECTOR } diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/impls/AnalyticAgentVectorStoreService.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/impls/AnalyticAgentVectorStoreService.java new file mode 100644 index 0000000..fcf6051 --- /dev/null +++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/impls/AnalyticAgentVectorStoreService.java @@ -0,0 +1,49 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.cloud.ai.service.vectorstore.impls; + +import com.alibaba.cloud.ai.service.vectorstore.AbstractAgentVectorStoreService; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.vectorstore.VectorStore; + +import com.alibaba.cloud.ai.connector.accessor.AccessorFactory; +import com.alibaba.cloud.ai.vectorstore.analyticdb.AnalyticDbVectorStore; + +public class AnalyticAgentVectorStoreService extends AbstractAgentVectorStoreService { + + private final EmbeddingModel embeddingModel; + + private final AnalyticDbVectorStore analyticDbVectorStore; + + public AnalyticAgentVectorStoreService(EmbeddingModel embeddingModel, AnalyticDbVectorStore vectorStore, + AccessorFactory accessorFactory) { + super(accessorFactory); + this.embeddingModel = embeddingModel; + this.analyticDbVectorStore = vectorStore; + } + + @Override + protected EmbeddingModel getEmbeddingModel() { + return embeddingModel; + } + + @Override + protected VectorStore getVectorStore() { + return analyticDbVectorStore; + } + +} diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/impls/AnalyticVectorStoreService.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/impls/AnalyticVectorStoreService.java deleted file mode 100644 index 70f282a..0000000 --- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/impls/AnalyticVectorStoreService.java +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Copyright 2024-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.alibaba.cloud.ai.service.vectorstore.impls; - -import com.alibaba.cloud.ai.request.SearchRequest; -import com.alibaba.cloud.ai.service.vectorstore.AbstractVectorStoreService; -import com.alibaba.cloud.ai.vectorstore.analyticdb.AnalyticDbVectorStoreProperties; -import com.aliyun.gpdb20160503.Client; -import com.aliyun.gpdb20160503.models.QueryCollectionDataRequest; -import com.aliyun.gpdb20160503.models.QueryCollectionDataResponse; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import org.springframework.ai.document.Document; -import org.springframework.ai.embedding.EmbeddingModel; - -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - -public class AnalyticVectorStoreService extends AbstractVectorStoreService { - - private static final String CONTENT_FIELD_NAME = "content"; - - private static final String METADATA_FIELD_NAME = "metadata"; - - private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); - - private final EmbeddingModel embeddingModel; - - private final AnalyticDbVectorStoreProperties analyticDbVectorStoreProperties; - - private final Client client; - - public AnalyticVectorStoreService(AnalyticDbVectorStoreProperties analyticDbVectorStoreProperties, - EmbeddingModel embeddingModel, Client client) { - this.analyticDbVectorStoreProperties = analyticDbVectorStoreProperties; - this.embeddingModel = embeddingModel; - this.client = client; - } - - @Override - protected EmbeddingModel getEmbeddingModel() { - return embeddingModel; - } - - /** - * Search interface with default filter - */ - @Override - public List searchWithVectorType(SearchRequest searchRequestDTO) { - String filter = String.format("jsonb_extract_path_text(metadata, 'vectorType') = '%s'", - searchRequestDTO.getVectorType()); - - QueryCollectionDataRequest request = buildBaseRequest(searchRequestDTO).setFilter(filter); - - return executeQuery(request); - } - - /** - * Search interface with custom filter - */ - @Override - public List searchWithFilter(SearchRequest searchRequestDTO) { - QueryCollectionDataRequest request = buildBaseRequest(searchRequestDTO) - .setFilter(searchRequestDTO.getFilterFormatted()); - return executeQuery(request); - } - - /** - * Build basic query request object - */ - private QueryCollectionDataRequest buildBaseRequest(SearchRequest searchRequestDTO) { - QueryCollectionDataRequest queryCollectionDataRequest = new QueryCollectionDataRequest() - .setDBInstanceId(analyticDbVectorStoreProperties.getDbInstanceId()) - .setRegionId(analyticDbVectorStoreProperties.getRegionId()) - .setNamespace(analyticDbVectorStoreProperties.getNamespace()) - .setNamespacePassword(analyticDbVectorStoreProperties.getNamespacePassword()) - .setCollection(analyticDbVectorStoreProperties.getCollectName()) - .setIncludeValues(false) - .setMetrics(analyticDbVectorStoreProperties.getMetrics()) - .setTopK((long) searchRequestDTO.getTopK()); - if (searchRequestDTO.getQuery() != null) { - queryCollectionDataRequest.setVector(embedDouble(searchRequestDTO.getQuery())); - queryCollectionDataRequest.setContent(searchRequestDTO.getQuery()); - } - return queryCollectionDataRequest; - } - - /** - * Execute actual query and parse results - */ - private List executeQuery(QueryCollectionDataRequest request) { - try { - QueryCollectionDataResponse response = client.queryCollectionData(request); - return parseDocuments(response); - } - catch (Exception e) { - throw new RuntimeException("向量数据库查询失败: " + e.getMessage(), e); - } - } - - /** - * Parse response data into Document list - */ - private List parseDocuments(QueryCollectionDataResponse response) throws Exception { - return response.getBody() - .getMatches() - .getMatch() - .stream() - .filter(match -> match.getScore() == null || match.getScore() > 0.1 || match.getScore() == 0.0) - .map(match -> { - Map metadata = match.getMetadata(); - try { - Map metadataJson = OBJECT_MAPPER.readValue(metadata.get(METADATA_FIELD_NAME), - new TypeReference>() { - }); - metadataJson.put("score", match.getScore()); - - return new Document(match.getId(), metadata.get(CONTENT_FIELD_NAME), metadataJson); - } - catch (Exception e) { - throw new RuntimeException("解析元数据失败: " + e.getMessage(), e); - } - }) - .collect(Collectors.toList()); - } - -} diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/impls/SimpleAgentVectorStoreService.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/impls/SimpleAgentVectorStoreService.java new file mode 100644 index 0000000..da2d758 --- /dev/null +++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/impls/SimpleAgentVectorStoreService.java @@ -0,0 +1,47 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.cloud.ai.service.vectorstore.impls; + +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.vectorstore.SimpleVectorStore; +import org.springframework.ai.vectorstore.VectorStore; + +import com.alibaba.cloud.ai.connector.accessor.AccessorFactory; +import com.alibaba.cloud.ai.service.vectorstore.AbstractAgentVectorStoreService; + +public class SimpleAgentVectorStoreService extends AbstractAgentVectorStoreService { + + private final SimpleVectorStore vectorStore; + + private final EmbeddingModel embeddingModel; + + public SimpleAgentVectorStoreService(EmbeddingModel embeddingModel, AccessorFactory accessorFactory) { + super(accessorFactory); + this.embeddingModel = embeddingModel; + this.vectorStore = SimpleVectorStore.builder(embeddingModel).build(); + } + + @Override + protected EmbeddingModel getEmbeddingModel() { + return embeddingModel; + } + + @Override + protected VectorStore getVectorStore() { + return this.vectorStore; + } + +} diff --git a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/impls/SimpleVectorStoreService.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/impls/SimpleVectorStoreService.java deleted file mode 100644 index 4c5e59f..0000000 --- a/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/service/vectorstore/impls/SimpleVectorStoreService.java +++ /dev/null @@ -1,562 +0,0 @@ -/* - * Copyright 2024-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.alibaba.cloud.ai.service.vectorstore.impls; - -import com.alibaba.cloud.ai.connector.accessor.Accessor; -import com.alibaba.cloud.ai.connector.accessor.AccessorFactory; -import com.alibaba.cloud.ai.connector.bo.ColumnInfoBO; -import com.alibaba.cloud.ai.connector.bo.DbQueryParameter; -import com.alibaba.cloud.ai.connector.bo.ForeignKeyInfoBO; -import com.alibaba.cloud.ai.connector.bo.TableInfoBO; -import com.alibaba.cloud.ai.connector.config.DbConfig; -import com.alibaba.cloud.ai.request.DeleteRequest; -import com.alibaba.cloud.ai.request.SchemaInitRequest; -import com.alibaba.cloud.ai.request.SearchRequest; -import com.alibaba.cloud.ai.service.vectorstore.AgentVectorStoreManager; -import com.alibaba.cloud.ai.service.vectorstore.AbstractVectorStoreService; -import com.alibaba.cloud.ai.util.JsonUtil; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; -import org.apache.commons.collections.CollectionUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.ai.document.Document; -import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.vectorstore.SimpleVectorStore; -import org.springframework.ai.vectorstore.filter.Filter; -import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; -import java.util.stream.Collectors; - -public class SimpleVectorStoreService extends AbstractVectorStoreService { - - private static final Logger log = LoggerFactory.getLogger(SimpleVectorStoreService.class); - - private final SimpleVectorStore vectorStore; // Keep original global storage for - // backward compatibility - - private final AgentVectorStoreManager agentVectorStoreManager; // New agent vector - // storage manager - - private final ObjectMapper objectMapper; - - // todo: 根据数据源动态获取Accessor - private final Accessor dbAccessor; - - private final EmbeddingModel embeddingModel; - - public SimpleVectorStoreService(EmbeddingModel embeddingModel, AccessorFactory accessorFactory, - AgentVectorStoreManager agentVectorStoreManager) { - this.objectMapper = JsonUtil.getObjectMapper(); - this.dbAccessor = accessorFactory.getAccessorByDbConfig(null); - this.embeddingModel = embeddingModel; - this.agentVectorStoreManager = agentVectorStoreManager; - this.vectorStore = SimpleVectorStore.builder(embeddingModel).build(); - } - - @Override - protected EmbeddingModel getEmbeddingModel() { - return embeddingModel; - } - - /** - * Initialize database schema to vector store - * @param schemaInitRequest schema initialization request - * @throws Exception if an error occurs - */ - @Override - public Boolean schema(SchemaInitRequest schemaInitRequest) throws Exception { - log.info("Starting schema initialization for database: {}, schema: {}, tables: {}", - schemaInitRequest.getDbConfig().getUrl(), schemaInitRequest.getDbConfig().getSchema(), - schemaInitRequest.getTables()); - - DbConfig dbConfig = schemaInitRequest.getDbConfig(); - DbQueryParameter dqp = DbQueryParameter.from(dbConfig) - .setSchema(dbConfig.getSchema()) - .setTables(schemaInitRequest.getTables()); - - // Clean up old schema data - DeleteRequest deleteRequest = new DeleteRequest(); - deleteRequest.setVectorType("column"); - deleteDocuments(deleteRequest); - deleteRequest.setVectorType("table"); - deleteDocuments(deleteRequest); - - log.debug("Fetching foreign keys from database"); - List foreignKeyInfoBOS = dbAccessor.showForeignKeys(dbConfig, dqp); - log.debug("Found {} foreign keys", foreignKeyInfoBOS.size()); - Map> foreignKeyMap = buildForeignKeyMap(foreignKeyInfoBOS); - - log.debug("Fetching tables from database"); - List tableInfoBOS = dbAccessor.fetchTables(dbConfig, dqp); - log.info("Found {} tables to process", tableInfoBOS.size()); - - for (TableInfoBO tableInfoBO : tableInfoBOS) { - log.debug("Processing table: {}", tableInfoBO.getName()); - processTable(tableInfoBO, dqp, dbConfig, foreignKeyMap); - } - - log.debug("Converting columns to documents"); - List columnDocuments = tableInfoBOS.stream().flatMap(table -> { - try { - dqp.setTable(table.getName()); - return dbAccessor.showColumns(dbConfig, dqp).stream().map(column -> convertToDocument(table, column)); - } - catch (Exception e) { - log.error("Error processing columns for table: {}", table.getName(), e); - throw new RuntimeException(e); - } - }).collect(Collectors.toList()); - - log.info("Adding {} column documents to vector store", columnDocuments.size()); - vectorStore.add(columnDocuments); - - log.debug("Converting tables to documents"); - List tableDocuments = tableInfoBOS.stream() - .map(this::convertTableToDocument) - .collect(Collectors.toList()); - - log.info("Adding {} table documents to vector store", tableDocuments.size()); - vectorStore.add(tableDocuments); - - log.info("Schema initialization completed successfully. Total documents added: {}", - columnDocuments.size() + tableDocuments.size()); - return true; - } - - private void processTable(TableInfoBO tableInfoBO, DbQueryParameter dqp, DbConfig dbConfig, - Map> foreignKeyMap) throws Exception { - dqp.setTable(tableInfoBO.getName()); - List columnInfoBOS = dbAccessor.showColumns(dbConfig, dqp); - for (ColumnInfoBO columnInfoBO : columnInfoBOS) { - dqp.setColumn(columnInfoBO.getName()); - List sampleColumn = dbAccessor.sampleColumn(dbConfig, dqp); - sampleColumn = Optional.ofNullable(sampleColumn) - .orElse(new ArrayList<>()) - .stream() - .filter(Objects::nonNull) - .distinct() - .limit(3) - .filter(s -> s.length() <= 100) - .toList(); - - columnInfoBO.setTableName(tableInfoBO.getName()); - try { - columnInfoBO.setSamples(objectMapper.writeValueAsString(sampleColumn)); - } - catch (JsonProcessingException e) { - columnInfoBO.setSamples("[]"); - } - } - - List targetPrimaryList = columnInfoBOS.stream() - .filter(ColumnInfoBO::isPrimary) - .collect(Collectors.toList()); - if (CollectionUtils.isNotEmpty(targetPrimaryList)) { - List columnNames = targetPrimaryList.stream() - .map(ColumnInfoBO::getName) - .collect(Collectors.toList()); - tableInfoBO.setPrimaryKeys(columnNames); - } - else { - tableInfoBO.setPrimaryKeys(new ArrayList<>()); - } - tableInfoBO - .setForeignKey(String.join("、", foreignKeyMap.getOrDefault(tableInfoBO.getName(), new ArrayList<>()))); - } - - public Document convertToDocument(TableInfoBO tableInfoBO, ColumnInfoBO columnInfoBO) { - log.debug("Converting column to document: table={}, column={}", tableInfoBO.getName(), columnInfoBO.getName()); - - String text = Optional.ofNullable(columnInfoBO.getDescription()).orElse(columnInfoBO.getName()); - String id = tableInfoBO.getName() + "." + columnInfoBO.getName(); - Map metadata = new HashMap<>(); - metadata.put("id", id); - metadata.put("name", columnInfoBO.getName()); - metadata.put("tableName", tableInfoBO.getName()); - metadata.put("description", Optional.ofNullable(columnInfoBO.getDescription()).orElse("")); - metadata.put("type", columnInfoBO.getType()); - metadata.put("primary", columnInfoBO.isPrimary()); - metadata.put("notnull", columnInfoBO.isNotnull()); - metadata.put("vectorType", "column"); - if (columnInfoBO.getSamples() != null) { - metadata.put("samples", columnInfoBO.getSamples()); - } - // Multi-table duplicate field data will be deduplicated, using table name + field - // name as unique identifier - Document document = new Document(id, text, metadata); - log.debug("Created column document with ID: {}", id); - return document; - } - - public Document convertTableToDocument(TableInfoBO tableInfoBO) { - log.debug("Converting table to document: {}", tableInfoBO.getName()); - - String text = Optional.ofNullable(tableInfoBO.getDescription()).orElse(tableInfoBO.getName()); - Map metadata = new HashMap<>(); - metadata.put("schema", Optional.ofNullable(tableInfoBO.getSchema()).orElse("")); - metadata.put("name", tableInfoBO.getName()); - metadata.put("description", Optional.ofNullable(tableInfoBO.getDescription()).orElse("")); - metadata.put("foreignKey", Optional.ofNullable(tableInfoBO.getForeignKey()).orElse("")); - metadata.put("primaryKey", Optional.ofNullable(tableInfoBO.getPrimaryKeys()).orElse(new ArrayList<>())); - metadata.put("vectorType", "table"); - Document document = new Document(tableInfoBO.getName(), text, metadata); - log.debug("Created table document with ID: {}", tableInfoBO.getName()); - return document; - } - - private Map> buildForeignKeyMap(List foreignKeyInfoBOS) { - Map> foreignKeyMap = new HashMap<>(); - for (ForeignKeyInfoBO fk : foreignKeyInfoBOS) { - String key = fk.getTable() + "." + fk.getColumn() + "=" + fk.getReferencedTable() + "." - + fk.getReferencedColumn(); - - foreignKeyMap.computeIfAbsent(fk.getTable(), k -> new ArrayList<>()).add(key); - foreignKeyMap.computeIfAbsent(fk.getReferencedTable(), k -> new ArrayList<>()).add(key); - } - return foreignKeyMap; - } - - /** - * Delete vector data with specified conditions - * @param deleteRequest delete request - * @return whether deletion succeeded - */ - public Boolean deleteDocuments(DeleteRequest deleteRequest) throws Exception { - log.info("Starting delete operation with request: id={}, vectorType={}", deleteRequest.getId(), - deleteRequest.getVectorType()); - - try { - if (deleteRequest.getId() != null && !deleteRequest.getId().isEmpty()) { - log.debug("Deleting documents by ID: {}", deleteRequest.getId()); - vectorStore.delete(Arrays.asList(deleteRequest.getId())); - log.info("Successfully deleted documents by ID"); - } - else if (deleteRequest.getVectorType() != null && !deleteRequest.getVectorType().isEmpty()) { - log.debug("Deleting documents by vectorType: {}", deleteRequest.getVectorType()); - FilterExpressionBuilder b = new FilterExpressionBuilder(); - Filter.Expression expression = b.eq("vectorType", deleteRequest.getVectorType()).build(); - List documents = vectorStore - .similaritySearch(org.springframework.ai.vectorstore.SearchRequest.builder() - .topK(Integer.MAX_VALUE) - .filterExpression(expression) - .build()); - if (documents != null && !documents.isEmpty()) { - log.info("Found {} documents to delete with vectorType: {}", documents.size(), - deleteRequest.getVectorType()); - vectorStore.delete(documents.stream().map(Document::getId).toList()); - log.info("Successfully deleted {} documents", documents.size()); - } - else { - log.info("No documents found to delete with vectorType: {}", deleteRequest.getVectorType()); - } - } - else { - log.warn("Invalid delete request: either id or vectorType must be specified"); - throw new IllegalArgumentException("Either id or vectorType must be specified."); - } - return true; - } - catch (Exception e) { - log.error("Failed to delete documents: {}", e.getMessage(), e); - throw new Exception("Failed to delete collection data by filterExpression: " + e.getMessage(), e); - } - } - - /** - * Search interface with default filter - */ - @Override - public List searchWithVectorType(SearchRequest searchRequestDTO) { - log.debug("Searching with vectorType: {}, query: {}, topK: {}", searchRequestDTO.getVectorType(), - searchRequestDTO.getQuery(), searchRequestDTO.getTopK()); - - FilterExpressionBuilder b = new FilterExpressionBuilder(); - Filter.Expression expression = b.eq("vectorType", searchRequestDTO.getVectorType()).build(); - - List results = vectorStore.similaritySearch(org.springframework.ai.vectorstore.SearchRequest.builder() - .query(searchRequestDTO.getQuery()) - .topK(searchRequestDTO.getTopK()) - .filterExpression(expression) - .build()); - - if (results == null) { - results = new ArrayList<>(); - } - - log.info("Search completed. Found {} documents for vectorType: {}", results.size(), - searchRequestDTO.getVectorType()); - return results; - } - - /** - * Search interface with custom filter - */ - @Override - public List searchWithFilter(SearchRequest searchRequestDTO) { - log.debug("Searching with custom filter: vectorType={}, query={}, topK={}", searchRequestDTO.getVectorType(), - searchRequestDTO.getQuery(), searchRequestDTO.getTopK()); - - // Need to parse filterFormatted field according to actual situation here, convert - // to FilterExpressionBuilder expression - // Simplified implementation, for demonstration only - FilterExpressionBuilder b = new FilterExpressionBuilder(); - Filter.Expression expression = b.eq("vectorType", searchRequestDTO.getVectorType()).build(); - - List results = vectorStore.similaritySearch(org.springframework.ai.vectorstore.SearchRequest.builder() - .query(searchRequestDTO.getQuery()) - .topK(searchRequestDTO.getTopK()) - .filterExpression(expression) - .build()); - - if (results == null) { - results = new ArrayList<>(); - } - - log.info("Search with filter completed. Found {} documents", results.size()); - return results; - } - - @Override - public List searchTableByNameAndVectorType(SearchRequest searchRequestDTO) { - log.debug("Searching table by name and vectorType: name={}, vectorType={}, topK={}", searchRequestDTO.getName(), - searchRequestDTO.getVectorType(), searchRequestDTO.getTopK()); - - FilterExpressionBuilder b = new FilterExpressionBuilder(); - Filter.Expression expression = b - .and(b.eq("vectorType", searchRequestDTO.getVectorType()), b.eq("id", searchRequestDTO.getName())) - .build(); - - List results = vectorStore.similaritySearch(org.springframework.ai.vectorstore.SearchRequest.builder() - .topK(searchRequestDTO.getTopK()) - .filterExpression(expression) - .build()); - - if (results == null) { - results = new ArrayList<>(); - } - - log.info("Search by name completed. Found {} documents for name: {}", results.size(), - searchRequestDTO.getName()); - return results; - } - - // ==================== 智能体相关的新方法 ==================== - - /** - * Initialize database schema to vector store for specified agent - * @param agentId agent ID - * @param schemaInitRequest schema initialization request - * @throws Exception if an error occurs - */ - @Override - public Boolean schemaForAgent(String agentId, SchemaInitRequest schemaInitRequest) throws Exception { - log.info("Starting schema initialization for agent: {}, database: {}, schema: {}, tables: {}", agentId, - schemaInitRequest.getDbConfig().getUrl(), schemaInitRequest.getDbConfig().getSchema(), - schemaInitRequest.getTables()); - - DbConfig dbConfig = schemaInitRequest.getDbConfig(); - DbQueryParameter dqp = DbQueryParameter.from(dbConfig) - .setSchema(dbConfig.getSchema()) - .setTables(schemaInitRequest.getTables()); - - // Clean up agent's old data - agentVectorStoreManager.deleteDocumentsByType(agentId, "column"); - agentVectorStoreManager.deleteDocumentsByType(agentId, "table"); - - log.debug("Fetching foreign keys from database for agent: {}", agentId); - List foreignKeyInfoBOS = dbAccessor.showForeignKeys(dbConfig, dqp); - log.debug("Found {} foreign keys for agent: {}", foreignKeyInfoBOS.size(), agentId); - Map> foreignKeyMap = buildForeignKeyMap(foreignKeyInfoBOS); - - log.debug("Fetching tables from database for agent: {}", agentId); - List tableInfoBOS = dbAccessor.fetchTables(dbConfig, dqp); - log.info("Found {} tables to process for agent: {}", tableInfoBOS.size(), agentId); - - for (TableInfoBO tableInfoBO : tableInfoBOS) { - log.debug("Processing table: {} for agent: {}", tableInfoBO.getName(), agentId); - processTable(tableInfoBO, dqp, dbConfig, foreignKeyMap); - } - - log.debug("Converting columns to documents for agent: {}", agentId); - List columnDocuments = tableInfoBOS.stream().flatMap(table -> { - try { - dqp.setTable(table.getName()); - return dbAccessor.showColumns(dbConfig, dqp) - .stream() - .map(column -> convertToDocumentForAgent(agentId, table, column)); - } - catch (Exception e) { - log.error("Error processing columns for table: {} and agent: {}", table.getName(), agentId, e); - throw new RuntimeException(e); - } - }).collect(Collectors.toList()); - - log.info("Adding {} column documents to vector store for agent: {}", columnDocuments.size(), agentId); - agentVectorStoreManager.addDocuments(agentId, columnDocuments); - - log.debug("Converting tables to documents for agent: {}", agentId); - List tableDocuments = tableInfoBOS.stream() - .map(table -> convertTableToDocumentForAgent(agentId, table)) - .collect(Collectors.toList()); - - log.info("Adding {} table documents to vector store for agent: {}", tableDocuments.size(), agentId); - agentVectorStoreManager.addDocuments(agentId, tableDocuments); - - log.info("Schema initialization completed successfully for agent: {}. Total documents added: {}", agentId, - columnDocuments.size() + tableDocuments.size()); - return true; - } - - /** - * Convert column information to documents for agent - */ - private Document convertToDocumentForAgent(String agentId, TableInfoBO tableInfoBO, ColumnInfoBO columnInfoBO) { - log.debug("Converting column to document for agent: {}, table={}, column={}", agentId, tableInfoBO.getName(), - columnInfoBO.getName()); - - String text = Optional.ofNullable(columnInfoBO.getDescription()).orElse(columnInfoBO.getName()); - String id = agentId + ":" + tableInfoBO.getName() + "." + columnInfoBO.getName(); - Map metadata = new HashMap<>(); - metadata.put("id", id); - metadata.put("agentId", agentId); - metadata.put("name", columnInfoBO.getName()); - metadata.put("tableName", tableInfoBO.getName()); - metadata.put("description", Optional.ofNullable(columnInfoBO.getDescription()).orElse("")); - metadata.put("type", columnInfoBO.getType()); - metadata.put("primary", columnInfoBO.isPrimary()); - metadata.put("notnull", columnInfoBO.isNotnull()); - metadata.put("vectorType", "column"); - if (columnInfoBO.getSamples() != null) { - metadata.put("samples", columnInfoBO.getSamples()); - } - - Document document = new Document(id, text, metadata); - log.debug("Created column document with ID: {} for agent: {}", id, agentId); - return document; - } - - /** - * Convert table information to documents for agent - */ - private Document convertTableToDocumentForAgent(String agentId, TableInfoBO tableInfoBO) { - log.debug("Converting table to document for agent: {}, table: {}", agentId, tableInfoBO.getName()); - - String text = Optional.ofNullable(tableInfoBO.getDescription()).orElse(tableInfoBO.getName()); - String id = agentId + ":" + tableInfoBO.getName(); - Map metadata = new HashMap<>(); - metadata.put("agentId", agentId); - metadata.put("schema", Optional.ofNullable(tableInfoBO.getSchema()).orElse("")); - metadata.put("name", tableInfoBO.getName()); - metadata.put("description", Optional.ofNullable(tableInfoBO.getDescription()).orElse("")); - metadata.put("foreignKey", Optional.ofNullable(tableInfoBO.getForeignKey()).orElse("")); - metadata.put("primaryKey", Optional.ofNullable(tableInfoBO.getPrimaryKeys()).orElse(new ArrayList<>())); - metadata.put("vectorType", "table"); - - Document document = new Document(id, text, metadata); - log.debug("Created table document with ID: {} for agent: {}", id, agentId); - return document; - } - - /** - * Search vector data for specified agent - */ - public List searchWithVectorTypeForAgent(String agentId, SearchRequest searchRequestDTO) { - log.debug("Searching for agent: {}, vectorType: {}, query: {}, topK: {}", agentId, - searchRequestDTO.getVectorType(), searchRequestDTO.getQuery(), searchRequestDTO.getTopK()); - - List results = agentVectorStoreManager.similaritySearchWithFilter(agentId, - searchRequestDTO.getQuery(), searchRequestDTO.getTopK(), searchRequestDTO.getVectorType()); - - log.info("Search completed for agent: {}. Found {} documents for vectorType: {}", agentId, results.size(), - searchRequestDTO.getVectorType()); - return results; - } - - /** - * Delete vector data for specified agent - */ - public Boolean deleteDocumentsForAgent(String agentId, DeleteRequest deleteRequest) throws Exception { - log.info("Starting delete operation for agent: {}, id={}, vectorType={}", agentId, deleteRequest.getId(), - deleteRequest.getVectorType()); - - try { - if (deleteRequest.getId() != null && !deleteRequest.getId().isEmpty()) { - log.debug("Deleting documents by ID for agent: {}, ID: {}", agentId, deleteRequest.getId()); - agentVectorStoreManager.deleteDocuments(agentId, Arrays.asList(deleteRequest.getId())); - log.info("Successfully deleted documents by ID for agent: {}", agentId); - } - else if (deleteRequest.getVectorType() != null && !deleteRequest.getVectorType().isEmpty()) { - log.debug("Deleting documents by vectorType for agent: {}, vectorType: {}", agentId, - deleteRequest.getVectorType()); - agentVectorStoreManager.deleteDocumentsByType(agentId, deleteRequest.getVectorType()); - log.info("Successfully deleted documents by vectorType for agent: {}", agentId); - } - else { - log.warn("Invalid delete request for agent: {}: either id or vectorType must be specified", agentId); - throw new IllegalArgumentException("Either id or vectorType must be specified."); - } - return true; - } - catch (Exception e) { - log.error("Failed to delete documents for agent: {}: {}", agentId, e.getMessage(), e); - throw new Exception("Failed to delete collection data for agent " + agentId + ": " + e.getMessage(), e); - } - } - - /** - * Get agent vector storage manager (for other services to use) - */ - @Override - public AgentVectorStoreManager getAgentVectorStoreManager() { - return agentVectorStoreManager; - } - - /** - * Get documents from vector store for specified agent Override parent method, use - * agent-specific vector storage - */ - @Override - public List getDocumentsForAgent(String agentId, String query, String vectorType) { - log.debug("Getting documents for agent: {}, query: {}, vectorType: {}", agentId, query, vectorType); - - if (agentId == null || agentId.trim().isEmpty()) { - log.warn("AgentId is null or empty, falling back to global search"); - return getDocuments(query, vectorType); - } - - try { - // Use agent vector storage manager for search - List results = agentVectorStoreManager.similaritySearchWithFilter(agentId, query, 100, // topK - vectorType); - - log.info("Found {} documents for agent: {}, vectorType: {}", results.size(), agentId, vectorType); - return results; - } - catch (Exception e) { - log.error("Error getting documents for agent: {}, falling back to global search", agentId, e); - return getDocuments(query, vectorType); - } - } - -} diff --git a/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/util/DocumentConverterUtil.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/util/DocumentConverterUtil.java similarity index 74% rename from spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/util/DocumentConverterUtil.java rename to spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/util/DocumentConverterUtil.java index d16d8db..fc1e781 100644 --- a/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/util/DocumentConverterUtil.java +++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/util/DocumentConverterUtil.java @@ -15,12 +15,14 @@ */ package com.alibaba.cloud.ai.util; +import java.util.*; +import java.util.stream.Collectors; + +import org.springframework.ai.document.Document; + import com.alibaba.cloud.ai.connector.bo.ColumnInfoBO; import com.alibaba.cloud.ai.connector.bo.TableInfoBO; import com.alibaba.cloud.ai.request.EvidenceRequest; -import org.springframework.ai.document.Document; - -import java.util.*; /** * Utility class for converting business objects to Document objects. Provides common @@ -28,13 +30,28 @@ */ public class DocumentConverterUtil { + public static List convertColumnsToDocuments(String agentId, List tables) { + List documents = new ArrayList<>(); + for (TableInfoBO table : tables) { + // 使用已经处理过的列数据,避免重复查询 + List columns = table.getColumns(); + if (columns != null) { + for (ColumnInfoBO column : columns) { + documents.add(DocumentConverterUtil.convertColumnToDocumentForAgent(agentId, table, column)); + } + } + } + return documents; + } + /** * Converts a column info object to a Document for vector storage. * @param tableInfoBO the table information containing schema details * @param columnInfoBO the column information to convert * @return Document object with column metadata */ - public static Document convertColumnToDocument(TableInfoBO tableInfoBO, ColumnInfoBO columnInfoBO) { + public static Document convertColumnToDocumentForAgent(String agentId, TableInfoBO tableInfoBO, + ColumnInfoBO columnInfoBO) { String text = Optional.ofNullable(columnInfoBO.getDescription()).orElse(columnInfoBO.getName()); Map metadata = new HashMap<>(); metadata.put("name", columnInfoBO.getName()); @@ -44,6 +61,7 @@ public static Document convertColumnToDocument(TableInfoBO tableInfoBO, ColumnIn metadata.put("primary", columnInfoBO.isPrimary()); metadata.put("notnull", columnInfoBO.isNotnull()); metadata.put("vectorType", "column"); + metadata.put("agentId", agentId); if (columnInfoBO.getSamples() != null) { metadata.put("samples", columnInfoBO.getSamples()); @@ -57,7 +75,7 @@ public static Document convertColumnToDocument(TableInfoBO tableInfoBO, ColumnIn * @param tableInfoBO the table information to convert * @return Document object with table metadata */ - public static Document convertTableToDocument(TableInfoBO tableInfoBO) { + public static Document convertTableToDocumentForAgent(String agentId, TableInfoBO tableInfoBO) { String text = Optional.ofNullable(tableInfoBO.getDescription()).orElse(tableInfoBO.getName()); Map metadata = new HashMap<>(); metadata.put("schema", Optional.ofNullable(tableInfoBO.getSchema()).orElse("")); @@ -66,21 +84,29 @@ public static Document convertTableToDocument(TableInfoBO tableInfoBO) { metadata.put("foreignKey", Optional.ofNullable(tableInfoBO.getForeignKey()).orElse("")); metadata.put("primaryKey", Optional.ofNullable(tableInfoBO.getPrimaryKeys()).orElse(new ArrayList<>())); metadata.put("vectorType", "table"); - + metadata.put("agentId", agentId); return new Document(tableInfoBO.getName(), text, metadata); } + public static List convertTablesToDocuments(String agentId, List tables) { + return tables.stream() + .map(table -> DocumentConverterUtil.convertTableToDocumentForAgent(agentId, table)) + .collect(Collectors.toList()); + } + /** * Converts evidence requests to Documents for vector storage. * @param evidenceRequests list of evidence requests * @return list of Document objects */ - public static List convertEvidenceToDocuments(List evidenceRequests) { + public static List convertEvidenceToDocumentsForAgent(String agentId, + List evidenceRequests) { return evidenceRequests.stream().map(evidenceRequest -> { Map metadata = new HashMap<>(); metadata.put("evidenceType", evidenceRequest.getType()); metadata.put("vectorType", "evidence"); + metadata.put("agentId", agentId); return new Document(UUID.randomUUID().toString(), evidenceRequest.getContent(), metadata); }).toList(); } diff --git a/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/util/SchemaProcessorUtil.java b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/util/SchemaProcessorUtil.java similarity index 75% rename from spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/util/SchemaProcessorUtil.java rename to spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/util/SchemaProcessorUtil.java index dd8a1f8..358e038 100644 --- a/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/util/SchemaProcessorUtil.java +++ b/spring-ai-alibaba-data-agent-chat/src/main/java/com/alibaba/cloud/ai/util/SchemaProcessorUtil.java @@ -15,6 +15,9 @@ */ package com.alibaba.cloud.ai.util; +import java.util.*; +import java.util.stream.Collectors; + import com.alibaba.cloud.ai.connector.accessor.Accessor; import com.alibaba.cloud.ai.connector.bo.ColumnInfoBO; import com.alibaba.cloud.ai.connector.bo.DbQueryParameter; @@ -22,10 +25,6 @@ import com.alibaba.cloud.ai.connector.config.DbConfig; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; -import org.apache.commons.collections.CollectionUtils; - -import java.util.*; -import java.util.stream.Collectors; /** * Utility class for processing database schema information. Provides common schema @@ -71,12 +70,10 @@ public static void enrichTableMetadata(TableInfoBO tableInfoBO, DbQueryParameter // 保存处理过的列数据到TableInfoBO,供后续使用 tableInfoBO.setColumns(columnInfoBOS); - List targetPrimaryList = columnInfoBOS.stream() - .filter(ColumnInfoBO::isPrimary) - .collect(Collectors.toList()); + List primaryKeyColumns = columnInfoBOS.stream().filter(ColumnInfoBO::isPrimary).toList(); - if (CollectionUtils.isNotEmpty(targetPrimaryList)) { - List columnNames = targetPrimaryList.stream() + if (!primaryKeyColumns.isEmpty()) { + List columnNames = primaryKeyColumns.stream() .map(ColumnInfoBO::getName) .collect(Collectors.toList()); tableInfoBO.setPrimaryKeys(columnNames); @@ -89,6 +86,31 @@ public static void enrichTableMetadata(TableInfoBO tableInfoBO, DbQueryParameter .setForeignKey(String.join("、", foreignKeyMap.getOrDefault(tableInfoBO.getName(), new ArrayList<>()))); } + public static DbConfig createDbConfigFromDatasource(com.alibaba.cloud.ai.entity.Datasource datasource) { + DbConfig dbConfig = new DbConfig(); + + // Set basic connection information + dbConfig.setUrl(datasource.getConnectionUrl()); + dbConfig.setUsername(datasource.getUsername()); + dbConfig.setPassword(datasource.getPassword()); + + // TODO Set database type need to be optimized + if ("mysql".equalsIgnoreCase(datasource.getType())) { + dbConfig.setConnectionType("jdbc"); + dbConfig.setDialectType("mysql"); + } + else if ("h2".equalsIgnoreCase(datasource.getType())) { + dbConfig.setConnectionType("jdbc"); + dbConfig.setDialectType("h2"); + } + // Support for other database types can be extended here + + // Set Schema as the database name of the data source + dbConfig.setSchema(datasource.getDatabaseName()); + + return dbConfig; + } + /** * Private constructor to prevent instantiation. */ diff --git a/spring-ai-alibaba-data-agent-common/src/main/java/com/alibaba/cloud/ai/request/AgentSearchRequest.java b/spring-ai-alibaba-data-agent-common/src/main/java/com/alibaba/cloud/ai/request/AgentSearchRequest.java new file mode 100644 index 0000000..1e9066e --- /dev/null +++ b/spring-ai-alibaba-data-agent-common/src/main/java/com/alibaba/cloud/ai/request/AgentSearchRequest.java @@ -0,0 +1,128 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.cloud.ai.request; + +import java.io.Serial; +import java.io.Serializable; +import java.util.Map; +import java.util.Objects; + +public class AgentSearchRequest implements Serializable { + + @Serial + private static final long serialVersionUID = 1L; + + private final String agentId; + + private String query; + + private Integer topK; + + private Map metadataFilter; + + /** + * 私有构造函数,禁止直接实例化 此构造函数仅内部使用,外部代码必须通过getInstance()方法创建实例 + */ + private AgentSearchRequest(String agentId) { + Objects.requireNonNull(agentId, "Agent ID cannot be null"); + this.agentId = agentId; + // 初始化metadataFilter,确保始终包含agentId + this.metadataFilter = Map.of("agentId", agentId); + } + + /** + * 创建AgentSearchRequest实例的工厂方法 + * @param agentId 代理ID,不能为空 + * @return AgentSearchRequest实例 + * @throws IllegalArgumentException 如果agentId为空 + */ + public static AgentSearchRequest getInstance(String agentId) { + return new AgentSearchRequest(agentId); + } + + /** + * 创建AgentSearchRequest实例的工厂方法 + * @param agentId 代理ID,不能为空 + * @param query 查询内容 + * @return AgentSearchRequest实例 + * @throws IllegalArgumentException 如果agentId为空 + */ + public static AgentSearchRequest getInstance(String agentId, String query) { + AgentSearchRequest request = new AgentSearchRequest(agentId); + request.setQuery(query); + return request; + } + + /** + * 创建AgentSearchRequest实例的工厂方法 + * @param agentId 代理ID,不能为空 + * @param query 查询内容 + * @param topK 返回结果数量 + * @return AgentSearchRequest实例 + * @throws IllegalArgumentException 如果agentId为空 + */ + public static AgentSearchRequest getInstance(String agentId, String query, Integer topK) { + AgentSearchRequest request = new AgentSearchRequest(agentId); + request.setQuery(query); + request.setTopK(topK); + return request; + } + + public String getAgentId() { + return agentId; + } + + public Map getMetadataFilter() { + return metadataFilter; + } + + public void setMetadataFilter(Map metadataFilter) { + if (metadataFilter == null) { + // 如果传入null,则创建只包含agentId的map + this.metadataFilter = Map.of("agentId", agentId); + } + else { + // 创建新的map,包含传入的所有参数和agentId + Map newFilter = new java.util.HashMap<>(metadataFilter); + newFilter.put("agentId", agentId); + this.metadataFilter = Map.copyOf(newFilter); // 创建不可变副本 + } + } + + public String getQuery() { + return query; + } + + public void setQuery(String query) { + this.query = query; + } + + public Integer getTopK() { + return topK; + } + + public void setTopK(Integer topK) { + this.topK = topK; + } + + @Override + public String toString() { + return "AgentSearchRequest{" + "agentId='" + agentId + '\'' + ", query='" + query + '\'' + ", topK=" + topK + + ", metadataFilter=" + metadataFilter + '}'; + } + +} diff --git a/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/controller/Nl2sqlForGraphController.java b/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/controller/Nl2sqlForGraphController.java index f4cd255..56b72da 100644 --- a/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/controller/Nl2sqlForGraphController.java +++ b/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/controller/Nl2sqlForGraphController.java @@ -16,17 +16,14 @@ package com.alibaba.cloud.ai.controller; -import com.alibaba.cloud.ai.connector.config.DbConfig; import com.alibaba.cloud.ai.constant.Constant; import com.alibaba.cloud.ai.graph.*; import com.alibaba.cloud.ai.graph.exception.GraphStateException; import com.alibaba.cloud.ai.graph.streaming.StreamingOutput; import com.alibaba.cloud.ai.graph.NodeOutput; import com.alibaba.cloud.ai.graph.state.StateSnapshot; -import com.alibaba.cloud.ai.request.SchemaInitRequest; -import com.alibaba.cloud.ai.service.vectorstore.VectorStoreService; +import com.alibaba.cloud.ai.service.vectorstore.AgentVectorStoreService; import com.alibaba.cloud.ai.service.DatasourceService; -import com.alibaba.cloud.ai.entity.Datasource; import com.alibaba.cloud.ai.service.AgentService; import com.alibaba.cloud.ai.util.JsonUtil; import com.fasterxml.jackson.databind.ObjectMapper; @@ -43,12 +40,10 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Sinks; -import java.util.Arrays; import java.util.Map; import java.util.Optional; import java.util.concurrent.CompletableFuture; -import static com.alibaba.cloud.ai.constant.Constant.AGENT_ID; import static com.alibaba.cloud.ai.constant.Constant.HUMAN_FEEDBACK_NODE; import static com.alibaba.cloud.ai.constant.Constant.INPUT_KEY; import static com.alibaba.cloud.ai.constant.Constant.RESULT; @@ -66,7 +61,7 @@ public class Nl2sqlForGraphController { private final CompiledGraph compiledGraph; - private final VectorStoreService vectorStoreService; + private final AgentVectorStoreService vectorStoreService; private final DatasourceService datasourceService; @@ -75,7 +70,7 @@ public class Nl2sqlForGraphController { private final ObjectMapper objectMapper = JsonUtil.getObjectMapper(); public Nl2sqlForGraphController(@Qualifier("nl2sqlGraph") StateGraph stateGraph, - VectorStoreService vectorStoreService, DatasourceService datasourceService, AgentService agentService) + AgentVectorStoreService vectorStoreService, DatasourceService datasourceService, AgentService agentService) throws GraphStateException { this.compiledGraph = stateGraph.compile(CompileConfig.builder().interruptBefore(HUMAN_FEEDBACK_NODE).build()); this.compiledGraph.setMaxIterations(100); @@ -84,106 +79,6 @@ public Nl2sqlForGraphController(@Qualifier("nl2sqlGraph") StateGraph stateGraph, this.agentService = agentService; } - @GetMapping("/search") - public String search( - @RequestParam(value = "query", required = false, - defaultValue = "查询每个分类下已经成交且销量最高的商品及其销售总量,每个分类只返回销量最高的商品。") String query, - @RequestParam(value = "dataSetId", required = false, defaultValue = "1") String dataSetId, - @RequestParam(value = "agentId", required = false, defaultValue = "1") String agentId) throws Exception { - // Get the data source configuration for an agent for vector initialization - DbConfig dbConfig = getDbConfigForAgent(Integer.valueOf(agentId)); - - SchemaInitRequest schemaInitRequest = new SchemaInitRequest(); - schemaInitRequest.setDbConfig(dbConfig); - schemaInitRequest - .setTables(Arrays.asList("categories", "order_items", "orders", "products", "users", "product_categories")); - vectorStoreService.schema(schemaInitRequest); - - boolean humanReviewEnabled = false; - try { - var agent = agentService.findById(Long.valueOf(agentId)); - humanReviewEnabled = agent != null && agent.getHumanReviewEnabled() != null - && agent.getHumanReviewEnabled() == 1; - } - catch (Exception ignore) { - } - - Optional invoke = compiledGraph - .call(Map.of(INPUT_KEY, query, AGENT_ID, agentId, HUMAN_REVIEW_ENABLED, humanReviewEnabled)); - OverAllState overAllState = invoke.get(); - // 注意:在新的人类反馈实现中,计划内容通过流式处理发送给前端 - // 这里不再需要单独获取计划内容 - return overAllState.value(RESULT).map(Object::toString).orElse(""); - } - - @GetMapping("/init") - public void init(@RequestParam(value = "agentId", required = false, defaultValue = "1") Integer agentId) - throws Exception { - // Get the data source configuration for an agent for vector initialization - DbConfig dbConfig = getDbConfigForAgent(agentId); - - SchemaInitRequest schemaInitRequest = new SchemaInitRequest(); - schemaInitRequest.setDbConfig(dbConfig); - schemaInitRequest - .setTables(Arrays.asList("categories", "order_items", "orders", "products", "users", "product_categories")); - vectorStoreService.schema(schemaInitRequest); - } - - /** - * Get database configuration by agent ID - */ - private DbConfig getDbConfigForAgent(Integer agentId) { - try { - // Get the enabled data source for an agent - var agentDatasources = datasourceService.getAgentDatasources(agentId); - var activeDatasource = agentDatasources.stream() - .filter(ad -> ad.getIsActive() == 1) - .findFirst() - .orElseThrow(() -> new RuntimeException("智能体 " + agentId + " 未配置启用的数据源")); - - // Convert to DbConfig - return createDbConfigFromDatasource(activeDatasource.getDatasource()); - } - catch (Exception e) { - logger.error("Failed to get agent datasource config for agent: {}", agentId, e); - throw new RuntimeException("获取智能体数据源配置失败: " + e.getMessage(), e); - } - } - - /** - * Create database configuration from data source entity - */ - private DbConfig createDbConfigFromDatasource(Datasource datasource) { - DbConfig dbConfig = new DbConfig(); - - // Set basic connection information - dbConfig.setUrl(datasource.getConnectionUrl()); - dbConfig.setUsername(datasource.getUsername()); - dbConfig.setPassword(datasource.getPassword()); - - // Set database type - if ("mysql".equalsIgnoreCase(datasource.getType())) { - dbConfig.setConnectionType("jdbc"); - dbConfig.setDialectType("mysql"); - } - else if ("postgresql".equalsIgnoreCase(datasource.getType())) { - dbConfig.setConnectionType("jdbc"); - dbConfig.setDialectType("postgresql"); - } - else if ("h2".equalsIgnoreCase(datasource.getType())) { - dbConfig.setConnectionType("jdbc"); - dbConfig.setDialectType("h2"); - } - else { - throw new RuntimeException("不支持的数据库类型: " + datasource.getType()); - } - - // Set Schema to the database name of the data source - dbConfig.setSchema(datasource.getDatabaseName()); - - return dbConfig; - } - @GetMapping(value = "/stream/search", produces = MediaType.TEXT_EVENT_STREAM_VALUE) public Flux> streamSearch(@RequestParam(value = "query") String query, @RequestParam(value = "agentId") String agentId, diff --git a/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/service/AbstractVectorStoreManagementService.java b/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/service/AbstractVectorStoreManagementService.java deleted file mode 100644 index 92ea915..0000000 --- a/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/service/AbstractVectorStoreManagementService.java +++ /dev/null @@ -1,177 +0,0 @@ -/* - * Copyright 2024-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.alibaba.cloud.ai.service; - -import com.alibaba.cloud.ai.connector.accessor.Accessor; -import com.alibaba.cloud.ai.connector.accessor.AccessorFactory; -import com.alibaba.cloud.ai.connector.bo.ColumnInfoBO; -import com.alibaba.cloud.ai.connector.bo.DbQueryParameter; -import com.alibaba.cloud.ai.connector.bo.ForeignKeyInfoBO; -import com.alibaba.cloud.ai.connector.bo.TableInfoBO; -import com.alibaba.cloud.ai.connector.config.DbConfig; -import com.alibaba.cloud.ai.request.DeleteRequest; -import com.alibaba.cloud.ai.request.EvidenceRequest; -import com.alibaba.cloud.ai.request.SchemaInitRequest; -import com.alibaba.cloud.ai.util.DocumentConverterUtil; -import com.alibaba.cloud.ai.util.JsonUtil; -import com.alibaba.cloud.ai.util.SchemaProcessorUtil; -import jakarta.annotation.PostConstruct; -import lombok.extern.slf4j.Slf4j; -import org.springframework.ai.document.Document; -import org.springframework.ai.vectorstore.VectorStore; -import org.springframework.ai.vectorstore.filter.Filter; -import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder; -import org.springframework.beans.factory.annotation.Autowired; - -import java.util.*; -import java.util.stream.Collectors; - -@Slf4j -public abstract class AbstractVectorStoreManagementService implements VectorStoreManagementService { - - @Autowired - protected AccessorFactory accessorFactory; - - protected Accessor dbAccessor; - - protected DbConfig dbConfig; - - @PostConstruct - public void init() { - this.dbAccessor = accessorFactory.getAccessorByDbConfig(dbConfig); - } - - @Override - public Boolean addEvidence(List evidenceRequests) { - List evidences = DocumentConverterUtil.convertEvidenceToDocuments(evidenceRequests); - getVectorStore().add(evidences); - return true; - } - - @Override - public Boolean deleteDocuments(DeleteRequest deleteRequest) throws Exception { - try { - if (deleteRequest.getId() != null && !deleteRequest.getId().isEmpty()) { - getVectorStore().delete(List.of(deleteRequest.getId())); - } - else if (deleteRequest.getVectorType() != null && !deleteRequest.getVectorType().isEmpty()) { - FilterExpressionBuilder builder = new FilterExpressionBuilder(); - Filter.Expression filterExpression = builder.eq("vectorType", deleteRequest.getVectorType()).build(); - - getVectorStore().delete(filterExpression); - } - else { - throw new IllegalArgumentException("Either id or vectorType must be specified."); - } - return true; - } - catch (Exception e) { - throw new Exception("Failed to delete collection data by filterExpression: " + e.getMessage(), e); - } - } - - protected abstract VectorStore getVectorStore(); - - // 模板方法 - 通用schema处理流程 - @Override - public Boolean schema(SchemaInitRequest schemaInitRequest) throws Exception { - try { - - DbConfig config = schemaInitRequest.getDbConfig(); - DbQueryParameter dqp = DbQueryParameter.from(config) - .setSchema(config.getSchema()) - .setTables(schemaInitRequest.getTables()); - - // 清理旧数据 - clearSchemaData(); - - // 处理外键 - List foreignKeys = dbAccessor.showForeignKeys(config, dqp); - Map> foreignKeyMap = buildForeignKeyMap(foreignKeys); - - // 处理表和列 - List tables = dbAccessor.fetchTables(config, dqp); - for (TableInfoBO table : tables) { - SchemaProcessorUtil.enrichTableMetadata(table, dqp, config, dbAccessor, JsonUtil.getObjectMapper(), - foreignKeyMap); - } - - // 转换为文档 - List columnDocs = convertColumnsToDocuments(tables); - List tableDocs = convertTablesToDocuments(tables); - - // 存储文档 - return storeSchemaDocuments(columnDocs, tableDocs); - } - catch (Exception e) { - log.error("Failed to process schema ", e); - return false; - } - } - - // 通用辅助方法 - protected Map> buildForeignKeyMap(List foreignKeys) { - Map> map = new HashMap<>(); - for (ForeignKeyInfoBO fk : foreignKeys) { - String key = fk.getTable() + "." + fk.getColumn() + "=" + fk.getReferencedTable() + "." - + fk.getReferencedColumn(); - - map.computeIfAbsent(fk.getTable(), k -> new ArrayList<>()).add(key); - map.computeIfAbsent(fk.getReferencedTable(), k -> new ArrayList<>()).add(key); - } - return map; - } - - private List convertColumnsToDocuments(List tables) throws Exception { - List documents = new ArrayList<>(); - for (TableInfoBO table : tables) { - // 使用已经处理过的列数据,避免重复查询 - List columns = table.getColumns(); - if (columns != null) { - for (ColumnInfoBO column : columns) { - documents.add(DocumentConverterUtil.convertColumnToDocument(table, column)); - } - } - } - return documents; - } - - private List convertTablesToDocuments(List tables) { - return tables.stream().map(DocumentConverterUtil::convertTableToDocument).collect(Collectors.toList()); - } - - protected Boolean storeSchemaDocuments(List columns, List tables) throws Exception { - try { - getVectorStore().add(columns); - getVectorStore().add(tables); - return true; - } - catch (Exception e) { - log.error("vectorstore schemaDocuments error", e); - return false; - } - - } - - protected void clearSchemaData() throws Exception { - DeleteRequest deleteRequest = new DeleteRequest(); - deleteRequest.setVectorType("column"); - deleteDocuments(deleteRequest); - deleteRequest.setVectorType("table"); - deleteDocuments(deleteRequest); - } - -} diff --git a/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/service/impl/AgentStartupInitializationService.java b/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/service/impl/AgentStartupInitializationService.java index 7c96bf1..9d27338 100644 --- a/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/service/impl/AgentStartupInitializationService.java +++ b/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/service/impl/AgentStartupInitializationService.java @@ -28,7 +28,6 @@ import org.springframework.stereotype.Service; import java.util.List; -import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -127,13 +126,10 @@ private boolean initializeAgentDataSource(Agent agent) { try { Long agentId = agent.getId(); - Map statistics = agentVectorService.getVectorStatistics(agentId); - boolean hasData = (Boolean) statistics.getOrDefault("hasData", false); - int documentCount = (Integer) statistics.getOrDefault("documentCount", 0); + boolean hasData = agentVectorService.isAlreadyInitialized(agentId); - if (hasData && documentCount > 0) { - log.info("Agent {} already has vector data (documents: {}), skipping initialization", agentId, - documentCount); + if (hasData) { + log.info("Agent {} already has vector data , skipping initialization", agentId); return true; } diff --git a/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/service/impl/AgentVectorService.java b/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/service/impl/AgentVectorService.java index 410f100..fa238b2 100644 --- a/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/service/impl/AgentVectorService.java +++ b/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/service/impl/AgentVectorService.java @@ -16,13 +16,15 @@ package com.alibaba.cloud.ai.service.impl; import com.alibaba.cloud.ai.connector.accessor.AccessorFactory; +import com.alibaba.cloud.ai.constant.Constant; import com.alibaba.cloud.ai.entity.AgentKnowledge; import com.alibaba.cloud.ai.request.SchemaInitRequest; import com.alibaba.cloud.ai.service.DatasourceService; -import com.alibaba.cloud.ai.service.vectorstore.VectorStoreService; +import com.alibaba.cloud.ai.service.vectorstore.AgentVectorStoreService; import com.alibaba.cloud.ai.connector.bo.DbQueryParameter; import com.alibaba.cloud.ai.connector.bo.TableInfoBO; import com.alibaba.cloud.ai.connector.config.DbConfig; +import com.alibaba.cloud.ai.util.SchemaProcessorUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; @@ -31,10 +33,8 @@ import java.util.ArrayList; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Set; /** * Agent Vector Storage Service Specializes in handling agent-related vector storage @@ -45,28 +45,8 @@ public class AgentVectorService { private static final Logger log = LoggerFactory.getLogger(AgentVectorService.class); - // todo: 提取实现类的方法到接口 @Autowired - private VectorStoreService vectorStoreService; - - /** - * Initialize database Schema for agent - * @param agentId agent ID - * @param schemaInitRequest Schema initialization request - * @return success status - */ - public Boolean initializeSchemaForAgent(Long agentId, SchemaInitRequest schemaInitRequest) { - try { - String agentIdStr = String.valueOf(agentId); - log.info("Initializing schema for agent: {}", agentIdStr); - - return vectorStoreService.schemaForAgent(agentIdStr, schemaInitRequest); - } - catch (Exception e) { - log.error("Failed to initialize schema for agent: {}", agentId, e); - throw new RuntimeException("Failed to initialize schema for agent " + agentId + ": " + e.getMessage(), e); - } - } + private AgentVectorStoreService vectorStoreService; /** * Add knowledge document to vector store for agent @@ -82,7 +62,7 @@ public void addKnowledgeToVector(Long agentId, AgentKnowledge knowledge) { Document document = createDocumentFromKnowledge(agentIdStr, knowledge); // Add to vector store - vectorStoreService.getAgentVectorStoreManager().addDocuments(agentIdStr, List.of(document)); + vectorStoreService.addDocuments(agentIdStr, List.of(document)); log.info("Successfully added knowledge to vector store for agent: {}", agentIdStr); } @@ -93,89 +73,6 @@ public void addKnowledgeToVector(Long agentId, AgentKnowledge knowledge) { } } - /** - * Batch add knowledge documents to vector store for agent - * @param agentId agent ID - * @param knowledgeList knowledge list - */ - public void addKnowledgeListToVector(Long agentId, List knowledgeList) { - if (knowledgeList == null || knowledgeList.isEmpty()) { - log.warn("No knowledge to add for agent: {}", agentId); - return; - } - - try { - String agentIdStr = String.valueOf(agentId); - log.info("Adding {} knowledge items to vector store for agent: {}", knowledgeList.size(), agentIdStr); - - // Create document列表 - List documents = knowledgeList.stream() - .map(knowledge -> createDocumentFromKnowledge(agentIdStr, knowledge)) - .toList(); - - // Batch add to vector store - vectorStoreService.getAgentVectorStoreManager().addDocuments(agentIdStr, documents); - - log.info("Successfully added {} knowledge items to vector store for agent: {}", documents.size(), - agentIdStr); - } - catch (Exception e) { - log.error("Failed to add knowledge list to vector store for agent: {}", agentId, e); - throw new RuntimeException("Failed to add knowledge list to vector store: " + e.getMessage(), e); - } - } - - /** - * Search related knowledge from vector store - * @param agentId agent ID - * @param query query text - * @param topK number of results to return - * @return list of related documents - */ - public List searchKnowledge(Long agentId, String query, int topK) { - try { - String agentIdStr = String.valueOf(agentId); - log.debug("Searching knowledge for agent: {}, query: {}, topK: {}", agentIdStr, query, topK); - - List results = vectorStoreService.getAgentVectorStoreManager() - .similaritySearch(agentIdStr, query, topK); - - log.info("Found {} knowledge documents for agent: {}", results.size(), agentIdStr); - return results; - } - catch (Exception e) { - log.error("Failed to search knowledge for agent: {}", agentId, e); - throw new RuntimeException("Failed to search knowledge: " + e.getMessage(), e); - } - } - - /** - * Search specific type of knowledge from vector store - * @param agentId agent ID - * @param query query text - * @param topK number of results to return - * @param knowledgeType knowledge type - * @return list of related documents - */ - public List searchKnowledgeByType(Long agentId, String query, int topK, String knowledgeType) { - try { - String agentIdStr = String.valueOf(agentId); - log.debug("Searching knowledge by type for agent: {}, query: {}, topK: {}, type: {}", agentIdStr, query, - topK, knowledgeType); - - List results = vectorStoreService.getAgentVectorStoreManager() - .similaritySearchWithFilter(agentIdStr, query, topK, "knowledge:" + knowledgeType); - - log.info("Found {} knowledge documents of type '{}' for agent: {}", results.size(), knowledgeType, - agentIdStr); - return results; - } - catch (Exception e) { - log.error("Failed to search knowledge by type for agent: {}", agentId, e); - throw new RuntimeException("Failed to search knowledge by type: " + e.getMessage(), e); - } - } - /** * Delete specific knowledge document of agent * @param agentId agent ID @@ -188,7 +85,10 @@ public void deleteKnowledgeFromVector(Long agentId, Integer knowledgeId) { log.info("Deleting knowledge from vector store for agent: {}, knowledge ID: {}", agentIdStr, knowledgeId); - vectorStoreService.getAgentVectorStoreManager().deleteDocuments(agentIdStr, List.of(documentId)); + Map metadata = new HashMap<>(Map.ofEntries(Map.entry(Constant.AGENT_ID, agentIdStr), + Map.entry(Constant.KNOWLEDGE_ID, knowledgeId))); + + vectorStoreService.deleteDocumentsByMetedata(agentIdStr, metadata); log.info("Successfully deleted knowledge from vector store for agent: {}", agentIdStr); } @@ -208,7 +108,7 @@ public void deleteAllVectorDataForAgent(Long agentId) { String agentIdStr = String.valueOf(agentId); log.info("Deleting all vector data for agent: {}", agentIdStr); - vectorStoreService.getAgentVectorStoreManager().deleteAgentData(agentIdStr); + vectorStoreService.deleteDocumentsByMetedata(String.valueOf(agentId), new HashMap<>()); log.info("Successfully deleted all vector data for agent: {}", agentIdStr); } @@ -218,6 +118,17 @@ public void deleteAllVectorDataForAgent(Long agentId) { } } + public boolean isAlreadyInitialized(Long agentId) { + try { + String agentIdStr = String.valueOf(agentId); + return vectorStoreService.hasDocuments(agentIdStr); + } + catch (Exception e) { + log.error("Failed to check initialization status for agent: {}, assuming not initialized", agentId, e); + return false; + } + } + /** * Get agent vector storage statistics * @param agentId agent ID @@ -228,15 +139,13 @@ public Map getVectorStatistics(Long agentId) { String agentIdStr = String.valueOf(agentId); Map stats = new HashMap<>(); - boolean hasData = vectorStoreService.getAgentVectorStoreManager().hasAgentData(agentIdStr); - int documentCount = vectorStoreService.getAgentVectorStoreManager().getDocumentCount(agentIdStr); + int docNum = vectorStoreService.estimateDocuments(agentIdStr); + stats.put("docNum", docNum); stats.put("agentId", agentId); - stats.put("hasData", hasData); - stats.put("documentCount", documentCount); + stats.put("hasData", docNum > 0); - log.debug("Vector statistics for agent {}: hasData={}, documentCount={}", agentIdStr, hasData, - documentCount); + log.info("Successfully retrieved vector statistics for agent: {}, detail: {}", agentIdStr, stats); return stats; } @@ -356,7 +265,7 @@ public List getDatasourceTables(Integer datasourceId) { // } // Create database configuration - DbConfig dbConfig = createDbConfigFromDatasource(datasource); + DbConfig dbConfig = SchemaProcessorUtil.createDbConfigFromDatasource(datasource); // Create query parameters DbQueryParameter queryParam = DbQueryParameter.from(dbConfig); @@ -382,41 +291,6 @@ public List getDatasourceTables(Integer datasourceId) { } } - /** - * Create database configuration from data source entity - */ - private DbConfig createDbConfigFromDatasource(com.alibaba.cloud.ai.entity.Datasource datasource) { - DbConfig dbConfig = new DbConfig(); - - // Set basic connection information - dbConfig.setUrl(datasource.getConnectionUrl()); - dbConfig.setUsername(datasource.getUsername()); - dbConfig.setPassword(datasource.getPassword()); - - // Set database type - if ("mysql".equalsIgnoreCase(datasource.getType())) { - dbConfig.setConnectionType("jdbc"); - dbConfig.setDialectType("mysql"); - } - else if ("h2".equalsIgnoreCase(datasource.getType())) { - dbConfig.setConnectionType("jdbc"); - dbConfig.setDialectType("h2"); - } - // Support for other database types can be extended here - // else if ("postgresql".equalsIgnoreCase(datasource.getType())) { - // dbConfig.setConnectionType("jdbc"); - // dbConfig.setDialectType("postgresql"); - // } - - // Set Schema as the database name of the data source - dbConfig.setSchema(datasource.getDatabaseName()); - - log.debug("Created DbConfig for datasource {}: url={}, schema={}, type={}", datasource.getId(), - dbConfig.getUrl(), dbConfig.getSchema(), dbConfig.getDialectType()); - - return dbConfig; - } - /** * Initialize database Schema for agent using data source ID * @param agentId agent ID @@ -437,7 +311,7 @@ public Boolean initializeSchemaForAgentWithDatasource(Long agentId, Integer data } // Create database configuration - DbConfig dbConfig = createDbConfigFromDatasource(datasource); + DbConfig dbConfig = SchemaProcessorUtil.createDbConfigFromDatasource(datasource); // Create SchemaInitRequest SchemaInitRequest schemaInitRequest = new SchemaInitRequest(); @@ -447,7 +321,7 @@ public Boolean initializeSchemaForAgentWithDatasource(Long agentId, Integer data log.info("Created SchemaInitRequest for agent: {}, dbConfig: {}, tables: {}", agentIdStr, dbConfig, tables); // Call the original initialization method - return vectorStoreService.schemaForAgent(agentIdStr, schemaInitRequest); + return vectorStoreService.schema(agentIdStr, schemaInitRequest); } catch (Exception e) { @@ -456,160 +330,4 @@ public Boolean initializeSchemaForAgentWithDatasource(Long agentId, Integer data } } - /** - * Agent chat function - * @param agentId agent ID - * @param query user query - * @return agent response - */ - public String chatWithAgent(Long agentId, String query) { - try { - String agentIdStr = String.valueOf(agentId); - log.info("Processing chat request for agent: {}, query: {}", agentIdStr, query); - - // Check if agent has been initialized - boolean hasData = vectorStoreService.getAgentVectorStoreManager().hasAgentData(agentIdStr); - if (!hasData) { - return "智能体尚未初始化数据源,请先在「初始化信息源」中配置数据源和表结构。"; - } - - // Get agent's data source information - List> datasources = getAgentDatasources(agentId); - if (datasources.isEmpty()) { - return "智能体没有配置可用的数据源,请先配置数据源。"; - } - - // Use the first active data source - Map datasource = datasources.get(0); - - // Create database configuration - com.alibaba.cloud.ai.entity.Datasource dsEntity = datasourceService - .getDatasourceById((Integer) datasource.get("id")); - if (dsEntity == null) { - return "数据源配置不存在,请检查数据源配置。"; - } - - DbConfig dbConfig = createDbConfigFromDatasource(dsEntity); - - // Use SimpleNl2SqlService to process query - // Note: SimpleNl2SqlService needs to be injected here, but for simplicity, we - // return a basic response first - String response = processAgentQuery(agentIdStr, query, dbConfig); - - log.info("Generated response for agent: {}", agentIdStr); - return response; - - } - catch (Exception e) { - log.error("Failed to process chat request for agent: {}", agentId, e); - return "处理查询时发生错误:" + e.getMessage() + "。请检查数据源配置和网络连接。"; - } - } - - /** - * Process agent query (simplified version) - */ - private String processAgentQuery(String agentId, String query, DbConfig dbConfig) { - try { - // This is a simplified implementation - // In actual applications, the complete NL2SQL processing flow should be - // integrated - - // 1. 检查是否是简单的问候语 - if (isGreeting(query)) { - return "您好!我是您的数据分析助手。您可以用自然语言询问数据相关的问题,我会帮您查询和分析数据。\n\n" + "例如:\n" + "• 查询用户总数\n" + "• 显示最近一周的订单统计\n" - + "• 分析销售趋势\n\n" + "请告诉我您想了解什么数据信息?"; - } - - // 2. 获取相关的表和列信息 - List relevantDocs = vectorStoreService - .getAgentVectorStoreManager() - .similaritySearch(agentId, query, 10); - - if (relevantDocs.isEmpty()) { - return "抱歉,我没有找到与您的问题相关的数据表信息。请确保已正确初始化数据源,或者尝试用不同的方式描述您的问题。"; - } - - // 3. 构建响应 - StringBuilder response = new StringBuilder(); - response.append("根据您的问题「").append(query).append("」,我找到了以下相关信息:\n\n"); - - // Analyze related tables and columns - Set tables = new HashSet<>(); - List columns = new ArrayList<>(); - - for (org.springframework.ai.document.Document doc : relevantDocs) { - Map metadata = doc.getMetadata(); - String vectorType = (String) metadata.get("vectorType"); - - if ("table".equals(vectorType)) { - tables.add((String) metadata.get("name")); - } - else if ("column".equals(vectorType)) { - String tableName = (String) metadata.get("tableName"); - String columnName = (String) metadata.get("name"); - String description = (String) metadata.get("description"); - - tables.add(tableName); - columns.add(String.format("• %s.%s%s", tableName, columnName, - description != null && !description.isEmpty() ? " - " + description : "")); - } - } - - if (!tables.isEmpty()) { - response.append("📊 **相关数据表:**\n"); - for (String table : tables) { - response.append("• ").append(table).append("\n"); - } - response.append("\n"); - } - - if (!columns.isEmpty()) { - response.append("📋 **相关字段:**\n"); - for (String column : columns.subList(0, Math.min(columns.size(), 8))) { // Limit - // display - // quantity - response.append(column).append("\n"); - } - if (columns.size() > 8) { - response.append("... 还有 ").append(columns.size() - 8).append(" 个相关字段\n"); - } - response.append("\n"); - } - - response.append("💡 **建议:**\n"); - response.append("基于找到的数据结构,您可以询问更具体的问题,比如:\n"); - if (tables.contains("users")) { - response.append("• 用户总数是多少?\n"); - response.append("• 最近注册的用户有哪些?\n"); - } - if (tables.contains("orders")) { - response.append("• 今天的订单数量是多少?\n"); - response.append("• 最近一周的销售额是多少?\n"); - } - if (tables.contains("products")) { - response.append("• 有哪些产品分类?\n"); - response.append("• 最受欢迎的产品是什么?\n"); - } - - response.append("\n⚠️ **注意:** 当前为调试模式,显示的是数据结构分析。完整的SQL查询和数据分析功能正在开发中。"); - - return response.toString(); - - } - catch (Exception e) { - log.error("Error processing agent query: {}", e.getMessage(), e); - return "处理查询时发生错误:" + e.getMessage(); - } - } - - /** - * Check if it is a greeting - */ - private boolean isGreeting(String query) { - String lowerQuery = query.toLowerCase().trim(); - return lowerQuery.matches(".*(你好|hello|hi|您好|嗨|hey).*") || lowerQuery.equals("你好") || lowerQuery.equals("您好") - || lowerQuery.equals("hello") || lowerQuery.equals("hi"); - } - } diff --git a/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/service/impl/AnalyticDbVectorStoreManagementService.java b/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/service/impl/AnalyticDbVectorStoreManagementService.java deleted file mode 100644 index 515c4ec..0000000 --- a/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/service/impl/AnalyticDbVectorStoreManagementService.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright 2024-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.alibaba.cloud.ai.service.impl; - -import com.alibaba.cloud.ai.annotation.ConditionalOnADBEnabled; -import com.alibaba.cloud.ai.connector.accessor.Accessor; -import com.alibaba.cloud.ai.connector.accessor.AccessorFactory; -import com.alibaba.cloud.ai.connector.config.DbConfig; -import com.alibaba.cloud.ai.service.AbstractVectorStoreManagementService; -import com.alibaba.cloud.ai.vectorstore.analyticdb.AnalyticDbVectorStore; -import org.springframework.ai.vectorstore.VectorStore; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.stereotype.Service; - -/** - * Core vector database operation service, providing vector writing, querying, deletion, - * schema initialization and other functions. - */ - -@Service -@ConditionalOnADBEnabled -public class AnalyticDbVectorStoreManagementService extends AbstractVectorStoreManagementService { - - @Autowired - private AnalyticDbVectorStore vectorStore; - - private final Accessor dbAccessor; - - public AnalyticDbVectorStoreManagementService(AccessorFactory accessorFactory, DbConfig dbConfig) { - this.dbAccessor = accessorFactory.getAccessorByDbConfig(dbConfig); - } - - @Override - protected VectorStore getVectorStore() { - return vectorStore; - } - -} diff --git a/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/service/impl/SimpleVectorStoreManagementService.java b/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/service/impl/SimpleVectorStoreManagementService.java deleted file mode 100644 index 1b3fc0b..0000000 --- a/spring-ai-alibaba-data-agent-management/src/main/java/com/alibaba/cloud/ai/service/impl/SimpleVectorStoreManagementService.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Copyright 2024-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.alibaba.cloud.ai.service.impl; - -import com.alibaba.cloud.ai.connector.accessor.Accessor; -import com.alibaba.cloud.ai.connector.accessor.AccessorFactory; -import com.alibaba.cloud.ai.dashscope.api.DashScopeApi; -import com.alibaba.cloud.ai.dashscope.embedding.DashScopeEmbeddingModel; -import com.alibaba.cloud.ai.dashscope.embedding.DashScopeEmbeddingOptions; -import com.alibaba.cloud.ai.connector.config.DbConfig; -import com.alibaba.cloud.ai.service.AbstractVectorStoreManagementService; -import lombok.extern.slf4j.Slf4j; -import org.springframework.ai.document.MetadataMode; -import org.springframework.ai.vectorstore.SimpleVectorStore; -import org.springframework.ai.vectorstore.VectorStore; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.stereotype.Service; - -@Slf4j -@Service -public class SimpleVectorStoreManagementService extends AbstractVectorStoreManagementService { - - private final SimpleVectorStore vectorStore; - - private final Accessor dbAccessor; - - private final DbConfig dbConfig; - - @Autowired - public SimpleVectorStoreManagementService(@Value("${spring.ai.dashscope.api-key:default_api_key}") String apiKey, - AccessorFactory accessorFactory, DbConfig dbConfig) { - this.dbAccessor = accessorFactory.getAccessorByDbConfig(dbConfig); - this.dbConfig = dbConfig; - - DashScopeApi dashScopeApi = DashScopeApi.builder().apiKey(apiKey).build(); - DashScopeEmbeddingModel dashScopeEmbeddingModel = new DashScopeEmbeddingModel(dashScopeApi, MetadataMode.EMBED, - DashScopeEmbeddingOptions.builder() - .withModel(DashScopeApi.EmbeddingModel.EMBEDDING_V4.getValue()) - .build()); - this.vectorStore = SimpleVectorStore.builder(dashScopeEmbeddingModel).build(); - } - - @Override - protected VectorStore getVectorStore() { - return vectorStore; - } - -}