diff --git a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/BranchNode.java b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/BranchNode.java deleted file mode 100644 index 71e9ea124f..0000000000 --- a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/BranchNode.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright 2024-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.alibaba.cloud.ai.graph.node; - -import com.alibaba.cloud.ai.graph.OverAllState; -import com.alibaba.cloud.ai.graph.action.NodeAction; -import java.util.HashMap; -import java.util.Map; -import org.springframework.util.StringUtils; - -public class BranchNode implements NodeAction { - - private final String inputKey; - - private final String outputKey; - - public BranchNode(String outputKey, String inputKey) { - if (!StringUtils.hasLength(inputKey) || !StringUtils.hasLength(outputKey)) { - throw new IllegalArgumentException("inputKey and outputKey must not be null or empty."); - } - this.inputKey = inputKey; - this.outputKey = outputKey; - } - - @Override - public Map apply(OverAllState state) throws Exception { - Map updatedState = new HashMap<>(); - String value = state.value(inputKey).map(Object::toString).orElse(null); - updatedState.put(this.outputKey, value); - return updatedState; - } - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - - private String outputKey; - - private String inputKey; - - public BranchNode build() { - return new BranchNode(outputKey, inputKey); - } - - public Builder outputKey(String outputKey) { - this.outputKey = outputKey; - return this; - } - - public Builder inputKey(String inputKey) { - this.inputKey = inputKey; - return this; - } - - } - -} diff --git a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/CodeExecutorNodeAction.java b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/CodeExecutorNodeAction.java index 6490a9cc3b..d6452e986a 100644 --- a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/CodeExecutorNodeAction.java +++ b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/CodeExecutorNodeAction.java @@ -17,9 +17,10 @@ import java.util.ArrayList; import java.util.HashMap; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; import com.alibaba.cloud.ai.graph.OverAllState; import com.alibaba.cloud.ai.graph.action.NodeAction; @@ -27,6 +28,8 @@ import com.alibaba.cloud.ai.graph.node.code.entity.CodeExecutionConfig; import com.alibaba.cloud.ai.graph.node.code.entity.CodeExecutionResult; import com.alibaba.cloud.ai.graph.node.code.entity.CodeLanguage; +import com.alibaba.cloud.ai.graph.node.code.entity.CodeParam; +import com.alibaba.cloud.ai.graph.node.code.entity.CodeStyle; import com.alibaba.cloud.ai.graph.node.code.entity.RunnerAndPreload; import com.alibaba.cloud.ai.graph.node.code.javascript.NodeJsTemplateTransformer; import com.alibaba.cloud.ai.graph.node.code.python3.Python3TemplateTransformer; @@ -47,10 +50,12 @@ public class CodeExecutorNodeAction implements NodeAction { private final CodeExecutionConfig codeExecutionConfig; - private Map params; + private final List params; private final String outputKey; + private final CodeStyle style; + private static final Map CODE_TEMPLATE_TRANSFORMERS = Map.of( CodeLanguage.PYTHON3, new Python3TemplateTransformer(), CodeLanguage.PYTHON, new Python3TemplateTransformer(), CodeLanguage.JAVASCRIPT, new NodeJsTemplateTransformer(), @@ -61,24 +66,25 @@ CodeLanguage.PYTHON3, new Python3TemplateTransformer(), CodeLanguage.PYTHON, CodeLanguage.PYTHON3.getValue(), CodeLanguage.PYTHON, CodeLanguage.PYTHON.getValue(), CodeLanguage.JAVA, CodeLanguage.JAVA.getValue()); - public CodeExecutorNodeAction(CodeExecutor codeExecutor, String codeLanguage, String code, - CodeExecutionConfig config, Map params, String outputKey) { + public CodeExecutorNodeAction(CodeExecutor codeExecutor, String codeLanguage, String code, CodeStyle style, + CodeExecutionConfig config, List params, String outputKey) { this.codeExecutor = codeExecutor; this.codeLanguage = codeLanguage; + this.style = style; this.code = code; this.codeExecutionConfig = config; this.params = params; this.outputKey = outputKey; } - private Map executeWorkflowCodeTemplate(CodeLanguage language, String code, List inputs) - throws Exception { + private Map executeWorkflowCodeTemplate(CodeLanguage language, String code, + Map inputs) throws Exception { TemplateTransformer templateTransformer = CODE_TEMPLATE_TRANSFORMERS.get(language); if (templateTransformer == null) { throw new RuntimeException("Unsupported language: " + language); } - RunnerAndPreload runnerAndPreload = templateTransformer.transformCaller(code, inputs); + RunnerAndPreload runnerAndPreload = templateTransformer.transformCaller(code, inputs, style); String response = executeCode(language, runnerAndPreload.preloadScript(), runnerAndPreload.runnerScript()); return templateTransformer.transformResponse(response); @@ -100,12 +106,12 @@ private String executeCode(CodeLanguage language, String preloadScript, String c @Override public Map apply(OverAllState state) throws Exception { - List inputs = new ArrayList<>(10); - if (params != null && !params.isEmpty()) { - for (String key : params.keySet()) { - inputs.add(state.data().get((String) params.get(key))); - } - } + Map inputs = Optional.ofNullable(params) + .orElse(List.of()) + .stream() + .collect(Collectors.toUnmodifiableMap(CodeParam::argName, param -> Optional.ofNullable(param.value()) + .or(() -> StringUtils.hasText(param.stateKey()) ? state.value(param.stateKey()) : Optional.empty()) + .orElseThrow(() -> new IllegalStateException("param has no value and legal key!")))); Map resultObjectMap = executeWorkflowCodeTemplate(CodeLanguage.fromValue(codeLanguage), code, inputs); Map updatedState = new HashMap<>(); @@ -127,13 +133,16 @@ public static class Builder { private String code; + private CodeStyle style; + private CodeExecutionConfig config; - private Map params; + private List params; private String outputKey; public Builder() { + style = CodeStyle.EXPLICIT_PARAMETERS; } public Builder codeExecutor(CodeExecutor codeExecutor) { @@ -151,13 +160,18 @@ public Builder code(String code) { return this; } + public Builder codeStyle(CodeStyle style) { + this.style = style; + return this; + } + public Builder config(CodeExecutionConfig config) { this.config = config; return this; } - public Builder params(Map params) { - this.params = new LinkedHashMap<>(params); + public Builder params(List params) { + this.params = List.copyOf(params); return this; } @@ -167,7 +181,7 @@ public Builder outputKey(String outputKey) { } public CodeExecutorNodeAction build() { - return new CodeExecutorNodeAction(codeExecutor, codeLanguage, code, config, params, outputKey); + return new CodeExecutorNodeAction(codeExecutor, codeLanguage, code, style, config, params, outputKey); } } diff --git a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/TemplateTransformer.java b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/TemplateTransformer.java index 2d6e6c562e..5b2048dc00 100644 --- a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/TemplateTransformer.java +++ b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/TemplateTransformer.java @@ -16,12 +16,12 @@ package com.alibaba.cloud.ai.graph.node.code; +import com.alibaba.cloud.ai.graph.node.code.entity.CodeStyle; import com.alibaba.cloud.ai.graph.node.code.entity.RunnerAndPreload; import com.fasterxml.jackson.databind.ObjectMapper; import java.nio.charset.StandardCharsets; import java.util.Base64; -import java.util.List; import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -38,8 +38,8 @@ public abstract class TemplateTransformer { protected static final String RESULT_TAG = "<>"; - public RunnerAndPreload transformCaller(String code, List inputs) throws Exception { - String runnerScript = assembleRunnerScript(code, inputs); + public RunnerAndPreload transformCaller(String code, Map inputs, CodeStyle style) throws Exception { + String runnerScript = assembleRunnerScript(code, inputs, style); String preloadScript = getPreloadScript(); return new RunnerAndPreload(runnerScript, preloadScript); @@ -52,7 +52,7 @@ public Map transformResponse(String response) throws Exception { mapper.getTypeFactory().constructMapType(Map.class, String.class, Object.class)); } - public abstract String getRunnerScript(); + public abstract String getRunnerScript(CodeStyle style); private String extractResultStrFromResponse(String response) { Pattern pattern = Pattern.compile(RESULT_TAG + "(.*?)" + RESULT_TAG, Pattern.DOTALL); @@ -66,14 +66,14 @@ private String extractResultStrFromResponse(String response) { } } - private String serializeInputs(List inputs) throws Exception { + private String serializeInputs(Map inputs) throws Exception { ObjectMapper mapper = new ObjectMapper(); String inputsJsonStr = mapper.writeValueAsString(inputs); return Base64.getEncoder().encodeToString(inputsJsonStr.getBytes(StandardCharsets.UTF_8)); } - private String assembleRunnerScript(String code, List inputs) throws Exception { - String script = getRunnerScript(); + private String assembleRunnerScript(String code, Map inputs, CodeStyle style) throws Exception { + String script = getRunnerScript(style); script = script.replace(CODE_PLACEHOLDER, code); script = script.replace(INPUTS_PLACEHOLDER, serializeInputs(inputs)); return script; diff --git a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/entity/CodeParam.java b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/entity/CodeParam.java new file mode 100644 index 0000000000..bfe2812fa2 --- /dev/null +++ b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/entity/CodeParam.java @@ -0,0 +1,38 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.cloud.ai.graph.node.code.entity; + +/** + * @author vlsmb + * @since 2025/9/11 + * @param argName 参数在代码中对应的名称 + * @param value 参数值,如果为null,则从OverallState中获取 + * @param stateKey 参数在OverallState中的key,如果value不为null,则忽略stateKey + */ +public record CodeParam(String argName, Object value, String stateKey) { + public CodeParam(String argName, String stateKey) { + this(argName, null, stateKey); + } + + public static CodeParam withValue(String argName, Object value) { + return new CodeParam(argName, value, null); + } + + public static CodeParam withKey(String argName, String stateKey) { + return new CodeParam(argName, null, stateKey); + } +} diff --git a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/entity/CodeStyle.java b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/entity/CodeStyle.java new file mode 100644 index 0000000000..2cb61ab181 --- /dev/null +++ b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/entity/CodeStyle.java @@ -0,0 +1,35 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.cloud.ai.graph.node.code.entity; + +/** + * @author vlsmb + * @since 2025/9/11 + */ +public enum CodeStyle { + + /** + * 参数直接作为函数形参的风格 示例: def main(x: int, y: int) -> dict: + */ + EXPLICIT_PARAMETERS, + + /** + * 参数通过全局字典访问的风格 示例: def main(): x = params['x'] + */ + GLOBAL_DICTIONARY + +} diff --git a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/java/JavaTemplateTransformer.java b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/java/JavaTemplateTransformer.java index 78ecb0af61..c81cbbcf51 100644 --- a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/java/JavaTemplateTransformer.java +++ b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/java/JavaTemplateTransformer.java @@ -15,10 +15,12 @@ */ package com.alibaba.cloud.ai.graph.node.code.java; +import com.alibaba.cloud.ai.graph.node.code.entity.CodeStyle; import com.alibaba.cloud.ai.graph.node.code.TemplateTransformer; /** * Java code template transformer Used to convert user code into executable Java programs + * user method should be {@code public static Object run([..args])} * * @author HeYQ * @since 2024-06-01 @@ -26,30 +28,87 @@ public class JavaTemplateTransformer extends TemplateTransformer { @Override - public String getRunnerScript() { - return """ - import java.util.*; - import com.fasterxml.jackson.databind.ObjectMapper; - import com.fasterxml.jackson.databind.node.ObjectNode; - - class Main { - public static void main(String[] args) throws Exception { - // Parse input parameters - String inputsBase64 = "%s"; - String inputsJson = new String(Base64.getDecoder().decode(inputsBase64)); - ObjectMapper mapper = new ObjectMapper(); - Object[] inputs = mapper.readValue(inputsJson, Object[].class); - - // Execute user code - Object result = main(inputs); - - // Output results - String output = mapper.writeValueAsString(result); - System.out.println("<>" + output + "<>"); - } - - %s - }""".formatted(INPUTS_PLACEHOLDER, CODE_PLACEHOLDER); + public String getRunnerScript(CodeStyle style) { + return switch (style) { + case EXPLICIT_PARAMETERS -> String.format( + """ + import java.util.*; + import com.fasterxml.jackson.databind.ObjectMapper; + import com.fasterxml.jackson.databind.node.ObjectNode; + import com.fasterxml.jackson.core.type.TypeReference; + import java.lang.reflect.Method; + + class Main { + public static void main(String[] args) throws Exception { + // Parse input parameters + String inputsBase64 = "%s"; + String inputsJson = new String(Base64.getDecoder().decode(inputsBase64)); + ObjectMapper mapper = new ObjectMapper(); + Map inputs = mapper.readValue(inputsJson, new TypeReference<>(){}); + + // Execute user code + Object result = invokeMethod(Main.class, "run", inputs); + + // Output results + String output = mapper.writeValueAsString(result); + System.out.println("%s" + output + "%s"); + } + + public static Object invokeMethod(Class clazz, String methodName, Map args) + throws Exception { + Method func = Arrays.stream(clazz.getMethods()).filter(method -> method.getName().equals(methodName)) + .filter(method -> method.getParameterCount() == args.size()) + .findFirst() + .orElseThrow(); + Object[] params = Arrays.stream(func.getParameters()).map( + parameter -> args.get(parameter.getName()) + ).toArray(); + return func.invoke(null, params); + } + + // user code + %s + } + """, + INPUTS_PLACEHOLDER, RESULT_TAG, RESULT_TAG, CODE_PLACEHOLDER); + case GLOBAL_DICTIONARY -> String.format(""" + import java.util.*; + import com.fasterxml.jackson.databind.ObjectMapper; + import com.fasterxml.jackson.databind.node.ObjectNode; + import com.fasterxml.jackson.core.type.TypeReference; + + class Main { + + private static final Map params; + + private static final ObjectMapper mapper; + + static { + try { + // Parse input parameters + String inputsBase64 = "%s"; + String inputsJson = new String(Base64.getDecoder().decode(inputsBase64)); + mapper = new ObjectMapper(); + params = mapper.readValue(inputsJson, new TypeReference<>(){}); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + public static void main(String[] args) throws Exception { + // Execute user code + Object result = run(); + + // Output results + String output = mapper.writeValueAsString(result); + System.out.println("%s" + output + "%s"); + } + + // user code + %s + } + """, INPUTS_PLACEHOLDER, RESULT_TAG, RESULT_TAG, CODE_PLACEHOLDER); + }; } } diff --git a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/javascript/NodeJsTemplateTransformer.java b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/javascript/NodeJsTemplateTransformer.java index ee8f25acf8..3909615e68 100644 --- a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/javascript/NodeJsTemplateTransformer.java +++ b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/javascript/NodeJsTemplateTransformer.java @@ -15,6 +15,7 @@ */ package com.alibaba.cloud.ai.graph.node.code.javascript; +import com.alibaba.cloud.ai.graph.node.code.entity.CodeStyle; import com.alibaba.cloud.ai.graph.node.code.TemplateTransformer; /** @@ -24,22 +25,39 @@ public class NodeJsTemplateTransformer extends TemplateTransformer { @Override - public String getRunnerScript() { - return """ - // declare main function - %s - - // decode and prepare input object - var inputs_obj = JSON.parse(Buffer.from('%s', 'base64').toString('utf-8')) - - // execute main function - var output_obj = main(inputs_obj) - - // convert output to json and print - var output_json = JSON.stringify(output_obj) - var result = `<>${output_json}<>` - console.log(result) - """.formatted(CODE_PLACEHOLDER, INPUTS_PLACEHOLDER); + public String getRunnerScript(CodeStyle style) { + return switch (style) { + case EXPLICIT_PARAMETERS -> String.format(""" + // declare main function + %s + + // decode and prepare input object + var inputs_obj = JSON.parse(Buffer.from('%s', 'base64').toString('utf-8')) + + // execute main function + var output_obj = main(inputs_obj) + + // convert output to json and print + var output_json = JSON.stringify(output_obj) + var result = `%s${output_json}%s` + console.log(result) + """, CODE_PLACEHOLDER, INPUTS_PLACEHOLDER, RESULT_TAG, RESULT_TAG); + case GLOBAL_DICTIONARY -> String.format(""" + // decode and prepare input object + let params = JSON.parse(Buffer.from('%s', 'base64').toString('utf-8')) + + // declare main function + %s + + // execute main function + var output_obj = main() + + // convert output to json and print + var output_json = JSON.stringify(output_obj) + var result = `%s${output_json}%s` + console.log(result) + """, INPUTS_PLACEHOLDER, CODE_PLACEHOLDER, RESULT_TAG, RESULT_TAG); + }; } } diff --git a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/python3/Python3TemplateTransformer.java b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/python3/Python3TemplateTransformer.java index c33900bc91..057f01fc56 100644 --- a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/python3/Python3TemplateTransformer.java +++ b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/python3/Python3TemplateTransformer.java @@ -15,6 +15,7 @@ */ package com.alibaba.cloud.ai.graph.node.code.python3; +import com.alibaba.cloud.ai.graph.node.code.entity.CodeStyle; import com.alibaba.cloud.ai.graph.node.code.TemplateTransformer; /** @@ -24,13 +25,34 @@ public class Python3TemplateTransformer extends TemplateTransformer { @Override - public String getRunnerScript() { - return String.join("\n", - new String[] { "# declare main function", CODE_PLACEHOLDER, "import json", - "from base64 import b64decode", - "inputs_obj = json.loads(b64decode('" + INPUTS_PLACEHOLDER + "').decode('utf-8'))", - "output_obj = main(*inputs_obj)", "output_json = json.dumps(output_obj, indent=4)", - "result = f'''<>{output_json}<>'''", "print(result)" }); + public String getRunnerScript(CodeStyle style) { + return switch (style) { + case EXPLICIT_PARAMETERS -> String.format(""" + # declare main function + %s + + import json + from base64 import b64decode + inputs_obj = json.loads(b64decode('%s').decode('utf-8')) + output_obj = main(**inputs_obj) + output_json = json.dumps(output_obj, indent=4) + result = f'''%s{output_json}%s''' + print(result) + """, CODE_PLACEHOLDER, INPUTS_PLACEHOLDER, RESULT_TAG, RESULT_TAG); + case GLOBAL_DICTIONARY -> String.format(""" + import json + from base64 import b64decode + params = json.loads(b64decode('%s').decode('utf-8')) + + # declare main function + %s + + output_obj = main() + output_json = json.dumps(output_obj, indent=4) + result = f'''%s{output_json}%s''' + print(result) + """, INPUTS_PLACEHOLDER, CODE_PLACEHOLDER, RESULT_TAG, RESULT_TAG); + }; } } diff --git a/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/node/CodeActionTest.java b/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/node/CodeActionTest.java index cad3c8041d..8d13d594d9 100644 --- a/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/node/CodeActionTest.java +++ b/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/node/CodeActionTest.java @@ -18,6 +18,8 @@ import com.alibaba.cloud.ai.graph.OverAllState; import com.alibaba.cloud.ai.graph.action.NodeAction; import com.alibaba.cloud.ai.graph.node.code.CodeExecutorNodeAction; +import com.alibaba.cloud.ai.graph.node.code.entity.CodeParam; +import com.alibaba.cloud.ai.graph.node.code.entity.CodeStyle; import com.alibaba.cloud.ai.graph.node.code.LocalCommandlineCodeExecutor; import com.alibaba.cloud.ai.graph.node.code.entity.CodeExecutionConfig; import org.junit.jupiter.api.BeforeEach; @@ -25,8 +27,8 @@ import org.junit.jupiter.api.io.TempDir; import java.nio.file.Path; -import java.util.HashMap; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -57,33 +59,107 @@ def main(arg1: str, arg2: str) -> dict: "result": arg1 + arg2, } """; - Map params = new HashMap<>(16); - params.put("key1", "arg1"); - params.put("key2", "arg2"); NodeAction codeNode = CodeExecutorNodeAction.builder() .codeExecutor(new LocalCommandlineCodeExecutor()) .code(code) - .codeLanguage("python") + .codeStyle(CodeStyle.EXPLICIT_PARAMETERS) + .codeLanguage("python3") .config(config) - .params(params) + .params(List.of(new CodeParam("arg1", "data1"), new CodeParam("arg2", "data2"))) + .outputKey("output") .build(); - Map initData = new HashMap<>(16); - initData.put("arg1", "1"); - initData.put("arg2", "2"); - OverAllState mockState = new OverAllState(initData); + OverAllState mockState = new OverAllState(Map.of("data1", "1", "data2", "2")); + Map stateData = codeNode.apply(mockState); + System.out.println(stateData); + assertNotNull(stateData); + assertEquals(Map.of("output", Map.of("result", "12")), stateData); + } + + @Test + void testExecutePythonGlobalDictStyleSuccessfully() throws Exception { + String code = """ + def main(): + ret = { + "output": params['arg1'] + params['arg2'] + params['arg3'] + } + return ret + """; + NodeAction codeNode = CodeExecutorNodeAction.builder() + .codeExecutor(new LocalCommandlineCodeExecutor()) + .code(code) + .codeStyle(CodeStyle.GLOBAL_DICTIONARY) + .codeLanguage("python3") + .config(config) + .params(List.of(CodeParam.withKey("arg1", "arg1"), CodeParam.withKey("arg2", "arg2"), + CodeParam.withValue("arg3", "3"))) + .outputKey("output") + .build(); + OverAllState mockState = new OverAllState(Map.of("arg1", "1", "arg2", "2")); + Map stateData = codeNode.apply(mockState); + assertNotNull(stateData); + assertEquals(Map.of("output", Map.of("output", "123")), stateData); + System.out.println(stateData); + } + + @Test + void testExecuteJavascriptSuccessfully() throws Exception { + String code = """ + function main({arg1, arg2}) { + return { + result: arg1 + arg2 + } + } + """; + NodeAction codeNode = CodeExecutorNodeAction.builder() + .codeExecutor(new LocalCommandlineCodeExecutor()) + .code(code) + .codeStyle(CodeStyle.EXPLICIT_PARAMETERS) + .codeLanguage("javascript") + .config(config) + .params(List.of(new CodeParam("arg1", "data1"), new CodeParam("arg2", "data2"))) + .outputKey("output") + .build(); + OverAllState mockState = new OverAllState(Map.of("data1", "1", "data2", "2")); Map stateData = codeNode.apply(mockState); System.out.println(stateData); + assertNotNull(stateData); + assertEquals(Map.of("output", Map.of("result", "12")), stateData); + } + + @Test + void testExecuteJavascriptGlobalDictStyleSuccessfully() throws Exception { + String code = """ + function main() { + const ret = { + "output": params.arg1 + params.arg2 + params.arg3 + }; + return ret; + } + """; + NodeAction codeNode = CodeExecutorNodeAction.builder() + .codeExecutor(new LocalCommandlineCodeExecutor()) + .code(code) + .codeStyle(CodeStyle.GLOBAL_DICTIONARY) + .codeLanguage("javascript") + .config(config) + .params(List.of(CodeParam.withKey("arg1", "arg1"), CodeParam.withKey("arg2", "arg2"), + CodeParam.withValue("arg3", "3"))) + .outputKey("output") + .build(); + OverAllState mockState = new OverAllState(Map.of("arg1", "1", "arg2", "2")); + Map stateData = codeNode.apply(mockState); + assertNotNull(stateData); + assertEquals(Map.of("output", Map.of("output", "123")), stateData); + System.out.println(stateData); } @Test void testExecuteJavaWithLocalExecutor() throws Exception { // Prepare test data String javaCode = """ - public static Object main(Object[] inputs) { - // Process input parameters - String text = (String) inputs[0]; - Integer count = (Integer) inputs[1]; - + public static Object run(String arg0, Integer arg1) { + String text = arg0; + int count = arg1; // Execute business logic StringBuilder result = new StringBuilder(); for (int i = 0; i < count; i++) { @@ -99,17 +175,14 @@ public static Object main(Object[] inputs) { } """; - // Create parameter mapping - Map params = new LinkedHashMap<>(); - params.put("text", "text"); - params.put("count", "count"); // Create code execution node action NodeAction codeNode = CodeExecutorNodeAction.builder() .codeExecutor(new LocalCommandlineCodeExecutor()) .code(javaCode) + .codeStyle(CodeStyle.EXPLICIT_PARAMETERS) .codeLanguage("java") .config(config) - .params(params) + .params(List.of(new CodeParam("arg0", "text"), new CodeParam("arg1", "count"))) .outputKey("codeNode1_output") .build(); @@ -121,13 +194,57 @@ public static Object main(Object[] inputs) { // Execute code Map result = codeNode.apply(mockState); - Map codeNode1Output = (Map) result.get("codeNode1_output"); + assertNotNull(result); + System.out.println(result); + assertEquals(Map.of("codeNode1_output", Map.of("length", 18, "count", 3, "repeated_text", "Hello Hello Hello")), + result); + } + + @Test + void testExecuteJavaGlobalDictStyleWithLocalExecutor() throws Exception { + // Prepare test data + String javaCode = """ + public static Object run() { + String text = (String) params.get("arg0"); + int count = (Integer) params.get("arg1"); + // Execute business logic + StringBuilder result = new StringBuilder(); + for (int i = 0; i < count; i++) { + result.append(text).append(" "); + } + + Map response = new HashMap<>(); + response.put("repeated_text", result.toString().trim()); + response.put("length", result.length()); + response.put("count", count); - // Verify results + return response; + } + """; + + // Create code execution node action + NodeAction codeNode = CodeExecutorNodeAction.builder() + .codeExecutor(new LocalCommandlineCodeExecutor()) + .code(javaCode) + .codeLanguage("java") + .config(config) + .codeStyle(CodeStyle.GLOBAL_DICTIONARY) + .params(List.of(new CodeParam("arg0", "text"), new CodeParam("arg1", "count"))) + .outputKey("codeNode1_output") + .build(); + + // Prepare input data + Map initData = new LinkedHashMap<>(); + initData.put("text", "Hello"); + initData.put("count", 3); + OverAllState mockState = new OverAllState(initData); + + // Execute code + Map result = codeNode.apply(mockState); assertNotNull(result); - assertEquals("Hello Hello Hello", codeNode1Output.get("repeated_text")); - assertEquals(18, codeNode1Output.get("length")); - assertEquals(3, codeNode1Output.get("count")); + System.out.println(result); + assertEquals(Map.of("codeNode1_output", Map.of("length", 18, "count", 3, "repeated_text", "Hello Hello Hello")), + result); } } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/Edge.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/Edge.java index b40bf275e6..7b1ca24a5c 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/Edge.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/Edge.java @@ -31,11 +31,6 @@ public class Edge { private Map data; - private Integer zIndex = 0; - - // Temp field - private boolean isDify = true; - public String getId() { return id; } @@ -90,22 +85,4 @@ public Edge setData(Map data) { return this; } - public Integer getzIndex() { - return zIndex; - } - - public Edge setzIndex(Integer zIndex) { - this.zIndex = zIndex; - return this; - } - - public boolean isDify() { - return isDify; - } - - public Edge setDify(boolean isDify) { - this.isDify = isDify; - return this; - } - } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/Node.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/Node.java index 62ca949a3f..c855591fc5 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/Node.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/Node.java @@ -21,6 +21,9 @@ public class Node implements RunnableModel { private String id; + // 如果在循环节点里,则有父节点ID + private String parentId; + private NodeType type; private String title; @@ -54,6 +57,15 @@ public Node setId(String id) { return this; } + public String getParentId() { + return parentId; + } + + public Node setParentId(String parentId) { + this.parentId = parentId; + return this; + } + public NodeType getType() { return type; } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java index e637e58390..6fefead060 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java @@ -26,11 +26,13 @@ public enum NodeType { ANSWER("answer", "answer", "UNSUPPORTED"), + MIDDLE_OUTPUT("middle-output", "UNSUPPORTED", "Output"), + AGENT("agent", "agent", "UNSUPPORTED"), LLM("llm", "llm", "LLM"), - CODE("code", "code", "UNSUPPORTED"), + CODE("code", "code", "Script"), RETRIEVER("retriever", "knowledge-retrieval", "Retrieval"), @@ -56,9 +58,13 @@ public enum NodeType { TEMPLATE_TRANSFORM("template-transform", "template-transform", "UNSUPPORTED"), - ITERATION("iteration", "iteration", "UNSUPPORTED"), + ITERATION("iteration", "iteration", "Parallel"), + + EMPTY("empty", "UNSUPPORTED", "UNSUPPORTED"), + + ITERATION_START("iteration-start", "iteration-start", "ParallelStart"), - DIFY_ITERATION_START("__empty__", "iteration-start", "UNSUPPORTED"), + ITERATION_END("iteration-end", "iteration-end", "ParallelEnd"), ASSIGNER("assigner", "assigner", "UNSUPPORTED"); diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/CodeNodeData.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/CodeNodeData.java index d235c57cd0..f805db583a 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/CodeNodeData.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/CodeNodeData.java @@ -16,30 +16,35 @@ package com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata; import java.util.List; +import java.util.Map; import com.alibaba.cloud.ai.studio.admin.generator.model.Variable; -import com.alibaba.cloud.ai.studio.admin.generator.model.VariableSelector; import com.alibaba.cloud.ai.studio.admin.generator.model.VariableType; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeData; +import com.alibaba.cloud.ai.studio.admin.generator.utils.ObjectToCodeUtil; public class CodeNodeData extends NodeData { public static Variable getDefaultOutputSchema() { - return new Variable("output", VariableType.STRING); + return new Variable("output", VariableType.OBJECT); } private String code; private String codeLanguage; + private List inputParams; + private String outputKey; - public CodeNodeData() { - } + private int maxRetryCount = 1; - public CodeNodeData(List inputs, List outputs) { - super(inputs, outputs); - } + private int retryIntervalMs = 1000; + + // 运行失败时的默认值 + private Map defaultValue; + + private CodeStyle codeStyle = CodeStyle.EXPLICIT_PARAMETERS; public String getCode() { return code; @@ -59,6 +64,15 @@ public CodeNodeData setCodeLanguage(String codeLanguage) { return this; } + public List getInputParams() { + return inputParams; + } + + public CodeNodeData setInputParams(List inputParams) { + this.inputParams = inputParams; + return this; + } + public String getOutputKey() { return outputKey; } @@ -68,4 +82,86 @@ public CodeNodeData setOutputKey(String outputKey) { return this; } + public int getMaxRetryCount() { + return maxRetryCount; + } + + public CodeNodeData setMaxRetryCount(int maxRetryCount) { + this.maxRetryCount = maxRetryCount; + return this; + } + + public int getRetryIntervalMs() { + return retryIntervalMs; + } + + public CodeNodeData setRetryIntervalMs(int retryIntervalMs) { + this.retryIntervalMs = retryIntervalMs; + return this; + } + + public Map getDefaultValue() { + return defaultValue; + } + + public CodeNodeData setDefaultValue(Map defaultValue) { + this.defaultValue = defaultValue; + return this; + } + + public CodeStyle getCodeStyle() { + return codeStyle; + } + + public CodeNodeData setCodeStyle(CodeStyle codeStyle) { + this.codeStyle = codeStyle; + return this; + } + + public enum CodeStyle { + + /** + * Dify代码样式 + */ + EXPLICIT_PARAMETERS, + + /** + * Studio代码样式 + */ + GLOBAL_DICTIONARY + + ; + + public String toString() { + return "CodeStyle." + name(); + } + + } + + public record CodeParam(String argName, Object value, String stateKey) { + public static CodeParam withValue(String argName, Object value) { + return new CodeParam(argName, value, null); + } + + public static CodeParam withKey(String argName, String stateKey) { + return new CodeParam(argName, null, stateKey); + } + + @Override + public String toString() { + if (argName == null) { + throw new IllegalArgumentException("argName cannot be null"); + } + if (value == null && stateKey != null) { + return String.format("CodeParam.withKey(%s, %s)", ObjectToCodeUtil.toCode(argName()), + ObjectToCodeUtil.toCode(stateKey())); + } + if (value != null && stateKey == null) { + return String.format("CodeParam.withValue(%s, %s)", ObjectToCodeUtil.toCode(argName()), + ObjectToCodeUtil.toCode(value())); + } + throw new IllegalArgumentException("value and stateKey must only one."); + } + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/EmptyNodeData.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/EmptyNodeData.java index 4efa671acd..6f5fb13ee5 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/EmptyNodeData.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/EmptyNodeData.java @@ -26,18 +26,4 @@ */ public class EmptyNodeData extends NodeData { - private String id; - - public String getId() { - return id; - } - - public void setId(String id) { - this.id = id; - } - - public EmptyNodeData(String id) { - this.id = id; - } - } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/IterationNodeData.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/IterationNodeData.java index 21551bf922..9c952b1aef 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/IterationNodeData.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/IterationNodeData.java @@ -16,14 +16,12 @@ package com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata; -import java.util.List; - import com.alibaba.cloud.ai.studio.admin.generator.model.Variable; import com.alibaba.cloud.ai.studio.admin.generator.model.VariableSelector; import com.alibaba.cloud.ai.studio.admin.generator.model.VariableType; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeData; -import org.springframework.util.StringUtils; +import java.util.List; /** * @author vlsmb @@ -31,137 +29,85 @@ */ public class IterationNodeData extends NodeData { - public static Variable getDefaultOutputSchema() { - return new Variable("output", VariableType.ARRAY_STRING); + public static List getDefaultOutputSchemas() { + return List.of(new Variable("state", VariableType.ARRAY_NUMBER), // 剩余未处理的元素索引 + new Variable("index", VariableType.NUMBER), // 迭代索引 + new Variable("isFinished", VariableType.BOOLEAN) // 迭代是否结束 + ); } - private String id; - - private VariableType inputType; - - private VariableType outputType; - - private VariableSelector inputSelector; - - private VariableSelector outputSelector; - - private String startNodeId; - - private String endNodeId; - - private String outputKey; - - private String inputKey; + // NodeData的来源节点名称 + private final String sourceVarName; - private String startNodeName; - - private String endNodeName; - - // 内部临时变量名 - private String innerArrayKey; - - private String innerStartFlagKey; - - private String innerEndFlagKey; - - private String innerItemKey; - - private String innerItemResultKey; - - private String innerIndexKey; - - private Variable output; - - public IterationNodeData(String id, VariableType inputType, VariableType outputType, VariableSelector inputSelector, - VariableSelector outputSelector, String startNodeId, String endNodeId, String inputKey, String outputKey) { - this.id = id; - this.inputType = inputType; - this.outputType = outputType; - this.inputSelector = inputSelector; - this.outputSelector = outputSelector; - this.startNodeId = startNodeId; - this.endNodeId = endNodeId; - this.inputKey = inputKey; - this.outputKey = outputKey; - this.setVarName(id); - this.output = new Variable(outputKey, outputType); - this.setInputs(List.of(inputSelector)); - this.setOutputs(List.of(output)); - } + private int parallelCount = 1; - @Override - public void setVarName(String varName) { - this.varName = varName; - this.innerArrayKey = this.varName + "_array"; - this.innerStartFlagKey = this.varName + "_start_flag"; - this.innerEndFlagKey = this.varName + "_end_flag"; - this.innerItemKey = this.varName + "_item"; - this.innerItemResultKey = this.varName + "_item_result"; - this.innerIndexKey = this.varName + "_index"; - } + private int maxIterationCount = Integer.MAX_VALUE; - public String getId() { - return id; - } + // Dify的迭代索引从0开始,而Studio的从1开始,故需要设置这个值 + private int indexOffset = 0; - public void setId(String id) { - this.id = id; - } + // itemKey和outputKey的后缀在Dify中固定,但在Studio中用户可以自定义 + private String itemKey = "item"; - public VariableType getInputType() { - return inputType; - } + private String outputKey = "output"; - public void setInputType(VariableType inputType) { - this.inputType = inputType; - } + // 迭代输入的Selector + private VariableSelector inputSelector; - public VariableType getOutputType() { - return outputType; - } + // 迭代结果元素的Selector + private VariableSelector resultSelector; - public void setOutputType(VariableType outputType) { - this.outputType = outputType; + public IterationNodeData(IterationNodeData other) { + parallelCount = other.parallelCount; + maxIterationCount = other.maxIterationCount; + indexOffset = other.indexOffset; + itemKey = other.itemKey; + outputKey = other.outputKey; + inputSelector = other.inputSelector; + resultSelector = other.resultSelector; + sourceVarName = other.getVarName(); + setVarName(other.getVarName()); } - public VariableSelector getInputSelector() { - return inputSelector; + public IterationNodeData() { + super(); + sourceVarName = null; } - public void setInputSelector(VariableSelector inputSelector) { - this.inputSelector = inputSelector; + public String getSourceVarName() { + return sourceVarName; } - public VariableSelector getOutputSelector() { - return outputSelector; + public int getParallelCount() { + return parallelCount; } - public void setOutputSelector(VariableSelector outputSelector) { - this.outputSelector = outputSelector; + public void setParallelCount(int parallelCount) { + this.parallelCount = parallelCount; } - public String getStartNodeId() { - return startNodeId; + public int getMaxIterationCount() { + return maxIterationCount; } - public void setStartNodeId(String startNodeId) { - this.startNodeId = startNodeId; + public void setMaxIterationCount(int maxIterationCount) { + this.maxIterationCount = maxIterationCount; } - public String getEndNodeId() { - return endNodeId; + public int getIndexOffset() { + return indexOffset; } - public void setEndNodeId(String endNodeId) { - this.endNodeId = endNodeId; + public void setIndexOffset(int indexOffset) { + this.indexOffset = indexOffset; } - public String getInputKey() { - return inputKey; + public String getItemKey() { + return itemKey; } - public void setInputKey(String inputKey) { - this.inputKey = inputKey; + public void setItemKey(String itemKey) { + this.itemKey = itemKey; } public String getOutputKey() { @@ -172,175 +118,20 @@ public void setOutputKey(String outputKey) { this.outputKey = outputKey; } - public String getInnerArrayKey() { - return innerArrayKey; - } - - public void setInnerArrayKey(String innerArrayKey) { - this.innerArrayKey = innerArrayKey; - } - - public String getInnerStartFlagKey() { - return innerStartFlagKey; - } - - public void setInnerStartFlagKey(String innerStartFlagKey) { - this.innerStartFlagKey = innerStartFlagKey; - } - - public String getInnerEndFlagKey() { - return innerEndFlagKey; - } - - public void setInnerEndFlagKey(String innerEndFlagKey) { - this.innerEndFlagKey = innerEndFlagKey; - } - - public String getInnerItemKey() { - return innerItemKey; - } - - public void setInnerItemKey(String innerItemKey) { - this.innerItemKey = innerItemKey; - } - - public String getInnerItemResultKey() { - return innerItemResultKey; - } - - public void setInnerItemResultKey(String innerItemResultKey) { - this.innerItemResultKey = innerItemResultKey; - } - - public String getInnerIndexKey() { - return innerIndexKey; - } - - public void setInnerIndexKey(String innerIndexKey) { - this.innerIndexKey = innerIndexKey; - } - - public Variable getOutput() { - return output; - } - - public void setOutput(Variable output) { - this.output = output; - } - - public String getStartNodeName() { - return startNodeName; - } - - public void setStartNodeName(String startNodeName) { - this.startNodeName = startNodeName; - } - - public String getEndNodeName() { - return endNodeName; + public VariableSelector getInputSelector() { + return inputSelector; } - public void setEndNodeName(String endNodeName) { - this.endNodeName = endNodeName; + public void setInputSelector(VariableSelector inputSelector) { + this.inputSelector = inputSelector; } - public static class Builder { - - private String id; - - private VariableType inputType; - - private VariableType outputType; - - private VariableSelector inputSelector; - - private VariableSelector outputSelector; - - private String startNodeId; - - private String endNodeId; - - private String inputKey; - - private String outputKey; - - // 可以不设置,使用默认值"_item" - private String itemKey; - - // 可以不设置,使用默认值"_index" - private String indexKey; - - public Builder id(String id) { - this.id = id; - return this; - } - - public Builder inputType(VariableType inputType) { - this.inputType = inputType; - return this; - } - - public Builder outputType(VariableType outputType) { - this.outputType = outputType; - return this; - } - - public Builder inputSelector(VariableSelector inputSelector) { - this.inputSelector = inputSelector; - return this; - } - - public Builder outputSelector(VariableSelector outputSelector) { - this.outputSelector = outputSelector; - return this; - } - - public Builder startNodeId(String startNodeId) { - this.startNodeId = startNodeId; - return this; - } - - public Builder endNodeId(String endNodeId) { - this.endNodeId = endNodeId; - return this; - } - - public Builder inputKey(String inputKey) { - this.inputKey = inputKey; - return this; - } - - public Builder outputKey(String outputKey) { - this.outputKey = outputKey; - return this; - } - - public Builder itemKey(String itemKey) { - this.itemKey = itemKey; - return this; - } - - public Builder indexKey(String indexKey) { - this.indexKey = indexKey; - return this; - } - - public IterationNodeData build() { - IterationNodeData data = new IterationNodeData(id, inputType, outputType, inputSelector, outputSelector, - startNodeId, endNodeId, inputKey, outputKey); - if (StringUtils.hasText(this.itemKey)) { - data.setInnerItemKey(this.itemKey); - } - if (StringUtils.hasText(this.indexKey)) { - data.setInnerIndexKey(this.indexKey); - } - return data; - } - + public VariableSelector getResultSelector() { + return resultSelector; } - public static Builder builder() { - return new Builder(); + public void setResultSelector(VariableSelector resultSelector) { + this.resultSelector = resultSelector; } } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/LLMNodeData.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/LLMNodeData.java index 38c4a015be..bf74735dd4 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/LLMNodeData.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/LLMNodeData.java @@ -131,8 +131,9 @@ public record MessageTemplate(String template, List keys, MessageType ty @Override public String toString() { - return String.format("new MessageTemplate(\"%s\", %s, MessageType.%s)", this.template(), - ObjectToCodeUtil.toCode(this.keys()), this.type().getValue().toUpperCase()); + return String.format("new MessageTemplate(%s, %s, MessageType.%s)", + ObjectToCodeUtil.toCode(this.template()), ObjectToCodeUtil.toCode(this.keys()), + this.type().getValue().toUpperCase()); } } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/MiddleOutputNodeData.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/MiddleOutputNodeData.java new file mode 100644 index 0000000000..550876dec2 --- /dev/null +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/MiddleOutputNodeData.java @@ -0,0 +1,65 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata; + +import com.alibaba.cloud.ai.studio.admin.generator.model.Variable; +import com.alibaba.cloud.ai.studio.admin.generator.model.VariableType; +import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeData; +import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; + +import java.util.List; + +public class MiddleOutputNodeData extends NodeData { + + public static List getDefaultOutputSchemas(DSLDialectType dialectType) { + return switch (dialectType) { + case STUDIO -> List.of(new Variable("output", VariableType.STRING)); + default -> List.of(); + }; + } + + private String outputTemplate; + + private List varKeys; + + private String outputKey; + + public String getOutputTemplate() { + return outputTemplate; + } + + public void setOutputTemplate(String outputTemplate) { + this.outputTemplate = outputTemplate; + } + + public List getVarKeys() { + return varKeys; + } + + public void setVarKeys(List varKeys) { + this.varKeys = varKeys; + } + + public String getOutputKey() { + return outputKey; + } + + public void setOutputKey(String outputKey) { + this.outputKey = outputKey; + } + +} diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/AbstractNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/AbstractNodeDataConverter.java index 2fd8b931ea..55f0b01c2f 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/AbstractNodeDataConverter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/AbstractNodeDataConverter.java @@ -19,9 +19,11 @@ import java.util.Map; import java.util.function.BiConsumer; import java.util.function.BiFunction; +import java.util.regex.MatchResult; import java.util.regex.Matcher; import java.util.regex.Pattern; +import com.alibaba.cloud.ai.studio.admin.generator.model.VariableSelector; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeData; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; @@ -72,6 +74,23 @@ public interface DialectConverter { Map dump(T nodeData); + /** + * 将模板字符串转换为变量选择器 + * @param dialectType dsl语言 + * @param template 模板字符串 + * @return 变量选择器 + */ + default VariableSelector varTemplateToSelector(DSLDialectType dialectType, String template) { + Pattern pattern = switch (dialectType) { + case DIFY -> DIFY_VAR_TEMPLATE_PATTERN; + case STUDIO -> STUDIO_VAR_TEMPLATE_PATTERN; + default -> throw new UnsupportedOperationException(); + }; + Matcher matcher = pattern.matcher(template); + MatchResult result = matcher.results().findFirst().orElseThrow(); + return new VariableSelector(result.group(1), result.group(2)); + } + } public static DialectConverter defaultCustomDialectConverter(Class clazz) { @@ -100,6 +119,12 @@ public Map dump(R nodeData) { protected abstract List> getDialectConverters(); + private static final Pattern DIFY_VAR_TEMPLATE_PATTERN = Pattern.compile("\\{\\{#(\\w+)\\.(\\w+)#}}"); + + private static final Pattern STUDIO_VAR_TEMPLATE_PATTERN = Pattern.compile("\\$\\{(\\w+)\\.\\[?(\\w+)]?}"); + + private static final Pattern VAR_TEMPLATE_PATTERN = Pattern.compile("\\{(\\w+)}"); + /** * 将文本中变量占位符进行转化,比如Dify DSL的"你好,{{#123.query#}}"转化为"你好,{nodeName1_query}" * @param dialectType dsl语言 @@ -116,8 +141,7 @@ protected String convertVarTemplate(DSLDialectType dialectType, String templateS return str; } StringBuilder result = new StringBuilder(); - Pattern pattern = Pattern.compile("\\{\\{#(\\w+)\\.(\\w+)#}}"); - Matcher matcher = pattern.matcher(str); + Matcher matcher = DIFY_VAR_TEMPLATE_PATTERN.matcher(str); while (matcher.find()) { String nodeId = matcher.group(1); String varName = matcher.group(2); @@ -133,8 +157,8 @@ protected String convertVarTemplate(DSLDialectType dialectType, String templateS return str; } StringBuilder result = new StringBuilder(); - Pattern pattern = Pattern.compile("\\$\\{(\\w+)\\.(\\w+)}"); - Matcher matcher = pattern.matcher(str); + + Matcher matcher = STUDIO_VAR_TEMPLATE_PATTERN.matcher(str); while (matcher.find()) { String nodeId = matcher.group(1); String varName = matcher.group(2); @@ -150,6 +174,16 @@ protected String convertVarTemplate(DSLDialectType dialectType, String templateS return func.apply(templateString, idToVarName); } + /** + * 获取模板中的变量占位符,比如"你好{var1},{var2}"返回"[var1, var2]" + * @param template 模板字符串 + * @return 变量占位符列表 + */ + protected List getVarTemplateKeys(String template) { + Matcher matcher = VAR_TEMPLATE_PATTERN.matcher(template); + return matcher.results().map(m -> m.group(1)).toList(); + } + /** * 创建一个空处理Consumer,便于使用.andThen编程 * @return BiConsumer diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/DifyDSLAdapter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/DifyDSLAdapter.java index d6d00f2417..75690d4bcb 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/DifyDSLAdapter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/DifyDSLAdapter.java @@ -20,6 +20,9 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; import java.util.UUID; import java.util.function.BiConsumer; import java.util.stream.Collectors; @@ -36,10 +39,12 @@ import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeData; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeType; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.Workflow; +import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata.IterationNodeData; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.AbstractDSLAdapter; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.NodeDataConverter; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.Serializer; +import com.alibaba.cloud.ai.studio.admin.generator.utils.MapReadUtil; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; @@ -161,18 +166,87 @@ public Workflow mapToWorkflow(Map data) { private Graph constructGraph(Map data) { Graph graph = new Graph(); - List nodes = new ArrayList<>(); - List edges = new ArrayList<>(); + List nodes; + List edges; + // convert nodes if (data.containsKey("nodes")) { List> nodeMaps = (List>) data.get("nodes"); - nodes = constructNodes(nodeMaps); + nodes = new ArrayList<>(constructNodes(nodeMaps)); } + else { + nodes = new ArrayList<>(); + } + // convert edges if (data.containsKey("edges")) { List> edgeMaps = (List>) data.get("edges"); - edges = constructEdges(edgeMaps); + edges = new ArrayList<>(constructEdges(edgeMaps)); } + else { + edges = new ArrayList<>(); + } + + Map varNames = nodes.stream() + .collect(Collectors.toMap(Node::getId, n -> n.getData().getVarName())); + Map nodeIdMap = nodes.stream().collect(Collectors.toMap(Node::getId, n -> n)); + + // 根据parnetId进行分组,为了给迭代节点的起始节点传递迭代数据 + Map> groupByParentId = nodes.stream() + .filter(node -> Objects.nonNull(node.getParentId())) + .collect(Collectors.groupingBy(Node::getParentId)); + + // 统计具有出度的节点 + Set nodeIdHasOut = edges.stream().map(Edge::getSource).collect(Collectors.toSet()); + + groupByParentId.forEach((parentId, subNodes) -> { + subNodes.forEach(node -> { + if (NodeType.ITERATION_START.equals(node.getType())) { + IterationNodeData nodeData = new IterationNodeData( + (IterationNodeData) nodeIdMap.get(parentId).getData()); + nodeData.setVarName(nodeIdMap.get(parentId).getData().getVarName() + "_start"); + varNames.put(node.getId(), nodeData.getVarName()); + node.setData(nodeData); + } + else if (NodeType.ITERATION_END.equals(node.getType())) { + IterationNodeData nodeData = new IterationNodeData( + (IterationNodeData) nodeIdMap.get(parentId).getData()); + nodeData.setVarName(nodeIdMap.get(parentId).getData().getVarName() + "_end"); + varNames.put(node.getId(), nodeData.getVarName()); + node.setData(nodeData); + } + }); + + // 添加迭代节点的终止节点(Dify的DSL没有提供但为了后续正常转换,这里需要添加) + NodeData nodeData = new IterationNodeData((IterationNodeData) nodeIdMap.get(parentId).getData()); + nodeData.setVarName(nodeData.getVarName() + "_end"); + Node endNode = new Node(); + endNode.setData(nodeData).setType(NodeType.ITERATION_END).setParentId(parentId); + nodes.add(endNode); + + // 计算每个节点的出度,出度为0的点将与迭代终止节点相连接 + subNodes.stream().map(Node::getId).filter(id -> !nodeIdHasOut.contains(id)).forEach(id -> { + Edge newEdge = new Edge().setSource(id).setTarget(nodeData.getVarName()); + edges.add(newEdge); + }); + }); + + // 将Edge里的source和target都转换成varName + edges.forEach(edge -> { + edge.setSource(varNames.getOrDefault(edge.getSource(), edge.getSource())); + edge.setTarget(varNames.getOrDefault(edge.getTarget(), edge.getTarget())); + }); + + // 将Iteration节点起始改为iteration_start,并将Iteration节点结束改为iteration_end + Map nodeVarMap = nodes.stream().collect(Collectors.toMap(n -> n.getData().getVarName(), n -> n)); + edges.forEach(edge -> { + if (NodeType.ITERATION.equals(nodeVarMap.get(edge.getSource()).getType())) { + edge.setSource(edge.getSource() + "_end"); + } + if (NodeType.ITERATION.equals(nodeVarMap.get(edge.getTarget()).getType())) { + edge.setTarget(edge.getTarget() + "_start"); + } + }); graph.setNodes(nodes); graph.setEdges(edges); @@ -208,7 +282,12 @@ private List constructNodes(List> nodeMaps) { nodeMap.remove("type"); Node node = objectMapper.convertValue(nodeMap, Node.class); // set title and desc - node.setTitle((String) nodeDataMap.get("title")).setDesc((String) nodeDataMap.get("desc")); + String parentId = Optional.ofNullable(MapReadUtil.getMapDeepValue(nodeMap, String.class, "parentId")) + .or(() -> Optional.ofNullable(MapReadUtil.getMapDeepValue(nodeDataMap, String.class, "iteration_id"))) + .orElse(null); + node.setTitle((String) nodeDataMap.get("title")) + .setDesc((String) nodeDataMap.get("desc")) + .setParentId(parentId); // convert node data using specific WorkflowNodeDataConverter @SuppressWarnings("unchecked") diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/StudioDSLAdapter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/StudioDSLAdapter.java index 9af8c883e1..4040058649 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/StudioDSLAdapter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/StudioDSLAdapter.java @@ -26,6 +26,7 @@ import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeData; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeType; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.Workflow; +import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata.IterationNodeData; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.AbstractDSLAdapter; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.NodeDataConverter; @@ -42,6 +43,8 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Optional; import java.util.function.BiConsumer; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -50,6 +53,7 @@ * @author vlsmb * @since 2025/8/27 */ +// TODO: 与DifyDSLAdapter合并一些重复代码 @Component public class StudioDSLAdapter extends AbstractDSLAdapter { @@ -118,13 +122,86 @@ public Workflow mapToWorkflow(Map data) { private Graph constructGraph(Map data) { Graph graph = new Graph(); - List> nodeMap = MapReadUtil - .safeCastToListWithMap(MapReadUtil.getMapDeepValue(data, List.class, "config", "nodes")); - List> edgeMap = MapReadUtil - .safeCastToListWithMap(MapReadUtil.getMapDeepValue(data, List.class, "config", "edges")); + List> nodeMap = new ArrayList<>(Optional + .ofNullable( + MapReadUtil.safeCastToListWithMap(MapReadUtil.getMapDeepValue(data, List.class, "config", "nodes"))) + .orElse(List.of())); + List> edgeMap = new ArrayList<>(Optional + .ofNullable( + MapReadUtil.safeCastToListWithMap(MapReadUtil.getMapDeepValue(data, List.class, "config", "edges"))) + .orElse(List.of())); + + List> innerNodeMaps = new ArrayList<>(); + List> innerEdgeMaps = new ArrayList<>(); + + // 展开迭代节点内部的Node和Edge + nodeMap.forEach(map -> { + NodeType type = NodeType.fromStudioValue(MapReadUtil.getMapDeepValue(map, String.class, "type")) + .orElseThrow(); + if (NodeType.ITERATION.equals(type)) { + List> innerNode = MapReadUtil.safeCastToListWithMap( + MapReadUtil.getMapDeepValue(map, List.class, "config", "node_param", "block", "nodes")); + if (innerNode != null) { + innerNodeMaps.addAll(innerNode); + } + List> innerEdge = MapReadUtil.safeCastToListWithMap( + MapReadUtil.getMapDeepValue(map, List.class, "config", "node_param", "block", "edges")); + if (innerEdge != null) { + innerEdgeMaps.addAll(innerEdge); + } + } + }); + nodeMap.addAll(innerNodeMaps); + edgeMap.addAll(innerEdgeMaps); List nodes = this.constructNodes(nodeMap); List edges = this.constructEdges(edgeMap); + + Map varNames = nodes.stream() + .collect(Collectors.toMap(Node::getId, n -> n.getData().getVarName())); + Map nodeIdMap = nodes.stream().collect(Collectors.toMap(Node::getId, n -> n)); + + // 根据parnetId进行分组,为了给迭代节点的起始节点传递迭代数据 + Map> groupByParentId = nodes.stream() + .filter(node -> Objects.nonNull(node.getParentId())) + .collect(Collectors.groupingBy(Node::getParentId)); + + groupByParentId.forEach((parentId, subNodes) -> { + subNodes.forEach(node -> { + if (NodeType.ITERATION_START.equals(node.getType())) { + IterationNodeData nodeData = new IterationNodeData( + (IterationNodeData) nodeIdMap.get(parentId).getData()); + nodeData.setVarName(nodeIdMap.get(parentId).getData().getVarName() + "_start"); + varNames.put(node.getId(), nodeData.getVarName()); + node.setData(nodeData); + } + else if (NodeType.ITERATION_END.equals(node.getType())) { + IterationNodeData nodeData = new IterationNodeData( + (IterationNodeData) nodeIdMap.get(parentId).getData()); + nodeData.setVarName(nodeIdMap.get(parentId).getData().getVarName() + "_end"); + varNames.put(node.getId(), nodeData.getVarName()); + node.setData(nodeData); + } + }); + }); + + // 将Edge里的source和target都转换成varName + edges.forEach(edge -> { + edge.setSource(varNames.getOrDefault(edge.getSource(), edge.getSource())); + edge.setTarget(varNames.getOrDefault(edge.getTarget(), edge.getTarget())); + }); + + // 将Iteration节点起始改为iteration_start,并将Iteration节点结束改为iteration_end + Map nodeVarMap = nodes.stream().collect(Collectors.toMap(n -> n.getData().getVarName(), n -> n)); + edges.forEach(edge -> { + if (NodeType.ITERATION.equals(nodeVarMap.get(edge.getSource()).getType())) { + edge.setSource(edge.getSource() + "_end"); + } + if (NodeType.ITERATION.equals(nodeVarMap.get(edge.getTarget()).getType())) { + edge.setTarget(edge.getTarget() + "_start"); + } + }); + graph.setNodes(nodes); graph.setEdges(edges); return graph; @@ -153,7 +230,10 @@ private List constructNodes(List> nodeMaps) { // 构造Node Node node = new Node(); - node.setId(nodeId).setType(nodeType).setTitle(nodeTitle); + node.setId(nodeId) + .setType(nodeType) + .setTitle(nodeTitle) + .setParentId(MapReadUtil.getMapDeepValue(nodeMap, String.class, "parent_id")); // convert node data using specific WorkflowNodeDataConverter @SuppressWarnings("unchecked") @@ -202,8 +282,7 @@ private List constructEdges(List> edgeMaps) { .setSource(source) .setTarget(target) .setSourceHandle(sourceHandle) - .setTargetHandle(targetHandle) - .setDify(false); + .setTargetHandle(targetHandle); return edge; }).toList(); } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/BranchNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/BranchNodeDataConverter.java index 7d3d3c6bc9..ad87859c12 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/BranchNodeDataConverter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/BranchNodeDataConverter.java @@ -19,6 +19,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.BiConsumer; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -135,4 +136,21 @@ public Stream extractWorkflowVars(BranchNodeData data) { return Stream.empty(); } + @Override + public BiConsumer> postProcessConsumer(DSLDialectType dialectType) { + return switch (dialectType) { + case DIFY -> super.postProcessConsumer(dialectType).andThen((nodeData, idToVarName) -> { + // 处理条件里的VariableSelector + nodeData.getCases().forEach(c -> { + c.getConditions().forEach(condition -> { + VariableSelector selector = condition.getVariableSelector(); + selector.setNameInCode(idToVarName.getOrDefault(selector.getNamespace(), "unknown") + "_" + + selector.getName()); + }); + }); + }); + default -> super.postProcessConsumer(dialectType); + }; + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/CodeNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/CodeNodeDataConverter.java index cd2c2c25b3..8b7bd9c167 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/CodeNodeDataConverter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/CodeNodeDataConverter.java @@ -15,11 +15,12 @@ */ package com.alibaba.cloud.ai.studio.admin.generator.service.dsl.converter; -import java.util.ArrayList; -import java.util.HashMap; +import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.function.BiConsumer; +import java.util.stream.Collectors; import java.util.stream.Stream; import com.alibaba.cloud.ai.studio.admin.generator.model.Variable; @@ -30,6 +31,8 @@ import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.AbstractNodeDataConverter; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; +import com.alibaba.cloud.ai.studio.admin.generator.utils.MapReadUtil; +import com.fasterxml.jackson.core.JsonProcessingException; import org.springframework.stereotype.Component; @Component @@ -55,43 +58,172 @@ public Boolean supportDialect(DSLDialectType dialectType) { @Override public CodeNodeData parse(Map data) { - List> variables = (List>) data.get("variables"); - List inputs = variables.stream().map(variable -> { - List selector = (List) variable.get("value_selector"); - return new VariableSelector(selector.get(0), selector.get(1), (String) variable.get("variable")); - }).toList(); - Map> outputsMap = (Map>) data.get("outputs"); - List outputs = outputsMap.entrySet().stream().map(entry -> { - String varName = entry.getKey(); - String difyType = (String) entry.getValue().get("type"); - VariableType varType = VariableType.fromDifyValue(difyType) - .orElseThrow(() -> new IllegalArgumentException("Unsupported dify variable type: " + difyType)); - return new Variable(varName, varType); - }).toList(); - - return new CodeNodeData(inputs, outputs).setCode((String) data.get("code")) - .setCodeLanguage((String) data.get("code_language")); + CodeNodeData nodeData = new CodeNodeData(); + + // 提取必要信息 + String code = MapReadUtil.getMapDeepValue(data, String.class, "code"); + String lang = MapReadUtil.getMapDeepValue(data, String.class, "code_language"); + Boolean isRetry = Optional + .ofNullable(MapReadUtil.getMapDeepValue(data, Boolean.class, "retry_config", "retry_enabled")) + .orElse(false); + int maxRetryCount = isRetry ? Optional + .ofNullable(MapReadUtil.getMapDeepValue(data, Integer.class, "retry_config", "max_retries")) + .orElse(1) : 1; + int retryIntervalMs = isRetry ? Optional + .ofNullable(MapReadUtil.getMapDeepValue(data, Integer.class, "retry_config", "retry_interval")) + .orElse(1000) : 1000; + + List outputParams = Optional + .ofNullable( + MapReadUtil.safeCastToListWithMap(MapReadUtil.getMapDeepValue(data, List.class, "outputs"))) + .orElse(List.of()) + .stream() + .map(Map::entrySet) + .flatMap(Collection::stream) + .map(Map.Entry::getKey) + .map(k -> new Variable(k, VariableType.OBJECT)) + .toList(); + + List inputParams = Optional + .ofNullable(MapReadUtil + .safeCastToListWithMap(MapReadUtil.getMapDeepValue(data, List.class, "variables"))) + .orElse(List.of()) + .stream() + .filter(map -> map.containsKey("value_selector") && map.containsKey("variable")) + .map(map -> { + List list = MapReadUtil.safeCastToList(map.get("value_selector"), String.class); + // 先以Value的形式存储selector,在post阶段转换为正确的stateKey + return new CodeNodeData.CodeParam(map.get("variable").toString(), list, list.get(0)); + }) + .toList(); + + // 设置必要信息 + nodeData.setCodeStyle(CodeNodeData.CodeStyle.EXPLICIT_PARAMETERS); + nodeData.setCode(code); + nodeData.setCodeLanguage(lang); + nodeData.setMaxRetryCount(maxRetryCount); + nodeData.setRetryIntervalMs(retryIntervalMs); + nodeData.setInputParams(inputParams); + nodeData.setOutputs(outputParams); + + // 错误处理策略 + String errorStrategy = MapReadUtil.getMapDeepValue(data, String.class, "error_strategy"); + + if (errorStrategy != null) { + // 暂仅支持默认值 + List> defaultValueList = MapReadUtil + .safeCastToListWithMap(MapReadUtil.getMapDeepValue(data, List.class, "default_value")); + if (defaultValueList != null) { + Map defaultValue = defaultValueList.stream() + .filter(map -> map.containsKey("key") && map.containsKey("value")) + .collect(Collectors.toUnmodifiableMap(map -> map.get("key").toString(), + map -> map.get("value"), (a, b) -> b)); + nodeData.setDefaultValue(defaultValue); + } + } + + return nodeData; + } + + @Override + public Map dump(CodeNodeData nodeData) { + throw new UnsupportedOperationException(); + } + }), + + STUDIO(new DialectConverter<>() { + @Override + public Boolean supportDialect(DSLDialectType dialectType) { + return DSLDialectType.STUDIO.equals(dialectType); + } + + @Override + public CodeNodeData parse(Map data) throws JsonProcessingException { + CodeNodeData nodeData = new CodeNodeData(); + + // 获取基本信息 + String code = MapReadUtil.getMapDeepValue(data, String.class, "config", "node_param", "script_content"); + String lang = MapReadUtil.getMapDeepValue(data, String.class, "config", "node_param", "script_type"); + Boolean isRetry = Optional + .ofNullable(MapReadUtil.getMapDeepValue(data, Boolean.class, "config", "node_param", "retry_config", + "retry_enabled")) + .orElse(false); + int maxRetryCount = isRetry ? Optional + .ofNullable(MapReadUtil.getMapDeepValue(data, Integer.class, "config", "node_param", "retry_config", + "max_retries")) + .orElse(1) : 1; + int retryIntervalMs = isRetry ? Optional + .ofNullable(MapReadUtil.getMapDeepValue(data, Integer.class, "config", "node_param", "retry_config", + "retry_interval")) + .orElse(1000) : 1000; + + List outputParams = Optional + .ofNullable(MapReadUtil.safeCastToListWithMap( + MapReadUtil.getMapDeepValue(data, List.class, "config", "output_params"))) + .orElse(List.of()) + .stream() + .filter(map -> map.containsKey("key")) + .map(map -> new Variable(map.get("key").toString(), VariableType.OBJECT)) + .toList(); + List inputParams = Optional + .ofNullable(MapReadUtil + .safeCastToListWithMap(MapReadUtil.getMapDeepValue(data, List.class, "config", "input_params"))) + .orElse(List.of()) + .stream() + .filter(map -> map.containsKey("key") && map.containsKey("value") && map.containsKey("value_from")) + .map(map -> { + String key = map.get("key").toString(); + Object value = map.get("value"); + String valueFrom = map.get("value_from").toString(); + if ("input".equalsIgnoreCase(valueFrom)) { + return CodeNodeData.CodeParam.withValue(key, value); + } + else { + // 先以Value的形式存储selector,在post阶段转换为正确的stateKey + VariableSelector selector = this.varTemplateToSelector(DSLDialectType.STUDIO, + value.toString()); + List list = List.of(selector.getNamespace(), selector.getName()); + return new CodeNodeData.CodeParam(key, list, value.toString()); + } + }) + .toList(); + + // 设置基本信息 + nodeData.setCodeStyle(CodeNodeData.CodeStyle.GLOBAL_DICTIONARY); + nodeData.setCode(code); + nodeData.setCodeLanguage(lang); + nodeData.setMaxRetryCount(maxRetryCount); + nodeData.setRetryIntervalMs(retryIntervalMs); + nodeData.setInputParams(inputParams); + nodeData.setOutputs(outputParams); + + // 设置错误策略 + String errorStrategy = MapReadUtil.getMapDeepValue(data, String.class, "config", "node_param", + "try_catch_config", "strategy"); + if (errorStrategy != null) { + // 暂仅支持默认值 + List> defaultValueList = MapReadUtil + .safeCastToListWithMap(MapReadUtil.getMapDeepValue(data, List.class, "config", "node_param", + "try_catch_config", "default_values")); + if (defaultValueList != null) { + Map defaultValue = defaultValueList.stream() + .filter(map -> map.containsKey("key") && map.containsKey("value")) + .collect(Collectors.toUnmodifiableMap(map -> map.get("key").toString(), + map -> map.get("value"), (a, b) -> b)); + nodeData.setDefaultValue(defaultValue); + } + } + + return nodeData; } @Override public Map dump(CodeNodeData nodeData) { - Map data = new HashMap<>(); - data.put("code", nodeData.getCode()); - data.put("code_language", nodeData.getCodeLanguage()); - List> inputVars = new ArrayList<>(); - nodeData.getInputs().forEach(v -> { - inputVars.add( - Map.of("variable", v.getLabel(), "value_selector", List.of(v.getNamespace(), v.getName()))); - }); - data.put("variables", inputVars); - Map outputVars = new HashMap<>(); - nodeData.getOutputs().forEach(variable -> { - outputVars.put(variable.getName(), Map.of("type", variable.getValueType().difyValue())); - }); - data.put("outputs", outputVars); - return data; + throw new UnsupportedOperationException(); } - }), CUSTOM(AbstractNodeDataConverter.defaultCustomDialectConverter(CodeNodeData.class)); + }) + + , CUSTOM(AbstractNodeDataConverter.defaultCustomDialectConverter(CodeNodeData.class)); private final DialectConverter dialectConverter; @@ -113,9 +245,19 @@ public String generateVarName(int count) { @Override public BiConsumer> postProcessConsumer(DSLDialectType dialectType) { return switch (dialectType) { - case DIFY -> this.emptyProcessConsumer().andThen((nodeData, idToVarName) -> { + case DIFY, STUDIO -> this.emptyProcessConsumer().andThen((nodeData, idToVarName) -> { // code节点将返回{"varName.output": {...}}的数据,之后拆包成若干输出数据 nodeData.setOutputKey(nodeData.getVarName() + "_" + CodeNodeData.getDefaultOutputSchema().getName()); + // 输入Param的Key都格式化为varName_key + nodeData.setInputParams(nodeData.getInputParams().stream().map(param -> { + if (param.stateKey() == null) { + return param; + } + @SuppressWarnings("unchecked") + List selector = (List) param.value(); + return CodeNodeData.CodeParam.withKey(param.argName(), + idToVarName.getOrDefault(selector.get(0), selector.get(0)) + "_" + selector.get(1)); + }).toList()); }).andThen(super.postProcessConsumer(dialectType)); default -> super.postProcessConsumer(dialectType); }; diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/EmptyNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/EmptyNodeDataConverter.java index eb35e2a241..be7103331a 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/EmptyNodeDataConverter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/EmptyNodeDataConverter.java @@ -37,23 +37,22 @@ public class EmptyNodeDataConverter extends AbstractNodeDataConverter() { + ALL(new DialectConverter<>() { @Override public Boolean supportDialect(DSLDialectType dialectType) { - return dialectType.equals(DSLDialectType.DIFY); + return true; } @Override public EmptyNodeData parse(Map data) throws JsonProcessingException { - String id = (String) data.get("id"); - return new EmptyNodeData(id); + return new EmptyNodeData(); } @Override public Map dump(EmptyNodeData nodeData) { - return Map.of(); + throw new UnsupportedOperationException(); } - }), CUSTOM(AbstractNodeDataConverter.defaultCustomDialectConverter(EmptyNodeData.class)); + }); private final DialectConverter dialectConverter; @@ -76,12 +75,14 @@ protected List> getDialectConverters() { @Override public Boolean supportNodeType(NodeType nodeType) { - return nodeType.equals(NodeType.DIFY_ITERATION_START); + // 迭代节点的起始节点与迭代节点共享一个data,故转换时不需要提取数据 + return NodeType.EMPTY.equals(nodeType) || NodeType.ITERATION_START.equals(nodeType) + || NodeType.ITERATION_END.equals(nodeType); } @Override public String generateVarName(int count) { - return "__empty__node_" + count; + return "emptyNode" + count; } } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/IterationNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/IterationNodeDataConverter.java index 95786ac661..8e8fd22c09 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/IterationNodeDataConverter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/IterationNodeDataConverter.java @@ -20,8 +20,6 @@ import java.util.Map; import java.util.Optional; import java.util.function.BiConsumer; -import java.util.regex.Matcher; -import java.util.regex.Pattern; import java.util.stream.Stream; import com.alibaba.cloud.ai.studio.admin.generator.model.Variable; @@ -31,6 +29,7 @@ import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata.IterationNodeData; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.AbstractNodeDataConverter; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; +import com.alibaba.cloud.ai.studio.admin.generator.utils.MapReadUtil; import com.fasterxml.jackson.core.JsonProcessingException; import org.springframework.stereotype.Component; @@ -44,7 +43,7 @@ public class IterationNodeDataConverter extends AbstractNodeDataConverter() { + DIFY(new DialectConverter<>() { @Override public Boolean supportDialect(DSLDialectType dialectType) { return DSLDialectType.DIFY.equals(dialectType); @@ -52,44 +51,86 @@ public Boolean supportDialect(DSLDialectType dialectType) { @Override public IterationNodeData parse(Map data) throws JsonProcessingException { - // 获取输入输出的类型,从 array[xxx] 中提取xxx - Pattern typePattern = Pattern.compile("array\\[(.*?)]"); - VariableType inputType = VariableType.OBJECT; - VariableType outputType = VariableType.OBJECT; - Matcher inputTypeMatcher = typePattern - .matcher((String) data.getOrDefault("iterator_input_type", "object")); - Matcher outputTypeMatcher = typePattern.matcher((String) data.getOrDefault("output_type", "object")); - if (inputTypeMatcher.find()) { - inputType = VariableType.fromDifyValue(inputTypeMatcher.group(1)).orElse(VariableType.OBJECT); - } - if (outputTypeMatcher.find()) { - outputType = VariableType.fromDifyValue(outputTypeMatcher.group(1)).orElse(VariableType.OBJECT); - } - List inputSelector = (List) data.get("iterator_selector"); - List outputSelector = (List) data.get("output_selector"); - String startNodeId = (String) data.get("start_node_id"); - String id = (String) data.get("id"); - // 规定输出结果的节点为最后一个节点 - String endNodeId = outputSelector.get(0); - // 返回 - return IterationNodeData.builder() - .id(id) - .inputType(inputType) - .outputType(outputType) - .inputSelector(new VariableSelector(inputSelector.get(0), inputSelector.get(1), "")) - .outputSelector(new VariableSelector(outputSelector.get(0), outputSelector.get(1), "")) - .startNodeId(startNodeId) - .endNodeId(endNodeId) - .inputKey(id + "_input") - .outputKey(id + "_output") - .build(); + IterationNodeData nodeData = new IterationNodeData(); + int parallelCount = Optional + .ofNullable(MapReadUtil.getMapDeepValue(data, Integer.class, "parallel_nums")) + .orElse(1); + nodeData.setParallelCount(parallelCount); + + List inputSelectorList = Optional + .ofNullable(MapReadUtil.safeCastToList( + MapReadUtil.getMapDeepValue(data, List.class, "iterator_selector"), String.class)) + .orElse(List.of("unknown", "unknown")); + nodeData.setInputSelector(new VariableSelector(inputSelectorList.get(0), inputSelectorList.get(1))); + + List outputSelectorList = Optional + .ofNullable(MapReadUtil + .safeCastToList(MapReadUtil.getMapDeepValue(data, List.class, "output_selector"), String.class)) + .orElse(List.of("unknown", "unknown")); + nodeData.setResultSelector(new VariableSelector(outputSelectorList.get(0), outputSelectorList.get(1))); + + return nodeData; } @Override public Map dump(IterationNodeData nodeData) { - return Map.of(); + throw new UnsupportedOperationException(); } - }), CUSTOM(AbstractNodeDataConverter.defaultCustomDialectConverter(IterationNodeData.class)); + }) + + , STUDIO(new DialectConverter<>() { + @Override + public Boolean supportDialect(DSLDialectType dialectType) { + return DSLDialectType.STUDIO.equals(dialectType); + } + + @Override + public IterationNodeData parse(Map data) throws JsonProcessingException { + IterationNodeData nodeData = new IterationNodeData(); + + // 获取必要信息 + int parallelCount = Optional + .ofNullable( + MapReadUtil.getMapDeepValue(data, Integer.class, "config", "node_param", "concurrent_size")) + .orElse(1); + int maxIterationCount = Optional + .ofNullable(MapReadUtil.getMapDeepValue(data, Integer.class, "config", "node_param", "batch_size")) + .orElse(1); + int indexOffset = 1; + + List> inputParamsList = Optional + .ofNullable(MapReadUtil + .safeCastToListWithMap(MapReadUtil.getMapDeepValue(data, List.class, "config", "input_params"))) + .orElse(List.of()); + List> outputParamsList = Optional + .ofNullable(MapReadUtil.safeCastToListWithMap( + MapReadUtil.getMapDeepValue(data, List.class, "config", "output_params"))) + .orElse(List.of()); + String itemKey = Optional.ofNullable(inputParamsList.get(0).get("key").toString()).orElse("item"); + String outputKey = Optional.ofNullable(outputParamsList.get(0).get("key").toString()).orElse("output"); + VariableSelector inputSelector = this.varTemplateToSelector(DSLDialectType.STUDIO, + inputParamsList.get(0).get("value").toString()); + VariableSelector resultSelector = this.varTemplateToSelector(DSLDialectType.STUDIO, + outputParamsList.get(0).get("value").toString()); + + // 设置必要信息 + nodeData.setParallelCount(parallelCount); + nodeData.setMaxIterationCount(maxIterationCount); + nodeData.setIndexOffset(indexOffset); + nodeData.setItemKey(itemKey); + nodeData.setOutputKey(outputKey); + nodeData.setInputSelector(inputSelector); + nodeData.setResultSelector(resultSelector); + return nodeData; + } + + @Override + public Map dump(IterationNodeData nodeData) { + throw new UnsupportedOperationException(); + } + }) + + , CUSTOM(AbstractNodeDataConverter.defaultCustomDialectConverter(IterationNodeData.class)); private final DialectConverter dialectConverter; @@ -123,40 +164,24 @@ public String generateVarName(int count) { @Override public BiConsumer> postProcessConsumer(DSLDialectType dialectType) { return switch (dialectType) { - case DIFY -> emptyProcessConsumer().andThen((nodeData, idToVarName) -> { - nodeData - .setOutputKey(nodeData.getVarName() + "_" + IterationNodeData.getDefaultOutputSchema().getName()); - nodeData.setOutputs(List.of(IterationNodeData.getDefaultOutputSchema())); - nodeData.setOutput(IterationNodeData.getDefaultOutputSchema()); - }).andThen(super.postProcessConsumer(dialectType)).andThen((iterationNodeData, varNames) -> { - // 等待所有的节点都生成了变量名后,补充迭代节点的起始名称 - iterationNodeData - .setStartNodeName(varNames.getOrDefault(iterationNodeData.getStartNodeId(), "unknown")); - iterationNodeData.setEndNodeName(varNames.getOrDefault(iterationNodeData.getEndNodeId(), "unknown")); - - // 更新迭代节点的输入Key - VariableSelector inputSelector = iterationNodeData.getInputs().get(0); - iterationNodeData.setInputKey(inputSelector.getNameInCode()); - - // 更新迭代节点的ResultKey - VariableSelector outputSelector = iterationNodeData.getOutputSelector(); - iterationNodeData.setInnerItemResultKey( - Optional.ofNullable(varNames.get(outputSelector.getNamespace())).orElse("unknown") + "_" - + outputSelector.getName()); + case DIFY, STUDIO -> emptyProcessConsumer().andThen((nodeData, idToVarName) -> { + nodeData.setInputs(List.of(nodeData.getInputSelector(), nodeData.getResultSelector())); + + nodeData.setOutputs(Stream + .of(IterationNodeData.getDefaultOutputSchemas(), + List.of(new Variable(nodeData.getItemKey(), VariableType.OBJECT), + new Variable(nodeData.getOutputKey(), VariableType.ARRAY_OBJECT))) + .flatMap(List::stream) + .toList()); + }).andThen(super.postProcessConsumer(dialectType)).andThen((nodeData, idToVarName) -> { + nodeData.setInputSelector(nodeData.getInputs().get(0)); + nodeData.setResultSelector(nodeData.getInputs().get(1)); + nodeData.setInputs(null); + nodeData.setItemKey(nodeData.getVarName() + "_" + nodeData.getItemKey()); + nodeData.setOutputKey(nodeData.getVarName() + "_" + nodeData.getOutputKey()); }); default -> super.postProcessConsumer(dialectType); }; } - @Override - public Stream extractWorkflowVars(IterationNodeData nodeData) { - return Stream.concat(nodeData.getOutputs().stream(), - Stream.of(new Variable(nodeData.getInnerArrayKey(), VariableType.STRING), - new Variable(nodeData.getInnerStartFlagKey(), VariableType.STRING), - new Variable(nodeData.getInnerEndFlagKey(), VariableType.STRING), - new Variable(nodeData.getInnerItemKey(), nodeData.getInputType()), - new Variable(nodeData.getInnerIndexKey(), VariableType.NUMBER), - new Variable(nodeData.getInnerItemResultKey(), nodeData.getOutputType()))); - } - } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/LLMNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/LLMNodeDataConverter.java index fb8123523f..643f359220 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/LLMNodeDataConverter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/LLMNodeDataConverter.java @@ -20,8 +20,6 @@ import java.util.Map; import java.util.Optional; import java.util.function.BiConsumer; -import java.util.regex.Matcher; -import java.util.regex.Pattern; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -206,8 +204,6 @@ public String generateVarName(int count) { return "LLMNode" + count; } - private static final Pattern VAR_TEMPLATE_PATTERN = Pattern.compile("\\{(\\w+)}"); - @Override public BiConsumer> postProcessConsumer(DSLDialectType dialectType) { return switch (dialectType) { @@ -222,8 +218,7 @@ public BiConsumer> postProcessConsumer(DSLDiale .stream() .map(template -> { String newText = this.convertVarTemplate(dialectType, template.template(), idToVarName); - Matcher matcher = VAR_TEMPLATE_PATTERN.matcher(newText); - List keys = matcher.results().map(m -> m.group(1)).toList(); + List keys = this.getVarTemplateKeys(newText); return new LLMNodeData.MessageTemplate(newText, keys, template.type()); }) .toList(); diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/MiddleOutputNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/MiddleOutputNodeDataConverter.java new file mode 100644 index 0000000000..677bf75ae5 --- /dev/null +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/MiddleOutputNodeDataConverter.java @@ -0,0 +1,102 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.cloud.ai.studio.admin.generator.service.dsl.converter; + +import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeType; +import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata.MiddleOutputNodeData; +import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.AbstractNodeDataConverter; +import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; +import com.alibaba.cloud.ai.studio.admin.generator.utils.MapReadUtil; +import com.fasterxml.jackson.core.JsonProcessingException; +import org.springframework.stereotype.Component; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.BiConsumer; +import java.util.stream.Stream; + +@Component +public class MiddleOutputNodeDataConverter extends AbstractNodeDataConverter { + + @Override + public Boolean supportNodeType(NodeType nodeType) { + return NodeType.MIDDLE_OUTPUT.equals(nodeType); + } + + @Override + public String generateVarName(int count) { + return "middleOutput" + count; + } + + @Override + public BiConsumer> postProcessConsumer(DSLDialectType dialectType) { + return switch (dialectType) { + case STUDIO -> emptyProcessConsumer().andThen((nodeData, idToVarName) -> { + // 设置输出键 + nodeData.setOutputs(MiddleOutputNodeData.getDefaultOutputSchemas(dialectType)); + nodeData.setOutputKey(nodeData.getVarName() + "_" + nodeData.getOutputs().get(0).getName()); + // 将输出模板进行处理 + nodeData + .setOutputTemplate(this.convertVarTemplate(dialectType, nodeData.getOutputTemplate(), idToVarName)); + nodeData.setVarKeys(this.getVarTemplateKeys(nodeData.getOutputTemplate())); + }).andThen(super.postProcessConsumer(dialectType)); + default -> super.postProcessConsumer(dialectType); + }; + } + + private enum MiddleOutputNodeConverter { + + STUDIO(new DialectConverter<>() { + @Override + public Boolean supportDialect(DSLDialectType dialectType) { + return DSLDialectType.STUDIO.equals(dialectType); + } + + @Override + public MiddleOutputNodeData parse(Map data) throws JsonProcessingException { + MiddleOutputNodeData nodeData = new MiddleOutputNodeData(); + String outputTemplate = MapReadUtil.getMapDeepValue(data, String.class, "config", "node_param", + "output"); + nodeData.setOutputTemplate(Optional.ofNullable(outputTemplate).orElse("")); + return nodeData; + } + + @Override + public Map dump(MiddleOutputNodeData nodeData) { + throw new UnsupportedOperationException(); + } + }), CUSTOM(defaultCustomDialectConverter(MiddleOutputNodeData.class)); + + private final DialectConverter converter; + + MiddleOutputNodeConverter(DialectConverter converter) { + this.converter = converter; + } + + public DialectConverter dialectConverter() { + return converter; + } + + } + + @Override + protected List> getDialectConverters() { + return Stream.of(MiddleOutputNodeConverter.values()).map(MiddleOutputNodeConverter::dialectConverter).toList(); + } + +} diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/NodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/NodeSection.java index b8a5a2dba9..c847ed4485 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/NodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/NodeSection.java @@ -17,7 +17,6 @@ import java.io.InputStream; import java.util.List; -import java.util.Map; import java.util.function.Supplier; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.Edge; @@ -39,6 +38,12 @@ public interface NodeSection { String render(Node node, String varName); + /** + * 返回当前节点需要导入的类列表 + * @return 类列表 + */ + List getImports(); + default String escape(String input) { if (input == null) { return ""; @@ -50,19 +55,23 @@ default String escape(String input) { .replace("\t", "\\t"); } - // todo: 完善条件边的EdgeAction /** - * 生成条件边的SAA代码 - * @param nodeData 节点数据 - * @param nodeMap nodeId与node的映射 - * @param entry 包含当前节点ID与当前节点出发的条件边List - * @param varNames nodeId与nodeVarName的映射 - * @return 条件边代码 + * 生成stateGraph边的代码。edge列表为从当前节点出发的边。 如果当前节点有条件边,则应重写本方法。本方法默认为无条件的边。 + * @param nodeData 当前节点(边起始节点)的数据 + * @param edges 边列表,且边的source和handle应格式化为varName + * @return 生成的代码 */ - default String renderConditionalEdges(T nodeData, Map nodeMap, Map.Entry> entry, - Map varNames) { - System.err.println("Unsupported Conditional Edges!"); - return ""; + default String renderEdges(T nodeData, List edges) { + StringBuilder sb = new StringBuilder(); + sb.append(String.format("// Edges For [%s]%n", nodeData.getVarName())); + if (edges.isEmpty()) { + return ""; + } + sb.append(String.format("stateGraph%n")); + edges.forEach( + edge -> sb.append(String.format(".addEdge(\"%s\", \"%s\")%n", edge.getSource(), edge.getTarget()))); + sb.append(String.format(";%n%n")); + return sb.toString(); } /** diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java index 4d54f8de09..59be8b4a96 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java @@ -21,12 +21,12 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.Arrays; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -38,7 +38,6 @@ import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeData; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeType; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.Workflow; -import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata.CodeNodeData; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLAdapter; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; import com.alibaba.cloud.ai.studio.admin.generator.service.generator.GraphProjectDescription; @@ -74,14 +73,10 @@ public class WorkflowProjectGenerator implements ProjectGenerator { private final String PACKAGE_NAME = "packageName"; - private final String HAS_CODE = "hasCode"; - private final List dslAdapters; private final TemplateRenderer templateRenderer; - private final List> nodeNodeSections; - private final Map> nodeSectionMap; public WorkflowProjectGenerator(List dslAdapters, @@ -90,7 +85,6 @@ public WorkflowProjectGenerator(List dslAdapters, this.dslAdapters = dslAdapters; this.templateRenderer = templateRenderer .getIfAvailable(() -> new MustacheTemplateRenderer("classpath:/templates")); - this.nodeNodeSections = nodeNodeSections; this.nodeSectionMap = nodeNodeSections.stream().map(nodeSection -> { List nodeTypeList = Arrays.stream(NodeType.values()).filter(nodeSection::support).toList(); if (nodeTypeList.isEmpty()) { @@ -121,8 +115,6 @@ public void generate(GraphProjectDescription projectDescription, Path projectRoo Map varNames = nodes.stream() .collect(Collectors.toMap(Node::getId, n -> n.getData().getVarName())); - boolean hasCode = nodes.stream().map(Node::getData).anyMatch(nd -> nd instanceof CodeNodeData); - String assistMethodCode = renderAssistMethodCode(nodes, projectDescription.getDslDialectType()); String stateSectionStr = renderStateSections( Stream.of(workflow.getWorkflowVars(), workflow.getEnvVars()).flatMap(List::stream).toList()); @@ -132,7 +124,7 @@ public void generate(GraphProjectDescription projectDescription, Path projectRoo Map graphBuilderModel = Map.of(PACKAGE_NAME, projectDescription.getPackageName(), GRAPH_BUILDER_STATE_SECTION, stateSectionStr, GRAPH_BUILDER_NODE_SECTION, nodeSectionStr, GRAPH_BUILDER_EDGE_SECTION, edgeSectionStr, GRAPH_BUILDER_IMPORT_SECTION, renderImportSection(workflow), - HAS_CODE, hasCode, GRAPH_BUILDER_ASSIST_METHOD_CODE, assistMethodCode); + GRAPH_BUILDER_ASSIST_METHOD_CODE, assistMethodCode); Map graphRunControllerModel = Map.of(PACKAGE_NAME, projectDescription.getPackageName()); renderAndWriteTemplates(List.of(GRAPH_BUILDER_TEMPLATE_NAME, GRAPH_RUN_TEMPLATE_NAME), List.of(graphBuilderModel, graphRunControllerModel), projectRoot, projectDescription); @@ -198,67 +190,23 @@ private String renderNodeSections(List nodes, Map varNames return sb.toString(); } - // TODO: 目前这里渲染edge的逻辑与Dify转换高度耦合,需要优化 private String renderEdgeSections(List edges, List nodes, Map varNames) { - StringBuilder sb = new StringBuilder(); - Map nodeMap = nodes.stream().collect(Collectors.toMap(Node::getId, n -> n)); - - // conditional edge set: sourceId -> List - Map> conditionalEdgesMap = edges.stream() - .filter(e -> e.getSourceHandle() != null && !"source".equals(e.getSourceHandle())) - .collect(Collectors.groupingBy(Edge::getSource)); - - // Set to track rendered edges to avoid duplicates - Set renderedEdges = new HashSet<>(); + // nodeVarName -> node的映射 + Map nodeMap = nodes.stream() + .collect(Collectors.toMap(node -> node.getData().getVarName(), Function.identity())); - // common edge - for (Edge edge : edges) { - String sourceId = edge.getSource(); - String targetId = edge.getTarget(); - String srcVar = varNames.get(sourceId); - String tgtVar = varNames.get(targetId); + // 根据source进行分组 + Map> edgeGroup = edges.stream().collect(Collectors.groupingBy(Edge::getSource)); - Node sourceNode = nodeMap.get(sourceId); - NodeType sourceType = sourceNode != null ? sourceNode.getType() : null; - - // Skip if already rendered as conditional - if (edge.getSourceHandle() != null && !"source".equals(edge.getSourceHandle()) && edge.isDify()) { - continue; - } - - // 迭代节点作为边的终止点时直接使用节点ID,作为边的起始点时使用ID_out - // todo: 修改迭代节点终止ID,防止与变量冲突(Dify不冲突) - if (sourceType != null && NodeType.ITERATION.equals(sourceType) && edge.isDify()) { - srcVar += "_out"; - } - - String key = srcVar + "->" + tgtVar; - if (renderedEdges.contains(key)) { - continue; - } - renderedEdges.add(key); - - // START and END special handling - if (NodeType.START.equals(sourceType)) { - sb.append(String.format("stateGraph.addEdge(START, \"%s\");%n", tgtVar)); - } - else { - sb.append(String.format("stateGraph.addEdge(\"%s\", \"%s\");%n", srcVar, tgtVar)); - } - } + StringBuilder sb = new StringBuilder(); - // conditional edge(aggregate by sourceId) - for (Map.Entry> entry : conditionalEdgesMap.entrySet()) { - String nodeId = entry.getKey(); - Node node = nodeMap.get(nodeId); - NodeType nodeType = node.getType(); - for (NodeSection section : nodeNodeSections) { - if (section.support(nodeType)) { - String edgeCode = section.renderConditionalEdges(node.getData(), nodeMap, entry, varNames); - sb.append(edgeCode); - } - } - } + // 调用每一个source节点的renderEdges方法 + edgeGroup.forEach((varName, edgeList) -> { + NodeType nodeType = nodeMap.get(varName).getType(); + @SuppressWarnings("unchecked") + NodeSection section = (NodeSection) nodeSectionMap.get(nodeType); + sb.append(section.renderEdges(nodeMap.get(varName).getData(), edgeList)); + }); // 统一生成end节点到StateGraph.END的边(避免边重复) List endNodeList = nodes.stream() @@ -268,6 +216,7 @@ private String renderEdgeSections(List edges, List nodes, Map sb.append(String.format("%n.addEdge(\"%s\", END)", endName))); sb.append(String.format(";%n")); @@ -277,65 +226,11 @@ private String renderEdgeSections(List edges, List nodes, Map> nodeTypeToClass = Map.ofEntries( - Map.entry(NodeType.ANSWER, List.of("com.alibaba.cloud.ai.graph.node.AnswerNode")), - Map.entry(NodeType.CODE, List.of("com.alibaba.cloud.ai.graph.node.code.CodeExecutorNodeAction", - "com.alibaba.cloud.ai.graph.node.code.entity.CodeExecutionConfig", - "com.alibaba.cloud.ai.graph.node.code.CodeExecutor", - "com.alibaba.cloud.ai.graph.node.code.LocalCommandlineCodeExecutor", "java.io.IOException", - "java.nio.file.Files", "java.nio.file.Path", "java.util.stream.Collectors")), - Map.entry(NodeType.AGENT, - List.of("com.alibaba.cloud.ai.graph.node.AgentNode", - "org.springframework.ai.tool.ToolCallback")), - Map.entry(NodeType.LLM, - List.of("org.springframework.ai.chat.messages.Message", - "org.springframework.ai.chat.messages.AssistantMessage", - "org.springframework.ai.chat.messages.MessageType", - "org.springframework.ai.chat.messages.SystemMessage", - "org.springframework.ai.chat.messages.UserMessage", - "com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions", - "org.springframework.beans.factory.annotation.Autowired", "java.util.Optional")), - Map.entry(NodeType.BRANCH, - List.of("com.alibaba.cloud.ai.graph.node.BranchNode", - "static com.alibaba.cloud.ai.graph.action.AsyncEdgeAction.edge_async")), - Map.entry(NodeType.DOC_EXTRACTOR, List.of("com.alibaba.cloud.ai.graph.node.DocumentExtractorNode")), - Map.entry(NodeType.HTTP, - List.of("com.alibaba.cloud.ai.graph.node.HttpNode", "org.springframework.http.HttpMethod")), - Map.entry(NodeType.LIST_OPERATOR, - List.of("com.alibaba.cloud.ai.graph.node.ListOperatorNode", "java.util.Comparator")), - Map.entry(NodeType.QUESTION_CLASSIFIER, - List.of("com.alibaba.cloud.ai.graph.node.QuestionClassifierNode", - "static com.alibaba.cloud.ai.graph.action.AsyncEdgeAction.edge_async")), - Map.entry(NodeType.PARAMETER_PARSING, - List.of("com.alibaba.cloud.ai.graph.node.ParameterParsingNode", "java.util.stream.Collectors")), - Map.entry(NodeType.TEMPLATE_TRANSFORM, - List.of("com.alibaba.cloud.ai.graph.node.TemplateTransformNode")), - Map.entry(NodeType.TOOL, - List.of("com.alibaba.cloud.ai.graph.node.ToolNode", "java.util.function.Function", - "org.springframework.ai.tool.function.FunctionToolCallback")), - Map.entry(NodeType.RETRIEVER, List.of("com.alibaba.cloud.ai.graph.node.KnowledgeRetrievalNode", - "org.springframework.ai.embedding.EmbeddingModel", "org.springframework.ai.reader.TextReader", - "org.springframework.ai.transformer.splitter.TokenTextSplitter", - "org.springframework.ai.vectorstore.SimpleVectorStore", - "org.springframework.ai.vectorstore.VectorStore", - "org.springframework.beans.factory.annotation.Value", "org.springframework.core.io.Resource", - "org.springframework.ai.document.Document", - "org.springframework.beans.factory.annotation.Autowired", - "org.springframework.core.io.ResourceLoader", "java.util.Optional")), - Map.entry(NodeType.AGGREGATOR, - List.of("com.alibaba.cloud.ai.graph.node.VariableAggregatorNode", - "java.util.stream.Collectors")), - Map.entry(NodeType.ASSIGNER, List.of("com.alibaba.cloud.ai.graph.node.AssignerNode")), - Map.entry(NodeType.ITERATION, List.of("com.alibaba.cloud.ai.graph.node.IterationNode")), - Map.entry(NodeType.END, List.of("java.util.stream.Stream", "java.util.stream.Collectors", - "org.springframework.ai.chat.prompt.PromptTemplate"))); - + // construct a set of node types Set uniqueTypes = workflow.getGraph() .getNodes() .stream() .map(Node::getType) - .filter(nodeTypeToClass::containsKey) .collect(Collectors.toSet()); if (uniqueTypes.isEmpty()) { @@ -344,7 +239,8 @@ private String renderImportSection(Workflow workflow) { StringBuilder sb = new StringBuilder(); uniqueTypes.stream() - .map(nodeTypeToClass::get) + .map(nodeSectionMap::get) + .map(NodeSection::getImports) .flatMap(List::stream) .distinct() .forEach(className -> sb.append("import ").append(className).append(";\n")); diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/AgentNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/AgentNodeSection.java index 82fda33b1f..81780c31c8 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/AgentNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/AgentNodeSection.java @@ -22,6 +22,7 @@ import com.alibaba.cloud.ai.studio.admin.generator.service.generator.workflow.NodeSection; import org.springframework.stereotype.Component; +import java.util.List; import java.util.stream.Collectors; @Component @@ -67,4 +68,9 @@ public String render(Node node, String varName) { return sb.toString(); } + @Override + public List getImports() { + return List.of("com.alibaba.cloud.ai.graph.node.AgentNode", "org.springframework.ai.tool.ToolCallback"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/AnswerNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/AnswerNodeSection.java index 0bf1e80f52..e1ffb99b29 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/AnswerNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/AnswerNodeSection.java @@ -23,6 +23,8 @@ import org.springframework.stereotype.Component; +import java.util.List; + @Component public class AnswerNodeSection implements NodeSection { @@ -59,4 +61,9 @@ public String render(Node node, String varName) { return sb.toString(); } + @Override + public List getImports() { + return List.of("com.alibaba.cloud.ai.graph.node.AnswerNode"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/AssignerNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/AssignerNodeSection.java index 9df2d0f59d..1bfdb1e242 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/AssignerNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/AssignerNodeSection.java @@ -22,6 +22,8 @@ import org.springframework.stereotype.Component; +import java.util.List; + @Component public class AssignerNodeSection implements NodeSection { @@ -52,4 +54,9 @@ public String render(Node node, String varName) { return sb.toString(); } + @Override + public List getImports() { + return List.of("com.alibaba.cloud.ai.graph.node.AssignerNode"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/BranchNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/BranchNodeSection.java index 61c5ee6920..384fe9d2c9 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/BranchNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/BranchNodeSection.java @@ -17,6 +17,7 @@ import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -30,6 +31,7 @@ import org.springframework.stereotype.Component; +// TODO: 对于Dify的条件渲染,将CaseID格式化为比较易懂的格式 @Component public class BranchNodeSection implements NodeSection { @@ -52,9 +54,8 @@ public String render(Node node, String varName) { } @Override - public String renderConditionalEdges(BranchNodeData branchNodeData, Map nodeMap, - Map.Entry> entry, Map varNames) { - String srcVar = varNames.get(entry.getKey()); + public String renderEdges(BranchNodeData branchNodeData, List edges) { + String srcVar = branchNodeData.getVarName(); StringBuilder sb = new StringBuilder(); List cases = branchNodeData.getCases(); @@ -70,7 +71,7 @@ public String renderConditionalEdges(BranchNodeData branchNodeData, Map edgeCaseMap = entry.getValue() - .stream() + Map edgeCaseMap = edges.stream() .collect(Collectors.toMap(Edge::getSourceHandle, Edge::getTarget)); String edgeCaseMapStr = "Map.of(" + edgeCaseMap.entrySet() .stream() - .flatMap(e -> Stream.of(e.getKey(), varNames.getOrDefault(e.getValue(), "unknown"))) + .flatMap(e -> Stream.of(e.getKey(), e.getValue())) .map(v -> String.format("\"%s\"", v)) .collect(Collectors.joining(", ")) + ")"; @@ -105,9 +105,9 @@ public String renderConditionalEdges(BranchNodeData branchNodeData, Map nodeMap) { + private String generateSafeVariableAccess(Case.Condition condition) { String varType = condition.getVarType(); - String variablePath = buildVariablePath(condition, nodeMap); + String variablePath = buildVariablePath(condition); switch (varType.toLowerCase()) { case "file": @@ -148,29 +148,17 @@ private String generateSafeVariableAccess(Case.Condition condition, Map nodeMap) { + private String buildVariablePath(Case.Condition condition) { VariableSelector variableSelector = condition.getVariableSelector(); if (variableSelector == null) { return "unknown"; } + return Optional.ofNullable(variableSelector.getNameInCode()).orElse("unknown"); + } - // 其中第一个是节点ID,第二个是变量名,第三个是属性 - - String nodeId = variableSelector.getNamespace(); - String variableName = variableSelector.getName(); - - // 如果有节点映射,尝试获取正确的变量名 - if (nodeMap.containsKey(nodeId)) { - Node inputNode = nodeMap.get(nodeId); - if (inputNode.getData().getOutputs() != null && !inputNode.getData().getOutputs().isEmpty()) { - // 使用输出定义中的变量名 - String outputName = inputNode.getData().getOutputs().get(0).getName(); - return outputName != null ? outputName : variableName; - } - } - - // 如果无法从节点映射获取,直接使用变量名 - return variableName != null ? variableName : "unknown"; + @Override + public List getImports() { + return List.of("static com.alibaba.cloud.ai.graph.action.AsyncEdgeAction.edge_async"); } } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/CodeNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/CodeNodeSection.java index fe6aa0c34d..44d3e350f2 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/CodeNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/CodeNodeSection.java @@ -15,16 +15,18 @@ */ package com.alibaba.cloud.ai.studio.admin.generator.service.generator.workflow.sections; -import java.util.stream.Collectors; - import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.Node; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeType; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata.CodeNodeData; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; import com.alibaba.cloud.ai.studio.admin.generator.service.generator.workflow.NodeSection; +import com.alibaba.cloud.ai.studio.admin.generator.utils.ObjectToCodeUtil; import org.springframework.stereotype.Component; +import java.util.List; + +// TODO: 支持异常分支 @Component public class CodeNodeSection implements NodeSection { @@ -35,68 +37,91 @@ public boolean support(NodeType nodeType) { @Override public String render(Node node, String varName) { - CodeNodeData data = (CodeNodeData) node.getData(); - StringBuilder sb = new StringBuilder(); - String id = node.getId(); - - sb.append(String.format("// —— CodeNode [%s] ——%n", id)); - - sb.append(String.format("CodeExecutorNodeAction %s = CodeExecutorNodeAction.builder()%n", varName)); - - sb.append(" .codeExecutor(codeExecutor) // 注入的 CodeExecutor Bean\n"); - - sb.append(String.format(" .codeLanguage(\"%s\")%n", data.getCodeLanguage())); - - String escaped = data.getCode().replace("\\", "\\\\").replace("\"\"\"", "\\\"\\\"\\\""); - sb.append(" .code(\"\"\"\n").append(escaped).append("\n\"\"\")\n"); - - sb.append(" .config(codeExecutionConfig) // 注入的 CodeExecutionConfig Bean\n"); - - if (!data.getInputs().isEmpty()) { - String params = data.getInputs() - .stream() - .map(sel -> String.format("\"%s\", \"%s\"", sel.getLabel(), sel.getNameInCode())) - .collect(Collectors.joining(", ")); - sb.append(String.format(" .params(Map.of(%s))%n", params)); - } - if (data.getOutputKey() != null) { - sb.append(String.format(".outputKey(\"%s\")%n", escape(data.getOutputKey()))); - } - - sb.append(" .build();\n"); - - // 辅助节点代码,包装codeNode,将他的返回值变量解包 - String assistantNodeCode = String.format("wrapperCodeNodeAction(%s, \"%s\", \"%s\")", varName, - data.getOutputKey(), varName); - - sb.append(String.format("stateGraph.addNode(\"%s\", AsyncNodeAction.node_async(%s));%n%n", varName, - assistantNodeCode)); - - return sb.toString(); + CodeNodeData nodeData = ((CodeNodeData) node.getData()); + return String.format(""" + // —— CodeNode [%s] —— + CodeExecutorNodeAction %s = CodeExecutorNodeAction.builder() + .codeExecutor(codeExecutor) + .codeLanguage("%s") + .code(%s) + .codeStyle(%s) + .config(codeExecutionConfig) + .params(%s) + .outputKey("%s") + .build(); + stateGraph.addNode("%s", AsyncNodeAction.node_async( + wrapperCodeNodeAction(%s, "%s", "%s", %d, %d, %s) + )); + + """, node.getId(), varName, nodeData.getCodeLanguage(), ObjectToCodeUtil.toCode(nodeData.getCode()), + ObjectToCodeUtil.toCode(nodeData.getCodeStyle()), ObjectToCodeUtil.toCode(nodeData.getInputParams()), + nodeData.getOutputKey(), varName, varName, nodeData.getOutputKey(), varName, + nodeData.getMaxRetryCount(), nodeData.getRetryIntervalMs(), + ObjectToCodeUtil.toCode(nodeData.getDefaultValue())); } @Override public String assistMethodCode(DSLDialectType dialectType) { return switch (dialectType) { - case DIFY -> """ - private NodeAction wrapperCodeNodeAction(NodeAction codeNodeAction, String key, String nodeName) { - return state -> { - // 将代码运行的结果拆包 - Map result = codeNodeAction.apply(state); - Object object = result.get(key); - if(!(object instanceof Map)) { - return Map.of(); - } - return ((Map) object).entrySet().stream() - .collect(Collectors.toMap( - entry -> nodeName + "_" + entry.getKey(), - Map.Entry::getValue - )); - }; - } - """; + case DIFY, STUDIO -> + """ + private static final CodeExecutionConfig codeExecutionConfig; + private static final CodeExecutor codeExecutor; + + static { + // todo: configure your own code execution configuration + try { + Path tempDir = Files.createTempDirectory("code-execution-workdir-"); + tempDir.toFile().deleteOnExit(); + codeExecutionConfig = new CodeExecutionConfig().setWorkDir(tempDir.toString()); + codeExecutor = new LocalCommandlineCodeExecutor(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private NodeAction wrapperCodeNodeAction(NodeAction codeNodeAction, String key, + String nodeName, int maxRetryCount, int retryIntervalMs, Map defaultValue) { + return state -> { + int count = maxRetryCount; + while (count-- > 0) { + try { + // 将代码运行的结果拆包 + Map result = codeNodeAction.apply(state); + Object object = result.get(key); + if(!(object instanceof Map)) { + throw new RuntimeException("unexcepted result"); + } + return ((Map) object).entrySet().stream() + .collect(Collectors.toMap( + entry -> nodeName + "_" + entry.getKey(), + Map.Entry::getValue + )); + } catch (Exception e) { + Thread.sleep(retryIntervalMs); + } + } + if(defaultValue != null) { + return defaultValue; + } else { + throw new RuntimeException("code execution failed!"); + } + }; + } + """; default -> ""; }; } + @Override + public List getImports() { + return List.of("com.alibaba.cloud.ai.graph.node.code.CodeExecutorNodeAction", + "com.alibaba.cloud.ai.graph.node.code.entity.CodeExecutionConfig", + "com.alibaba.cloud.ai.graph.node.code.CodeExecutor", + "com.alibaba.cloud.ai.graph.node.code.LocalCommandlineCodeExecutor", "java.io.IOException", + "java.nio.file.Files", "java.nio.file.Path", "java.util.stream.Collectors", + "com.alibaba.cloud.ai.graph.node.code.entity.CodeParam", + "com.alibaba.cloud.ai.graph.node.code.entity.CodeStyle"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/DocumentExtractorNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/DocumentExtractorNodeSection.java index 00572120ec..f076fa3b6d 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/DocumentExtractorNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/DocumentExtractorNodeSection.java @@ -60,4 +60,9 @@ public String render(Node node, String varName) { return sb.toString(); } + @Override + public List getImports() { + return List.of("com.alibaba.cloud.ai.graph.node.DocumentExtractorNode"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/EmptyNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/EmptyNodeSection.java index 126b988846..32364f20a1 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/EmptyNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/EmptyNodeSection.java @@ -23,6 +23,8 @@ import org.springframework.stereotype.Component; +import java.util.List; + /** * @author vlsmb * @since 2025/7/23 @@ -32,14 +34,13 @@ public class EmptyNodeSection implements NodeSection { @Override public boolean support(NodeType nodeType) { - return nodeType.equals(NodeType.DIFY_ITERATION_START); + return nodeType.equals(NodeType.EMPTY); } @Override public String render(Node node, String varName) { - EmptyNodeData data = (EmptyNodeData) node.getData(); StringBuilder sb = new StringBuilder(); - String id = data.getId(); + String id = node.getId(); sb.append("// —— Empty Node [").append(id).append("] ——\n"); sb.append("stateGraph.addNode(\"") .append(varName) @@ -47,4 +48,9 @@ public String render(Node node, String varName) { return sb.toString(); } + @Override + public List getImports() { + return List.of(); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/EndNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/EndNodeSection.java index 6bbb1809d0..461833a1ee 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/EndNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/EndNodeSection.java @@ -26,6 +26,7 @@ import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata.EndNodeData; import com.alibaba.cloud.ai.studio.admin.generator.service.generator.workflow.NodeSection; +import com.alibaba.cloud.ai.studio.admin.generator.utils.ObjectToCodeUtil; import org.springframework.stereotype.Component; @Component @@ -49,13 +50,14 @@ public String render(Node node, String varName) { if ("text".equalsIgnoreCase(data.getOutputType())) { // 如果输出类型为text,则使用对应的输出模板输出最终结果 if (data.getTextTemplateVars().isEmpty()) { - codeStr = String.format("state -> Map.of(\"output\", \"%s\")", data.getTextTemplate()); + codeStr = String.format("state -> Map.of(\"output\", %s)", + ObjectToCodeUtil.toCode(data.getTextTemplate())); } else { codeStr = String.format(""" state -> { - String template = "%s"; - Map params = Stream.of(%s) + String template = %s; + Map params = %s.stream() .collect(Collectors.toMap( key -> key, key -> state.value(key).orElse(""), @@ -63,11 +65,9 @@ public String render(Node node, String varName) { template = new PromptTemplate(template).render(params); return Map.of("output", template); } - """, data.getTextTemplate(), - data.getTextTemplateVars() - .stream() - .map(s -> String.format("\"%s\"", s)) - .collect(Collectors.joining(", "))); + + """, ObjectToCodeUtil.toCode(data.getTextTemplate()), + ObjectToCodeUtil.toCode(data.getTextTemplateVars())); } } else { @@ -86,7 +86,14 @@ public String render(Node node, String varName) { .append("\", AsyncNodeAction.node_async(") .append(codeStr) .append("));"); + sb.append(String.format("%n")); return sb.toString(); } + @Override + public List getImports() { + return List.of("java.util.stream.Stream", "java.util.stream.Collectors", + "org.springframework.ai.chat.prompt.PromptTemplate"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/HttpNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/HttpNodeSection.java index 23be5108e6..ac40e57073 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/HttpNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/HttpNodeSection.java @@ -16,6 +16,7 @@ package com.alibaba.cloud.ai.studio.admin.generator.service.generator.workflow.sections; +import java.util.List; import java.util.Map; import com.alibaba.cloud.ai.graph.node.HttpNode; @@ -126,4 +127,9 @@ private NodeAction wrapperHttpNodeAction(NodeAction httpNodeAction, String varNa }; } + @Override + public List getImports() { + return List.of("com.alibaba.cloud.ai.graph.node.HttpNode", "org.springframework.http.HttpMethod"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/HumanNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/HumanNodeSection.java index 243eac3824..8bcf06f607 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/HumanNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/HumanNodeSection.java @@ -82,4 +82,9 @@ public String render(Node node, String varName) { return sb.toString(); } + @Override + public List getImports() { + return List.of(); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/IterationNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/IterationNodeSection.java index 05d858989a..b9ce707318 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/IterationNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/IterationNodeSection.java @@ -16,90 +16,200 @@ package com.alibaba.cloud.ai.studio.admin.generator.service.generator.workflow.sections; -import com.alibaba.cloud.ai.studio.admin.generator.model.VariableType; +import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.Edge; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.Node; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeType; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata.IterationNodeData; +import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; import com.alibaba.cloud.ai.studio.admin.generator.service.generator.workflow.NodeSection; import org.springframework.stereotype.Component; +import java.util.List; + /** * @author vlsmb * @since 2025/7/23 */ +// TODO: 支持并行模式、错误处理,支持Studio的默认输入值,支持Studio的多输入/多输出 @Component public class IterationNodeSection implements NodeSection { @Override public boolean support(NodeType nodeType) { - return nodeType.equals(NodeType.ITERATION); + return NodeType.ITERATION.equals(nodeType); } @Override public String render(Node node, String varName) { - // 构建Iteration.Start -> Iteration -> Iteration.End节点 - IterationNodeData data = (IterationNodeData) node.getData(); - StringBuilder sb = new StringBuilder(); - - // 获取输入输出的泛型 - String inputType = "Map"; - String outputType = "Map"; - if (VariableType.STRING.equals(data.getInputType())) { - inputType = "String"; + // 迭代节点在转换为Workflow的时候已经拆分为多个节点,故本方法返回空 + return ""; + } + + @Override + public String renderEdges(IterationNodeData nodeData, List edges) { + return ""; + } + + @Override + public List getImports() { + return List.of("java.util.ArrayList", "java.util.Arrays"); + } + + // 规定迭代节点的start为iterationVarName_start,end为iterationVarName_end + + @Component + public static class IterationStartNodeSection implements NodeSection { + + @Override + public boolean support(NodeType nodeType) { + return NodeType.ITERATION_START.equals(nodeType); } - else if (VariableType.NUMBER.equals(data.getInputType())) { - inputType = "Number"; + + @Override + public String render(Node node, String varName) { + IterationNodeData nodeData = ((IterationNodeData) node.getData()); + return String.format(""" + // Iteration [%s] Start Node + stateGraph.addNode("%s", AsyncNodeAction.node_async( + createIterationStartAction("%s", "%s", "%s", "%s", "%s", %d) + )); + + """, node.getId(), varName, nodeData.getInputSelector().getNameInCode(), + nodeData.getSourceVarName() + "_state", nodeData.getItemKey(), + nodeData.getSourceVarName() + "_index", nodeData.getSourceVarName() + "_isFinished", + nodeData.getIndexOffset()); } - if (VariableType.STRING.equals(data.getOutputType())) { - outputType = "String"; + + // TODO: 添加辅助节点以支持迭代起始节点并行 + @Override + public String renderEdges(IterationNodeData nodeData, List edges) { + Edge edge = edges.get(0); + return String.format(""" + // Iteration [%s] Start Edge + stateGraph.addConditionalEdges("%s", AsyncEdgeAction.edge_async( + state -> { + Boolean b = state.value("%s", false); + return b ? "end" : "iteration"; + } + ), Map.of("end", "%s", "iteration", "%s")); + + """, nodeData.getSourceVarName(), nodeData.getSourceVarName() + "_start", + nodeData.getSourceVarName() + "_isFinished", nodeData.getSourceVarName() + "_end", + edge.getTarget()); } - if (VariableType.NUMBER.equals(data.getOutputType())) { - outputType = "Number"; + + @Override + public String assistMethodCode(DSLDialectType dialectType) { + return """ + private NodeAction createIterationStartAction( + String arrayKey, String stateKey, + String itemKey, String indexKey, String flagKey, + int indexOffset) { + return state -> { + Object arrayObj = state.value(arrayKey).orElse(List.of()); + List stateList = state.value(stateKey, List.class).orElse(null); + + List arrayList; + if (stateList == null) { + // the first time in iteration + if (arrayObj instanceof List) { + arrayList = new ArrayList<>((List) arrayObj); + } else if (arrayObj.getClass().isArray()) { + arrayList = new ArrayList<>(Arrays.stream((Object[])arrayObj).toList()); + } else { + throw new IllegalStateException("value {" + arrayKey + "} is not an array!"); + } + int len = arrayList.size(); + stateList = new ArrayList<>(); + for (int i = 0; i < len; i++) { + stateList.add(i); + } + } else { + arrayList = (List) arrayObj; + } + + if(stateList.isEmpty()) { + return Map.of(flagKey, true); + } + int index = stateList.get(0); + Object item = arrayList.get(index); + stateList.remove(0); + return Map.of(arrayKey, arrayList, stateKey, stateList, itemKey, item, + indexKey, index + indexOffset, flagKey, false); + }; + } + """; + } + + @Override + public List getImports() { + return List.of("java.util.ArrayList", "java.util.Arrays"); + } + + } + + @Component + public static class IterationEndNodeSection implements NodeSection { + + @Override + public boolean support(NodeType nodeType) { + return NodeType.ITERATION_END.equals(nodeType); + } + + @Override + public String render(Node node, String varName) { + IterationNodeData nodeData = ((IterationNodeData) node.getData()); + return String.format(""" + // Iteration [%s] End Node + stateGraph.addNode("%s", AsyncNodeAction.node_async( + createIterationEndAction("%s", "%s", "%s") + )); + + """, nodeData.getSourceVarName(), varName, nodeData.getSourceVarName() + "_isFinished", + nodeData.getResultSelector().getNameInCode(), nodeData.getOutputKey()); + } + + // TODO: 添加辅助节点以支持迭代终止节点并行 + @Override + public String renderEdges(IterationNodeData nodeData, List edges) { + Edge edge = edges.get(0); + return String.format(""" + // Iteration [%s] End Edge + stateGraph.addConditionalEdges("%s", AsyncEdgeAction.edge_async( + state -> { + Boolean b = state.value("%s", false); + return b ? "finish" : "start"; + } + ), Map.of("finish", "%s", "start", "%s")); + + """, nodeData.getSourceVarName(), nodeData.getVarName(), + nodeData.getSourceVarName() + "_isFinished", edge.getTarget(), + nodeData.getSourceVarName() + "_start"); + } + + @Override + public String assistMethodCode(DSLDialectType dialectType) { + return """ + private NodeAction createIterationEndAction(String flagKey, String resultKey, String outputKey) { + return state -> { + boolean flag = state.value(flagKey, Boolean.class).orElse(true); + List outputList = state.value(outputKey, List.class).orElse(new ArrayList<>()); + if(flag) { + return Map.of(outputKey, outputList); + } + outputList.add(state.value(resultKey).orElseThrow()); + return Map.of(outputKey, outputList); + }; + } + """; + } + + @Override + public List getImports() { + return List.of("java.util.ArrayList", "java.util.Arrays"); } - sb.append("// —— IterationNode [").append(data.getId()).append("] ——\n"); - sb.append("IterationNode.<") - .append(inputType) - .append(", ") - .append(outputType) - .append(">converter()\n") - .append(".subGraphStartNodeName(\"") - .append(data.getStartNodeName()) - .append("\")\n") - .append(".subGraphEndNodeName(\"") - .append(data.getEndNodeName()) - .append("\")\n") - .append(".tempArrayKey(\"") - .append(data.getInnerArrayKey()) - .append("\")\n") - .append(".tempStartFlagKey(\"") - .append(data.getInnerStartFlagKey()) - .append("\")\n") - .append(".tempEndFlagKey(\"") - .append(data.getInnerEndFlagKey()) - .append("\")\n") - .append(".tempIndexKey(\"") - .append(data.getInnerIndexKey()) - .append("\")\n") - .append(".iteratorItemKey(\"") - .append(data.getInnerItemKey()) - .append("\")\n") - .append(".iteratorResultKey(\"") - .append(data.getInnerItemResultKey()) - .append("\")\n") - .append(".inputArrayJsonKey(\"") - .append(data.getInputKey()) - .append("\")\n") - .append(".outputArrayJsonKey(\"") - .append(data.getOutputKey()) - .append("\")\n") - .append(".appendToStateGraph(stateGraph, \"") - .append(varName) - .append("\", \"") - .append(varName) - .append("_out\");\n\n"); - return sb.toString(); } } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/KnowledgeRetrievalNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/KnowledgeRetrievalNodeSection.java index ced9a04108..a9911beb47 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/KnowledgeRetrievalNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/KnowledgeRetrievalNodeSection.java @@ -131,7 +131,7 @@ public String render(Node node, String varName) { } return String.format(""" - // —— KnowledgeRetrievalNode [%s] ——%n + // —— KnowledgeRetrievalNode [%s] —— KnowledgeRetrievalNode %s = KnowledgeRetrievalNode.builder() .topK(%s) .similarityThreshold(%s) @@ -140,6 +140,7 @@ public String render(Node node, String varName) { .vectorStore(createVectorStore(%s)) .build(); stateGraph.addNode("%s", AsyncNodeAction.node_async(wrapperRetrievalNodeAction(%s, "%s"))); + """, node.getId(), varName, ObjectToCodeUtil.toCode(nodeData.getTopK()), ObjectToCodeUtil.toCode(nodeData.getThreshold()), ObjectToCodeUtil.toCode(nodeData.getInputKey()), ObjectToCodeUtil.toCode(nodeData.getOutputKey()), @@ -233,4 +234,16 @@ public List resourceFiles(DSLDialectType dialectType, KnowledgeRet .orElse(NodeSection.super.resourceFiles(dialectType, nodeData)); } + @Override + public List getImports() { + return List.of("com.alibaba.cloud.ai.graph.node.KnowledgeRetrievalNode", + "org.springframework.ai.embedding.EmbeddingModel", "org.springframework.ai.reader.TextReader", + "org.springframework.ai.transformer.splitter.TokenTextSplitter", + "org.springframework.ai.vectorstore.SimpleVectorStore", + "org.springframework.ai.vectorstore.VectorStore", "org.springframework.beans.factory.annotation.Value", + "org.springframework.core.io.Resource", "org.springframework.ai.document.Document", + "org.springframework.beans.factory.annotation.Autowired", "org.springframework.core.io.ResourceLoader", + "java.util.Optional"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/LLMNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/LLMNodeSection.java index 6258b04559..2e61c9229d 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/LLMNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/LLMNodeSection.java @@ -25,6 +25,8 @@ import com.alibaba.cloud.ai.studio.admin.generator.utils.ObjectToCodeUtil; import org.springframework.stereotype.Component; +import java.util.List; + // TODO:支持异常分支、支持DashScope平台以外其他模型、Dify的结构化输出 @Component public class LLMNodeSection implements NodeSection { @@ -38,10 +40,11 @@ public boolean support(NodeType nodeType) { public String render(Node node, String varName) { LLMNodeData nodeData = ((LLMNodeData) node.getData()); return String.format(""" - // —— LLMNode [%s] ——%n + // —— LLMNode [%s] —— stateGraph.addNode("%s", AsyncNodeAction.node_async( createLLMNodeAction(%s, %s, %s, %s, %s, %s, %s, %s, %s) )); + """, node.getId(), varName, ObjectToCodeUtil.toCode(nodeData.getChatModeName()), ObjectToCodeUtil.toCode(nodeData.getModeParams()), ObjectToCodeUtil.toCode(nodeData.getMessageTemplates()), @@ -147,4 +150,15 @@ else if (errorNextNode != null) { : "Map.of(outputKeyPrefix + \"output\", defaultOutput, outputKeyPrefix + \"reasoning_content\", defaultOutput)"); } + @Override + public List getImports() { + return List.of("org.springframework.ai.chat.messages.Message", + "org.springframework.ai.chat.messages.AssistantMessage", + "org.springframework.ai.chat.messages.MessageType", + "org.springframework.ai.chat.messages.SystemMessage", + "org.springframework.ai.chat.messages.UserMessage", + "com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions", + "org.springframework.beans.factory.annotation.Autowired", "java.util.Optional"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ListOperatorNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ListOperatorNodeSection.java index 0d81e7978c..0215c3d8dd 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ListOperatorNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ListOperatorNodeSection.java @@ -23,6 +23,8 @@ import org.springframework.stereotype.Component; +import java.util.List; + @Component public class ListOperatorNodeSection implements NodeSection { @@ -107,4 +109,9 @@ public String render(Node node, String varName) { return sb.toString(); } + @Override + public List getImports() { + return List.of("com.alibaba.cloud.ai.graph.node.ListOperatorNode", "java.util.Comparator"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/MCPNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/MCPNodeSection.java index 8f0b332fa7..a8c58e6b1e 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/MCPNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/MCPNodeSection.java @@ -90,4 +90,9 @@ public String render(Node node, String varName) { return sb.toString(); } + @Override + public List getImports() { + return List.of(); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/MiddleOutputSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/MiddleOutputSection.java new file mode 100644 index 0000000000..63a18484ce --- /dev/null +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/MiddleOutputSection.java @@ -0,0 +1,77 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.cloud.ai.studio.admin.generator.service.generator.workflow.sections; + +import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.Node; +import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeType; +import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata.MiddleOutputNodeData; +import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; +import com.alibaba.cloud.ai.studio.admin.generator.service.generator.workflow.NodeSection; +import com.alibaba.cloud.ai.studio.admin.generator.utils.ObjectToCodeUtil; +import org.springframework.stereotype.Component; + +import java.util.List; + +@Component +public class MiddleOutputSection implements NodeSection { + + @Override + public boolean support(NodeType nodeType) { + return NodeType.MIDDLE_OUTPUT.equals(nodeType); + } + + @Override + public String render(Node node, String varName) { + MiddleOutputNodeData nodeData = (MiddleOutputNodeData) node.getData(); + return String.format(""" + // -- MiddleOutputNode [%s] -- + stateGraph.addNode("%s", AsyncNodeAction.node_async( + createMiddleOutputNodeAction(%s, %s, %s)) + ); + + """, varName, varName, ObjectToCodeUtil.toCode(nodeData.getOutputTemplate()), + ObjectToCodeUtil.toCode(nodeData.getVarKeys()), ObjectToCodeUtil.toCode(nodeData.getOutputKey())); + } + + @Override + public String assistMethodCode(DSLDialectType dialectType) { + return switch (dialectType) { + case STUDIO -> + """ + private NodeAction createMiddleOutputNodeAction(String outputTemplate, List keys, String outputKey) { + return state -> { + Map params = keys.stream() + .collect(Collectors.toUnmodifiableMap( + key -> key, + key -> state.value(key).orElse(""), + (a, b) -> b + )); + String output = new PromptTemplate(outputTemplate).render(params); + return Map.of(outputKey, output); + }; + } + """; + default -> NodeSection.super.assistMethodCode(dialectType); + }; + } + + @Override + public List getImports() { + return List.of("java.util.stream.Collectors", "org.springframework.ai.chat.prompt.PromptTemplate"); + } + +} diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ParameterParsingNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ParameterParsingNodeSection.java index 25b4e58639..00f0f81809 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ParameterParsingNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ParameterParsingNodeSection.java @@ -103,4 +103,9 @@ private NodeAction wrapperParameterNodeAction(NodeAction nodeAction, String node }; } + @Override + public List getImports() { + return List.of("com.alibaba.cloud.ai.graph.node.ParameterParsingNode", "java.util.stream.Collectors"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/QuestionClassifierNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/QuestionClassifierNodeSection.java index 5738401e93..1fd035bc69 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/QuestionClassifierNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/QuestionClassifierNodeSection.java @@ -18,7 +18,6 @@ import java.util.ArrayList; import java.util.List; -import java.util.Map; import java.util.stream.Collectors; import com.alibaba.cloud.ai.studio.admin.generator.model.VariableSelector; @@ -100,21 +99,18 @@ private String resolveConditionKey(QuestionClassifierNodeData classifier, String } @Override - public String renderConditionalEdges(QuestionClassifierNodeData nodeData, Map nodeMap, - Map.Entry> entry, Map varNames) { - String sourceId = entry.getKey(); - List condEdges = entry.getValue(); + public String renderEdges(QuestionClassifierNodeData nodeData, List edges) { List conditions = new ArrayList<>(); List mappings = new ArrayList<>(); - String srcVar = varNames.get(sourceId); + String srcVar = nodeData.getVarName(); StringBuilder sb = new StringBuilder(); // 如果输出的都不是预定分类,则使用最后一个分类 String lastConditionKey = "unknown"; - for (Edge e : condEdges) { + for (Edge e : edges) { String conditionKey = resolveConditionKey(nodeData, e.getSourceHandle()); - String tgtVar2 = varNames.get(e.getTarget()); + String tgtVar2 = e.getTarget(); lastConditionKey = conditionKey; conditions.add(String.format("if (value.contains(\"%s\")) return \"%s\";", conditionKey, conditionKey)); mappings.add(String.format("\"%s\", \"%s\"", conditionKey, tgtVar2)); @@ -132,4 +128,10 @@ public String renderConditionalEdges(QuestionClassifierNodeData nodeData, Map getImports() { + return List.of("com.alibaba.cloud.ai.graph.node.QuestionClassifierNode", + "static com.alibaba.cloud.ai.graph.action.AsyncEdgeAction.edge_async"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/StartNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/StartNodeSection.java index 76a7f06a25..4649c1a3d3 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/StartNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/StartNodeSection.java @@ -15,6 +15,7 @@ */ package com.alibaba.cloud.ai.studio.admin.generator.service.generator.workflow.sections; +import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.Edge; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.Node; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeType; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata.StartNodeData; @@ -22,6 +23,8 @@ import org.springframework.stereotype.Component; +import java.util.List; + @Component public class StartNodeSection implements NodeSection { @@ -36,4 +39,23 @@ public String render(Node node, String varName) { return ""; } + @Override + public String renderEdges(StartNodeData nodeData, List edges) { + // 开始节点的Source应为StateGraph.START + StringBuilder sb = new StringBuilder(); + sb.append(String.format("// Edges For [START]%n")); + if (edges.isEmpty()) { + return ""; + } + sb.append(String.format("stateGraph%n")); + edges.forEach(edge -> sb.append(String.format(".addEdge(START, \"%s\")%n", edge.getTarget()))); + sb.append(String.format(";%n%n")); + return sb.toString(); + } + + @Override + public List getImports() { + return List.of(); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/TemplateTransformNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/TemplateTransformNodeSection.java index 94b3023545..e11dc6c224 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/TemplateTransformNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/TemplateTransformNodeSection.java @@ -23,6 +23,8 @@ import org.springframework.stereotype.Component; +import java.util.List; + @Component public class TemplateTransformNodeSection implements NodeSection { @@ -60,4 +62,9 @@ public String render(Node node, String varName) { return sb.toString(); } + @Override + public List getImports() { + return List.of("com.alibaba.cloud.ai.graph.node.TemplateTransformNode"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ToolNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ToolNodeSection.java index 74a085b6ff..6f54f65053 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ToolNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ToolNodeSection.java @@ -82,4 +82,10 @@ public String render(Node node, String varName) { return sb.toString(); } + @Override + public List getImports() { + return List.of("com.alibaba.cloud.ai.graph.node.ToolNode", "java.util.function.Function", + "org.springframework.ai.tool.function.FunctionToolCallback"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/VariableAggregatorNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/VariableAggregatorNodeSection.java index bcc72efe60..a4d04572dc 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/VariableAggregatorNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/VariableAggregatorNodeSection.java @@ -168,4 +168,9 @@ else if (object instanceof List list) { }; } + @Override + public List getImports() { + return List.of("com.alibaba.cloud.ai.graph.node.VariableAggregatorNode", "java.util.stream.Collectors"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/utils/ObjectToCodeUtil.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/utils/ObjectToCodeUtil.java index c58d139202..d25534ce0d 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/utils/ObjectToCodeUtil.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/utils/ObjectToCodeUtil.java @@ -16,6 +16,8 @@ package com.alibaba.cloud.ai.studio.admin.generator.utils; +import com.fasterxml.jackson.databind.ObjectMapper; + import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -29,6 +31,8 @@ */ public final class ObjectToCodeUtil { + private static final ObjectMapper objectMapper = new ObjectMapper(); + private ObjectToCodeUtil() { } @@ -52,7 +56,18 @@ public static String toCode(Object object) { return "null"; } else if (object instanceof String) { - return "\"" + object + "\""; + try { + // 尝试使用Jackson打印字符串,以便转义特殊字符,如果失败则进行简单处理 + return objectMapper.writeValueAsString(object.toString()); + } + catch (Exception e) { + return "\"" + object.toString() + .replace("\"", "\\") + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("\t", "\\t") + .replace("\b", "\\b") + "\""; + } } else if (object instanceof List) { return listToCode((List) object); diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/resources/templates/GraphBuilder.java.mustache b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/resources/templates/GraphBuilder.java.mustache index 821a909284..fa67caf3d2 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/resources/templates/GraphBuilder.java.mustache +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/resources/templates/GraphBuilder.java.mustache @@ -28,10 +28,7 @@ public class GraphBuilder { {{assistMethodCode}} @Bean - public CompiledGraph buildGraph( - ChatModel chatModel - {{#hasCode}}, CodeExecutionConfig codeExecutionConfig, CodeExecutor codeExecutor{{/hasCode}} - ) throws Exception { + public CompiledGraph buildGraph(ChatModel chatModel) throws Exception { ChatClient chatClient = ChatClient.builder(chatModel).defaultAdvisors(new SimpleLoggerAdvisor()).build(); // new stateGraph @@ -43,23 +40,4 @@ public class GraphBuilder { return stateGraph.compile(); } - {{#hasCode}} - @Bean - public Path tempDir() throws IOException { - // todo: set your work dir - Path tempDir = Files.createTempDirectory("code-execution-workdir-"); - tempDir.toFile().deleteOnExit(); - return tempDir; - } - - @Bean - public CodeExecutionConfig codeExecutionConfig(Path tempDir) { - return new CodeExecutionConfig().setWorkDir(tempDir.toString()); - } - - @Bean - public CodeExecutor codeGenerator() { - return new LocalCommandlineCodeExecutor(); - } - {{/hasCode}} }