Skip to content

Commit 2c25ef0

Browse files
VLSMBzhouyou9505
authored andcommitted
feat(nl2sql): enhance python code generate and display (alibaba#2233)
* feat: enhance CodePoolExecutorService * feat: add LocalCodePoolExecutorService * feat: enhance python prompt * feat: enhance python display * fix: frontend bugs * fix: frontend bugs
1 parent 1c7de88 commit 2c25ef0

File tree

18 files changed

+857
-112
lines changed

18 files changed

+857
-112
lines changed

spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/config/CodeExecutorProperties.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ public class CodeExecutorProperties {
112112
/**
113113
* Container network mode
114114
*/
115-
String networkMode = "bridge";
115+
String networkMode = "none";
116116

117117
public CodePoolExecutorEnum getCodePoolExecutor() {
118118
return codePoolExecutor;

spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/node/PythonExecuteNode.java

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,25 @@ public Map<String, Object> apply(OverAllState state) throws Exception {
7676
log.error(errorMsg);
7777
throw new RuntimeException(errorMsg);
7878
}
79-
log.info("Python Execute Success! StdOut: {}", taskResponse.stdOut());
79+
80+
// Python输出的JSON字符串可能有Unicode转义形式,需要解析回汉字
81+
String stdout = taskResponse.stdOut();
82+
try {
83+
Object value = objectMapper.readValue(stdout, Object.class);
84+
stdout = objectMapper.writeValueAsString(value);
85+
}
86+
catch (Exception e) {
87+
stdout = taskResponse.stdOut();
88+
}
89+
String finalStdout = stdout;
90+
91+
log.info("Python Execute Success! StdOut: {}", finalStdout);
8092

8193
// Create display flux for user experience only
8294
Flux<ChatResponse> displayFlux = Flux.create(emitter -> {
8395
emitter.next(ChatResponseUtil.createStatusResponse("开始执行Python代码..."));
8496
emitter.next(ChatResponseUtil.createStatusResponse("标准输出:\n```"));
85-
emitter.next(ChatResponseUtil.createStatusResponse(taskResponse.stdOut()));
97+
emitter.next(ChatResponseUtil.createStatusResponse(finalStdout));
8698
emitter.next(ChatResponseUtil.createStatusResponse("\n```"));
8799
emitter.next(ChatResponseUtil.createStatusResponse("Python代码执行成功!"));
88100
emitter.complete();
@@ -91,14 +103,14 @@ public Map<String, Object> apply(OverAllState state) throws Exception {
91103
// Create generator using utility class, returning pre-computed business logic
92104
// result
93105
var generator = StreamingChatGeneratorUtil.createStreamingGeneratorWithMessages(this.getClass(), state,
94-
v -> Map.of(PYTHON_EXECUTE_NODE_OUTPUT, taskResponse.stdOut(), PYTHON_IS_SUCCESS, true),
95-
displayFlux, StreamResponseType.PYTHON_EXECUTE);
106+
v -> Map.of(PYTHON_EXECUTE_NODE_OUTPUT, finalStdout, PYTHON_IS_SUCCESS, true), displayFlux,
107+
StreamResponseType.PYTHON_EXECUTE);
96108

97109
return Map.of(PYTHON_EXECUTE_NODE_OUTPUT, generator);
98110
}
99111
catch (Exception e) {
100112
String errorMessage = e.getMessage();
101-
log.error("Python Execute Exception: {}", errorMessage, e);
113+
log.error("Python Execute Exception: {}", errorMessage);
102114

103115
// Prepare error result
104116
Map<String, Object> errorResult = Map.of(PYTHON_EXECUTE_NODE_OUTPUT, errorMessage, PYTHON_IS_SUCCESS,

spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/node/PythonGenerateNode.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,8 @@ public Map<String, Object> apply(OverAllState state) throws Exception {
119119
.stream()
120120
.chatResponse();
121121

122-
var generator = StreamingChatGeneratorUtil.createStreamingGeneratorWithMessages(this.getClass(), state,
123-
"正在生成Python代码...", "Python代码生成完成。", aiResponse -> {
122+
var generator = StreamingChatGeneratorUtil.createStreamingGeneratorWithMessages(this.getClass(), state, "", "",
123+
aiResponse -> {
124124
// Some AI models still output Markdown markup (even though Prompt has
125125
// emphasized this)
126126
aiResponse = MarkdownParser.extractRawText(aiResponse);

spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/CodePoolExecutorEnum.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,6 @@
2424
*/
2525
public enum CodePoolExecutorEnum {
2626

27-
DOCKER, CONTAINERD, KATA, AI_SIMULATION;
27+
DOCKER, CONTAINERD, KATA, AI_SIMULATION, LOCAL;
2828

2929
}

spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/CodePoolExecutorService.java

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,22 @@ record TaskRequest(String code, String input, String requirement) {
3030

3131
}
3232

33-
record TaskResponse(boolean isSuccess, String stdOut, String stdErr, String exceptionMsg) {
34-
public static TaskResponse error(String msg) {
35-
return new TaskResponse(false, null, null, "An exception occurred while executing the task: " + msg);
33+
record TaskResponse(boolean isSuccess, boolean executionSuccessButResultFailed, String stdOut, String stdErr,
34+
String exceptionMsg) {
35+
36+
// 执行运行代码任务时发生异常
37+
public static TaskResponse exception(String msg) {
38+
return new TaskResponse(false, false, null, null, "An exception occurred while executing the task: " + msg);
39+
}
40+
41+
// 执行运行代码任务成功,并且代码正常返回
42+
public static TaskResponse success(String stdOut) {
43+
return new TaskResponse(true, false, stdOut, null, null);
44+
}
45+
46+
// 执行运行代码任务成功,但是代码异常返回
47+
public static TaskResponse failure(String stdOut, String stdErr) {
48+
return new TaskResponse(false, true, stdOut, stdErr, "StdErr: " + stdErr);
3649
}
3750

3851
@Override
@@ -44,7 +57,7 @@ public String toString() {
4457

4558
enum State {
4659

47-
READY, RUNNING
60+
READY, RUNNING, REMOVING
4861

4962
}
5063

spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/CodePoolExecutorServiceFactory.java

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import com.alibaba.cloud.ai.config.CodeExecutorProperties;
2020
import com.alibaba.cloud.ai.service.code.impl.AiSimulationCodeExecutorService;
2121
import com.alibaba.cloud.ai.service.code.impl.DockerCodePoolExecutorService;
22+
import com.alibaba.cloud.ai.service.code.impl.LocalCodePoolExecutorService;
2223
import org.springframework.ai.chat.client.ChatClient;
2324

2425
/**
@@ -35,15 +36,13 @@ private CodePoolExecutorServiceFactory() {
3536

3637
public static CodePoolExecutorService newInstance(CodeExecutorProperties properties,
3738
ChatClient.Builder chatClientBuilder) {
38-
if (properties.getCodePoolExecutor().equals(CodePoolExecutorEnum.DOCKER)) {
39-
return new DockerCodePoolExecutorService(properties);
40-
}
41-
else if (properties.getCodePoolExecutor().equals(CodePoolExecutorEnum.AI_SIMULATION)) {
42-
return new AiSimulationCodeExecutorService(chatClientBuilder);
43-
}
44-
else {
45-
throw new IllegalArgumentException("Unknown container impl: " + properties.getCodePoolExecutor());
46-
}
39+
return switch (properties.getCodePoolExecutor()) {
40+
case DOCKER -> new DockerCodePoolExecutorService(properties);
41+
case LOCAL -> new LocalCodePoolExecutorService(properties);
42+
case AI_SIMULATION -> new AiSimulationCodeExecutorService(chatClientBuilder);
43+
default -> throw new UnsupportedOperationException(
44+
"This option does not have a corresponding implementation class yet.");
45+
};
4746
}
4847

4948
}

spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/impl/AbstractCodePoolExecutorService.java

Lines changed: 83 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -104,36 +104,38 @@ public AbstractCodePoolExecutorService(CodeExecutorProperties properties) {
104104
}));
105105
}
106106

107+
/**
108+
* 创建新的容器
109+
* @return 容器ID
110+
*/
107111
protected abstract String createNewContainer() throws Exception;
108112

109-
protected abstract TaskResponse execTaskInContainer(TaskRequest request, String containerId) throws Exception;
113+
/**
114+
* 在指定容器ID的容器运行任务
115+
* @param request 任务请求对象
116+
* @param containerId 容器ID
117+
* @return 运行结果对象
118+
*/
119+
protected abstract TaskResponse execTaskInContainer(TaskRequest request, String containerId);
110120

121+
/**
122+
* 停止指定容器
123+
* @param containerId 容器ID
124+
*/
111125
protected abstract void stopContainer(String containerId) throws Exception;
112126

127+
/**
128+
* 删除指定容器
129+
* @param containerId 容器ID
130+
*/
113131
protected abstract void removeContainer(String containerId) throws Exception;
114132

115133
protected void shutdownPool() throws Exception {
116134
// Shutdown thread pool
117135
this.consumerThreadPool.shutdownNow();
118136
// Stop and delete all containers
119-
for (String containerId : this.tempContainerState.keySet()) {
120-
try {
121-
this.stopContainer(containerId);
122-
this.removeContainer(containerId);
123-
}
124-
catch (Exception ignored) {
125-
126-
}
127-
}
128-
for (String containerId : this.coreContainerState.keySet()) {
129-
try {
130-
this.stopContainer(containerId);
131-
this.removeContainer(containerId);
132-
}
133-
catch (Exception ignored) {
134-
135-
}
136-
}
137+
this.tempContainerState.keySet().forEach(id -> this.removeContainerAndState(id, false, true));
138+
this.coreContainerState.keySet().forEach(id -> this.removeContainerAndState(id, true, true));
137139
this.tempContainerState.clear();
138140
this.coreContainerState.clear();
139141
this.tempContainerRemoveFuture.clear();
@@ -142,6 +144,48 @@ protected void shutdownPool() throws Exception {
142144
this.taskQueue.clear();
143145
}
144146

147+
private void removeContainerAndState(String containerId, boolean isCore, boolean isForce) {
148+
try {
149+
if (isCore) {
150+
// Remove core container
151+
State state = this.coreContainerState.replace(containerId, State.REMOVING);
152+
if (state == State.RUNNING) {
153+
if (isForce) {
154+
this.stopContainer(containerId);
155+
}
156+
else {
157+
throw new RuntimeException("Container is still Running!");
158+
}
159+
}
160+
this.removeContainer(containerId);
161+
this.coreContainerState.remove(containerId);
162+
this.currentCoreContainerSize.decrementAndGet();
163+
log.info("Core Container {} has been removed successfully", containerId);
164+
}
165+
else {
166+
// Remove temporary container
167+
State state = this.tempContainerState.replace(containerId, State.REMOVING);
168+
if (state == State.RUNNING) {
169+
if (isForce) {
170+
this.stopContainer(containerId);
171+
}
172+
else {
173+
throw new RuntimeException("Container is still Running!");
174+
}
175+
}
176+
this.removeContainer(containerId);
177+
this.tempContainerState.remove(containerId);
178+
this.tempContainerRemoveFuture.remove(containerId);
179+
this.currentTempContainerSize.decrementAndGet();
180+
log.info("Temp Container {} has been removed successfully", containerId);
181+
}
182+
}
183+
catch (Exception e) {
184+
log.error("Error when trying to remove a container, containerId: {}, info: {}", containerId, e.getMessage(),
185+
e);
186+
}
187+
}
188+
145189
// Create thread to delete temporary containers
146190
private Future<?> registerRemoveTempContainer(String containerId) {
147191
return consumerThreadPool.submit(() -> {
@@ -156,17 +200,7 @@ private Future<?> registerRemoveTempContainer(String containerId) {
156200
log.debug("Interrupted while waiting for temp container to be removed, info: {}", e.getMessage());
157201
return;
158202
}
159-
try {
160-
// Remove temporary container
161-
this.tempContainerState.remove(containerId);
162-
this.tempContainerRemoveFuture.remove(containerId);
163-
this.removeContainer(containerId);
164-
log.debug("Container {} has been removed successfully", containerId);
165-
}
166-
catch (Exception e) {
167-
log.error("Error when trying to register temp container to be removed, containerId: {}, info: {}",
168-
containerId, e.getMessage(), e);
169-
}
203+
this.removeContainerAndState(containerId, false, false);
170204
});
171205
}
172206

@@ -176,6 +210,13 @@ private TaskResponse useCoreContainer(String containerId, TaskRequest request) {
176210
// Execute task
177211
this.coreContainerState.replace(containerId, State.RUNNING);
178212
TaskResponse resp = this.execTaskInContainer(request, containerId);
213+
// 如果运行代码任务时出现了异常,认为容器损坏,执行容器清除,并将当前任务放进队列里重新执行
214+
if (!resp.isSuccess() && !resp.executionSuccessButResultFailed()) {
215+
log.error("use core container failed, {}", resp.exceptionMsg());
216+
this.coreContainerState.replace(containerId, State.REMOVING);
217+
this.removeContainerAndState(containerId, true, true);
218+
return this.pushTaskQueue(request);
219+
}
179220
this.coreContainerState.replace(containerId, State.READY);
180221
// Put back into blocking queue
181222
this.readyCoreContainer.add(containerId);
@@ -185,7 +226,7 @@ private TaskResponse useCoreContainer(String containerId, TaskRequest request) {
185226
}
186227
catch (Exception e) {
187228
log.error("use core container failed, {}", e.getMessage(), e);
188-
return TaskResponse.error(e.getMessage());
229+
return TaskResponse.exception(e.getMessage());
189230
}
190231
}
191232

@@ -205,6 +246,13 @@ private TaskResponse useTempContainer(String containerId, TaskRequest request) {
205246
// Execute task
206247
this.tempContainerState.replace(containerId, State.RUNNING);
207248
TaskResponse resp = this.execTaskInContainer(request, containerId);
249+
// 如果运行代码任务时出现了异常,认为容器损坏,执行容器清除,并将当前任务放进队列里重新执行
250+
if (!resp.isSuccess() && !resp.executionSuccessButResultFailed()) {
251+
log.error("use temp container failed, {}", resp.exceptionMsg());
252+
this.tempContainerState.replace(containerId, State.REMOVING);
253+
this.removeContainerAndState(containerId, false, true);
254+
return this.pushTaskQueue(request);
255+
}
208256
this.tempContainerState.replace(containerId, State.READY);
209257
// Put back into blocking queue
210258
this.readyTempContainer.add(containerId);
@@ -216,7 +264,7 @@ private TaskResponse useTempContainer(String containerId, TaskRequest request) {
216264
}
217265
catch (Exception e) {
218266
log.error("use temp container failed, {}", e.getMessage(), e);
219-
return TaskResponse.error(e.getMessage());
267+
return TaskResponse.exception(e.getMessage());
220268
}
221269
}
222270

@@ -228,7 +276,7 @@ private TaskResponse createAndUseCoreContainer(TaskRequest request) {
228276
}
229277
catch (Exception e) {
230278
log.error("create new container failed, {}", e.getMessage(), e);
231-
return TaskResponse.error(e.getMessage());
279+
return TaskResponse.exception(e.getMessage());
232280
}
233281
// Record newly added container
234282
this.coreContainerState.put(containerId, State.READY);
@@ -244,7 +292,7 @@ private TaskResponse createAndUseTempContainer(TaskRequest request) {
244292
}
245293
catch (Exception e) {
246294
log.error("create new container failed, {}", e.getMessage(), e);
247-
return TaskResponse.error(e.getMessage());
295+
return TaskResponse.exception(e.getMessage());
248296
}
249297
// Record newly added container
250298
this.tempContainerState.put(containerId, State.READY);
@@ -326,7 +374,7 @@ public TaskResponse runTask(TaskRequest request) {
326374
}
327375
catch (Exception e) {
328376
log.error("An exception occurred while executing the task: {}", e.getMessage(), e);
329-
return TaskResponse.error(e.getMessage());
377+
return TaskResponse.exception(e.getMessage());
330378
}
331379
}
332380

spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/impl/AiSimulationCodeExecutorService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ public TaskResponse runTask(TaskRequest request) {
5858
```
5959
""", request.code(), request.input());
6060
String output = chatClient.prompt().user(userPrompt).call().content();
61-
return new TaskResponse(true, output, null, null);
61+
return TaskResponse.success(output);
6262
}
6363

6464
}

0 commit comments

Comments
 (0)