16
16
17
17
package com .alibaba .cloud .ai .a2a .server ;
18
18
19
+ import java .util .Collections ;
19
20
import java .util .List ;
20
21
import java .util .Map ;
22
+ import java .util .Set ;
21
23
import java .util .UUID ;
24
+ import java .util .concurrent .TimeUnit ;
25
+ import java .util .concurrent .atomic .AtomicInteger ;
26
+ import java .util .function .Consumer ;
22
27
28
+ import com .alibaba .cloud .ai .graph .NodeOutput ;
23
29
import com .alibaba .cloud .ai .graph .agent .BaseAgent ;
30
+ import com .alibaba .cloud .ai .graph .async .AsyncGenerator ;
31
+ import com .alibaba .cloud .ai .graph .exception .GraphRunnerException ;
32
+ import com .alibaba .cloud .ai .graph .exception .GraphStateException ;
33
+ import com .alibaba .cloud .ai .graph .streaming .StreamingOutput ;
34
+ import com .alibaba .fastjson .JSON ;
35
+ import com .alibaba .fastjson .JSONObject ;
24
36
import io .a2a .A2A ;
25
37
import io .a2a .server .agentexecution .AgentExecutor ;
26
38
import io .a2a .server .agentexecution .RequestContext ;
27
39
import io .a2a .server .events .EventQueue ;
28
40
import io .a2a .server .tasks .TaskUpdater ;
29
41
import io .a2a .spec .JSONRPCError ;
30
42
import io .a2a .spec .Message ;
43
+ import io .a2a .spec .MessageSendParams ;
31
44
import io .a2a .spec .Part ;
32
45
import io .a2a .spec .Task ;
33
46
import io .a2a .spec .TaskState ;
37
50
import org .slf4j .LoggerFactory ;
38
51
39
52
import org .springframework .ai .chat .messages .UserMessage ;
53
+ import org .springframework .util .StringUtils ;
40
54
41
55
public class GraphAgentExecutor implements AgentExecutor {
42
56
43
57
private static final Logger LOGGER = LoggerFactory .getLogger (GraphAgentExecutor .class );
44
58
59
+ private static final Set <String > IGNORE_NODE_TYPE = Set .of ("preLlm" , "postLlm" , "preTool" , "tool" , "postTool" );
60
+
61
+ public static final String STREAMING_METADATA_KEY = "isStreaming" ;
62
+
45
63
private final BaseAgent executeAgent ;
46
64
47
65
public GraphAgentExecutor (BaseAgent executeAgent ) {
@@ -72,28 +90,11 @@ public void execute(RequestContext context, EventQueue eventQueue) throws JSONRP
72
90
}
73
91
// TODO adapter for all agent type, now only support react agent
74
92
Map <String , Object > input = Map .of ("messages" , List .of (new UserMessage (sb .toString ().trim ())));
75
- var result = executeAgent .invoke (input );
76
- String outputText = result .get ().data ().containsKey (executeAgent .outputKey ())
77
- ? String .valueOf (result .get ().data ().get (executeAgent .outputKey ())) : "No output key in result." ;
78
-
79
- Task task = context .getTask ();
80
- if (task == null ) {
81
- task = newTask (context .getMessage ());
82
- eventQueue .enqueueEvent (task );
83
- }
84
- TaskUpdater taskUpdater = new TaskUpdater (context , eventQueue );
85
- boolean taskComplete = true ;
86
- boolean requireUserInput = false ;
87
- if (!taskComplete && !requireUserInput ) {
88
- taskUpdater .startWork (taskUpdater .newAgentMessage (List .of (new TextPart (outputText )), Map .of ()));
89
- }
90
- else if (requireUserInput ) {
91
- taskUpdater .startWork (taskUpdater .newAgentMessage (List .of (new TextPart (outputText )), Map .of ()));
93
+ if (isStreamRequest (context )) {
94
+ executeStreamTask (input , context , eventQueue );
92
95
}
93
96
else {
94
- taskUpdater .addArtifact (List .of (new TextPart (outputText )), UUID .randomUUID ().toString (),
95
- "conversation_result" , Map .of ("output" , outputText ));
96
- taskUpdater .complete ();
97
+ executeForNonStreamTask (input , context , eventQueue );
97
98
}
98
99
}
99
100
catch (Exception e ) {
@@ -106,4 +107,109 @@ else if (requireUserInput) {
106
107
public void cancel (RequestContext context , EventQueue eventQueue ) throws JSONRPCError {
107
108
}
108
109
110
+ private boolean isStreamRequest (RequestContext context ) {
111
+ MessageSendParams params = context .getParams ();
112
+ if (null == params .metadata ()) {
113
+ return false ;
114
+ }
115
+ if (!params .metadata ().containsKey (STREAMING_METADATA_KEY )) {
116
+ return false ;
117
+ }
118
+ return (boolean ) params .metadata ().get (STREAMING_METADATA_KEY );
119
+ }
120
+
121
+ private void executeStreamTask (Map <String , Object > input , RequestContext context , EventQueue eventQueue )
122
+ throws GraphStateException , GraphRunnerException {
123
+ AsyncGenerator <NodeOutput > generator = executeAgent .stream (input );
124
+ Task task = context .getTask ();
125
+ if (task == null ) {
126
+ task = newTask (context .getMessage ());
127
+ eventQueue .enqueueEvent (task );
128
+ }
129
+ TaskUpdater taskUpdater = new TaskUpdater (context , eventQueue );
130
+ taskUpdater .submit ();
131
+ generator .forEachAsync (new ReactAgentNodeOutputConsumer (taskUpdater )).thenAccept (o -> taskUpdater .complete ());
132
+ waitTaskCompleted (task );
133
+ }
134
+
135
+ private void executeForNonStreamTask (Map <String , Object > input , RequestContext context , EventQueue eventQueue )
136
+ throws GraphStateException , GraphRunnerException {
137
+ var result = executeAgent .invoke (input );
138
+ String outputText = result .get ().data ().containsKey (executeAgent .outputKey ())
139
+ ? String .valueOf (result .get ().data ().get (executeAgent .outputKey ())) : "No output key in result." ;
140
+
141
+ Task task = context .getTask ();
142
+ if (task == null ) {
143
+ task = newTask (context .getMessage ());
144
+ eventQueue .enqueueEvent (task );
145
+ }
146
+ TaskUpdater taskUpdater = new TaskUpdater (context , eventQueue );
147
+ boolean taskComplete = true ;
148
+ boolean requireUserInput = false ;
149
+ if (!taskComplete && !requireUserInput ) {
150
+ taskUpdater .startWork (taskUpdater .newAgentMessage (List .of (new TextPart (outputText )), Map .of ()));
151
+ }
152
+ else if (requireUserInput ) {
153
+ taskUpdater .startWork (taskUpdater .newAgentMessage (List .of (new TextPart (outputText )), Map .of ()));
154
+ }
155
+ else {
156
+ taskUpdater .addArtifact (List .of (new TextPart (outputText )), UUID .randomUUID ().toString (),
157
+ "conversation_result" , Map .of ("output" , outputText ));
158
+ taskUpdater .complete ();
159
+ }
160
+ }
161
+
162
+ private void waitTaskCompleted (Task task ) {
163
+ while (!task .getStatus ().state ().equals (TaskState .COMPLETED )
164
+ && !task .getStatus ().state ().equals (TaskState .CANCELED )) {
165
+ try {
166
+ TimeUnit .SECONDS .sleep (1 );
167
+ }
168
+ catch (InterruptedException ignored ) {
169
+ }
170
+ }
171
+ }
172
+
173
+ private static class ReactAgentNodeOutputConsumer implements Consumer <NodeOutput > {
174
+
175
+ private final TaskUpdater taskUpdater ;
176
+
177
+ private final AtomicInteger artifactNum ;
178
+
179
+ private ReactAgentNodeOutputConsumer (TaskUpdater taskUpdater ) {
180
+ this .taskUpdater = taskUpdater ;
181
+ this .artifactNum = new AtomicInteger ();
182
+ }
183
+
184
+ @ Override
185
+ public void accept (NodeOutput nodeOutput ) {
186
+ if (nodeOutput .isSTART () || nodeOutput .isEND () || IGNORE_NODE_TYPE .contains (nodeOutput .node ())) {
187
+ if (LOGGER .isDebugEnabled ()) {
188
+ LOGGER .debug ("Agent parts output: {}" , buildDebugDetailInfo (nodeOutput ));
189
+ }
190
+ return ;
191
+ }
192
+
193
+ String content = "" ;
194
+ if (nodeOutput instanceof StreamingOutput ) {
195
+ content = ((StreamingOutput ) nodeOutput ).chunk ();
196
+ }
197
+
198
+ if (!StringUtils .hasLength (content )) {
199
+ return ;
200
+ }
201
+
202
+ taskUpdater .addArtifact (Collections .singletonList (new TextPart (content )), null ,
203
+ String .valueOf (artifactNum .incrementAndGet ()), Map .of ());
204
+ }
205
+
206
+ private String buildDebugDetailInfo (NodeOutput nodeOutput ) {
207
+ JSONObject outputJson = new JSONObject ();
208
+ outputJson .put ("data" , nodeOutput .state ().data ());
209
+ outputJson .put ("node" , nodeOutput .node ());
210
+ return JSON .toJSONString (outputJson );
211
+ }
212
+
213
+ }
214
+
109
215
}
0 commit comments