@@ -177,6 +177,10 @@ private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, Ob
177177 Map <String , Object > templateVariables = getTemplateVariables (methodArgs , methodCreateInfo .getUserMessageInfo ());
178178
179179 Type returnType = methodCreateInfo .getReturnType ();
180+ boolean isMulti = TypeUtil .isMulti (returnType );
181+
182+ final boolean isStringMulti = (isMulti && returnType instanceof ParameterizedType
183+ && TypeUtil .isTypeOf (((ParameterizedType ) returnType ).getActualTypeArguments ()[0 ], String .class ));
180184 if (TypeUtil .isImage (returnType ) || TypeUtil .isResultImage (returnType )) {
181185 return doImplementGenerateImage (methodCreateInfo , context , systemMessage , userMessage , memoryId , returnType ,
182186 templateVariables , auditSourceInfo );
@@ -217,7 +221,7 @@ private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, Ob
217221 Metadata metadata = Metadata .from (userMessage , memoryId , chatMemory );
218222 AugmentationRequest augmentationRequest = new AugmentationRequest (userMessage , metadata );
219223
220- if (!TypeUtil . isMulti ( returnType ) ) {
224+ if (!isMulti ) {
221225 augmentationResult = context .retrievalAugmentor .augment (augmentationRequest );
222226 userMessage = (UserMessage ) augmentationResult .chatMessage ();
223227 } else {
@@ -244,10 +248,18 @@ public Flow.Publisher<?> apply(AugmentationResult ar) {
244248 var stream = new TokenStreamMulti (messagesToSend , effectiveToolSpecifications ,
245249 finalToolExecutors , ar .contents (), context , memoryId ,
246250 methodCreateInfo .isSwitchToWorkerThreadForToolExecution (), isRunningOnWorkerThread );
247- return stream .plug (m -> ResponseAugmenterSupport .apply (m , methodCreateInfo ,
248- new ResponseAugmenterParams ((UserMessage ) augmentedUserMessage ,
249- memory , ar , methodCreateInfo .getUserMessageTemplate (),
250- templateVariables )));
251+ return stream
252+ .filter (event -> {
253+ return !isStringMulti || event instanceof ChatEvent .PartialResponseEvent ;
254+ }).map (event -> {
255+ if (isStringMulti && event instanceof ChatEvent .PartialResponseEvent ) {
256+ return ((ChatEvent .PartialResponseEvent ) event ).getChunk ();
257+ }
258+ return event ;
259+ }).plug (m -> ResponseAugmenterSupport .apply (m , methodCreateInfo ,
260+ new ResponseAugmenterParams ((UserMessage ) augmentedUserMessage ,
261+ memory , ar , methodCreateInfo .getUserMessageTemplate (),
262+ templateVariables )));
251263 }
252264
253265 private List <ChatMessage > messagesToSend (UserMessage augmentedUserMessage ,
@@ -297,13 +309,20 @@ private List<ChatMessage> messagesToSend(UserMessage augmentedUserMessage,
297309
298310 var actualAugmentationResult = augmentationResult ;
299311 var actualUserMessage = userMessage ;
300- if (TypeUtil . isMulti ( returnType ) ) {
312+ if (isMulti ) {
301313 chatMemory .commit (); // for streaming cases, we really have to commit because all alternatives are worse
302314 if (methodCreateInfo .getOutputGuardrailsClassNames ().isEmpty ()) {
303315 var stream = new TokenStreamMulti (messagesToSend , toolSpecifications , toolExecutors ,
304316 (augmentationResult != null ? augmentationResult .contents () : null ), context , memoryId ,
305317 methodCreateInfo .isSwitchToWorkerThreadForToolExecution (), isRunningOnWorkerThread );
306- return stream .plug (m -> ResponseAugmenterSupport .apply (m , methodCreateInfo ,
318+ return stream .filter (event -> {
319+ return !isStringMulti || event instanceof ChatEvent .PartialResponseEvent ;
320+ }).map (event -> {
321+ if (isStringMulti && event instanceof ChatEvent .PartialResponseEvent ) {
322+ return ((ChatEvent .PartialResponseEvent ) event ).getChunk ();
323+ }
324+ return event ;
325+ }).plug (m -> ResponseAugmenterSupport .apply (m , methodCreateInfo ,
307326 new ResponseAugmenterParams (actualUserMessage ,
308327 chatMemory , actualAugmentationResult , methodCreateInfo .getUserMessageTemplate (),
309328 Collections .unmodifiableMap (templateVariables ))));
@@ -317,7 +336,8 @@ private List<ChatMessage> messagesToSend(UserMessage augmentedUserMessage,
317336 OutputGuardrailResult result ;
318337 try {
319338 result = GuardrailsSupport .invokeOutputGuardrailsForStream (methodCreateInfo ,
320- new OutputGuardrailParams (AiMessage .from (chunk ), chatMemory , actualAugmentationResult ,
339+ new OutputGuardrailParams (AiMessage .from (chunk .getMessage ()), chatMemory ,
340+ actualAugmentationResult ,
321341 methodCreateInfo .getUserMessageTemplate (),
322342 Collections .unmodifiableMap (templateVariables )),
323343 beanManager , auditSourceInfo );
@@ -340,6 +360,9 @@ private List<ChatMessage> messagesToSend(UserMessage augmentedUserMessage,
340360 throw new GuardrailException (
341361 "Attempting to rewrite the LLM output while streaming is not allowed" );
342362 }
363+ if (isStringMulti ) {
364+ return chunk .getMessage ();
365+ }
343366 return chunk ;
344367 }
345368 })
@@ -915,7 +938,7 @@ public interface Wrapper {
915938 Object wrap (Input input , Function <Input , Object > fun );
916939 }
917940
918- private static class TokenStreamMulti extends AbstractMulti <String > implements Multi <String > {
941+ private static class TokenStreamMulti extends AbstractMulti <ChatEvent > implements Multi <ChatEvent > {
919942 private final List <ChatMessage > messagesToSend ;
920943 private final List <ToolSpecification > toolSpecifications ;
921944 private final Map <String , ToolExecutor > toolsExecutors ;
@@ -941,14 +964,14 @@ public TokenStreamMulti(List<ChatMessage> messagesToSend, List<ToolSpecification
941964 }
942965
943966 @ Override
944- public void subscribe (MultiSubscriber <? super String > subscriber ) {
945- UnicastProcessor <String > processor = UnicastProcessor .create ();
967+ public void subscribe (MultiSubscriber <? super ChatEvent > subscriber ) {
968+ UnicastProcessor <ChatEvent > processor = UnicastProcessor .create ();
946969 processor .subscribe (subscriber );
947970
948971 createTokenStream (processor );
949972 }
950973
951- private void createTokenStream (UnicastProcessor <String > processor ) {
974+ private void createTokenStream (UnicastProcessor <ChatEvent > processor ) {
952975 Context ctxt = null ;
953976 if (switchToWorkerThreadForToolExecution || isCallerRunningOnWorkerThread ) {
954977 // we create or retrieve the current context, to use `executeBlocking` when required.
@@ -959,8 +982,18 @@ private void createTokenStream(UnicastProcessor<String> processor) {
959982 toolsExecutors , contents , context , memoryId , ctxt , switchToWorkerThreadForToolExecution ,
960983 isCallerRunningOnWorkerThread );
961984 TokenStream tokenStream = stream
962- .onPartialResponse (processor ::onNext )
963- .onCompleteResponse (message -> processor .onComplete ())
985+ .onPartialResponse (chunk -> processor
986+ .onNext (new ChatEvent .PartialResponseEvent (chunk )))
987+ .onCompleteResponse (message -> {
988+ processor .onNext (new ChatEvent .ChatCompletedEvent (message ));
989+ processor .onComplete ();
990+ })
991+ .onRetrieved (content -> {
992+ processor .onNext (new ChatEvent .ContentFetchedEvent (content ));
993+ })
994+ .onToolExecuted (execution -> {
995+ processor .onNext (new ChatEvent .ToolExecutedEvent (execution ));
996+ })
964997 .onError (processor ::onError );
965998 // This is equivalent to "run subscription on worker thread"
966999 if (switchToWorkerThreadForToolExecution && Context .isOnEventLoopThread ()) {
0 commit comments