Skip to content

Commit 77f033a

Browse files
committed
feat(a2a): a2a server support stream.
1 parent ac5c090 commit 77f033a

File tree

4 files changed

+179
-28
lines changed

4 files changed

+179
-28
lines changed

spring-ai-alibaba-a2a/spring-ai-alibaba-a2a-common/src/main/java/com/alibaba/cloud/ai/a2a/route/JsonRpcA2aRouterProvider.java

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,16 @@
1616

1717
package com.alibaba.cloud.ai.a2a.route;
1818

19+
import java.io.IOException;
1920
import java.time.Duration;
20-
import java.util.concurrent.Flow;
21+
import java.util.function.Consumer;
2122

2223
import com.alibaba.cloud.ai.a2a.server.JsonRpcA2aRequestHandler;
24+
import io.a2a.spec.JSONRPCResponse;
25+
import io.a2a.util.Utils;
2326
import org.slf4j.Logger;
2427
import org.slf4j.LoggerFactory;
28+
import reactor.core.publisher.Flux;
2529

2630
import org.springframework.http.HttpStatus;
2731
import org.springframework.web.servlet.function.HandlerFunction;
@@ -96,8 +100,8 @@ public ServerResponse handle(ServerRequest request) throws Exception {
96100
try {
97101
String bodyString = request.body(String.class);
98102
Object result = a2aRequestHandler.onHandler(bodyString, request.headers());
99-
if (result instanceof Flow.Publisher) {
100-
return buildSseResponse(result);
103+
if (result instanceof Flux<?>) {
104+
return buildSseResponse((Flux<?>) result);
101105
}
102106
else {
103107
return buildJsonRpcResponse(result);
@@ -113,12 +117,27 @@ private ServerResponse buildJsonRpcResponse(Object result) {
113117
return ServerResponse.ok().body(result);
114118
}
115119

116-
private ServerResponse buildSseResponse(Object result) {
117-
// TODO
118-
return ServerResponse.sse((sseBuilder) -> {
120+
private ServerResponse buildSseResponse(Flux<?> result) {
121+
return ServerResponse.sse(sseBuilder -> {
119122
sseBuilder.onComplete(() -> {
123+
log.debug("Agent SSE connection completed.");
120124
});
121125
sseBuilder.onTimeout(() -> {
126+
log.debug("Agent SSE connection timeout.");
127+
});
128+
result.subscribe((Consumer<Object>) o -> {
129+
if (o instanceof JSONRPCResponse) {
130+
try {
131+
String sseBody = Utils.OBJECT_MAPPER.writeValueAsString(o);
132+
if (log.isDebugEnabled()) {
133+
log.debug("send sse body to agent: {}", sseBody);
134+
}
135+
sseBuilder.data(sseBody);
136+
}
137+
catch (IOException e) {
138+
sseBuilder.error(e);
139+
}
140+
}
122141
});
123142
}, Duration.ZERO);
124143
}

spring-ai-alibaba-a2a/spring-ai-alibaba-a2a-common/src/main/java/com/alibaba/cloud/ai/a2a/server/GraphAgentExecutor.java

Lines changed: 126 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,31 @@
1616

1717
package com.alibaba.cloud.ai.a2a.server;
1818

19+
import java.util.Collections;
1920
import java.util.List;
2021
import java.util.Map;
22+
import java.util.Set;
2123
import java.util.UUID;
24+
import java.util.concurrent.TimeUnit;
25+
import java.util.concurrent.atomic.AtomicInteger;
26+
import java.util.function.Consumer;
2227

28+
import com.alibaba.cloud.ai.graph.NodeOutput;
2329
import com.alibaba.cloud.ai.graph.agent.BaseAgent;
30+
import com.alibaba.cloud.ai.graph.async.AsyncGenerator;
31+
import com.alibaba.cloud.ai.graph.exception.GraphRunnerException;
32+
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
33+
import com.alibaba.cloud.ai.graph.streaming.StreamingOutput;
34+
import com.alibaba.fastjson.JSON;
35+
import com.alibaba.fastjson.JSONObject;
2436
import io.a2a.A2A;
2537
import io.a2a.server.agentexecution.AgentExecutor;
2638
import io.a2a.server.agentexecution.RequestContext;
2739
import io.a2a.server.events.EventQueue;
2840
import io.a2a.server.tasks.TaskUpdater;
2941
import io.a2a.spec.JSONRPCError;
3042
import io.a2a.spec.Message;
43+
import io.a2a.spec.MessageSendParams;
3144
import io.a2a.spec.Part;
3245
import io.a2a.spec.Task;
3346
import io.a2a.spec.TaskState;
@@ -37,11 +50,16 @@
3750
import org.slf4j.LoggerFactory;
3851

3952
import org.springframework.ai.chat.messages.UserMessage;
53+
import org.springframework.util.StringUtils;
4054

4155
public class GraphAgentExecutor implements AgentExecutor {
4256

4357
private static final Logger LOGGER = LoggerFactory.getLogger(GraphAgentExecutor.class);
4458

59+
private static final Set<String> IGNORE_NODE_TYPE = Set.of("preLlm", "postLlm", "preTool", "tool", "postTool");
60+
61+
public static final String STREAMING_METADATA_KEY = "isStreaming";
62+
4563
private final BaseAgent executeAgent;
4664

4765
public GraphAgentExecutor(BaseAgent executeAgent) {
@@ -72,28 +90,11 @@ public void execute(RequestContext context, EventQueue eventQueue) throws JSONRP
7290
}
7391
// TODO adapter for all agent type, now only support react agent
7492
Map<String, Object> input = Map.of("messages", List.of(new UserMessage(sb.toString().trim())));
75-
var result = executeAgent.invoke(input);
76-
String outputText = result.get().data().containsKey(executeAgent.outputKey())
77-
? String.valueOf(result.get().data().get(executeAgent.outputKey())) : "No output key in result.";
78-
79-
Task task = context.getTask();
80-
if (task == null) {
81-
task = newTask(context.getMessage());
82-
eventQueue.enqueueEvent(task);
83-
}
84-
TaskUpdater taskUpdater = new TaskUpdater(context, eventQueue);
85-
boolean taskComplete = true;
86-
boolean requireUserInput = false;
87-
if (!taskComplete && !requireUserInput) {
88-
taskUpdater.startWork(taskUpdater.newAgentMessage(List.of(new TextPart(outputText)), Map.of()));
89-
}
90-
else if (requireUserInput) {
91-
taskUpdater.startWork(taskUpdater.newAgentMessage(List.of(new TextPart(outputText)), Map.of()));
93+
if (isStreamRequest(context)) {
94+
executeStreamTask(input, context, eventQueue);
9295
}
9396
else {
94-
taskUpdater.addArtifact(List.of(new TextPart(outputText)), UUID.randomUUID().toString(),
95-
"conversation_result", Map.of("output", outputText));
96-
taskUpdater.complete();
97+
executeForNonStreamTask(input, context, eventQueue);
9798
}
9899
}
99100
catch (Exception e) {
@@ -106,4 +107,109 @@ else if (requireUserInput) {
106107
public void cancel(RequestContext context, EventQueue eventQueue) throws JSONRPCError {
107108
}
108109

110+
private boolean isStreamRequest(RequestContext context) {
111+
MessageSendParams params = context.getParams();
112+
if (null == params.metadata()) {
113+
return false;
114+
}
115+
if (!params.metadata().containsKey(STREAMING_METADATA_KEY)) {
116+
return false;
117+
}
118+
return (boolean) params.metadata().get(STREAMING_METADATA_KEY);
119+
}
120+
121+
private void executeStreamTask(Map<String, Object> input, RequestContext context, EventQueue eventQueue)
122+
throws GraphStateException, GraphRunnerException {
123+
AsyncGenerator<NodeOutput> generator = executeAgent.stream(input);
124+
Task task = context.getTask();
125+
if (task == null) {
126+
task = newTask(context.getMessage());
127+
eventQueue.enqueueEvent(task);
128+
}
129+
TaskUpdater taskUpdater = new TaskUpdater(context, eventQueue);
130+
taskUpdater.submit();
131+
generator.forEachAsync(new ReactAgentNodeOutputConsumer(taskUpdater)).thenAccept(o -> taskUpdater.complete());
132+
waitTaskCompleted(task);
133+
}
134+
135+
private void executeForNonStreamTask(Map<String, Object> input, RequestContext context, EventQueue eventQueue)
136+
throws GraphStateException, GraphRunnerException {
137+
var result = executeAgent.invoke(input);
138+
String outputText = result.get().data().containsKey(executeAgent.outputKey())
139+
? String.valueOf(result.get().data().get(executeAgent.outputKey())) : "No output key in result.";
140+
141+
Task task = context.getTask();
142+
if (task == null) {
143+
task = newTask(context.getMessage());
144+
eventQueue.enqueueEvent(task);
145+
}
146+
TaskUpdater taskUpdater = new TaskUpdater(context, eventQueue);
147+
boolean taskComplete = true;
148+
boolean requireUserInput = false;
149+
if (!taskComplete && !requireUserInput) {
150+
taskUpdater.startWork(taskUpdater.newAgentMessage(List.of(new TextPart(outputText)), Map.of()));
151+
}
152+
else if (requireUserInput) {
153+
taskUpdater.startWork(taskUpdater.newAgentMessage(List.of(new TextPart(outputText)), Map.of()));
154+
}
155+
else {
156+
taskUpdater.addArtifact(List.of(new TextPart(outputText)), UUID.randomUUID().toString(),
157+
"conversation_result", Map.of("output", outputText));
158+
taskUpdater.complete();
159+
}
160+
}
161+
162+
private void waitTaskCompleted(Task task) {
163+
while (!task.getStatus().state().equals(TaskState.COMPLETED)
164+
&& !task.getStatus().state().equals(TaskState.CANCELED)) {
165+
try {
166+
TimeUnit.SECONDS.sleep(1);
167+
}
168+
catch (InterruptedException ignored) {
169+
}
170+
}
171+
}
172+
173+
private static class ReactAgentNodeOutputConsumer implements Consumer<NodeOutput> {
174+
175+
private final TaskUpdater taskUpdater;
176+
177+
private final AtomicInteger artifactNum;
178+
179+
private ReactAgentNodeOutputConsumer(TaskUpdater taskUpdater) {
180+
this.taskUpdater = taskUpdater;
181+
this.artifactNum = new AtomicInteger();
182+
}
183+
184+
@Override
185+
public void accept(NodeOutput nodeOutput) {
186+
if (nodeOutput.isSTART() || nodeOutput.isEND() || IGNORE_NODE_TYPE.contains(nodeOutput.node())) {
187+
if (LOGGER.isDebugEnabled()) {
188+
LOGGER.debug("Agent parts output: {}", buildDebugDetailInfo(nodeOutput));
189+
}
190+
return;
191+
}
192+
193+
String content = "";
194+
if (nodeOutput instanceof StreamingOutput) {
195+
content = ((StreamingOutput) nodeOutput).chunk();
196+
}
197+
198+
if (!StringUtils.hasLength(content)) {
199+
return;
200+
}
201+
202+
taskUpdater.addArtifact(Collections.singletonList(new TextPart(content)), null,
203+
String.valueOf(artifactNum.incrementAndGet()), Map.of());
204+
}
205+
206+
private String buildDebugDetailInfo(NodeOutput nodeOutput) {
207+
JSONObject outputJson = new JSONObject();
208+
outputJson.put("data", nodeOutput.state().data());
209+
outputJson.put("node", nodeOutput.node());
210+
return JSON.toJSONString(outputJson);
211+
}
212+
213+
}
214+
109215
}

spring-ai-alibaba-a2a/spring-ai-alibaba-a2a-common/src/main/java/com/alibaba/cloud/ai/a2a/server/JsonRpcA2aRequestHandler.java

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package com.alibaba.cloud.ai.a2a.server;
1818

1919
import java.time.Duration;
20+
import java.util.Map;
2021
import java.util.concurrent.Flow;
2122
import java.util.function.Function;
2223

@@ -34,6 +35,7 @@
3435
import io.a2a.spec.JSONRPCRequest;
3536
import io.a2a.spec.JSONRPCResponse;
3637
import io.a2a.spec.ListTaskPushNotificationConfigRequest;
38+
import io.a2a.spec.MessageSendParams;
3739
import io.a2a.spec.NonStreamingJSONRPCRequest;
3840
import io.a2a.spec.SendMessageRequest;
3941
import io.a2a.spec.SendStreamingMessageRequest;
@@ -98,7 +100,12 @@ private Flux<?> handleStreamRequest(String body) throws JsonProcessingException
98100
StreamingJSONRPCRequest<?> request = Utils.OBJECT_MAPPER.readValue(body, StreamingJSONRPCRequest.class);
99101
Flow.Publisher<? extends JSONRPCResponse<?>> publisher;
100102
if (request instanceof SendStreamingMessageRequest req) {
101-
publisher = jsonRpcHandler.onMessageSendStream(req);
103+
SendStreamingMessageRequest.Builder newReqBuilder = new SendStreamingMessageRequest.Builder()
104+
.id(req.getId())
105+
.jsonrpc(req.getJsonrpc())
106+
.method(req.getMethod())
107+
.params(injectStreamMetadata(req.getParams(), true));
108+
publisher = jsonRpcHandler.onMessageSendStream(newReqBuilder.build());
102109
LOGGER.info("get Stream publisher {}", publisher);
103110
}
104111
else if (request instanceof TaskResubscriptionRequest req) {
@@ -118,7 +125,11 @@ private JSONRPCResponse<?> handleNonStreamRequest(String body) throws JsonProces
118125
return jsonRpcHandler.onGetTask(req);
119126
}
120127
else if (request instanceof SendMessageRequest req) {
121-
return jsonRpcHandler.onMessageSend(req);
128+
SendMessageRequest.Builder newReqBuilder = new SendMessageRequest.Builder().id(req.getId())
129+
.jsonrpc(req.getJsonrpc())
130+
.method(req.getMethod())
131+
.params(injectStreamMetadata(req.getParams(), false));
132+
return jsonRpcHandler.onMessageSend(newReqBuilder.build());
122133
}
123134
else if (request instanceof CancelTaskRequest req) {
124135
return jsonRpcHandler.onCancelTask(req);
@@ -144,4 +155,18 @@ private static JSONRPCErrorResponse generateErrorResponse(JSONRPCRequest<?> requ
144155
return new JSONRPCErrorResponse(request.getId(), error);
145156
}
146157

158+
private MessageSendParams injectStreamMetadata(MessageSendParams original, boolean isStreaming) {
159+
if (null == original.metadata()) {
160+
MessageSendParams.Builder newBuilder = new MessageSendParams.Builder();
161+
newBuilder.configuration(original.configuration());
162+
newBuilder.metadata(Map.of(GraphAgentExecutor.STREAMING_METADATA_KEY, isStreaming));
163+
newBuilder.message(original.message());
164+
return newBuilder.build();
165+
}
166+
else {
167+
original.metadata().put(GraphAgentExecutor.STREAMING_METADATA_KEY, isStreaming);
168+
return original;
169+
}
170+
}
171+
147172
}

spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/agent/a2a/A2aNode.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ private AsyncGenerator<NodeOutput> createStreamingGenerator(OverAllState state)
105105
try (CloseableHttpClient httpClient = HttpClients.createDefault()) {
106106
HttpPost post = new HttpPost(baseUrl);
107107
post.setHeader("Content-Type", "application/json");
108+
post.setHeader("Accept", "text/event-stream");
108109
post.setEntity(new StringEntity(requestPayload, ContentType.APPLICATION_JSON));
109110

110111
try (CloseableHttpResponse response = httpClient.execute(post)) {

0 commit comments

Comments
 (0)