Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ public StateGraph nl2sqlGraph(ChatClient.Builder chatClientBuilder) throws Graph
.addNode(SCHEMA_RECALL_NODE, node_async(new SchemaRecallNode(schemaService)))
.addNode(TABLE_RELATION_NODE,
node_async(new TableRelationNode(schemaService, nl2SqlService, businessKnowledgeRecallService,
semanticModelRecallService)))
semanticModelRecallService, datasourceService)))
.addNode(SQL_GENERATE_NODE, node_async(new SqlGenerateNode(chatClientBuilder, nl2SqlService)))
.addNode(PLANNER_NODE, node_async(new PlannerNode(chatClientBuilder)))
.addNode(PLAN_EXECUTOR_NODE, node_async(new PlanExecutorNode()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import static com.alibaba.cloud.ai.constant.Constant.IS_ONLY_NL2SQL;
import static com.alibaba.cloud.ai.constant.Constant.PLANNER_NODE_OUTPUT;
import static com.alibaba.cloud.ai.constant.Constant.PLAN_VALIDATION_ERROR;
import static com.alibaba.cloud.ai.constant.Constant.QUERY_REWRITE_NODE_OUTPUT;
import static com.alibaba.cloud.ai.constant.Constant.SEMANTIC_MODEL;
import static com.alibaba.cloud.ai.constant.Constant.TABLE_RELATION_OUTPUT;

Expand All @@ -56,6 +57,11 @@ public PlannerNode(ChatClient.Builder chatClientBuilder) {
@Override
public Map<String, Object> apply(OverAllState state) throws Exception {
String input = (String) state.value(INPUT_KEY).orElseThrow();
// 使用经过时间表达式处理的重写查询,如果没有则回退到原始输入
String processedQuery = StateUtils.getStringValue(state, QUERY_REWRITE_NODE_OUTPUT, input);
logger.info("Using processed query for planning: {}", processedQuery);

// 是否为NL2SQL模式
Boolean onlyNl2sql = state.value(IS_ONLY_NL2SQL, false);

// 检查是否为修复模式
Expand All @@ -74,7 +80,7 @@ public Map<String, Object> apply(OverAllState state) throws Exception {
String schemaStr = PromptHelper.buildMixMacSqlDbPrompt(schemaDTO, true);

// 构建用户提示
String userPrompt = buildUserPrompt(input, validationError, state);
String userPrompt = buildUserPrompt(processedQuery, validationError, state);

// 构建模板参数
Map<String, Object> params = Map.of("user_question", userPrompt, "schema", schemaStr, "business_knowledge",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@

package com.alibaba.cloud.ai.node;

import com.alibaba.cloud.ai.connector.config.DbConfig;
import com.alibaba.cloud.ai.constant.Constant;
import com.alibaba.cloud.ai.dto.BusinessKnowledgeDTO;
import com.alibaba.cloud.ai.dto.SemanticModelDTO;
import com.alibaba.cloud.ai.dto.schema.SchemaDTO;
import com.alibaba.cloud.ai.entity.AgentDatasource;
import com.alibaba.cloud.ai.entity.Datasource;
import com.alibaba.cloud.ai.enums.StreamResponseType;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.dto.schema.SchemaDTO;
import com.alibaba.cloud.ai.service.DatasourceService;
import com.alibaba.cloud.ai.service.base.BaseNl2SqlService;
import com.alibaba.cloud.ai.service.base.BaseSchemaService;
import com.alibaba.cloud.ai.service.business.BusinessKnowledgeRecallService;
Expand Down Expand Up @@ -66,13 +70,16 @@ public class TableRelationNode implements NodeAction {

private final SemanticModelRecallService semanticModelRecallService;

private final DatasourceService datasourceService;

public TableRelationNode(BaseSchemaService baseSchemaService, BaseNl2SqlService baseNl2SqlService,
BusinessKnowledgeRecallService businessKnowledgeRecallService,
SemanticModelRecallService semanticModelRecallService) {
SemanticModelRecallService semanticModelRecallService, DatasourceService datasourceService) {
this.baseSchemaService = baseSchemaService;
this.baseNl2SqlService = baseNl2SqlService;
this.businessKnowledgeRecallService = businessKnowledgeRecallService;
this.semanticModelRecallService = semanticModelRecallService;
this.datasourceService = datasourceService;
}

@Override
Expand Down Expand Up @@ -163,21 +170,104 @@ private SchemaDTO buildInitialSchema(List<List<Document>> columnDocumentsByKeywo
return schemaDTO;
}

/**
* Dynamically get the data source configuration for an agent
* @param state The state object containing the agent ID
* @return The database configuration corresponding to the agent, or null if not found
*/
private DbConfig getAgentDbConfig(OverAllState state) {
try {
// Get the agent ID from the state
String agentIdStr = StateUtils.getStringValue(state, Constant.AGENT_ID, null);
if (agentIdStr == null || agentIdStr.trim().isEmpty()) {
logger.debug("AgentId is null or empty, will use default dbConfig");
return null;
}

Integer agentId = Integer.valueOf(agentIdStr);
logger.debug("Getting datasource config for agent: {}", agentId);

// Get the enabled data source for the agent
List<AgentDatasource> agentDatasources = datasourceService.getAgentDatasources(agentId);
if (agentDatasources.isEmpty()) {
// TODO 调试AgentID不一致,暂时手动处理
agentDatasources = datasourceService.getAgentDatasources(agentId - 999999);
Comment on lines +175 to +176
Copy link
Preview

Copilot AI Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hardcoded magic number 999999 and TODO comment indicate a temporary workaround. Consider implementing a proper agent ID mapping mechanism or configuration property instead of this hardcoded offset.

Suggested change
// TODO 调试AgentID不一致,暂时手动处理
agentDatasources = datasourceService.getAgentDatasources(agentId - 999999);
// Use configurable agentIdOffset to handle AgentID inconsistency
agentDatasources = datasourceService.getAgentDatasources(agentId - agentIdOffset);

Copilot uses AI. Check for mistakes.

}

AgentDatasource activeDatasource = agentDatasources.stream()
.filter(ad -> ad.getIsActive() == 1)
.findFirst()
.orElse(null);

if (activeDatasource == null) {
logger.debug("Agent {} has no active datasource, will use default dbConfig", agentId);
return null;
}

// Convert to DbConfig
DbConfig dbConfig = createDbConfigFromDatasource(activeDatasource.getDatasource());
logger.debug("Successfully created DbConfig for agent {}: url={}, schema={}, type={}", agentId,
dbConfig.getUrl(), dbConfig.getSchema(), dbConfig.getDialectType());

return dbConfig;
}
catch (Exception e) {
logger.warn("Failed to get agent datasource config, will use default dbConfig: {}", e.getMessage());
return null;
}
}

/**
* Create database configuration from data source entity
* @param datasource The data source entity
* @return The database configuration object
*/
private DbConfig createDbConfigFromDatasource(Datasource datasource) {
DbConfig dbConfig = new DbConfig();

// Set basic connection information
dbConfig.setUrl(datasource.getConnectionUrl());
dbConfig.setUsername(datasource.getUsername());
dbConfig.setPassword(datasource.getPassword());

// Set database type
if ("mysql".equalsIgnoreCase(datasource.getType())) {
dbConfig.setConnectionType("jdbc");
dbConfig.setDialectType("mysql");
}
else if ("postgresql".equalsIgnoreCase(datasource.getType())) {
dbConfig.setConnectionType("jdbc");
dbConfig.setDialectType("postgresql");
}
else {
throw new RuntimeException("不支持的数据库类型: " + datasource.getType());
}

// Set Schema to the database name of the data source
dbConfig.setSchema(datasource.getDatabaseName());

return dbConfig;
}

/**
* Processes schema selection based on input, evidence, and optional advice.
*/
private SchemaDTO processSchemaSelection(SchemaDTO schemaDTO, String input, List<String> evidenceList,
OverAllState state) {
String schemaAdvice = StateUtils.getStringValue(state, SQL_GENERATE_SCHEMA_MISSING_ADVICE, null);

// 动态获取Agent对应的数据库配置
DbConfig agentDbConfig = getAgentDbConfig(state);
logger.debug("Using agent-specific dbConfig: {}", agentDbConfig != null ? agentDbConfig.getUrl() : "default");

if (schemaAdvice != null) {
logger.info("[{}] Processing with schema supplement advice: {}", this.getClass().getSimpleName(),
schemaAdvice);
return baseNl2SqlService.fineSelect(schemaDTO, input, evidenceList, schemaAdvice);
return baseNl2SqlService.fineSelect(schemaDTO, input, evidenceList, schemaAdvice, agentDbConfig);
}
else {
logger.info("[{}] Executing regular schema selection", this.getClass().getSimpleName());
return baseNl2SqlService.fineSelect(schemaDTO, input, evidenceList);
return baseNl2SqlService.fineSelect(schemaDTO, input, evidenceList, null, agentDbConfig);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,8 @@ public static PromptTemplate getSemanticModelPromptTemplate() {
return new PromptTemplate(PromptLoader.loadPrompt("semantic-model"));
}

public static PromptTemplate getTimeConversionPromptTemplate() {
return new PromptTemplate(PromptLoader.loadPrompt("time-conversion"));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.apache.commons.collections.CollectionUtils;

import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.*;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -102,6 +104,21 @@ public static String buildDateTimeExtractPrompt(String question) {
return PromptConstant.getExtractDatetimePromptTemplate().render(params);
}

/**
* 构建时间转换提示词
* @param query 用户查询
* @return 时间转换提示词
*/
public static String buildTimeConversionPrompt(String query) {
Map<String, Object> promptMap = new HashMap<>();
promptMap.put("current_time_info",
LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")));
promptMap.put("query", query);

PromptTemplate promptTemplate = PromptConstant.getTimeConversionPromptTemplate();
return promptTemplate.render(promptMap);
}

public static String buildMixMacSqlDbPrompt(SchemaDTO schemaDTO, Boolean withColumnType) {
StringBuilder sb = new StringBuilder();
sb.append("【DB_ID】 ").append(schemaDTO.getName() == null ? "" : schemaDTO.getName()).append("\n");
Expand Down
Loading