Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@

<!-- Maven Compiler Plugin -->
<maven-compiler-plugin.version>3.11.0</maven-compiler-plugin.version>

<!-- spotless version-->
<spotless-maven-plugin.version>2.44.5</spotless-maven-plugin.version>
<docker-java.version>3.5.3</docker-java.version>
<gpdb.version>3.0.0</gpdb.version>
<druid.version>1.2.22</druid.version>
Expand Down Expand Up @@ -249,6 +250,24 @@
</execution>
</executions>
</plugin>
<plugin>
<groupId>com.diffplug.spotless</groupId>
<artifactId>spotless-maven-plugin</artifactId>
<version>${spotless-maven-plugin.version}</version>
<configuration>
<java>
<removeUnusedImports/>
</java>
</configuration>
<executions>
<execution>
<phase>compile</phase>
<goals>
<goal>apply</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,16 @@ private Constant() {
// 人类复核相关
public static final String HUMAN_REVIEW_ENABLED = "HUMAN_REVIEW_ENABLED";

// column
public static final String COLUMN = "column";

// table
public static final String TABLE = "table";

// vectorType
public static final String VECTOR_TYPE = "vectorType";

// knowledgeId
public static final String KNOWLEDGE_ID = "knowledgeId";

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import com.alibaba.cloud.ai.constant.Constant;
import com.alibaba.cloud.ai.dto.BusinessKnowledgeDTO;
import com.alibaba.cloud.ai.dto.SemanticModelDTO;
import com.alibaba.cloud.ai.entity.AgentDatasource;
import com.alibaba.cloud.ai.entity.Datasource;
import com.alibaba.cloud.ai.enums.StreamResponseType;
import com.alibaba.cloud.ai.graph.OverAllState;
Expand All @@ -32,6 +31,7 @@
import com.alibaba.cloud.ai.service.schema.SchemaService;
import com.alibaba.cloud.ai.service.semantic.SemanticModelRecallService;
import com.alibaba.cloud.ai.util.ChatResponseUtil;
import com.alibaba.cloud.ai.util.SchemaProcessorUtil;
import com.alibaba.cloud.ai.util.StateUtil;
import com.alibaba.cloud.ai.util.StreamingChatGeneratorUtil;
import org.slf4j.Logger;
Expand All @@ -40,6 +40,7 @@
import org.springframework.ai.document.Document;
import org.springframework.dao.DataAccessException;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Flux;

import java.util.List;
Expand Down Expand Up @@ -98,21 +99,20 @@ public Map<String, Object> apply(OverAllState state) throws Exception {
COLUMN_DOCUMENTS_BY_KEYWORDS_OUTPUT);
String dataSetId = StateUtil.getStringValue(state, Constant.AGENT_ID);
String agentIdStr = StateUtil.getStringValue(state, AGENT_ID);
long agentId = -1L;
if (!agentIdStr.isEmpty()) {
agentId = Long.parseLong(agentIdStr);
}
if (!StringUtils.hasText(agentIdStr))
throw new RuntimeException("Agent ID is empty.");

// Execute business logic first - get final result immediately
SchemaDTO schemaDTO = buildInitialSchema(columnDocumentsByKeywords, tableDocuments);
SchemaDTO result = processSchemaSelection(schemaDTO, input, evidenceList, state);
DbConfig agentDbConfig = getAgentDbConfig(Integer.valueOf(agentIdStr));
SchemaDTO schemaDTO = buildInitialSchema(agentIdStr, columnDocumentsByKeywords, tableDocuments, agentDbConfig);
SchemaDTO result = processSchemaSelection(schemaDTO, input, evidenceList, state, agentDbConfig);

List<BusinessKnowledgeDTO> businessKnowledges;
List<SemanticModelDTO> semanticModel;
try {
// Extract business knowledge and semantic model
businessKnowledges = businessKnowledgeRecallService.getFieldByDataSetId(dataSetId);
semanticModel = semanticModelRecallService.getFieldByDataSetId(String.valueOf(agentId));
semanticModel = semanticModelRecallService.getFieldByDataSetId(dataSetId);
}
catch (DataAccessException e) {
logger.warn("Database query failed (attempt {}): {}", retryCount + 1, e.getMessage());
Expand Down Expand Up @@ -164,59 +164,27 @@ private String classifyDatabaseError(DataAccessException e) {
/**
* Builds initial schema from column and table documents.
*/
private SchemaDTO buildInitialSchema(List<List<Document>> columnDocumentsByKeywords,
List<Document> tableDocuments) {
private SchemaDTO buildInitialSchema(String agentId, List<List<Document>> columnDocumentsByKeywords,
List<Document> tableDocuments, DbConfig agentDbConfig) {
SchemaDTO schemaDTO = new SchemaDTO();
schemaService.extractDatabaseName(schemaDTO);
schemaService.buildSchemaFromDocuments(columnDocumentsByKeywords, tableDocuments, schemaDTO);

schemaService.extractDatabaseName(schemaDTO, agentDbConfig);
schemaService.buildSchemaFromDocuments(agentId, columnDocumentsByKeywords, tableDocuments, schemaDTO);
return schemaDTO;
}

/**
* Dynamically get the data source configuration for an agent
* @param state The state object containing the agent ID
* @return The database configuration corresponding to the agent, or null if not found
*/
private DbConfig getAgentDbConfig(OverAllState state) {
try {
// Get the agent ID from the state
String agentIdStr = StateUtil.getStringValue(state, Constant.AGENT_ID, null);
if (agentIdStr == null || agentIdStr.trim().isEmpty()) {
logger.debug("AgentId is null or empty, will use default dbConfig");
return null;
}

Integer agentId = Integer.valueOf(agentIdStr);
logger.debug("Getting datasource config for agent: {}", agentId);

// Get the enabled data source for the agent
List<AgentDatasource> agentDatasources = datasourceService.getAgentDatasources(agentId);
if (agentDatasources.isEmpty()) {
// TODO 调试AgentID不一致,暂时手动处理
agentDatasources = datasourceService.getAgentDatasources(agentId - 999999);
}

AgentDatasource activeDatasource = agentDatasources.stream()
.filter(ad -> ad.getIsActive() == 1)
.findFirst()
.orElse(null);

if (activeDatasource == null) {
logger.debug("Agent {} has no active datasource, will use default dbConfig", agentId);
return null;
}
private DbConfig getAgentDbConfig(Integer agentId) {
// Get the enabled data source for the agent
Datasource agentDatasource = datasourceService.getActiveDatasourceByAgentId(agentId);
if (agentDatasource == null)
throw new RuntimeException("No active datasource found for agent " + agentId);

// Convert to DbConfig
DbConfig dbConfig = createDbConfigFromDatasource(activeDatasource.getDatasource());
logger.debug("Successfully created DbConfig for agent {}: url={}, schema={}, type={}", agentId,
dbConfig.getUrl(), dbConfig.getSchema(), dbConfig.getDialectType());
// Convert to DbConfig
DbConfig dbConfig = SchemaProcessorUtil.createDbConfigFromDatasource(agentDatasource);
logger.debug("Successfully created DbConfig for agent {}: url={}, schema={}, type={}", agentId,
dbConfig.getUrl(), dbConfig.getSchema(), dbConfig.getDialectType());

return dbConfig;
}
catch (Exception e) {
logger.warn("Failed to get agent datasource config, will use default dbConfig: {}", e.getMessage());
return null;
}
return dbConfig;
}

/**
Expand Down Expand Up @@ -255,13 +223,9 @@ else if ("postgresql".equalsIgnoreCase(datasource.getType())) {
* Processes schema selection based on input, evidence, and optional advice.
*/
private SchemaDTO processSchemaSelection(SchemaDTO schemaDTO, String input, List<String> evidenceList,
OverAllState state) {
OverAllState state, DbConfig agentDbConfig) {
String schemaAdvice = StateUtil.getStringValue(state, SQL_GENERATE_SCHEMA_MISSING_ADVICE, null);

// 动态获取Agent对应的数据库配置
DbConfig agentDbConfig = getAgentDbConfig(state);
logger.debug("Using agent-specific dbConfig: {}", agentDbConfig != null ? agentDbConfig.getUrl() : "default");

if (schemaAdvice != null) {
logger.info("[{}] Processing with schema supplement advice: {}", this.getClass().getSimpleName(),
schemaAdvice);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,4 +299,15 @@ public Map<String, Object> getDatasourceStats() {
return stats;
}

public Datasource getActiveDatasourceByAgentId(Integer agentId) {
AgentDatasource agentDatasource = getAgentDatasources(agentId).stream()
.filter(a -> a.getIsActive() == 1)
.findFirst()
.orElse(null);
if (agentDatasource == null) {
return null;
}
return agentDatasource.getDatasource();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,20 @@
import com.alibaba.cloud.ai.dto.schema.SchemaDTO;
import com.alibaba.cloud.ai.prompt.PromptConstant;
import com.alibaba.cloud.ai.prompt.PromptHelper;
import com.alibaba.cloud.ai.service.DatasourceService;
import com.alibaba.cloud.ai.service.LlmService;
import com.alibaba.cloud.ai.service.nl2sql.Nl2SqlService;
import com.alibaba.cloud.ai.service.schema.SchemaService;
import com.alibaba.cloud.ai.service.vectorstore.VectorStoreService;
import com.alibaba.cloud.ai.service.vectorstore.AgentVectorStoreService;
import com.alibaba.cloud.ai.util.JsonUtil;
import com.alibaba.cloud.ai.util.SchemaProcessorUtil;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.document.Document;
import org.springframework.util.Assert;
import reactor.core.publisher.Flux;

import java.util.Collections;
Expand All @@ -44,11 +47,14 @@ public abstract class AbstractQueryProcessingService implements QueryProcessingS

private final LlmService aiService;

public AbstractQueryProcessingService(LlmService aiService) {
private final DatasourceService datasourceService;

public AbstractQueryProcessingService(LlmService aiService, DatasourceService datasourceService) {
this.aiService = aiService;
this.datasourceService = datasourceService;
}

protected abstract VectorStoreService getVectorStoreService();
protected abstract AgentVectorStoreService getVectorStoreService();

protected abstract SchemaService getSchemaService();

Expand All @@ -57,13 +63,9 @@ public AbstractQueryProcessingService(LlmService aiService) {
@Override
public List<String> extractEvidences(String query, String agentId) {
logger.debug("Extracting evidences for query: {} with agentId: {}", query, agentId);
List<Document> evidenceDocuments;
if (agentId != null && !agentId.trim().isEmpty()) {
evidenceDocuments = getVectorStoreService().getDocumentsForAgent(agentId, query, "evidence");
}
else {
evidenceDocuments = getVectorStoreService().getDocuments(query, "evidence");
}
Assert.notNull(agentId, "AgentId cannot be null");
List<Document> evidenceDocuments = getVectorStoreService().getDocumentsForAgent(agentId, query, "evidence");

List<String> evidences = evidenceDocuments.stream().map(Document::getText).collect(Collectors.toList());
logger.debug("Extracted {} evidences: {}", evidences.size(), evidences);
return evidences;
Expand Down Expand Up @@ -176,17 +178,20 @@ private String processTimeExpressions(String query) {
}

private SchemaDTO select(String query, List<String> evidenceList, String agentId) throws Exception {
Assert.notNull(agentId, "AgentId cannot be null");
logger.debug("Starting schema selection for query: {} with {} evidences and agentId: {}", query,
evidenceList.size(), agentId);
List<String> keywords = extractKeywords(query, evidenceList);
logger.debug("Using {} keywords for schema selection", keywords != null ? keywords.size() : 0);
SchemaDTO schemaDTO;
if (agentId != null) {
schemaDTO = getSchemaService().mixRagForAgent(agentId, query, keywords);
}
else {
schemaDTO = getSchemaService().mixRag(query, keywords);

com.alibaba.cloud.ai.entity.Datasource datasource = datasourceService
.getActiveDatasourceByAgentId(Integer.valueOf(agentId));
if (datasource == null) {
throw new RuntimeException("No active datasource found for agentId: " + agentId);
}
SchemaDTO schemaDTO = getSchemaService().mixRagForAgent(agentId, query, keywords,
SchemaProcessorUtil.createDbConfigFromDatasource(datasource));

logger.debug("Retrieved schema with {} tables", schemaDTO.getTable() != null ? schemaDTO.getTable().size() : 0);
SchemaDTO result = fineSelect(schemaDTO, query, evidenceList);
logger.debug("Fine selection completed, final schema has {} tables",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,33 @@

package com.alibaba.cloud.ai.service.processing.impls;

import com.alibaba.cloud.ai.service.DatasourceService;
import com.alibaba.cloud.ai.service.LlmService;
import com.alibaba.cloud.ai.service.nl2sql.Nl2SqlService;
import com.alibaba.cloud.ai.service.schema.SchemaService;
import com.alibaba.cloud.ai.service.processing.AbstractQueryProcessingService;
import com.alibaba.cloud.ai.service.vectorstore.VectorStoreService;
import com.alibaba.cloud.ai.service.vectorstore.AgentVectorStoreService;
import org.springframework.stereotype.Service;

@Service
public class QueryProcessingServiceImpl extends AbstractQueryProcessingService {

private final VectorStoreService vectorStoreService;
private final AgentVectorStoreService vectorStoreService;

private final SchemaService schemaService;

private final Nl2SqlService nl2SqlService;

public QueryProcessingServiceImpl(LlmService aiService, VectorStoreService vectorStoreService,
SchemaService schemaService, Nl2SqlService nl2SqlService) {
super(aiService);
public QueryProcessingServiceImpl(LlmService aiService, AgentVectorStoreService vectorStoreService,
SchemaService schemaService, Nl2SqlService nl2SqlService, DatasourceService datasourceService) {
super(aiService, datasourceService);
this.vectorStoreService = vectorStoreService;
this.schemaService = schemaService;
this.nl2SqlService = nl2SqlService;
}

@Override
protected VectorStoreService getVectorStoreService() {
protected AgentVectorStoreService getVectorStoreService() {
return this.vectorStoreService;
}

Expand Down
Loading
Loading