From a56c96517b239a4fcb202f152409e1e28c91e7f5 Mon Sep 17 00:00:00 2001 From: VLSMB <2047857654@qq.com> Date: Tue, 9 Sep 2025 21:16:50 +0800 Subject: [PATCH 01/11] feat: add MiddleOutputNode for Studio DSL --- .../generator/model/workflow/NodeType.java | 2 + .../nodedata/MiddleOutputNodeData.java | 65 +++++++++++ .../dsl/AbstractNodeDataConverter.java | 12 +++ .../dsl/converter/LLMNodeDataConverter.java | 7 +- .../MiddleOutputNodeDataConverter.java | 102 ++++++++++++++++++ .../workflow/WorkflowProjectGenerator.java | 3 + .../workflow/sections/EndNodeSection.java | 16 +-- .../KnowledgeRetrievalNodeSection.java | 3 +- .../workflow/sections/LLMNodeSection.java | 3 +- .../sections/MiddleOutputSection.java | 70 ++++++++++++ .../generator/utils/ObjectToCodeUtil.java | 17 ++- 11 files changed, 283 insertions(+), 17 deletions(-) create mode 100644 spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/MiddleOutputNodeData.java create mode 100644 spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/MiddleOutputNodeDataConverter.java create mode 100644 spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/MiddleOutputSection.java diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java index e637e58390..78bf01c7c0 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java @@ -26,6 +26,8 @@ public enum NodeType { ANSWER("answer", "answer", "UNSUPPORTED"), + MIDDLE_OUTPUT("middle-output", "UNSUPPORTED", "Output"), + AGENT("agent", "agent", "UNSUPPORTED"), LLM("llm", "llm", "LLM"), diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/MiddleOutputNodeData.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/MiddleOutputNodeData.java new file mode 100644 index 0000000000..550876dec2 --- /dev/null +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/MiddleOutputNodeData.java @@ -0,0 +1,65 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata; + +import com.alibaba.cloud.ai.studio.admin.generator.model.Variable; +import com.alibaba.cloud.ai.studio.admin.generator.model.VariableType; +import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeData; +import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; + +import java.util.List; + +public class MiddleOutputNodeData extends NodeData { + + public static List getDefaultOutputSchemas(DSLDialectType dialectType) { + return switch (dialectType) { + case STUDIO -> List.of(new Variable("output", VariableType.STRING)); + default -> List.of(); + }; + } + + private String outputTemplate; + + private List varKeys; + + private String outputKey; + + public String getOutputTemplate() { + return outputTemplate; + } + + public void setOutputTemplate(String outputTemplate) { + this.outputTemplate = outputTemplate; + } + + public List getVarKeys() { + return varKeys; + } + + public void setVarKeys(List varKeys) { + this.varKeys = varKeys; + } + + public String getOutputKey() { + return outputKey; + } + + public void setOutputKey(String outputKey) { + this.outputKey = outputKey; + } + +} diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/AbstractNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/AbstractNodeDataConverter.java index 2fd8b931ea..1b1b174543 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/AbstractNodeDataConverter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/AbstractNodeDataConverter.java @@ -150,6 +150,18 @@ protected String convertVarTemplate(DSLDialectType dialectType, String templateS return func.apply(templateString, idToVarName); } + private static final Pattern VAR_TEMPLATE_PATTERN = Pattern.compile("\\{(\\w+)}"); + + /** + * 获取模板中的变量占位符,比如"你好{var1},{var2}"返回"[var1, var2]" + * @param template 模板字符串 + * @return 变量占位符列表 + */ + protected List getVarTemplateKeys(String template) { + Matcher matcher = VAR_TEMPLATE_PATTERN.matcher(template); + return matcher.results().map(m -> m.group(1)).toList(); + } + /** * 创建一个空处理Consumer,便于使用.andThen编程 * @return BiConsumer diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/LLMNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/LLMNodeDataConverter.java index fb8123523f..643f359220 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/LLMNodeDataConverter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/LLMNodeDataConverter.java @@ -20,8 +20,6 @@ import java.util.Map; import java.util.Optional; import java.util.function.BiConsumer; -import java.util.regex.Matcher; -import java.util.regex.Pattern; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -206,8 +204,6 @@ public String generateVarName(int count) { return "LLMNode" + count; } - private static final Pattern VAR_TEMPLATE_PATTERN = Pattern.compile("\\{(\\w+)}"); - @Override public BiConsumer> postProcessConsumer(DSLDialectType dialectType) { return switch (dialectType) { @@ -222,8 +218,7 @@ public BiConsumer> postProcessConsumer(DSLDiale .stream() .map(template -> { String newText = this.convertVarTemplate(dialectType, template.template(), idToVarName); - Matcher matcher = VAR_TEMPLATE_PATTERN.matcher(newText); - List keys = matcher.results().map(m -> m.group(1)).toList(); + List keys = this.getVarTemplateKeys(newText); return new LLMNodeData.MessageTemplate(newText, keys, template.type()); }) .toList(); diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/MiddleOutputNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/MiddleOutputNodeDataConverter.java new file mode 100644 index 0000000000..677bf75ae5 --- /dev/null +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/MiddleOutputNodeDataConverter.java @@ -0,0 +1,102 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.cloud.ai.studio.admin.generator.service.dsl.converter; + +import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeType; +import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata.MiddleOutputNodeData; +import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.AbstractNodeDataConverter; +import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; +import com.alibaba.cloud.ai.studio.admin.generator.utils.MapReadUtil; +import com.fasterxml.jackson.core.JsonProcessingException; +import org.springframework.stereotype.Component; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.BiConsumer; +import java.util.stream.Stream; + +@Component +public class MiddleOutputNodeDataConverter extends AbstractNodeDataConverter { + + @Override + public Boolean supportNodeType(NodeType nodeType) { + return NodeType.MIDDLE_OUTPUT.equals(nodeType); + } + + @Override + public String generateVarName(int count) { + return "middleOutput" + count; + } + + @Override + public BiConsumer> postProcessConsumer(DSLDialectType dialectType) { + return switch (dialectType) { + case STUDIO -> emptyProcessConsumer().andThen((nodeData, idToVarName) -> { + // 设置输出键 + nodeData.setOutputs(MiddleOutputNodeData.getDefaultOutputSchemas(dialectType)); + nodeData.setOutputKey(nodeData.getVarName() + "_" + nodeData.getOutputs().get(0).getName()); + // 将输出模板进行处理 + nodeData + .setOutputTemplate(this.convertVarTemplate(dialectType, nodeData.getOutputTemplate(), idToVarName)); + nodeData.setVarKeys(this.getVarTemplateKeys(nodeData.getOutputTemplate())); + }).andThen(super.postProcessConsumer(dialectType)); + default -> super.postProcessConsumer(dialectType); + }; + } + + private enum MiddleOutputNodeConverter { + + STUDIO(new DialectConverter<>() { + @Override + public Boolean supportDialect(DSLDialectType dialectType) { + return DSLDialectType.STUDIO.equals(dialectType); + } + + @Override + public MiddleOutputNodeData parse(Map data) throws JsonProcessingException { + MiddleOutputNodeData nodeData = new MiddleOutputNodeData(); + String outputTemplate = MapReadUtil.getMapDeepValue(data, String.class, "config", "node_param", + "output"); + nodeData.setOutputTemplate(Optional.ofNullable(outputTemplate).orElse("")); + return nodeData; + } + + @Override + public Map dump(MiddleOutputNodeData nodeData) { + throw new UnsupportedOperationException(); + } + }), CUSTOM(defaultCustomDialectConverter(MiddleOutputNodeData.class)); + + private final DialectConverter converter; + + MiddleOutputNodeConverter(DialectConverter converter) { + this.converter = converter; + } + + public DialectConverter dialectConverter() { + return converter; + } + + } + + @Override + protected List> getDialectConverters() { + return Stream.of(MiddleOutputNodeConverter.values()).map(MiddleOutputNodeConverter::dialectConverter).toList(); + } + +} diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java index 4d54f8de09..11ea27856e 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java @@ -276,10 +276,13 @@ private String renderEdgeSections(List edges, List nodes, Map> nodeTypeToClass = Map.ofEntries( Map.entry(NodeType.ANSWER, List.of("com.alibaba.cloud.ai.graph.node.AnswerNode")), + Map.entry(NodeType.MIDDLE_OUTPUT, + List.of("java.util.stream.Collectors", "org.springframework.ai.chat.prompt.PromptTemplate")), Map.entry(NodeType.CODE, List.of("com.alibaba.cloud.ai.graph.node.code.CodeExecutorNodeAction", "com.alibaba.cloud.ai.graph.node.code.entity.CodeExecutionConfig", "com.alibaba.cloud.ai.graph.node.code.CodeExecutor", diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/EndNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/EndNodeSection.java index 6bbb1809d0..dda8dec62d 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/EndNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/EndNodeSection.java @@ -26,6 +26,7 @@ import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata.EndNodeData; import com.alibaba.cloud.ai.studio.admin.generator.service.generator.workflow.NodeSection; +import com.alibaba.cloud.ai.studio.admin.generator.utils.ObjectToCodeUtil; import org.springframework.stereotype.Component; @Component @@ -49,13 +50,14 @@ public String render(Node node, String varName) { if ("text".equalsIgnoreCase(data.getOutputType())) { // 如果输出类型为text,则使用对应的输出模板输出最终结果 if (data.getTextTemplateVars().isEmpty()) { - codeStr = String.format("state -> Map.of(\"output\", \"%s\")", data.getTextTemplate()); + codeStr = String.format("state -> Map.of(\"output\", %s)", + ObjectToCodeUtil.toCode(data.getTextTemplate())); } else { codeStr = String.format(""" state -> { - String template = "%s"; - Map params = Stream.of(%s) + String template = %s; + Map params = %s.stream() .collect(Collectors.toMap( key -> key, key -> state.value(key).orElse(""), @@ -63,11 +65,9 @@ public String render(Node node, String varName) { template = new PromptTemplate(template).render(params); return Map.of("output", template); } - """, data.getTextTemplate(), - data.getTextTemplateVars() - .stream() - .map(s -> String.format("\"%s\"", s)) - .collect(Collectors.joining(", "))); + + """, ObjectToCodeUtil.toCode(data.getTextTemplate()), + ObjectToCodeUtil.toCode(data.getTextTemplateVars())); } } else { diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/KnowledgeRetrievalNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/KnowledgeRetrievalNodeSection.java index ced9a04108..d65c19be25 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/KnowledgeRetrievalNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/KnowledgeRetrievalNodeSection.java @@ -131,7 +131,7 @@ public String render(Node node, String varName) { } return String.format(""" - // —— KnowledgeRetrievalNode [%s] ——%n + // —— KnowledgeRetrievalNode [%s] —— KnowledgeRetrievalNode %s = KnowledgeRetrievalNode.builder() .topK(%s) .similarityThreshold(%s) @@ -140,6 +140,7 @@ public String render(Node node, String varName) { .vectorStore(createVectorStore(%s)) .build(); stateGraph.addNode("%s", AsyncNodeAction.node_async(wrapperRetrievalNodeAction(%s, "%s"))); + """, node.getId(), varName, ObjectToCodeUtil.toCode(nodeData.getTopK()), ObjectToCodeUtil.toCode(nodeData.getThreshold()), ObjectToCodeUtil.toCode(nodeData.getInputKey()), ObjectToCodeUtil.toCode(nodeData.getOutputKey()), diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/LLMNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/LLMNodeSection.java index 6258b04559..74e1f7ae9e 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/LLMNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/LLMNodeSection.java @@ -38,10 +38,11 @@ public boolean support(NodeType nodeType) { public String render(Node node, String varName) { LLMNodeData nodeData = ((LLMNodeData) node.getData()); return String.format(""" - // —— LLMNode [%s] ——%n + // —— LLMNode [%s] —— stateGraph.addNode("%s", AsyncNodeAction.node_async( createLLMNodeAction(%s, %s, %s, %s, %s, %s, %s, %s, %s) )); + """, node.getId(), varName, ObjectToCodeUtil.toCode(nodeData.getChatModeName()), ObjectToCodeUtil.toCode(nodeData.getModeParams()), ObjectToCodeUtil.toCode(nodeData.getMessageTemplates()), diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/MiddleOutputSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/MiddleOutputSection.java new file mode 100644 index 0000000000..47f4735d3f --- /dev/null +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/MiddleOutputSection.java @@ -0,0 +1,70 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.cloud.ai.studio.admin.generator.service.generator.workflow.sections; + +import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.Node; +import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeType; +import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata.MiddleOutputNodeData; +import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; +import com.alibaba.cloud.ai.studio.admin.generator.service.generator.workflow.NodeSection; +import com.alibaba.cloud.ai.studio.admin.generator.utils.ObjectToCodeUtil; +import org.springframework.stereotype.Component; + +@Component +public class MiddleOutputSection implements NodeSection { + + @Override + public boolean support(NodeType nodeType) { + return NodeType.MIDDLE_OUTPUT.equals(nodeType); + } + + @Override + public String render(Node node, String varName) { + MiddleOutputNodeData nodeData = (MiddleOutputNodeData) node.getData(); + return String.format(""" + // -- MiddleOutputNode [%s] -- + stateGraph.addNode("%s", AsyncNodeAction.node_async( + createMiddleOutputNodeAction(%s, %s, %s)) + ); + + """, varName, varName, ObjectToCodeUtil.toCode(nodeData.getOutputTemplate()), + ObjectToCodeUtil.toCode(nodeData.getVarKeys()), ObjectToCodeUtil.toCode(nodeData.getOutputKey())); + } + + @Override + public String assistMethodCode(DSLDialectType dialectType) { + return switch (dialectType) { + case STUDIO -> + """ + private NodeAction createMiddleOutputNodeAction(String outputTemplate, List keys, String outputKey) { + return state -> { + Map params = keys.stream() + .collect(Collectors.toUnmodifiableMap( + key -> key, + key -> state.value(key).orElse(""), + (a, b) -> b + )); + String output = new PromptTemplate(outputTemplate).render(params); + return Map.of(outputKey, output); + }; + } + """; + default -> NodeSection.super.assistMethodCode(dialectType); + }; + } + +} diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/utils/ObjectToCodeUtil.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/utils/ObjectToCodeUtil.java index c58d139202..d25534ce0d 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/utils/ObjectToCodeUtil.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/utils/ObjectToCodeUtil.java @@ -16,6 +16,8 @@ package com.alibaba.cloud.ai.studio.admin.generator.utils; +import com.fasterxml.jackson.databind.ObjectMapper; + import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -29,6 +31,8 @@ */ public final class ObjectToCodeUtil { + private static final ObjectMapper objectMapper = new ObjectMapper(); + private ObjectToCodeUtil() { } @@ -52,7 +56,18 @@ public static String toCode(Object object) { return "null"; } else if (object instanceof String) { - return "\"" + object + "\""; + try { + // 尝试使用Jackson打印字符串,以便转义特殊字符,如果失败则进行简单处理 + return objectMapper.writeValueAsString(object.toString()); + } + catch (Exception e) { + return "\"" + object.toString() + .replace("\"", "\\") + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("\t", "\\t") + .replace("\b", "\\b") + "\""; + } } else if (object instanceof List) { return listToCode((List) object); From 3ccc1e342afd3cfd6387c65d34f175df5b938fd2 Mon Sep 17 00:00:00 2001 From: VLSMB <2047857654@qq.com> Date: Tue, 9 Sep 2025 23:52:15 +0800 Subject: [PATCH 02/11] refactor: refactor renderEdges --- .../admin/generator/model/workflow/Edge.java | 23 ------ .../model/workflow/nodedata/LLMNodeData.java | 5 +- .../dsl/adapters/StudioDSLAdapter.java | 3 +- .../converter/BranchNodeDataConverter.java | 18 +++++ .../generator/workflow/NodeSection.java | 26 +++--- .../workflow/WorkflowProjectGenerator.java | 80 +++++-------------- .../workflow/sections/BranchNodeSection.java | 39 +++------ .../workflow/sections/EndNodeSection.java | 1 + .../QuestionClassifierNodeSection.java | 12 +-- .../workflow/sections/StartNodeSection.java | 17 ++++ 10 files changed, 90 insertions(+), 134 deletions(-) diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/Edge.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/Edge.java index b40bf275e6..7b1ca24a5c 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/Edge.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/Edge.java @@ -31,11 +31,6 @@ public class Edge { private Map data; - private Integer zIndex = 0; - - // Temp field - private boolean isDify = true; - public String getId() { return id; } @@ -90,22 +85,4 @@ public Edge setData(Map data) { return this; } - public Integer getzIndex() { - return zIndex; - } - - public Edge setzIndex(Integer zIndex) { - this.zIndex = zIndex; - return this; - } - - public boolean isDify() { - return isDify; - } - - public Edge setDify(boolean isDify) { - this.isDify = isDify; - return this; - } - } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/LLMNodeData.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/LLMNodeData.java index 38c4a015be..bf74735dd4 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/LLMNodeData.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/LLMNodeData.java @@ -131,8 +131,9 @@ public record MessageTemplate(String template, List keys, MessageType ty @Override public String toString() { - return String.format("new MessageTemplate(\"%s\", %s, MessageType.%s)", this.template(), - ObjectToCodeUtil.toCode(this.keys()), this.type().getValue().toUpperCase()); + return String.format("new MessageTemplate(%s, %s, MessageType.%s)", + ObjectToCodeUtil.toCode(this.template()), ObjectToCodeUtil.toCode(this.keys()), + this.type().getValue().toUpperCase()); } } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/StudioDSLAdapter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/StudioDSLAdapter.java index 9af8c883e1..17e8741058 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/StudioDSLAdapter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/StudioDSLAdapter.java @@ -202,8 +202,7 @@ private List constructEdges(List> edgeMaps) { .setSource(source) .setTarget(target) .setSourceHandle(sourceHandle) - .setTargetHandle(targetHandle) - .setDify(false); + .setTargetHandle(targetHandle); return edge; }).toList(); } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/BranchNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/BranchNodeDataConverter.java index 7d3d3c6bc9..ad87859c12 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/BranchNodeDataConverter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/BranchNodeDataConverter.java @@ -19,6 +19,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.BiConsumer; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -135,4 +136,21 @@ public Stream extractWorkflowVars(BranchNodeData data) { return Stream.empty(); } + @Override + public BiConsumer> postProcessConsumer(DSLDialectType dialectType) { + return switch (dialectType) { + case DIFY -> super.postProcessConsumer(dialectType).andThen((nodeData, idToVarName) -> { + // 处理条件里的VariableSelector + nodeData.getCases().forEach(c -> { + c.getConditions().forEach(condition -> { + VariableSelector selector = condition.getVariableSelector(); + selector.setNameInCode(idToVarName.getOrDefault(selector.getNamespace(), "unknown") + "_" + + selector.getName()); + }); + }); + }); + default -> super.postProcessConsumer(dialectType); + }; + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/NodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/NodeSection.java index b8a5a2dba9..b0990d50ce 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/NodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/NodeSection.java @@ -50,19 +50,23 @@ default String escape(String input) { .replace("\t", "\\t"); } - // todo: 完善条件边的EdgeAction /** - * 生成条件边的SAA代码 - * @param nodeData 节点数据 - * @param nodeMap nodeId与node的映射 - * @param entry 包含当前节点ID与当前节点出发的条件边List - * @param varNames nodeId与nodeVarName的映射 - * @return 条件边代码 + * 生成stateGraph边的代码。edge列表为从当前节点出发的边。 如果当前节点有条件边,则应重写本方法。本方法默认为无条件的边。 + * @param nodeData 当前节点(边起始节点)的数据 + * @param edges 边列表,且边的source和handle应格式化为varName + * @return 生成的代码 */ - default String renderConditionalEdges(T nodeData, Map nodeMap, Map.Entry> entry, - Map varNames) { - System.err.println("Unsupported Conditional Edges!"); - return ""; + default String renderEdges(T nodeData, List edges) { + StringBuilder sb = new StringBuilder(); + sb.append(String.format("// Edges For [%s]%n", nodeData.getVarName())); + if (edges.isEmpty()) { + return ""; + } + sb.append(String.format("stateGraph%n")); + edges.forEach( + edge -> sb.append(String.format(".addEdge(\"%s\", \"%s\")%n", edge.getSource(), edge.getTarget()))); + sb.append(String.format(";%n%n")); + return sb.toString(); } /** diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java index 11ea27856e..fa931edaff 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java @@ -21,12 +21,12 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.Arrays; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -80,8 +80,6 @@ public class WorkflowProjectGenerator implements ProjectGenerator { private final TemplateRenderer templateRenderer; - private final List> nodeNodeSections; - private final Map> nodeSectionMap; public WorkflowProjectGenerator(List dslAdapters, @@ -90,7 +88,6 @@ public WorkflowProjectGenerator(List dslAdapters, this.dslAdapters = dslAdapters; this.templateRenderer = templateRenderer .getIfAvailable(() -> new MustacheTemplateRenderer("classpath:/templates")); - this.nodeNodeSections = nodeNodeSections; this.nodeSectionMap = nodeNodeSections.stream().map(nodeSection -> { List nodeTypeList = Arrays.stream(NodeType.values()).filter(nodeSection::support).toList(); if (nodeTypeList.isEmpty()) { @@ -198,67 +195,29 @@ private String renderNodeSections(List nodes, Map varNames return sb.toString(); } - // TODO: 目前这里渲染edge的逻辑与Dify转换高度耦合,需要优化 private String renderEdgeSections(List edges, List nodes, Map varNames) { - StringBuilder sb = new StringBuilder(); - Map nodeMap = nodes.stream().collect(Collectors.toMap(Node::getId, n -> n)); - - // conditional edge set: sourceId -> List - Map> conditionalEdgesMap = edges.stream() - .filter(e -> e.getSourceHandle() != null && !"source".equals(e.getSourceHandle())) - .collect(Collectors.groupingBy(Edge::getSource)); - - // Set to track rendered edges to avoid duplicates - Set renderedEdges = new HashSet<>(); - - // common edge - for (Edge edge : edges) { - String sourceId = edge.getSource(); - String targetId = edge.getTarget(); - String srcVar = varNames.get(sourceId); - String tgtVar = varNames.get(targetId); - - Node sourceNode = nodeMap.get(sourceId); - NodeType sourceType = sourceNode != null ? sourceNode.getType() : null; - - // Skip if already rendered as conditional - if (edge.getSourceHandle() != null && !"source".equals(edge.getSourceHandle()) && edge.isDify()) { - continue; - } + // 将Edge里的source和target都转换成varName + edges.forEach(edge -> { + edge.setSource(varNames.getOrDefault(edge.getSource(), edge.getSource())); + edge.setTarget(varNames.getOrDefault(edge.getTarget(), edge.getTarget())); + }); - // 迭代节点作为边的终止点时直接使用节点ID,作为边的起始点时使用ID_out - // todo: 修改迭代节点终止ID,防止与变量冲突(Dify不冲突) - if (sourceType != null && NodeType.ITERATION.equals(sourceType) && edge.isDify()) { - srcVar += "_out"; - } + // nodeVarName -> node的映射 + Map nodeMap = nodes.stream() + .collect(Collectors.toMap(node -> node.getData().getVarName(), Function.identity())); - String key = srcVar + "->" + tgtVar; - if (renderedEdges.contains(key)) { - continue; - } - renderedEdges.add(key); + // 根据source进行分组 + Map> edgeGroup = edges.stream().collect(Collectors.groupingBy(Edge::getSource)); - // START and END special handling - if (NodeType.START.equals(sourceType)) { - sb.append(String.format("stateGraph.addEdge(START, \"%s\");%n", tgtVar)); - } - else { - sb.append(String.format("stateGraph.addEdge(\"%s\", \"%s\");%n", srcVar, tgtVar)); - } - } + StringBuilder sb = new StringBuilder(); - // conditional edge(aggregate by sourceId) - for (Map.Entry> entry : conditionalEdgesMap.entrySet()) { - String nodeId = entry.getKey(); - Node node = nodeMap.get(nodeId); - NodeType nodeType = node.getType(); - for (NodeSection section : nodeNodeSections) { - if (section.support(nodeType)) { - String edgeCode = section.renderConditionalEdges(node.getData(), nodeMap, entry, varNames); - sb.append(edgeCode); - } - } - } + // 调用每一个source节点的renderEdges方法 + edgeGroup.forEach((varName, edgeList) -> { + NodeType nodeType = nodeMap.get(varName).getType(); + @SuppressWarnings("unchecked") + NodeSection section = (NodeSection) nodeSectionMap.get(nodeType); + sb.append(section.renderEdges(nodeMap.get(varName).getData(), edgeList)); + }); // 统一生成end节点到StateGraph.END的边(避免边重复) List endNodeList = nodes.stream() @@ -268,6 +227,7 @@ private String renderEdgeSections(List edges, List nodes, Map sb.append(String.format("%n.addEdge(\"%s\", END)", endName))); sb.append(String.format(";%n")); diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/BranchNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/BranchNodeSection.java index 61c5ee6920..e5707d64fa 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/BranchNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/BranchNodeSection.java @@ -17,6 +17,7 @@ import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -30,6 +31,7 @@ import org.springframework.stereotype.Component; +// TODO: 对于Dify的条件渲染,将CaseID格式化为比较易懂的格式 @Component public class BranchNodeSection implements NodeSection { @@ -52,9 +54,8 @@ public String render(Node node, String varName) { } @Override - public String renderConditionalEdges(BranchNodeData branchNodeData, Map nodeMap, - Map.Entry> entry, Map varNames) { - String srcVar = varNames.get(entry.getKey()); + public String renderEdges(BranchNodeData branchNodeData, List edges) { + String srcVar = branchNodeData.getVarName(); StringBuilder sb = new StringBuilder(); List cases = branchNodeData.getCases(); @@ -70,7 +71,7 @@ public String renderConditionalEdges(BranchNodeData branchNodeData, Map edgeCaseMap = entry.getValue() - .stream() + Map edgeCaseMap = edges.stream() .collect(Collectors.toMap(Edge::getSourceHandle, Edge::getTarget)); String edgeCaseMapStr = "Map.of(" + edgeCaseMap.entrySet() .stream() - .flatMap(e -> Stream.of(e.getKey(), varNames.getOrDefault(e.getValue(), "unknown"))) + .flatMap(e -> Stream.of(e.getKey(), e.getValue())) .map(v -> String.format("\"%s\"", v)) .collect(Collectors.joining(", ")) + ")"; @@ -105,9 +105,9 @@ public String renderConditionalEdges(BranchNodeData branchNodeData, Map nodeMap) { + private String generateSafeVariableAccess(Case.Condition condition) { String varType = condition.getVarType(); - String variablePath = buildVariablePath(condition, nodeMap); + String variablePath = buildVariablePath(condition); switch (varType.toLowerCase()) { case "file": @@ -148,29 +148,12 @@ private String generateSafeVariableAccess(Case.Condition condition, Map nodeMap) { + private String buildVariablePath(Case.Condition condition) { VariableSelector variableSelector = condition.getVariableSelector(); if (variableSelector == null) { return "unknown"; } - - // 其中第一个是节点ID,第二个是变量名,第三个是属性 - - String nodeId = variableSelector.getNamespace(); - String variableName = variableSelector.getName(); - - // 如果有节点映射,尝试获取正确的变量名 - if (nodeMap.containsKey(nodeId)) { - Node inputNode = nodeMap.get(nodeId); - if (inputNode.getData().getOutputs() != null && !inputNode.getData().getOutputs().isEmpty()) { - // 使用输出定义中的变量名 - String outputName = inputNode.getData().getOutputs().get(0).getName(); - return outputName != null ? outputName : variableName; - } - } - - // 如果无法从节点映射获取,直接使用变量名 - return variableName != null ? variableName : "unknown"; + return Optional.ofNullable(variableSelector.getNameInCode()).orElse("unknown"); } } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/EndNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/EndNodeSection.java index dda8dec62d..6d98d2bbcb 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/EndNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/EndNodeSection.java @@ -86,6 +86,7 @@ public String render(Node node, String varName) { .append("\", AsyncNodeAction.node_async(") .append(codeStr) .append("));"); + sb.append(String.format("%n")); return sb.toString(); } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/QuestionClassifierNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/QuestionClassifierNodeSection.java index 5738401e93..4eb5ba1df8 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/QuestionClassifierNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/QuestionClassifierNodeSection.java @@ -18,7 +18,6 @@ import java.util.ArrayList; import java.util.List; -import java.util.Map; import java.util.stream.Collectors; import com.alibaba.cloud.ai.studio.admin.generator.model.VariableSelector; @@ -100,21 +99,18 @@ private String resolveConditionKey(QuestionClassifierNodeData classifier, String } @Override - public String renderConditionalEdges(QuestionClassifierNodeData nodeData, Map nodeMap, - Map.Entry> entry, Map varNames) { - String sourceId = entry.getKey(); - List condEdges = entry.getValue(); + public String renderEdges(QuestionClassifierNodeData nodeData, List edges) { List conditions = new ArrayList<>(); List mappings = new ArrayList<>(); - String srcVar = varNames.get(sourceId); + String srcVar = nodeData.getVarName(); StringBuilder sb = new StringBuilder(); // 如果输出的都不是预定分类,则使用最后一个分类 String lastConditionKey = "unknown"; - for (Edge e : condEdges) { + for (Edge e : edges) { String conditionKey = resolveConditionKey(nodeData, e.getSourceHandle()); - String tgtVar2 = varNames.get(e.getTarget()); + String tgtVar2 = e.getTarget(); lastConditionKey = conditionKey; conditions.add(String.format("if (value.contains(\"%s\")) return \"%s\";", conditionKey, conditionKey)); mappings.add(String.format("\"%s\", \"%s\"", conditionKey, tgtVar2)); diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/StartNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/StartNodeSection.java index 76a7f06a25..80b5ec5cb5 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/StartNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/StartNodeSection.java @@ -15,6 +15,7 @@ */ package com.alibaba.cloud.ai.studio.admin.generator.service.generator.workflow.sections; +import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.Edge; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.Node; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeType; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata.StartNodeData; @@ -22,6 +23,8 @@ import org.springframework.stereotype.Component; +import java.util.List; + @Component public class StartNodeSection implements NodeSection { @@ -36,4 +39,18 @@ public String render(Node node, String varName) { return ""; } + @Override + public String renderEdges(StartNodeData nodeData, List edges) { + // 开始节点的Source应为StateGraph.START + StringBuilder sb = new StringBuilder(); + sb.append(String.format("// Edges For [START]%n")); + if (edges.isEmpty()) { + return ""; + } + sb.append(String.format("stateGraph%n")); + edges.forEach(edge -> sb.append(String.format(".addEdge(START, \"%s\")%n", edge.getTarget()))); + sb.append(String.format(";%n%n")); + return sb.toString(); + } + } From c28344890c23194b1709cd64089632cbf7dcc8a5 Mon Sep 17 00:00:00 2001 From: VLSMB <2047857654@qq.com> Date: Tue, 9 Sep 2025 23:54:56 +0800 Subject: [PATCH 03/11] refactor: remove unused BranchNode.java --- .../cloud/ai/graph/node/BranchNode.java | 72 ------------------- 1 file changed, 72 deletions(-) delete mode 100644 spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/BranchNode.java diff --git a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/BranchNode.java b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/BranchNode.java deleted file mode 100644 index 71e9ea124f..0000000000 --- a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/BranchNode.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright 2024-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.alibaba.cloud.ai.graph.node; - -import com.alibaba.cloud.ai.graph.OverAllState; -import com.alibaba.cloud.ai.graph.action.NodeAction; -import java.util.HashMap; -import java.util.Map; -import org.springframework.util.StringUtils; - -public class BranchNode implements NodeAction { - - private final String inputKey; - - private final String outputKey; - - public BranchNode(String outputKey, String inputKey) { - if (!StringUtils.hasLength(inputKey) || !StringUtils.hasLength(outputKey)) { - throw new IllegalArgumentException("inputKey and outputKey must not be null or empty."); - } - this.inputKey = inputKey; - this.outputKey = outputKey; - } - - @Override - public Map apply(OverAllState state) throws Exception { - Map updatedState = new HashMap<>(); - String value = state.value(inputKey).map(Object::toString).orElse(null); - updatedState.put(this.outputKey, value); - return updatedState; - } - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - - private String outputKey; - - private String inputKey; - - public BranchNode build() { - return new BranchNode(outputKey, inputKey); - } - - public Builder outputKey(String outputKey) { - this.outputKey = outputKey; - return this; - } - - public Builder inputKey(String inputKey) { - this.inputKey = inputKey; - return this; - } - - } - -} From 34b79745517e6ac08556ce5f50bba9b63ad563e6 Mon Sep 17 00:00:00 2001 From: VLSMB <2047857654@qq.com> Date: Thu, 11 Sep 2025 00:14:46 +0800 Subject: [PATCH 04/11] feat: refactor IterationNodeDataConverter to support Studio DSL primarily --- .../admin/generator/model/workflow/Node.java | 12 + .../generator/model/workflow/NodeType.java | 6 +- .../workflow/nodedata/EmptyNodeData.java | 14 - .../workflow/nodedata/IterationNodeData.java | 315 +++--------------- .../service/dsl/adapters/DifyDSLAdapter.java | 77 ++++- .../dsl/adapters/StudioDSLAdapter.java | 76 ++++- .../dsl/converter/EmptyNodeDataConverter.java | 11 +- .../converter/IterationNodeDataConverter.java | 115 +++---- .../generator/workflow/NodeSection.java | 1 - .../workflow/WorkflowProjectGenerator.java | 6 - .../workflow/sections/EmptyNodeSection.java | 5 +- .../sections/IterationNodeSection.java | 210 ++++++++---- 12 files changed, 408 insertions(+), 440 deletions(-) diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/Node.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/Node.java index 62ca949a3f..c855591fc5 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/Node.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/Node.java @@ -21,6 +21,9 @@ public class Node implements RunnableModel { private String id; + // 如果在循环节点里,则有父节点ID + private String parentId; + private NodeType type; private String title; @@ -54,6 +57,15 @@ public Node setId(String id) { return this; } + public String getParentId() { + return parentId; + } + + public Node setParentId(String parentId) { + this.parentId = parentId; + return this; + } + public NodeType getType() { return type; } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java index 78bf01c7c0..891b8ee5c5 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java @@ -60,7 +60,11 @@ public enum NodeType { ITERATION("iteration", "iteration", "UNSUPPORTED"), - DIFY_ITERATION_START("__empty__", "iteration-start", "UNSUPPORTED"), + EMPTY("empty", "UNSUPPORTED", "UNSUPPORTED"), + + ITERATION_START("iteration-start", "iteration-start", "ParallelStart"), + + ITERATION_END("iteration-end", "iteration-end", "ParallelEnd"), ASSIGNER("assigner", "assigner", "UNSUPPORTED"); diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/EmptyNodeData.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/EmptyNodeData.java index 4efa671acd..6f5fb13ee5 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/EmptyNodeData.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/EmptyNodeData.java @@ -26,18 +26,4 @@ */ public class EmptyNodeData extends NodeData { - private String id; - - public String getId() { - return id; - } - - public void setId(String id) { - this.id = id; - } - - public EmptyNodeData(String id) { - this.id = id; - } - } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/IterationNodeData.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/IterationNodeData.java index 21551bf922..a066b68f17 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/IterationNodeData.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/IterationNodeData.java @@ -16,14 +16,12 @@ package com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata; -import java.util.List; - import com.alibaba.cloud.ai.studio.admin.generator.model.Variable; import com.alibaba.cloud.ai.studio.admin.generator.model.VariableSelector; import com.alibaba.cloud.ai.studio.admin.generator.model.VariableType; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeData; -import org.springframework.util.StringUtils; +import java.util.List; /** * @author vlsmb @@ -31,137 +29,61 @@ */ public class IterationNodeData extends NodeData { - public static Variable getDefaultOutputSchema() { - return new Variable("output", VariableType.ARRAY_STRING); + public static List getDefaultOutputSchemas() { + return List.of(new Variable("state", VariableType.ARRAY_NUMBER), // 剩余未处理的元素索引 + new Variable("index", VariableType.NUMBER), // 迭代索引 + new Variable("isFinished", VariableType.BOOLEAN) // 迭代是否结束 + ); } - private String id; - - private VariableType inputType; - - private VariableType outputType; - - private VariableSelector inputSelector; - - private VariableSelector outputSelector; - - private String startNodeId; - - private String endNodeId; - - private String outputKey; - - private String inputKey; - - private String startNodeName; + private int parallelCount = 1; - private String endNodeName; + private int maxIterationCount = Integer.MAX_VALUE; - // 内部临时变量名 - private String innerArrayKey; + // Dify的迭代索引从0开始,而Studio的从1开始,故需要设置这个值 + private int indexOffset = 0; - private String innerStartFlagKey; + // itemKey和outputKey在Dify中固定,但在Studio中用户可以自定义 + private String itemKey = "item"; - private String innerEndFlagKey; + private String outputKey = "output"; - private String innerItemKey; - - private String innerItemResultKey; - - private String innerIndexKey; - - private Variable output; - - public IterationNodeData(String id, VariableType inputType, VariableType outputType, VariableSelector inputSelector, - VariableSelector outputSelector, String startNodeId, String endNodeId, String inputKey, String outputKey) { - this.id = id; - this.inputType = inputType; - this.outputType = outputType; - this.inputSelector = inputSelector; - this.outputSelector = outputSelector; - this.startNodeId = startNodeId; - this.endNodeId = endNodeId; - this.inputKey = inputKey; - this.outputKey = outputKey; - this.setVarName(id); - this.output = new Variable(outputKey, outputType); - this.setInputs(List.of(inputSelector)); - this.setOutputs(List.of(output)); - } - - @Override - public void setVarName(String varName) { - this.varName = varName; - this.innerArrayKey = this.varName + "_array"; - this.innerStartFlagKey = this.varName + "_start_flag"; - this.innerEndFlagKey = this.varName + "_end_flag"; - this.innerItemKey = this.varName + "_item"; - this.innerItemResultKey = this.varName + "_item_result"; - this.innerIndexKey = this.varName + "_index"; - } - - public String getId() { - return id; - } - - public void setId(String id) { - this.id = id; - } + // 迭代输入的Selector + private VariableSelector inputSelector; - public VariableType getInputType() { - return inputType; - } + // 迭代结果元素的Selector + private VariableSelector resultSelector; - public void setInputType(VariableType inputType) { - this.inputType = inputType; + public int getParallelCount() { + return parallelCount; } - public VariableType getOutputType() { - return outputType; + public void setParallelCount(int parallelCount) { + this.parallelCount = parallelCount; } - public void setOutputType(VariableType outputType) { - this.outputType = outputType; + public int getMaxIterationCount() { + return maxIterationCount; } - public VariableSelector getInputSelector() { - return inputSelector; + public void setMaxIterationCount(int maxIterationCount) { + this.maxIterationCount = maxIterationCount; } - public void setInputSelector(VariableSelector inputSelector) { - this.inputSelector = inputSelector; + public int getIndexOffset() { + return indexOffset; } - public VariableSelector getOutputSelector() { - return outputSelector; + public void setIndexOffset(int indexOffset) { + this.indexOffset = indexOffset; } - public void setOutputSelector(VariableSelector outputSelector) { - this.outputSelector = outputSelector; + public String getItemKey() { + return itemKey; } - public String getStartNodeId() { - return startNodeId; - } - - public void setStartNodeId(String startNodeId) { - this.startNodeId = startNodeId; - } - - public String getEndNodeId() { - return endNodeId; - } - - public void setEndNodeId(String endNodeId) { - this.endNodeId = endNodeId; - } - - public String getInputKey() { - return inputKey; - } - - public void setInputKey(String inputKey) { - this.inputKey = inputKey; + public void setItemKey(String itemKey) { + this.itemKey = itemKey; } public String getOutputKey() { @@ -172,175 +94,20 @@ public void setOutputKey(String outputKey) { this.outputKey = outputKey; } - public String getInnerArrayKey() { - return innerArrayKey; - } - - public void setInnerArrayKey(String innerArrayKey) { - this.innerArrayKey = innerArrayKey; - } - - public String getInnerStartFlagKey() { - return innerStartFlagKey; - } - - public void setInnerStartFlagKey(String innerStartFlagKey) { - this.innerStartFlagKey = innerStartFlagKey; - } - - public String getInnerEndFlagKey() { - return innerEndFlagKey; - } - - public void setInnerEndFlagKey(String innerEndFlagKey) { - this.innerEndFlagKey = innerEndFlagKey; - } - - public String getInnerItemKey() { - return innerItemKey; - } - - public void setInnerItemKey(String innerItemKey) { - this.innerItemKey = innerItemKey; - } - - public String getInnerItemResultKey() { - return innerItemResultKey; - } - - public void setInnerItemResultKey(String innerItemResultKey) { - this.innerItemResultKey = innerItemResultKey; - } - - public String getInnerIndexKey() { - return innerIndexKey; - } - - public void setInnerIndexKey(String innerIndexKey) { - this.innerIndexKey = innerIndexKey; - } - - public Variable getOutput() { - return output; - } - - public void setOutput(Variable output) { - this.output = output; - } - - public String getStartNodeName() { - return startNodeName; - } - - public void setStartNodeName(String startNodeName) { - this.startNodeName = startNodeName; - } - - public String getEndNodeName() { - return endNodeName; + public VariableSelector getInputSelector() { + return inputSelector; } - public void setEndNodeName(String endNodeName) { - this.endNodeName = endNodeName; + public void setInputSelector(VariableSelector inputSelector) { + this.inputSelector = inputSelector; } - public static class Builder { - - private String id; - - private VariableType inputType; - - private VariableType outputType; - - private VariableSelector inputSelector; - - private VariableSelector outputSelector; - - private String startNodeId; - - private String endNodeId; - - private String inputKey; - - private String outputKey; - - // 可以不设置,使用默认值"_item" - private String itemKey; - - // 可以不设置,使用默认值"_index" - private String indexKey; - - public Builder id(String id) { - this.id = id; - return this; - } - - public Builder inputType(VariableType inputType) { - this.inputType = inputType; - return this; - } - - public Builder outputType(VariableType outputType) { - this.outputType = outputType; - return this; - } - - public Builder inputSelector(VariableSelector inputSelector) { - this.inputSelector = inputSelector; - return this; - } - - public Builder outputSelector(VariableSelector outputSelector) { - this.outputSelector = outputSelector; - return this; - } - - public Builder startNodeId(String startNodeId) { - this.startNodeId = startNodeId; - return this; - } - - public Builder endNodeId(String endNodeId) { - this.endNodeId = endNodeId; - return this; - } - - public Builder inputKey(String inputKey) { - this.inputKey = inputKey; - return this; - } - - public Builder outputKey(String outputKey) { - this.outputKey = outputKey; - return this; - } - - public Builder itemKey(String itemKey) { - this.itemKey = itemKey; - return this; - } - - public Builder indexKey(String indexKey) { - this.indexKey = indexKey; - return this; - } - - public IterationNodeData build() { - IterationNodeData data = new IterationNodeData(id, inputType, outputType, inputSelector, outputSelector, - startNodeId, endNodeId, inputKey, outputKey); - if (StringUtils.hasText(this.itemKey)) { - data.setInnerItemKey(this.itemKey); - } - if (StringUtils.hasText(this.indexKey)) { - data.setInnerIndexKey(this.indexKey); - } - return data; - } - + public VariableSelector getResultSelector() { + return resultSelector; } - public static Builder builder() { - return new Builder(); + public void setResultSelector(VariableSelector resultSelector) { + this.resultSelector = resultSelector; } } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/DifyDSLAdapter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/DifyDSLAdapter.java index d6d00f2417..e739f9f79e 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/DifyDSLAdapter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/DifyDSLAdapter.java @@ -20,6 +20,8 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Optional; import java.util.UUID; import java.util.function.BiConsumer; import java.util.stream.Collectors; @@ -40,6 +42,7 @@ import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.NodeDataConverter; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.Serializer; +import com.alibaba.cloud.ai.studio.admin.generator.utils.MapReadUtil; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; @@ -161,19 +164,80 @@ public Workflow mapToWorkflow(Map data) { private Graph constructGraph(Map data) { Graph graph = new Graph(); - List nodes = new ArrayList<>(); + List nodes; List edges = new ArrayList<>(); // convert nodes if (data.containsKey("nodes")) { List> nodeMaps = (List>) data.get("nodes"); - nodes = constructNodes(nodeMaps); + nodes = new ArrayList<>(constructNodes(nodeMaps)); + } + else { + nodes = new ArrayList<>(); } // convert edges if (data.containsKey("edges")) { List> edgeMaps = (List>) data.get("edges"); - edges = constructEdges(edgeMaps); + edges = new ArrayList<>(constructEdges(edgeMaps)); } + Map varNames = nodes.stream() + .collect(Collectors.toMap(Node::getId, n -> n.getData().getVarName())); + Map nodeIdMap = nodes.stream().collect(Collectors.toMap(Node::getId, n -> n)); + Map nodeVarMap = nodes.stream().collect(Collectors.toMap(n -> n.getData().getVarName(), n -> n)); + // 将Edge里的source和target都转换成varName + // 将Iteration节点起始改为iteration_start,并将Iteration节点结束改为iteration_end + edges.forEach(edge -> { + if (NodeType.ITERATION.equals(nodeIdMap.get(edge.getSource()).getType())) { + edge.setSource(varNames.getOrDefault(edge.getSource(), edge.getSource()) + "_start"); + } + else { + edge.setSource(varNames.getOrDefault(edge.getSource(), edge.getSource())); + } + if (NodeType.ITERATION.equals(nodeIdMap.get(edge.getTarget()).getType())) { + edge.setTarget(varNames.getOrDefault(edge.getTarget(), edge.getTarget()) + "_end"); + } + else { + edge.setTarget(varNames.getOrDefault(edge.getTarget(), edge.getTarget())); + } + }); + + // 根据parnetId进行分组,为了给迭代节点的起始节点传递迭代数据 + Map> groupByParentId = nodes.stream() + .filter(node -> Objects.nonNull(node.getParentId())) + .collect(Collectors.groupingBy(Node::getParentId)); + + List finalEdges = edges; + groupByParentId.forEach((parentId, subNodes) -> { + subNodes.forEach(node -> { + if (NodeType.ITERATION_START.equals(node.getType()) || NodeType.ITERATION_END.equals(node.getType())) { + node.setData(nodeIdMap.get(parentId).getData()); + } + }); + // 添加迭代节点的终止节点(Dify的DSL没有提供但为了后续正常转换,这里需要添加) + NodeData nodeData = nodeIdMap.get(parentId).getData(); + Node endNode = new Node(); + endNode.setData(nodeData).setType(NodeType.ITERATION_END); + nodes.add(endNode); + + // 计算每个节点的出度,出度为0的点将与迭代终止节点相连接 + finalEdges.stream().filter(e -> { + Node n = nodeVarMap.get(e.getSource()); + return parentId.equals(n.getParentId()); + }) + .collect(Collectors.groupingBy(Edge::getSource)) + .entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().size())) + .entrySet() + .stream() + .filter(entry -> entry.getValue() == 0) + .map(Map.Entry::getKey) + .forEach(nodeName -> { + Edge edge = new Edge().setSource(nodeName).setTarget(nodeData.getVarName() + "_end"); + finalEdges.add(edge); + }); + }); + graph.setNodes(nodes); graph.setEdges(edges); return graph; @@ -208,7 +272,12 @@ private List constructNodes(List> nodeMaps) { nodeMap.remove("type"); Node node = objectMapper.convertValue(nodeMap, Node.class); // set title and desc - node.setTitle((String) nodeDataMap.get("title")).setDesc((String) nodeDataMap.get("desc")); + String parentId = Optional.ofNullable(MapReadUtil.getMapDeepValue(nodeMap, String.class, "parentId")) + .or(() -> Optional.ofNullable(MapReadUtil.getMapDeepValue(nodeDataMap, String.class, "iteration_id"))) + .orElse(null); + node.setTitle((String) nodeDataMap.get("title")) + .setDesc((String) nodeDataMap.get("desc")) + .setParentId(parentId); // convert node data using specific WorkflowNodeDataConverter @SuppressWarnings("unchecked") diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/StudioDSLAdapter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/StudioDSLAdapter.java index 17e8741058..5c9c338743 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/StudioDSLAdapter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/StudioDSLAdapter.java @@ -42,6 +42,8 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Optional; import java.util.function.BiConsumer; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -118,13 +120,74 @@ public Workflow mapToWorkflow(Map data) { private Graph constructGraph(Map data) { Graph graph = new Graph(); - List> nodeMap = MapReadUtil - .safeCastToListWithMap(MapReadUtil.getMapDeepValue(data, List.class, "config", "nodes")); - List> edgeMap = MapReadUtil - .safeCastToListWithMap(MapReadUtil.getMapDeepValue(data, List.class, "config", "edges")); + List> nodeMap = new ArrayList<>(Optional + .ofNullable( + MapReadUtil.safeCastToListWithMap(MapReadUtil.getMapDeepValue(data, List.class, "config", "nodes"))) + .orElse(List.of())); + List> edgeMap = new ArrayList<>(Optional + .ofNullable( + MapReadUtil.safeCastToListWithMap(MapReadUtil.getMapDeepValue(data, List.class, "config", "edges"))) + .orElse(List.of())); + + List> innerNodeMaps = new ArrayList<>(); + List> innerEdgeMaps = new ArrayList<>(); + + // 展开迭代节点内部的Node和Edge + nodeMap.forEach(map -> { + NodeType type = NodeType.fromStudioValue(MapReadUtil.getMapDeepValue(map, String.class, "type")) + .orElseThrow(); + if (NodeType.ITERATION.equals(type)) { + List> innerNode = MapReadUtil.safeCastToListWithMap( + MapReadUtil.getMapDeepValue(map, List.class, "config", "node_param", "block", "nodes")); + if (innerNode != null) { + innerNodeMaps.addAll(innerNode); + } + List> innerEdge = MapReadUtil.safeCastToListWithMap( + MapReadUtil.getMapDeepValue(map, List.class, "config", "node_param", "block", "edges")); + if (innerEdge != null) { + innerEdgeMaps.addAll(innerEdge); + } + } + }); + nodeMap.addAll(innerNodeMaps); + edgeMap.addAll(innerEdgeMaps); List nodes = this.constructNodes(nodeMap); List edges = this.constructEdges(edgeMap); + + Map varNames = nodes.stream() + .collect(Collectors.toMap(Node::getId, n -> n.getData().getVarName())); + Map nodeIdMap = nodes.stream().collect(Collectors.toMap(Node::getId, n -> n)); + // 将Edge里的source和target都转换成varName + // 将Iteration节点起始改为iteration_start,并将Iteration节点结束改为iteration_end + edges.forEach(edge -> { + if (NodeType.ITERATION.equals(nodeIdMap.get(edge.getSource()).getType())) { + edge.setSource(varNames.getOrDefault(edge.getSource(), edge.getSource()) + "_start"); + } + else { + edge.setSource(varNames.getOrDefault(edge.getSource(), edge.getSource())); + } + if (NodeType.ITERATION.equals(nodeIdMap.get(edge.getTarget()).getType())) { + edge.setTarget(varNames.getOrDefault(edge.getTarget(), edge.getTarget()) + "_end"); + } + else { + edge.setTarget(varNames.getOrDefault(edge.getTarget(), edge.getTarget())); + } + }); + + // 根据parnetId进行分组,为了给迭代节点的起始节点传递迭代数据 + Map> groupByParentId = nodes.stream() + .filter(node -> Objects.nonNull(node.getParentId())) + .collect(Collectors.groupingBy(Node::getParentId)); + + groupByParentId.forEach((parentId, subNodes) -> { + subNodes.forEach(node -> { + if (NodeType.ITERATION_START.equals(node.getType()) || NodeType.ITERATION_END.equals(node.getType())) { + node.setData(nodeIdMap.get(parentId).getData()); + } + }); + }); + graph.setNodes(nodes); graph.setEdges(edges); return graph; @@ -153,7 +216,10 @@ private List constructNodes(List> nodeMaps) { // 构造Node Node node = new Node(); - node.setId(nodeId).setType(nodeType).setTitle(nodeTitle); + node.setId(nodeId) + .setType(nodeType) + .setTitle(nodeTitle) + .setParentId(MapReadUtil.getMapDeepValue(nodeMap, String.class, "parent_id")); // convert node data using specific WorkflowNodeDataConverter @SuppressWarnings("unchecked") diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/EmptyNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/EmptyNodeDataConverter.java index eb35e2a241..1588fc70f1 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/EmptyNodeDataConverter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/EmptyNodeDataConverter.java @@ -37,7 +37,7 @@ public class EmptyNodeDataConverter extends AbstractNodeDataConverter() { + DIFY(new DialectConverter<>() { @Override public Boolean supportDialect(DSLDialectType dialectType) { return dialectType.equals(DSLDialectType.DIFY); @@ -45,8 +45,7 @@ public Boolean supportDialect(DSLDialectType dialectType) { @Override public EmptyNodeData parse(Map data) throws JsonProcessingException { - String id = (String) data.get("id"); - return new EmptyNodeData(id); + return new EmptyNodeData(); } @Override @@ -76,12 +75,14 @@ protected List> getDialectConverters() { @Override public Boolean supportNodeType(NodeType nodeType) { - return nodeType.equals(NodeType.DIFY_ITERATION_START); + // 迭代节点的起始节点与迭代节点共享一个data,故转换时不需要提取数据 + return NodeType.EMPTY.equals(nodeType) || NodeType.ITERATION_START.equals(nodeType) + || NodeType.ITERATION_END.equals(nodeType); } @Override public String generateVarName(int count) { - return "__empty__node_" + count; + return "emptyNode" + count; } } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/IterationNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/IterationNodeDataConverter.java index 95786ac661..a304ae30e5 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/IterationNodeDataConverter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/IterationNodeDataConverter.java @@ -20,17 +20,14 @@ import java.util.Map; import java.util.Optional; import java.util.function.BiConsumer; -import java.util.regex.Matcher; -import java.util.regex.Pattern; import java.util.stream.Stream; -import com.alibaba.cloud.ai.studio.admin.generator.model.Variable; import com.alibaba.cloud.ai.studio.admin.generator.model.VariableSelector; -import com.alibaba.cloud.ai.studio.admin.generator.model.VariableType; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeType; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata.IterationNodeData; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.AbstractNodeDataConverter; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; +import com.alibaba.cloud.ai.studio.admin.generator.utils.MapReadUtil; import com.fasterxml.jackson.core.JsonProcessingException; import org.springframework.stereotype.Component; @@ -44,7 +41,7 @@ public class IterationNodeDataConverter extends AbstractNodeDataConverter() { + DIFY(new DialectConverter<>() { @Override public Boolean supportDialect(DSLDialectType dialectType) { return DSLDialectType.DIFY.equals(dialectType); @@ -52,44 +49,51 @@ public Boolean supportDialect(DSLDialectType dialectType) { @Override public IterationNodeData parse(Map data) throws JsonProcessingException { - // 获取输入输出的类型,从 array[xxx] 中提取xxx - Pattern typePattern = Pattern.compile("array\\[(.*?)]"); - VariableType inputType = VariableType.OBJECT; - VariableType outputType = VariableType.OBJECT; - Matcher inputTypeMatcher = typePattern - .matcher((String) data.getOrDefault("iterator_input_type", "object")); - Matcher outputTypeMatcher = typePattern.matcher((String) data.getOrDefault("output_type", "object")); - if (inputTypeMatcher.find()) { - inputType = VariableType.fromDifyValue(inputTypeMatcher.group(1)).orElse(VariableType.OBJECT); - } - if (outputTypeMatcher.find()) { - outputType = VariableType.fromDifyValue(outputTypeMatcher.group(1)).orElse(VariableType.OBJECT); - } - List inputSelector = (List) data.get("iterator_selector"); - List outputSelector = (List) data.get("output_selector"); - String startNodeId = (String) data.get("start_node_id"); - String id = (String) data.get("id"); - // 规定输出结果的节点为最后一个节点 - String endNodeId = outputSelector.get(0); - // 返回 - return IterationNodeData.builder() - .id(id) - .inputType(inputType) - .outputType(outputType) - .inputSelector(new VariableSelector(inputSelector.get(0), inputSelector.get(1), "")) - .outputSelector(new VariableSelector(outputSelector.get(0), outputSelector.get(1), "")) - .startNodeId(startNodeId) - .endNodeId(endNodeId) - .inputKey(id + "_input") - .outputKey(id + "_output") - .build(); + IterationNodeData nodeData = new IterationNodeData(); + int parallelCount = Optional + .ofNullable(MapReadUtil.getMapDeepValue(data, Integer.class, "parallel_nums")) + .orElse(1); + nodeData.setParallelCount(parallelCount); + + List inputSelectorList = Optional + .ofNullable(MapReadUtil.safeCastToList( + MapReadUtil.getMapDeepValue(data, List.class, "iterator_selector"), String.class)) + .orElse(List.of("unknown", "unknown")); + nodeData.setInputSelector(new VariableSelector(inputSelectorList.get(0), inputSelectorList.get(1))); + + List outputSelectorList = Optional + .ofNullable(MapReadUtil + .safeCastToList(MapReadUtil.getMapDeepValue(data, List.class, "output_selector"), String.class)) + .orElse(List.of("unknown", "unknown")); + nodeData.setResultSelector(new VariableSelector(outputSelectorList.get(0), outputSelectorList.get(1))); + + return nodeData; } @Override public Map dump(IterationNodeData nodeData) { - return Map.of(); + throw new UnsupportedOperationException(); } - }), CUSTOM(AbstractNodeDataConverter.defaultCustomDialectConverter(IterationNodeData.class)); + }) + + , STUDIO(new DialectConverter<>() { + @Override + public Boolean supportDialect(DSLDialectType dialectType) { + return DSLDialectType.STUDIO.equals(dialectType); + } + + @Override + public IterationNodeData parse(Map data) throws JsonProcessingException { + throw new UnsupportedOperationException(); + } + + @Override + public Map dump(IterationNodeData nodeData) { + throw new UnsupportedOperationException(); + } + }) + + , CUSTOM(AbstractNodeDataConverter.defaultCustomDialectConverter(IterationNodeData.class)); private final DialectConverter dialectConverter; @@ -124,39 +128,14 @@ public String generateVarName(int count) { public BiConsumer> postProcessConsumer(DSLDialectType dialectType) { return switch (dialectType) { case DIFY -> emptyProcessConsumer().andThen((nodeData, idToVarName) -> { - nodeData - .setOutputKey(nodeData.getVarName() + "_" + IterationNodeData.getDefaultOutputSchema().getName()); - nodeData.setOutputs(List.of(IterationNodeData.getDefaultOutputSchema())); - nodeData.setOutput(IterationNodeData.getDefaultOutputSchema()); - }).andThen(super.postProcessConsumer(dialectType)).andThen((iterationNodeData, varNames) -> { - // 等待所有的节点都生成了变量名后,补充迭代节点的起始名称 - iterationNodeData - .setStartNodeName(varNames.getOrDefault(iterationNodeData.getStartNodeId(), "unknown")); - iterationNodeData.setEndNodeName(varNames.getOrDefault(iterationNodeData.getEndNodeId(), "unknown")); - - // 更新迭代节点的输入Key - VariableSelector inputSelector = iterationNodeData.getInputs().get(0); - iterationNodeData.setInputKey(inputSelector.getNameInCode()); - - // 更新迭代节点的ResultKey - VariableSelector outputSelector = iterationNodeData.getOutputSelector(); - iterationNodeData.setInnerItemResultKey( - Optional.ofNullable(varNames.get(outputSelector.getNamespace())).orElse("unknown") + "_" - + outputSelector.getName()); + nodeData.setInputs(List.of(nodeData.getInputSelector(), nodeData.getResultSelector())); + }).andThen(super.postProcessConsumer(dialectType)).andThen((nodeData, idToVarName) -> { + nodeData.setInputSelector(nodeData.getInputs().get(0)); + nodeData.setResultSelector(nodeData.getInputs().get(1)); + nodeData.setInputs(null); }); default -> super.postProcessConsumer(dialectType); }; } - @Override - public Stream extractWorkflowVars(IterationNodeData nodeData) { - return Stream.concat(nodeData.getOutputs().stream(), - Stream.of(new Variable(nodeData.getInnerArrayKey(), VariableType.STRING), - new Variable(nodeData.getInnerStartFlagKey(), VariableType.STRING), - new Variable(nodeData.getInnerEndFlagKey(), VariableType.STRING), - new Variable(nodeData.getInnerItemKey(), nodeData.getInputType()), - new Variable(nodeData.getInnerIndexKey(), VariableType.NUMBER), - new Variable(nodeData.getInnerItemResultKey(), nodeData.getOutputType()))); - } - } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/NodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/NodeSection.java index b0990d50ce..d0059e33c2 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/NodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/NodeSection.java @@ -17,7 +17,6 @@ import java.io.InputStream; import java.util.List; -import java.util.Map; import java.util.function.Supplier; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.Edge; diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java index fa931edaff..8ff840bd41 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java @@ -196,12 +196,6 @@ private String renderNodeSections(List nodes, Map varNames } private String renderEdgeSections(List edges, List nodes, Map varNames) { - // 将Edge里的source和target都转换成varName - edges.forEach(edge -> { - edge.setSource(varNames.getOrDefault(edge.getSource(), edge.getSource())); - edge.setTarget(varNames.getOrDefault(edge.getTarget(), edge.getTarget())); - }); - // nodeVarName -> node的映射 Map nodeMap = nodes.stream() .collect(Collectors.toMap(node -> node.getData().getVarName(), Function.identity())); diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/EmptyNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/EmptyNodeSection.java index 126b988846..a116265d39 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/EmptyNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/EmptyNodeSection.java @@ -32,14 +32,13 @@ public class EmptyNodeSection implements NodeSection { @Override public boolean support(NodeType nodeType) { - return nodeType.equals(NodeType.DIFY_ITERATION_START); + return nodeType.equals(NodeType.EMPTY); } @Override public String render(Node node, String varName) { - EmptyNodeData data = (EmptyNodeData) node.getData(); StringBuilder sb = new StringBuilder(); - String id = data.getId(); + String id = node.getId(); sb.append("// —— Empty Node [").append(id).append("] ——\n"); sb.append("stateGraph.addNode(\"") .append(varName) diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/IterationNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/IterationNodeSection.java index 05d858989a..5cd55a4b92 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/IterationNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/IterationNodeSection.java @@ -16,90 +16,182 @@ package com.alibaba.cloud.ai.studio.admin.generator.service.generator.workflow.sections; -import com.alibaba.cloud.ai.studio.admin.generator.model.VariableType; +import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.Edge; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.Node; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeType; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata.IterationNodeData; +import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; import com.alibaba.cloud.ai.studio.admin.generator.service.generator.workflow.NodeSection; import org.springframework.stereotype.Component; +import java.util.List; + /** * @author vlsmb * @since 2025/7/23 */ +// TODO: 支持并行模式、错误处理,支持Studio的默认输入值 @Component public class IterationNodeSection implements NodeSection { @Override public boolean support(NodeType nodeType) { - return nodeType.equals(NodeType.ITERATION); + return NodeType.ITERATION.equals(nodeType); } @Override public String render(Node node, String varName) { - // 构建Iteration.Start -> Iteration -> Iteration.End节点 - IterationNodeData data = (IterationNodeData) node.getData(); - StringBuilder sb = new StringBuilder(); - - // 获取输入输出的泛型 - String inputType = "Map"; - String outputType = "Map"; - if (VariableType.STRING.equals(data.getInputType())) { - inputType = "String"; + // 迭代节点在转换为Workflow的时候已经拆分为多个节点,故本方法返回空 + return ""; + } + + @Override + public String renderEdges(IterationNodeData nodeData, List edges) { + return ""; + } + + // 规定迭代节点的start为iterationVarName_start,end为iterationVarName_end + + @Component + public static class IterationStartNodeSection implements NodeSection { + + @Override + public boolean support(NodeType nodeType) { + return NodeType.ITERATION_START.equals(nodeType); + } + + @Override + public String render(Node node, String varName) { + IterationNodeData nodeData = ((IterationNodeData) node.getData()); + return String.format(""" + // Iteration [%s] Start Node + stateGraph.addNode("%s", AsyncNodeAction.node_async( + createIterationStartAction("%s", "%s", "%s", "%s", "%s", %d) + )); + + """, node.getId(), nodeData.getVarName() + "_start", nodeData.getResultSelector().getNameInCode(), + nodeData.getVarName() + "_state", nodeData.getItemKey(), nodeData.getVarName() + "_index", + nodeData.getVarName() + "_isFinished", nodeData.getIndexOffset()); + } + + // TODO: 添加辅助节点以支持迭代起始节点并行 + @Override + public String renderEdges(IterationNodeData nodeData, List edges) { + Edge edge = edges.get(0); + return String.format(""" + // Iteration [%s] Start Edge + stateGraph.addConditionalEdges("%s", AsyncEdgeAction.edge_async( + state -> { + Boolean b = state.value("%s", false); + return b ? "end" : "iteration"; + } + ), Map.of("end", "%s", "iteration", "%s")); + + """, nodeData.getVarName(), nodeData.getVarName() + "_start", nodeData.getVarName() + "_isFinished", + nodeData.getVarName() + "_end", edge.getTarget()); } - else if (VariableType.NUMBER.equals(data.getInputType())) { - inputType = "Number"; + + @Override + public String assistMethodCode(DSLDialectType dialectType) { + return """ + private NodeAction createIterationStartAction( + String arrayKey, String stateKey, + String itemKey, String indexKey, String flagKey, + int indexOffset) { + return state -> { + Object arrayObj = state.value(arrayKey).orElse(List.of()); + List stateList = state.value(stateKey, List.class).orElse(null); + + List arrayList; + if (stateList == null) { + // the first time in iteration + if (arrayObj instanceof List) { + arrayList = new ArrayList<>((List) arrayObj); + } else if (arrayObj.getClass().isArray()) { + arrayList = new ArrayList<>(Arrays.stream((Object[])arrayObj).toList()); + } else { + throw new IllegalStateException("value {" + arrayKey + "} is not an array!"); + } + int len = arrayList.size(); + stateList = new ArrayList<>(); + for (int i = 0; i < len; i++) { + stateList.add(i); + } + } else { + arrayList = (List) arrayObj; + } + + if(stateList.isEmpty()) { + return Map.of(flagKey, true); + } + int index = stateList.get(0); + Object item = arrayList.get(index); + stateList.remove(0); + return Map.of(arrayKey, arrayList, stateKey, stateList, itemKey, item, + indexKey, index + indexOffset, flagKey, false); + }; + } + """; + } + + } + + @Component + public static class IterationEndNodeSection implements NodeSection { + + @Override + public boolean support(NodeType nodeType) { + return NodeType.ITERATION_END.equals(nodeType); } - if (VariableType.STRING.equals(data.getOutputType())) { - outputType = "String"; + + @Override + public String render(Node node, String varName) { + IterationNodeData nodeData = ((IterationNodeData) node.getData()); + return String.format(""" + // Iteration [%s] End Node + stateGraph.addNode("%s", AsyncNodeAction.node_async( + createIterationEndAction("%s", "%s", "%s") + )); + + """, nodeData.getVarName(), nodeData.getVarName() + "_end", nodeData.getVarName() + "_isFinished", + nodeData.getResultSelector().getNameInCode(), nodeData.getOutputKey()); } - if (VariableType.NUMBER.equals(data.getOutputType())) { - outputType = "Number"; + + // TODO: 添加辅助节点以支持迭代终止节点并行 + @Override + public String renderEdges(IterationNodeData nodeData, List edges) { + Edge edge = edges.get(0); + return String.format(""" + // Iteration [%s] End Edge + stateGraph.addConditionalEdges("%s", AsyncEdgeAction.edge_async( + state -> { + Boolean b = state.value("%s", false); + return b ? "finish" : "start"; + } + ), Map.of("finish", "%s", "start", "%s")); + + """, nodeData.getVarName(), nodeData.getVarName() + "_end", nodeData.getVarName() + "_isFinished", + edge.getTarget(), nodeData.getVarName() + "_start"); + } + + @Override + public String assistMethodCode(DSLDialectType dialectType) { + return """ + private NodeAction createIterationEndAction(String flagKey, String resultKey, String outputKey) { + return state -> { + boolean flag = state.value(flagKey, Boolean.class).orElse(true); + List outputList = state.value(outputKey, List.class).orElse(new ArrayList<>()); + if(flag) { + return Map.of(outputKey, outputList); + } + outputList.add(state.value(resultKey).orElseThrow()); + return Map.of(outputKey, outputList); + }; + } + """; } - sb.append("// —— IterationNode [").append(data.getId()).append("] ——\n"); - sb.append("IterationNode.<") - .append(inputType) - .append(", ") - .append(outputType) - .append(">converter()\n") - .append(".subGraphStartNodeName(\"") - .append(data.getStartNodeName()) - .append("\")\n") - .append(".subGraphEndNodeName(\"") - .append(data.getEndNodeName()) - .append("\")\n") - .append(".tempArrayKey(\"") - .append(data.getInnerArrayKey()) - .append("\")\n") - .append(".tempStartFlagKey(\"") - .append(data.getInnerStartFlagKey()) - .append("\")\n") - .append(".tempEndFlagKey(\"") - .append(data.getInnerEndFlagKey()) - .append("\")\n") - .append(".tempIndexKey(\"") - .append(data.getInnerIndexKey()) - .append("\")\n") - .append(".iteratorItemKey(\"") - .append(data.getInnerItemKey()) - .append("\")\n") - .append(".iteratorResultKey(\"") - .append(data.getInnerItemResultKey()) - .append("\")\n") - .append(".inputArrayJsonKey(\"") - .append(data.getInputKey()) - .append("\")\n") - .append(".outputArrayJsonKey(\"") - .append(data.getOutputKey()) - .append("\")\n") - .append(".appendToStateGraph(stateGraph, \"") - .append(varName) - .append("\", \"") - .append(varName) - .append("_out\");\n\n"); - return sb.toString(); } } From b63d52de09b358821de02cd176bed65b89db5181 Mon Sep 17 00:00:00 2001 From: VLSMB <2047857654@qq.com> Date: Thu, 11 Sep 2025 14:12:37 +0800 Subject: [PATCH 05/11] fix: fix dify iteration generate bugs --- .../workflow/nodedata/IterationNodeData.java | 17 ++++- .../service/dsl/adapters/DifyDSLAdapter.java | 76 +++++++++---------- .../dsl/adapters/StudioDSLAdapter.java | 39 +++++----- .../converter/IterationNodeDataConverter.java | 11 +++ .../workflow/WorkflowProjectGenerator.java | 2 +- 5 files changed, 87 insertions(+), 58 deletions(-) diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/IterationNodeData.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/IterationNodeData.java index a066b68f17..1b7262c7d0 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/IterationNodeData.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/IterationNodeData.java @@ -43,7 +43,7 @@ public static List getDefaultOutputSchemas() { // Dify的迭代索引从0开始,而Studio的从1开始,故需要设置这个值 private int indexOffset = 0; - // itemKey和outputKey在Dify中固定,但在Studio中用户可以自定义 + // itemKey和outputKey的后缀在Dify中固定,但在Studio中用户可以自定义 private String itemKey = "item"; private String outputKey = "output"; @@ -54,6 +54,21 @@ public static List getDefaultOutputSchemas() { // 迭代结果元素的Selector private VariableSelector resultSelector; + public IterationNodeData(IterationNodeData other) { + parallelCount = other.parallelCount; + maxIterationCount = other.maxIterationCount; + indexOffset = other.indexOffset; + itemKey = other.itemKey; + outputKey = other.outputKey; + inputSelector = other.inputSelector; + resultSelector = other.resultSelector; + setVarName(other.getVarName()); + } + + public IterationNodeData() { + super(); + } + public int getParallelCount() { return parallelCount; } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/DifyDSLAdapter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/DifyDSLAdapter.java index e739f9f79e..b67f50deb4 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/DifyDSLAdapter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/DifyDSLAdapter.java @@ -22,6 +22,7 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.UUID; import java.util.function.BiConsumer; import java.util.stream.Collectors; @@ -38,6 +39,7 @@ import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeData; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeType; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.Workflow; +import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata.IterationNodeData; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.AbstractDSLAdapter; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.NodeDataConverter; @@ -165,7 +167,8 @@ public Workflow mapToWorkflow(Map data) { private Graph constructGraph(Map data) { Graph graph = new Graph(); List nodes; - List edges = new ArrayList<>(); + List edges; + // convert nodes if (data.containsKey("nodes")) { List> nodeMaps = (List>) data.get("nodes"); @@ -174,68 +177,65 @@ private Graph constructGraph(Map data) { else { nodes = new ArrayList<>(); } + // convert edges if (data.containsKey("edges")) { List> edgeMaps = (List>) data.get("edges"); edges = new ArrayList<>(constructEdges(edgeMaps)); } + else { + edges = new ArrayList<>(); + } Map varNames = nodes.stream() .collect(Collectors.toMap(Node::getId, n -> n.getData().getVarName())); Map nodeIdMap = nodes.stream().collect(Collectors.toMap(Node::getId, n -> n)); Map nodeVarMap = nodes.stream().collect(Collectors.toMap(n -> n.getData().getVarName(), n -> n)); - // 将Edge里的source和target都转换成varName - // 将Iteration节点起始改为iteration_start,并将Iteration节点结束改为iteration_end - edges.forEach(edge -> { - if (NodeType.ITERATION.equals(nodeIdMap.get(edge.getSource()).getType())) { - edge.setSource(varNames.getOrDefault(edge.getSource(), edge.getSource()) + "_start"); - } - else { - edge.setSource(varNames.getOrDefault(edge.getSource(), edge.getSource())); - } - if (NodeType.ITERATION.equals(nodeIdMap.get(edge.getTarget()).getType())) { - edge.setTarget(varNames.getOrDefault(edge.getTarget(), edge.getTarget()) + "_end"); - } - else { - edge.setTarget(varNames.getOrDefault(edge.getTarget(), edge.getTarget())); - } - }); // 根据parnetId进行分组,为了给迭代节点的起始节点传递迭代数据 Map> groupByParentId = nodes.stream() .filter(node -> Objects.nonNull(node.getParentId())) .collect(Collectors.groupingBy(Node::getParentId)); - List finalEdges = edges; + // 统计具有出度的节点 + Set nodeIdHasOut = edges.stream().map(Edge::getSource).collect(Collectors.toSet()); + groupByParentId.forEach((parentId, subNodes) -> { subNodes.forEach(node -> { - if (NodeType.ITERATION_START.equals(node.getType()) || NodeType.ITERATION_END.equals(node.getType())) { - node.setData(nodeIdMap.get(parentId).getData()); + if (NodeType.ITERATION_START.equals(node.getType())) { + IterationNodeData nodeData = new IterationNodeData( + (IterationNodeData) nodeIdMap.get(parentId).getData()); + nodeData.setVarName(nodeIdMap.get(parentId).getData().getVarName() + "_start"); + varNames.put(node.getId(), nodeData.getVarName()); + node.setData(nodeData); + } + else if (NodeType.ITERATION_END.equals(node.getType())) { + IterationNodeData nodeData = new IterationNodeData( + (IterationNodeData) nodeIdMap.get(parentId).getData()); + nodeData.setVarName(nodeIdMap.get(parentId).getData().getVarName() + "_end"); + varNames.put(node.getId(), nodeData.getVarName()); + node.setData(nodeData); } }); + // 添加迭代节点的终止节点(Dify的DSL没有提供但为了后续正常转换,这里需要添加) - NodeData nodeData = nodeIdMap.get(parentId).getData(); + NodeData nodeData = new IterationNodeData((IterationNodeData) nodeIdMap.get(parentId).getData()); + nodeData.setVarName(nodeData.getVarName() + "_end"); Node endNode = new Node(); - endNode.setData(nodeData).setType(NodeType.ITERATION_END); + endNode.setData(nodeData).setType(NodeType.ITERATION_END).setParentId(parentId); nodes.add(endNode); // 计算每个节点的出度,出度为0的点将与迭代终止节点相连接 - finalEdges.stream().filter(e -> { - Node n = nodeVarMap.get(e.getSource()); - return parentId.equals(n.getParentId()); - }) - .collect(Collectors.groupingBy(Edge::getSource)) - .entrySet() - .stream() - .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().size())) - .entrySet() - .stream() - .filter(entry -> entry.getValue() == 0) - .map(Map.Entry::getKey) - .forEach(nodeName -> { - Edge edge = new Edge().setSource(nodeName).setTarget(nodeData.getVarName() + "_end"); - finalEdges.add(edge); - }); + subNodes.stream().map(Node::getId).filter(id -> !nodeIdHasOut.contains(id)).forEach(id -> { + Edge newEdge = new Edge().setSource(id).setTarget(nodeData.getVarName()); + edges.add(newEdge); + }); + }); + + // 将Edge里的source和target都转换成varName + edges.forEach(edge -> { + edge.setSource(varNames.getOrDefault(edge.getSource(), edge.getSource())); + edge.setTarget(varNames.getOrDefault(edge.getTarget(), edge.getTarget())); }); graph.setNodes(nodes); diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/StudioDSLAdapter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/StudioDSLAdapter.java index 5c9c338743..6ded474c0d 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/StudioDSLAdapter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/StudioDSLAdapter.java @@ -26,6 +26,7 @@ import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeData; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeType; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.Workflow; +import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata.IterationNodeData; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.AbstractDSLAdapter; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.NodeDataConverter; @@ -52,6 +53,7 @@ * @author vlsmb * @since 2025/8/27 */ +// TODO: 与DifyDSLAdapter合并一些重复代码 @Component public class StudioDSLAdapter extends AbstractDSLAdapter { @@ -158,22 +160,6 @@ private Graph constructGraph(Map data) { Map varNames = nodes.stream() .collect(Collectors.toMap(Node::getId, n -> n.getData().getVarName())); Map nodeIdMap = nodes.stream().collect(Collectors.toMap(Node::getId, n -> n)); - // 将Edge里的source和target都转换成varName - // 将Iteration节点起始改为iteration_start,并将Iteration节点结束改为iteration_end - edges.forEach(edge -> { - if (NodeType.ITERATION.equals(nodeIdMap.get(edge.getSource()).getType())) { - edge.setSource(varNames.getOrDefault(edge.getSource(), edge.getSource()) + "_start"); - } - else { - edge.setSource(varNames.getOrDefault(edge.getSource(), edge.getSource())); - } - if (NodeType.ITERATION.equals(nodeIdMap.get(edge.getTarget()).getType())) { - edge.setTarget(varNames.getOrDefault(edge.getTarget(), edge.getTarget()) + "_end"); - } - else { - edge.setTarget(varNames.getOrDefault(edge.getTarget(), edge.getTarget())); - } - }); // 根据parnetId进行分组,为了给迭代节点的起始节点传递迭代数据 Map> groupByParentId = nodes.stream() @@ -182,12 +168,29 @@ private Graph constructGraph(Map data) { groupByParentId.forEach((parentId, subNodes) -> { subNodes.forEach(node -> { - if (NodeType.ITERATION_START.equals(node.getType()) || NodeType.ITERATION_END.equals(node.getType())) { - node.setData(nodeIdMap.get(parentId).getData()); + if (NodeType.ITERATION_START.equals(node.getType())) { + IterationNodeData nodeData = new IterationNodeData( + (IterationNodeData) nodeIdMap.get(parentId).getData()); + nodeData.setVarName(nodeIdMap.get(parentId).getData().getVarName() + "_start"); + varNames.put(node.getId(), nodeData.getVarName()); + node.setData(nodeData); + } + else if (NodeType.ITERATION_END.equals(node.getType())) { + IterationNodeData nodeData = new IterationNodeData( + (IterationNodeData) nodeIdMap.get(parentId).getData()); + nodeData.setVarName(nodeIdMap.get(parentId).getData().getVarName() + "_end"); + varNames.put(node.getId(), nodeData.getVarName()); + node.setData(nodeData); } }); }); + // 将Edge里的source和target都转换成varName + edges.forEach(edge -> { + edge.setSource(varNames.getOrDefault(edge.getSource(), edge.getSource())); + edge.setTarget(varNames.getOrDefault(edge.getTarget(), edge.getTarget())); + }); + graph.setNodes(nodes); graph.setEdges(edges); return graph; diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/IterationNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/IterationNodeDataConverter.java index a304ae30e5..5a0459e17d 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/IterationNodeDataConverter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/IterationNodeDataConverter.java @@ -22,7 +22,9 @@ import java.util.function.BiConsumer; import java.util.stream.Stream; +import com.alibaba.cloud.ai.studio.admin.generator.model.Variable; import com.alibaba.cloud.ai.studio.admin.generator.model.VariableSelector; +import com.alibaba.cloud.ai.studio.admin.generator.model.VariableType; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeType; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata.IterationNodeData; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.AbstractNodeDataConverter; @@ -129,10 +131,19 @@ public BiConsumer> postProcessConsumer(DS return switch (dialectType) { case DIFY -> emptyProcessConsumer().andThen((nodeData, idToVarName) -> { nodeData.setInputs(List.of(nodeData.getInputSelector(), nodeData.getResultSelector())); + + nodeData.setOutputs(Stream + .of(IterationNodeData.getDefaultOutputSchemas(), + List.of(new Variable(nodeData.getItemKey(), VariableType.OBJECT), + new Variable(nodeData.getOutputKey(), VariableType.ARRAY_OBJECT))) + .flatMap(List::stream) + .toList()); }).andThen(super.postProcessConsumer(dialectType)).andThen((nodeData, idToVarName) -> { nodeData.setInputSelector(nodeData.getInputs().get(0)); nodeData.setResultSelector(nodeData.getInputs().get(1)); nodeData.setInputs(null); + nodeData.setItemKey(nodeData.getVarName() + "_" + nodeData.getItemKey()); + nodeData.setOutputKey(nodeData.getVarName() + "_" + nodeData.getOutputKey()); }); default -> super.postProcessConsumer(dialectType); }; diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java index 8ff840bd41..2e13c42663 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java @@ -284,7 +284,7 @@ private String renderImportSection(Workflow workflow) { List.of("com.alibaba.cloud.ai.graph.node.VariableAggregatorNode", "java.util.stream.Collectors")), Map.entry(NodeType.ASSIGNER, List.of("com.alibaba.cloud.ai.graph.node.AssignerNode")), - Map.entry(NodeType.ITERATION, List.of("com.alibaba.cloud.ai.graph.node.IterationNode")), + Map.entry(NodeType.ITERATION, List.of("java.util.ArrayList", "java.util.Arrays")), Map.entry(NodeType.END, List.of("java.util.stream.Stream", "java.util.stream.Collectors", "org.springframework.ai.chat.prompt.PromptTemplate"))); From 129597e69e10d47cef7c4062b899a9f5655c3e40 Mon Sep 17 00:00:00 2001 From: VLSMB <2047857654@qq.com> Date: Thu, 11 Sep 2025 15:27:23 +0800 Subject: [PATCH 06/11] fix: fix dify iteration bugs --- .../workflow/nodedata/IterationNodeData.java | 9 +++++++++ .../service/dsl/adapters/DifyDSLAdapter.java | 12 +++++++++++- .../dsl/adapters/StudioDSLAdapter.java | 11 +++++++++++ .../sections/IterationNodeSection.java | 19 +++++++++++-------- 4 files changed, 42 insertions(+), 9 deletions(-) diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/IterationNodeData.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/IterationNodeData.java index 1b7262c7d0..9c952b1aef 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/IterationNodeData.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/IterationNodeData.java @@ -36,6 +36,9 @@ public static List getDefaultOutputSchemas() { ); } + // NodeData的来源节点名称 + private final String sourceVarName; + private int parallelCount = 1; private int maxIterationCount = Integer.MAX_VALUE; @@ -62,11 +65,17 @@ public IterationNodeData(IterationNodeData other) { outputKey = other.outputKey; inputSelector = other.inputSelector; resultSelector = other.resultSelector; + sourceVarName = other.getVarName(); setVarName(other.getVarName()); } public IterationNodeData() { super(); + sourceVarName = null; + } + + public String getSourceVarName() { + return sourceVarName; } public int getParallelCount() { diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/DifyDSLAdapter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/DifyDSLAdapter.java index b67f50deb4..75690d4bcb 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/DifyDSLAdapter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/DifyDSLAdapter.java @@ -190,7 +190,6 @@ private Graph constructGraph(Map data) { Map varNames = nodes.stream() .collect(Collectors.toMap(Node::getId, n -> n.getData().getVarName())); Map nodeIdMap = nodes.stream().collect(Collectors.toMap(Node::getId, n -> n)); - Map nodeVarMap = nodes.stream().collect(Collectors.toMap(n -> n.getData().getVarName(), n -> n)); // 根据parnetId进行分组,为了给迭代节点的起始节点传递迭代数据 Map> groupByParentId = nodes.stream() @@ -238,6 +237,17 @@ else if (NodeType.ITERATION_END.equals(node.getType())) { edge.setTarget(varNames.getOrDefault(edge.getTarget(), edge.getTarget())); }); + // 将Iteration节点起始改为iteration_start,并将Iteration节点结束改为iteration_end + Map nodeVarMap = nodes.stream().collect(Collectors.toMap(n -> n.getData().getVarName(), n -> n)); + edges.forEach(edge -> { + if (NodeType.ITERATION.equals(nodeVarMap.get(edge.getSource()).getType())) { + edge.setSource(edge.getSource() + "_end"); + } + if (NodeType.ITERATION.equals(nodeVarMap.get(edge.getTarget()).getType())) { + edge.setTarget(edge.getTarget() + "_start"); + } + }); + graph.setNodes(nodes); graph.setEdges(edges); return graph; diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/StudioDSLAdapter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/StudioDSLAdapter.java index 6ded474c0d..4040058649 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/StudioDSLAdapter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/StudioDSLAdapter.java @@ -191,6 +191,17 @@ else if (NodeType.ITERATION_END.equals(node.getType())) { edge.setTarget(varNames.getOrDefault(edge.getTarget(), edge.getTarget())); }); + // 将Iteration节点起始改为iteration_start,并将Iteration节点结束改为iteration_end + Map nodeVarMap = nodes.stream().collect(Collectors.toMap(n -> n.getData().getVarName(), n -> n)); + edges.forEach(edge -> { + if (NodeType.ITERATION.equals(nodeVarMap.get(edge.getSource()).getType())) { + edge.setSource(edge.getSource() + "_end"); + } + if (NodeType.ITERATION.equals(nodeVarMap.get(edge.getTarget()).getType())) { + edge.setTarget(edge.getTarget() + "_start"); + } + }); + graph.setNodes(nodes); graph.setEdges(edges); return graph; diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/IterationNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/IterationNodeSection.java index 5cd55a4b92..fbdcea67a4 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/IterationNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/IterationNodeSection.java @@ -70,9 +70,10 @@ public String render(Node node, String varName) { createIterationStartAction("%s", "%s", "%s", "%s", "%s", %d) )); - """, node.getId(), nodeData.getVarName() + "_start", nodeData.getResultSelector().getNameInCode(), - nodeData.getVarName() + "_state", nodeData.getItemKey(), nodeData.getVarName() + "_index", - nodeData.getVarName() + "_isFinished", nodeData.getIndexOffset()); + """, node.getId(), varName, nodeData.getInputSelector().getNameInCode(), + nodeData.getSourceVarName() + "_state", nodeData.getItemKey(), + nodeData.getSourceVarName() + "_index", nodeData.getSourceVarName() + "_isFinished", + nodeData.getIndexOffset()); } // TODO: 添加辅助节点以支持迭代起始节点并行 @@ -88,8 +89,9 @@ public String renderEdges(IterationNodeData nodeData, List edges) { } ), Map.of("end", "%s", "iteration", "%s")); - """, nodeData.getVarName(), nodeData.getVarName() + "_start", nodeData.getVarName() + "_isFinished", - nodeData.getVarName() + "_end", edge.getTarget()); + """, nodeData.getSourceVarName(), nodeData.getSourceVarName() + "_start", + nodeData.getSourceVarName() + "_isFinished", nodeData.getSourceVarName() + "_end", + edge.getTarget()); } @Override @@ -154,7 +156,7 @@ public String render(Node node, String varName) { createIterationEndAction("%s", "%s", "%s") )); - """, nodeData.getVarName(), nodeData.getVarName() + "_end", nodeData.getVarName() + "_isFinished", + """, nodeData.getSourceVarName(), varName, nodeData.getSourceVarName() + "_isFinished", nodeData.getResultSelector().getNameInCode(), nodeData.getOutputKey()); } @@ -171,8 +173,9 @@ public String renderEdges(IterationNodeData nodeData, List edges) { } ), Map.of("finish", "%s", "start", "%s")); - """, nodeData.getVarName(), nodeData.getVarName() + "_end", nodeData.getVarName() + "_isFinished", - edge.getTarget(), nodeData.getVarName() + "_start"); + """, nodeData.getSourceVarName(), nodeData.getVarName(), + nodeData.getSourceVarName() + "_isFinished", edge.getTarget(), + nodeData.getSourceVarName() + "_start"); } @Override From 9e932613cf7091a3f23e171483795f76a12a8cc7 Mon Sep 17 00:00:00 2001 From: VLSMB <2047857654@qq.com> Date: Thu, 11 Sep 2025 21:19:14 +0800 Subject: [PATCH 07/11] feat: enhance CodeExecutorNodeAction and fix bugs --- .../node/code/CodeExecutorNodeAction.java | 48 +++-- .../graph/node/code/TemplateTransformer.java | 14 +- .../ai/graph/node/code/entity/CodeParam.java | 38 ++++ .../ai/graph/node/code/entity/CodeStyle.java | 35 ++++ .../code/java/JavaTemplateTransformer.java | 107 ++++++++--- .../javascript/NodeJsTemplateTransformer.java | 50 ++++-- .../python3/Python3TemplateTransformer.java | 36 +++- .../cloud/ai/graph/node/CodeActionTest.java | 167 +++++++++++++++--- 8 files changed, 399 insertions(+), 96 deletions(-) create mode 100644 spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/entity/CodeParam.java create mode 100644 spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/entity/CodeStyle.java diff --git a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/CodeExecutorNodeAction.java b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/CodeExecutorNodeAction.java index 6490a9cc3b..d6452e986a 100644 --- a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/CodeExecutorNodeAction.java +++ b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/CodeExecutorNodeAction.java @@ -17,9 +17,10 @@ import java.util.ArrayList; import java.util.HashMap; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; import com.alibaba.cloud.ai.graph.OverAllState; import com.alibaba.cloud.ai.graph.action.NodeAction; @@ -27,6 +28,8 @@ import com.alibaba.cloud.ai.graph.node.code.entity.CodeExecutionConfig; import com.alibaba.cloud.ai.graph.node.code.entity.CodeExecutionResult; import com.alibaba.cloud.ai.graph.node.code.entity.CodeLanguage; +import com.alibaba.cloud.ai.graph.node.code.entity.CodeParam; +import com.alibaba.cloud.ai.graph.node.code.entity.CodeStyle; import com.alibaba.cloud.ai.graph.node.code.entity.RunnerAndPreload; import com.alibaba.cloud.ai.graph.node.code.javascript.NodeJsTemplateTransformer; import com.alibaba.cloud.ai.graph.node.code.python3.Python3TemplateTransformer; @@ -47,10 +50,12 @@ public class CodeExecutorNodeAction implements NodeAction { private final CodeExecutionConfig codeExecutionConfig; - private Map params; + private final List params; private final String outputKey; + private final CodeStyle style; + private static final Map CODE_TEMPLATE_TRANSFORMERS = Map.of( CodeLanguage.PYTHON3, new Python3TemplateTransformer(), CodeLanguage.PYTHON, new Python3TemplateTransformer(), CodeLanguage.JAVASCRIPT, new NodeJsTemplateTransformer(), @@ -61,24 +66,25 @@ CodeLanguage.PYTHON3, new Python3TemplateTransformer(), CodeLanguage.PYTHON, CodeLanguage.PYTHON3.getValue(), CodeLanguage.PYTHON, CodeLanguage.PYTHON.getValue(), CodeLanguage.JAVA, CodeLanguage.JAVA.getValue()); - public CodeExecutorNodeAction(CodeExecutor codeExecutor, String codeLanguage, String code, - CodeExecutionConfig config, Map params, String outputKey) { + public CodeExecutorNodeAction(CodeExecutor codeExecutor, String codeLanguage, String code, CodeStyle style, + CodeExecutionConfig config, List params, String outputKey) { this.codeExecutor = codeExecutor; this.codeLanguage = codeLanguage; + this.style = style; this.code = code; this.codeExecutionConfig = config; this.params = params; this.outputKey = outputKey; } - private Map executeWorkflowCodeTemplate(CodeLanguage language, String code, List inputs) - throws Exception { + private Map executeWorkflowCodeTemplate(CodeLanguage language, String code, + Map inputs) throws Exception { TemplateTransformer templateTransformer = CODE_TEMPLATE_TRANSFORMERS.get(language); if (templateTransformer == null) { throw new RuntimeException("Unsupported language: " + language); } - RunnerAndPreload runnerAndPreload = templateTransformer.transformCaller(code, inputs); + RunnerAndPreload runnerAndPreload = templateTransformer.transformCaller(code, inputs, style); String response = executeCode(language, runnerAndPreload.preloadScript(), runnerAndPreload.runnerScript()); return templateTransformer.transformResponse(response); @@ -100,12 +106,12 @@ private String executeCode(CodeLanguage language, String preloadScript, String c @Override public Map apply(OverAllState state) throws Exception { - List inputs = new ArrayList<>(10); - if (params != null && !params.isEmpty()) { - for (String key : params.keySet()) { - inputs.add(state.data().get((String) params.get(key))); - } - } + Map inputs = Optional.ofNullable(params) + .orElse(List.of()) + .stream() + .collect(Collectors.toUnmodifiableMap(CodeParam::argName, param -> Optional.ofNullable(param.value()) + .or(() -> StringUtils.hasText(param.stateKey()) ? state.value(param.stateKey()) : Optional.empty()) + .orElseThrow(() -> new IllegalStateException("param has no value and legal key!")))); Map resultObjectMap = executeWorkflowCodeTemplate(CodeLanguage.fromValue(codeLanguage), code, inputs); Map updatedState = new HashMap<>(); @@ -127,13 +133,16 @@ public static class Builder { private String code; + private CodeStyle style; + private CodeExecutionConfig config; - private Map params; + private List params; private String outputKey; public Builder() { + style = CodeStyle.EXPLICIT_PARAMETERS; } public Builder codeExecutor(CodeExecutor codeExecutor) { @@ -151,13 +160,18 @@ public Builder code(String code) { return this; } + public Builder codeStyle(CodeStyle style) { + this.style = style; + return this; + } + public Builder config(CodeExecutionConfig config) { this.config = config; return this; } - public Builder params(Map params) { - this.params = new LinkedHashMap<>(params); + public Builder params(List params) { + this.params = List.copyOf(params); return this; } @@ -167,7 +181,7 @@ public Builder outputKey(String outputKey) { } public CodeExecutorNodeAction build() { - return new CodeExecutorNodeAction(codeExecutor, codeLanguage, code, config, params, outputKey); + return new CodeExecutorNodeAction(codeExecutor, codeLanguage, code, style, config, params, outputKey); } } diff --git a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/TemplateTransformer.java b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/TemplateTransformer.java index 2d6e6c562e..5b2048dc00 100644 --- a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/TemplateTransformer.java +++ b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/TemplateTransformer.java @@ -16,12 +16,12 @@ package com.alibaba.cloud.ai.graph.node.code; +import com.alibaba.cloud.ai.graph.node.code.entity.CodeStyle; import com.alibaba.cloud.ai.graph.node.code.entity.RunnerAndPreload; import com.fasterxml.jackson.databind.ObjectMapper; import java.nio.charset.StandardCharsets; import java.util.Base64; -import java.util.List; import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -38,8 +38,8 @@ public abstract class TemplateTransformer { protected static final String RESULT_TAG = "<>"; - public RunnerAndPreload transformCaller(String code, List inputs) throws Exception { - String runnerScript = assembleRunnerScript(code, inputs); + public RunnerAndPreload transformCaller(String code, Map inputs, CodeStyle style) throws Exception { + String runnerScript = assembleRunnerScript(code, inputs, style); String preloadScript = getPreloadScript(); return new RunnerAndPreload(runnerScript, preloadScript); @@ -52,7 +52,7 @@ public Map transformResponse(String response) throws Exception { mapper.getTypeFactory().constructMapType(Map.class, String.class, Object.class)); } - public abstract String getRunnerScript(); + public abstract String getRunnerScript(CodeStyle style); private String extractResultStrFromResponse(String response) { Pattern pattern = Pattern.compile(RESULT_TAG + "(.*?)" + RESULT_TAG, Pattern.DOTALL); @@ -66,14 +66,14 @@ private String extractResultStrFromResponse(String response) { } } - private String serializeInputs(List inputs) throws Exception { + private String serializeInputs(Map inputs) throws Exception { ObjectMapper mapper = new ObjectMapper(); String inputsJsonStr = mapper.writeValueAsString(inputs); return Base64.getEncoder().encodeToString(inputsJsonStr.getBytes(StandardCharsets.UTF_8)); } - private String assembleRunnerScript(String code, List inputs) throws Exception { - String script = getRunnerScript(); + private String assembleRunnerScript(String code, Map inputs, CodeStyle style) throws Exception { + String script = getRunnerScript(style); script = script.replace(CODE_PLACEHOLDER, code); script = script.replace(INPUTS_PLACEHOLDER, serializeInputs(inputs)); return script; diff --git a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/entity/CodeParam.java b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/entity/CodeParam.java new file mode 100644 index 0000000000..bfe2812fa2 --- /dev/null +++ b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/entity/CodeParam.java @@ -0,0 +1,38 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.cloud.ai.graph.node.code.entity; + +/** + * @author vlsmb + * @since 2025/9/11 + * @param argName 参数在代码中对应的名称 + * @param value 参数值,如果为null,则从OverallState中获取 + * @param stateKey 参数在OverallState中的key,如果value不为null,则忽略stateKey + */ +public record CodeParam(String argName, Object value, String stateKey) { + public CodeParam(String argName, String stateKey) { + this(argName, null, stateKey); + } + + public static CodeParam withValue(String argName, Object value) { + return new CodeParam(argName, value, null); + } + + public static CodeParam withKey(String argName, String stateKey) { + return new CodeParam(argName, null, stateKey); + } +} diff --git a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/entity/CodeStyle.java b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/entity/CodeStyle.java new file mode 100644 index 0000000000..2cb61ab181 --- /dev/null +++ b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/entity/CodeStyle.java @@ -0,0 +1,35 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.cloud.ai.graph.node.code.entity; + +/** + * @author vlsmb + * @since 2025/9/11 + */ +public enum CodeStyle { + + /** + * 参数直接作为函数形参的风格 示例: def main(x: int, y: int) -> dict: + */ + EXPLICIT_PARAMETERS, + + /** + * 参数通过全局字典访问的风格 示例: def main(): x = params['x'] + */ + GLOBAL_DICTIONARY + +} diff --git a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/java/JavaTemplateTransformer.java b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/java/JavaTemplateTransformer.java index 78ecb0af61..c81cbbcf51 100644 --- a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/java/JavaTemplateTransformer.java +++ b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/java/JavaTemplateTransformer.java @@ -15,10 +15,12 @@ */ package com.alibaba.cloud.ai.graph.node.code.java; +import com.alibaba.cloud.ai.graph.node.code.entity.CodeStyle; import com.alibaba.cloud.ai.graph.node.code.TemplateTransformer; /** * Java code template transformer Used to convert user code into executable Java programs + * user method should be {@code public static Object run([..args])} * * @author HeYQ * @since 2024-06-01 @@ -26,30 +28,87 @@ public class JavaTemplateTransformer extends TemplateTransformer { @Override - public String getRunnerScript() { - return """ - import java.util.*; - import com.fasterxml.jackson.databind.ObjectMapper; - import com.fasterxml.jackson.databind.node.ObjectNode; - - class Main { - public static void main(String[] args) throws Exception { - // Parse input parameters - String inputsBase64 = "%s"; - String inputsJson = new String(Base64.getDecoder().decode(inputsBase64)); - ObjectMapper mapper = new ObjectMapper(); - Object[] inputs = mapper.readValue(inputsJson, Object[].class); - - // Execute user code - Object result = main(inputs); - - // Output results - String output = mapper.writeValueAsString(result); - System.out.println("<>" + output + "<>"); - } - - %s - }""".formatted(INPUTS_PLACEHOLDER, CODE_PLACEHOLDER); + public String getRunnerScript(CodeStyle style) { + return switch (style) { + case EXPLICIT_PARAMETERS -> String.format( + """ + import java.util.*; + import com.fasterxml.jackson.databind.ObjectMapper; + import com.fasterxml.jackson.databind.node.ObjectNode; + import com.fasterxml.jackson.core.type.TypeReference; + import java.lang.reflect.Method; + + class Main { + public static void main(String[] args) throws Exception { + // Parse input parameters + String inputsBase64 = "%s"; + String inputsJson = new String(Base64.getDecoder().decode(inputsBase64)); + ObjectMapper mapper = new ObjectMapper(); + Map inputs = mapper.readValue(inputsJson, new TypeReference<>(){}); + + // Execute user code + Object result = invokeMethod(Main.class, "run", inputs); + + // Output results + String output = mapper.writeValueAsString(result); + System.out.println("%s" + output + "%s"); + } + + public static Object invokeMethod(Class clazz, String methodName, Map args) + throws Exception { + Method func = Arrays.stream(clazz.getMethods()).filter(method -> method.getName().equals(methodName)) + .filter(method -> method.getParameterCount() == args.size()) + .findFirst() + .orElseThrow(); + Object[] params = Arrays.stream(func.getParameters()).map( + parameter -> args.get(parameter.getName()) + ).toArray(); + return func.invoke(null, params); + } + + // user code + %s + } + """, + INPUTS_PLACEHOLDER, RESULT_TAG, RESULT_TAG, CODE_PLACEHOLDER); + case GLOBAL_DICTIONARY -> String.format(""" + import java.util.*; + import com.fasterxml.jackson.databind.ObjectMapper; + import com.fasterxml.jackson.databind.node.ObjectNode; + import com.fasterxml.jackson.core.type.TypeReference; + + class Main { + + private static final Map params; + + private static final ObjectMapper mapper; + + static { + try { + // Parse input parameters + String inputsBase64 = "%s"; + String inputsJson = new String(Base64.getDecoder().decode(inputsBase64)); + mapper = new ObjectMapper(); + params = mapper.readValue(inputsJson, new TypeReference<>(){}); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + public static void main(String[] args) throws Exception { + // Execute user code + Object result = run(); + + // Output results + String output = mapper.writeValueAsString(result); + System.out.println("%s" + output + "%s"); + } + + // user code + %s + } + """, INPUTS_PLACEHOLDER, RESULT_TAG, RESULT_TAG, CODE_PLACEHOLDER); + }; } } diff --git a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/javascript/NodeJsTemplateTransformer.java b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/javascript/NodeJsTemplateTransformer.java index ee8f25acf8..3909615e68 100644 --- a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/javascript/NodeJsTemplateTransformer.java +++ b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/javascript/NodeJsTemplateTransformer.java @@ -15,6 +15,7 @@ */ package com.alibaba.cloud.ai.graph.node.code.javascript; +import com.alibaba.cloud.ai.graph.node.code.entity.CodeStyle; import com.alibaba.cloud.ai.graph.node.code.TemplateTransformer; /** @@ -24,22 +25,39 @@ public class NodeJsTemplateTransformer extends TemplateTransformer { @Override - public String getRunnerScript() { - return """ - // declare main function - %s - - // decode and prepare input object - var inputs_obj = JSON.parse(Buffer.from('%s', 'base64').toString('utf-8')) - - // execute main function - var output_obj = main(inputs_obj) - - // convert output to json and print - var output_json = JSON.stringify(output_obj) - var result = `<>${output_json}<>` - console.log(result) - """.formatted(CODE_PLACEHOLDER, INPUTS_PLACEHOLDER); + public String getRunnerScript(CodeStyle style) { + return switch (style) { + case EXPLICIT_PARAMETERS -> String.format(""" + // declare main function + %s + + // decode and prepare input object + var inputs_obj = JSON.parse(Buffer.from('%s', 'base64').toString('utf-8')) + + // execute main function + var output_obj = main(inputs_obj) + + // convert output to json and print + var output_json = JSON.stringify(output_obj) + var result = `%s${output_json}%s` + console.log(result) + """, CODE_PLACEHOLDER, INPUTS_PLACEHOLDER, RESULT_TAG, RESULT_TAG); + case GLOBAL_DICTIONARY -> String.format(""" + // decode and prepare input object + let params = JSON.parse(Buffer.from('%s', 'base64').toString('utf-8')) + + // declare main function + %s + + // execute main function + var output_obj = main() + + // convert output to json and print + var output_json = JSON.stringify(output_obj) + var result = `%s${output_json}%s` + console.log(result) + """, INPUTS_PLACEHOLDER, CODE_PLACEHOLDER, RESULT_TAG, RESULT_TAG); + }; } } diff --git a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/python3/Python3TemplateTransformer.java b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/python3/Python3TemplateTransformer.java index c33900bc91..057f01fc56 100644 --- a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/python3/Python3TemplateTransformer.java +++ b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/python3/Python3TemplateTransformer.java @@ -15,6 +15,7 @@ */ package com.alibaba.cloud.ai.graph.node.code.python3; +import com.alibaba.cloud.ai.graph.node.code.entity.CodeStyle; import com.alibaba.cloud.ai.graph.node.code.TemplateTransformer; /** @@ -24,13 +25,34 @@ public class Python3TemplateTransformer extends TemplateTransformer { @Override - public String getRunnerScript() { - return String.join("\n", - new String[] { "# declare main function", CODE_PLACEHOLDER, "import json", - "from base64 import b64decode", - "inputs_obj = json.loads(b64decode('" + INPUTS_PLACEHOLDER + "').decode('utf-8'))", - "output_obj = main(*inputs_obj)", "output_json = json.dumps(output_obj, indent=4)", - "result = f'''<>{output_json}<>'''", "print(result)" }); + public String getRunnerScript(CodeStyle style) { + return switch (style) { + case EXPLICIT_PARAMETERS -> String.format(""" + # declare main function + %s + + import json + from base64 import b64decode + inputs_obj = json.loads(b64decode('%s').decode('utf-8')) + output_obj = main(**inputs_obj) + output_json = json.dumps(output_obj, indent=4) + result = f'''%s{output_json}%s''' + print(result) + """, CODE_PLACEHOLDER, INPUTS_PLACEHOLDER, RESULT_TAG, RESULT_TAG); + case GLOBAL_DICTIONARY -> String.format(""" + import json + from base64 import b64decode + params = json.loads(b64decode('%s').decode('utf-8')) + + # declare main function + %s + + output_obj = main() + output_json = json.dumps(output_obj, indent=4) + result = f'''%s{output_json}%s''' + print(result) + """, INPUTS_PLACEHOLDER, CODE_PLACEHOLDER, RESULT_TAG, RESULT_TAG); + }; } } diff --git a/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/node/CodeActionTest.java b/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/node/CodeActionTest.java index cad3c8041d..8d13d594d9 100644 --- a/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/node/CodeActionTest.java +++ b/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/node/CodeActionTest.java @@ -18,6 +18,8 @@ import com.alibaba.cloud.ai.graph.OverAllState; import com.alibaba.cloud.ai.graph.action.NodeAction; import com.alibaba.cloud.ai.graph.node.code.CodeExecutorNodeAction; +import com.alibaba.cloud.ai.graph.node.code.entity.CodeParam; +import com.alibaba.cloud.ai.graph.node.code.entity.CodeStyle; import com.alibaba.cloud.ai.graph.node.code.LocalCommandlineCodeExecutor; import com.alibaba.cloud.ai.graph.node.code.entity.CodeExecutionConfig; import org.junit.jupiter.api.BeforeEach; @@ -25,8 +27,8 @@ import org.junit.jupiter.api.io.TempDir; import java.nio.file.Path; -import java.util.HashMap; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -57,33 +59,107 @@ def main(arg1: str, arg2: str) -> dict: "result": arg1 + arg2, } """; - Map params = new HashMap<>(16); - params.put("key1", "arg1"); - params.put("key2", "arg2"); NodeAction codeNode = CodeExecutorNodeAction.builder() .codeExecutor(new LocalCommandlineCodeExecutor()) .code(code) - .codeLanguage("python") + .codeStyle(CodeStyle.EXPLICIT_PARAMETERS) + .codeLanguage("python3") .config(config) - .params(params) + .params(List.of(new CodeParam("arg1", "data1"), new CodeParam("arg2", "data2"))) + .outputKey("output") .build(); - Map initData = new HashMap<>(16); - initData.put("arg1", "1"); - initData.put("arg2", "2"); - OverAllState mockState = new OverAllState(initData); + OverAllState mockState = new OverAllState(Map.of("data1", "1", "data2", "2")); + Map stateData = codeNode.apply(mockState); + System.out.println(stateData); + assertNotNull(stateData); + assertEquals(Map.of("output", Map.of("result", "12")), stateData); + } + + @Test + void testExecutePythonGlobalDictStyleSuccessfully() throws Exception { + String code = """ + def main(): + ret = { + "output": params['arg1'] + params['arg2'] + params['arg3'] + } + return ret + """; + NodeAction codeNode = CodeExecutorNodeAction.builder() + .codeExecutor(new LocalCommandlineCodeExecutor()) + .code(code) + .codeStyle(CodeStyle.GLOBAL_DICTIONARY) + .codeLanguage("python3") + .config(config) + .params(List.of(CodeParam.withKey("arg1", "arg1"), CodeParam.withKey("arg2", "arg2"), + CodeParam.withValue("arg3", "3"))) + .outputKey("output") + .build(); + OverAllState mockState = new OverAllState(Map.of("arg1", "1", "arg2", "2")); + Map stateData = codeNode.apply(mockState); + assertNotNull(stateData); + assertEquals(Map.of("output", Map.of("output", "123")), stateData); + System.out.println(stateData); + } + + @Test + void testExecuteJavascriptSuccessfully() throws Exception { + String code = """ + function main({arg1, arg2}) { + return { + result: arg1 + arg2 + } + } + """; + NodeAction codeNode = CodeExecutorNodeAction.builder() + .codeExecutor(new LocalCommandlineCodeExecutor()) + .code(code) + .codeStyle(CodeStyle.EXPLICIT_PARAMETERS) + .codeLanguage("javascript") + .config(config) + .params(List.of(new CodeParam("arg1", "data1"), new CodeParam("arg2", "data2"))) + .outputKey("output") + .build(); + OverAllState mockState = new OverAllState(Map.of("data1", "1", "data2", "2")); Map stateData = codeNode.apply(mockState); System.out.println(stateData); + assertNotNull(stateData); + assertEquals(Map.of("output", Map.of("result", "12")), stateData); + } + + @Test + void testExecuteJavascriptGlobalDictStyleSuccessfully() throws Exception { + String code = """ + function main() { + const ret = { + "output": params.arg1 + params.arg2 + params.arg3 + }; + return ret; + } + """; + NodeAction codeNode = CodeExecutorNodeAction.builder() + .codeExecutor(new LocalCommandlineCodeExecutor()) + .code(code) + .codeStyle(CodeStyle.GLOBAL_DICTIONARY) + .codeLanguage("javascript") + .config(config) + .params(List.of(CodeParam.withKey("arg1", "arg1"), CodeParam.withKey("arg2", "arg2"), + CodeParam.withValue("arg3", "3"))) + .outputKey("output") + .build(); + OverAllState mockState = new OverAllState(Map.of("arg1", "1", "arg2", "2")); + Map stateData = codeNode.apply(mockState); + assertNotNull(stateData); + assertEquals(Map.of("output", Map.of("output", "123")), stateData); + System.out.println(stateData); } @Test void testExecuteJavaWithLocalExecutor() throws Exception { // Prepare test data String javaCode = """ - public static Object main(Object[] inputs) { - // Process input parameters - String text = (String) inputs[0]; - Integer count = (Integer) inputs[1]; - + public static Object run(String arg0, Integer arg1) { + String text = arg0; + int count = arg1; // Execute business logic StringBuilder result = new StringBuilder(); for (int i = 0; i < count; i++) { @@ -99,17 +175,14 @@ public static Object main(Object[] inputs) { } """; - // Create parameter mapping - Map params = new LinkedHashMap<>(); - params.put("text", "text"); - params.put("count", "count"); // Create code execution node action NodeAction codeNode = CodeExecutorNodeAction.builder() .codeExecutor(new LocalCommandlineCodeExecutor()) .code(javaCode) + .codeStyle(CodeStyle.EXPLICIT_PARAMETERS) .codeLanguage("java") .config(config) - .params(params) + .params(List.of(new CodeParam("arg0", "text"), new CodeParam("arg1", "count"))) .outputKey("codeNode1_output") .build(); @@ -121,13 +194,57 @@ public static Object main(Object[] inputs) { // Execute code Map result = codeNode.apply(mockState); - Map codeNode1Output = (Map) result.get("codeNode1_output"); + assertNotNull(result); + System.out.println(result); + assertEquals(Map.of("codeNode1_output", Map.of("length", 18, "count", 3, "repeated_text", "Hello Hello Hello")), + result); + } + + @Test + void testExecuteJavaGlobalDictStyleWithLocalExecutor() throws Exception { + // Prepare test data + String javaCode = """ + public static Object run() { + String text = (String) params.get("arg0"); + int count = (Integer) params.get("arg1"); + // Execute business logic + StringBuilder result = new StringBuilder(); + for (int i = 0; i < count; i++) { + result.append(text).append(" "); + } + + Map response = new HashMap<>(); + response.put("repeated_text", result.toString().trim()); + response.put("length", result.length()); + response.put("count", count); - // Verify results + return response; + } + """; + + // Create code execution node action + NodeAction codeNode = CodeExecutorNodeAction.builder() + .codeExecutor(new LocalCommandlineCodeExecutor()) + .code(javaCode) + .codeLanguage("java") + .config(config) + .codeStyle(CodeStyle.GLOBAL_DICTIONARY) + .params(List.of(new CodeParam("arg0", "text"), new CodeParam("arg1", "count"))) + .outputKey("codeNode1_output") + .build(); + + // Prepare input data + Map initData = new LinkedHashMap<>(); + initData.put("text", "Hello"); + initData.put("count", 3); + OverAllState mockState = new OverAllState(initData); + + // Execute code + Map result = codeNode.apply(mockState); assertNotNull(result); - assertEquals("Hello Hello Hello", codeNode1Output.get("repeated_text")); - assertEquals(18, codeNode1Output.get("length")); - assertEquals(3, codeNode1Output.get("count")); + System.out.println(result); + assertEquals(Map.of("codeNode1_output", Map.of("length", 18, "count", 3, "repeated_text", "Hello Hello Hello")), + result); } } From a774fe215500fe2f5c8d8fa30ad2be81207bfaca Mon Sep 17 00:00:00 2001 From: VLSMB <2047857654@qq.com> Date: Thu, 11 Sep 2025 23:42:35 +0800 Subject: [PATCH 08/11] feat: enhance CodeNodeSection for Dify DSL --- .../generator/model/workflow/NodeType.java | 2 +- .../model/workflow/nodedata/CodeNodeData.java | 110 +++++++++++++- .../dsl/converter/CodeNodeDataConverter.java | 138 +++++++++++++----- .../workflow/WorkflowProjectGenerator.java | 3 +- .../workflow/sections/CodeNodeSection.java | 124 ++++++++-------- .../templates/GraphBuilder.java.mustache | 24 +-- 6 files changed, 276 insertions(+), 125 deletions(-) diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java index 891b8ee5c5..411d448339 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java @@ -32,7 +32,7 @@ public enum NodeType { LLM("llm", "llm", "LLM"), - CODE("code", "code", "UNSUPPORTED"), + CODE("code", "code", "Script"), RETRIEVER("retriever", "knowledge-retrieval", "Retrieval"), diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/CodeNodeData.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/CodeNodeData.java index d235c57cd0..f805db583a 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/CodeNodeData.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/CodeNodeData.java @@ -16,30 +16,35 @@ package com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata; import java.util.List; +import java.util.Map; import com.alibaba.cloud.ai.studio.admin.generator.model.Variable; -import com.alibaba.cloud.ai.studio.admin.generator.model.VariableSelector; import com.alibaba.cloud.ai.studio.admin.generator.model.VariableType; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeData; +import com.alibaba.cloud.ai.studio.admin.generator.utils.ObjectToCodeUtil; public class CodeNodeData extends NodeData { public static Variable getDefaultOutputSchema() { - return new Variable("output", VariableType.STRING); + return new Variable("output", VariableType.OBJECT); } private String code; private String codeLanguage; + private List inputParams; + private String outputKey; - public CodeNodeData() { - } + private int maxRetryCount = 1; - public CodeNodeData(List inputs, List outputs) { - super(inputs, outputs); - } + private int retryIntervalMs = 1000; + + // 运行失败时的默认值 + private Map defaultValue; + + private CodeStyle codeStyle = CodeStyle.EXPLICIT_PARAMETERS; public String getCode() { return code; @@ -59,6 +64,15 @@ public CodeNodeData setCodeLanguage(String codeLanguage) { return this; } + public List getInputParams() { + return inputParams; + } + + public CodeNodeData setInputParams(List inputParams) { + this.inputParams = inputParams; + return this; + } + public String getOutputKey() { return outputKey; } @@ -68,4 +82,86 @@ public CodeNodeData setOutputKey(String outputKey) { return this; } + public int getMaxRetryCount() { + return maxRetryCount; + } + + public CodeNodeData setMaxRetryCount(int maxRetryCount) { + this.maxRetryCount = maxRetryCount; + return this; + } + + public int getRetryIntervalMs() { + return retryIntervalMs; + } + + public CodeNodeData setRetryIntervalMs(int retryIntervalMs) { + this.retryIntervalMs = retryIntervalMs; + return this; + } + + public Map getDefaultValue() { + return defaultValue; + } + + public CodeNodeData setDefaultValue(Map defaultValue) { + this.defaultValue = defaultValue; + return this; + } + + public CodeStyle getCodeStyle() { + return codeStyle; + } + + public CodeNodeData setCodeStyle(CodeStyle codeStyle) { + this.codeStyle = codeStyle; + return this; + } + + public enum CodeStyle { + + /** + * Dify代码样式 + */ + EXPLICIT_PARAMETERS, + + /** + * Studio代码样式 + */ + GLOBAL_DICTIONARY + + ; + + public String toString() { + return "CodeStyle." + name(); + } + + } + + public record CodeParam(String argName, Object value, String stateKey) { + public static CodeParam withValue(String argName, Object value) { + return new CodeParam(argName, value, null); + } + + public static CodeParam withKey(String argName, String stateKey) { + return new CodeParam(argName, null, stateKey); + } + + @Override + public String toString() { + if (argName == null) { + throw new IllegalArgumentException("argName cannot be null"); + } + if (value == null && stateKey != null) { + return String.format("CodeParam.withKey(%s, %s)", ObjectToCodeUtil.toCode(argName()), + ObjectToCodeUtil.toCode(stateKey())); + } + if (value != null && stateKey == null) { + return String.format("CodeParam.withValue(%s, %s)", ObjectToCodeUtil.toCode(argName()), + ObjectToCodeUtil.toCode(value())); + } + throw new IllegalArgumentException("value and stateKey must only one."); + } + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/CodeNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/CodeNodeDataConverter.java index cd2c2c25b3..cc7bb63207 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/CodeNodeDataConverter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/CodeNodeDataConverter.java @@ -15,21 +15,23 @@ */ package com.alibaba.cloud.ai.studio.admin.generator.service.dsl.converter; -import java.util.ArrayList; -import java.util.HashMap; +import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.function.BiConsumer; +import java.util.stream.Collectors; import java.util.stream.Stream; import com.alibaba.cloud.ai.studio.admin.generator.model.Variable; -import com.alibaba.cloud.ai.studio.admin.generator.model.VariableSelector; import com.alibaba.cloud.ai.studio.admin.generator.model.VariableType; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeType; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata.CodeNodeData; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.AbstractNodeDataConverter; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; +import com.alibaba.cloud.ai.studio.admin.generator.utils.MapReadUtil; +import com.fasterxml.jackson.core.JsonProcessingException; import org.springframework.stereotype.Component; @Component @@ -55,43 +57,97 @@ public Boolean supportDialect(DSLDialectType dialectType) { @Override public CodeNodeData parse(Map data) { - List> variables = (List>) data.get("variables"); - List inputs = variables.stream().map(variable -> { - List selector = (List) variable.get("value_selector"); - return new VariableSelector(selector.get(0), selector.get(1), (String) variable.get("variable")); - }).toList(); - Map> outputsMap = (Map>) data.get("outputs"); - List outputs = outputsMap.entrySet().stream().map(entry -> { - String varName = entry.getKey(); - String difyType = (String) entry.getValue().get("type"); - VariableType varType = VariableType.fromDifyValue(difyType) - .orElseThrow(() -> new IllegalArgumentException("Unsupported dify variable type: " + difyType)); - return new Variable(varName, varType); - }).toList(); - - return new CodeNodeData(inputs, outputs).setCode((String) data.get("code")) - .setCodeLanguage((String) data.get("code_language")); + CodeNodeData nodeData = new CodeNodeData(); + + // 提取必要信息 + String code = MapReadUtil.getMapDeepValue(data, String.class, "code"); + String lang = MapReadUtil.getMapDeepValue(data, String.class, "code_language"); + Boolean isRetry = Optional + .ofNullable(MapReadUtil.getMapDeepValue(data, Boolean.class, "retry_config", "retry_enabled")) + .orElse(false); + int maxRetryCount = isRetry ? Optional + .ofNullable(MapReadUtil.getMapDeepValue(data, Integer.class, "retry_config", "max_retries")) + .orElse(1) : 1; + int retryIntervalMs = isRetry ? Optional + .ofNullable(MapReadUtil.getMapDeepValue(data, Integer.class, "retry_config", "retry_interval")) + .orElse(1000) : 1000; + + List outputParams = Optional + .ofNullable( + MapReadUtil.safeCastToListWithMap(MapReadUtil.getMapDeepValue(data, List.class, "outputs"))) + .orElse(List.of()) + .stream() + .map(Map::entrySet) + .flatMap(Collection::stream) + .map(Map.Entry::getKey) + .map(k -> new Variable(k, VariableType.OBJECT)) + .toList(); + + List inputParams = Optional + .ofNullable(MapReadUtil + .safeCastToListWithMap(MapReadUtil.getMapDeepValue(data, List.class, "variables"))) + .orElse(List.of()) + .stream() + .filter(map -> map.containsKey("value_selector") && map.containsKey("variable")) + .map(map -> { + List list = MapReadUtil.safeCastToList(map.get("value_selector"), String.class); + // 先以Value的形式存储selector,在post阶段转换为正确的stateKey + return new CodeNodeData.CodeParam(map.get("variable").toString(), list, list.get(0)); + }) + .toList(); + + // 设置必要信息 + nodeData.setCodeStyle(CodeNodeData.CodeStyle.EXPLICIT_PARAMETERS); + nodeData.setCode(code); + nodeData.setCodeLanguage(lang); + nodeData.setMaxRetryCount(maxRetryCount); + nodeData.setRetryIntervalMs(retryIntervalMs); + nodeData.setInputParams(inputParams); + nodeData.setOutputs(outputParams); + + // 错误处理策略 + String errorStrategy = MapReadUtil.getMapDeepValue(data, String.class, "error_strategy"); + + if (errorStrategy != null) { + // 暂仅支持默认值 + List> defaultValueList = MapReadUtil + .safeCastToListWithMap(MapReadUtil.getMapDeepValue(data, List.class, "default_value")); + if (defaultValueList != null) { + Map defaultValue = defaultValueList.stream() + .filter(map -> map.containsKey("key") && map.containsKey("value")) + .collect(Collectors.toUnmodifiableMap(map -> map.get("key").toString(), + map -> map.get("value"), (a, b) -> b)); + nodeData.setDefaultValue(defaultValue); + } + } + + return nodeData; } @Override public Map dump(CodeNodeData nodeData) { - Map data = new HashMap<>(); - data.put("code", nodeData.getCode()); - data.put("code_language", nodeData.getCodeLanguage()); - List> inputVars = new ArrayList<>(); - nodeData.getInputs().forEach(v -> { - inputVars.add( - Map.of("variable", v.getLabel(), "value_selector", List.of(v.getNamespace(), v.getName()))); - }); - data.put("variables", inputVars); - Map outputVars = new HashMap<>(); - nodeData.getOutputs().forEach(variable -> { - outputVars.put(variable.getName(), Map.of("type", variable.getValueType().difyValue())); - }); - data.put("outputs", outputVars); - return data; + throw new UnsupportedOperationException(); } - }), CUSTOM(AbstractNodeDataConverter.defaultCustomDialectConverter(CodeNodeData.class)); + }), + + STUDIO(new DialectConverter<>() { + @Override + public Boolean supportDialect(DSLDialectType dialectType) { + return DSLDialectType.STUDIO.equals(dialectType); + } + + @Override + public CodeNodeData parse(Map data) throws JsonProcessingException { + return null; + } + + @Override + public Map dump(CodeNodeData nodeData) { + throw new UnsupportedOperationException(); + } + }) + + , CUSTOM(AbstractNodeDataConverter.defaultCustomDialectConverter(CodeNodeData.class)); private final DialectConverter dialectConverter; @@ -113,9 +169,19 @@ public String generateVarName(int count) { @Override public BiConsumer> postProcessConsumer(DSLDialectType dialectType) { return switch (dialectType) { - case DIFY -> this.emptyProcessConsumer().andThen((nodeData, idToVarName) -> { + case DIFY, STUDIO -> this.emptyProcessConsumer().andThen((nodeData, idToVarName) -> { // code节点将返回{"varName.output": {...}}的数据,之后拆包成若干输出数据 nodeData.setOutputKey(nodeData.getVarName() + "_" + CodeNodeData.getDefaultOutputSchema().getName()); + // 输入Param的Key都格式化为varName_key + nodeData.setInputParams(nodeData.getInputParams().stream().map(param -> { + if (param.stateKey() == null) { + return param; + } + @SuppressWarnings("unchecked") + List selector = (List) param.value(); + return CodeNodeData.CodeParam.withKey(param.argName(), + idToVarName.get(selector.get(0)) + "_" + selector.get(1)); + }).toList()); }).andThen(super.postProcessConsumer(dialectType)); default -> super.postProcessConsumer(dialectType); }; diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java index 2e13c42663..8794bc655a 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java @@ -241,7 +241,8 @@ private String renderImportSection(Workflow workflow) { "com.alibaba.cloud.ai.graph.node.code.entity.CodeExecutionConfig", "com.alibaba.cloud.ai.graph.node.code.CodeExecutor", "com.alibaba.cloud.ai.graph.node.code.LocalCommandlineCodeExecutor", "java.io.IOException", - "java.nio.file.Files", "java.nio.file.Path", "java.util.stream.Collectors")), + "java.nio.file.Files", "java.nio.file.Path", "java.util.stream.Collectors", + "com.alibaba.cloud.ai.graph.node.code.entity.CodeParam")), Map.entry(NodeType.AGENT, List.of("com.alibaba.cloud.ai.graph.node.AgentNode", "org.springframework.ai.tool.ToolCallback")), diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/CodeNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/CodeNodeSection.java index fe6aa0c34d..658dc3a02f 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/CodeNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/CodeNodeSection.java @@ -15,16 +15,16 @@ */ package com.alibaba.cloud.ai.studio.admin.generator.service.generator.workflow.sections; -import java.util.stream.Collectors; - import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.Node; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeType; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata.CodeNodeData; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; import com.alibaba.cloud.ai.studio.admin.generator.service.generator.workflow.NodeSection; +import com.alibaba.cloud.ai.studio.admin.generator.utils.ObjectToCodeUtil; import org.springframework.stereotype.Component; +// TODO: 支持异常分支 @Component public class CodeNodeSection implements NodeSection { @@ -35,66 +35,76 @@ public boolean support(NodeType nodeType) { @Override public String render(Node node, String varName) { - CodeNodeData data = (CodeNodeData) node.getData(); - StringBuilder sb = new StringBuilder(); - String id = node.getId(); - - sb.append(String.format("// —— CodeNode [%s] ——%n", id)); - - sb.append(String.format("CodeExecutorNodeAction %s = CodeExecutorNodeAction.builder()%n", varName)); - - sb.append(" .codeExecutor(codeExecutor) // 注入的 CodeExecutor Bean\n"); - - sb.append(String.format(" .codeLanguage(\"%s\")%n", data.getCodeLanguage())); - - String escaped = data.getCode().replace("\\", "\\\\").replace("\"\"\"", "\\\"\\\"\\\""); - sb.append(" .code(\"\"\"\n").append(escaped).append("\n\"\"\")\n"); - - sb.append(" .config(codeExecutionConfig) // 注入的 CodeExecutionConfig Bean\n"); - - if (!data.getInputs().isEmpty()) { - String params = data.getInputs() - .stream() - .map(sel -> String.format("\"%s\", \"%s\"", sel.getLabel(), sel.getNameInCode())) - .collect(Collectors.joining(", ")); - sb.append(String.format(" .params(Map.of(%s))%n", params)); - } - if (data.getOutputKey() != null) { - sb.append(String.format(".outputKey(\"%s\")%n", escape(data.getOutputKey()))); - } - - sb.append(" .build();\n"); - - // 辅助节点代码,包装codeNode,将他的返回值变量解包 - String assistantNodeCode = String.format("wrapperCodeNodeAction(%s, \"%s\", \"%s\")", varName, - data.getOutputKey(), varName); - - sb.append(String.format("stateGraph.addNode(\"%s\", AsyncNodeAction.node_async(%s));%n%n", varName, - assistantNodeCode)); - - return sb.toString(); + CodeNodeData nodeData = ((CodeNodeData) node.getData()); + return String.format(""" + // —— CodeNode [%s] —— + CodeExecutorNodeAction %s = CodeExecutorNodeAction.builder() + .codeExecutor(codeExecutor) + .codeLanguage("%s") + .code(%s) + .config(codeExecutionConfig) + .params(%s) + .outputKey("%s") + .build(); + stateGraph.addNode("%s", AsyncNodeAction.node_async( + wrapperCodeNodeAction(%s, "%s", "%s", %d, %d, %s) + )); + + """, node.getId(), varName, nodeData.getCodeLanguage(), ObjectToCodeUtil.toCode(nodeData.getCode()), + ObjectToCodeUtil.toCode(nodeData.getInputParams()), nodeData.getOutputKey(), varName, varName, + nodeData.getOutputKey(), varName, nodeData.getMaxRetryCount(), nodeData.getRetryIntervalMs(), + ObjectToCodeUtil.toCode(nodeData.getDefaultValue())); } @Override public String assistMethodCode(DSLDialectType dialectType) { return switch (dialectType) { - case DIFY -> """ - private NodeAction wrapperCodeNodeAction(NodeAction codeNodeAction, String key, String nodeName) { - return state -> { - // 将代码运行的结果拆包 - Map result = codeNodeAction.apply(state); - Object object = result.get(key); - if(!(object instanceof Map)) { - return Map.of(); - } - return ((Map) object).entrySet().stream() - .collect(Collectors.toMap( - entry -> nodeName + "_" + entry.getKey(), - Map.Entry::getValue - )); - }; - } - """; + case DIFY, STUDIO -> + """ + private static final CodeExecutionConfig codeExecutionConfig; + private static final CodeExecutor codeExecutor; + + static { + // todo: configure your own code execution configuration + try { + Path tempDir = Files.createTempDirectory("code-execution-workdir-"); + tempDir.toFile().deleteOnExit(); + codeExecutionConfig = new CodeExecutionConfig().setWorkDir(tempDir.toString()); + codeExecutor = new LocalCommandlineCodeExecutor(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private NodeAction wrapperCodeNodeAction(NodeAction codeNodeAction, String key, + String nodeName, int maxRetryCount, int retryIntervalMs, Map defaultValue) { + return state -> { + int count = maxRetryCount; + while (count-- > 0) { + try { + // 将代码运行的结果拆包 + Map result = codeNodeAction.apply(state); + Object object = result.get(key); + if(!(object instanceof Map)) { + throw new RuntimeException("unexcepted result"); + } + return ((Map) object).entrySet().stream() + .collect(Collectors.toMap( + entry -> nodeName + "_" + entry.getKey(), + Map.Entry::getValue + )); + } catch (Exception e) { + Thread.sleep(retryIntervalMs); + } + } + if(defaultValue != null) { + return defaultValue; + } else { + throw new RuntimeException("code execution failed!"); + } + }; + } + """; default -> ""; }; } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/resources/templates/GraphBuilder.java.mustache b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/resources/templates/GraphBuilder.java.mustache index 821a909284..fa67caf3d2 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/resources/templates/GraphBuilder.java.mustache +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/resources/templates/GraphBuilder.java.mustache @@ -28,10 +28,7 @@ public class GraphBuilder { {{assistMethodCode}} @Bean - public CompiledGraph buildGraph( - ChatModel chatModel - {{#hasCode}}, CodeExecutionConfig codeExecutionConfig, CodeExecutor codeExecutor{{/hasCode}} - ) throws Exception { + public CompiledGraph buildGraph(ChatModel chatModel) throws Exception { ChatClient chatClient = ChatClient.builder(chatModel).defaultAdvisors(new SimpleLoggerAdvisor()).build(); // new stateGraph @@ -43,23 +40,4 @@ public class GraphBuilder { return stateGraph.compile(); } - {{#hasCode}} - @Bean - public Path tempDir() throws IOException { - // todo: set your work dir - Path tempDir = Files.createTempDirectory("code-execution-workdir-"); - tempDir.toFile().deleteOnExit(); - return tempDir; - } - - @Bean - public CodeExecutionConfig codeExecutionConfig(Path tempDir) { - return new CodeExecutionConfig().setWorkDir(tempDir.toString()); - } - - @Bean - public CodeExecutor codeGenerator() { - return new LocalCommandlineCodeExecutor(); - } - {{/hasCode}} } From 18da21e740b482dc4a797ffca83b5ce5a60aed4c Mon Sep 17 00:00:00 2001 From: VLSMB <2047857654@qq.com> Date: Fri, 12 Sep 2025 00:48:59 +0800 Subject: [PATCH 09/11] feat: support CodeNodeSection for Studio DSL --- .../dsl/AbstractNodeDataConverter.java | 34 ++++++-- .../dsl/converter/CodeNodeDataConverter.java | 80 ++++++++++++++++++- .../workflow/WorkflowProjectGenerator.java | 14 ++-- .../workflow/sections/CodeNodeSection.java | 6 +- 4 files changed, 118 insertions(+), 16 deletions(-) diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/AbstractNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/AbstractNodeDataConverter.java index 1b1b174543..f389ce3f8c 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/AbstractNodeDataConverter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/AbstractNodeDataConverter.java @@ -19,9 +19,11 @@ import java.util.Map; import java.util.function.BiConsumer; import java.util.function.BiFunction; +import java.util.regex.MatchResult; import java.util.regex.Matcher; import java.util.regex.Pattern; +import com.alibaba.cloud.ai.studio.admin.generator.model.VariableSelector; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeData; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; @@ -72,6 +74,23 @@ public interface DialectConverter { Map dump(T nodeData); + /** + * 将模板字符串转换为变量选择器 + * @param dialectType dsl语言 + * @param template 模板字符串 + * @return 变量选择器 + */ + default VariableSelector varTemplateToSelector(DSLDialectType dialectType, String template) { + Pattern pattern = switch (dialectType) { + case DIFY -> DIFY_VAR_TEMPLATE_PATTERN; + case STUDIO -> STUDIO_VAR_TEMPLATE_PATTERN; + default -> throw new UnsupportedOperationException(); + }; + Matcher matcher = pattern.matcher(template); + MatchResult result = matcher.results().findFirst().orElseThrow(); + return new VariableSelector(result.group(1), result.group(2)); + } + } public static DialectConverter defaultCustomDialectConverter(Class clazz) { @@ -100,6 +119,12 @@ public Map dump(R nodeData) { protected abstract List> getDialectConverters(); + private static final Pattern DIFY_VAR_TEMPLATE_PATTERN = Pattern.compile("\\{\\{#(\\w+)\\.(\\w+)#}}"); + + private static final Pattern STUDIO_VAR_TEMPLATE_PATTERN = Pattern.compile("\\$\\{(\\w+)\\.(\\w+)}"); + + private static final Pattern VAR_TEMPLATE_PATTERN = Pattern.compile("\\{(\\w+)}"); + /** * 将文本中变量占位符进行转化,比如Dify DSL的"你好,{{#123.query#}}"转化为"你好,{nodeName1_query}" * @param dialectType dsl语言 @@ -116,8 +141,7 @@ protected String convertVarTemplate(DSLDialectType dialectType, String templateS return str; } StringBuilder result = new StringBuilder(); - Pattern pattern = Pattern.compile("\\{\\{#(\\w+)\\.(\\w+)#}}"); - Matcher matcher = pattern.matcher(str); + Matcher matcher = DIFY_VAR_TEMPLATE_PATTERN.matcher(str); while (matcher.find()) { String nodeId = matcher.group(1); String varName = matcher.group(2); @@ -133,8 +157,8 @@ protected String convertVarTemplate(DSLDialectType dialectType, String templateS return str; } StringBuilder result = new StringBuilder(); - Pattern pattern = Pattern.compile("\\$\\{(\\w+)\\.(\\w+)}"); - Matcher matcher = pattern.matcher(str); + + Matcher matcher = STUDIO_VAR_TEMPLATE_PATTERN.matcher(str); while (matcher.find()) { String nodeId = matcher.group(1); String varName = matcher.group(2); @@ -150,8 +174,6 @@ protected String convertVarTemplate(DSLDialectType dialectType, String templateS return func.apply(templateString, idToVarName); } - private static final Pattern VAR_TEMPLATE_PATTERN = Pattern.compile("\\{(\\w+)}"); - /** * 获取模板中的变量占位符,比如"你好{var1},{var2}"返回"[var1, var2]" * @param template 模板字符串 diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/CodeNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/CodeNodeDataConverter.java index cc7bb63207..8b7bd9c167 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/CodeNodeDataConverter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/CodeNodeDataConverter.java @@ -24,6 +24,7 @@ import java.util.stream.Stream; import com.alibaba.cloud.ai.studio.admin.generator.model.Variable; +import com.alibaba.cloud.ai.studio.admin.generator.model.VariableSelector; import com.alibaba.cloud.ai.studio.admin.generator.model.VariableType; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeType; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata.CodeNodeData; @@ -138,7 +139,82 @@ public Boolean supportDialect(DSLDialectType dialectType) { @Override public CodeNodeData parse(Map data) throws JsonProcessingException { - return null; + CodeNodeData nodeData = new CodeNodeData(); + + // 获取基本信息 + String code = MapReadUtil.getMapDeepValue(data, String.class, "config", "node_param", "script_content"); + String lang = MapReadUtil.getMapDeepValue(data, String.class, "config", "node_param", "script_type"); + Boolean isRetry = Optional + .ofNullable(MapReadUtil.getMapDeepValue(data, Boolean.class, "config", "node_param", "retry_config", + "retry_enabled")) + .orElse(false); + int maxRetryCount = isRetry ? Optional + .ofNullable(MapReadUtil.getMapDeepValue(data, Integer.class, "config", "node_param", "retry_config", + "max_retries")) + .orElse(1) : 1; + int retryIntervalMs = isRetry ? Optional + .ofNullable(MapReadUtil.getMapDeepValue(data, Integer.class, "config", "node_param", "retry_config", + "retry_interval")) + .orElse(1000) : 1000; + + List outputParams = Optional + .ofNullable(MapReadUtil.safeCastToListWithMap( + MapReadUtil.getMapDeepValue(data, List.class, "config", "output_params"))) + .orElse(List.of()) + .stream() + .filter(map -> map.containsKey("key")) + .map(map -> new Variable(map.get("key").toString(), VariableType.OBJECT)) + .toList(); + List inputParams = Optional + .ofNullable(MapReadUtil + .safeCastToListWithMap(MapReadUtil.getMapDeepValue(data, List.class, "config", "input_params"))) + .orElse(List.of()) + .stream() + .filter(map -> map.containsKey("key") && map.containsKey("value") && map.containsKey("value_from")) + .map(map -> { + String key = map.get("key").toString(); + Object value = map.get("value"); + String valueFrom = map.get("value_from").toString(); + if ("input".equalsIgnoreCase(valueFrom)) { + return CodeNodeData.CodeParam.withValue(key, value); + } + else { + // 先以Value的形式存储selector,在post阶段转换为正确的stateKey + VariableSelector selector = this.varTemplateToSelector(DSLDialectType.STUDIO, + value.toString()); + List list = List.of(selector.getNamespace(), selector.getName()); + return new CodeNodeData.CodeParam(key, list, value.toString()); + } + }) + .toList(); + + // 设置基本信息 + nodeData.setCodeStyle(CodeNodeData.CodeStyle.GLOBAL_DICTIONARY); + nodeData.setCode(code); + nodeData.setCodeLanguage(lang); + nodeData.setMaxRetryCount(maxRetryCount); + nodeData.setRetryIntervalMs(retryIntervalMs); + nodeData.setInputParams(inputParams); + nodeData.setOutputs(outputParams); + + // 设置错误策略 + String errorStrategy = MapReadUtil.getMapDeepValue(data, String.class, "config", "node_param", + "try_catch_config", "strategy"); + if (errorStrategy != null) { + // 暂仅支持默认值 + List> defaultValueList = MapReadUtil + .safeCastToListWithMap(MapReadUtil.getMapDeepValue(data, List.class, "config", "node_param", + "try_catch_config", "default_values")); + if (defaultValueList != null) { + Map defaultValue = defaultValueList.stream() + .filter(map -> map.containsKey("key") && map.containsKey("value")) + .collect(Collectors.toUnmodifiableMap(map -> map.get("key").toString(), + map -> map.get("value"), (a, b) -> b)); + nodeData.setDefaultValue(defaultValue); + } + } + + return nodeData; } @Override @@ -180,7 +256,7 @@ public BiConsumer> postProcessConsumer(DSLDial @SuppressWarnings("unchecked") List selector = (List) param.value(); return CodeNodeData.CodeParam.withKey(param.argName(), - idToVarName.get(selector.get(0)) + "_" + selector.get(1)); + idToVarName.getOrDefault(selector.get(0), selector.get(0)) + "_" + selector.get(1)); }).toList()); }).andThen(super.postProcessConsumer(dialectType)); default -> super.postProcessConsumer(dialectType); diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java index 8794bc655a..8f399ecfc7 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java @@ -237,12 +237,14 @@ private String renderImportSection(Workflow workflow) { Map.entry(NodeType.ANSWER, List.of("com.alibaba.cloud.ai.graph.node.AnswerNode")), Map.entry(NodeType.MIDDLE_OUTPUT, List.of("java.util.stream.Collectors", "org.springframework.ai.chat.prompt.PromptTemplate")), - Map.entry(NodeType.CODE, List.of("com.alibaba.cloud.ai.graph.node.code.CodeExecutorNodeAction", - "com.alibaba.cloud.ai.graph.node.code.entity.CodeExecutionConfig", - "com.alibaba.cloud.ai.graph.node.code.CodeExecutor", - "com.alibaba.cloud.ai.graph.node.code.LocalCommandlineCodeExecutor", "java.io.IOException", - "java.nio.file.Files", "java.nio.file.Path", "java.util.stream.Collectors", - "com.alibaba.cloud.ai.graph.node.code.entity.CodeParam")), + Map.entry(NodeType.CODE, + List.of("com.alibaba.cloud.ai.graph.node.code.CodeExecutorNodeAction", + "com.alibaba.cloud.ai.graph.node.code.entity.CodeExecutionConfig", + "com.alibaba.cloud.ai.graph.node.code.CodeExecutor", + "com.alibaba.cloud.ai.graph.node.code.LocalCommandlineCodeExecutor", + "java.io.IOException", "java.nio.file.Files", "java.nio.file.Path", + "java.util.stream.Collectors", "com.alibaba.cloud.ai.graph.node.code.entity.CodeParam", + "com.alibaba.cloud.ai.graph.node.code.entity.CodeStyle")), Map.entry(NodeType.AGENT, List.of("com.alibaba.cloud.ai.graph.node.AgentNode", "org.springframework.ai.tool.ToolCallback")), diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/CodeNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/CodeNodeSection.java index 658dc3a02f..a4c7461963 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/CodeNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/CodeNodeSection.java @@ -42,6 +42,7 @@ public String render(Node node, String varName) { .codeExecutor(codeExecutor) .codeLanguage("%s") .code(%s) + .codeStyle(%s) .config(codeExecutionConfig) .params(%s) .outputKey("%s") @@ -51,8 +52,9 @@ public String render(Node node, String varName) { )); """, node.getId(), varName, nodeData.getCodeLanguage(), ObjectToCodeUtil.toCode(nodeData.getCode()), - ObjectToCodeUtil.toCode(nodeData.getInputParams()), nodeData.getOutputKey(), varName, varName, - nodeData.getOutputKey(), varName, nodeData.getMaxRetryCount(), nodeData.getRetryIntervalMs(), + ObjectToCodeUtil.toCode(nodeData.getCodeStyle()), ObjectToCodeUtil.toCode(nodeData.getInputParams()), + nodeData.getOutputKey(), varName, varName, nodeData.getOutputKey(), varName, + nodeData.getMaxRetryCount(), nodeData.getRetryIntervalMs(), ObjectToCodeUtil.toCode(nodeData.getDefaultValue())); } From e11c484d7e3728aa8e8c0c828624bfc6c0be88a4 Mon Sep 17 00:00:00 2001 From: VLSMB <2047857654@qq.com> Date: Fri, 12 Sep 2025 12:06:00 +0800 Subject: [PATCH 10/11] feat: support IterationNode for Studio DSL --- .../generator/model/workflow/NodeType.java | 2 +- .../dsl/AbstractNodeDataConverter.java | 2 +- .../dsl/converter/EmptyNodeDataConverter.java | 8 ++-- .../converter/IterationNodeDataConverter.java | 39 ++++++++++++++++++- .../workflow/WorkflowProjectGenerator.java | 7 +--- .../sections/IterationNodeSection.java | 2 +- 6 files changed, 45 insertions(+), 15 deletions(-) diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java index 411d448339..6fefead060 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java @@ -58,7 +58,7 @@ public enum NodeType { TEMPLATE_TRANSFORM("template-transform", "template-transform", "UNSUPPORTED"), - ITERATION("iteration", "iteration", "UNSUPPORTED"), + ITERATION("iteration", "iteration", "Parallel"), EMPTY("empty", "UNSUPPORTED", "UNSUPPORTED"), diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/AbstractNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/AbstractNodeDataConverter.java index f389ce3f8c..55f0b01c2f 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/AbstractNodeDataConverter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/AbstractNodeDataConverter.java @@ -121,7 +121,7 @@ public Map dump(R nodeData) { private static final Pattern DIFY_VAR_TEMPLATE_PATTERN = Pattern.compile("\\{\\{#(\\w+)\\.(\\w+)#}}"); - private static final Pattern STUDIO_VAR_TEMPLATE_PATTERN = Pattern.compile("\\$\\{(\\w+)\\.(\\w+)}"); + private static final Pattern STUDIO_VAR_TEMPLATE_PATTERN = Pattern.compile("\\$\\{(\\w+)\\.\\[?(\\w+)]?}"); private static final Pattern VAR_TEMPLATE_PATTERN = Pattern.compile("\\{(\\w+)}"); diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/EmptyNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/EmptyNodeDataConverter.java index 1588fc70f1..be7103331a 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/EmptyNodeDataConverter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/EmptyNodeDataConverter.java @@ -37,10 +37,10 @@ public class EmptyNodeDataConverter extends AbstractNodeDataConverter() { + ALL(new DialectConverter<>() { @Override public Boolean supportDialect(DSLDialectType dialectType) { - return dialectType.equals(DSLDialectType.DIFY); + return true; } @Override @@ -50,9 +50,9 @@ public EmptyNodeData parse(Map data) throws JsonProcessingExcept @Override public Map dump(EmptyNodeData nodeData) { - return Map.of(); + throw new UnsupportedOperationException(); } - }), CUSTOM(AbstractNodeDataConverter.defaultCustomDialectConverter(EmptyNodeData.class)); + }); private final DialectConverter dialectConverter; diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/IterationNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/IterationNodeDataConverter.java index 5a0459e17d..8e8fd22c09 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/IterationNodeDataConverter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/IterationNodeDataConverter.java @@ -86,7 +86,42 @@ public Boolean supportDialect(DSLDialectType dialectType) { @Override public IterationNodeData parse(Map data) throws JsonProcessingException { - throw new UnsupportedOperationException(); + IterationNodeData nodeData = new IterationNodeData(); + + // 获取必要信息 + int parallelCount = Optional + .ofNullable( + MapReadUtil.getMapDeepValue(data, Integer.class, "config", "node_param", "concurrent_size")) + .orElse(1); + int maxIterationCount = Optional + .ofNullable(MapReadUtil.getMapDeepValue(data, Integer.class, "config", "node_param", "batch_size")) + .orElse(1); + int indexOffset = 1; + + List> inputParamsList = Optional + .ofNullable(MapReadUtil + .safeCastToListWithMap(MapReadUtil.getMapDeepValue(data, List.class, "config", "input_params"))) + .orElse(List.of()); + List> outputParamsList = Optional + .ofNullable(MapReadUtil.safeCastToListWithMap( + MapReadUtil.getMapDeepValue(data, List.class, "config", "output_params"))) + .orElse(List.of()); + String itemKey = Optional.ofNullable(inputParamsList.get(0).get("key").toString()).orElse("item"); + String outputKey = Optional.ofNullable(outputParamsList.get(0).get("key").toString()).orElse("output"); + VariableSelector inputSelector = this.varTemplateToSelector(DSLDialectType.STUDIO, + inputParamsList.get(0).get("value").toString()); + VariableSelector resultSelector = this.varTemplateToSelector(DSLDialectType.STUDIO, + outputParamsList.get(0).get("value").toString()); + + // 设置必要信息 + nodeData.setParallelCount(parallelCount); + nodeData.setMaxIterationCount(maxIterationCount); + nodeData.setIndexOffset(indexOffset); + nodeData.setItemKey(itemKey); + nodeData.setOutputKey(outputKey); + nodeData.setInputSelector(inputSelector); + nodeData.setResultSelector(resultSelector); + return nodeData; } @Override @@ -129,7 +164,7 @@ public String generateVarName(int count) { @Override public BiConsumer> postProcessConsumer(DSLDialectType dialectType) { return switch (dialectType) { - case DIFY -> emptyProcessConsumer().andThen((nodeData, idToVarName) -> { + case DIFY, STUDIO -> emptyProcessConsumer().andThen((nodeData, idToVarName) -> { nodeData.setInputs(List.of(nodeData.getInputSelector(), nodeData.getResultSelector())); nodeData.setOutputs(Stream diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java index 8f399ecfc7..bfc806fe1b 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java @@ -38,7 +38,6 @@ import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeData; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeType; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.Workflow; -import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata.CodeNodeData; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLAdapter; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; import com.alibaba.cloud.ai.studio.admin.generator.service.generator.GraphProjectDescription; @@ -74,8 +73,6 @@ public class WorkflowProjectGenerator implements ProjectGenerator { private final String PACKAGE_NAME = "packageName"; - private final String HAS_CODE = "hasCode"; - private final List dslAdapters; private final TemplateRenderer templateRenderer; @@ -118,8 +115,6 @@ public void generate(GraphProjectDescription projectDescription, Path projectRoo Map varNames = nodes.stream() .collect(Collectors.toMap(Node::getId, n -> n.getData().getVarName())); - boolean hasCode = nodes.stream().map(Node::getData).anyMatch(nd -> nd instanceof CodeNodeData); - String assistMethodCode = renderAssistMethodCode(nodes, projectDescription.getDslDialectType()); String stateSectionStr = renderStateSections( Stream.of(workflow.getWorkflowVars(), workflow.getEnvVars()).flatMap(List::stream).toList()); @@ -129,7 +124,7 @@ public void generate(GraphProjectDescription projectDescription, Path projectRoo Map graphBuilderModel = Map.of(PACKAGE_NAME, projectDescription.getPackageName(), GRAPH_BUILDER_STATE_SECTION, stateSectionStr, GRAPH_BUILDER_NODE_SECTION, nodeSectionStr, GRAPH_BUILDER_EDGE_SECTION, edgeSectionStr, GRAPH_BUILDER_IMPORT_SECTION, renderImportSection(workflow), - HAS_CODE, hasCode, GRAPH_BUILDER_ASSIST_METHOD_CODE, assistMethodCode); + GRAPH_BUILDER_ASSIST_METHOD_CODE, assistMethodCode); Map graphRunControllerModel = Map.of(PACKAGE_NAME, projectDescription.getPackageName()); renderAndWriteTemplates(List.of(GRAPH_BUILDER_TEMPLATE_NAME, GRAPH_RUN_TEMPLATE_NAME), List.of(graphBuilderModel, graphRunControllerModel), projectRoot, projectDescription); diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/IterationNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/IterationNodeSection.java index fbdcea67a4..6cc80124fe 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/IterationNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/IterationNodeSection.java @@ -31,7 +31,7 @@ * @author vlsmb * @since 2025/7/23 */ -// TODO: 支持并行模式、错误处理,支持Studio的默认输入值 +// TODO: 支持并行模式、错误处理,支持Studio的默认输入值,支持Studio的多输入/多输出 @Component public class IterationNodeSection implements NodeSection { From 491b37f1049e3906c8a63cbd0f5153d817c8103b Mon Sep 17 00:00:00 2001 From: VLSMB <2047857654@qq.com> Date: Fri, 12 Sep 2025 15:49:16 +0800 Subject: [PATCH 11/11] refactor: add NodeSection#getImports Method --- .../generator/workflow/NodeSection.java | 6 ++ .../workflow/WorkflowProjectGenerator.java | 65 +------------------ .../workflow/sections/AgentNodeSection.java | 6 ++ .../workflow/sections/AnswerNodeSection.java | 7 ++ .../sections/AssignerNodeSection.java | 7 ++ .../workflow/sections/BranchNodeSection.java | 5 ++ .../workflow/sections/CodeNodeSection.java | 13 ++++ .../DocumentExtractorNodeSection.java | 5 ++ .../workflow/sections/EmptyNodeSection.java | 7 ++ .../workflow/sections/EndNodeSection.java | 6 ++ .../workflow/sections/HttpNodeSection.java | 6 ++ .../workflow/sections/HumanNodeSection.java | 5 ++ .../sections/IterationNodeSection.java | 15 +++++ .../KnowledgeRetrievalNodeSection.java | 12 ++++ .../workflow/sections/LLMNodeSection.java | 13 ++++ .../sections/ListOperatorNodeSection.java | 7 ++ .../workflow/sections/MCPNodeSection.java | 5 ++ .../sections/MiddleOutputSection.java | 7 ++ .../sections/ParameterParsingNodeSection.java | 5 ++ .../QuestionClassifierNodeSection.java | 6 ++ .../workflow/sections/StartNodeSection.java | 5 ++ .../TemplateTransformNodeSection.java | 7 ++ .../workflow/sections/ToolNodeSection.java | 6 ++ .../VariableAggregatorNodeSection.java | 5 ++ 24 files changed, 169 insertions(+), 62 deletions(-) diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/NodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/NodeSection.java index d0059e33c2..c847ed4485 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/NodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/NodeSection.java @@ -38,6 +38,12 @@ public interface NodeSection { String render(Node node, String varName); + /** + * 返回当前节点需要导入的类列表 + * @return 类列表 + */ + List getImports(); + default String escape(String input) { if (input == null) { return ""; diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java index bfc806fe1b..59be8b4a96 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java @@ -225,72 +225,12 @@ private String renderEdgeSections(List edges, List nodes, Map> nodeTypeToClass = Map.ofEntries( - Map.entry(NodeType.ANSWER, List.of("com.alibaba.cloud.ai.graph.node.AnswerNode")), - Map.entry(NodeType.MIDDLE_OUTPUT, - List.of("java.util.stream.Collectors", "org.springframework.ai.chat.prompt.PromptTemplate")), - Map.entry(NodeType.CODE, - List.of("com.alibaba.cloud.ai.graph.node.code.CodeExecutorNodeAction", - "com.alibaba.cloud.ai.graph.node.code.entity.CodeExecutionConfig", - "com.alibaba.cloud.ai.graph.node.code.CodeExecutor", - "com.alibaba.cloud.ai.graph.node.code.LocalCommandlineCodeExecutor", - "java.io.IOException", "java.nio.file.Files", "java.nio.file.Path", - "java.util.stream.Collectors", "com.alibaba.cloud.ai.graph.node.code.entity.CodeParam", - "com.alibaba.cloud.ai.graph.node.code.entity.CodeStyle")), - Map.entry(NodeType.AGENT, - List.of("com.alibaba.cloud.ai.graph.node.AgentNode", - "org.springframework.ai.tool.ToolCallback")), - Map.entry(NodeType.LLM, - List.of("org.springframework.ai.chat.messages.Message", - "org.springframework.ai.chat.messages.AssistantMessage", - "org.springframework.ai.chat.messages.MessageType", - "org.springframework.ai.chat.messages.SystemMessage", - "org.springframework.ai.chat.messages.UserMessage", - "com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions", - "org.springframework.beans.factory.annotation.Autowired", "java.util.Optional")), - Map.entry(NodeType.BRANCH, - List.of("com.alibaba.cloud.ai.graph.node.BranchNode", - "static com.alibaba.cloud.ai.graph.action.AsyncEdgeAction.edge_async")), - Map.entry(NodeType.DOC_EXTRACTOR, List.of("com.alibaba.cloud.ai.graph.node.DocumentExtractorNode")), - Map.entry(NodeType.HTTP, - List.of("com.alibaba.cloud.ai.graph.node.HttpNode", "org.springframework.http.HttpMethod")), - Map.entry(NodeType.LIST_OPERATOR, - List.of("com.alibaba.cloud.ai.graph.node.ListOperatorNode", "java.util.Comparator")), - Map.entry(NodeType.QUESTION_CLASSIFIER, - List.of("com.alibaba.cloud.ai.graph.node.QuestionClassifierNode", - "static com.alibaba.cloud.ai.graph.action.AsyncEdgeAction.edge_async")), - Map.entry(NodeType.PARAMETER_PARSING, - List.of("com.alibaba.cloud.ai.graph.node.ParameterParsingNode", "java.util.stream.Collectors")), - Map.entry(NodeType.TEMPLATE_TRANSFORM, - List.of("com.alibaba.cloud.ai.graph.node.TemplateTransformNode")), - Map.entry(NodeType.TOOL, - List.of("com.alibaba.cloud.ai.graph.node.ToolNode", "java.util.function.Function", - "org.springframework.ai.tool.function.FunctionToolCallback")), - Map.entry(NodeType.RETRIEVER, List.of("com.alibaba.cloud.ai.graph.node.KnowledgeRetrievalNode", - "org.springframework.ai.embedding.EmbeddingModel", "org.springframework.ai.reader.TextReader", - "org.springframework.ai.transformer.splitter.TokenTextSplitter", - "org.springframework.ai.vectorstore.SimpleVectorStore", - "org.springframework.ai.vectorstore.VectorStore", - "org.springframework.beans.factory.annotation.Value", "org.springframework.core.io.Resource", - "org.springframework.ai.document.Document", - "org.springframework.beans.factory.annotation.Autowired", - "org.springframework.core.io.ResourceLoader", "java.util.Optional")), - Map.entry(NodeType.AGGREGATOR, - List.of("com.alibaba.cloud.ai.graph.node.VariableAggregatorNode", - "java.util.stream.Collectors")), - Map.entry(NodeType.ASSIGNER, List.of("com.alibaba.cloud.ai.graph.node.AssignerNode")), - Map.entry(NodeType.ITERATION, List.of("java.util.ArrayList", "java.util.Arrays")), - Map.entry(NodeType.END, List.of("java.util.stream.Stream", "java.util.stream.Collectors", - "org.springframework.ai.chat.prompt.PromptTemplate"))); - + // construct a set of node types Set uniqueTypes = workflow.getGraph() .getNodes() .stream() .map(Node::getType) - .filter(nodeTypeToClass::containsKey) .collect(Collectors.toSet()); if (uniqueTypes.isEmpty()) { @@ -299,7 +239,8 @@ private String renderImportSection(Workflow workflow) { StringBuilder sb = new StringBuilder(); uniqueTypes.stream() - .map(nodeTypeToClass::get) + .map(nodeSectionMap::get) + .map(NodeSection::getImports) .flatMap(List::stream) .distinct() .forEach(className -> sb.append("import ").append(className).append(";\n")); diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/AgentNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/AgentNodeSection.java index 82fda33b1f..81780c31c8 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/AgentNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/AgentNodeSection.java @@ -22,6 +22,7 @@ import com.alibaba.cloud.ai.studio.admin.generator.service.generator.workflow.NodeSection; import org.springframework.stereotype.Component; +import java.util.List; import java.util.stream.Collectors; @Component @@ -67,4 +68,9 @@ public String render(Node node, String varName) { return sb.toString(); } + @Override + public List getImports() { + return List.of("com.alibaba.cloud.ai.graph.node.AgentNode", "org.springframework.ai.tool.ToolCallback"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/AnswerNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/AnswerNodeSection.java index 0bf1e80f52..e1ffb99b29 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/AnswerNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/AnswerNodeSection.java @@ -23,6 +23,8 @@ import org.springframework.stereotype.Component; +import java.util.List; + @Component public class AnswerNodeSection implements NodeSection { @@ -59,4 +61,9 @@ public String render(Node node, String varName) { return sb.toString(); } + @Override + public List getImports() { + return List.of("com.alibaba.cloud.ai.graph.node.AnswerNode"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/AssignerNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/AssignerNodeSection.java index 9df2d0f59d..1bfdb1e242 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/AssignerNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/AssignerNodeSection.java @@ -22,6 +22,8 @@ import org.springframework.stereotype.Component; +import java.util.List; + @Component public class AssignerNodeSection implements NodeSection { @@ -52,4 +54,9 @@ public String render(Node node, String varName) { return sb.toString(); } + @Override + public List getImports() { + return List.of("com.alibaba.cloud.ai.graph.node.AssignerNode"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/BranchNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/BranchNodeSection.java index e5707d64fa..384fe9d2c9 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/BranchNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/BranchNodeSection.java @@ -156,4 +156,9 @@ private String buildVariablePath(Case.Condition condition) { return Optional.ofNullable(variableSelector.getNameInCode()).orElse("unknown"); } + @Override + public List getImports() { + return List.of("static com.alibaba.cloud.ai.graph.action.AsyncEdgeAction.edge_async"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/CodeNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/CodeNodeSection.java index a4c7461963..44d3e350f2 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/CodeNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/CodeNodeSection.java @@ -24,6 +24,8 @@ import com.alibaba.cloud.ai.studio.admin.generator.utils.ObjectToCodeUtil; import org.springframework.stereotype.Component; +import java.util.List; + // TODO: 支持异常分支 @Component public class CodeNodeSection implements NodeSection { @@ -111,4 +113,15 @@ private NodeAction wrapperCodeNodeAction(NodeAction codeNodeAction, String key, }; } + @Override + public List getImports() { + return List.of("com.alibaba.cloud.ai.graph.node.code.CodeExecutorNodeAction", + "com.alibaba.cloud.ai.graph.node.code.entity.CodeExecutionConfig", + "com.alibaba.cloud.ai.graph.node.code.CodeExecutor", + "com.alibaba.cloud.ai.graph.node.code.LocalCommandlineCodeExecutor", "java.io.IOException", + "java.nio.file.Files", "java.nio.file.Path", "java.util.stream.Collectors", + "com.alibaba.cloud.ai.graph.node.code.entity.CodeParam", + "com.alibaba.cloud.ai.graph.node.code.entity.CodeStyle"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/DocumentExtractorNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/DocumentExtractorNodeSection.java index 00572120ec..f076fa3b6d 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/DocumentExtractorNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/DocumentExtractorNodeSection.java @@ -60,4 +60,9 @@ public String render(Node node, String varName) { return sb.toString(); } + @Override + public List getImports() { + return List.of("com.alibaba.cloud.ai.graph.node.DocumentExtractorNode"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/EmptyNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/EmptyNodeSection.java index a116265d39..32364f20a1 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/EmptyNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/EmptyNodeSection.java @@ -23,6 +23,8 @@ import org.springframework.stereotype.Component; +import java.util.List; + /** * @author vlsmb * @since 2025/7/23 @@ -46,4 +48,9 @@ public String render(Node node, String varName) { return sb.toString(); } + @Override + public List getImports() { + return List.of(); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/EndNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/EndNodeSection.java index 6d98d2bbcb..461833a1ee 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/EndNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/EndNodeSection.java @@ -90,4 +90,10 @@ public String render(Node node, String varName) { return sb.toString(); } + @Override + public List getImports() { + return List.of("java.util.stream.Stream", "java.util.stream.Collectors", + "org.springframework.ai.chat.prompt.PromptTemplate"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/HttpNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/HttpNodeSection.java index 23be5108e6..ac40e57073 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/HttpNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/HttpNodeSection.java @@ -16,6 +16,7 @@ package com.alibaba.cloud.ai.studio.admin.generator.service.generator.workflow.sections; +import java.util.List; import java.util.Map; import com.alibaba.cloud.ai.graph.node.HttpNode; @@ -126,4 +127,9 @@ private NodeAction wrapperHttpNodeAction(NodeAction httpNodeAction, String varNa }; } + @Override + public List getImports() { + return List.of("com.alibaba.cloud.ai.graph.node.HttpNode", "org.springframework.http.HttpMethod"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/HumanNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/HumanNodeSection.java index 243eac3824..8bcf06f607 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/HumanNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/HumanNodeSection.java @@ -82,4 +82,9 @@ public String render(Node node, String varName) { return sb.toString(); } + @Override + public List getImports() { + return List.of(); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/IterationNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/IterationNodeSection.java index 6cc80124fe..b9ce707318 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/IterationNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/IterationNodeSection.java @@ -51,6 +51,11 @@ public String renderEdges(IterationNodeData nodeData, List edges) { return ""; } + @Override + public List getImports() { + return List.of("java.util.ArrayList", "java.util.Arrays"); + } + // 规定迭代节点的start为iterationVarName_start,end为iterationVarName_end @Component @@ -137,6 +142,11 @@ private NodeAction createIterationStartAction( """; } + @Override + public List getImports() { + return List.of("java.util.ArrayList", "java.util.Arrays"); + } + } @Component @@ -195,6 +205,11 @@ private NodeAction createIterationEndAction(String flagKey, String resultKey, St """; } + @Override + public List getImports() { + return List.of("java.util.ArrayList", "java.util.Arrays"); + } + } } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/KnowledgeRetrievalNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/KnowledgeRetrievalNodeSection.java index d65c19be25..a9911beb47 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/KnowledgeRetrievalNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/KnowledgeRetrievalNodeSection.java @@ -234,4 +234,16 @@ public List resourceFiles(DSLDialectType dialectType, KnowledgeRet .orElse(NodeSection.super.resourceFiles(dialectType, nodeData)); } + @Override + public List getImports() { + return List.of("com.alibaba.cloud.ai.graph.node.KnowledgeRetrievalNode", + "org.springframework.ai.embedding.EmbeddingModel", "org.springframework.ai.reader.TextReader", + "org.springframework.ai.transformer.splitter.TokenTextSplitter", + "org.springframework.ai.vectorstore.SimpleVectorStore", + "org.springframework.ai.vectorstore.VectorStore", "org.springframework.beans.factory.annotation.Value", + "org.springframework.core.io.Resource", "org.springframework.ai.document.Document", + "org.springframework.beans.factory.annotation.Autowired", "org.springframework.core.io.ResourceLoader", + "java.util.Optional"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/LLMNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/LLMNodeSection.java index 74e1f7ae9e..2e61c9229d 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/LLMNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/LLMNodeSection.java @@ -25,6 +25,8 @@ import com.alibaba.cloud.ai.studio.admin.generator.utils.ObjectToCodeUtil; import org.springframework.stereotype.Component; +import java.util.List; + // TODO:支持异常分支、支持DashScope平台以外其他模型、Dify的结构化输出 @Component public class LLMNodeSection implements NodeSection { @@ -148,4 +150,15 @@ else if (errorNextNode != null) { : "Map.of(outputKeyPrefix + \"output\", defaultOutput, outputKeyPrefix + \"reasoning_content\", defaultOutput)"); } + @Override + public List getImports() { + return List.of("org.springframework.ai.chat.messages.Message", + "org.springframework.ai.chat.messages.AssistantMessage", + "org.springframework.ai.chat.messages.MessageType", + "org.springframework.ai.chat.messages.SystemMessage", + "org.springframework.ai.chat.messages.UserMessage", + "com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions", + "org.springframework.beans.factory.annotation.Autowired", "java.util.Optional"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ListOperatorNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ListOperatorNodeSection.java index 0d81e7978c..0215c3d8dd 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ListOperatorNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ListOperatorNodeSection.java @@ -23,6 +23,8 @@ import org.springframework.stereotype.Component; +import java.util.List; + @Component public class ListOperatorNodeSection implements NodeSection { @@ -107,4 +109,9 @@ public String render(Node node, String varName) { return sb.toString(); } + @Override + public List getImports() { + return List.of("com.alibaba.cloud.ai.graph.node.ListOperatorNode", "java.util.Comparator"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/MCPNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/MCPNodeSection.java index 8f0b332fa7..a8c58e6b1e 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/MCPNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/MCPNodeSection.java @@ -90,4 +90,9 @@ public String render(Node node, String varName) { return sb.toString(); } + @Override + public List getImports() { + return List.of(); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/MiddleOutputSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/MiddleOutputSection.java index 47f4735d3f..63a18484ce 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/MiddleOutputSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/MiddleOutputSection.java @@ -24,6 +24,8 @@ import com.alibaba.cloud.ai.studio.admin.generator.utils.ObjectToCodeUtil; import org.springframework.stereotype.Component; +import java.util.List; + @Component public class MiddleOutputSection implements NodeSection { @@ -67,4 +69,9 @@ private NodeAction createMiddleOutputNodeAction(String outputTemplate, List getImports() { + return List.of("java.util.stream.Collectors", "org.springframework.ai.chat.prompt.PromptTemplate"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ParameterParsingNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ParameterParsingNodeSection.java index 25b4e58639..00f0f81809 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ParameterParsingNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ParameterParsingNodeSection.java @@ -103,4 +103,9 @@ private NodeAction wrapperParameterNodeAction(NodeAction nodeAction, String node }; } + @Override + public List getImports() { + return List.of("com.alibaba.cloud.ai.graph.node.ParameterParsingNode", "java.util.stream.Collectors"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/QuestionClassifierNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/QuestionClassifierNodeSection.java index 4eb5ba1df8..1fd035bc69 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/QuestionClassifierNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/QuestionClassifierNodeSection.java @@ -128,4 +128,10 @@ public String renderEdges(QuestionClassifierNodeData nodeData, List edges) return sb.toString(); } + @Override + public List getImports() { + return List.of("com.alibaba.cloud.ai.graph.node.QuestionClassifierNode", + "static com.alibaba.cloud.ai.graph.action.AsyncEdgeAction.edge_async"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/StartNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/StartNodeSection.java index 80b5ec5cb5..4649c1a3d3 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/StartNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/StartNodeSection.java @@ -53,4 +53,9 @@ public String renderEdges(StartNodeData nodeData, List edges) { return sb.toString(); } + @Override + public List getImports() { + return List.of(); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/TemplateTransformNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/TemplateTransformNodeSection.java index 94b3023545..e11dc6c224 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/TemplateTransformNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/TemplateTransformNodeSection.java @@ -23,6 +23,8 @@ import org.springframework.stereotype.Component; +import java.util.List; + @Component public class TemplateTransformNodeSection implements NodeSection { @@ -60,4 +62,9 @@ public String render(Node node, String varName) { return sb.toString(); } + @Override + public List getImports() { + return List.of("com.alibaba.cloud.ai.graph.node.TemplateTransformNode"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ToolNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ToolNodeSection.java index 74a085b6ff..6f54f65053 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ToolNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ToolNodeSection.java @@ -82,4 +82,10 @@ public String render(Node node, String varName) { return sb.toString(); } + @Override + public List getImports() { + return List.of("com.alibaba.cloud.ai.graph.node.ToolNode", "java.util.function.Function", + "org.springframework.ai.tool.function.FunctionToolCallback"); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/VariableAggregatorNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/VariableAggregatorNodeSection.java index bcc72efe60..a4d04572dc 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/VariableAggregatorNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/VariableAggregatorNodeSection.java @@ -168,4 +168,9 @@ else if (object instanceof List list) { }; } + @Override + public List getImports() { + return List.of("com.alibaba.cloud.ai.graph.node.VariableAggregatorNode", "java.util.stream.Collectors"); + } + }