Skip to content

Commit bf7751a

Browse files
committed
Allow returning a ChatResponse Multi in streaming mode, otherwise we can't get the metadata and token usage and associate it with a specific request.
Prevent duplicate final message when using Multi<String> return value Allow using a new ChatEvent producing multi so consumer can get a stream of all chat events. Replace start imports. sort imports Start fixing failing tests Don't use multi with context, as it prevents an exception to propagate when accumulator throws Formatting sort imports Fix more tests Add javadoc to ChatEvent and update the guide on streaming response to indicate possible usage of ChatEvent formatting
1 parent ac78650 commit bf7751a

File tree

6 files changed

+332
-26
lines changed

6 files changed

+332
-26
lines changed

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -693,14 +693,15 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
693693
if (!DotNames.MULTI.equals(method.returnType().name())) {
694694
continue;
695695
}
696-
boolean isMultiString = false;
696+
boolean isSupportedResponseType = false;
697697
if (method.returnType().kind() == Type.Kind.PARAMETERIZED_TYPE) {
698698
Type multiType = method.returnType().asParameterizedType().arguments().get(0);
699-
if (DotNames.STRING.equals(multiType.name())) {
700-
isMultiString = true;
699+
if (DotNames.STRING.equals(multiType.name())
700+
|| DotNames.CHAT_EVENT.equals(multiType.name())) {
701+
isSupportedResponseType = true;
701702
}
702703
}
703-
if (!isMultiString) {
704+
if (!isSupportedResponseType) {
704705
throw illegalConfiguration("Only Multi<String> is supported as a Multi return type. Offending method is '"
705706
+ method.declaringClass().name().toString() + "#" + method.name() + "'");
706707
}

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator;
1818
import io.quarkiverse.langchain4j.response.AiResponseAugmenter;
1919
import io.quarkiverse.langchain4j.response.ResponseAugmenter;
20+
import io.quarkiverse.langchain4j.runtime.aiservice.ChatEvent;
2021
import io.smallrye.common.annotation.Blocking;
2122
import io.smallrye.common.annotation.NonBlocking;
2223
import io.smallrye.common.annotation.RunOnVirtualThread;
@@ -62,6 +63,7 @@ public class DotNames {
6263
public static final DotName CHAT_MODEL_LISTENER = DotName.createSimple(ChatModelListener.class);
6364
public static final DotName MODEL_AUTH_PROVIDER = DotName.createSimple(ModelAuthProvider.class);
6465
public static final DotName TOOL = DotName.createSimple(Tool.class);
66+
public static final DotName CHAT_EVENT = DotName.createSimple(ChatEvent.class);
6567

6668
public static final DotName REGISTER_REST_CLIENT = DotName.createSimple(RegisterRestClient.class);
6769

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,10 @@ private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, Ob
177177
Map<String, Object> templateVariables = getTemplateVariables(methodArgs, methodCreateInfo.getUserMessageInfo());
178178

179179
Type returnType = methodCreateInfo.getReturnType();
180+
boolean isMulti = TypeUtil.isMulti(returnType);
181+
182+
final boolean isStringMulti = (isMulti && returnType instanceof ParameterizedType
183+
&& TypeUtil.isTypeOf(((ParameterizedType) returnType).getActualTypeArguments()[0], String.class));
180184
if (TypeUtil.isImage(returnType) || TypeUtil.isResultImage(returnType)) {
181185
return doImplementGenerateImage(methodCreateInfo, context, systemMessage, userMessage, memoryId, returnType,
182186
templateVariables, auditSourceInfo);
@@ -217,7 +221,7 @@ private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, Ob
217221
Metadata metadata = Metadata.from(userMessage, memoryId, chatMemory);
218222
AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata);
219223

220-
if (!TypeUtil.isMulti(returnType)) {
224+
if (!isMulti) {
221225
augmentationResult = context.retrievalAugmentor.augment(augmentationRequest);
222226
userMessage = (UserMessage) augmentationResult.chatMessage();
223227
} else {
@@ -244,10 +248,18 @@ public Flow.Publisher<?> apply(AugmentationResult ar) {
244248
var stream = new TokenStreamMulti(messagesToSend, effectiveToolSpecifications,
245249
finalToolExecutors, ar.contents(), context, memoryId,
246250
methodCreateInfo.isSwitchToWorkerThreadForToolExecution(), isRunningOnWorkerThread);
247-
return stream.plug(m -> ResponseAugmenterSupport.apply(m, methodCreateInfo,
248-
new ResponseAugmenterParams((UserMessage) augmentedUserMessage,
249-
memory, ar, methodCreateInfo.getUserMessageTemplate(),
250-
templateVariables)));
251+
return stream
252+
.filter(event -> {
253+
return !isStringMulti || event instanceof ChatEvent.PartialResponseEvent;
254+
}).map(event -> {
255+
if (isStringMulti && event instanceof ChatEvent.PartialResponseEvent) {
256+
return ((ChatEvent.PartialResponseEvent) event).getChunk();
257+
}
258+
return event;
259+
}).plug(m -> ResponseAugmenterSupport.apply(m, methodCreateInfo,
260+
new ResponseAugmenterParams((UserMessage) augmentedUserMessage,
261+
memory, ar, methodCreateInfo.getUserMessageTemplate(),
262+
templateVariables)));
251263
}
252264

253265
private List<ChatMessage> messagesToSend(UserMessage augmentedUserMessage,
@@ -297,13 +309,20 @@ private List<ChatMessage> messagesToSend(UserMessage augmentedUserMessage,
297309

298310
var actualAugmentationResult = augmentationResult;
299311
var actualUserMessage = userMessage;
300-
if (TypeUtil.isMulti(returnType)) {
312+
if (isMulti) {
301313
chatMemory.commit(); // for streaming cases, we really have to commit because all alternatives are worse
302314
if (methodCreateInfo.getOutputGuardrailsClassNames().isEmpty()) {
303315
var stream = new TokenStreamMulti(messagesToSend, toolSpecifications, toolExecutors,
304316
(augmentationResult != null ? augmentationResult.contents() : null), context, memoryId,
305317
methodCreateInfo.isSwitchToWorkerThreadForToolExecution(), isRunningOnWorkerThread);
306-
return stream.plug(m -> ResponseAugmenterSupport.apply(m, methodCreateInfo,
318+
return stream.filter(event -> {
319+
return !isStringMulti || event instanceof ChatEvent.PartialResponseEvent;
320+
}).map(event -> {
321+
if (isStringMulti && event instanceof ChatEvent.PartialResponseEvent) {
322+
return ((ChatEvent.PartialResponseEvent) event).getChunk();
323+
}
324+
return event;
325+
}).plug(m -> ResponseAugmenterSupport.apply(m, methodCreateInfo,
307326
new ResponseAugmenterParams(actualUserMessage,
308327
chatMemory, actualAugmentationResult, methodCreateInfo.getUserMessageTemplate(),
309328
Collections.unmodifiableMap(templateVariables))));
@@ -317,7 +336,8 @@ private List<ChatMessage> messagesToSend(UserMessage augmentedUserMessage,
317336
OutputGuardrailResult result;
318337
try {
319338
result = GuardrailsSupport.invokeOutputGuardrailsForStream(methodCreateInfo,
320-
new OutputGuardrailParams(AiMessage.from(chunk), chatMemory, actualAugmentationResult,
339+
new OutputGuardrailParams(AiMessage.from(chunk.getMessage()), chatMemory,
340+
actualAugmentationResult,
321341
methodCreateInfo.getUserMessageTemplate(),
322342
Collections.unmodifiableMap(templateVariables)),
323343
beanManager, auditSourceInfo);
@@ -340,6 +360,9 @@ private List<ChatMessage> messagesToSend(UserMessage augmentedUserMessage,
340360
throw new GuardrailException(
341361
"Attempting to rewrite the LLM output while streaming is not allowed");
342362
}
363+
if (isStringMulti) {
364+
return chunk.getMessage();
365+
}
343366
return chunk;
344367
}
345368
})
@@ -915,7 +938,7 @@ public interface Wrapper {
915938
Object wrap(Input input, Function<Input, Object> fun);
916939
}
917940

918-
private static class TokenStreamMulti extends AbstractMulti<String> implements Multi<String> {
941+
private static class TokenStreamMulti extends AbstractMulti<ChatEvent> implements Multi<ChatEvent> {
919942
private final List<ChatMessage> messagesToSend;
920943
private final List<ToolSpecification> toolSpecifications;
921944
private final Map<String, ToolExecutor> toolsExecutors;
@@ -941,14 +964,14 @@ public TokenStreamMulti(List<ChatMessage> messagesToSend, List<ToolSpecification
941964
}
942965

943966
@Override
944-
public void subscribe(MultiSubscriber<? super String> subscriber) {
945-
UnicastProcessor<String> processor = UnicastProcessor.create();
967+
public void subscribe(MultiSubscriber<? super ChatEvent> subscriber) {
968+
UnicastProcessor<ChatEvent> processor = UnicastProcessor.create();
946969
processor.subscribe(subscriber);
947970

948971
createTokenStream(processor);
949972
}
950973

951-
private void createTokenStream(UnicastProcessor<String> processor) {
974+
private void createTokenStream(UnicastProcessor<ChatEvent> processor) {
952975
Context ctxt = null;
953976
if (switchToWorkerThreadForToolExecution || isCallerRunningOnWorkerThread) {
954977
// we create or retrieve the current context, to use `executeBlocking` when required.
@@ -959,8 +982,18 @@ private void createTokenStream(UnicastProcessor<String> processor) {
959982
toolsExecutors, contents, context, memoryId, ctxt, switchToWorkerThreadForToolExecution,
960983
isCallerRunningOnWorkerThread);
961984
TokenStream tokenStream = stream
962-
.onPartialResponse(processor::onNext)
963-
.onCompleteResponse(message -> processor.onComplete())
985+
.onPartialResponse(chunk -> processor
986+
.onNext(new ChatEvent.PartialResponseEvent(chunk)))
987+
.onCompleteResponse(message -> {
988+
processor.onNext(new ChatEvent.ChatCompletedEvent(message));
989+
processor.onComplete();
990+
})
991+
.onRetrieved(content -> {
992+
processor.onNext(new ChatEvent.ContentFetchedEvent(content));
993+
})
994+
.onToolExecuted(execution -> {
995+
processor.onNext(new ChatEvent.ToolExecutedEvent(execution));
996+
})
964997
.onError(processor::onError);
965998
// This is equivalent to "run subscription on worker thread"
966999
if (switchToWorkerThreadForToolExecution && Context.isOnEventLoopThread()) {

0 commit comments

Comments
 (0)