diff --git a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/config/CodeExecutorProperties.java b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/config/CodeExecutorProperties.java index 72be1e6501..365d4f2a43 100644 --- a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/config/CodeExecutorProperties.java +++ b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/config/CodeExecutorProperties.java @@ -112,7 +112,7 @@ public class CodeExecutorProperties { /** * Container network mode */ - String networkMode = "bridge"; + String networkMode = "none"; public CodePoolExecutorEnum getCodePoolExecutor() { return codePoolExecutor; diff --git a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/node/PythonExecuteNode.java b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/node/PythonExecuteNode.java index d80ea77eb5..be5448d577 100644 --- a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/node/PythonExecuteNode.java +++ b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/node/PythonExecuteNode.java @@ -76,13 +76,25 @@ public Map apply(OverAllState state) throws Exception { log.error(errorMsg); throw new RuntimeException(errorMsg); } - log.info("Python Execute Success! StdOut: {}", taskResponse.stdOut()); + + // Python输出的JSON字符串可能有Unicode转义形式,需要解析回汉字 + String stdout = taskResponse.stdOut(); + try { + Object value = objectMapper.readValue(stdout, Object.class); + stdout = objectMapper.writeValueAsString(value); + } + catch (Exception e) { + stdout = taskResponse.stdOut(); + } + String finalStdout = stdout; + + log.info("Python Execute Success! StdOut: {}", finalStdout); // Create display flux for user experience only Flux displayFlux = Flux.create(emitter -> { emitter.next(ChatResponseUtil.createStatusResponse("开始执行Python代码...")); emitter.next(ChatResponseUtil.createStatusResponse("标准输出:\n```")); - emitter.next(ChatResponseUtil.createStatusResponse(taskResponse.stdOut())); + emitter.next(ChatResponseUtil.createStatusResponse(finalStdout)); emitter.next(ChatResponseUtil.createStatusResponse("\n```")); emitter.next(ChatResponseUtil.createStatusResponse("Python代码执行成功!")); emitter.complete(); @@ -91,14 +103,14 @@ public Map apply(OverAllState state) throws Exception { // Create generator using utility class, returning pre-computed business logic // result var generator = StreamingChatGeneratorUtil.createStreamingGeneratorWithMessages(this.getClass(), state, - v -> Map.of(PYTHON_EXECUTE_NODE_OUTPUT, taskResponse.stdOut(), PYTHON_IS_SUCCESS, true), - displayFlux, StreamResponseType.PYTHON_EXECUTE); + v -> Map.of(PYTHON_EXECUTE_NODE_OUTPUT, finalStdout, PYTHON_IS_SUCCESS, true), displayFlux, + StreamResponseType.PYTHON_EXECUTE); return Map.of(PYTHON_EXECUTE_NODE_OUTPUT, generator); } catch (Exception e) { String errorMessage = e.getMessage(); - log.error("Python Execute Exception: {}", errorMessage, e); + log.error("Python Execute Exception: {}", errorMessage); // Prepare error result Map errorResult = Map.of(PYTHON_EXECUTE_NODE_OUTPUT, errorMessage, PYTHON_IS_SUCCESS, diff --git a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/node/PythonGenerateNode.java b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/node/PythonGenerateNode.java index 661be6f338..d788ce51e4 100644 --- a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/node/PythonGenerateNode.java +++ b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/node/PythonGenerateNode.java @@ -119,8 +119,8 @@ public Map apply(OverAllState state) throws Exception { .stream() .chatResponse(); - var generator = StreamingChatGeneratorUtil.createStreamingGeneratorWithMessages(this.getClass(), state, - "正在生成Python代码...", "Python代码生成完成。", aiResponse -> { + var generator = StreamingChatGeneratorUtil.createStreamingGeneratorWithMessages(this.getClass(), state, "", "", + aiResponse -> { // Some AI models still output Markdown markup (even though Prompt has // emphasized this) aiResponse = MarkdownParser.extractRawText(aiResponse); diff --git a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/CodePoolExecutorEnum.java b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/CodePoolExecutorEnum.java index aeb2be25ac..39fdd383a2 100644 --- a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/CodePoolExecutorEnum.java +++ b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/CodePoolExecutorEnum.java @@ -24,6 +24,6 @@ */ public enum CodePoolExecutorEnum { - DOCKER, CONTAINERD, KATA, AI_SIMULATION; + DOCKER, CONTAINERD, KATA, AI_SIMULATION, LOCAL; } diff --git a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/CodePoolExecutorService.java b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/CodePoolExecutorService.java index 0f19df246e..62b27e213b 100644 --- a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/CodePoolExecutorService.java +++ b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/CodePoolExecutorService.java @@ -30,9 +30,22 @@ record TaskRequest(String code, String input, String requirement) { } - record TaskResponse(boolean isSuccess, String stdOut, String stdErr, String exceptionMsg) { - public static TaskResponse error(String msg) { - return new TaskResponse(false, null, null, "An exception occurred while executing the task: " + msg); + record TaskResponse(boolean isSuccess, boolean executionSuccessButResultFailed, String stdOut, String stdErr, + String exceptionMsg) { + + // 执行运行代码任务时发生异常 + public static TaskResponse exception(String msg) { + return new TaskResponse(false, false, null, null, "An exception occurred while executing the task: " + msg); + } + + // 执行运行代码任务成功,并且代码正常返回 + public static TaskResponse success(String stdOut) { + return new TaskResponse(true, false, stdOut, null, null); + } + + // 执行运行代码任务成功,但是代码异常返回 + public static TaskResponse failure(String stdOut, String stdErr) { + return new TaskResponse(false, true, stdOut, stdErr, "StdErr: " + stdErr); } @Override @@ -44,7 +57,7 @@ public String toString() { enum State { - READY, RUNNING + READY, RUNNING, REMOVING } diff --git a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/CodePoolExecutorServiceFactory.java b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/CodePoolExecutorServiceFactory.java index 70b09b1b8c..f02defd419 100644 --- a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/CodePoolExecutorServiceFactory.java +++ b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/CodePoolExecutorServiceFactory.java @@ -19,6 +19,7 @@ import com.alibaba.cloud.ai.config.CodeExecutorProperties; import com.alibaba.cloud.ai.service.code.impl.AiSimulationCodeExecutorService; import com.alibaba.cloud.ai.service.code.impl.DockerCodePoolExecutorService; +import com.alibaba.cloud.ai.service.code.impl.LocalCodePoolExecutorService; import org.springframework.ai.chat.client.ChatClient; /** @@ -35,15 +36,13 @@ private CodePoolExecutorServiceFactory() { public static CodePoolExecutorService newInstance(CodeExecutorProperties properties, ChatClient.Builder chatClientBuilder) { - if (properties.getCodePoolExecutor().equals(CodePoolExecutorEnum.DOCKER)) { - return new DockerCodePoolExecutorService(properties); - } - else if (properties.getCodePoolExecutor().equals(CodePoolExecutorEnum.AI_SIMULATION)) { - return new AiSimulationCodeExecutorService(chatClientBuilder); - } - else { - throw new IllegalArgumentException("Unknown container impl: " + properties.getCodePoolExecutor()); - } + return switch (properties.getCodePoolExecutor()) { + case DOCKER -> new DockerCodePoolExecutorService(properties); + case LOCAL -> new LocalCodePoolExecutorService(properties); + case AI_SIMULATION -> new AiSimulationCodeExecutorService(chatClientBuilder); + default -> throw new UnsupportedOperationException( + "This option does not have a corresponding implementation class yet."); + }; } } diff --git a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/impl/AbstractCodePoolExecutorService.java b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/impl/AbstractCodePoolExecutorService.java index 4aa9ba3be2..7cd627db27 100644 --- a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/impl/AbstractCodePoolExecutorService.java +++ b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/impl/AbstractCodePoolExecutorService.java @@ -104,36 +104,38 @@ public AbstractCodePoolExecutorService(CodeExecutorProperties properties) { })); } + /** + * 创建新的容器 + * @return 容器ID + */ protected abstract String createNewContainer() throws Exception; - protected abstract TaskResponse execTaskInContainer(TaskRequest request, String containerId) throws Exception; + /** + * 在指定容器ID的容器运行任务 + * @param request 任务请求对象 + * @param containerId 容器ID + * @return 运行结果对象 + */ + protected abstract TaskResponse execTaskInContainer(TaskRequest request, String containerId); + /** + * 停止指定容器 + * @param containerId 容器ID + */ protected abstract void stopContainer(String containerId) throws Exception; + /** + * 删除指定容器 + * @param containerId 容器ID + */ protected abstract void removeContainer(String containerId) throws Exception; protected void shutdownPool() throws Exception { // Shutdown thread pool this.consumerThreadPool.shutdownNow(); // Stop and delete all containers - for (String containerId : this.tempContainerState.keySet()) { - try { - this.stopContainer(containerId); - this.removeContainer(containerId); - } - catch (Exception ignored) { - - } - } - for (String containerId : this.coreContainerState.keySet()) { - try { - this.stopContainer(containerId); - this.removeContainer(containerId); - } - catch (Exception ignored) { - - } - } + this.tempContainerState.keySet().forEach(id -> this.removeContainerAndState(id, false, true)); + this.coreContainerState.keySet().forEach(id -> this.removeContainerAndState(id, true, true)); this.tempContainerState.clear(); this.coreContainerState.clear(); this.tempContainerRemoveFuture.clear(); @@ -142,6 +144,48 @@ protected void shutdownPool() throws Exception { this.taskQueue.clear(); } + private void removeContainerAndState(String containerId, boolean isCore, boolean isForce) { + try { + if (isCore) { + // Remove core container + State state = this.coreContainerState.replace(containerId, State.REMOVING); + if (state == State.RUNNING) { + if (isForce) { + this.stopContainer(containerId); + } + else { + throw new RuntimeException("Container is still Running!"); + } + } + this.removeContainer(containerId); + this.coreContainerState.remove(containerId); + this.currentCoreContainerSize.decrementAndGet(); + log.info("Core Container {} has been removed successfully", containerId); + } + else { + // Remove temporary container + State state = this.tempContainerState.replace(containerId, State.REMOVING); + if (state == State.RUNNING) { + if (isForce) { + this.stopContainer(containerId); + } + else { + throw new RuntimeException("Container is still Running!"); + } + } + this.removeContainer(containerId); + this.tempContainerState.remove(containerId); + this.tempContainerRemoveFuture.remove(containerId); + this.currentTempContainerSize.decrementAndGet(); + log.info("Temp Container {} has been removed successfully", containerId); + } + } + catch (Exception e) { + log.error("Error when trying to remove a container, containerId: {}, info: {}", containerId, e.getMessage(), + e); + } + } + // Create thread to delete temporary containers private Future registerRemoveTempContainer(String containerId) { return consumerThreadPool.submit(() -> { @@ -156,17 +200,7 @@ private Future registerRemoveTempContainer(String containerId) { log.debug("Interrupted while waiting for temp container to be removed, info: {}", e.getMessage()); return; } - try { - // Remove temporary container - this.tempContainerState.remove(containerId); - this.tempContainerRemoveFuture.remove(containerId); - this.removeContainer(containerId); - log.debug("Container {} has been removed successfully", containerId); - } - catch (Exception e) { - log.error("Error when trying to register temp container to be removed, containerId: {}, info: {}", - containerId, e.getMessage(), e); - } + this.removeContainerAndState(containerId, false, false); }); } @@ -176,6 +210,13 @@ private TaskResponse useCoreContainer(String containerId, TaskRequest request) { // Execute task this.coreContainerState.replace(containerId, State.RUNNING); TaskResponse resp = this.execTaskInContainer(request, containerId); + // 如果运行代码任务时出现了异常,认为容器损坏,执行容器清除,并将当前任务放进队列里重新执行 + if (!resp.isSuccess() && !resp.executionSuccessButResultFailed()) { + log.error("use core container failed, {}", resp.exceptionMsg()); + this.coreContainerState.replace(containerId, State.REMOVING); + this.removeContainerAndState(containerId, true, true); + return this.pushTaskQueue(request); + } this.coreContainerState.replace(containerId, State.READY); // Put back into blocking queue this.readyCoreContainer.add(containerId); @@ -185,7 +226,7 @@ private TaskResponse useCoreContainer(String containerId, TaskRequest request) { } catch (Exception e) { log.error("use core container failed, {}", e.getMessage(), e); - return TaskResponse.error(e.getMessage()); + return TaskResponse.exception(e.getMessage()); } } @@ -205,6 +246,13 @@ private TaskResponse useTempContainer(String containerId, TaskRequest request) { // Execute task this.tempContainerState.replace(containerId, State.RUNNING); TaskResponse resp = this.execTaskInContainer(request, containerId); + // 如果运行代码任务时出现了异常,认为容器损坏,执行容器清除,并将当前任务放进队列里重新执行 + if (!resp.isSuccess() && !resp.executionSuccessButResultFailed()) { + log.error("use temp container failed, {}", resp.exceptionMsg()); + this.tempContainerState.replace(containerId, State.REMOVING); + this.removeContainerAndState(containerId, false, true); + return this.pushTaskQueue(request); + } this.tempContainerState.replace(containerId, State.READY); // Put back into blocking queue this.readyTempContainer.add(containerId); @@ -216,7 +264,7 @@ private TaskResponse useTempContainer(String containerId, TaskRequest request) { } catch (Exception e) { log.error("use temp container failed, {}", e.getMessage(), e); - return TaskResponse.error(e.getMessage()); + return TaskResponse.exception(e.getMessage()); } } @@ -228,7 +276,7 @@ private TaskResponse createAndUseCoreContainer(TaskRequest request) { } catch (Exception e) { log.error("create new container failed, {}", e.getMessage(), e); - return TaskResponse.error(e.getMessage()); + return TaskResponse.exception(e.getMessage()); } // Record newly added container this.coreContainerState.put(containerId, State.READY); @@ -244,7 +292,7 @@ private TaskResponse createAndUseTempContainer(TaskRequest request) { } catch (Exception e) { log.error("create new container failed, {}", e.getMessage(), e); - return TaskResponse.error(e.getMessage()); + return TaskResponse.exception(e.getMessage()); } // Record newly added container this.tempContainerState.put(containerId, State.READY); @@ -326,7 +374,7 @@ public TaskResponse runTask(TaskRequest request) { } catch (Exception e) { log.error("An exception occurred while executing the task: {}", e.getMessage(), e); - return TaskResponse.error(e.getMessage()); + return TaskResponse.exception(e.getMessage()); } } diff --git a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/impl/AiSimulationCodeExecutorService.java b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/impl/AiSimulationCodeExecutorService.java index f5bf5655c3..5755a29e58 100644 --- a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/impl/AiSimulationCodeExecutorService.java +++ b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/impl/AiSimulationCodeExecutorService.java @@ -58,7 +58,7 @@ public TaskResponse runTask(TaskRequest request) { ``` """, request.code(), request.input()); String output = chatClient.prompt().user(userPrompt).call().content(); - return new TaskResponse(true, output, null, null); + return TaskResponse.success(output); } } diff --git a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/impl/DockerCodePoolExecutorService.java b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/impl/DockerCodePoolExecutorService.java index 81e813e587..8baf4d7dc6 100644 --- a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/impl/DockerCodePoolExecutorService.java +++ b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/impl/DockerCodePoolExecutorService.java @@ -273,21 +273,27 @@ protected String createNewContainer() throws Exception { } @Override - protected TaskResponse execTaskInContainer(TaskRequest request, String containerId) throws Exception { + protected TaskResponse execTaskInContainer(TaskRequest request, String containerId) { // Get temporary directory object, write data to temporary directory Path tempDir = this.containerTempPath.get(containerId); if (tempDir == null) { log.error("Container '{}' does not exist work dir", containerId); - return TaskResponse.error("Container '" + containerId + "' does not exist work dir"); + return TaskResponse.exception("Container '" + containerId + "' does not exist work dir"); + } + + try { + Files.write(tempDir.resolve("script.py"), + StringUtils.hasText(request.code()) ? request.code().getBytes() : "".getBytes()); + Files.write(tempDir.resolve("requirements.txt"), + StringUtils.hasText(request.requirement()) ? request.requirement().getBytes() : "".getBytes()); + Files.write(tempDir.resolve("input_data.txt"), + StringUtils.hasText(request.input()) ? request.input().getBytes() : "".getBytes()); + Files.write(tempDir.resolve("stdout.txt"), "".getBytes()); + Files.write(tempDir.resolve("stderr.txt"), "".getBytes()); + } + catch (Exception e) { + return TaskResponse.exception(e.getMessage()); } - Files.write(tempDir.resolve("script.py"), - StringUtils.hasText(request.code()) ? request.code().getBytes() : "".getBytes()); - Files.write(tempDir.resolve("requirements.txt"), - StringUtils.hasText(request.requirement()) ? request.requirement().getBytes() : "".getBytes()); - Files.write(tempDir.resolve("input_data.txt"), - StringUtils.hasText(request.input()) ? request.input().getBytes() : "".getBytes()); - Files.write(tempDir.resolve("stdout.txt"), "".getBytes()); - Files.write(tempDir.resolve("stderr.txt"), "".getBytes()); try { // start docker @@ -306,13 +312,13 @@ protected TaskResponse execTaskInContainer(TaskRequest request, String container if (exitCode != 0) { String errorMessage = "Docker exit code " + exitCode + ". Stderr: " + stderr + ". Stdout: " + stdout; log.error("Error executing Docker container {}: {}", containerId, errorMessage); - return TaskResponse.error(errorMessage); + return TaskResponse.failure(stdout, stderr); } - return new TaskResponse(true, stdout, stderr, null); + return TaskResponse.success(stdout); } catch (Exception e) { log.error("Error when creating container in docker: {}", e.getMessage()); - return TaskResponse.error(e.getMessage()); + return TaskResponse.exception(e.getMessage()); } } diff --git a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/impl/LocalCodePoolExecutorService.java b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/impl/LocalCodePoolExecutorService.java new file mode 100644 index 0000000000..5c6b4ec007 --- /dev/null +++ b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/code/impl/LocalCodePoolExecutorService.java @@ -0,0 +1,263 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.cloud.ai.service.code.impl; + +import com.alibaba.cloud.ai.config.CodeExecutorProperties; +import com.alibaba.cloud.ai.service.code.CodePoolExecutorService; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.util.StringUtils; + +import java.io.BufferedReader; +import java.io.File; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.StringWriter; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * 使用本地Python3环境运行代码的实现类,要求本地的Python3需要有pandas等数据分析库。 + * + * @author vlsmb + * @since 2025/8/23 + */ +public class LocalCodePoolExecutorService extends AbstractCodePoolExecutorService implements CodePoolExecutorService { + + private static final Logger logger = LoggerFactory.getLogger(LocalCodePoolExecutorService.class); + + private final ConcurrentHashMap containers; + + private static final String[] pythonNames = new String[] { "python3", "pypy3", "py3", "python", "pypy", "py" }; + + private static final String[] pipNames = new String[] { "pip3", "pip" }; + + // 对于本地运行这个实现类,“容器”为临时文件夹 + public LocalCodePoolExecutorService(CodeExecutorProperties properties) { + super(properties); + this.containers = new ConcurrentHashMap<>(); + if (this.checkProgramExists(pythonNames) == null) { + throw new IllegalStateException( + "No valid Python interpreter was found for the current system environment variables. Please install Python3 into the system environment variables first."); + } + } + + @Override + protected String createNewContainer() throws Exception { + Path container = Files.createTempDirectory(this.properties.getContainerNamePrefix()); + String containerId = container.toString(); + this.containers.put(containerId, container); + return containerId; + } + + @Override + protected TaskResponse execTaskInContainer(TaskRequest request, String containerId) { + Path container = this.containers.get(containerId); + + // 写入Py代码和标准输入 + Path scriptFile = container.resolve("script.py"); + Path stdinFile = container.resolve("stdin.txt"); + Path requirementFile = container.resolve("requirements.txt"); + try { + Files.write(scriptFile, Optional.ofNullable(request.code()).orElse("").getBytes()); + Files.write(stdinFile, Optional.ofNullable(request.input()).orElse("").getBytes()); + Files.write(requirementFile, Optional.ofNullable(request.requirement()).orElse("").getBytes()); + } + catch (Exception e) { + logger.error("Create temp file failed: {}", e.getMessage(), e); + return TaskResponse.exception(e.getMessage()); + } + + // 如果有requirements,则先安装依赖 + if (this.checkProgramExists(pipNames) != null && StringUtils.hasText(request.requirement())) { + ProcessBuilder pip = new ProcessBuilder(this.checkProgramExists(pipNames), "install", "--no-cache-dir", + "-r", requirementFile.toAbsolutePath().toString(), ">", "/dev/null"); + Process process = null; + + try { + process = pip.start(); + boolean completed = process.waitFor(this.properties.getContainerTimeout(), TimeUnit.MINUTES); + if (!completed) { + process.destroy(); + if (process.isAlive()) { + process.destroyForcibly(); + } + throw new RuntimeException("Pip command timed out."); + } + } + catch (Exception e) { + // 即使PIP安装失败,仍然尝试运行Python代码 + logger.warn("Pip install failed: {}", e.getMessage(), e); + } + finally { + if (process != null && process.isAlive()) { + process.destroyForcibly(); + } + } + } + + // 运行Python代码 + Process process = null; + try { + ProcessBuilder pb = new ProcessBuilder(this.checkProgramExists(pythonNames), + scriptFile.toAbsolutePath().toString()); + pb.directory(container.toFile()); + pb.redirectInput(stdinFile.toFile()); + process = pb.start(); + + // 读取stdout和stderr + StringWriter stdoutWriter = new StringWriter(); + StringWriter stderrWriter = new StringWriter(); + try (BufferedReader stdoutReader = new BufferedReader(new InputStreamReader(process.getInputStream())); + BufferedReader stderrReader = new BufferedReader(new InputStreamReader(process.getErrorStream()))) { + CompletableFuture stdoutFuture = CompletableFuture.runAsync(() -> { + try { + stdoutReader.transferTo(stdoutWriter); + } + catch (IOException e) { + stderrWriter.write("Error reading stdout: " + e.getMessage()); + } + }); + CompletableFuture stderrFuture = CompletableFuture.runAsync(() -> { + try { + stderrReader.transferTo(stderrWriter); + } + catch (IOException e) { + stderrWriter.write("Error reading stderr: " + e.getMessage()); + } + }); + + // 等待进程完成,带超时限制 + boolean completed = process.waitFor(this.parseToMilliseconds(this.properties.getCodeTimeout()), + TimeUnit.MILLISECONDS); + if (!completed) { + process.destroy(); + if (process.isAlive()) { + process.destroyForcibly(); + } + return TaskResponse.failure("", "python code timeout, Killed."); + } + + // 等待输出读取完成,给输出读取额外2秒时间 + CompletableFuture.allOf(stdoutFuture, stderrFuture).get(2, TimeUnit.SECONDS); + } + + // 返回结果 + int exitCode = process.exitValue(); + String stdout = stdoutWriter.toString(); + String stderr = stderrWriter.toString(); + if (exitCode != 0) { + return TaskResponse.failure(stdout, stderr); + } + else { + return TaskResponse.success(stdout); + } + + } + catch (Exception e) { + logger.error("Python execution failed: {}", e.getMessage(), e); + return TaskResponse.exception(e.getMessage()); + } + finally { + if (process != null && process.isAlive()) { + process.destroyForcibly(); + } + } + } + + @Override + protected void stopContainer(String containerId) throws Exception { + // 临时文件夹没有停止方法 + } + + @Override + protected void removeContainer(String containerId) throws Exception { + Path container = this.containers.remove(containerId); + this.clearTempDir(container); + } + + /** + * 按顺序检查多个程序是否存在 + * @param programNames 程序名称,按优先级顺序 + * @return 第一个找到的程序名称,如果都没找到返回null + */ + private String checkProgramExists(String... programNames) { + if (programNames == null) + return null; + + String pathEnv = System.getenv("PATH"); + if (pathEnv == null) + return null; + + String[] pathDirs = pathEnv.split(File.pathSeparator); + boolean isWindows = System.getProperty("os.name").toLowerCase().contains("win"); + + for (String program : programNames) { + for (String dir : pathDirs) { + if (dir == null || dir.trim().isEmpty()) + continue; + + // 检查原始程序名 + Path path = Paths.get(dir, program); + if (Files.exists(path) && Files.isExecutable(path)) { + return program; + } + + // 在Windows上检查.exe后缀 + if (isWindows) { + Path exePath = Paths.get(dir, program + ".exe"); + if (Files.exists(exePath) && Files.isExecutable(exePath)) { + return program; + } + } + } + } + return null; + } + + private long parseToMilliseconds(String timeString) { + Pattern pattern = Pattern.compile("(\\d+)(ms|[smhd])"); + Matcher matcher = pattern.matcher(timeString.toLowerCase()); + + if (matcher.find()) { + long value = Long.parseLong(matcher.group(1)); + String unit = matcher.group(2); + return switch (unit) { + case "ms" -> value; + case "s" -> value * 1000; + case "m" -> value * 60 * 1000; + case "h" -> value * 60 * 60 * 1000; + case "d" -> value * 24 * 60 * 60 * 1000; + default -> { + logger.warn("Unknown time unit: {}", unit); + // 返回默认值60s + yield 60 * 1000; + } + }; + } + logger.warn("Invalid time format: {}", timeString); + return 60 * 1000; + } + +} diff --git a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/resources/prompts/python-analyze.txt b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/resources/prompts/python-analyze.txt index ec8225a818..5a37060786 100644 --- a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/resources/prompts/python-analyze.txt +++ b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/resources/prompts/python-analyze.txt @@ -1,9 +1,9 @@ +# ROLE: 数据分析报告撰写专家 + 你是一位专业的数据分析报告撰写专家,擅长将复杂的数据分析结果转化为清晰、准确、易懂的自然语言总结。 你的任务是根据用户的原始查询需求和Python脚本的分析输出结果,生成一段**结构清晰、语言简洁、内容准确**的总结性描述。 -请遵循以下要求: - --- ### 输入信息 @@ -27,4 +27,82 @@ --- -请根据以上信息,生成符合要求的总结内容: +### 分析与总结逻辑 + +1. **理解用户查询**:仔细分析用户的原始查询,明确其核心需求和关注点。 +2. **解析分析结果**:根据Python输出结果,提取其中的关键信息,并确保总结内容与用户查询高度相关。 +3. **处理特殊情况**: + - 如果Python输出为空或异常,直接说明“未找到相关数据”或“分析过程中出现错误”。 + - 如果结果中包含多个维度或指标,按重要性排序,优先总结最核心的内容。 +4. **语言优化**:使用简练的语言表达,避免冗长或复杂的句式,确保总结易于理解。 + +--- + +### 示例 + +#### 示例 1:正常分析结果 +**用户原始查询**: +“统计各渠道的线索数量和转化率。” + +**Python分析结果**: +```json +\{ + "channel_stats": [ + \{"channel": "线上广告", "lead_count": 500, "conversion_rate": 0.15\}, + \{"channel": "线下活动", "lead_count": 300, "conversion_rate": 0.25\}, + \{"channel": "合作伙伴", "lead_count": 200, "conversion_rate": 0.1\} + ] +\} +``` + +**总结**: +线上广告带来了500条线索,转化率为15%;线下活动带来了300条线索,转化率为25%;合作伙伴带来了200条线索,转化率为10%。 + +--- + +#### 示例 2:空结果 +**用户原始查询**: +“分析过去一年的城市销售数据。” + +**Python分析结果**: +```json +\{\} +``` + +**总结**: +未找到相关数据。 + +--- + +#### 示例 3:异常结果 +**用户原始查询**: +“计算每个省份的平均订单金额。” + +**Python分析结果**: +分析过程中出现错误。 + +**总结**: +分析过程中出现错误。 + +--- + +### 输出模板 + +请根据以下模板生成总结内容: + +```text +[总结内容] +``` + +--- + +### 注意事项 + +1. **准确性**:总结必须完全基于Python分析结果,不得添加任何推测或主观内容。 +2. **简洁性**:尽量用简短的句子表达关键信息,避免冗余。 +3. **一致性**:确保总结内容与用户查询的需求一致,避免偏离主题。 +4. **适应性**:能够处理各种类型的分析结果,包括正常结果、空结果和异常结果。 + +--- + +请根据以上规则生成符合要求的总结内容: diff --git a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/resources/prompts/python-generator.txt b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/resources/prompts/python-generator.txt index f66425c830..337f27388c 100644 --- a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/resources/prompts/python-generator.txt +++ b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/resources/prompts/python-generator.txt @@ -1,11 +1,13 @@ +# ROLE: 专业Python数据分析工程师 + 你是一个专业的Python数据分析工程师,你的任务是根据用户提供的自然语言分析需求、数据库表结构和SQL查询结果样例,编写一段**可直接运行的无状态Python脚本**。 **请严格遵循以下规范生成代码**: -1. 【纯净输出】只输出可执行的Python代码,禁止包含任何额外说明或自然语言,在代码内部需要有适量的注释方便阅读。 - **特别注意**:模型输出的文本直接接入Python解释器运行,因此不要添加任何额外符号,**比如Markdown的代码块标记符号**! -2. 【输入规范】从`sys.stdin`读取JSON数组(List[Dict]),使用`json.load(sys.stdin)` -3. 【输出规范】最终结果必须是JSON对象(Dict),通过`print(json.dumps(result))`输出,JSON字段可以自定义,但要满足用户需求 -4. 【错误处理】使用以下结构捕获所有异常: +1. **纯净输出**:只输出可执行的Python代码,禁止包含任何额外说明或自然语言。在代码内部需要有适量的注释方便阅读。 + - **特别注意**:模型输出的文本直接接入Python解释器运行,因此不要添加任何额外符号,**比如Markdown的代码块标记符号**! +2. **输入规范**:从`sys.stdin`读取JSON数据(List[Dict]),使用`json.load(sys.stdin)`。 +3. **输出规范**:最终结果必须是JSON对象(Dict),通过`print(json.dumps(result, ensure_ascii=False))`输出,JSON字段可以自定义,但要满足用户需求。 +4. **错误处理**:使用以下结构捕获所有异常: ```python import traceback try: @@ -14,18 +16,113 @@ traceback.print_exc(file=sys.stderr) sys.exit(1) ``` -5. 【依赖限制】所有使用的库必须是`continuumio/anaconda3`默认安装的库,如`pandas`, `numpy`, `json`, `sys`等。 -6. 【动态处理】禁止硬编码列名/值,所有逻辑基于输入数据动态构建 -7. 【安全限制】禁止以下操作: - - 任何文件/网络操作(open/requests等) - - 系统调用(os/subprocess) - - 图形/绘图功能 - - 一些危险的库(pickle) -8. 【性能约束】单线程执行,最大内存:{python_memory} MB,超时时间:{python_timeout} +5. **依赖限制**:所有使用的库必须是`continuumio/anaconda3`默认安装的库,如`pandas`, `numpy`, `json`, `sys`等。 +6. **动态处理**:禁止硬编码列名/值,所有逻辑基于输入数据动态构建。 +7. **安全限制**:禁止以下操作: + - 任何文件/网络操作(open/requests等)。 + - 系统调用(os/subprocess)。 + - 图形/绘图功能。 + - 一些危险的库(pickle)。 +8. **性能约束**:单线程执行,最大内存:{python_memory} MB,超时时间:{python_timeout}。 **核心要求**:生成的代码必须满足: -① 输入SQL结果JSON → ② 执行分析 → ③ 输出JSON结果 的完整闭环 -④ 异常时通过stderr提供可调试的完整堆栈信息 +① 输入SQL结果JSON → ② 执行分析 → ③ 输出JSON结果 的完整闭环。 +④ 异常时通过stderr提供可调试的完整堆栈信息。 + +以下是生成代码的模板,请根据具体需求填充逻辑: + +```python +import sys +import json +import traceback +import pandas as pd + +# 错误处理 +try: + # 从stdin读取输入数据 + input_data = json.load(sys.stdin) + + # 将输入数据转换为DataFrame以便于分析 + df = pd.DataFrame(input_data) + + # 动态分析逻辑 + # 示例:计算某些统计指标 + result = \{ + "summary": \{\}, + "details": [] + \} + + # 示例逻辑:计算每列的平均值(可根据需求调整) + for column in df.columns: + if pd.api.types.is_numeric_dtype(df[column]): + result["summary"][column] = \{ + "mean": df[column].mean(), + "min": df[column].min(), + "max": df[column].max() + \} + + # 示例逻辑:将原始数据分组并统计(可根据需求调整) + grouped = df.groupby(list(df.columns[:2])).size().reset_index(name="count") + result["details"] = grouped.to_dict(orient="records") + + # 输出结果为JSON对象 + print(json.dumps(result, ensure_ascii=False)) + +except Exception: + # 捕获异常并输出堆栈信息到stderr + traceback.print_exc(file=sys.stderr) + sys.exit(1) +``` + +--- + +# 注意事项 + +1. **输入验证**:确保代码能够正确处理空输入或格式不正确的输入,并在异常时提供清晰的错误信息。**处理的数据必须来自`json.load(sys.stdin)`**。 +2. **性能优化**:尽量减少不必要的计算和内存占用,确保代码在性能约束内高效运行。 +3. **结果完整性**:输出的JSON对象应全面反映分析结果,且字段命名清晰易懂。 + +--- + +# 示例输出 + +假设用户需求是“统计每个渠道的线索数量和转化率”,生成的代码可能如下: + +```python +import sys +import json +import traceback +import pandas as pd + +try: + # 从stdin读取输入数据 + input_data = json.load(sys.stdin) + + # 转换为DataFrame + df = pd.DataFrame(input_data) + + # 动态分析逻辑 + result = \{ + "channel_stats": [] + \} + + # 计算每个渠道的线索数量和转化率 + if "channel" in df.columns and "conversion" in df.columns: + grouped = df.groupby("channel").agg( + lead_count=("conversion", "size"), + conversion_rate=("conversion", "mean") + ).reset_index() + + result["channel_stats"] = grouped.to_dict(orient="records") + + # 输出结果为JSON对象 + print(json.dumps(result, ensure_ascii=False)) + +except Exception: + # 捕获异常并输出堆栈信息到stderr + traceback.print_exc(file=sys.stderr) + sys.exit(1) +``` === 上下文信息 === @@ -47,6 +144,8 @@ {plan_description} ``` +--- + === 用户输入 === 接下来是用户的需求: diff --git a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/java/com/alibaba/cloud/ai/service/code/DockerCodePoolExecutorServiceTest.java b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/java/com/alibaba/cloud/ai/service/code/DockerCodePoolExecutorServiceTest.java index 6858703f04..1dce20e9cd 100644 --- a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/java/com/alibaba/cloud/ai/service/code/DockerCodePoolExecutorServiceTest.java +++ b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/java/com/alibaba/cloud/ai/service/code/DockerCodePoolExecutorServiceTest.java @@ -26,7 +26,6 @@ import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; -import org.springframework.test.context.ActiveProfiles; import org.springframework.util.StringUtils; import java.util.concurrent.CountDownLatch; @@ -38,7 +37,6 @@ @SpringBootTest(classes = { CodeExecutorProperties.class }) @DisplayName("Run Python Code in Docker Test Without Network") -@ActiveProfiles("docker") public class DockerCodePoolExecutorServiceTest { private static final Logger log = LoggerFactory.getLogger(DockerCodePoolExecutorServiceTest.class); @@ -50,6 +48,8 @@ public class DockerCodePoolExecutorServiceTest { @BeforeEach public void init() { + this.properties.setCodeTimeout("5s"); + this.properties.setCodePoolExecutor(CodePoolExecutorEnum.DOCKER); this.codePoolExecutorService = new DockerCodePoolExecutorService(properties); } @@ -81,7 +81,8 @@ private void testTimeoutCode() { .runTask(new CodePoolExecutorService.TaskRequest(CodeTestConstant.TIMEOUT_CODE, "", null)); System.out.println(response); log.info("Run Code with Endless Loop Finished"); - if (response.isSuccess() || !response.toString().contains("Killed")) { + if (response.isSuccess() || !response.toString().contains("Killed") + || !response.executionSuccessButResultFailed()) { throw new RuntimeException("Test Failed"); } } @@ -92,7 +93,8 @@ private void testErrorCode() { .runTask(new CodePoolExecutorService.TaskRequest(CodeTestConstant.ERROR_CODE, "", null)); System.out.println(response); log.info("Run Code with Syntax Error Finished"); - if (response.isSuccess() || !response.toString().contains("SyntaxError")) { + if (response.isSuccess() || !response.toString().contains("SyntaxError") + || !response.executionSuccessButResultFailed()) { throw new RuntimeException("Test Failed"); } } diff --git a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/java/com/alibaba/cloud/ai/service/code/LocalCodePoolExecutorServiceTest.java b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/java/com/alibaba/cloud/ai/service/code/LocalCodePoolExecutorServiceTest.java new file mode 100644 index 0000000000..f8ca71fbce --- /dev/null +++ b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/java/com/alibaba/cloud/ai/service/code/LocalCodePoolExecutorServiceTest.java @@ -0,0 +1,160 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.cloud.ai.service.code; + +import com.alibaba.cloud.ai.config.CodeExecutorProperties; +import com.alibaba.cloud.ai.service.code.impl.LocalCodePoolExecutorService; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.util.StringUtils; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; + +@SpringBootTest(classes = { CodeExecutorProperties.class }) +@DisplayName("Run Python Code in Local Command Test") +public class LocalCodePoolExecutorServiceTest { + + private static final Logger logger = LoggerFactory.getLogger(LocalCodePoolExecutorServiceTest.class); + + @Autowired + private CodeExecutorProperties properties; + + private CodePoolExecutorService codePoolExecutorService = null; + + @BeforeEach + public void init() { + this.properties.setCodeTimeout("5s"); + this.properties.setCodePoolExecutor(CodePoolExecutorEnum.LOCAL); + this.codePoolExecutorService = new LocalCodePoolExecutorService(properties); + } + + private void testNormalCode() { + logger.info("Run Normal Code"); + CodePoolExecutorService.TaskResponse response = codePoolExecutorService + .runTask(new CodePoolExecutorService.TaskRequest(CodeTestConstant.NORMAL_CODE, "", null)); + System.out.println(response); + logger.info("Run Normal Code Finished"); + if (!response.isSuccess() || !response.stdOut().contains("3628800")) { + throw new RuntimeException("Test Failed"); + } + } + + private void testTimeoutCode() { + logger.info("Run Code with Endless Loop"); + CodePoolExecutorService.TaskResponse response = codePoolExecutorService + .runTask(new CodePoolExecutorService.TaskRequest(CodeTestConstant.TIMEOUT_CODE, "", null)); + System.out.println(response); + logger.info("Run Code with Endless Loop Finished"); + if (response.isSuccess() || !response.toString().contains("Killed") + || !response.executionSuccessButResultFailed()) { + throw new RuntimeException("Test Failed"); + } + } + + private void testErrorCode() { + logger.info("Run Code with Syntax Error"); + CodePoolExecutorService.TaskResponse response = codePoolExecutorService + .runTask(new CodePoolExecutorService.TaskRequest(CodeTestConstant.ERROR_CODE, "", null)); + System.out.println(response); + logger.info("Run Code with Syntax Error Finished"); + if (response.isSuccess() || !response.toString().contains("SyntaxError") + || !response.executionSuccessButResultFailed()) { + throw new RuntimeException("Test Failed"); + } + } + + private void testNeedInput() { + logger.info("Check Need Input"); + CodePoolExecutorService.TaskResponse response = codePoolExecutorService + .runTask(new CodePoolExecutorService.TaskRequest(CodeTestConstant.NEED_INPUT, "DataFrame Data", null)); + System.out.println(response); + logger.info("Run Need Input Finished"); + if (!response.isSuccess() || !response.stdOut().contains("DataFrame Data")) { + throw new RuntimeException("Test Failed"); + } + } + + private void testStudentScoreAnalysis() { + logger.info("Run Student Score Analysis"); + CodePoolExecutorService.TaskResponse response = codePoolExecutorService + .runTask(new CodePoolExecutorService.TaskRequest(CodeTestConstant.STUDENT_SCORE_ANALYSIS, + CodeTestConstant.STUDENT_SCORE_ANALYSIS_INPUT, null)); + System.out.println(response); + logger.info("Run Student Score Analysis Finished"); + if (!response.isSuccess() || !StringUtils.hasText(response.stdOut())) { + throw new RuntimeException("Test Failed"); + } + } + + @Test + public void testPandasCode() { + logger.info("Run Pandas Code"); + CodePoolExecutorService.TaskResponse response = codePoolExecutorService + .runTask(new CodePoolExecutorService.TaskRequest(CodeTestConstant.ECOMMERCE_SALES_PANDAS_CODE, + CodeTestConstant.ECOMMERCE_SALES_PANDAS_INPUT, null)); + System.out.println(response); + logger.info("Run Pandas Code Finished"); + assert response.isSuccess() + || (response.executionSuccessButResultFailed() && response.toString().contains("ModuleNotFoundError")); + } + + @Test + @DisplayName("Concurrency Testing") + public void testConcurrency() throws InterruptedException { + ExecutorService executorService = Executors.newFixedThreadPool(10); + final int taskNum = 5; + CountDownLatch countDownLatch = new CountDownLatch(taskNum); + AtomicInteger successTask = new AtomicInteger(0); + + Consumer> submitTask = consumer -> { + executorService.submit(() -> { + try { + consumer.accept(this); + successTask.incrementAndGet(); + } + catch (Exception e) { + throw new RuntimeException(e); + } + finally { + countDownLatch.countDown(); + } + }); + }; + + submitTask.accept(LocalCodePoolExecutorServiceTest::testNormalCode); + submitTask.accept(LocalCodePoolExecutorServiceTest::testTimeoutCode); + submitTask.accept(LocalCodePoolExecutorServiceTest::testErrorCode); + submitTask.accept(LocalCodePoolExecutorServiceTest::testNeedInput); + submitTask.accept(LocalCodePoolExecutorServiceTest::testStudentScoreAnalysis); + + assert countDownLatch.await(600L, TimeUnit.SECONDS); + logger.info("Success Task Number: {}", successTask.get()); + Assertions.assertEquals(taskNum, successTask.get()); + } + +} diff --git a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/resources/application-docker.yml b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/resources/application-docker.yml deleted file mode 100644 index 8ac19132a4..0000000000 --- a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/resources/application-docker.yml +++ /dev/null @@ -1,8 +0,0 @@ -spring: - ai: - alibaba: - nl2sql: - code-executor: - code-pool-executor: docker - code-timeout: 5s - enabled: true diff --git a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-management/src/main/resources/application.yml b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-management/src/main/resources/application.yml index 07a2e807e9..c143d3f105 100644 --- a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-management/src/main/resources/application.yml +++ b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-management/src/main/resources/application.yml @@ -52,7 +52,8 @@ spring: alibaba: nl2sql: code-executor: - code-pool-executor: ai_simulation + # 运行Python代码的环境(生产环境建议使用docker,不建议使用local) + code-pool-executor: local # MyBatis Plus 配置 mybatis-plus: diff --git a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-web-ui/src/views/AgentRun.vue b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-web-ui/src/views/AgentRun.vue index c9b449dfb6..ff9320cf1c 100644 --- a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-web-ui/src/views/AgentRun.vue +++ b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-web-ui/src/views/AgentRun.vue @@ -268,7 +268,7 @@ @@ -320,6 +320,15 @@ import { ref, reactive, computed, onMounted, onUnmounted, nextTick, watch } from import { useRoute, useRouter } from 'vue-router' import { presetQuestionApi } from '../utils/api.js' +import hljs from 'highlight.js'; +import 'highlight.js/styles/github.css'; +import python from 'highlight.js/lib/languages/python'; +import sql from 'highlight.js/lib/languages/sql' + +// 注册语言 +hljs.registerLanguage('python', python); +hljs.registerLanguage('sql', sql); + export default { name: 'AgentRun', setup() { @@ -558,6 +567,8 @@ export default { content: message, timestamp: new Date() } + + console.log("userMessage: " + userMessage); currentMessages.value.push(userMessage) @@ -588,8 +599,9 @@ export default { }) const streamState = { - contentByType: {}, - typeOrder: [], + contentByIndex: [], + typeByIndex: [], + lastType: "" } const typeMapping = { @@ -613,9 +625,10 @@ export default { const updateDisplay = () => { let fullContent = '
' - for (const type of streamState.typeOrder) { + for(let i = 0; i < streamState.contentByIndex.length; i++) { + const type = streamState.typeByIndex[i]; const typeInfo = typeMapping[type] || { title: type, icon: 'bi bi-file-text' } - const content = streamState.contentByType[type] || '' + const content = streamState.contentByIndex[i] || '' const formattedSubContent = formatContentByType(type, content) fullContent += `
@@ -679,14 +692,17 @@ export default { if (actualType === 'sql' && typeof processedData === 'string') { processedData = processedData.replace(/^```\s*sql?\s*/i, '').replace(/```\s*$/, '').trim() } - - if (!streamState.contentByType.hasOwnProperty(actualType)) { - streamState.typeOrder.push(actualType) - streamState.contentByType[actualType] = '' + + // 增加状态判断,如果当前节点的type与上一个type不同,则说明应该另外起一个Content + console.log("lastType: " + streamState.lastType + ", actualType: " + actualType); + if (streamState.lastType !== actualType) { + streamState.typeByIndex.push(actualType); + streamState.contentByIndex.push(""); + streamState.lastType = actualType; } if (processedData) { - streamState.contentByType[actualType] += processedData + streamState.contentByIndex[streamState.contentByIndex.length - 1] += processedData; } updateDisplay() @@ -803,6 +819,11 @@ export default { sendMessage() } } + + // 发送按钮不能直接接入sendMessage函数,因为会把event当作参数传递进去,导致message不为字符串 + const handleSendBtnPressed = (event) => { + sendMessage(); + } const adjustTextareaHeight = () => { const textarea = messageInput.value @@ -1115,6 +1136,25 @@ export default { return `
${cleanedData}
`; } + if (type === 'python_generate') { + // 处理可能存在的Markdown标记(正常情况下不会有) + let cleanedData = data.replace(/^```\s*python?\s*/i, '').replace(/```\s*$/, '').trim(); + + // 创建code元素 + const codeElement = document.createElement('code'); + codeElement.className = 'language-python'; + codeElement.textContent = cleanedData; + + // 高亮代码 + hljs.highlightElement(codeElement); + + // 创建pre元素并包装code元素 + const preElement = document.createElement('pre'); + preElement.appendChild(codeElement); + + return preElement.outerHTML; + } + if (type === 'result') { return convertJsonToHTMLTable(data); } @@ -2558,6 +2598,7 @@ export default { sendMessage, sendQuickMessage, handleKeyDown, + handleSendBtnPressed, adjustTextareaHeight, formatMessage, formatTime, diff --git a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-web-ui/src/views/AgentWorkspace.vue b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-web-ui/src/views/AgentWorkspace.vue index e9057b394c..0ebdfc14aa 100644 --- a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-web-ui/src/views/AgentWorkspace.vue +++ b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-web-ui/src/views/AgentWorkspace.vue @@ -195,6 +195,15 @@ import { ref, onMounted, nextTick } from 'vue'; import { useRouter } from 'vue-router'; import { agentApi, presetQuestionApi } from '../utils/api.js'; +import hljs from 'highlight.js'; +import 'highlight.js/styles/github.css'; +import python from 'highlight.js/lib/languages/python'; +import sql from 'highlight.js/lib/languages/sql' + +// 注册语言 +hljs.registerLanguage('python', python); +hljs.registerLanguage('sql', sql); + export default { name: 'AgentWorkspace', setup() { @@ -422,6 +431,9 @@ export default { const formatContentByType = (type, data) => { if (data === null || data === undefined) return ''; + console.log(type); + console.log(data); + if (type === 'sql') { let cleanedData = data.replace(/^```\s*sql?\s*/i, '').replace(/```\s*$/, '').trim(); // 处理SQL中的转义换行符 @@ -433,6 +445,25 @@ export default { return convertJsonToHTMLTable(data); } + if (type === 'python_generate') { + // 处理可能存在的Markdown标记(正常情况下不会有) + let cleanedData = data.replace(/^```\s*python?\s*/i, '').replace(/```\s*$/, '').trim(); + + // 创建code元素 + const codeElement = document.createElement('code'); + codeElement.className = 'language-python'; + codeElement.textContent = cleanedData; + + // 高亮代码 + hljs.highlightElement(codeElement); + + // 创建pre元素并包装code元素 + const preElement = document.createElement('pre'); + preElement.appendChild(codeElement); + + return preElement.outerHTML; + } + // 直接处理数据,简化逻辑 let processedData = data;