Skip to content

Commit 94b6b4e

Browse files
committed
Merge branch 'feat-2510' of github.com:mengnankkkk/spring-ai-alibaba into feat-2510
2 parents 876901a + a0d003d commit 94b6b4e

File tree

2 files changed

+144
-33
lines changed

2 files changed

+144
-33
lines changed

spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/LlmNode.java

Lines changed: 59 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -87,17 +87,17 @@ public static Builder builder() {
8787

8888
@Override
8989
public Map<String, Object> apply(OverAllState state) throws Exception {
90-
initNodeWithState(state);
90+
ExecutionContext context = initNodeWithState(state);
9191

9292
// add streaming support
9393
if (Boolean.TRUE.equals(stream)) {
94-
Flux<ChatResponse> chatResponseFlux = stream(state);
94+
Flux<ChatResponse> chatResponseFlux = stream(context);
9595
return Map.of(StringUtils.hasLength(this.outputKey) ? this.outputKey : "messages", chatResponseFlux);
9696
}
9797
else {
9898
AssistantMessage responseOutput;
9999
try {
100-
ChatResponse response = call(state);
100+
ChatResponse response = call(context);
101101
responseOutput = response.getResult().getOutput();
102102
}
103103
catch (Exception e) {
@@ -113,19 +113,24 @@ public Map<String, Object> apply(OverAllState state) throws Exception {
113113
}
114114
}
115115

116-
private void initNodeWithState(OverAllState state) {
116+
private ExecutionContext initNodeWithState(OverAllState state) {
117+
String localUserPrompt = this.userPrompt;
118+
String localSystemPrompt = this.systemPrompt;
119+
Map<String, Object> localParams = new HashMap<>(this.params);
120+
List<Message> localMessages = new ArrayList<>(this.messages);
121+
117122
if (StringUtils.hasLength(userPromptKey)) {
118-
this.userPrompt = (String) state.value(userPromptKey).orElse(this.userPrompt);
123+
localUserPrompt = (String) state.value(userPromptKey).orElse(localUserPrompt);
119124
}
120125
if (StringUtils.hasLength(systemPromptKey)) {
121-
this.systemPrompt = (String) state.value(systemPromptKey).orElse(this.systemPrompt);
126+
localSystemPrompt = (String) state.value(systemPromptKey).orElse(localSystemPrompt);
122127
}
123128
if (StringUtils.hasLength(paramsKey)) {
124-
this.params = (Map<String, Object>) state.value(paramsKey).orElse(this.params);
129+
localParams = (Map<String, Object>) state.value(paramsKey).orElse(localParams);
125130
}
126131
// Used for adapting the dify's DSL conversion
127-
if (!this.params.isEmpty()) {
128-
Map<String, Object> rawParams = this.params;
132+
if (!localParams.isEmpty()) {
133+
Map<String, Object> rawParams = localParams;
129134
Map<String, Object> filledParams = new HashMap<>();
130135
for (Map.Entry<String, Object> entry : rawParams.entrySet()) {
131136
if (entry.getValue().equals("null")) {
@@ -136,19 +141,32 @@ private void initNodeWithState(OverAllState state) {
136141
filledParams.put(entry.getKey(), entry.getValue());
137142
}
138143
}
139-
140-
this.params = filledParams;
144+
localParams = filledParams;
141145
}
142146
if (StringUtils.hasLength(messagesKey)) {
143147
Object messagesValue = state.value(messagesKey).orElse(null);
144148
if (messagesValue != null) {
145149
List<Message> convertedMessages = convertToMessages(messagesValue);
146-
this.messages = convertedMessages.isEmpty() ? this.messages : convertedMessages;
150+
localMessages = convertedMessages.isEmpty() ? localMessages : convertedMessages;
147151
}
148152
}
149-
if (StringUtils.hasLength(userPrompt) && !params.isEmpty()) {
150-
this.userPrompt = renderPromptTemplate(userPrompt, params);
153+
154+
String renderedUserPrompt = localUserPrompt;
155+
String renderedSystemPrompt = localSystemPrompt;
156+
157+
if (StringUtils.hasLength(localUserPrompt) && !localParams.isEmpty()) {
158+
renderedUserPrompt = renderPromptTemplate(localUserPrompt, localParams);
159+
}
160+
161+
if (StringUtils.hasLength(localSystemPrompt)) {
162+
if (!localParams.isEmpty()) {
163+
renderedSystemPrompt = renderPromptTemplate(localSystemPrompt, localParams);
164+
} else {
165+
renderedSystemPrompt = renderPromptTemplate(localSystemPrompt, state.data());
166+
}
151167
}
168+
169+
return new ExecutionContext(renderedSystemPrompt, renderedUserPrompt, localParams, localMessages, state);
152170
}
153171

154172
public void setToolCallbacks(List<ToolCallback> toolCallbacks) {
@@ -160,35 +178,26 @@ private String renderPromptTemplate(String prompt, Map<String, Object> params) {
160178
return promptTemplate.render(params);
161179
}
162180

163-
public Flux<ChatResponse> stream(OverAllState state) {
164-
return buildChatClientRequestSpec(state).stream().chatResponse();
181+
public Flux<ChatResponse> stream(ExecutionContext context) {
182+
return buildChatClientRequestSpec(context).stream().chatResponse();
165183
}
166184

167-
public ChatResponse call(OverAllState state) {
168-
return buildChatClientRequestSpec(state).call().chatResponse();
185+
public ChatResponse call(ExecutionContext context) {
186+
return buildChatClientRequestSpec(context).call().chatResponse();
169187
}
170188

171-
private ChatClient.ChatClientRequestSpec buildChatClientRequestSpec(OverAllState state) {
189+
private ChatClient.ChatClientRequestSpec buildChatClientRequestSpec(ExecutionContext context) {
172190
ChatClient.ChatClientRequestSpec chatClientRequestSpec = chatClient.prompt()
173191
.toolCallbacks(toolCallbacks)
174-
.messages(messages)
192+
.messages(context.messages)
175193
.advisors(advisors);
176194

177-
if (StringUtils.hasLength(systemPrompt)) {
178-
if (!params.isEmpty()) {
179-
systemPrompt = renderPromptTemplate(systemPrompt, params);
180-
} else {
181-
// try render with state
182-
systemPrompt = renderPromptTemplate(systemPrompt, state.data());
183-
}
184-
chatClientRequestSpec.system(systemPrompt);
195+
if (StringUtils.hasLength(context.systemPrompt)) {
196+
chatClientRequestSpec.system(context.systemPrompt);
185197
}
186198

187-
if (StringUtils.hasLength(userPrompt)) {
188-
if (!params.isEmpty()) {
189-
userPrompt = renderPromptTemplate(userPrompt, params);
190-
}
191-
chatClientRequestSpec.user(userPrompt);
199+
if (StringUtils.hasLength(context.userPrompt)) {
200+
chatClientRequestSpec.user(context.userPrompt);
192201
}
193202

194203
return chatClientRequestSpec;
@@ -315,4 +324,21 @@ public LlmNode build() {
315324

316325
}
317326

327+
private static class ExecutionContext {
328+
final String systemPrompt;
329+
final String userPrompt;
330+
final Map<String, Object> params;
331+
final List<Message> messages;
332+
final OverAllState state;
333+
334+
ExecutionContext(String systemPrompt, String userPrompt, Map<String, Object> params,
335+
List<Message> messages, OverAllState state) {
336+
this.systemPrompt = systemPrompt;
337+
this.userPrompt = userPrompt;
338+
this.params = params;
339+
this.messages = messages;
340+
this.state = state;
341+
}
342+
}
343+
318344
}
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Copyright 2024-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package com.alibaba.cloud.ai.graph.node;
17+
18+
import com.alibaba.cloud.ai.graph.OverAllState;
19+
import org.junit.Test;
20+
import org.springframework.ai.chat.messages.UserMessage;
21+
22+
import java.util.ArrayList;
23+
import java.util.List;
24+
import java.util.Map;
25+
import java.util.concurrent.CountDownLatch;
26+
import java.util.concurrent.ExecutorService;
27+
import java.util.concurrent.Executors;
28+
import java.util.concurrent.atomic.AtomicInteger;
29+
30+
import static org.junit.Assert.assertEquals;
31+
32+
public class LlmNodeThreadSafetyTest {
33+
34+
@Test
35+
public void testConcurrentMessageHandling() throws InterruptedException {
36+
TestLlmNode node = new TestLlmNode();
37+
38+
ExecutorService executor = Executors.newFixedThreadPool(10);
39+
CountDownLatch latch = new CountDownLatch(50);
40+
AtomicInteger conflicts = new AtomicInteger(0);
41+
42+
for (int i = 0; i < 50; i++) {
43+
final int requestId = i;
44+
executor.submit(() -> {
45+
try {
46+
OverAllState state = new OverAllState();
47+
state.updateState(Map.of("messagesKey", List.of(new UserMessage("Request-" + requestId))));
48+
49+
Map<String, Object> result = node.simulateApply(state);
50+
51+
String resultStr = result.toString();
52+
if (!resultStr.contains("Request-" + requestId)) {
53+
conflicts.incrementAndGet();
54+
}
55+
} catch (Exception e) {
56+
conflicts.incrementAndGet();
57+
} finally {
58+
latch.countDown();
59+
}
60+
});
61+
}
62+
63+
latch.await();
64+
executor.shutdown();
65+
assertEquals("Thread safety test failed - found data conflicts", 0, conflicts.get());
66+
}
67+
68+
private static class TestLlmNode {
69+
private List<Object> messages = new ArrayList<>();
70+
private String messagesKey = "messagesKey";
71+
72+
public Map<String, Object> simulateApply(OverAllState state) {
73+
List<Object> localMessages = new ArrayList<>(this.messages);
74+
75+
if (messagesKey != null) {
76+
Object messagesValue = state.value(messagesKey).orElse(null);
77+
if (messagesValue != null) {
78+
localMessages = (List<Object>) messagesValue;
79+
}
80+
}
81+
82+
return Map.of("processedMessages", localMessages);
83+
}
84+
}
85+
}

0 commit comments

Comments
 (0)