Skip to content

Commit 1762714

Browse files
authored
Merge pull request #1662 from tomas1885/streaming_response_chatevent
Allow returning a ChatResponse Multi in streaming mode
2 parents 95492e7 + bf7751a commit 1762714

File tree

6 files changed

+332
-26
lines changed

6 files changed

+332
-26
lines changed

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -693,14 +693,15 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
693693
if (!DotNames.MULTI.equals(method.returnType().name())) {
694694
continue;
695695
}
696-
boolean isMultiString = false;
696+
boolean isSupportedResponseType = false;
697697
if (method.returnType().kind() == Type.Kind.PARAMETERIZED_TYPE) {
698698
Type multiType = method.returnType().asParameterizedType().arguments().get(0);
699-
if (DotNames.STRING.equals(multiType.name())) {
700-
isMultiString = true;
699+
if (DotNames.STRING.equals(multiType.name())
700+
|| DotNames.CHAT_EVENT.equals(multiType.name())) {
701+
isSupportedResponseType = true;
701702
}
702703
}
703-
if (!isMultiString) {
704+
if (!isSupportedResponseType) {
704705
throw illegalConfiguration("Only Multi<String> is supported as a Multi return type. Offending method is '"
705706
+ method.declaringClass().name().toString() + "#" + method.name() + "'");
706707
}

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator;
1818
import io.quarkiverse.langchain4j.response.AiResponseAugmenter;
1919
import io.quarkiverse.langchain4j.response.ResponseAugmenter;
20+
import io.quarkiverse.langchain4j.runtime.aiservice.ChatEvent;
2021
import io.smallrye.common.annotation.Blocking;
2122
import io.smallrye.common.annotation.NonBlocking;
2223
import io.smallrye.common.annotation.RunOnVirtualThread;
@@ -62,6 +63,7 @@ public class DotNames {
6263
public static final DotName CHAT_MODEL_LISTENER = DotName.createSimple(ChatModelListener.class);
6364
public static final DotName MODEL_AUTH_PROVIDER = DotName.createSimple(ModelAuthProvider.class);
6465
public static final DotName TOOL = DotName.createSimple(Tool.class);
66+
public static final DotName CHAT_EVENT = DotName.createSimple(ChatEvent.class);
6567

6668
public static final DotName REGISTER_REST_CLIENT = DotName.createSimple(RegisterRestClient.class);
6769

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,10 @@ private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, Ob
177177
Map<String, Object> templateVariables = getTemplateVariables(methodArgs, methodCreateInfo.getUserMessageInfo());
178178

179179
Type returnType = methodCreateInfo.getReturnType();
180+
boolean isMulti = TypeUtil.isMulti(returnType);
181+
182+
final boolean isStringMulti = (isMulti && returnType instanceof ParameterizedType
183+
&& TypeUtil.isTypeOf(((ParameterizedType) returnType).getActualTypeArguments()[0], String.class));
180184
if (TypeUtil.isImage(returnType) || TypeUtil.isResultImage(returnType)) {
181185
return doImplementGenerateImage(methodCreateInfo, context, systemMessage, userMessage, memoryId, returnType,
182186
templateVariables, auditSourceInfo);
@@ -217,7 +221,7 @@ private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, Ob
217221
Metadata metadata = Metadata.from(userMessage, memoryId, chatMemory);
218222
AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata);
219223

220-
if (!TypeUtil.isMulti(returnType)) {
224+
if (!isMulti) {
221225
augmentationResult = context.retrievalAugmentor.augment(augmentationRequest);
222226
userMessage = (UserMessage) augmentationResult.chatMessage();
223227
} else {
@@ -244,10 +248,18 @@ public Flow.Publisher<?> apply(AugmentationResult ar) {
244248
var stream = new TokenStreamMulti(messagesToSend, effectiveToolSpecifications,
245249
finalToolExecutors, ar.contents(), context, memoryId,
246250
methodCreateInfo.isSwitchToWorkerThreadForToolExecution(), isRunningOnWorkerThread);
247-
return stream.plug(m -> ResponseAugmenterSupport.apply(m, methodCreateInfo,
248-
new ResponseAugmenterParams((UserMessage) augmentedUserMessage,
249-
memory, ar, methodCreateInfo.getUserMessageTemplate(),
250-
templateVariables)));
251+
return stream
252+
.filter(event -> {
253+
return !isStringMulti || event instanceof ChatEvent.PartialResponseEvent;
254+
}).map(event -> {
255+
if (isStringMulti && event instanceof ChatEvent.PartialResponseEvent) {
256+
return ((ChatEvent.PartialResponseEvent) event).getChunk();
257+
}
258+
return event;
259+
}).plug(m -> ResponseAugmenterSupport.apply(m, methodCreateInfo,
260+
new ResponseAugmenterParams((UserMessage) augmentedUserMessage,
261+
memory, ar, methodCreateInfo.getUserMessageTemplate(),
262+
templateVariables)));
251263
}
252264

253265
private List<ChatMessage> messagesToSend(UserMessage augmentedUserMessage,
@@ -297,13 +309,20 @@ private List<ChatMessage> messagesToSend(UserMessage augmentedUserMessage,
297309

298310
var actualAugmentationResult = augmentationResult;
299311
var actualUserMessage = userMessage;
300-
if (TypeUtil.isMulti(returnType)) {
312+
if (isMulti) {
301313
chatMemory.commit(); // for streaming cases, we really have to commit because all alternatives are worse
302314
if (methodCreateInfo.getOutputGuardrailsClassNames().isEmpty()) {
303315
var stream = new TokenStreamMulti(messagesToSend, toolSpecifications, toolExecutors,
304316
(augmentationResult != null ? augmentationResult.contents() : null), context, memoryId,
305317
methodCreateInfo.isSwitchToWorkerThreadForToolExecution(), isRunningOnWorkerThread);
306-
return stream.plug(m -> ResponseAugmenterSupport.apply(m, methodCreateInfo,
318+
return stream.filter(event -> {
319+
return !isStringMulti || event instanceof ChatEvent.PartialResponseEvent;
320+
}).map(event -> {
321+
if (isStringMulti && event instanceof ChatEvent.PartialResponseEvent) {
322+
return ((ChatEvent.PartialResponseEvent) event).getChunk();
323+
}
324+
return event;
325+
}).plug(m -> ResponseAugmenterSupport.apply(m, methodCreateInfo,
307326
new ResponseAugmenterParams(actualUserMessage,
308327
chatMemory, actualAugmentationResult, methodCreateInfo.getUserMessageTemplate(),
309328
Collections.unmodifiableMap(templateVariables))));
@@ -317,7 +336,8 @@ private List<ChatMessage> messagesToSend(UserMessage augmentedUserMessage,
317336
OutputGuardrailResult result;
318337
try {
319338
result = GuardrailsSupport.invokeOutputGuardrailsForStream(methodCreateInfo,
320-
new OutputGuardrailParams(AiMessage.from(chunk), chatMemory, actualAugmentationResult,
339+
new OutputGuardrailParams(AiMessage.from(chunk.getMessage()), chatMemory,
340+
actualAugmentationResult,
321341
methodCreateInfo.getUserMessageTemplate(),
322342
Collections.unmodifiableMap(templateVariables)),
323343
beanManager, auditSourceInfo);
@@ -340,6 +360,9 @@ private List<ChatMessage> messagesToSend(UserMessage augmentedUserMessage,
340360
throw new GuardrailException(
341361
"Attempting to rewrite the LLM output while streaming is not allowed");
342362
}
363+
if (isStringMulti) {
364+
return chunk.getMessage();
365+
}
343366
return chunk;
344367
}
345368
})
@@ -915,7 +938,7 @@ public interface Wrapper {
915938
Object wrap(Input input, Function<Input, Object> fun);
916939
}
917940

918-
private static class TokenStreamMulti extends AbstractMulti<String> implements Multi<String> {
941+
private static class TokenStreamMulti extends AbstractMulti<ChatEvent> implements Multi<ChatEvent> {
919942
private final List<ChatMessage> messagesToSend;
920943
private final List<ToolSpecification> toolSpecifications;
921944
private final Map<String, ToolExecutor> toolsExecutors;
@@ -941,14 +964,14 @@ public TokenStreamMulti(List<ChatMessage> messagesToSend, List<ToolSpecification
941964
}
942965

943966
@Override
944-
public void subscribe(MultiSubscriber<? super String> subscriber) {
945-
UnicastProcessor<String> processor = UnicastProcessor.create();
967+
public void subscribe(MultiSubscriber<? super ChatEvent> subscriber) {
968+
UnicastProcessor<ChatEvent> processor = UnicastProcessor.create();
946969
processor.subscribe(subscriber);
947970

948971
createTokenStream(processor);
949972
}
950973

951-
private void createTokenStream(UnicastProcessor<String> processor) {
974+
private void createTokenStream(UnicastProcessor<ChatEvent> processor) {
952975
Context ctxt = null;
953976
if (switchToWorkerThreadForToolExecution || isCallerRunningOnWorkerThread) {
954977
// we create or retrieve the current context, to use `executeBlocking` when required.
@@ -959,8 +982,18 @@ private void createTokenStream(UnicastProcessor<String> processor) {
959982
toolsExecutors, contents, context, memoryId, ctxt, switchToWorkerThreadForToolExecution,
960983
isCallerRunningOnWorkerThread);
961984
TokenStream tokenStream = stream
962-
.onPartialResponse(processor::onNext)
963-
.onCompleteResponse(message -> processor.onComplete())
985+
.onPartialResponse(chunk -> processor
986+
.onNext(new ChatEvent.PartialResponseEvent(chunk)))
987+
.onCompleteResponse(message -> {
988+
processor.onNext(new ChatEvent.ChatCompletedEvent(message));
989+
processor.onComplete();
990+
})
991+
.onRetrieved(content -> {
992+
processor.onNext(new ChatEvent.ContentFetchedEvent(content));
993+
})
994+
.onToolExecuted(execution -> {
995+
processor.onNext(new ChatEvent.ToolExecutedEvent(execution));
996+
})
964997
.onError(processor::onError);
965998
// This is equivalent to "run subscription on worker thread"
966999
if (switchToWorkerThreadForToolExecution && Context.isOnEventLoopThread()) {

0 commit comments

Comments
 (0)