Skip to content

Commit e6741a8

Browse files
authored
[Feat] add mcp node (#1028)
1 parent 0638659 commit e6741a8

File tree

10 files changed

+725
-9
lines changed

10 files changed

+725
-9
lines changed

community/document-readers/spring-ai-alibaba-starter-document-reader-gpt-repo/pom.xml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,6 @@
4949
<version>${project.parent.version}</version>
5050
</dependency>
5151

52-
<dependency>
53-
<groupId>com.google.guava</groupId>
54-
<artifactId>guava</artifactId>
55-
<version>${guava.version}</version>
56-
</dependency>
57-
5852
<!-- test dependencies -->
5953
<dependency>
6054
<groupId>org.springframework.ai</groupId>

community/document-readers/spring-ai-alibaba-starter-document-reader-gpt-repo/src/main/java/com/alibaba/cloud/ai/reader/gptrepo/GptRepoDocumentReader.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ public FileVisitResult visitFile(Path file, @NotNull BasicFileAttributes attrs)
197197

198198
// Check file extension
199199
if (extensions != null && !extensions.isEmpty()) {
200-
String ext = com.google.common.io.Files.getFileExtension(file.toString());
200+
String ext = getFileExtension(file.toString());
201201
if (!extensions.contains(ext)) {
202202
return FileVisitResult.CONTINUE;
203203
}
@@ -225,6 +225,12 @@ public FileVisitResult visitFile(Path file, @NotNull BasicFileAttributes attrs)
225225
return results;
226226
}
227227

228+
private String getFileExtension(String fullName) {
229+
String fileName = (new File(fullName)).getName();
230+
int dotIndex = fileName.lastIndexOf(46);
231+
return dotIndex == -1 ? "" : fileName.substring(dotIndex + 1);
232+
}
233+
228234
/**
229235
* Format file content
230236
*/

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@
227227

228228
<mysql.version>8.0.32</mysql.version>
229229

230-
<guava.version>33.4.0-jre</guava.version>
230+
<mcp.version>0.10.0</mcp.version>
231231

232232
<!-- CheckStyle Plugin -->
233233
<disable.checks>true</disable.checks>

spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,12 @@
165165
<optional>true</optional>
166166
</dependency>
167167

168+
<dependency>
169+
<groupId>io.modelcontextprotocol.sdk</groupId>
170+
<artifactId>mcp</artifactId>
171+
<version>${mcp.version}</version>
172+
</dependency>
173+
168174
<dependency>
169175
<groupId>junit</groupId>
170176
<artifactId>junit</artifactId>
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
/*
2+
* Copyright 2024-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.alibaba.cloud.ai.graph.node;
18+
19+
import com.alibaba.cloud.ai.graph.OverAllState;
20+
import com.alibaba.cloud.ai.graph.action.NodeAction;
21+
import io.modelcontextprotocol.client.McpClient;
22+
import io.modelcontextprotocol.client.McpSyncClient;
23+
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
24+
import io.modelcontextprotocol.spec.McpSchema;
25+
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
26+
import io.modelcontextprotocol.spec.McpSchema.TextContent;
27+
import org.slf4j.Logger;
28+
import org.slf4j.LoggerFactory;
29+
import org.springframework.util.CollectionUtils;
30+
import org.springframework.util.StringUtils;
31+
32+
import java.util.HashMap;
33+
import java.util.List;
34+
import java.util.Map;
35+
import java.util.regex.Matcher;
36+
import java.util.regex.Pattern;
37+
38+
/**
39+
* MCP Node: Node for calling MCP Server
40+
*/
41+
public class McpNode implements NodeAction {
42+
43+
private static final Pattern VARIABLE_PATTERN = Pattern.compile("\\$\\{(.+?)\\}");
44+
45+
private static final Logger log = LoggerFactory.getLogger(McpNode.class);
46+
47+
private final String url;
48+
49+
private final String tool;
50+
51+
private final Map<String, String> headers;
52+
53+
private final Map<String, Object> params;
54+
55+
private final String outputKey;
56+
57+
private final List<String> inputParamKeys;
58+
59+
private HttpClientSseClientTransport transport;
60+
61+
private McpSyncClient client;
62+
63+
private McpNode(Builder builder) {
64+
this.url = builder.url;
65+
this.tool = builder.tool;
66+
this.headers = builder.headers;
67+
this.params = builder.params;
68+
this.outputKey = builder.outputKey;
69+
this.inputParamKeys = builder.inputParamKeys;
70+
}
71+
72+
@Override
73+
public Map<String, Object> apply(OverAllState state) throws Exception {
74+
log.info(
75+
"[McpNode] Start executing apply, original configuration: url={}, tool={}, headers={}, inputParamKeys={}",
76+
url, tool, headers, inputParamKeys);
77+
78+
// Build transport and client
79+
HttpClientSseClientTransport.Builder transportBuilder = HttpClientSseClientTransport.builder(this.url);
80+
if (this.headers != null && !this.headers.isEmpty()) {
81+
transportBuilder.customizeRequest(req -> this.headers.forEach(req::header));
82+
}
83+
this.transport = transportBuilder.build();
84+
this.client = McpClient.sync(this.transport).build();
85+
this.client.initialize();
86+
// Variable replacement
87+
String finalTool = replaceVariables(tool, state);
88+
Map<String, Object> finalParams = new HashMap<>();
89+
// 1. First read from inputParamKeys
90+
if (inputParamKeys != null) {
91+
for (String key : inputParamKeys) {
92+
Object value = state.value(key).orElse(null);
93+
if (value != null) {
94+
finalParams.put(key, value);
95+
}
96+
}
97+
}
98+
// 2. Then use params (after variable replacement) to overwrite
99+
Map<String, Object> replacedParams = replaceVariablesObj(params, state);
100+
if (replacedParams != null) {
101+
finalParams.putAll(replacedParams);
102+
}
103+
log.info("[McpNode] after replace params: url={}, tool={}, headers={}, params={}", url, finalTool, headers,
104+
finalParams);
105+
106+
// Directly use the already initialized client
107+
CallToolResult result;
108+
try {
109+
McpSchema.CallToolRequest request = new McpSchema.CallToolRequest(finalTool, finalParams);
110+
log.info("[McpNode] CallToolRequest: {}", request);
111+
result = client.callTool(request);
112+
log.info("[McpNode] tool call result: {}", result);
113+
}
114+
catch (Exception e) {
115+
log.error("[McpNode] MCP call fail:", e);
116+
throw new McpNodeException("MCP call fail: " + e.getMessage(), e);
117+
}
118+
119+
// Result handling
120+
Map<String, Object> updatedState = new HashMap<>();
121+
// updatedState.put("mcp_result", result.content());
122+
updatedState.put("messages", result.content());
123+
if (StringUtils.hasLength(this.outputKey)) {
124+
Object content = result.content();
125+
if (content instanceof List<?> list && !CollectionUtils.isEmpty(list)) {
126+
Object first = list.get(0);
127+
// Compatible with the text field of TextContent
128+
if (first instanceof TextContent textContent) {
129+
updatedState.put(this.outputKey, textContent.text());
130+
}
131+
else if (first instanceof Map<?, ?> map && map.containsKey("text")) {
132+
updatedState.put(this.outputKey, map.get("text"));
133+
}
134+
else {
135+
updatedState.put(this.outputKey, first);
136+
}
137+
}
138+
else {
139+
updatedState.put(this.outputKey, content);
140+
}
141+
}
142+
log.info("[McpNode] update state: {}", updatedState);
143+
return updatedState;
144+
}
145+
146+
private String replaceVariables(String template, OverAllState state) {
147+
if (template == null)
148+
return null;
149+
Matcher matcher = VARIABLE_PATTERN.matcher(template);
150+
StringBuilder result = new StringBuilder();
151+
while (matcher.find()) {
152+
String key = matcher.group(1);
153+
Object value = state.value(key).orElse("");
154+
log.info("[McpNode] replace param: {} -> {}", key, value);
155+
matcher.appendReplacement(result, value.toString());
156+
}
157+
matcher.appendTail(result);
158+
return result.toString();
159+
}
160+
161+
private Map<String, Object> replaceVariablesObj(Map<String, Object> map, OverAllState state) {
162+
if (map == null)
163+
return null;
164+
Map<String, Object> result = new HashMap<>();
165+
map.forEach((k, v) -> {
166+
if (v instanceof String) {
167+
result.put(k, replaceVariables((String) v, state));
168+
}
169+
else {
170+
result.put(k, v);
171+
}
172+
});
173+
return result;
174+
}
175+
176+
public static Builder builder() {
177+
return new Builder();
178+
}
179+
180+
public static class Builder {
181+
182+
private String url;
183+
184+
private String tool;
185+
186+
private Map<String, String> headers = new HashMap<>();
187+
188+
private Map<String, Object> params = new HashMap<>();
189+
190+
private String outputKey;
191+
192+
private List<String> inputParamKeys;
193+
194+
public Builder url(String url) {
195+
this.url = url;
196+
return this;
197+
}
198+
199+
public Builder tool(String tool) {
200+
this.tool = tool;
201+
return this;
202+
}
203+
204+
public Builder header(String name, String value) {
205+
this.headers.put(name, value);
206+
return this;
207+
}
208+
209+
public Builder param(String name, Object value) {
210+
this.params.put(name, value);
211+
return this;
212+
}
213+
214+
public Builder outputKey(String outputKey) {
215+
this.outputKey = outputKey;
216+
return this;
217+
}
218+
219+
public Builder inputParamKeys(List<String> inputParamKeys) {
220+
this.inputParamKeys = inputParamKeys;
221+
return this;
222+
}
223+
224+
public McpNode build() {
225+
return new McpNode(this);
226+
}
227+
228+
}
229+
230+
public static class McpNodeException extends RuntimeException {
231+
232+
public McpNodeException(String message, Throwable cause) {
233+
super(message, cause);
234+
}
235+
236+
}
237+
238+
}

spring-ai-alibaba-graph/spring-ai-alibaba-graph-example/pom.xml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,17 @@
4848
<version>1.18.3</version>
4949
</dependency>
5050

51+
<dependency>
52+
<groupId>io.modelcontextprotocol.sdk</groupId>
53+
<artifactId>mcp</artifactId>
54+
<version>${mcp.version}</version>
55+
</dependency>
56+
57+
<dependency>
58+
<groupId>org.springframework.ai</groupId>
59+
<artifactId>spring-ai-starter-mcp-server-webmvc</artifactId>
60+
</dependency>
61+
5162
<dependency>
5263
<groupId>org.springframework.boot</groupId>
5364
<artifactId>spring-boot-starter-web</artifactId>
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright 2024-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.alibaba.cloud.ai.example.graph.mcp;
18+
19+
import com.alibaba.cloud.ai.graph.GraphRepresentation;
20+
import com.alibaba.cloud.ai.graph.OverAllState;
21+
import com.alibaba.cloud.ai.graph.OverAllStateFactory;
22+
import com.alibaba.cloud.ai.graph.StateGraph;
23+
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
24+
import com.alibaba.cloud.ai.graph.node.McpNode;
25+
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
26+
import com.google.common.collect.Lists;
27+
import org.springframework.ai.tool.ToolCallbackProvider;
28+
import org.springframework.ai.tool.method.MethodToolCallbackProvider;
29+
import org.springframework.context.annotation.Bean;
30+
import org.springframework.context.annotation.Configuration;
31+
32+
import static com.alibaba.cloud.ai.graph.StateGraph.END;
33+
import static com.alibaba.cloud.ai.graph.StateGraph.START;
34+
import static com.alibaba.cloud.ai.graph.action.AsyncNodeAction.node_async;
35+
36+
@Configuration
37+
public class McpAutoConfiguration {
38+
39+
@Bean
40+
public ToolCallbackProvider weatherTools(OpenMeteoService openMeteoService) {
41+
return MethodToolCallbackProvider.builder().toolObjects(openMeteoService).build();
42+
}
43+
44+
@Bean
45+
public StateGraph mcpGraph() throws GraphStateException {
46+
47+
OverAllStateFactory stateFactory = () -> {
48+
OverAllState state = new OverAllState();
49+
state.registerKeyAndStrategy("latitude", new ReplaceStrategy());
50+
state.registerKeyAndStrategy("longitude", new ReplaceStrategy());
51+
state.registerKeyAndStrategy("mcp_result", new ReplaceStrategy());
52+
return state;
53+
};
54+
55+
// 示例:添加 MCP Node
56+
McpNode mcpNode = McpNode.builder()
57+
.url("http://localhost:18080/sse") // MCP Server SSE 地址
58+
.tool("getWeatherForecastByLocation") // MCP 工具名(需根据实际 MCP Server 配置)
59+
.inputParamKeys(Lists.newArrayList("latitude", "longitude")) // 输入参数键
60+
// .param("latitude",39.9042) // 工具参数
61+
// .param("longitude",116.4074) // 工具参数
62+
63+
.header("clientId", "111222") // 可选:添加请求头
64+
.outputKey("mcp_result")
65+
.build();
66+
67+
StateGraph stateGraph = new StateGraph(stateFactory).addNode("mcp_node", node_async(mcpNode))
68+
.addEdge(START, "mcp_node")
69+
.addEdge("mcp_node", END);
70+
71+
GraphRepresentation graphRepresentation = stateGraph.getGraph(GraphRepresentation.Type.PLANTUML, "mcp graph");
72+
73+
System.out.println("\n\n");
74+
System.out.println(graphRepresentation.content());
75+
System.out.println("\n\n");
76+
77+
return stateGraph;
78+
}
79+
80+
}

0 commit comments

Comments
 (0)