Skip to content

Commit 4d2aee4

Browse files
committed
feat(deepresearch): optimize
1 parent f394869 commit 4d2aee4

File tree

8 files changed

+50
-77
lines changed

8 files changed

+50
-77
lines changed

spring-ai-alibaba-deepresearch/src/main/java/com/alibaba/cloud/ai/example/deepresearch/agents/McpAssignNodeConfiguration.java

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
import org.springframework.context.annotation.Configuration;
3333
import org.springframework.core.io.Resource;
3434
import org.springframework.core.io.ResourceLoader;
35-
import org.springframework.web.reactive.function.client.WebClient;
3635

3736
import java.io.IOException;
3837
import java.io.InputStream;
@@ -56,18 +55,12 @@ public class McpAssignNodeConfiguration {
5655
@Autowired
5756
private McpAssignNodeProperties mcpAssignNodeProperties;
5857

59-
@Autowired
60-
private McpClientCommonProperties commonProperties;
61-
6258
@Autowired
6359
private ResourceLoader resourceLoader;
6460

6561
@Autowired
6662
private ObjectMapper objectMapper;
6763

68-
@Autowired
69-
private WebClient.Builder webClientBuilderTemplate;
70-
7164
/**
7265
* 读取JSON配置文件
7366
*/

spring-ai-alibaba-deepresearch/src/main/java/com/alibaba/cloud/ai/example/deepresearch/node/BackgroundInvestigationNode.java

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -91,55 +91,50 @@ public Map<String, Object> apply(OverAllState state) throws Exception {
9191
// 使用统一的智能搜索选择方法
9292
SmartAgentUtil.SearchSelectionResult searchSelection = smartAgentSelectionHelper
9393
.intelligentSearchSelection(state, query);
94-
List<Map<String, String>> results = new ArrayList<>();
94+
List<Map<String, String>> results;
9595

9696
// 使用支持工具调用的搜索方法
97-
results = searchInfoService.searchInfo(state.value("enable_search_filter", true),
98-
searchSelection.getSearchEnum(), query, searchSelection.getSearchPlatform());
97+
results = searchInfoService.searchInfo(StateUtil.isSearchFilter(state), searchSelection.getSearchEnum(),
98+
query, searchSelection.getSearchPlatform());
9999
resultMap.put("site_information", results);
100100
resultsList.add(results);
101101
}
102102

103-
if (!resultsList.isEmpty()) {
104-
List<String> backgroundResults = new ArrayList<>();
105-
assert resultsList.size() != queries.size();
103+
List<String> backgroundResults = new ArrayList<>();
104+
assert resultsList.size() != queries.size();
106105

107-
for (int i = 0; i < resultsList.size(); i++) {
108-
List<Map<String, String>> searchResults = resultsList.get(i);
106+
for (int i = 0; i < resultsList.size(); i++) {
107+
List<Map<String, String>> searchResults = resultsList.get(i);
109108

110-
String query = queries.get(i);
109+
String query = queries.get(i);
111110

112-
Message messages = new UserMessage(
113-
"搜索问题:" + query + "\n" + "以下是搜索结果:\n\n" + searchResults.stream().map(r -> {
114-
return String.format("标题: %s\n权重: %s\n内容: %s\n", r.get("title"), r.get("weight"),
115-
r.get("content"));
116-
}).collect(Collectors.joining("\n\n")));
111+
Message messages = new UserMessage(
112+
"搜索问题:" + query + "\n" + "以下是搜索结果:\n\n" + searchResults.stream().map(r -> {
113+
return String.format("标题: %s\n权重: %s\n内容: %s\n", r.get("title"), r.get("weight"),
114+
r.get("content"));
115+
}).collect(Collectors.joining("\n\n")));
117116

118-
String sessionId = state.value("session_id", String.class).orElse("__default__");
119-
List<SessionHistory> reports = sessionContextService.getRecentReports(sessionId);
120-
Message lastReportMessage;
121-
if (reports != null && !reports.isEmpty()) {
122-
lastReportMessage = new AssistantMessage("这是用户前几次使用DeepResearch的报告:\r\n"
123-
+ reports.stream().map(SessionHistory::toString).collect(Collectors.joining("\r\n\r\n")));
124-
}
125-
else {
126-
lastReportMessage = new AssistantMessage("这是用户的第一次询问,因此没有上下文。");
127-
}
117+
String sessionId = state.value("session_id", String.class).orElse("__default__");
118+
List<SessionHistory> reports = sessionContextService.getRecentReports(sessionId);
119+
Message lastReportMessage;
120+
if (reports != null && !reports.isEmpty()) {
121+
lastReportMessage = new AssistantMessage("这是用户前几次使用DeepResearch的报告:\r\n"
122+
+ reports.stream().map(SessionHistory::toString).collect(Collectors.joining("\r\n\r\n")));
123+
}
124+
else {
125+
lastReportMessage = new AssistantMessage("这是用户的第一次询问,因此没有上下文。");
126+
}
128127

129-
String content = backgroundAgent.prompt().messages(lastReportMessage, messages).call().content();
128+
String content = backgroundAgent.prompt().messages(lastReportMessage, messages).call().content();
130129

131-
backgroundResults.add(content);
130+
backgroundResults.add(content);
132131

133-
logger.info("背景调查报告生成已完成: {}", backgroundResults.size());
134-
}
135-
resultMap.put("background_investigation_results", backgroundResults);
136-
}
137-
else {
138-
logger.warn("⚠️ 搜索失败");
132+
logger.info("背景调查报告生成已完成: {}", backgroundResults.size());
139133
}
134+
resultMap.put("background_investigation_results", backgroundResults);
140135

141136
String nextStep = "planner";
142-
if (!state.value("enable_deepresearch", true)) {
137+
if (!StateUtil.isDeepresearch(state)) {
143138
nextStep = "reporter";
144139
}
145140
resultMap.put("background_investigation_next_node", nextStep);

spring-ai-alibaba-deepresearch/src/main/java/com/alibaba/cloud/ai/example/deepresearch/node/CoordinatorNode.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ public Map<String, Object> apply(OverAllState state) throws Exception {
6565
messages.add(TemplateUtil.getMessage("coordinator"));
6666

6767
// 添加前几次同一会话的报告
68-
String sessionId = state.value("session_id", String.class).orElse("__default__");
68+
String sessionId = StateUtil.getSessionId(state);
6969
List<SessionHistory> reports = sessionContextService.getRecentReports(sessionId);
7070
Message lastReportMessage;
7171
if (reports != null && !reports.isEmpty()) {

spring-ai-alibaba-deepresearch/src/main/java/com/alibaba/cloud/ai/example/deepresearch/node/ProfessionalKbDecisionNode.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package com.alibaba.cloud.ai.example.deepresearch.node;
1818

1919
import com.alibaba.cloud.ai.example.deepresearch.config.rag.RagProperties;
20+
import com.alibaba.cloud.ai.example.deepresearch.util.StateUtil;
2021
import com.alibaba.cloud.ai.graph.OverAllState;
2122
import com.alibaba.cloud.ai.graph.action.NodeAction;
2223
import org.slf4j.Logger;
@@ -48,7 +49,7 @@ public ProfessionalKbDecisionNode(ChatClient chatClient, RagProperties ragProper
4849
@Override
4950
public Map<String, Object> apply(OverAllState state) throws Exception {
5051
logger.info("Professional KB decision node is running.");
51-
String query = state.value("query", "");
52+
String query = StateUtil.getQuery(state);
5253
Map<String, Object> updated = new HashMap<>();
5354

5455
// 如果没有启用专业知识库决策,直接返回不使用

spring-ai-alibaba-deepresearch/src/main/java/com/alibaba/cloud/ai/example/deepresearch/node/RagNode.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import com.alibaba.cloud.ai.example.deepresearch.rag.core.HybridRagProcessor;
2020
import com.alibaba.cloud.ai.example.deepresearch.rag.strategy.FusionStrategy;
2121
import com.alibaba.cloud.ai.example.deepresearch.rag.strategy.RetrievalStrategy;
22+
import com.alibaba.cloud.ai.example.deepresearch.util.StateUtil;
2223
import com.alibaba.cloud.ai.graph.OverAllState;
2324
import com.alibaba.cloud.ai.graph.action.NodeAction;
2425
import com.alibaba.cloud.ai.graph.streaming.StreamingChatGenerator;
@@ -79,8 +80,7 @@ public RagNode(HybridRagProcessor hybridRagProcessor, ChatClient ragAgent) {
7980
@Override
8081
public Map<String, Object> apply(OverAllState state) throws Exception {
8182
logger.info("rag_node is running.");
82-
String queryText = state.value("query", String.class)
83-
.orElseThrow(() -> new IllegalArgumentException("Query is missing from state"));
83+
String queryText = StateUtil.getQuery(state);
8484

8585
Map<String, Object> options = new HashMap<>();
8686
state.value("session_id", String.class).ifPresent(v -> options.put("session_id", v));

spring-ai-alibaba-deepresearch/src/main/java/com/alibaba/cloud/ai/example/deepresearch/node/ReporterNode.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,8 @@ public Map<String, Object> apply(OverAllState state) throws Exception {
7171
logger.info("reporter node is running.");
7272

7373
// 从 OverAllState 中获取线程ID
74-
String threadId = state.value("thread_id", String.class)
75-
.orElseThrow(() -> new IllegalArgumentException("thread_id is missing from state"));
76-
String sessionId = state.value("session_id", String.class)
77-
.orElseThrow(() -> new IllegalArgumentException("session_id is missing from state"));
74+
String threadId = StateUtil.getThreadId(state);
75+
String sessionId = StateUtil.getSessionId(state);
7876
logger.info("Thread ID from state: {}", threadId);
7977
logger.info("Session ID from state: {}", sessionId);
8078

@@ -93,8 +91,7 @@ public Map<String, Object> apply(OverAllState state) throws Exception {
9391

9492
// 添加深度研究信息
9593
if (state.value("enable_deepresearch", true)) {
96-
Plan currentPlan = state.value("current_plan", Plan.class)
97-
.orElseThrow(() -> new IllegalArgumentException("current_plan is missing"));
94+
Plan currentPlan = StateUtil.getPlan(state);
9895

9996
// 1.1 研究报告格式消息
10097
messages.add(new UserMessage(

spring-ai-alibaba-deepresearch/src/main/java/com/alibaba/cloud/ai/example/deepresearch/node/RewriteAndMultiQueryNode.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ public Map<String, Object> apply(OverAllState state) throws Exception {
7272
Query rewriteQuery = queryTransformer.transform(query);
7373

7474
// 查询拓展
75-
int optimizeQueryNum = state.value("optimize_query_num", 3);
75+
int optimizeQueryNum = StateUtil.getOptimizeQueryNum(state);
7676
optimizeQueryNum = Math.max(MinOptimizeQueryNum, Math.min(MaxOptimizeQueryNum, optimizeQueryNum));
7777
QueryExpander queryExpander = MultiQueryExpander.builder()
7878
.chatClientBuilder(rewriteAndMultiQueryAgentBuilder)

spring-ai-alibaba-deepresearch/src/main/java/com/alibaba/cloud/ai/example/deepresearch/util/StateUtil.java

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,10 @@
1818

1919
import com.alibaba.cloud.ai.example.deepresearch.model.dto.Plan;
2020
import com.alibaba.cloud.ai.graph.OverAllState;
21-
import org.slf4j.Logger;
22-
import org.slf4j.LoggerFactory;
2321

2422
import java.util.ArrayList;
25-
import java.util.Collections;
2623
import java.util.List;
27-
import java.util.Map;
2824
import java.util.Optional;
29-
import java.util.concurrent.ExecutorService;
30-
import java.util.concurrent.Executors;
3125

3226
/**
3327
* @author yingzi
@@ -36,10 +30,6 @@
3630

3731
public class StateUtil {
3832

39-
private static final Logger logger = LoggerFactory.getLogger(StateUtil.class);
40-
41-
private static final ExecutorService executor = Executors.newFixedThreadPool(10);
42-
4333
public static final String EXECUTION_STATUS_ASSIGNED_PREFIX = "assigned_";
4434

4535
public static final String EXECUTION_STATUS_PROCESSING_PREFIX = "processing_";
@@ -66,10 +56,6 @@ public static void handleStepError(Plan.Step step, String nodeName, Throwable er
6656
logger.error("{} failed: {}", nodeName, error.getMessage(), error);
6757
}
6858

69-
public static List<String> getMessagesByType(OverAllState state, String name) {
70-
return state.value(name, List.class).map(obj -> new ArrayList<>((List<String>) obj)).orElseGet(ArrayList::new);
71-
}
72-
7359
public static List<String> getParallelMessages(OverAllState state, List<String> researcherTeam, int count) {
7460
List<String> resList = new ArrayList<>();
7561

@@ -112,8 +98,16 @@ public static Integer getMaxStepNum(OverAllState state) {
11298
return state.value("max_step_num", 3);
11399
}
114100

101+
public static Integer getOptimizeQueryNum(OverAllState state) {
102+
return state.value("optimize_query_num", 3);
103+
}
104+
115105
public static String getThreadId(OverAllState state) {
116-
return state.value("thread_id", "__default__");
106+
return state.value("thread_id", "");
107+
}
108+
109+
public static String getSessionId(OverAllState state) {
110+
return state.value("session_id", "__default__");
117111
}
118112

119113
public static boolean getAutoAcceptedPlan(OverAllState state) {
@@ -124,19 +118,12 @@ public static String getRagContent(OverAllState state) {
124118
return state.value("rag_content", "");
125119
}
126120

127-
/**
128-
* 获取MCP设置
129-
*/
130-
public static Map<String, Object> getMcpSettings(OverAllState state) {
131-
return state.value("mcp_settings", Map.class).orElse(Collections.emptyMap());
121+
public static boolean isSearchFilter(OverAllState state) {
122+
return state.value("search_filter", true);
132123
}
133124

134-
/**
135-
* 检查是否有运行时MCP配置
136-
*/
137-
public static boolean hasRuntimeMcpConfig(OverAllState state) {
138-
Map<String, Object> mcpSettings = getMcpSettings(state);
139-
return !mcpSettings.isEmpty();
125+
public static boolean isDeepresearch(OverAllState state) {
126+
return state.value("enable_deepresearch", true);
140127
}
141128

142129
}

0 commit comments

Comments
 (0)