From 0759f72e226a0381ff551a8662aeba9ea3f4eede Mon Sep 17 00:00:00 2001 From: VLSMB <2047857654@qq.com> Date: Sat, 13 Sep 2025 21:59:02 +0800 Subject: [PATCH 01/12] refactor: prepare for BranchNodeDataConverter for Studio DSL --- .../admin/generator/model/VariableType.java | 2 + .../admin/generator/model/workflow/Case.java | 50 ++- .../workflow/ComparisonOperatorType.java | 155 +++++---- .../model/workflow/LogicalOperatorType.java | 18 +- .../generator/model/workflow/NodeType.java | 1 + .../workflow/nodedata/BranchNodeData.java | 16 +- .../nodedata/ListOperatorNodeData.java | 3 +- .../generator/service/dsl/DSLDialectType.java | 1 + .../converter/BranchNodeDataConverter.java | 41 +-- .../workflow/sections/BranchNodeSection.java | 56 ++-- .../sections/SaaBranchNodeDecisionTest.java | 294 ------------------ 11 files changed, 216 insertions(+), 421 deletions(-) delete mode 100644 spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/test/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/SaaBranchNodeDecisionTest.java diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/VariableType.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/VariableType.java index 43d1c943da..ad0bb2f94f 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/VariableType.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/VariableType.java @@ -31,6 +31,8 @@ public enum VariableType { // TODO:定义文件类型对象,以实现工作流直接使用文件 FILE("File", Object.class, "file", "File"), + ARRAY("Array", Object.class, "array", "Array"), + ARRAY_STRING("String[]", String[].class, "array[string]", "String[]"), ARRAY_NUMBER("Number[]", Number[].class, "array[number]", "Number[]"), diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/Case.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/Case.java index 69d193bea8..6e339f5c8b 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/Case.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/Case.java @@ -18,6 +18,7 @@ import java.util.List; import com.alibaba.cloud.ai.studio.admin.generator.model.VariableSelector; +import com.alibaba.cloud.ai.studio.admin.generator.model.VariableType; public class Case { @@ -56,28 +57,51 @@ public Case setConditions(List conditions) { public static class Condition { - private String value; - - private String varType; + private VariableType varType; private ComparisonOperatorType comparisonOperator; - private VariableSelector variableSelector; + private VariableSelector targetSelector; + + private String referenceValue; + + private VariableSelector referenceSelector; + // 参考值可能来自stateKey,也有可能直接是常量值 public String getValue() { - return value; + if (referenceValue != null) { + return referenceValue; + } + else if (referenceSelector != null) { + return String.format("((%s) state.value(\"%s\").orElse(null))", varType.value(), + referenceSelector.getNameInCode()); + } + throw new IllegalStateException("referenceValue or referenceSelector must be set"); + } + + public String getReferenceValue() { + return referenceValue; + } + + public Condition setReferenceValue(String referenceValue) { + this.referenceValue = referenceValue; + return this; + } + + public VariableSelector getReferenceSelector() { + return referenceSelector; } - public Condition setValue(String value) { - this.value = value; + public Condition setReferenceSelector(VariableSelector referenceSelector) { + this.referenceSelector = referenceSelector; return this; } - public String getVarType() { + public VariableType getVarType() { return varType; } - public Condition setVarType(String varType) { + public Condition setVarType(VariableType varType) { this.varType = varType; return this; } @@ -91,12 +115,12 @@ public Condition setComparisonOperator(ComparisonOperatorType comparisonOperator return this; } - public VariableSelector getVariableSelector() { - return variableSelector; + public VariableSelector getTargetSelector() { + return targetSelector; } - public Condition setVariableSelector(VariableSelector variableSelector) { - this.variableSelector = variableSelector; + public Condition setTargetSelector(VariableSelector targetSelector) { + this.targetSelector = targetSelector; return this; } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/ComparisonOperatorType.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/ComparisonOperatorType.java index 277cb17a54..eef73c7ddb 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/ComparisonOperatorType.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/ComparisonOperatorType.java @@ -16,73 +16,132 @@ package com.alibaba.cloud.ai.studio.admin.generator.model.workflow; +import com.alibaba.cloud.ai.studio.admin.generator.model.VariableType; +import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; + import java.util.List; import java.util.function.BiFunction; +import java.util.function.Function; public enum ComparisonOperatorType { - CONTAINS("contains", "contains", List.of(String.class, List.class), - (objName, constVal) -> String.format("(%s.contains(%s))", objName, constVal)), - NOT_CONTAINS("not_contains", "not contains", List.of(String.class, List.class), - (objName, constVal) -> String.format("!(%s.contains(%s))", objName, constVal)), - START_WITH("start_with", "start with", List.of(String.class), - (objName, constVal) -> String.format("(%s.startsWith(%s))", objName, constVal)), - END_WITH("end_with", "end with", List.of(String.class), - (objName, constVal) -> String.format("(%s.endsWith(%s))", objName, constVal)), - IS("is", "is", List.of(String.class, List.class), - (objName, constVal) -> String.format("(%s.equals(%s))", objName, constVal)), - IS_NOT("is_not", "is not", List.of(String.class, List.class), - (objName, constVal) -> String.format("!(%s.equals(%s))", objName, constVal)), - EMPTY("empty", "empty", List.of(String.class, List.class), - (objName, constVal) -> String.format("(%s.isEmpty())", objName)), - NOT_EMPTY("not empty", "not empty", List.of(String.class, List.class), - (objName, constVal) -> String.format("!(%s.isEmpty())", objName)), - IN("in", "in", List.of(String.class, List.class), - (objName, constVal) -> String.format("(%s.contains(%s))", constVal, objName)), - NOT_IN("not_in", "not in", List.of(String.class, List.class), - (objName, constVal) -> String.format("!(%s.contains(%s))", constVal, objName)), - ALL_OF("all_of", "all of", List.of(List.class), + CONTAINS("contains", type -> switch (type) { + case DIFY -> "contains"; + default -> "unknown"; + }, List.of(VariableType.STRING, VariableType.ARRAY, VariableType.ARRAY_OBJECT, VariableType.ARRAY_STRING, + VariableType.ARRAY_NUMBER), (objName, constVal) -> String.format("(%s.contains(%s))", objName, constVal)), + NOT_CONTAINS("not_contains", type -> switch (type) { + case DIFY -> "not contains"; + default -> "unknown"; + }, List.of(VariableType.STRING, VariableType.ARRAY, VariableType.ARRAY_OBJECT, VariableType.ARRAY_STRING, + VariableType.ARRAY_NUMBER), (objName, constVal) -> String.format("!(%s.contains(%s))", objName, constVal)), + START_WITH("start_with", type -> switch (type) { + case DIFY -> "start with"; + default -> "unknown"; + }, List.of(VariableType.STRING), (objName, constVal) -> String.format("(%s.startsWith(%s))", objName, constVal)), + END_WITH("end_with", type -> switch (type) { + case DIFY -> "end with"; + default -> "unknown"; + }, List.of(VariableType.STRING), (objName, constVal) -> String.format("(%s.endsWith(%s))", objName, constVal)), + IS("is", type -> switch (type) { + case DIFY -> "is"; + default -> "unknown"; + }, List.of(VariableType.STRING, VariableType.ARRAY, VariableType.ARRAY_OBJECT, VariableType.ARRAY_STRING, + VariableType.ARRAY_NUMBER), (objName, constVal) -> String.format("(%s.equals(%s))", objName, constVal)), + IS_NOT("is_not", type -> switch (type) { + case DIFY -> "is not"; + default -> "unknown"; + }, List.of(VariableType.STRING, VariableType.ARRAY, VariableType.ARRAY_OBJECT, VariableType.ARRAY_STRING, + VariableType.ARRAY_NUMBER), (objName, constVal) -> String.format("!(%s.equals(%s))", objName, constVal)), + EMPTY("empty", type -> switch (type) { + case DIFY -> "empty"; + default -> "unknown"; + }, List.of(VariableType.STRING, VariableType.ARRAY, VariableType.ARRAY_OBJECT, VariableType.ARRAY_STRING, + VariableType.ARRAY_NUMBER), (objName, constVal) -> String.format("(%s.isEmpty())", objName)), + NOT_EMPTY("not_empty", type -> switch (type) { + case DIFY -> "not empty"; + default -> "unknown"; + }, List.of(VariableType.STRING, VariableType.ARRAY, VariableType.ARRAY_OBJECT, VariableType.ARRAY_STRING, + VariableType.ARRAY_NUMBER), (objName, constVal) -> String.format("!(%s.isEmpty())", objName)), + IN("in", type -> switch (type) { + case DIFY -> "in"; + default -> "unknown"; + }, List.of(VariableType.STRING, VariableType.ARRAY, VariableType.ARRAY_OBJECT, VariableType.ARRAY_STRING, + VariableType.ARRAY_NUMBER), (objName, constVal) -> String.format("(%s.contains(%s))", constVal, objName)), + NOT_IN("not_in", type -> switch (type) { + case DIFY -> "not in"; + default -> "unknown"; + }, List.of(VariableType.STRING, VariableType.ARRAY, VariableType.ARRAY_OBJECT, VariableType.ARRAY_STRING, + VariableType.ARRAY_NUMBER), (objName, constVal) -> String.format("!(%s.contains(%s))", constVal, objName)), + ALL_OF("all_of", type -> switch (type) { + case DIFY -> "all of"; + default -> "unknown"; + }, List.of(VariableType.ARRAY, VariableType.ARRAY_OBJECT, VariableType.ARRAY_STRING, VariableType.ARRAY_NUMBER), (objName, constVal) -> String.format("(%s.containsAll(%s))", objName, constVal)), - EQUAL("equal", "=", List.of(Number.class), + EQUAL("equal", type -> switch (type) { + case DIFY -> "="; + default -> "unknown"; + }, List.of(VariableType.NUMBER), (objName, constVal) -> String.format("(%s.doubleValue() == %s)", objName, constVal)), - NOT_EQUAL("not_equal", "≠", List.of(Number.class), + NOT_EQUAL("not_equal", type -> switch (type) { + case DIFY -> "≠"; + default -> "unknown"; + }, List.of(VariableType.NUMBER), (objName, constVal) -> String.format("(%s.doubleValue() != %s)", objName, constVal)), - GREATER_THAN("greater_than", ">", List.of(Number.class), + GREATER_THAN("greater_than", type -> switch (type) { + case DIFY -> ">"; + default -> "unknown"; + }, List.of(VariableType.NUMBER), (objName, constVal) -> String.format("(%s.doubleValue() > %s)", objName, constVal)), - LESS_THAN("less_than", "<", List.of(Number.class), + LESS_THAN("less_than", type -> switch (type) { + case DIFY -> "<"; + default -> "unknown"; + }, List.of(VariableType.NUMBER), (objName, constVal) -> String.format("(%s.doubleValue() < %s)", objName, constVal)), - NOT_LESS_THAN("not_less_than", "≥", List.of(Number.class), + NOT_LESS_THAN("not_less_than", type -> switch (type) { + case DIFY -> "≥"; + default -> "unknown"; + }, List.of(VariableType.NUMBER), (objName, constVal) -> String.format("(%s.doubleValue() >= %s)", objName, constVal)), - NOT_GREATER_THAN("not_greater_than", "≤", List.of(Number.class), + NOT_GREATER_THAN("not_greater_than", type -> switch (type) { + case DIFY -> "≤"; + default -> "unknown"; + }, List.of(VariableType.NUMBER), (objName, constVal) -> String.format("(%s.doubleValue() <= %s)", objName, constVal)), - NULL("null", "null", List.of(String.class, List.class, Number.class, Object.class), - (objName, constVal) -> String.format("(%s == null)", objName)), - NOT_NULL("not_null", "not null", List.of(String.class, List.class, Number.class, Object.class), - (objName, constVal) -> String.format("(%s != null)", objName)),; + NULL("null", type -> switch (type) { + case DIFY -> "null"; + default -> "unknown"; + }, List.of(VariableType.values()), (objName, constVal) -> String.format("(%s == null)", objName)), + NOT_NULL("not_null", type -> switch (type) { + case DIFY -> "not null"; + default -> "unknown"; + }, List.of(VariableType.values()), (objName, constVal) -> String.format("(%s != null)", objName)); private final String value; - private final String difyValue; + private final Function dslValueFunc; - private final List> supportedClassList; + private final List supportedTypes; private final BiFunction toJavaExpression; - ComparisonOperatorType(String value, String difyValue, List> supportedClassList, - BiFunction toJavaExpression) { + ComparisonOperatorType(String value, Function dslValueFunc, + List supportedTypes, BiFunction toJavaExpression) { this.value = value; - this.difyValue = difyValue; - this.supportedClassList = supportedClassList; + this.dslValueFunc = dslValueFunc; + this.supportedTypes = supportedTypes; this.toJavaExpression = toJavaExpression; } - public static ComparisonOperatorType fromDifyValue(String DifyValue) { + public static ComparisonOperatorType fromDslValue(DSLDialectType dialectType, String dslValue, + VariableType variableType) { for (ComparisonOperatorType comparisonOperatorType : ComparisonOperatorType.values()) { - if (comparisonOperatorType.difyValue.equals(DifyValue)) { + if (comparisonOperatorType.dslValueFunc.apply(dialectType).equals(dslValue) + && comparisonOperatorType.supportedTypes.contains(variableType)) { return comparisonOperatorType; } } - throw new IllegalArgumentException("Not support difyValue:" + DifyValue); + throw new IllegalArgumentException("Not support dslValue:" + dslValue); } public String convert(String objName, String constValue) { @@ -93,20 +152,8 @@ public String getValue() { return value; } - public String getDifyValue() { - return difyValue; - } - - public List> getSupportedClassList() { - return supportedClassList; - } - - public BiFunction getToJavaExpression() { - return toJavaExpression; - } - - public boolean isSupported(Class clazz) { - return this.supportedClassList.contains(clazz); + public String getDslValue(DSLDialectType dialectType) { + return dslValueFunc.apply(dialectType); } } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/LogicalOperatorType.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/LogicalOperatorType.java index 5a2ba1c2ee..917c1a269a 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/LogicalOperatorType.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/LogicalOperatorType.java @@ -18,32 +18,32 @@ public enum LogicalOperatorType { - AND("&&", "and"), OR("||", "or"); + AND("and", "&&"), OR("or", "||"); private final String value; - private final String difyValue; + private final String codeValue; - LogicalOperatorType(String value, String difyValue) { + LogicalOperatorType(String value, String codeValue) { this.value = value; - this.difyValue = difyValue; + this.codeValue = codeValue; } - public static LogicalOperatorType fromDifyValue(String difyValue) { + public static LogicalOperatorType fromValue(String value) { for (LogicalOperatorType logicalOperatorType : LogicalOperatorType.values()) { - if (logicalOperatorType.difyValue.equals(difyValue)) { + if (logicalOperatorType.value.equals(value)) { return logicalOperatorType; } } - throw new IllegalArgumentException("Unsupported logical operator type: " + difyValue); + throw new IllegalArgumentException("Unsupported logical operator type: " + value); } public String getValue() { return this.value; } - public String getDifyValue() { - return this.difyValue; + public String getCodeValue() { + return this.codeValue; } } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java index 6fefead060..a287481342 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java @@ -18,6 +18,7 @@ import java.util.Arrays; import java.util.Optional; +// TODO: 将枚举类的DSL Value字段改为Function public enum NodeType { START("start", "start", "Start"), diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/BranchNodeData.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/BranchNodeData.java index 39e151343d..c94a6f0ba5 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/BranchNodeData.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/BranchNodeData.java @@ -26,12 +26,7 @@ public class BranchNodeData extends NodeData { private List cases; - public BranchNodeData() { - } - - public BranchNodeData(List inputs, List outputs) { - super(inputs, outputs); - } + private String defaultCase; public List getCases() { return cases; @@ -42,4 +37,13 @@ public BranchNodeData setCases(List cases) { return this; } + public String getDefaultCase() { + return defaultCase; + } + + public BranchNodeData setDefaultCase(String defaultCase) { + this.defaultCase = defaultCase; + return this; + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/ListOperatorNodeData.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/ListOperatorNodeData.java index 7a9593000d..bb4e78a204 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/ListOperatorNodeData.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/ListOperatorNodeData.java @@ -23,6 +23,7 @@ import com.alibaba.cloud.ai.studio.admin.generator.model.VariableType; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.ComparisonOperatorType; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeData; +import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; /** * NodeData for ListOperatorNode, which contains all the configurable properties in the @@ -56,7 +57,7 @@ public static FilterCondition ofDify(String condition, String value) { return null; } for (ComparisonOperatorType e : ComparisonOperatorType.values()) { - if (e.getDifyValue().equalsIgnoreCase(condition)) { + if (e.getDslValue(DSLDialectType.DIFY).equalsIgnoreCase(condition)) { return new FilterCondition(e, value); } } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/DSLDialectType.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/DSLDialectType.java index 3e036c0d5a..822f6a3ebe 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/DSLDialectType.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/DSLDialectType.java @@ -18,6 +18,7 @@ import java.util.Arrays; import java.util.Optional; +// TODO: 移动到model包中 public enum DSLDialectType { DIFY("dify", ".yml"), diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/BranchNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/BranchNodeDataConverter.java index ad87859c12..37860a5e35 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/BranchNodeDataConverter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/BranchNodeDataConverter.java @@ -16,7 +16,6 @@ package com.alibaba.cloud.ai.studio.admin.generator.service.dsl.converter; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.function.BiConsumer; @@ -76,39 +75,25 @@ public BranchNodeData parse(Map data) { String difyVarType = (String) conditionMap.get("varType"); VariableType variableType = VariableType.fromDifyValue(difyVarType) .orElse(VariableType.OBJECT); - return new Case.Condition().setValue((String) conditionMap.get("value")) - .setVarType(variableType.value()) - .setComparisonOperator(ComparisonOperatorType - .fromDifyValue((String) conditionMap.get("comparison_operator"))) - .setVariableSelector(new VariableSelector(selectors.get(0), selectors.get(1))); + return new Case.Condition().setReferenceValue((String) conditionMap.get("value")) + .setVarType(variableType) + .setComparisonOperator(ComparisonOperatorType.fromDslValue(DSLDialectType.DIFY, + (String) conditionMap.get("comparison_operator"), variableType)) + .setTargetSelector(new VariableSelector(selectors.get(0), selectors.get(1))); }).collect(Collectors.toList()); cases.add(new Case().setId((String) caseData.get("id")) .setLogicalOperator( - LogicalOperatorType.fromDifyValue((String) caseData.get("logical_operator"))) + LogicalOperatorType.fromValue((String) caseData.get("logical_operator"))) .setConditions(conditions)); } } - return new BranchNodeData(List.of(), List.of()).setCases(cases); + return new BranchNodeData().setCases(cases).setDefaultCase("false"); } @Override public Map dump(BranchNodeData nodeData) { - Map data = new HashMap<>(); - List> caseMaps = nodeData.getCases().stream().map(c -> { - List> conditions = c.getConditions() - .stream() - .map(condition -> Map.of("comparison_operator", - condition.getComparisonOperator().getDifyValue(), "value", condition.getValue(), - "varType", condition.getVarType(), "variable_selector", - List.of(condition.getVariableSelector().getNamespace(), - condition.getVariableSelector().getName()))) - .toList(); - return Map.of("id", c.getId(), "case_id", c.getId(), "conditions", conditions, "logical_operator", - c.getLogicalOperator().getDifyValue()); - }).toList(); - data.put("cases", caseMaps); - return data; + throw new UnsupportedOperationException(); } }), @@ -139,13 +124,19 @@ public Stream extractWorkflowVars(BranchNodeData data) { @Override public BiConsumer> postProcessConsumer(DSLDialectType dialectType) { return switch (dialectType) { - case DIFY -> super.postProcessConsumer(dialectType).andThen((nodeData, idToVarName) -> { + case DIFY, STUDIO -> super.postProcessConsumer(dialectType).andThen((nodeData, idToVarName) -> { // 处理条件里的VariableSelector nodeData.getCases().forEach(c -> { c.getConditions().forEach(condition -> { - VariableSelector selector = condition.getVariableSelector(); + VariableSelector selector = condition.getTargetSelector(); selector.setNameInCode(idToVarName.getOrDefault(selector.getNamespace(), "unknown") + "_" + selector.getName()); + VariableSelector referenceSelector = condition.getReferenceSelector(); + if (referenceSelector != null) { + referenceSelector + .setNameInCode(idToVarName.getOrDefault(referenceSelector.getNamespace(), "unknown") + + "_" + referenceSelector.getName()); + } }); }); }); diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/BranchNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/BranchNodeSection.java index 384fe9d2c9..21700f47b1 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/BranchNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/BranchNodeSection.java @@ -18,9 +18,11 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import java.util.stream.Stream; +import com.alibaba.cloud.ai.studio.admin.generator.model.VariableType; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.Case; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.Edge; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.Node; @@ -31,7 +33,6 @@ import org.springframework.stereotype.Component; -// TODO: 对于Dify的条件渲染,将CaseID格式化为比较易懂的格式 @Component public class BranchNodeSection implements NodeSection { @@ -55,18 +56,31 @@ public String render(Node node, String varName) { @Override public String renderEdges(BranchNodeData branchNodeData, List edges) { + // 此处规定Edge的sourceHandle为caseId,前面的转化需要符合这条规则 String srcVar = branchNodeData.getVarName(); StringBuilder sb = new StringBuilder(); List cases = branchNodeData.getCases(); + // 维护一个caseId到caseName的映射 + AtomicInteger count = new AtomicInteger(1); + Map caseIdToName = cases.stream() + .map(Case::getId) + .collect(Collectors.toUnmodifiableMap(id -> id, id -> { + // 如果一些节点的caseId本身就有含义,直接使用 + if (id.equalsIgnoreCase("default") || id.equalsIgnoreCase("true") || id.equalsIgnoreCase("false")) { + return id; + } + return "case_" + (count.getAndIncrement()); + })); + // 构造EdgeAction.apply函数 StringBuilder conditionsBuffer = new StringBuilder(); for (Case c : cases) { - String logicalOperator = " " + c.getLogicalOperator().getValue() + " "; + String logicalOperator = " " + c.getLogicalOperator().getCodeValue() + " "; List expressions = c.getConditions().stream().map(condition -> { String constValue = condition.getValue(); - if (condition.getVarType().equalsIgnoreCase("String") - || condition.getVarType().equalsIgnoreCase("file")) { + if (condition.getReferenceValue() != null && (VariableType.STRING.equals(condition.getVarType()) + || VariableType.FILE.equals(condition.getVarType()))) { constValue = "\"" + constValue + "\""; } @@ -78,15 +92,16 @@ public String renderEdges(BranchNodeData branchNodeData, List edges) { // 组合复合条件 conditionsBuffer.append(String.join(logicalOperator, expressions)); conditionsBuffer.append(") {\n"); - conditionsBuffer.append(String.format("return \"%s\";", c.getId())); + conditionsBuffer.append(String.format("return \"%s\";", caseIdToName.get(c.getId()))); conditionsBuffer.append("}\n"); } // 最后需要加上else的结果 - conditionsBuffer.append("return \"false\";"); + conditionsBuffer.append(String.format("return \"%s\";", branchNodeData.getDefaultCase())); // 构建Map Map edgeCaseMap = edges.stream() - .collect(Collectors.toMap(Edge::getSourceHandle, Edge::getTarget)); + .collect(Collectors.toMap(e -> caseIdToName.getOrDefault(e.getSourceHandle(), e.getSourceHandle()), + Edge::getTarget)); String edgeCaseMapStr = "Map.of(" + edgeCaseMap.entrySet() .stream() .flatMap(e -> Stream.of(e.getKey(), e.getValue())) @@ -100,19 +115,19 @@ public String renderEdges(BranchNodeData branchNodeData, List edges) { .append(conditionsBuffer) .append("}), ") .append(edgeCaseMapStr) - .append(");\n"); + .append(");\n\n"); return sb.toString(); } private String generateSafeVariableAccess(Case.Condition condition) { - String varType = condition.getVarType(); + VariableType varType = condition.getVarType(); String variablePath = buildVariablePath(condition); - switch (varType.toLowerCase()) { - case "file": + switch (varType) { + case FILE: // 支持从 VariableSelector 中获取属性路径 - VariableSelector selector = condition.getVariableSelector(); + VariableSelector selector = condition.getTargetSelector(); boolean accessExtension = selector != null && (selector.getLabel() != null && selector.getLabel().contains("extension") || selector.getName() != null && selector.getName().contains("extension")); @@ -129,18 +144,21 @@ private String generateSafeVariableAccess(Case.Condition condition) { + "return dotIndex > 0 ? name.substring(dotIndex) : \"\"; " + "}).orElse(\"\")", variablePath); } - case "string": + case STRING: return String.format("state.value(\"%s\", String.class).orElse(\"\")", variablePath); - case "number": + case NUMBER: return String.format("state.value(\"%s\", Number.class).orElse(0)", variablePath); - case "boolean": + case BOOLEAN: return String.format("state.value(\"%s\", Boolean.class).orElse(false)", variablePath); - case "list": - case "array": + case ARRAY_FILE: + case ARRAY_NUMBER: + case ARRAY_STRING: + case ARRAY_OBJECT: + case ARRAY: return String.format( "state.value(\"%s\", java.util.List.class).orElse(java.util.Collections.emptyList())", variablePath); - case "object": + case OBJECT: return String.format("state.value(\"%s\", Object.class).orElse(null)", variablePath); default: // 使用默认的类型 @@ -149,7 +167,7 @@ private String generateSafeVariableAccess(Case.Condition condition) { } private String buildVariablePath(Case.Condition condition) { - VariableSelector variableSelector = condition.getVariableSelector(); + VariableSelector variableSelector = condition.getTargetSelector(); if (variableSelector == null) { return "unknown"; } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/test/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/SaaBranchNodeDecisionTest.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/test/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/SaaBranchNodeDecisionTest.java deleted file mode 100644 index c2c143389a..0000000000 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/test/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/SaaBranchNodeDecisionTest.java +++ /dev/null @@ -1,294 +0,0 @@ -/* - * Copyright 2024-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.alibaba.cloud.ai.studio.admin.generator.service.generator.workflow.sections; - -import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.Node; -import com.alibaba.cloud.ai.studio.admin.generator.model.VariableSelector; -import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.ComparisonOperatorType; -import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.Case; -import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.LogicalOperatorType; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Nested; - -import java.util.HashMap; -import java.util.Map; -import java.util.List; -import java.util.Arrays; - -import static org.junit.jupiter.api.Assertions.*; - -public class SaaBranchNodeDecisionTest { - - private Map nodeMap; - - @BeforeEach - public void setUp() { - nodeMap = new HashMap<>(); - - Node fileReaderNode = new Node(); - fileReaderNode.setId("node-1725861677895"); - fileReaderNode.setTitle("File Reader"); - nodeMap.put("node-1725861677895", fileReaderNode); - - Node numberNode = new Node(); - numberNode.setId("node-1725861677896"); - numberNode.setTitle("Number Processor"); - nodeMap.put("node-1725861677896", numberNode); - } - - @Nested - @DisplayName("ComparisonOperatorType Tests") - class ComparisonOperatorTypeTests { - - @Test - @DisplayName("Test numeric comparison operators") - public void testAllNumericComparisonOperators() { - ComparisonOperatorType[] operators = { ComparisonOperatorType.EQUAL, ComparisonOperatorType.NOT_EQUAL, - ComparisonOperatorType.GREATER_THAN, ComparisonOperatorType.LESS_THAN, - ComparisonOperatorType.NOT_LESS_THAN, ComparisonOperatorType.NOT_GREATER_THAN }; - - String[] expectedPatterns = { "(%s.doubleValue() == %s)", "(%s.doubleValue() != %s)", - "(%s.doubleValue() > %s)", "(%s.doubleValue() < %s)", "(%s.doubleValue() >= %s)", - "(%s.doubleValue() <= %s)" }; - - for (int i = 0; i < operators.length; i++) { - String result = operators[i].convert("numberVar", "42"); - String expected = String.format(expectedPatterns[i], "numberVar", "42"); - - assertEquals(expected, result); - assertTrue(result.contains(".doubleValue()")); - assertFalse(result.matches(".*numberVar\\s*[><=!]+\\s*42.*")); - } - } - - @Test - @DisplayName("Test string operators") - public void testStringOperatorsUnaffected() { - String containsResult = ComparisonOperatorType.CONTAINS.convert("stringVar", "\"test\""); - assertEquals("(stringVar.contains(\"test\"))", containsResult); - - String equalsResult = ComparisonOperatorType.IS.convert("stringVar", "\"value\""); - assertEquals("(stringVar.equals(\"value\"))", equalsResult); - - String startsWithResult = ComparisonOperatorType.START_WITH.convert("stringVar", "\"prefix\""); - assertEquals("(stringVar.startsWith(\"prefix\"))", startsWithResult); - } - - @Test - @DisplayName("Test null check operators") - public void testNullCheckOperators() { - String nullResult = ComparisonOperatorType.NULL.convert("var", "null"); - assertEquals("(var == null)", nullResult); - - String notNullResult = ComparisonOperatorType.NOT_NULL.convert("var", "null"); - assertEquals("(var != null)", notNullResult); - } - - @Test - @DisplayName("Test edge cases") - public void testEdgeCasesAndSpecialValues() { - String negativeResult = ComparisonOperatorType.LESS_THAN.convert("num", "-5.5"); - assertEquals("(num.doubleValue() < -5.5)", negativeResult); - - String zeroResult = ComparisonOperatorType.EQUAL.convert("num", "0"); - assertEquals("(num.doubleValue() == 0)", zeroResult); - } - - } - - @Nested - @DisplayName("BranchNodeSection Tests") - class BranchNodeSectionTests { - - @Test - @DisplayName("Test file type variable access") - public void testFileTypeVariableAccess() { - Case.Condition condition = new Case.Condition(); - condition.setVarType("file"); - - VariableSelector selector = new VariableSelector(); - selector.setNamespace("node-1725861677895"); - selector.setName("uploaded_file"); - condition.setVariableSelector(selector); - - assertEquals("file", condition.getVarType()); - assertEquals("node-1725861677895", condition.getVariableSelector().getNamespace()); - assertEquals("uploaded_file", condition.getVariableSelector().getName()); - } - - @Test - @DisplayName("Test variable type identification") - public void testVariableTypeIdentification() { - String[] varTypes = { "string", "number", "boolean", "list", "object", "file" }; - - for (String varType : varTypes) { - Case.Condition condition = new Case.Condition(); - condition.setVarType(varType); - assertEquals(varType, condition.getVarType()); - } - } - - @Test - @DisplayName("Test multi-condition logic") - public void testMultiConditionLogicalCombination() { - Case testCase = new Case(); - testCase.setLogicalOperator(LogicalOperatorType.AND); - - Case.Condition condition1 = new Case.Condition(); - condition1.setVarType("file"); - condition1.setComparisonOperator(ComparisonOperatorType.IS); - - Case.Condition condition2 = new Case.Condition(); - condition2.setVarType("number"); - condition2.setComparisonOperator(ComparisonOperatorType.LESS_THAN); - - testCase.setConditions(Arrays.asList(condition1, condition2)); - - assertEquals(LogicalOperatorType.AND, testCase.getLogicalOperator()); - assertEquals(2, testCase.getConditions().size()); - } - - } - - @Test - @DisplayName("Test file extension comparison") - public void testOriginalErrorCase_FileExtensionComparison() { - Case.Condition condition = new Case.Condition(); - condition.setVarType("file"); - - VariableSelector selector = new VariableSelector(); - selector.setNamespace("node-1725861677895"); - selector.setName("extension"); - selector.setLabel("extension"); - condition.setVariableSelector(selector); - - condition.setComparisonOperator(ComparisonOperatorType.IS); - - String generatedExpression = ComparisonOperatorType.IS.convert("fileVar", "\".pdf\""); - assertEquals("(fileVar.equals(\".pdf\"))", generatedExpression); - assertFalse(generatedExpression.contains("==")); - } - - @Test - @DisplayName("Test number comparison fix") - public void testOriginalErrorCase_NumberComparison() { - String[] comparisonTests = { ComparisonOperatorType.EQUAL.convert("numberValue", "42"), - ComparisonOperatorType.NOT_EQUAL.convert("numberValue", "42"), - ComparisonOperatorType.GREATER_THAN.convert("numberValue", "42"), - ComparisonOperatorType.LESS_THAN.convert("numberValue", "42"), - ComparisonOperatorType.NOT_LESS_THAN.convert("numberValue", "42"), - ComparisonOperatorType.NOT_GREATER_THAN.convert("numberValue", "42") }; - - String[] expectedResults = { "(numberValue.doubleValue() == 42)", "(numberValue.doubleValue() != 42)", - "(numberValue.doubleValue() > 42)", "(numberValue.doubleValue() < 42)", - "(numberValue.doubleValue() >= 42)", "(numberValue.doubleValue() <= 42)" }; - - for (int i = 0; i < comparisonTests.length; i++) { - assertEquals(expectedResults[i], comparisonTests[i]); - assertTrue(comparisonTests[i].contains(".doubleValue()")); - } - } - - @Test - @DisplayName("Test fix comparison") - public void testBeforeAfterFixComparison() { - String numberComparison = ComparisonOperatorType.EQUAL.convert("numberValue", "42"); - assertFalse(numberComparison.equals("(numberValue == 42)")); - assertTrue(numberComparison.contains(".doubleValue()")); - - String stringComparison = ComparisonOperatorType.IS.convert("stringVar", "\"test\""); - assertEquals("(stringVar.equals(\"test\"))", stringComparison); - } - - @Test - @DisplayName("Test DSL variable selector parsing") - public void testDSLVariableSelectorParsing() { - VariableSelector selector = new VariableSelector(); - selector.setNamespace("node-1725861677895"); - selector.setName("extension"); - - assertEquals("node-1725861677895", selector.getNamespace()); - assertEquals("extension", selector.getName()); - - assertTrue(nodeMap.containsKey("node-1725861677895")); - assertEquals("File Reader", nodeMap.get("node-1725861677895").getTitle()); - } - - @Test - @DisplayName("Test multi-condition logical expression") - public void testMultiConditionLogicalExpression() { - Case testCase = new Case(); - testCase.setLogicalOperator(LogicalOperatorType.AND); - - Case.Condition fileCondition = new Case.Condition(); - fileCondition.setVarType("file"); - VariableSelector fileSelector = new VariableSelector(); - fileSelector.setNamespace("node-1725861677895"); - fileSelector.setName("uploaded_file"); - fileSelector.setLabel("extension"); - fileCondition.setVariableSelector(fileSelector); - fileCondition.setComparisonOperator(ComparisonOperatorType.IS); - - Case.Condition numberCondition = new Case.Condition(); - numberCondition.setVarType("number"); - VariableSelector numberSelector = new VariableSelector(); - numberSelector.setNamespace("node-1725861677896"); - numberSelector.setName("file_size"); - numberCondition.setVariableSelector(numberSelector); - numberCondition.setComparisonOperator(ComparisonOperatorType.LESS_THAN); - - testCase.setConditions(List.of(fileCondition, numberCondition)); - - assertEquals(2, testCase.getConditions().size()); - assertEquals(LogicalOperatorType.AND, testCase.getLogicalOperator()); - - String fileExpression = fileCondition.getComparisonOperator().convert("fileExtension", "\".pdf\""); - String numberExpression = numberCondition.getComparisonOperator().convert("fileSize", "1048576"); - - assertEquals("(fileExtension.equals(\".pdf\"))", fileExpression); - assertEquals("(fileSize.doubleValue() < 1048576)", numberExpression); - assertTrue(numberExpression.contains(".doubleValue()")); - } - - @Test - @DisplayName("Test fix resolves compilation errors") - public void testFixResolvesCompilationErrors() { - ComparisonOperatorType[] numberOperators = { ComparisonOperatorType.EQUAL, ComparisonOperatorType.NOT_EQUAL, - ComparisonOperatorType.GREATER_THAN, ComparisonOperatorType.LESS_THAN, - ComparisonOperatorType.NOT_LESS_THAN, ComparisonOperatorType.NOT_GREATER_THAN }; - - for (ComparisonOperatorType operator : numberOperators) { - String expression = operator.convert("numVar", "123"); - - assertTrue(expression.startsWith("(") && expression.endsWith(")")); - assertTrue(expression.contains(".doubleValue()")); - assertFalse(expression.matches(".*numVar\\s*[><=!]+\\s*123.*")); - } - - String stringExpression = ComparisonOperatorType.CONTAINS.convert("strVar", "\"test\""); - assertEquals("(strVar.contains(\"test\"))", stringExpression); - - String nullExpression = ComparisonOperatorType.NULL.convert("var", "null"); - assertEquals("(var == null)", nullExpression); - - String notNullExpression = ComparisonOperatorType.NOT_NULL.convert("var", "null"); - assertEquals("(var != null)", notNullExpression); - } - -} From 9f84983035b5a215f6f3cb148be52d96d8b22b0c Mon Sep 17 00:00:00 2001 From: VLSMB <2047857654@qq.com> Date: Sun, 14 Sep 2025 23:43:12 +0800 Subject: [PATCH 02/12] feat: support JudgeNode for Studio DSL --- .../admin/generator/model/VariableType.java | 28 +++- .../admin/generator/model/workflow/Case.java | 25 +++- .../workflow/ComparisonOperatorType.java | 105 +++++++++++---- .../generator/model/workflow/NodeType.java | 2 +- .../workflow/nodedata/BranchNodeData.java | 2 - .../converter/BranchNodeDataConverter.java | 122 +++++++++++++++++- .../workflow/sections/BranchNodeSection.java | 16 +-- 7 files changed, 253 insertions(+), 47 deletions(-) diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/VariableType.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/VariableType.java index ad0bb2f94f..2e01bf0110 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/VariableType.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/VariableType.java @@ -16,7 +16,9 @@ package com.alibaba.cloud.ai.studio.admin.generator.model; import java.util.Arrays; +import java.util.List; import java.util.Optional; +import java.util.stream.Stream; public enum VariableType { @@ -33,13 +35,15 @@ public enum VariableType { ARRAY("Array", Object.class, "array", "Array"), - ARRAY_STRING("String[]", String[].class, "array[string]", "String[]"), + ARRAY_BOOLEAN("Boolean[]", Boolean[].class, "array[boolean]", "Array"), - ARRAY_NUMBER("Number[]", Number[].class, "array[number]", "Number[]"), + ARRAY_STRING("String[]", String[].class, "array[string]", "Array"), - ARRAY_OBJECT("Object[]", Object[].class, "array[object]", "Object[]"), + ARRAY_NUMBER("Number[]", Number[].class, "array[number]", "Array"), - ARRAY_FILE("File[]", Object[].class, "file-list", "File[]"); + ARRAY_OBJECT("Object[]", Object[].class, "array[object]", "Array"), + + ARRAY_FILE("File[]", Object[].class, "file-list", "Array"); private final String value; @@ -56,6 +60,22 @@ public enum VariableType { this.studioValue = studioValue; } + public static List all() { + return List.of(values()); + } + + public static List arrays() { + return List.of(ARRAY_BOOLEAN, ARRAY_STRING, ARRAY_NUMBER, ARRAY_OBJECT, ARRAY_FILE, ARRAY); + } + + public static List arraysWithOther(VariableType... other) { + return Stream.concat(Stream.of(other), arrays().stream()).toList(); + } + + public static List except(VariableType... excepted) { + return Stream.of(VariableType.values()).filter(type -> !Arrays.asList(excepted).contains(type)).toList(); + } + public String value() { return value; } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/Case.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/Case.java index 6e339f5c8b..bc6d6b04bb 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/Case.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/Case.java @@ -57,26 +57,36 @@ public Case setConditions(List conditions) { public static class Condition { + // 左值数据类型 private VariableType varType; + // 右值数据类型 + private VariableType referenceType; + private ComparisonOperatorType comparisonOperator; + // 左值 private VariableSelector targetSelector; + // 右值 private String referenceValue; private VariableSelector referenceSelector; - // 参考值可能来自stateKey,也有可能直接是常量值 + // 参考值可能来自stateKey,也有可能直接是常量值,也有可能没有参考值 public String getValue() { if (referenceValue != null) { return referenceValue; } else if (referenceSelector != null) { - return String.format("((%s) state.value(\"%s\").orElse(null))", varType.value(), + if (VariableType.NUMBER.equals(referenceType)) { + return String.format("((%s) state.value(\"%s\").orElse(null)).doubleValue()", referenceType.value(), + referenceSelector.getNameInCode()); + } + return String.format("((%s) state.value(\"%s\").orElse(null))", referenceType.value(), referenceSelector.getNameInCode()); } - throw new IllegalStateException("referenceValue or referenceSelector must be set"); + return null; } public String getReferenceValue() { @@ -106,6 +116,15 @@ public Condition setVarType(VariableType varType) { return this; } + public VariableType getReferenceType() { + return referenceType; + } + + public Condition setReferenceType(VariableType referenceType) { + this.referenceType = referenceType; + return this; + } + public ComparisonOperatorType getComparisonOperator() { return comparisonOperator; } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/ComparisonOperatorType.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/ComparisonOperatorType.java index eef73c7ddb..9298a6afe0 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/ComparisonOperatorType.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/ComparisonOperatorType.java @@ -26,15 +26,16 @@ public enum ComparisonOperatorType { CONTAINS("contains", type -> switch (type) { - case DIFY -> "contains"; + case DIFY, STUDIO -> "contains"; default -> "unknown"; - }, List.of(VariableType.STRING, VariableType.ARRAY, VariableType.ARRAY_OBJECT, VariableType.ARRAY_STRING, - VariableType.ARRAY_NUMBER), (objName, constVal) -> String.format("(%s.contains(%s))", objName, constVal)), + }, VariableType.arraysWithOther(VariableType.STRING), + (objName, constVal) -> String.format("(%s.contains(%s))", objName, constVal)), NOT_CONTAINS("not_contains", type -> switch (type) { case DIFY -> "not contains"; + case STUDIO -> "notContains"; default -> "unknown"; - }, List.of(VariableType.STRING, VariableType.ARRAY, VariableType.ARRAY_OBJECT, VariableType.ARRAY_STRING, - VariableType.ARRAY_NUMBER), (objName, constVal) -> String.format("!(%s.contains(%s))", objName, constVal)), + }, VariableType.arraysWithOther(VariableType.STRING), + (objName, constVal) -> String.format("!(%s.contains(%s))", objName, constVal)), START_WITH("start_with", type -> switch (type) { case DIFY -> "start with"; default -> "unknown"; @@ -45,77 +46,136 @@ public enum ComparisonOperatorType { }, List.of(VariableType.STRING), (objName, constVal) -> String.format("(%s.endsWith(%s))", objName, constVal)), IS("is", type -> switch (type) { case DIFY -> "is"; + case STUDIO -> "equals"; default -> "unknown"; - }, List.of(VariableType.STRING, VariableType.ARRAY, VariableType.ARRAY_OBJECT, VariableType.ARRAY_STRING, - VariableType.ARRAY_NUMBER), (objName, constVal) -> String.format("(%s.equals(%s))", objName, constVal)), + }, VariableType.except(VariableType.NUMBER), + (objName, constVal) -> String.format("(%s.equals(%s))", objName, constVal)), IS_NOT("is_not", type -> switch (type) { case DIFY -> "is not"; + case STUDIO -> "notEquals"; default -> "unknown"; - }, List.of(VariableType.STRING, VariableType.ARRAY, VariableType.ARRAY_OBJECT, VariableType.ARRAY_STRING, - VariableType.ARRAY_NUMBER), (objName, constVal) -> String.format("!(%s.equals(%s))", objName, constVal)), + }, VariableType.except(VariableType.NUMBER), + (objName, constVal) -> String.format("!(%s.equals(%s))", objName, constVal)), EMPTY("empty", type -> switch (type) { case DIFY -> "empty"; default -> "unknown"; - }, List.of(VariableType.STRING, VariableType.ARRAY, VariableType.ARRAY_OBJECT, VariableType.ARRAY_STRING, - VariableType.ARRAY_NUMBER), (objName, constVal) -> String.format("(%s.isEmpty())", objName)), + }, VariableType.arraysWithOther(VariableType.STRING), + (objName, constVal) -> String.format("(%s.isEmpty())", objName)), NOT_EMPTY("not_empty", type -> switch (type) { case DIFY -> "not empty"; default -> "unknown"; - }, List.of(VariableType.STRING, VariableType.ARRAY, VariableType.ARRAY_OBJECT, VariableType.ARRAY_STRING, - VariableType.ARRAY_NUMBER), (objName, constVal) -> String.format("!(%s.isEmpty())", objName)), + }, VariableType.arraysWithOther(VariableType.STRING), + (objName, constVal) -> String.format("!(%s.isEmpty())", objName)), IN("in", type -> switch (type) { case DIFY -> "in"; default -> "unknown"; - }, List.of(VariableType.STRING, VariableType.ARRAY, VariableType.ARRAY_OBJECT, VariableType.ARRAY_STRING, - VariableType.ARRAY_NUMBER), (objName, constVal) -> String.format("(%s.contains(%s))", constVal, objName)), + }, VariableType.arraysWithOther(VariableType.STRING), + (objName, constVal) -> String.format("(%s.contains(%s))", constVal, objName)), NOT_IN("not_in", type -> switch (type) { case DIFY -> "not in"; default -> "unknown"; - }, List.of(VariableType.STRING, VariableType.ARRAY, VariableType.ARRAY_OBJECT, VariableType.ARRAY_STRING, - VariableType.ARRAY_NUMBER), (objName, constVal) -> String.format("!(%s.contains(%s))", constVal, objName)), + }, VariableType.arraysWithOther(VariableType.STRING), + (objName, constVal) -> String.format("!(%s.contains(%s))", constVal, objName)), ALL_OF("all_of", type -> switch (type) { case DIFY -> "all of"; default -> "unknown"; - }, List.of(VariableType.ARRAY, VariableType.ARRAY_OBJECT, VariableType.ARRAY_STRING, VariableType.ARRAY_NUMBER), - (objName, constVal) -> String.format("(%s.containsAll(%s))", objName, constVal)), + }, VariableType.arrays(), (objName, constVal) -> String.format("(%s.containsAll(%s))", objName, constVal)), EQUAL("equal", type -> switch (type) { case DIFY -> "="; + case STUDIO -> "equals"; default -> "unknown"; }, List.of(VariableType.NUMBER), (objName, constVal) -> String.format("(%s.doubleValue() == %s)", objName, constVal)), NOT_EQUAL("not_equal", type -> switch (type) { case DIFY -> "≠"; + case STUDIO -> "notEquals"; default -> "unknown"; }, List.of(VariableType.NUMBER), (objName, constVal) -> String.format("(%s.doubleValue() != %s)", objName, constVal)), GREATER_THAN("greater_than", type -> switch (type) { case DIFY -> ">"; + case STUDIO -> "greater"; default -> "unknown"; }, List.of(VariableType.NUMBER), (objName, constVal) -> String.format("(%s.doubleValue() > %s)", objName, constVal)), LESS_THAN("less_than", type -> switch (type) { case DIFY -> "<"; + case STUDIO -> "less"; default -> "unknown"; }, List.of(VariableType.NUMBER), (objName, constVal) -> String.format("(%s.doubleValue() < %s)", objName, constVal)), NOT_LESS_THAN("not_less_than", type -> switch (type) { case DIFY -> "≥"; + case STUDIO -> "greaterAndEqual"; default -> "unknown"; }, List.of(VariableType.NUMBER), (objName, constVal) -> String.format("(%s.doubleValue() >= %s)", objName, constVal)), NOT_GREATER_THAN("not_greater_than", type -> switch (type) { case DIFY -> "≤"; + case STUDIO -> "lessAndEqual"; default -> "unknown"; }, List.of(VariableType.NUMBER), (objName, constVal) -> String.format("(%s.doubleValue() <= %s)", objName, constVal)), NULL("null", type -> switch (type) { case DIFY -> "null"; + case STUDIO -> "isNull"; default -> "unknown"; - }, List.of(VariableType.values()), (objName, constVal) -> String.format("(%s == null)", objName)), + }, VariableType.all(), (objName, constVal) -> String.format("(%s == null)", objName)), NOT_NULL("not_null", type -> switch (type) { case DIFY -> "not null"; + case STUDIO -> "isNotNull"; default -> "unknown"; - }, List.of(VariableType.values()), (objName, constVal) -> String.format("(%s != null)", objName)); + }, VariableType.all(), (objName, constVal) -> String.format("(%s != null)", objName)), + LENGTH_EQUAL("length_equal", type -> switch (type) { + case STUDIO -> "lengthEquals"; + default -> "unknown"; + }, VariableType.arrays(), (objName, constVal) -> String.format("(%s.size() == %s)", objName, constVal)), + LENGTH_GREATER_THAN("length_greater_than", type -> switch (type) { + case STUDIO -> "lengthGreater"; + default -> "unknown"; + }, VariableType.arrays(), (objName, constVal) -> String.format("(%s.size() > %s)", objName, constVal)), + LENGTH_NOT_LESS_THAN("length_not_less_than", type -> switch (type) { + case STUDIO -> "lengthGreaterAndEqual"; + default -> "unknown"; + }, VariableType.arrays(), (objName, constVal) -> String.format("(%s.size() >= %s)", objName, constVal)), + LENGTH_LESS_THAN("length_less_than", type -> switch (type) { + case STUDIO -> "lengthLess"; + default -> "unknown"; + }, VariableType.arrays(), (objName, constVal) -> String.format("(%s.size() < %s)", objName, constVal)), + LENGTH_NOT_GREATER_THAN("length_not_greater_than", type -> switch (type) { + case STUDIO -> "lengthLessAndEqual"; + default -> "unknown"; + }, VariableType.arrays(), (objName, constVal) -> String.format("(%s.size() <= %s)", objName, constVal)), + + STR_LENGTH_EQUAL("str_length_equal", type -> switch (type) { + case STUDIO -> "lengthEquals"; + default -> "unknown"; + }, List.of(VariableType.STRING), (objName, constVal) -> String.format("(%s.length() == %s)", objName, constVal)), + STR_LENGTH_GREATER_THAN("str_length_greater_than", type -> switch (type) { + case STUDIO -> "lengthGreater"; + default -> "unknown"; + }, List.of(VariableType.STRING), (objName, constVal) -> String.format("(%s.length() > %s)", objName, constVal)), + STR_LENGTH_NOT_LESS_THAN("str_length_not_less_than", type -> switch (type) { + case STUDIO -> "lengthGreaterAndEqual"; + default -> "unknown"; + }, List.of(VariableType.STRING), (objName, constVal) -> String.format("(%s.length() >= %s)", objName, constVal)), + STR_LENGTH_LESS_THAN("str_length_less_than", type -> switch (type) { + case STUDIO -> "lengthLess"; + default -> "unknown"; + }, List.of(VariableType.STRING), (objName, constVal) -> String.format("(%s.length() < %s)", objName, constVal)), + STR_LENGTH_NOT_GREATER_THAN("str_length_not_greater_than", type -> switch (type) { + case STUDIO -> "lengthLessAndEqual"; + default -> "unknown"; + }, List.of(VariableType.STRING), (objName, constVal) -> String.format("(%s.length() <= %s)", objName, constVal)), + + IS_TRUE("is_true", type -> switch (type) { + case STUDIO -> "isTrue"; + default -> "unknown"; + }, List.of(VariableType.BOOLEAN), (objName, constVal) -> String.format("(%s)", objName)), + IS_FALSE("is_false", type -> switch (type) { + case STUDIO -> "isFalse"; + default -> "unknown"; + }, List.of(VariableType.BOOLEAN), (objName, constVal) -> String.format("!(%s)", objName)); private final String value; @@ -141,7 +201,8 @@ public static ComparisonOperatorType fromDslValue(DSLDialectType dialectType, St return comparisonOperatorType; } } - throw new IllegalArgumentException("Not support dslValue:" + dslValue); + throw new IllegalArgumentException( + "Not support dslValue: [" + dslValue + "] for type: [" + variableType.value() + "]"); } public String convert(String objName, String constValue) { diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java index a287481342..d73680fd1e 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java @@ -41,7 +41,7 @@ public enum NodeType { HUMAN("human", "unsupported", "UNSUPPORTED"), - BRANCH("branch", "if-else", "UNSUPPORTED"), + BRANCH("branch", "if-else", "Judge"), DOC_EXTRACTOR("document-extractor", "document-extractor", "UNSUPPORTED"), diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/BranchNodeData.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/BranchNodeData.java index c94a6f0ba5..0f5e4f3aae 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/BranchNodeData.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/BranchNodeData.java @@ -17,8 +17,6 @@ import java.util.List; -import com.alibaba.cloud.ai.studio.admin.generator.model.Variable; -import com.alibaba.cloud.ai.studio.admin.generator.model.VariableSelector; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.Case; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeData; diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/BranchNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/BranchNodeDataConverter.java index 37860a5e35..ffdf5b988c 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/BranchNodeDataConverter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/BranchNodeDataConverter.java @@ -18,6 +18,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.function.BiConsumer; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -33,6 +34,8 @@ import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.AbstractNodeDataConverter; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; +import com.alibaba.cloud.ai.studio.admin.generator.utils.MapReadUtil; +import com.fasterxml.jackson.core.JsonProcessingException; import org.springframework.stereotype.Component; @Component @@ -97,6 +100,96 @@ public Map dump(BranchNodeData nodeData) { } }), + STUDIO(new DialectConverter<>() { + @Override + public Boolean supportDialect(DSLDialectType dialectType) { + return DSLDialectType.STUDIO.equals(dialectType); + } + + @Override + public BranchNodeData parse(Map data) throws JsonProcessingException { + BranchNodeData nodeData = new BranchNodeData(); + + // 获取条件信息 + List> caseList = Optional + .ofNullable(MapReadUtil.safeCastToListWithMap( + MapReadUtil.getMapDeepValue(data, List.class, "config", "node_param", "branches"))) + .orElse(List.of()) + .stream() + .filter(caseMap -> caseMap.containsKey("id")) + .toList(); + String defaultCase = caseList.stream() + .filter(map -> !map.containsKey("conditions")) + .map(map -> map.get("id").toString()) + .findFirst() + .orElse("default"); + List cases = caseList.stream().filter(map -> map.containsKey("conditions")).map(map -> { + String id = MapReadUtil.getMapDeepValue(map, String.class, "id"); + LogicalOperatorType logicalOperatorType = LogicalOperatorType + .fromValue(Optional.ofNullable(MapReadUtil.getMapDeepValue(map, String.class, "logic")) + .orElse(LogicalOperatorType.AND.getValue())); + + // 提取Conditions + List> conditionMap = Optional + .ofNullable(MapReadUtil + .safeCastToListWithMap(MapReadUtil.getMapDeepValue(map, List.class, "conditions"))) + .orElse(List.of()) + .stream() + .filter(mp -> mp.containsKey("left") && mp.containsKey("right") && mp.containsKey("operator")) + .toList(); + List conditions = conditionMap.stream().map(mp -> { + String rightFrom = MapReadUtil.getMapDeepValue(mp, String.class, "right", "value_from"); + String leftValue = MapReadUtil.getMapDeepValue(mp, String.class, "left", "value"); + String rightValue = MapReadUtil.getMapDeepValue(mp, String.class, "right", "value"); + + VariableType variableType = VariableType + .fromStudioValue( + Optional.ofNullable(MapReadUtil.getMapDeepValue(mp, String.class, "left", "type")) + .orElse(VariableType.OBJECT.studioValue())) + .orElseThrow(); + VariableType referenceType = VariableType + .fromStudioValue( + Optional.ofNullable(MapReadUtil.getMapDeepValue(mp, String.class, "right", "type")) + .orElse(VariableType.OBJECT.studioValue())) + .orElseThrow(); + + ComparisonOperatorType comparisonOperatorType = ComparisonOperatorType.fromDslValue( + DSLDialectType.STUDIO, MapReadUtil.getMapDeepValue(mp, String.class, "operator"), + variableType); + + VariableSelector targetSelector = this.varTemplateToSelector(DSLDialectType.STUDIO, leftValue); + Case.Condition condition = new Case.Condition().setVarType(variableType) + .setReferenceType(referenceType) + .setTargetSelector(targetSelector) + .setComparisonOperator(comparisonOperatorType); + + if ("refer".equalsIgnoreCase(rightFrom)) { + VariableSelector referenceSelector = this.varTemplateToSelector(DSLDialectType.STUDIO, + rightValue); + condition.setReferenceSelector(referenceSelector); + } + else { + condition.setReferenceValue(rightValue); + } + + return condition; + }).toList(); + + return new Case().setId(id).setLogicalOperator(logicalOperatorType).setConditions(conditions); + }).toList(); + + // 设置基本信息 + nodeData.setCases(cases); + nodeData.setDefaultCase(defaultCase); + return nodeData; + } + + @Override + public Map dump(BranchNodeData nodeData) { + throw new UnsupportedOperationException(); + } + }), + CUSTOM(AbstractNodeDataConverter.defaultCustomDialectConverter(BranchNodeData.class)); private final DialectConverter dialectConverter; @@ -123,23 +216,38 @@ public Stream extractWorkflowVars(BranchNodeData data) { @Override public BiConsumer> postProcessConsumer(DSLDialectType dialectType) { - return switch (dialectType) { - case DIFY, STUDIO -> super.postProcessConsumer(dialectType).andThen((nodeData, idToVarName) -> { + BiConsumer> consumer = super.postProcessConsumer(dialectType) + .andThen((nodeData, idToVarName) -> { // 处理条件里的VariableSelector nodeData.getCases().forEach(c -> { c.getConditions().forEach(condition -> { VariableSelector selector = condition.getTargetSelector(); - selector.setNameInCode(idToVarName.getOrDefault(selector.getNamespace(), "unknown") + "_" - + selector.getName()); + selector + .setNameInCode(idToVarName.getOrDefault(selector.getNamespace(), selector.getNamespace()) + + "_" + selector.getName()); VariableSelector referenceSelector = condition.getReferenceSelector(); if (referenceSelector != null) { - referenceSelector - .setNameInCode(idToVarName.getOrDefault(referenceSelector.getNamespace(), "unknown") - + "_" + referenceSelector.getName()); + referenceSelector.setNameInCode(idToVarName.getOrDefault(referenceSelector.getNamespace(), + referenceSelector.getNamespace()) + "_" + referenceSelector.getName()); } }); }); }); + + return switch (dialectType) { + case DIFY -> consumer; + case STUDIO -> consumer.andThen((nodeData, idToVarName) -> { + // 将Case的caseId里添加nodeId(为了与Edge里的sourceHandle保持一致) + String varName = nodeData.getVarName(); + String prefix = idToVarName.entrySet() + .stream() + .filter(entry -> entry.getValue().equals(varName)) + .map(Map.Entry::getKey) + .findFirst() + .orElseThrow() + "_"; + nodeData.getCases().forEach(c -> c.setId(prefix + c.getId())); + nodeData.setDefaultCase(prefix + nodeData.getDefaultCase()); + }); default -> super.postProcessConsumer(dialectType); }; } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/BranchNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/BranchNodeSection.java index 21700f47b1..e6c8512c52 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/BranchNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/BranchNodeSection.java @@ -134,30 +134,30 @@ private String generateSafeVariableAccess(Case.Condition condition) { if (accessExtension) { // 如果是访问扩展名属性,直接访问扩展名字段 - return String.format("state.value(\"%s\", String.class).orElse(\"\")", variablePath); + return String.format("state.value(\"%s\", String.class).orElse(null)", variablePath); } else { // 从文件对象中提取扩展名 return String.format( "state.value(\"%s\", java.io.File.class).map(file -> { " + "String name = file.getName(); " + "int dotIndex = name.lastIndexOf('.'); " - + "return dotIndex > 0 ? name.substring(dotIndex) : \"\"; " + "}).orElse(\"\")", + + "return dotIndex > 0 ? name.substring(dotIndex) : \"\"; " + "}).orElse(null)", variablePath); } + // 默认返回null,避免isNull判断恒为false case STRING: - return String.format("state.value(\"%s\", String.class).orElse(\"\")", variablePath); + return String.format("state.value(\"%s\", String.class).orElse(null)", variablePath); case NUMBER: - return String.format("state.value(\"%s\", Number.class).orElse(0)", variablePath); + return String.format("state.value(\"%s\", Number.class).orElse(null)", variablePath); case BOOLEAN: - return String.format("state.value(\"%s\", Boolean.class).orElse(false)", variablePath); + return String.format("state.value(\"%s\", Boolean.class).orElse(null)", variablePath); case ARRAY_FILE: case ARRAY_NUMBER: case ARRAY_STRING: case ARRAY_OBJECT: + case ARRAY_BOOLEAN: case ARRAY: - return String.format( - "state.value(\"%s\", java.util.List.class).orElse(java.util.Collections.emptyList())", - variablePath); + return String.format("state.value(\"%s\", List.class).orElse(null)", variablePath); case OBJECT: return String.format("state.value(\"%s\", Object.class).orElse(null)", variablePath); default: From bf86cf18c187474f555ac1ed83f69cb684c474ee Mon Sep 17 00:00:00 2001 From: VLSMB <2047857654@qq.com> Date: Mon, 15 Sep 2025 15:12:14 +0800 Subject: [PATCH 03/12] feat: enhance QuestionClassifierNode to support variableCategories --- .../ai/graph/node/QuestionClassifierNode.java | 86 ++++++++++++++--- .../node/QuestionClassifierNodeTest.java | 96 +++++++++++++++++++ 2 files changed, 170 insertions(+), 12 deletions(-) create mode 100644 spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/node/QuestionClassifierNodeTest.java diff --git a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/QuestionClassifierNode.java b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/QuestionClassifierNode.java index 8cab4d87bb..6faae82384 100644 --- a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/QuestionClassifierNode.java +++ b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/QuestionClassifierNode.java @@ -22,13 +22,20 @@ import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.util.StringUtils; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.Stream; public class QuestionClassifierNode implements NodeAction { @@ -69,21 +76,25 @@ public class QuestionClassifierNode implements NodeAction { ``` """; - private SystemPromptTemplate systemPromptTemplate; + private static final Pattern VAR_TEMPLATE_PATTERN = Pattern.compile("\\{(\\w+)}"); - private ChatClient chatClient; + private final SystemPromptTemplate systemPromptTemplate; + + private final ChatClient chatClient; private String inputText; - private List categories; + // 类别里可能存在{}这样的占位变量需要处理 + private final Map categories; - private List classificationInstructions; + // 分类指导里可能存在{}这样的占位变量需要处理 + private final List classificationInstructions; - private String inputTextKey; + private final String inputTextKey; - private String outputKey; + private final String outputKey; - public QuestionClassifierNode(ChatClient chatClient, String inputTextKey, List categories, + public QuestionClassifierNode(ChatClient chatClient, String inputTextKey, Map categories, List classificationInstructions, String outputKey) { this.chatClient = chatClient; this.inputTextKey = inputTextKey; @@ -93,6 +104,21 @@ public QuestionClassifierNode(ChatClient chatClient, String inputTextKey, List params = Stream.of(template) + .map(VAR_TEMPLATE_PATTERN::matcher) + .map(Matcher::results) + .map(results -> results.collect(Collectors.toUnmodifiableMap(r -> r.group(1), + r -> state.value(r.group(1)).orElse(""), (a, b) -> b))) + .findFirst() + .orElseThrow(); + return new PromptTemplate(template).render(params); + } + + private List renderTemplates(OverAllState state, List templates) { + return templates.stream().map(template -> renderTemplate(state, template)).toList(); + } + @Override public Map apply(OverAllState state) throws Exception { if (StringUtils.hasLength(inputTextKey)) { @@ -109,16 +135,38 @@ public Map apply(OverAllState state) throws Exception { messages.add(userMessage2); messages.add(assistantMessage2); + Map renderedCategories = categories.entrySet() + .stream() + .map(e -> Map.entry(e.getKey(), renderTemplate(state, e.getValue()))) + .collect(Collectors.toUnmodifiableMap(Map.Entry::getKey, Map.Entry::getValue)); + + List categoriesList = renderedCategories.values().stream().toList(); + ChatResponse response = chatClient.prompt() - .system(systemPromptTemplate.render(Map.of("inputText", inputText, "categories", categories, - "classificationInstructions", classificationInstructions))) + .system(systemPromptTemplate.render(Map.of("inputText", inputText, "categories", categoriesList, + "classificationInstructions", renderTemplates(state, classificationInstructions)))) .user(inputText) .messages(messages) .call() .chatResponse(); Map updatedState = new HashMap<>(); - updatedState.put(outputKey, response.getResult().getOutput().getText()); + String output = Optional + .ofNullable(Optional.ofNullable(response) + .orElseThrow(() -> new RuntimeException("chat response is null")) + .getResult() + .getOutput() + .getText()) + .orElseThrow(() -> new RuntimeException("chat response text is null")); + String result = renderedCategories.entrySet() + .stream() + .filter(entry -> output.contains(entry.getValue())) + .map(Map.Entry::getKey) + .findFirst() + .orElseThrow(() -> new RuntimeException( + "chatClient returns [" + output + "], but it does not belong to the given category.")); + + updatedState.put(outputKey, result); if (state.value("messages").isPresent()) { updatedState.put("messages", response.getResult().getOutput()); } @@ -136,7 +184,7 @@ public static class Builder { private ChatClient chatClient; - private List categories; + private Map categories; private List classificationInstructions; @@ -152,11 +200,25 @@ public Builder chatClient(ChatClient chatClient) { return this; } - public Builder categories(List categories) { + /** + * 需要一个Map对象,key为类别ID,value为具体的类别名称。 + * 类别名称里可以存在{@code {var_name}}这样的变量占位符,节点将从{@code OverAllState}里获取变量值进行替换。 + */ + public Builder categories(Map categories) { this.categories = categories; return this; } + // 兼容旧版本 + @Deprecated + public Builder categories(List categories) { + this.categories = categories.stream().collect(Collectors.toMap(k -> k, k -> k, (a, b) -> b)); + return this; + } + + /** + * 指导说明列表。 指导说明里可以存在{@code {var_name}}这样的变量占位符,节点将从{@code OverAllState}里获取变量值进行替换。 + */ public Builder classificationInstructions(List classificationInstructions) { this.classificationInstructions = classificationInstructions; return this; diff --git a/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/node/QuestionClassifierNodeTest.java b/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/node/QuestionClassifierNodeTest.java new file mode 100644 index 0000000000..52e02e1202 --- /dev/null +++ b/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/node/QuestionClassifierNodeTest.java @@ -0,0 +1,96 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.cloud.ai.graph.node; + +import com.alibaba.cloud.ai.dashscope.api.DashScopeApi; +import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel; +import com.alibaba.cloud.ai.graph.OverAllState; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.model.ChatModel; + +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +@EnabledIfEnvironmentVariable(named = "AI_DASHSCOPE_API_KEY", matches = ".+") +public class QuestionClassifierNodeTest { + + private ChatClient chatClient; + + @BeforeEach + public void setUp() { + DashScopeApi dashScopeApi = DashScopeApi.builder().apiKey(System.getenv("AI_DASHSCOPE_API_KEY")).build(); + ChatModel chatModel = DashScopeChatModel.builder().dashScopeApi(dashScopeApi).build(); + chatClient = ChatClient.builder(chatModel).build(); + } + + private QuestionClassifierNode createNode(Map categories, List instructions) { + return QuestionClassifierNode.builder() + .chatClient(chatClient) + .inputTextKey("input") + .categories(categories) + .outputKey("output") + .classificationInstructions(instructions) + .build(); + } + + private OverAllState createState(Map map) { + OverAllState state = new OverAllState(); + state.updateState(map); + return state; + } + + @Test + public void testBase() throws Exception { + QuestionClassifierNode node = createNode(Map.of("1", "正面评价", "2", "负面评价", "3", "中立评价"), + List.of("请根据输入的评价内容,给出评价的分类结果。")); + Map apply = node.apply(createState(Map.of("input", "你们的服务做的真好!"))); + System.out.println(apply); + assertEquals("1", apply.get("output")); + Map apply1 = node.apply(createState(Map.of("input", "你们服务做的真差!"))); + System.out.println(apply1); + assertEquals("2", apply1.get("output")); + } + + @Test + public void testVariableCategories() throws Exception { + QuestionClassifierNode node = createNode(Map.of("1", "{category1}评价", "2", "{category2}评价"), + List.of("请根据输入的评价内容,给出评价的分类结果。")); + Map apply = node + .apply(createState(Map.of("input", "你们的服务做的真好!", "category1", "正面", "category2", "负面"))); + System.out.println(apply); + assertEquals("1", apply.get("output")); + Map apply1 = node + .apply(createState(Map.of("input", "你们服务做的真差!", "category2", "正面", "category1", "负面"))); + System.out.println(apply1); + assertEquals("1", apply1.get("output")); + } + + @Test + public void testVariableInstructions() throws Exception { + QuestionClassifierNode node = createNode(Map.of("1", "正面评价", "2", "负面评价"), List.of("{instruction}")); + Map apply = node + .apply(createState(Map.of("input", "你们的服务做的真差!", "instruction", "请根据输入的评价内容,给出评价的分类结果。"))); + System.out.println(apply); + assertEquals("2", apply.get("output")); + } + +} From 3755cc19f550cfaacd38906ece538ad82f57374c Mon Sep 17 00:00:00 2001 From: VLSMB <2047857654@qq.com> Date: Mon, 15 Sep 2025 19:57:29 +0800 Subject: [PATCH 04/12] feat: support QuestionClassifierNode for Studio DSL and sorted import in generated code --- .../ai/graph/node/QuestionClassifierNode.java | 1 - .../generator/model/workflow/NodeType.java | 6 +- .../nodedata/QuestionClassifierNodeData.java | 316 ++---------------- .../dsl/AbstractNodeDataConverter.java | 7 +- .../QuestionClassifyNodeDataConverter.java | 247 +++++++------- .../generator/workflow/NodeSection.java | 1 + .../workflow/WorkflowProjectGenerator.java | 40 ++- .../QuestionClassifierNodeSection.java | 169 +++++----- .../templates/GraphBuilder.java.mustache | 24 +- 9 files changed, 289 insertions(+), 522 deletions(-) diff --git a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/QuestionClassifierNode.java b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/QuestionClassifierNode.java index 6faae82384..c28c12ca06 100644 --- a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/QuestionClassifierNode.java +++ b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/QuestionClassifierNode.java @@ -27,7 +27,6 @@ import org.springframework.util.StringUtils; import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java index d73680fd1e..41a9e879b5 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java @@ -39,13 +39,13 @@ public enum NodeType { AGGREGATOR("aggregator", "variable-aggregator", "UNSUPPORTED"), - HUMAN("human", "unsupported", "UNSUPPORTED"), + HUMAN("human", "UNSUPPORTED", "UNSUPPORTED"), BRANCH("branch", "if-else", "Judge"), DOC_EXTRACTOR("document-extractor", "document-extractor", "UNSUPPORTED"), - QUESTION_CLASSIFIER("question-classifier", "question-classifier", "UNSUPPORTED"), + QUESTION_CLASSIFIER("question-classifier", "question-classifier", "Classifier"), HTTP("http", "http-request", "UNSUPPORTED"), @@ -55,7 +55,7 @@ public enum NodeType { TOOL("tool", "tool", "UNSUPPORTED"), - MCP("mcp", "unsupported", "UNSUPPORTED"), + MCP("mcp", "UNSUPPORTED", "UNSUPPORTED"), TEMPLATE_TRANSFORM("template-transform", "template-transform", "UNSUPPORTED"), diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/QuestionClassifierNodeData.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/QuestionClassifierNodeData.java index 08c05e541d..3b86b45cd0 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/QuestionClassifierNodeData.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/QuestionClassifierNodeData.java @@ -15,12 +15,14 @@ */ package com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata; -import java.util.List; - import com.alibaba.cloud.ai.studio.admin.generator.model.Variable; import com.alibaba.cloud.ai.studio.admin.generator.model.VariableSelector; import com.alibaba.cloud.ai.studio.admin.generator.model.VariableType; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeData; +import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; + +import java.util.List; +import java.util.Map; /** * @author HeYQ @@ -28,53 +30,52 @@ */ public class QuestionClassifierNodeData extends NodeData { - public static Variable getDefaultOutputSchema() { - return new Variable("class_name", VariableType.STRING); + public static Variable getDefaultOutputSchema(DSLDialectType dialectType) { + return switch (dialectType) { + case DIFY -> new Variable("class_name", VariableType.STRING); + case STUDIO -> new Variable("subject", VariableType.STRING); + default -> new Variable("text", VariableType.STRING); + }; } - private ModelConfig model; + private String chatModeName; - private MemoryConfig memoryConfig; + private Map modeParams; - private List promptTemplate; + private VariableSelector inputSelector; - private String instruction; + private String outputKey; private List classes; - private String outputKey; + private String promptTemplate; - private String inputTextKey; + public record ClassConfig(String id, String classTemplate) { - public QuestionClassifierNodeData(List inputs, List outputs) { - super(inputs, outputs); } - public ModelConfig getModel() { - return model; + public String getChatModeName() { + return chatModeName; } - public QuestionClassifierNodeData setModel(ModelConfig model) { - this.model = model; - return this; + public void setChatModeName(String chatModeName) { + this.chatModeName = chatModeName; } - public List getPromptTemplate() { - return promptTemplate; + public Map getModeParams() { + return modeParams; } - public QuestionClassifierNodeData setPromptTemplate(List promptTemplate) { - this.promptTemplate = promptTemplate; - return this; + public void setModeParams(Map modeParams) { + this.modeParams = modeParams; } - public MemoryConfig getMemoryConfig() { - return memoryConfig; + public VariableSelector getInputSelector() { + return inputSelector; } - public QuestionClassifierNodeData setMemoryConfig(MemoryConfig memoryConfig) { - this.memoryConfig = memoryConfig; - return this; + public void setInputSelector(VariableSelector inputSelector) { + this.inputSelector = inputSelector; } public String getOutputKey() { @@ -85,271 +86,20 @@ public void setOutputKey(String outputKey) { this.outputKey = outputKey; } - public String getInputTextKey() { - return inputTextKey; - } - - public QuestionClassifierNodeData setInputTextKey(String inputTextKey) { - this.inputTextKey = inputTextKey; - return this; - } - - public String getInstruction() { - return instruction; - } - - public QuestionClassifierNodeData setInstruction(String instruction) { - this.instruction = instruction; - return this; - } - public List getClasses() { return classes; } - public QuestionClassifierNodeData setClasses(List classes) { + public void setClasses(List classes) { this.classes = classes; - return this; } - public static class ClassConfig { - - private String id; - - private String text; - - public String getId() { - return id; - } - - public ClassConfig setId(String id) { - this.id = id; - return this; - } - - public String getText() { - return text; - } - - public ClassConfig setText(String text) { - this.text = text; - return this; - } - - } - - public static class PromptTemplate { - - private String role; - - private String text; - - public PromptTemplate() { - } - - public PromptTemplate(String role, String text) { - this.role = role; - this.text = text; - } - - public String getText() { - return text; - } - - public PromptTemplate setText(String text) { - this.text = text; - return this; - } - - public String getRole() { - return role; - } - - public PromptTemplate setRole(String role) { - this.role = role; - return this; - } - - } - - public static class ModelConfig { - - public static final String MODE_COMPLETION = "completion"; - - public static final String MODE_CHAT = "chat"; - - private String mode; - - private String name; - - private String provider; - - private CompletionParams completionParams; - - public String getMode() { - return mode; - } - - public ModelConfig setMode(String mode) { - this.mode = mode; - return this; - } - - public String getName() { - return name; - } - - public ModelConfig setName(String name) { - this.name = name; - return this; - } - - public String getProvider() { - return provider; - } - - public ModelConfig setProvider(String provider) { - this.provider = provider; - return this; - } - - public CompletionParams getCompletionParams() { - return completionParams; - } - - public ModelConfig setCompletionParams(CompletionParams completionParams) { - this.completionParams = completionParams; - return this; - } - - } - - public static class CompletionParams { - - private Integer maxTokens; - - private Float repetitionPenalty; - - private String responseFormat; - - private Integer seed; - - private List stop; - - private Float temperature; - - private Float topP; - - private Integer topK; - - private Integer frequencyPenalty; - - private Integer presencePenalty; - - public Integer getMaxTokens() { - return maxTokens; - } - - public CompletionParams setMaxTokens(Integer maxTokens) { - this.maxTokens = maxTokens; - return this; - } - - public void setRepetitionPenalty(Float repetitionPenalty) { - this.repetitionPenalty = repetitionPenalty; - } - - public void setResponseFormat(String responseFormat) { - this.responseFormat = responseFormat; - } - - public void setSeed(Integer seed) { - this.seed = seed; - } - - public void setStop(List stop) { - this.stop = stop; - } - - public void setTemperature(Float temperature) { - this.temperature = temperature; - } - - public void setTopP(Float topP) { - this.topP = topP; - } - - public void setTopK(Integer topK) { - this.topK = topK; - } - - public void setFrequencyPenalty(Integer frequencyPenalty) { - this.frequencyPenalty = frequencyPenalty; - } - - public void setPresencePenalty(Integer presencePenalty) { - this.presencePenalty = presencePenalty; - } - + public String getPromptTemplate() { + return promptTemplate; } - public static class MemoryConfig { - - private Boolean enabled = false; - - private Integer windowSize = 20; - - private Boolean windowEnabled = true; - - private Boolean includeLastMessage = false; - - private String lastMessageTemplate; - - public Boolean getEnabled() { - return enabled; - } - - public MemoryConfig setEnabled(Boolean enabled) { - this.enabled = enabled; - return this; - } - - public Integer getWindowSize() { - return windowSize; - } - - public MemoryConfig setWindowSize(Integer windowSize) { - this.windowSize = windowSize; - return this; - } - - public Boolean getWindowEnabled() { - return windowEnabled; - } - - public MemoryConfig setWindowEnabled(Boolean windowEnabled) { - this.windowEnabled = windowEnabled; - return this; - } - - public Boolean getIncludeLastMessage() { - return includeLastMessage; - } - - public MemoryConfig setIncludeLastMessage(Boolean includeLastMessage) { - this.includeLastMessage = includeLastMessage; - return this; - } - - public String getLastMessageTemplate() { - return lastMessageTemplate; - } - - public MemoryConfig setLastMessageTemplate(String lastMessageTemplate) { - this.lastMessageTemplate = lastMessageTemplate; - return this; - } - + public void setPromptTemplate(String promptTemplate) { + this.promptTemplate = promptTemplate; } } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/AbstractNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/AbstractNodeDataConverter.java index 55f0b01c2f..4532460de6 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/AbstractNodeDataConverter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/AbstractNodeDataConverter.java @@ -81,13 +81,18 @@ public interface DialectConverter { * @return 变量选择器 */ default VariableSelector varTemplateToSelector(DSLDialectType dialectType, String template) { + if (template == null) { + throw new NullPointerException("Template string is null"); + } Pattern pattern = switch (dialectType) { case DIFY -> DIFY_VAR_TEMPLATE_PATTERN; case STUDIO -> STUDIO_VAR_TEMPLATE_PATTERN; default -> throw new UnsupportedOperationException(); }; Matcher matcher = pattern.matcher(template); - MatchResult result = matcher.results().findFirst().orElseThrow(); + MatchResult result = matcher.results() + .findFirst() + .orElseThrow(() -> new IllegalArgumentException("Invalid template string")); return new VariableSelector(result.group(1), result.group(2)); } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/QuestionClassifyNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/QuestionClassifyNodeDataConverter.java index bbf4242edf..27580211c7 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/QuestionClassifyNodeDataConverter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/QuestionClassifyNodeDataConverter.java @@ -15,12 +15,11 @@ */ package com.alibaba.cloud.ai.studio.admin.generator.service.dsl.converter; -import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.function.BiConsumer; +import java.util.stream.Collectors; import java.util.stream.Stream; import com.alibaba.cloud.ai.studio.admin.generator.model.VariableSelector; @@ -28,12 +27,8 @@ import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata.QuestionClassifierNodeData; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.AbstractNodeDataConverter; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; -import com.alibaba.cloud.ai.studio.admin.generator.utils.StringTemplateUtil; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.databind.DeserializationFeature; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.PropertyNamingStrategies; -import org.apache.commons.collections4.CollectionUtils; +import com.alibaba.cloud.ai.studio.admin.generator.utils.MapReadUtil; +import com.fasterxml.jackson.core.JsonProcessingException; import org.springframework.stereotype.Component; @@ -67,117 +62,113 @@ public Boolean supportDialect(DSLDialectType dialectType) { @Override public QuestionClassifierNodeData parse(Map data) { - List inputs = Optional.ofNullable((List) data.get("query_variable_selector")) - .filter(CollectionUtils::isNotEmpty) - .map(variables -> Collections - .singletonList(new VariableSelector(variables.get(0), variables.get(1)))) - .orElse(Collections.emptyList()); - - // convert model config - Map modelData = (Map) data.get("model"); - ObjectMapper objectMapper = new ObjectMapper(); - objectMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); - objectMapper.setPropertyNamingStrategy(PropertyNamingStrategies.SNAKE_CASE); - QuestionClassifierNodeData.ModelConfig modelConfig = new QuestionClassifierNodeData.ModelConfig() - .setMode((String) modelData.get("mode")) - .setName((String) modelData.get("name")) - .setProvider((String) modelData.get("provider")) - .setCompletionParams(objectMapper.convertValue(modelData.get("completion_params"), - QuestionClassifierNodeData.CompletionParams.class)); - - QuestionClassifierNodeData nodeData = new QuestionClassifierNodeData(inputs, - List.of(QuestionClassifierNodeData.getDefaultOutputSchema())) - .setModel(modelConfig); - - // convert instructions - String instruction = (String) data.get("instructions"); - if (instruction != null && !instruction.isBlank()) { - nodeData.setInstruction(instruction); - } - - // convert classes - if (data.containsKey("classes")) { - List> classes = (List>) data.get("classes"); - nodeData.setClasses(classes.stream() - .map(item -> new QuestionClassifierNodeData.ClassConfig().setId((String) item.get("id")) - .setText((String) item.get("name"))) - .toList()); - } - - // convert memory config - if (data.containsKey("memory")) { - Map memoryData = (Map) data.get("memory"); - String lastMessageTemplate = (String) memoryData.get("query_prompt_template"); - Map window = (Map) memoryData.get("window"); - Boolean windowEnabled = (Boolean) window.get("enabled"); - Integer windowSize = (Integer) window.get("size"); - QuestionClassifierNodeData.MemoryConfig memory = new QuestionClassifierNodeData.MemoryConfig() - .setWindowEnabled(windowEnabled) - .setWindowSize(windowSize) - .setLastMessageTemplate(lastMessageTemplate) - .setIncludeLastMessage(false); - nodeData.setMemoryConfig(memory); - } - - // output_key - String outputKey = (String) data.get("output_key"); + QuestionClassifierNodeData nodeData = new QuestionClassifierNodeData(); + + // 获取必要的信息 + String modeName = MapReadUtil.getMapDeepValue(data, String.class, "model", "name"); + Map modeParams = MapReadUtil.safeCastToMapWithStringKey( + MapReadUtil.getMapDeepValue(data, Map.class, "model", "completion_params")); + List inputSelectorList = Optional + .ofNullable(MapReadUtil.safeCastToList( + MapReadUtil.getMapDeepValue(data, List.class, "query_variable_selector"), String.class)) + .orElseThrow(); + VariableSelector selector = new VariableSelector(inputSelectorList.get(0), inputSelectorList.get(1)); + String outputKey = QuestionClassifierNodeData.getDefaultOutputSchema(DSLDialectType.DIFY).getName(); + List classes = Optional + .ofNullable( + MapReadUtil.safeCastToListWithMap(MapReadUtil.getMapDeepValue(data, List.class, "classes"))) + .orElseThrow() + .stream() + .filter(map -> map.containsKey("id") && map.containsKey("name")) + .map(map -> new QuestionClassifierNodeData.ClassConfig(map.get("id").toString(), + map.get("name").toString())) + .toList(); + String promptTemplate = Optional + .ofNullable(MapReadUtil.getMapDeepValue(data, String.class, "instruction")) + .orElse(""); + + // 设置基本信息 + nodeData.setChatModeName(modeName); + nodeData.setModeParams(modeParams); + nodeData.setInputSelector(selector); nodeData.setOutputKey(outputKey); + nodeData.setClasses(classes); + nodeData.setPromptTemplate(promptTemplate); + return nodeData; + } + + @Override + public Map dump(QuestionClassifierNodeData nodeData) { + throw new UnsupportedOperationException(); + } + }) - // input_text_key + , STUDIO(new DialectConverter<>() { + @Override + public Boolean supportDialect(DSLDialectType dialectType) { + return DSLDialectType.STUDIO.equals(dialectType); + } + @Override + public QuestionClassifierNodeData parse(Map data) throws JsonProcessingException { + QuestionClassifierNodeData nodeData = new QuestionClassifierNodeData(); + // 从data中提取必要信息 + Map modeConfigMap = MapReadUtil.safeCastToMapWithStringKey( + MapReadUtil.getMapDeepValue(data, Map.class, "config", "node_param", "model_config")); + String modeName = MapReadUtil.getMapDeepValue(modeConfigMap, String.class, "model_id"); + Map modeParams = Optional + .ofNullable(MapReadUtil + .safeCastToListWithMap(MapReadUtil.getMapDeepValue(modeConfigMap, List.class, "params"))) + .orElse(List.of()) + .stream() + .filter(map -> Boolean.TRUE.equals(map.get("enable"))) + .filter(map -> map.containsKey("key") && map.containsKey("value")) + .map(map -> Map.entry(map.get("key").toString(), map.get("value"))) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (a, b) -> b)); + VariableSelector selector = this.varTemplateToSelector(DSLDialectType.STUDIO, + MapReadUtil + .safeCastToListWithMap( + MapReadUtil.getMapDeepValue(data, List.class, "config", "input_params")) + .get(0) + .get("value") + .toString()); + String outputKey = QuestionClassifierNodeData.getDefaultOutputSchema(DSLDialectType.STUDIO).getName(); + List classes = Optional + .ofNullable(MapReadUtil.safeCastToListWithMap( + MapReadUtil.getMapDeepValue(data, List.class, "config", "node_param", "conditions"))) + .orElseThrow() + .stream() + .filter(map -> map.containsKey("id") && map.containsKey("subject")) + .map(map -> { + String id = map.get("id").toString(); + String subject = map.get("subject").toString(); + if ("default".equalsIgnoreCase(id)) { + subject = "default"; + } + return new QuestionClassifierNodeData.ClassConfig(id, subject); + }) + .toList(); + String promptTemplate = Optional + .ofNullable(MapReadUtil.getMapDeepValue(data, String.class, "config", "node_param", "instruction")) + .orElse(""); + + // 设置基本信息 + nodeData.setChatModeName(modeName); + nodeData.setModeParams(modeParams); + nodeData.setInputSelector(selector); + nodeData.setOutputKey(outputKey); + nodeData.setClasses(classes); + nodeData.setPromptTemplate(promptTemplate); return nodeData; } @Override public Map dump(QuestionClassifierNodeData nodeData) { - Map data = new HashMap<>(); - ObjectMapper objectMapper = new ObjectMapper(); - objectMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); - objectMapper.setPropertyNamingStrategy(PropertyNamingStrategies.LOWER_CASE); - objectMapper.setSerializationInclusion(JsonInclude.Include.NON_NULL); - - // put memory - QuestionClassifierNodeData.MemoryConfig memory = nodeData.getMemoryConfig(); - if (memory != null) { - data.put("memory", - Map.of("query_prompt_template", - StringTemplateUtil.toDifyTmpl(memory.getLastMessageTemplate()), "role_prefix", - Map.of("assistant", "", "user", ""), "window", - Map.of("enabled", memory.getWindowEnabled(), "size", memory.getWindowSize()))); - } - - // put model - QuestionClassifierNodeData.ModelConfig model = nodeData.getModel(); - data.put("model", - Map.of("mode", model.getMode(), "name", model.getName(), "provider", model.getProvider(), - "completion_params", - objectMapper.convertValue(model.getCompletionParams(), Map.class))); - - // put query_variable_selector - List inputs = nodeData.getInputs(); - Optional.ofNullable(inputs) - .filter(CollectionUtils::isNotEmpty) - .map(inputList -> inputList.stream() - .findFirst() - .map(input -> List.of(input.getNamespace(), input.getName())) - .orElse(Collections.emptyList())) - .ifPresent(variables -> data.put("query_variable_selector", variables)); - - // put instructions - data.put("instructions", nodeData.getInstruction() != null ? nodeData.getInstruction() : ""); - - // put Classes - if (!CollectionUtils.isEmpty(nodeData.getClasses())) { - data.put("classes", - nodeData.getClasses() - .stream() - .map(item -> Map.of("id", item.getId(), "text", item.getText())) - .toList()); - } - - return data; + throw new UnsupportedOperationException(); } - }), CUSTOM(AbstractNodeDataConverter.defaultCustomDialectConverter(QuestionClassifierNodeData.class)); + }) + + , CUSTOM(AbstractNodeDataConverter.defaultCustomDialectConverter(QuestionClassifierNodeData.class)); private final DialectConverter dialectConverter; @@ -197,12 +188,38 @@ public String generateVarName(int count) { @Override public BiConsumer> postProcessConsumer(DSLDialectType dialectType) { + BiConsumer> consumer = emptyProcessConsumer() + .andThen((nodeData, idToVarName) -> { + nodeData.setOutputs(List.of(QuestionClassifierNodeData.getDefaultOutputSchema(dialectType))); + nodeData.setInputs(List.of(nodeData.getInputSelector())); + }) + .andThen(super.postProcessConsumer(dialectType)) + .andThen((nodeData, idToVarName) -> { + nodeData.setOutputKey(nodeData.getOutputs().get(0).getName()); + nodeData.setInputSelector(nodeData.getInputs().get(0)); + // 替换掉类别和指导中的占位变量 + nodeData + .setPromptTemplate(this.convertVarTemplate(dialectType, nodeData.getPromptTemplate(), idToVarName)); + nodeData.setClasses(nodeData.getClasses() + .stream() + .map(classConfig -> new QuestionClassifierNodeData.ClassConfig(classConfig.id(), + this.convertVarTemplate(dialectType, classConfig.classTemplate(), idToVarName))) + .toList()); + }); return switch (dialectType) { - case DIFY -> emptyProcessConsumer().andThen((data, map) -> { - data.setOutputKey( - data.getVarName() + "_" + QuestionClassifierNodeData.getDefaultOutputSchema().getName()); - data.setOutputs(List.of(QuestionClassifierNodeData.getDefaultOutputSchema())); - }).andThen(super.postProcessConsumer(dialectType)); + case DIFY -> consumer; + case STUDIO -> consumer.andThen((nodeData, idToVarName) -> { + // 将classConfig的id里添加nodeId(为了与Edge里的sourceHandle保持一致) + Map varNameToId = idToVarName.entrySet() + .stream() + .collect(Collectors.toUnmodifiableMap(Map.Entry::getValue, Map.Entry::getKey)); + String nodeId = varNameToId.getOrDefault(nodeData.getVarName(), nodeData.getVarName()); + nodeData.setClasses(nodeData.getClasses() + .stream() + .map(classConfig -> new QuestionClassifierNodeData.ClassConfig(nodeId + "_" + classConfig.id(), + classConfig.classTemplate())) + .toList()); + }); default -> super.postProcessConsumer(dialectType); }; } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/NodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/NodeSection.java index c847ed4485..1438d33de1 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/NodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/NodeSection.java @@ -36,6 +36,7 @@ public interface NodeSection { boolean support(NodeType nodeType); + // TODO: NodeData里有varName字段,去掉varName参数 String render(Node node, String varName); /** diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java index 59be8b4a96..bf01fd668b 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java @@ -21,6 +21,7 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.Arrays; +import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.Objects; @@ -57,21 +58,30 @@ public class WorkflowProjectGenerator implements ProjectGenerator { private static final Logger log = LoggerFactory.getLogger(WorkflowProjectGenerator.class); - private final String GRAPH_BUILDER_TEMPLATE_NAME = "GraphBuilder.java"; + private static final String GRAPH_BUILDER_TEMPLATE_NAME = "GraphBuilder.java"; - private final String GRAPH_BUILDER_STATE_SECTION = "stateSection"; + private static final String GRAPH_BUILDER_STATE_SECTION = "stateSection"; - private final String GRAPH_BUILDER_NODE_SECTION = "nodeSection"; + private static final String GRAPH_BUILDER_NODE_SECTION = "nodeSection"; - private final String GRAPH_BUILDER_EDGE_SECTION = "edgeSection"; + private static final String GRAPH_BUILDER_EDGE_SECTION = "edgeSection"; - private final String GRAPH_BUILDER_IMPORT_SECTION = "importSection"; + private static final String GRAPH_BUILDER_IMPORT_SECTION = "importSection"; - private final String GRAPH_BUILDER_ASSIST_METHOD_CODE = "assistMethodCode"; + private static final String GRAPH_BUILDER_ASSIST_METHOD_CODE = "assistMethodCode"; - private final String GRAPH_RUN_TEMPLATE_NAME = "GraphRunController.java"; + private static final String GRAPH_RUN_TEMPLATE_NAME = "GraphRunController.java"; - private final String PACKAGE_NAME = "packageName"; + private static final String PACKAGE_NAME = "packageName"; + + private static final List GRAPH_COMMON_IMPORTS = List.of("com.alibaba.cloud.ai.graph.CompiledGraph", + "com.alibaba.cloud.ai.graph.KeyStrategy", "com.alibaba.cloud.ai.graph.OverAllState", + "com.alibaba.cloud.ai.graph.StateGraph", "com.alibaba.cloud.ai.graph.action.AsyncEdgeAction", + "com.alibaba.cloud.ai.graph.action.AsyncNodeAction", "com.alibaba.cloud.ai.graph.action.NodeAction", + "com.alibaba.cloud.ai.graph.exception.GraphStateException", "org.springframework.ai.chat.client.ChatClient", + "org.springframework.ai.chat.model.ChatModel", "org.springframework.context.annotation.Bean", + "org.springframework.stereotype.Component", "java.util.HashMap", "java.util.Map", "java.util.List", + "static com.alibaba.cloud.ai.graph.StateGraph.END", "static com.alibaba.cloud.ai.graph.StateGraph.START"); private final List dslAdapters; @@ -237,13 +247,21 @@ private String renderImportSection(Workflow workflow) { return ""; } - StringBuilder sb = new StringBuilder(); - uniqueTypes.stream() + List commonImports = uniqueTypes.stream() .map(nodeSectionMap::get) .map(NodeSection::getImports) .flatMap(List::stream) .distinct() - .forEach(className -> sb.append("import ").append(className).append(";\n")); + .toList(); + // 按照字典序升序排序,其中static开头的放在后面 + List allImports = Stream.of(commonImports, GRAPH_COMMON_IMPORTS) + .flatMap(List::stream) + .distinct() + .sorted(Comparator.comparing((String s) -> s.startsWith("static")).thenComparing(String::compareTo)) + .toList(); + + StringBuilder sb = new StringBuilder(); + allImports.forEach(className -> sb.append("import ").append(className).append(";\n")); return sb.toString(); } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/QuestionClassifierNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/QuestionClassifierNodeSection.java index 1fd035bc69..1e9c142d33 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/QuestionClassifierNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/QuestionClassifierNodeSection.java @@ -16,17 +16,17 @@ package com.alibaba.cloud.ai.studio.admin.generator.service.generator.workflow.sections; -import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; -import com.alibaba.cloud.ai.studio.admin.generator.model.VariableSelector; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.Edge; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.Node; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeType; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata.QuestionClassifierNodeData; +import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; import com.alibaba.cloud.ai.studio.admin.generator.service.generator.workflow.NodeSection; -import com.google.common.base.Strings; +import com.alibaba.cloud.ai.studio.admin.generator.utils.ObjectToCodeUtil; import org.springframework.stereotype.Component; @@ -40,98 +40,97 @@ public boolean support(NodeType nodeType) { @Override public String render(Node node, String varName) { - QuestionClassifierNodeData data = (QuestionClassifierNodeData) node.getData(); - String id = node.getId(); - - StringBuilder sb = new StringBuilder(); - sb.append(String.format("// —— QuestionClassifierNode [%s] ——%n", id)); - sb.append(String.format("QuestionClassifierNode %s = QuestionClassifierNode.builder()%n", varName)); - - sb.append(".chatClient(chatClient)\n"); - - List inputs = data.getInputs(); - if (inputs != null && !inputs.isEmpty()) { - String key = inputs.get(0).getNameInCode(); - sb.append(String.format(".inputTextKey(\"%s\")%n", escape(key))); - } - else { - sb.append(".inputTextKey(\"input\")\n"); - } - - List categoryIds = data.getClasses() - .stream() - .map(QuestionClassifierNodeData.ClassConfig::getText) - .toList(); - if (!categoryIds.isEmpty()) { - String joined = categoryIds.stream() - .map(this::escape) - .map(s -> "\"" + s + "\"") - .collect(Collectors.joining(", ")); - sb.append(String.format(".categories(List.of(%s))%n", joined)); - } - - String outputKey = data.getOutputKey(); - if (!Strings.isNullOrEmpty(outputKey)) { - sb.append(String.format(".outputKey(\"%s\")%n", escape(outputKey))); - } - - String instr = data.getInstruction(); - if (instr != null && !instr.isBlank()) { - sb.append(String.format(".classificationInstructions(List.of(\"%s\"))%n", escape(instr))); - } - else { - sb.append(".classificationInstructions(List.of(\"请根据输入内容选择对应分类\"))\n"); - } - - sb.append(".build();\n"); - sb.append(String.format("stateGraph.addNode(\"%s\", AsyncNodeAction.node_async(%s));%n%n", varName, varName)); - - return sb.toString(); + QuestionClassifierNodeData nodeData = (QuestionClassifierNodeData) node.getData(); + return String.format(""" + // —— QuestionClassifierNode [%s] —— + stateGraph.addNode("%s", AsyncNodeAction.node_async( + createQuestionClassifierAction(%s, %s, "%s", "%s", %s, %s) + )); + + """, node.getId(), varName, ObjectToCodeUtil.toCode(nodeData.getChatModeName()), + ObjectToCodeUtil.toCode(nodeData.getModeParams()), nodeData.getInputSelector().getNameInCode(), + nodeData.getOutputKey(), + ObjectToCodeUtil.toCode(nodeData.getClasses() + .stream() + .collect(Collectors.toUnmodifiableMap(QuestionClassifierNodeData.ClassConfig::id, + QuestionClassifierNodeData.ClassConfig::classTemplate, (a, b) -> b))), + ObjectToCodeUtil.toCode(List.of(nodeData.getPromptTemplate()))); } - private String resolveConditionKey(QuestionClassifierNodeData classifier, String handleId) { - return classifier.getClasses() - .stream() - .filter(c -> c.getId().equals(handleId)) - .map(QuestionClassifierNodeData.ClassConfig::getText) - .findFirst() - .orElse(handleId); + @Override + public String renderEdges(QuestionClassifierNodeData nodeData, List edges) { + // 规定edge的sourceHandle为caseId,前面的转化需要符合这条规则 + String edgeCode = String.format(""" + state -> { + String result = state.value("%s").orElseThrow().toString(); + %s + throw new RuntimeException("invalid output"); + } + """, nodeData.getOutputKey(), + nodeData.getClasses() + .stream() + .map(QuestionClassifierNodeData.ClassConfig::id) + .map(id -> String.format(""" + if("%s".equals(result)) { + return "%s"; + } + """, id, id)) + .collect(Collectors.joining("\n"))); + + Map caseToTarget = edges.stream() + .collect(Collectors.toUnmodifiableMap(Edge::getSourceHandle, Edge::getTarget)); + + return String.format(""" + // render QuestionNode [%s]'s edge + stateGraph.addConditionalEdges("%s", AsyncEdgeAction.edge_async(%s), %s); + + """, nodeData.getVarName(), nodeData.getVarName(), edgeCode, ObjectToCodeUtil.toCode(caseToTarget)); } @Override - public String renderEdges(QuestionClassifierNodeData nodeData, List edges) { - List conditions = new ArrayList<>(); - List mappings = new ArrayList<>(); - String srcVar = nodeData.getVarName(); - StringBuilder sb = new StringBuilder(); - - // 如果输出的都不是预定分类,则使用最后一个分类 - String lastConditionKey = "unknown"; - - for (Edge e : edges) { - String conditionKey = resolveConditionKey(nodeData, e.getSourceHandle()); - String tgtVar2 = e.getTarget(); - lastConditionKey = conditionKey; - conditions.add(String.format("if (value.contains(\"%s\")) return \"%s\";", conditionKey, conditionKey)); - mappings.add(String.format("\"%s\", \"%s\"", conditionKey, tgtVar2)); - } - - String lambdaContent = String.join("\n", conditions); - String mapContent = String.join(", ", mappings); - - sb.append(String.format( - "stateGraph.addConditionalEdges(\"%s\",%n" + " edge_async(state -> {%n" - + "String value = state.value(\"%s_class_name\", String.class).orElse(\"\");%n" + "%s%n" - + "return \"%s\";%n" + " }),%n" + " Map.of(%s)%n" + ");%n", - srcVar, srcVar, lambdaContent, lastConditionKey, mapContent)); - - return sb.toString(); + public String assistMethodCode(DSLDialectType dialectType) { + return switch (dialectType) { + case DIFY, STUDIO -> + """ + @Autowired + private ChatModel chatModelForQuestion; + + private NodeAction createQuestionClassifierAction( + String chatModelName, Map modeParams, + String inputKey, String outputKey, + Map categories, List instructions) { + // build ChatClient + var chatOptionsBuilder = DashScopeChatOptions.builder().withModel(chatModelName); + Optional.ofNullable(modeParams.get("temperature")) + .ifPresent(val -> chatOptionsBuilder.withTemperature(val.doubleValue())); + Optional.ofNullable(modeParams.get("seed")).ifPresent(val -> chatOptionsBuilder.withSeed(val.intValue())); + Optional.ofNullable(modeParams.get("top_p")).ifPresent(val -> chatOptionsBuilder.withTopP(val.doubleValue())); + Optional.ofNullable(modeParams.get("top_k")).ifPresent(val -> chatOptionsBuilder.withTopK(val.intValue())); + Optional.ofNullable(modeParams.get("max_tokens")) + .ifPresent(val -> chatOptionsBuilder.withMaxToken(val.intValue())); + Optional.ofNullable(modeParams.get("repetition_penalty")) + .ifPresent(val -> chatOptionsBuilder.withRepetitionPenalty(val.doubleValue())); + final ChatClient chatClient = ChatClient.builder(chatModelForQuestion).defaultOptions(chatOptionsBuilder.build()).build(); + + // build Node + return QuestionClassifierNode.builder() + .chatClient(chatClient) + .inputTextKey(inputKey) + .outputKey(outputKey) + .categories(categories) + .classificationInstructions(instructions) + .build(); + } + """; + default -> ""; + }; } @Override public List getImports() { return List.of("com.alibaba.cloud.ai.graph.node.QuestionClassifierNode", - "static com.alibaba.cloud.ai.graph.action.AsyncEdgeAction.edge_async"); + "org.springframework.beans.factory.annotation.Autowired", + "com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions", "java.util.Optional"); } } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/resources/templates/GraphBuilder.java.mustache b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/resources/templates/GraphBuilder.java.mustache index fa67caf3d2..d6431ab57a 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/resources/templates/GraphBuilder.java.mustache +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/resources/templates/GraphBuilder.java.mustache @@ -1,25 +1,5 @@ package {{packageName}}.graph; -import com.alibaba.cloud.ai.graph.CompiledGraph; -import com.alibaba.cloud.ai.graph.KeyStrategy; -import com.alibaba.cloud.ai.graph.OverAllState; -import com.alibaba.cloud.ai.graph.StateGraph; -import com.alibaba.cloud.ai.graph.action.AsyncEdgeAction; -import com.alibaba.cloud.ai.graph.action.AsyncNodeAction; -import com.alibaba.cloud.ai.graph.action.NodeAction; -import com.alibaba.cloud.ai.graph.exception.GraphStateException; -import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor; -import org.springframework.ai.chat.model.ChatModel; -import org.springframework.context.annotation.Bean; -import org.springframework.stereotype.Component; - -import java.util.HashMap; -import java.util.Map; -import java.util.List; - -import static com.alibaba.cloud.ai.graph.StateGraph.END; -import static com.alibaba.cloud.ai.graph.StateGraph.START; {{importSection}} @Component @@ -28,9 +8,7 @@ public class GraphBuilder { {{assistMethodCode}} @Bean - public CompiledGraph buildGraph(ChatModel chatModel) throws Exception { - ChatClient chatClient = ChatClient.builder(chatModel).defaultAdvisors(new SimpleLoggerAdvisor()).build(); - + public CompiledGraph buildGraph() throws Exception { // new stateGraph StateGraph stateGraph = new StateGraph({{stateSection}}); // add nodes From 9873e9dae85e025ae423419c237da8988671f7f2 Mon Sep 17 00:00:00 2001 From: VLSMB <2047857654@qq.com> Date: Mon, 15 Sep 2025 20:07:59 +0800 Subject: [PATCH 05/12] refactor: add caseIdToName Map --- .../nodedata/QuestionClassifierNodeData.java | 17 +++++++++++++++++ .../QuestionClassifyNodeDataConverter.java | 12 +++++------- .../sections/QuestionClassifierNodeSection.java | 6 ++++-- 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/QuestionClassifierNodeData.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/QuestionClassifierNodeData.java index 3b86b45cd0..84adde2eda 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/QuestionClassifierNodeData.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/QuestionClassifierNodeData.java @@ -23,6 +23,8 @@ import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; /** * @author HeYQ @@ -50,6 +52,8 @@ public static Variable getDefaultOutputSchema(DSLDialectType dialectType) { private String promptTemplate; + private Map classIdToName; + public record ClassConfig(String id, String classTemplate) { } @@ -92,6 +96,7 @@ public List getClasses() { public void setClasses(List classes) { this.classes = classes; + updateClassIdToName(); } public String getPromptTemplate() { @@ -102,4 +107,16 @@ public void setPromptTemplate(String promptTemplate) { this.promptTemplate = promptTemplate; } + public Map getClassIdToName() { + return classIdToName; + } + + private void updateClassIdToName() { + AtomicInteger count = new AtomicInteger(1); + this.classIdToName = this.getClasses() + .stream() + .map(QuestionClassifierNodeData.ClassConfig::id) + .collect(Collectors.toUnmodifiableMap(id -> id, name -> "case_" + (count.getAndIncrement()))); + } + } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/QuestionClassifyNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/QuestionClassifyNodeDataConverter.java index 27580211c7..bf37fe5295 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/QuestionClassifyNodeDataConverter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/QuestionClassifyNodeDataConverter.java @@ -125,13 +125,11 @@ public QuestionClassifierNodeData parse(Map data) throws JsonPro .filter(map -> map.containsKey("key") && map.containsKey("value")) .map(map -> Map.entry(map.get("key").toString(), map.get("value"))) .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (a, b) -> b)); - VariableSelector selector = this.varTemplateToSelector(DSLDialectType.STUDIO, - MapReadUtil - .safeCastToListWithMap( - MapReadUtil.getMapDeepValue(data, List.class, "config", "input_params")) - .get(0) - .get("value") - .toString()); + VariableSelector selector = this.varTemplateToSelector(DSLDialectType.STUDIO, MapReadUtil + .safeCastToListWithMap(MapReadUtil.getMapDeepValue(data, List.class, "config", "input_params")) + .get(0) + .get("value") + .toString()); String outputKey = QuestionClassifierNodeData.getDefaultOutputSchema(DSLDialectType.STUDIO).getName(); List classes = Optional .ofNullable(MapReadUtil.safeCastToListWithMap( diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/QuestionClassifierNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/QuestionClassifierNodeSection.java index 1e9c142d33..c1a498b450 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/QuestionClassifierNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/QuestionClassifierNodeSection.java @@ -59,6 +59,7 @@ public String render(Node node, String varName) { @Override public String renderEdges(QuestionClassifierNodeData nodeData, List edges) { + Map classIdToName = nodeData.getClassIdToName(); // 规定edge的sourceHandle为caseId,前面的转化需要符合这条规则 String edgeCode = String.format(""" state -> { @@ -74,11 +75,12 @@ public String renderEdges(QuestionClassifierNodeData nodeData, List edges) if("%s".equals(result)) { return "%s"; } - """, id, id)) + """, id, classIdToName.getOrDefault(id, id))) .collect(Collectors.joining("\n"))); Map caseToTarget = edges.stream() - .collect(Collectors.toUnmodifiableMap(Edge::getSourceHandle, Edge::getTarget)); + .collect(Collectors.toUnmodifiableMap( + e -> classIdToName.getOrDefault(e.getSourceHandle(), e.getSourceHandle()), Edge::getTarget)); return String.format(""" // render QuestionNode [%s]'s edge From f7afc2b0ce0ddc466c1bd08894d8856ae7ea5b3b Mon Sep 17 00:00:00 2001 From: VLSMB <2047857654@qq.com> Date: Mon, 15 Sep 2025 21:20:45 +0800 Subject: [PATCH 06/12] feat: enhance AssignerNode (prepare for Studio DSL) --- .../cloud/ai/graph/node/AssignerNode.java | 83 +++++++++---------- .../cloud/ai/graph/node/AssignerNodeTest.java | 21 ++++- 2 files changed, 61 insertions(+), 43 deletions(-) diff --git a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/AssignerNode.java b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/AssignerNode.java index 4d483aeb21..aa92797f11 100644 --- a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/AssignerNode.java +++ b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/AssignerNode.java @@ -33,39 +33,25 @@ public class AssignerNode implements NodeAction { public enum WriteMode { - OVER_WRITE, APPEND, CLEAR + OVER_WRITE, APPEND, CLEAR, INPUT_CONSTANT } /** * description of a single assignment operation */ - public static class AssignItem { - - private final String targetKey; - - private final String inputKey; - - private final WriteMode writeMode; - + public record AssignItem(String targetKey, String inputKey, WriteMode writeMode, Object inputValue) { public AssignItem(String targetKey, String inputKey, WriteMode writeMode) { - this.targetKey = targetKey; - this.inputKey = inputKey; - this.writeMode = writeMode; + this(targetKey, inputKey, writeMode, null); } - public String getTargetKey() { - return targetKey; + public AssignItem(String targetKey, Object inputValue) { + this(targetKey, null, WriteMode.INPUT_CONSTANT, inputValue); } - public String getInputKey() { - return inputKey; + public AssignItem(String targetKey) { + this(targetKey, null, WriteMode.OVER_WRITE); } - - public WriteMode getWriteMode() { - return writeMode; - } - } private final List items; @@ -88,15 +74,12 @@ public AssignerNode(String targetKey, String inputKey, WriteMode writeMode) { public Map apply(OverAllState state) { Map updates = new HashMap<>(); for (AssignItem item : items) { - Object value = state.value(item.inputKey).orElse(null); - Object targetValue = state.value(item.targetKey).orElse(null); - Object result = null; - - switch (item.writeMode) { - case OVER_WRITE: - result = value; - break; - case APPEND: + Object value = state.value(item.inputKey()).orElse(null); + Object targetValue = state.value(item.targetKey()).orElse(null); + + Object result = switch (item.writeMode()) { + case OVER_WRITE -> value; + case APPEND -> { if (targetValue instanceof List && value != null) { List newList = new ArrayList<>((List) targetValue); if (value instanceof Collection col) { @@ -105,35 +88,41 @@ public Map apply(OverAllState state) { else { newList.add(value); } - result = newList; + yield newList; } else if (value != null) { if (value instanceof Collection col) { - result = new ArrayList<>(col); + yield new ArrayList<>(col); } else { - result = new ArrayList<>(List.of(value)); + yield new ArrayList<>(List.of(value)); } } - break; - case CLEAR: + else { + throw new IllegalArgumentException( + "Cannot append to non-list value for key: " + item.targetKey()); + } + } + case CLEAR -> { if (targetValue instanceof List) { - result = new ArrayList<>(); + yield new ArrayList<>(); } else if (targetValue instanceof Map) { - result = new HashMap<>(); + yield new HashMap<>(); } else if (targetValue instanceof String) { - result = ""; + yield ""; } else if (targetValue instanceof Number) { - result = 0; + yield 0; } else { - result = null; + yield null; } - break; - } + } + case INPUT_CONSTANT -> item.inputValue(); + default -> throw new IllegalArgumentException("Invalid write mode: " + item.writeMode()); + }; updates.put(item.targetKey, result); } return updates; @@ -153,6 +142,16 @@ public Builder addItem(String targetKey, String inputKey, WriteMode writeMode) { return this; } + public Builder addConst(String targetKey, Object inputValue) { + items.add(new AssignItem(targetKey, inputValue)); + return this; + } + + public Builder addClear(String targetKey) { + items.add(new AssignItem(targetKey)); + return this; + } + public Builder addItem(AssignItem item) { items.add(item); return this; diff --git a/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/node/AssignerNodeTest.java b/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/node/AssignerNodeTest.java index 829987094a..221617ae8c 100644 --- a/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/node/AssignerNodeTest.java +++ b/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/node/AssignerNodeTest.java @@ -96,28 +96,47 @@ public void testBatchClear() { assertEquals(0, result.get("z")); } + @Test + public void testBatchInputConstant() { + OverAllState state = new OverAllState(); + state.registerKeyAndStrategy("x", new ReplaceStrategy()); + state.registerKeyAndStrategy("y", new ReplaceStrategy()); + state.registerKeyAndStrategy("z", new ReplaceStrategy()); + state.updateState(Map.of("x", "something", "y", new ArrayList<>(List.of(1, 2, 3)), "z", 42)); + + AssignerNode node = AssignerNode.builder().addConst("x", "x").addConst("y", "y").addConst("z", "z").build(); + + Map result = node.apply(state); + assertEquals("x", result.get("x")); + assertEquals("y", result.get("y")); + assertEquals("z", result.get("z")); + } + @Test public void testMixBatch() { OverAllState state = new OverAllState(); state.registerKeyAndStrategy("a", new ReplaceStrategy()); state.registerKeyAndStrategy("b", new ReplaceStrategy()); state.registerKeyAndStrategy("c", new ReplaceStrategy()); + state.registerKeyAndStrategy("d", new ReplaceStrategy()); state.registerKeyAndStrategy("input1", new ReplaceStrategy()); state.registerKeyAndStrategy("input2", new ReplaceStrategy()); state.registerKeyAndStrategy("input3", new ReplaceStrategy()); state.updateState(Map.of("input1", "A", "input2", "B", "input3", "C", "a", new ArrayList<>(List.of("a0")), "b", - "to be cleared", "c", 999)); + "to be cleared", "c", 999, "d", false)); AssignerNode node = AssignerNode.builder() .addItem("a", "input1", AssignerNode.WriteMode.APPEND) .addItem("b", null, AssignerNode.WriteMode.CLEAR) .addItem("c", "input3", AssignerNode.WriteMode.OVER_WRITE) + .addConst("d", true) .build(); Map result = node.apply(state); assertEquals(List.of("a0", "A"), result.get("a")); assertEquals("", result.get("b")); assertEquals("C", result.get("c")); + assertTrue((Boolean) result.get("d")); } @Test From 23575219c1b24d66417682ae7677dadafa08c0ed Mon Sep 17 00:00:00 2001 From: VLSMB <2047857654@qq.com> Date: Tue, 16 Sep 2025 09:45:36 +0800 Subject: [PATCH 07/12] feat: support AssignerNode for Studio DSL --- .../cloud/ai/graph/node/AssignerNode.java | 7 +- .../generator/model/workflow/NodeType.java | 2 +- .../workflow/nodedata/AssignerNodeData.java | 140 +++++++---------- .../converter/AssignerNodeDataConverter.java | 145 +++++++++++------- .../sections/AssignerNodeSection.java | 29 ++-- 5 files changed, 163 insertions(+), 160 deletions(-) diff --git a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/AssignerNode.java b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/AssignerNode.java index aa92797f11..ba1a59e291 100644 --- a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/AssignerNode.java +++ b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/AssignerNode.java @@ -135,7 +135,12 @@ public static Builder builder() { public static class Builder { - private final List items = new ArrayList<>(); + private List items = new ArrayList<>(); + + public Builder setItems(List items) { + this.items = new ArrayList<>(items); + return this; + } public Builder addItem(String targetKey, String inputKey, WriteMode writeMode) { items.add(new AssignItem(targetKey, inputKey, writeMode)); diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java index 41a9e879b5..eb61c9c7b1 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java @@ -67,7 +67,7 @@ public enum NodeType { ITERATION_END("iteration-end", "iteration-end", "ParallelEnd"), - ASSIGNER("assigner", "assigner", "UNSUPPORTED"); + ASSIGNER("assigner", "assigner", "VariableAssign"); private final String value; diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/AssignerNodeData.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/AssignerNodeData.java index 36ab681e43..1542bbb4a9 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/AssignerNodeData.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/AssignerNodeData.java @@ -13,117 +13,83 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata; -import java.util.List; +package com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata; import com.alibaba.cloud.ai.studio.admin.generator.model.VariableSelector; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeData; +import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; +import com.alibaba.cloud.ai.studio.admin.generator.utils.ObjectToCodeUtil; -public class AssignerNodeData extends NodeData { - - private List items; - - private String outputKey; - - private String title; - - private String desc; - - private String version; - - public static class AssignerItem { - - private String inputType; +import java.util.List; +import java.util.function.Function; - private String operation; +public class AssignerNodeData extends NodeData { - private VariableSelector value; + private List items; - private VariableSelector variableSelector; + public List getItems() { + return items; + } - private String writeMode; + public void setItems(List items) { + this.items = items; + } - public String getInputType() { - return inputType; + public record AssignItem(VariableSelector targetSelector, VariableSelector inputSelector, WriteMode writeMode, + String inputConst) { + @Override + public String toString() { + return String.format("new AssignerNode.AssignItem(%s, %s, %s, %s)", + ObjectToCodeUtil.toCode(this.targetSelector().getNameInCode()), + ObjectToCodeUtil.toCode(this.inputSelector().getNameInCode()), + ObjectToCodeUtil.toCode(this.writeMode()), ObjectToCodeUtil.toCode(this.inputConst())); } + } - public void setInputType(String inputType) { - this.inputType = inputType; - } + private static final String UNSUPPORTED = "UNSUPPORTED"; - public String getOperation() { - return operation; - } + public enum WriteMode { - public void setOperation(String operation) { - this.operation = operation; - } + OVER_WRITE(type -> switch (type) { + case DIFY -> "over-write"; + case STUDIO -> "refer"; + default -> UNSUPPORTED; + }), - public VariableSelector getValue() { - return value; - } + APPEND(type -> UNSUPPORTED), - public void setValue(VariableSelector value) { - this.value = value; - } + CLEAR(type -> switch (type) { + case DIFY, STUDIO -> "clear"; + default -> UNSUPPORTED; + }), - public VariableSelector getVariableSelector() { - return variableSelector; - } + INPUT_CONSTANT(type -> switch (type) { + case DIFY -> "set"; + case STUDIO -> "input"; + default -> UNSUPPORTED; + }); - public void setVariableSelector(VariableSelector variableSelector) { - this.variableSelector = variableSelector; - } + private final Function dslValue; - public String getWriteMode() { - return writeMode; + WriteMode(Function dslValue) { + this.dslValue = dslValue; } - public void setWriteMode(String writeMode) { - this.writeMode = writeMode; + public static WriteMode fromDslValue(DSLDialectType dialectType, String dslValue) { + for (WriteMode mode : WriteMode.values()) { + if (mode.dslValue.apply(dialectType).equals(dslValue)) { + return mode; + } + } + throw new IllegalArgumentException("Invalid write mode: " + dslValue); } - } - - public List getItems() { - return items; - } - - public void setItems(List items) { - this.items = items; - } - - public String getOutputKey() { - return outputKey; - } - - public void setOutputKey(String outputKey) { - this.outputKey = outputKey; - } - - public String getTitle() { - return title; - } - - public void setTitle(String title) { - this.title = title; - } - - public String getDesc() { - return desc; - } - - public void setDesc(String desc) { - this.desc = desc; - } - - public String getVersion() { - return version; - } + @Override + public String toString() { + return "AssignerNode.WriteMode." + this.name(); + } - public void setVersion(String version) { - this.version = version; } } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/AssignerNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/AssignerNodeDataConverter.java index 864898ff0c..f42002ecc0 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/AssignerNodeDataConverter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/AssignerNodeDataConverter.java @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package com.alibaba.cloud.ai.studio.admin.generator.service.dsl.converter; -import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.function.BiConsumer; +import java.util.function.Consumer; import java.util.stream.Stream; import com.alibaba.cloud.ai.studio.admin.generator.model.VariableSelector; @@ -26,10 +28,9 @@ import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata.AssignerNodeData; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.AbstractNodeDataConverter; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; -import com.fasterxml.jackson.databind.DeserializationFeature; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.alibaba.cloud.ai.studio.admin.generator.utils.MapReadUtil; +import com.fasterxml.jackson.core.JsonProcessingException; import org.springframework.stereotype.Component; @Component @@ -57,57 +58,92 @@ public Boolean supportDialect(DSLDialectType dialectType) { @Override public AssignerNodeData parse(Map data) { - ObjectMapper objectMapper = new ObjectMapper(); - objectMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); - objectMapper.setPropertyNamingStrategy(PropertyNamingStrategies.SNAKE_CASE); - - List> itemsList = (List>) data.get("items"); - List items = itemsList.stream().map(item -> { - AssignerNodeData.AssignerItem ai = new AssignerNodeData.AssignerItem(); - ai.setInputType((String) item.get("input_type")); - ai.setOperation((String) item.get("operation")); - Object valueObj = item.get("value"); - if (valueObj instanceof List valueList && valueList.size() >= 2) { - ai.setValue(new VariableSelector(valueList.get(0).toString(), valueList.get(1).toString())); - } - Object variableObj = item.get("variable_selector"); - if (variableObj instanceof List variableList && variableList.size() >= 2) { - ai.setVariableSelector( - new VariableSelector(variableList.get(0).toString(), variableList.get(1).toString())); - } - ai.setWriteMode((String) item.get("write_mode")); - return ai; - }).toList(); - AssignerNodeData nodeData = new AssignerNodeData(); + List items = Stream + .ofNullable( + MapReadUtil.safeCastToListWithMap(MapReadUtil.getMapDeepValue(data, List.class, "items"))) + .flatMap(List::stream) + .filter(map -> map.containsKey("operation") && map.containsKey("variable_selector")) + .map(map -> { + List variableSelectorList = Optional + .ofNullable(MapReadUtil.safeCastToList( + MapReadUtil.getMapDeepValue(map, List.class, "variable_selector"), String.class)) + .orElseThrow(); + VariableSelector variableSelector = new VariableSelector(variableSelectorList.get(0), + variableSelectorList.get(1)); + + AssignerNodeData.WriteMode writeMode = AssignerNodeData.WriteMode.fromDslValue( + DSLDialectType.DIFY, MapReadUtil.getMapDeepValue(map, String.class, "operation")); + + VariableSelector inputSelector = null; + String inputConst = null; + if (AssignerNodeData.WriteMode.INPUT_CONSTANT.equals(writeMode)) { + inputConst = map.get("value").toString(); + } + else if (AssignerNodeData.WriteMode.OVER_WRITE.equals(writeMode)) { + List inputSelectorList = Optional + .ofNullable(MapReadUtil.safeCastToList( + MapReadUtil.getMapDeepValue(map, List.class, "value"), String.class)) + .orElseThrow(); + inputSelector = new VariableSelector(inputSelectorList.get(0), inputSelectorList.get(1)); + } + + return new AssignerNodeData.AssignItem(variableSelector, inputSelector, writeMode, inputConst); + }) + .toList(); nodeData.setItems(items); - nodeData.setTitle((String) data.get("title")); - nodeData.setDesc((String) data.get("desc")); - nodeData.setVersion((String) data.get("version")); - return nodeData; } @Override public Map dump(AssignerNodeData nodeData) { - Map dataMap = new HashMap<>(); - dataMap.put("type", "assigner"); - dataMap.put("title", nodeData.getTitle()); - dataMap.put("desc", nodeData.getDesc()); - dataMap.put("version", nodeData.getVersion()); - List> itemsList = nodeData.getItems() - .stream() - .map(item -> Map.of("input_type", item.getInputType(), "operation", item.getOperation(), "value", - item.getValue(), "variable_selector", item.getVariableSelector(), "write_mode", - item.getWriteMode())) + throw new UnsupportedOperationException(); + } + }), STUDIO(new DialectConverter<>() { + @Override + public Boolean supportDialect(DSLDialectType dialectType) { + return DSLDialectType.STUDIO.equals(dialectType); + } + + @Override + public AssignerNodeData parse(Map data) throws JsonProcessingException { + AssignerNodeData nodeData = new AssignerNodeData(); + List items = Stream + .ofNullable(MapReadUtil.safeCastToListWithMap( + MapReadUtil.getMapDeepValue(data, List.class, "config", "node_param", "inputs"))) + .flatMap(List::stream) + .filter(map -> map.containsKey("left") && map.containsKey("right")) + .map(map -> { + VariableSelector targetSelector = this.varTemplateToSelector(DSLDialectType.STUDIO, + MapReadUtil.getMapDeepValue(map, String.class, "left", "value")); + + AssignerNodeData.WriteMode writeMode = AssignerNodeData.WriteMode.fromDslValue( + DSLDialectType.STUDIO, + MapReadUtil.getMapDeepValue(map, String.class, "right", "value_from")); + VariableSelector inputSelector = null; + String inputConst = null; + if (AssignerNodeData.WriteMode.INPUT_CONSTANT.equals(writeMode)) { + inputConst = MapReadUtil.getMapDeepValue(map, String.class, "right", "value"); + } + else if (AssignerNodeData.WriteMode.OVER_WRITE.equals(writeMode)) { + inputSelector = this.varTemplateToSelector(DSLDialectType.STUDIO, + MapReadUtil.getMapDeepValue(map, String.class, "right", "value")); + } + + return new AssignerNodeData.AssignItem(targetSelector, inputSelector, writeMode, inputConst); + }) .toList(); - dataMap.put("items", itemsList); + nodeData.setItems(items); + return nodeData; + } - Map ret = new HashMap<>(); - ret.put("data", dataMap); - return ret; + @Override + public Map dump(AssignerNodeData nodeData) { + throw new UnsupportedOperationException(); } - }), CUSTOM(AbstractNodeDataConverter.defaultCustomDialectConverter(AssignerNodeData.class)); + }) + + , CUSTOM(AbstractNodeDataConverter.defaultCustomDialectConverter(AssignerNodeData.class)); private final DialectConverter dialectConverter; @@ -129,14 +165,17 @@ public String generateVarName(int count) { @Override public BiConsumer> postProcessConsumer(DSLDialectType dialectType) { return switch (dialectType) { - case DIFY -> emptyProcessConsumer().andThen((nodeData, idToVarName) -> { - // 将赋值的多组变量放进Inputs里,方便格式化格式 - List selectors = nodeData.getItems() - .stream() - .flatMap(item -> Stream.of(item.getValue(), item.getVariableSelector())) - .toList(); - nodeData.setInputs(selectors); - }).andThen(super.postProcessConsumer(dialectType)); + case DIFY, STUDIO -> super.postProcessConsumer(dialectType).andThen((nodeData, idToVarName) -> { + nodeData.getItems().forEach(item -> { + Consumer consumer = selector -> { + selector + .setNameInCode(idToVarName.getOrDefault(selector.getNamespace(), selector.getNamespace()) + + "_" + selector.getName()); + }; + Optional.ofNullable(item.targetSelector()).ifPresent(consumer); + Optional.ofNullable(item.inputSelector()).ifPresent(consumer); + }); + }); default -> super.postProcessConsumer(dialectType); }; } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/AssignerNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/AssignerNodeSection.java index 1bfdb1e242..107c3c20e5 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/AssignerNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/AssignerNodeSection.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package com.alibaba.cloud.ai.studio.admin.generator.service.generator.workflow.sections; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.Node; @@ -20,6 +21,7 @@ import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata.AssignerNodeData; import com.alibaba.cloud.ai.studio.admin.generator.service.generator.workflow.NodeSection; +import com.alibaba.cloud.ai.studio.admin.generator.utils.ObjectToCodeUtil; import org.springframework.stereotype.Component; import java.util.List; @@ -34,24 +36,15 @@ public boolean support(NodeType nodeType) { @Override public String render(Node node, String varName) { - AssignerNodeData data = (AssignerNodeData) node.getData(); - String id = node.getId(); - - StringBuilder sb = new StringBuilder(); - sb.append(String.format("// —— AssignerNode [%s] ——%n", id)); - sb.append(String.format("AssignerNode %s = AssignerNode.builder()%n", varName)); - for (AssignerNodeData.AssignerItem item : data.getItems()) { - String targetKey = item.getVariableSelector() != null ? item.getVariableSelector().getNameInCode() - : "target"; - String inputKey = (item.getValue() != null) ? item.getValue().getNameInCode() : null; - String writeMode = item.getWriteMode() != null ? item.getWriteMode().toUpperCase().replace("-", "_") - : "OVER_WRITE"; - sb.append(String.format(".addItem(\"%s\", %s, AssignerNode.WriteMode.%s)%n", targetKey, - inputKey == null ? "null" : "\"" + inputKey + "\"", writeMode)); - } - sb.append(".build();\n"); - sb.append(String.format("stateGraph.addNode(\"%s\", AsyncNodeAction.node_async(%s));%n%n", varName, varName)); - return sb.toString(); + AssignerNodeData nodeData = ((AssignerNodeData) node.getData()); + return String.format(""" + // —— AssignerNode [%s] —— + AssignerNode %s = AssignerNode.builder() + .setItems(%s) + .build(); + stateGraph.addNode("%s", AsyncNodeAction.node_async(%s)); + + """, node.getId(), varName, ObjectToCodeUtil.toCode(nodeData.getItems()), varName, varName); } @Override From 95a565545cdd5418b1e2a31316821ef1e86c7951 Mon Sep 17 00:00:00 2001 From: VLSMB <2047857654@qq.com> Date: Tue, 16 Sep 2025 10:05:00 +0800 Subject: [PATCH 08/12] fix bugs --- .../workflow/nodedata/AssignerNodeData.java | 5 +- .../service/dsl/adapters/DifyDSLAdapter.java | 46 ++++++++----------- 2 files changed, 22 insertions(+), 29 deletions(-) diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/AssignerNodeData.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/AssignerNodeData.java index 1542bbb4a9..74fbf45238 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/AssignerNodeData.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/AssignerNodeData.java @@ -41,8 +41,9 @@ public record AssignItem(VariableSelector targetSelector, VariableSelector input @Override public String toString() { return String.format("new AssignerNode.AssignItem(%s, %s, %s, %s)", - ObjectToCodeUtil.toCode(this.targetSelector().getNameInCode()), - ObjectToCodeUtil.toCode(this.inputSelector().getNameInCode()), + ObjectToCodeUtil + .toCode(this.targetSelector() != null ? this.targetSelector().getNameInCode() : null), + ObjectToCodeUtil.toCode(this.inputSelector() != null ? this.inputSelector().getNameInCode() : null), ObjectToCodeUtil.toCode(this.writeMode()), ObjectToCodeUtil.toCode(this.inputConst())); } } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/DifyDSLAdapter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/DifyDSLAdapter.java index 75690d4bcb..7038c2c598 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/DifyDSLAdapter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/adapters/DifyDSLAdapter.java @@ -111,7 +111,7 @@ public Map metadataToMap(AppMetadata metadata) { @Override public Workflow mapToWorkflow(Map data) { - Map workflowData = (Map) data.get("workflow"); + Map workflowData = MapReadUtil.safeCastToMapWithStringKey(data.get("workflow")); Workflow workflow = new Workflow(); ObjectMapper objectMapper = new ObjectMapper(); objectMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); @@ -119,21 +119,16 @@ public Workflow mapToWorkflow(Map data) { objectMapper.setPropertyNamingStrategy(PropertyNamingStrategies.SNAKE_CASE); List convVars = new ArrayList<>(); if (workflowData.containsKey("conversation_variables")) { - List> variables = (List>) workflowData - .get("conversation_variables"); - convVars = variables.stream() - .map(variable -> convertToVariable(variable, objectMapper)) - .peek(v -> v.setName("conversation_" + v.getName())) - .toList(); + List> variables = MapReadUtil + .safeCastToListWithMap(workflowData.get("conversation_variables")); + convVars = variables.stream().map(this::convertToVariable).toList(); } List envVars = List.of(); if (workflowData.containsKey("environment_variables")) { - List> variables = (List>) workflowData.get("environment_variables"); - envVars = variables.stream() - .map(variable -> convertToVariable(variable, objectMapper)) - .peek(v -> v.setName("env_" + v.getName())) - .toList(); + List> variables = MapReadUtil + .safeCastToListWithMap(workflowData.get("environment_variables")); + envVars = variables.stream().map(this::convertToVariable).toList(); } List sysVars = List.of(new Variable("sys_query", VariableType.STRING), new Variable("sys_files", VariableType.ARRAY_FILE), @@ -144,7 +139,7 @@ public Workflow mapToWorkflow(Map data) { new Variable("sys_workflow_run_id", VariableType.STRING)); workflow.setEnvVars(Stream.of(envVars, sysVars).flatMap(List::stream).toList()); - Graph graph = constructGraph((Map) workflowData.get("graph")); + Graph graph = constructGraph(MapReadUtil.safeCastToMapWithStringKey(workflowData.get("graph"))); workflow.setGraph(graph); // register overAllState output key @@ -404,20 +399,17 @@ public Boolean supportDialect(DSLDialectType dialectType) { return DSLDialectType.DIFY.equals(dialectType); } - private Variable convertToVariable(Map variableMap, ObjectMapper objectMapper) { - try { - Map processedMap = new HashMap<>(variableMap); - - Object value = processedMap.get("value"); - if (value != null && !(value instanceof String)) { - processedMap.put("value", objectMapper.writeValueAsString(value)); - } - - return objectMapper.convertValue(processedMap, Variable.class); - } - catch (Exception e) { - throw new IllegalArgumentException("Failed to convert variable: " + variableMap, e); - } + private Variable convertToVariable(Map variableMap) { + String name = String.join("_", + Optional.ofNullable(MapReadUtil.safeCastToList(variableMap.get("selector"), String.class)) + .orElseThrow(() -> new IllegalArgumentException("Invalid variable selector"))); + String value = Optional.ofNullable(variableMap.get("value")).map(Object::toString).orElse(null); + VariableType type = VariableType + .fromDifyValue(Optional.ofNullable(variableMap.get("value_type")) + .map(Object::toString) + .orElse(VariableType.OBJECT.difyValue())) + .orElse(VariableType.OBJECT); + return new Variable(name, type).setValue(value); } } From a02c88af2d940e0714e94fa69abae6c167efdf45 Mon Sep 17 00:00:00 2001 From: VLSMB <2047857654@qq.com> Date: Tue, 16 Sep 2025 12:06:38 +0800 Subject: [PATCH 09/12] feat: enhance ParameterParsingNode --- .../ai/graph/node/ParameterParsingNode.java | 267 ++++++++++++------ .../graph/node/ParameterParsingNodeTest.java | 98 +++++++ 2 files changed, 275 insertions(+), 90 deletions(-) create mode 100644 spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/node/ParameterParsingNodeTest.java diff --git a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/ParameterParsingNode.java b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/ParameterParsingNode.java index f464b01141..7bb083bc17 100644 --- a/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/ParameterParsingNode.java +++ b/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/ParameterParsingNode.java @@ -17,13 +17,15 @@ import com.alibaba.cloud.ai.graph.OverAllState; import com.alibaba.cloud.ai.graph.action.NodeAction; -import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.util.StringUtils; @@ -31,140 +33,203 @@ import java.util.HashMap; import java.util.List; import java.util.Map; - +import java.util.Optional; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * 从自然语言中提取参数。返回三部分:{@code is_success}、{@code data}、{@code reason} + */ public class ParameterParsingNode implements NodeAction { private static final String PARAMETER_PARSING_PROMPT_TEMPLATE = """ ### Role - You are a JSON-based structured data extractor. Your task is to extract parameter values - from user input. + You are a JSON-based structured data extractor. Your task is to extract parameter values from user input and return a structured response. + ### Task - Given user input and a list of expected parameters (with names, types, and descriptions - Type can be "string", "number", "boolean", or "array"), - return a valid JSON object containing those parameter values. + Given user input and a list of expected parameters (with names, types, and descriptions; type can be "String", "Number", "Boolean", or "List"), return a JSON object with the following structure: + \\{ + "data": \\{ ... \\}, // extracted parameter values + "is_success": boolean, // true if extraction successful + "reason": string // set to 'success' when successful, explain the reason for failure + \\} + ### Input Text: {inputText} + ### Parameters: {parameters} + ### Output Constraints - - Return ONLY a valid JSON object containing all defined keys. - - Missing values must be set to null. - - DO NOT include any explanation, markdown, or preamble. - - Output must be directly parsable as JSON. + - ALWAYS return a valid JSON object with exactly three keys: "data", "is_success", and "reason" + - For successful extraction: + - "is_success": true + - "data": contains all defined parameters with extracted values (null for missing values) + - "reason": "success" + - For unsuccessful extraction (e.g., input is empty, completely irrelevant, or missing all required parameters): + - "is_success": false + - "data": null + - "reason": brief explanation of failure (e.g., "input is empty", "no relevant parameters found") + - DO NOT include any explanation, markdown, or preamble + - Output must be directly parsable as JSON + - Ensure the JSON is properly formatted and escaped """; private static final String PARAMETER_PARSING_USER_PROMPT_1 = """ - { "input_text":[" Please help me check the paper, paper number: 2405.10739 ."], - "Parameters":{"name":[paper_num],“type”:[string],"description":["paper number"]}} + { "input_text": "[Instruction: Please help me check the paper] paper number: 2405.10739 .", + "Parameters":{"name": "paper_num", "type": "String", "description": "paper number"}} """; private static final String PARAMETER_PARSING_ASSISTANT_PROMPT_1 = """ - json - {"paper_num": "2405.10739"} + {"is_success": true, "data": {"paper_num": "2405.10739"}, "reason": "success"} """; private static final String PARAMETER_PARSING_USER_PROMPT_2 = """ - { "input_text":[" Chapter 1: Encounters. The sun shines in the forest, and the young man sees the girl for the first time. - Chapter 2: The Storm. The village was attacked, and its fate changed dramatically. - Chapter 3: Departure. They embark on a journey to find the truth."], - "Parameters":{"name":[array_of_story_outlines],“type”:[array],"description":["the story outlines"]}} + { "input_text": null, + "Parameters":{"name": "array_of_story_outlines", "type": "List","description": "the story outlines"}} """; private static final String PARAMETER_PARSING_ASSISTANT_PROMPT_2 = """ - json - { - "array_of_story_outlines": [ - "Chapter 1: Encounters. The sun shines in the forest, and the young man sees the girl for the first time.", - "Chapter 2: The Storm. The village was attacked, and its fate changed dramatically.", - "Chapter 3: Departure. They embark on a journey to find the truth." - ] - } + {"is_success": false, "data": null, "reason": "input_text is null."} """; + private static final SystemPromptTemplate SYSTEM_PROMPT_TEMPLATE = new SystemPromptTemplate( + PARAMETER_PARSING_PROMPT_TEMPLATE); + + private static final Pattern VAR_TEMPLATE_PATTERN = Pattern.compile("\\{(\\w+)}"); + + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private final ChatClient chatClient; - private String inputText; + private final String inputText; private final String inputTextKey; - private final List> parameters; + private final List parameters; + + private final String successKey; - private final SystemPromptTemplate systemPromptTemplate; + private final String dataKey; - private final String outputKey; + private final String reasonKey; - // 由用户提供的一个指令 + // 由用户提供的一个指令,可以有{}占位符号 private final String instruction; - public ParameterParsingNode(ChatClient chatClient, String inputTextKey, List> parameters, - String outputKey, String instruction) { + public ParameterParsingNode(ChatClient chatClient, String inputText, String inputTextKey, List parameters, + String instruction, String successKey, String dataKey, String reasonKey) { + if (chatClient == null || !StringUtils.hasText(successKey) || !StringUtils.hasText(dataKey) + || !StringUtils.hasText(reasonKey)) { + throw new IllegalArgumentException("There are some empty fields"); + } this.chatClient = chatClient; + this.inputText = inputText; this.inputTextKey = inputTextKey; this.parameters = parameters; - this.outputKey = outputKey; - this.systemPromptTemplate = new SystemPromptTemplate(PARAMETER_PARSING_PROMPT_TEMPLATE); this.instruction = instruction; + this.successKey = successKey; + this.dataKey = dataKey; + this.reasonKey = reasonKey; + } + + private String renderTemplate(OverAllState state, String template) { + Map params = Stream.of(template) + .map(VAR_TEMPLATE_PATTERN::matcher) + .map(Matcher::results) + .map(results -> results.collect(Collectors.toUnmodifiableMap(r -> r.group(1), + r -> state.value(r.group(1)).orElse(""), (a, b) -> b))) + .findFirst() + .orElseThrow(); + return new PromptTemplate(template).render(params); } @Override public Map apply(OverAllState state) throws Exception { + try { + String currentInputText = this.inputText; + if (StringUtils.hasText(inputTextKey)) { + currentInputText = (String) state.value(inputTextKey).orElse(currentInputText); + } + if (!StringUtils.hasText(currentInputText)) { + throw new IllegalArgumentException("inputText is empty."); + } - if (StringUtils.hasLength(inputTextKey)) { - this.inputText = (String) state.value(inputTextKey).orElse(this.inputText); - } + Map promptInput = new HashMap<>(); + promptInput.put("inputText", + String.format("[Instruction: %s] %s", renderTemplate(state, instruction), currentInputText)); + promptInput.put("parameters", OBJECT_MAPPER.writeValueAsString(parameters)); + + List messages = new ArrayList<>(); + UserMessage userMessage1 = new UserMessage(PARAMETER_PARSING_USER_PROMPT_1); + AssistantMessage assistantMessage1 = new AssistantMessage(PARAMETER_PARSING_ASSISTANT_PROMPT_1); + UserMessage userMessage2 = new UserMessage(PARAMETER_PARSING_USER_PROMPT_2); + AssistantMessage assistantMessage2 = new AssistantMessage(PARAMETER_PARSING_ASSISTANT_PROMPT_2); + messages.add(userMessage1); + messages.add(assistantMessage1); + messages.add(userMessage2); + messages.add(assistantMessage2); + + ChatResponse response = chatClient.prompt() + .system(SYSTEM_PROMPT_TEMPLATE.render(promptInput)) + .user(currentInputText) + .messages(messages) + .call() + .chatResponse(); + + String rawJson = Optional.ofNullable(response) + .orElseThrow(() -> new RuntimeException("chat response is null")) + .getResult() + .getOutput() + .getText(); + // 去掉Markdown标记 + if (rawJson != null) { + rawJson = rawJson.replace("```json", "").replace("```", "").trim(); + } - Map promptInput = new HashMap<>(); - promptInput.put("inputText", String.format("[Instruction: %s]%n", instruction) + inputText); - promptInput.put("parameters", formatParameters(parameters)); - - List messages = new ArrayList<>(); - UserMessage userMessage1 = new UserMessage(PARAMETER_PARSING_USER_PROMPT_1); - AssistantMessage assistantMessage1 = new AssistantMessage(PARAMETER_PARSING_ASSISTANT_PROMPT_1); - UserMessage userMessage2 = new UserMessage(PARAMETER_PARSING_USER_PROMPT_2); - AssistantMessage assistantMessage2 = new AssistantMessage(PARAMETER_PARSING_ASSISTANT_PROMPT_2); - messages.add(userMessage1); - messages.add(assistantMessage1); - messages.add(userMessage2); - messages.add(assistantMessage2); - - ChatResponse response = chatClient.prompt() - .system(systemPromptTemplate.render(promptInput)) - .user(inputText) - .messages(messages) - .call() - .chatResponse(); - - String rawJson = response.getResult().getOutput().getText(); - // 去掉Markdown标记 - if (rawJson != null) { - rawJson = rawJson.replace("```json", "").replace("```", "").trim(); - } - ObjectMapper mapper = new ObjectMapper(); - Map updateState = new HashMap<>(); - try { - updateState.put(outputKey, mapper.readValue(rawJson, new TypeReference<>() { - })); + Map result = new HashMap<>(); + Response responseJson; + try { + responseJson = OBJECT_MAPPER.readValue(rawJson, Response.class); + } + catch (JsonProcessingException e) { + throw new RuntimeException("ChatClient successfully returned, but the returned json is invalid."); + } + + if (responseJson.isSuccess()) { + if (responseJson.data() == null) { + throw new RuntimeException("ChatClient successfully returned, but the returned data is invalid."); + } + result.put(successKey, true); + result.put(dataKey, responseJson.data()); + result.put(reasonKey, "success"); + } + else { + result.put(successKey, false); + result.put(reasonKey, Optional.ofNullable(responseJson.reason()).orElse("reason is empty")); + } + + return result; } catch (Exception e) { - throw new RuntimeException("Invalid JSON response from model: " + rawJson, e); + return Map.of(successKey, false, reasonKey, e.getMessage()); } - return updateState; } - private String formatParameters(List> parameters) { - StringBuilder builder = new StringBuilder(); - for (Map param : parameters) { - builder.append("- ") - .append(param.get("name")) - .append(" (") - .append(param.get("type")) - .append("): ") - .append(param.get("description")) - .append("\n"); - } - return builder.toString(); + public record Param(String name, String type, String description) { + + } + + private record Response(@JsonProperty("is_success") boolean isSuccess, Map data, String reason) { + + } + + public static Param param(String name, String type, String description) { + return new Param(name, type, description); } public static Builder builder() { @@ -173,16 +238,27 @@ public static Builder builder() { public static class Builder { + private String inputText = ""; + private String inputTextKey; private ChatClient chatClient; - private List> parameters; + private List parameters; - private String outputKey; + private String successKey = "is_success"; + + private String dataKey = "data"; + + private String reasonKey = "reason"; private String instruction = ""; + public Builder inputText(String inputText) { + this.inputText = inputText; + return this; + } + public Builder inputTextKey(String input) { this.inputTextKey = input; return this; @@ -193,13 +269,23 @@ public Builder chatClient(ChatClient chatClient) { return this; } - public Builder parameters(List> parameters) { + public Builder parameters(List parameters) { this.parameters = parameters; return this; } - public Builder outputKey(String outputKey) { - this.outputKey = outputKey; + public Builder successKey(String successKey) { + this.successKey = successKey; + return this; + } + + public Builder dataKey(String dataKey) { + this.dataKey = dataKey; + return this; + } + + public Builder reasonKey(String reasonKey) { + this.reasonKey = reasonKey; return this; } @@ -209,7 +295,8 @@ public Builder instruction(String instruction) { } public ParameterParsingNode build() { - return new ParameterParsingNode(chatClient, inputTextKey, parameters, outputKey, instruction); + return new ParameterParsingNode(chatClient, inputText, inputTextKey, parameters, instruction, successKey, + dataKey, reasonKey); } } diff --git a/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/node/ParameterParsingNodeTest.java b/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/node/ParameterParsingNodeTest.java new file mode 100644 index 0000000000..28fe0305d2 --- /dev/null +++ b/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/node/ParameterParsingNodeTest.java @@ -0,0 +1,98 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.cloud.ai.graph.node; + +import com.alibaba.cloud.ai.dashscope.api.DashScopeApi; +import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel; +import com.alibaba.cloud.ai.graph.OverAllState; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.model.ChatModel; + +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@EnabledIfEnvironmentVariable(named = "AI_DASHSCOPE_API_KEY", matches = ".+") +public class ParameterParsingNodeTest { + + private ChatClient chatClient; + + @BeforeEach + public void setUp() { + DashScopeApi dashScopeApi = DashScopeApi.builder().apiKey(System.getenv("AI_DASHSCOPE_API_KEY")).build(); + ChatModel chatModel = DashScopeChatModel.builder().dashScopeApi(dashScopeApi).build(); + chatClient = ChatClient.builder(chatModel).build(); + } + + private OverAllState createState(Map map) { + OverAllState state = new OverAllState(); + state.updateState(map); + return state; + } + + @Test + public void testSuccess() throws Exception { + ParameterParsingNode node = ParameterParsingNode.builder() + .inputText("") + .inputTextKey("input") + .chatClient(chatClient) + .parameters(List.of(ParameterParsingNode.param("name", "String", "The name of the person"), + ParameterParsingNode.param("age", "Number", "The age of the person"))) + .successKey("success") + .dataKey("data") + .reasonKey("reason") + .instruction("Parse the input text into a JSON object with the following keys: name, age") + .build(); + OverAllState state = createState(Map.of("input", "My name is Kanbe Kotori and I am 20 years old.")); + Map result = node.apply(state); + System.out.println(result); + assertNotNull(result); + assertEquals(Map.of("success", true, "data", Map.of("name", "Kanbe Kotori", "age", 20), "reason", "success"), + result); + } + + @Test + public void testFail() throws Exception { + ParameterParsingNode node = ParameterParsingNode.builder() + .inputText("") + .inputTextKey("input") + .chatClient(chatClient) + .parameters(List.of(ParameterParsingNode.param("name", "String", "The name of the person"), + ParameterParsingNode.param("age", "Number", "The age of the person"))) + .successKey("success") + .dataKey("data") + .reasonKey("reason") + .instruction("Parse the input text into a JSON object with the following keys: name, age") + .build(); + OverAllState state = createState(Map.of()); + Map result = node.apply(state); + System.out.println(result); + assertNotNull(result); + assertTrue(result.containsKey("success")); + assertTrue(result.containsKey("reason")); + assertFalse(result.containsKey("data")); + assertFalse((Boolean) result.get("success")); + } + +} From b7528c8c7248243dfadc451d207b60974538a888 Mon Sep 17 00:00:00 2001 From: VLSMB <2047857654@qq.com> Date: Tue, 16 Sep 2025 12:47:02 +0800 Subject: [PATCH 10/12] refactor: exact common code --- .../dsl/AbstractNodeDataConverter.java | 48 +++++++++++++++++++ .../dsl/converter/LLMNodeDataConverter.java | 22 ++------- .../QuestionClassifyNodeDataConverter.java | 20 ++------ .../workflow/sections/LLMNodeSection.java | 8 ++-- .../QuestionClassifierNodeSection.java | 8 ++-- .../templates/GraphBuilder.java.mustache | 2 +- 6 files changed, 65 insertions(+), 43 deletions(-) diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/AbstractNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/AbstractNodeDataConverter.java index 4532460de6..57e03cac99 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/AbstractNodeDataConverter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/AbstractNodeDataConverter.java @@ -17,14 +17,17 @@ import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.regex.MatchResult; import java.util.regex.Matcher; import java.util.regex.Pattern; +import java.util.stream.Collectors; import com.alibaba.cloud.ai.studio.admin.generator.model.VariableSelector; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeData; +import com.alibaba.cloud.ai.studio.admin.generator.utils.MapReadUtil; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; @@ -96,6 +99,51 @@ default VariableSelector varTemplateToSelector(DSLDialectType dialectType, Strin return new VariableSelector(result.group(1), result.group(2)); } + /** + * 从data中获取模型名称(LLMNode、ClassifierNode等共同使用) + * @param dialectType dsl语言 + * @param data 节点数据 + * @return 模型名称 + */ + default String exactChatModelName(DSLDialectType dialectType, Map data) { + return switch (dialectType) { + case DIFY -> MapReadUtil.getMapDeepValue(data, String.class, "model", "name"); + case STUDIO -> { + Map modeConfigMap = MapReadUtil.safeCastToMapWithStringKey( + MapReadUtil.getMapDeepValue(data, Map.class, "config", "node_param", "model_config")); + yield MapReadUtil.getMapDeepValue(modeConfigMap, String.class, "model_id"); + } + default -> throw new UnsupportedOperationException(); + }; + } + + /** + * 从data中获取模型参数(LLMNode、ClassifierNode等共同使用) + * @param dialectType dsl语言 + * @param data 节点数据 + * @return 模型参数 + */ + default Map exactChatModelParam(DSLDialectType dialectType, Map data) { + return switch (dialectType) { + case DIFY -> MapReadUtil.safeCastToMapWithStringKey( + MapReadUtil.getMapDeepValue(data, Map.class, "model", "completion_params")); + case STUDIO -> { + Map modeConfigMap = MapReadUtil.safeCastToMapWithStringKey( + MapReadUtil.getMapDeepValue(data, Map.class, "config", "node_param", "model_config")); + yield Optional + .ofNullable(MapReadUtil + .safeCastToListWithMap(MapReadUtil.getMapDeepValue(modeConfigMap, List.class, "params"))) + .orElse(List.of()) + .stream() + .filter(map -> Boolean.TRUE.equals(map.get("enable"))) + .filter(map -> map.containsKey("key") && map.containsKey("value")) + .map(map -> Map.entry(map.get("key").toString(), map.get("value"))) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (a, b) -> b)); + } + default -> throw new UnsupportedOperationException(); + }; + } + } public static DialectConverter defaultCustomDialectConverter(Class clazz) { diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/LLMNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/LLMNodeDataConverter.java index 643f359220..170246187f 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/LLMNodeDataConverter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/LLMNodeDataConverter.java @@ -59,9 +59,8 @@ public LLMNodeData parse(Map data) { LLMNodeData nodeData = new LLMNodeData(); // 获取必要的信息 - String modeName = MapReadUtil.getMapDeepValue(data, String.class, "model", "name"); - Map modeParams = MapReadUtil.safeCastToMapWithStringKey( - MapReadUtil.getMapDeepValue(data, Map.class, "model", "completion_params")); + String modeName = this.exactChatModelName(DSLDialectType.DIFY, data); + Map modeParams = this.exactChatModelParam(DSLDialectType.DIFY, data); // MessageTemplate的keys字段将在postProcess中确定,所以这里先设置为空 List messageTemplates = Optional @@ -113,7 +112,7 @@ public Boolean supportDialect(DSLDialectType dialect) { } }) - , STUDIO(new DialectConverter() { + , STUDIO(new DialectConverter<>() { @Override public Boolean supportDialect(DSLDialectType dialectType) { return DSLDialectType.STUDIO.equals(dialectType); @@ -124,19 +123,8 @@ public LLMNodeData parse(Map data) throws JsonProcessingExceptio LLMNodeData nodeData = new LLMNodeData(); // 从data中提取必要信息 - Map modeConfigMap = MapReadUtil.safeCastToMapWithStringKey( - MapReadUtil.getMapDeepValue(data, Map.class, "config", "node_param", "model_config")); - String modeName = MapReadUtil.getMapDeepValue(modeConfigMap, String.class, "model_id"); - - Map modeParams = Optional - .ofNullable(MapReadUtil - .safeCastToListWithMap(MapReadUtil.getMapDeepValue(modeConfigMap, List.class, "params"))) - .orElse(List.of()) - .stream() - .filter(map -> Boolean.TRUE.equals(map.get("enable"))) - .filter(map -> map.containsKey("key") && map.containsKey("value")) - .map(map -> Map.entry(map.get("key").toString(), map.get("value"))) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (a, b) -> b)); + String modeName = this.exactChatModelName(DSLDialectType.STUDIO, data); + Map modeParams = this.exactChatModelParam(DSLDialectType.STUDIO, data); String systemPrompt = MapReadUtil.getMapDeepValue(data, String.class, "config", "node_param", "sys_prompt_content"); diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/QuestionClassifyNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/QuestionClassifyNodeDataConverter.java index bf37fe5295..14ae6f83cd 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/QuestionClassifyNodeDataConverter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/QuestionClassifyNodeDataConverter.java @@ -65,9 +65,8 @@ public QuestionClassifierNodeData parse(Map data) { QuestionClassifierNodeData nodeData = new QuestionClassifierNodeData(); // 获取必要的信息 - String modeName = MapReadUtil.getMapDeepValue(data, String.class, "model", "name"); - Map modeParams = MapReadUtil.safeCastToMapWithStringKey( - MapReadUtil.getMapDeepValue(data, Map.class, "model", "completion_params")); + String modeName = this.exactChatModelName(DSLDialectType.DIFY, data); + Map modeParams = this.exactChatModelParam(DSLDialectType.DIFY, data); List inputSelectorList = Optional .ofNullable(MapReadUtil.safeCastToList( MapReadUtil.getMapDeepValue(data, List.class, "query_variable_selector"), String.class)) @@ -113,18 +112,9 @@ public Boolean supportDialect(DSLDialectType dialectType) { public QuestionClassifierNodeData parse(Map data) throws JsonProcessingException { QuestionClassifierNodeData nodeData = new QuestionClassifierNodeData(); // 从data中提取必要信息 - Map modeConfigMap = MapReadUtil.safeCastToMapWithStringKey( - MapReadUtil.getMapDeepValue(data, Map.class, "config", "node_param", "model_config")); - String modeName = MapReadUtil.getMapDeepValue(modeConfigMap, String.class, "model_id"); - Map modeParams = Optional - .ofNullable(MapReadUtil - .safeCastToListWithMap(MapReadUtil.getMapDeepValue(modeConfigMap, List.class, "params"))) - .orElse(List.of()) - .stream() - .filter(map -> Boolean.TRUE.equals(map.get("enable"))) - .filter(map -> map.containsKey("key") && map.containsKey("value")) - .map(map -> Map.entry(map.get("key").toString(), map.get("value"))) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (a, b) -> b)); + String modeName = this.exactChatModelName(DSLDialectType.STUDIO, data); + Map modeParams = this.exactChatModelParam(DSLDialectType.STUDIO, data); + VariableSelector selector = this.varTemplateToSelector(DSLDialectType.STUDIO, MapReadUtil .safeCastToListWithMap(MapReadUtil.getMapDeepValue(data, List.class, "config", "input_params")) .get(0) diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/LLMNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/LLMNodeSection.java index 2e61c9229d..dbcfeeee66 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/LLMNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/LLMNodeSection.java @@ -42,7 +42,7 @@ public String render(Node node, String varName) { return String.format(""" // —— LLMNode [%s] —— stateGraph.addNode("%s", AsyncNodeAction.node_async( - createLLMNodeAction(%s, %s, %s, %s, %s, %s, %s, %s, %s) + createLLMNodeAction(chatModel, %s, %s, %s, %s, %s, %s, %s, %s, %s) )); """, node.getId(), varName, ObjectToCodeUtil.toCode(nodeData.getChatModeName()), @@ -60,9 +60,6 @@ public String assistMethodCode(DSLDialectType dialectType) { return String.format( """ - @Autowired - private ChatModel chatModel; - private record MessageTemplate(String template, List keys, MessageType type) { public Message render(OverAllState state) { Map params = keys.stream() @@ -77,7 +74,8 @@ public Message render(OverAllState state) { } } - private NodeAction createLLMNodeAction(String chatModelName, Map modeParams, + private NodeAction createLLMNodeAction(ChatModel chatModel, + String chatModelName, Map modeParams, List messageTemplates, String memoryKey, Integer maxRetryCount, Integer retryIntervalMs, String defaultOutput, String errorNextNode, String outputKeyPrefix) { // build chatClient with params diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/QuestionClassifierNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/QuestionClassifierNodeSection.java index c1a498b450..99ef9f8cec 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/QuestionClassifierNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/QuestionClassifierNodeSection.java @@ -44,7 +44,7 @@ public String render(Node node, String varName) { return String.format(""" // —— QuestionClassifierNode [%s] —— stateGraph.addNode("%s", AsyncNodeAction.node_async( - createQuestionClassifierAction(%s, %s, "%s", "%s", %s, %s) + createQuestionClassifierAction(chatModel, %s, %s, "%s", "%s", %s, %s) )); """, node.getId(), varName, ObjectToCodeUtil.toCode(nodeData.getChatModeName()), @@ -94,10 +94,8 @@ public String assistMethodCode(DSLDialectType dialectType) { return switch (dialectType) { case DIFY, STUDIO -> """ - @Autowired - private ChatModel chatModelForQuestion; - private NodeAction createQuestionClassifierAction( + ChatModel chatModel, String chatModelName, Map modeParams, String inputKey, String outputKey, Map categories, List instructions) { @@ -112,7 +110,7 @@ private NodeAction createQuestionClassifierAction( .ifPresent(val -> chatOptionsBuilder.withMaxToken(val.intValue())); Optional.ofNullable(modeParams.get("repetition_penalty")) .ifPresent(val -> chatOptionsBuilder.withRepetitionPenalty(val.doubleValue())); - final ChatClient chatClient = ChatClient.builder(chatModelForQuestion).defaultOptions(chatOptionsBuilder.build()).build(); + final ChatClient chatClient = ChatClient.builder(chatModel).defaultOptions(chatOptionsBuilder.build()).build(); // build Node return QuestionClassifierNode.builder() diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/resources/templates/GraphBuilder.java.mustache b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/resources/templates/GraphBuilder.java.mustache index d6431ab57a..07f4dc12ed 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/resources/templates/GraphBuilder.java.mustache +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/resources/templates/GraphBuilder.java.mustache @@ -8,7 +8,7 @@ public class GraphBuilder { {{assistMethodCode}} @Bean - public CompiledGraph buildGraph() throws Exception { + public CompiledGraph buildGraph(ChatModel chatModel) throws Exception { // new stateGraph StateGraph stateGraph = new StateGraph({{stateSection}}); // add nodes From 7e9b39b34b0f14e0bf106c61a5bce26361663460 Mon Sep 17 00:00:00 2001 From: VLSMB <2047857654@qq.com> Date: Tue, 16 Sep 2025 18:16:05 +0800 Subject: [PATCH 11/12] feat: support ParameterParsingNode for Studio DSL --- .../generator/model/workflow/NodeType.java | 2 +- .../nodedata/ParameterParsingNodeData.java | 105 +++++++++--- .../ParameterParsingNodeDataConverter.java | 155 ++++++++++++------ .../sections/ParameterParsingNodeSection.java | 113 ++++++------- 4 files changed, 245 insertions(+), 130 deletions(-) diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java index eb61c9c7b1..34820c2e90 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/NodeType.java @@ -51,7 +51,7 @@ public enum NodeType { LIST_OPERATOR("list-operator", "list-operator", "UNSUPPORTED"), - PARAMETER_PARSING("parameter-parsing", "parameter-extractor", "UNSUPPORTED"), + PARAMETER_PARSING("parameter-parsing", "parameter-extractor", "ParameterExtractor"), TOOL("tool", "tool", "UNSUPPORTED"), diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/ParameterParsingNodeData.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/ParameterParsingNodeData.java index 0fcf6391fc..3a1c93be3c 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/ParameterParsingNodeData.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/model/workflow/nodedata/ParameterParsingNodeData.java @@ -16,48 +16,83 @@ package com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata; -import java.util.List; -import java.util.Map; - +import com.alibaba.cloud.ai.studio.admin.generator.model.Variable; import com.alibaba.cloud.ai.studio.admin.generator.model.VariableSelector; +import com.alibaba.cloud.ai.studio.admin.generator.model.VariableType; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeData; +import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; +import com.alibaba.cloud.ai.studio.admin.generator.utils.ObjectToCodeUtil; + +import java.util.List; +import java.util.Map; +import java.util.Optional; -/** - * NodeData for ParameterParsingNode, which contains three fields: inputTextKey, - * parameters, and outputKey. - */ public class ParameterParsingNodeData extends NodeData { - private String inputTextKey; + public static List getDefaultOutputSchema(DSLDialectType dialectType) { + return switch (dialectType) { + case DIFY -> List.of(new Variable("__is_success", VariableType.BOOLEAN), + new Variable("__reason", VariableType.STRING)); + case STUDIO -> List.of(new Variable("_is_completed", VariableType.BOOLEAN), + new Variable("_reason", VariableType.STRING)); + default -> List.of(); + }; + } + + private VariableSelector inputSelector; + + private String chatModeName; - private List> parameters; + private Map modeParams; + + private List parameters; private String instruction; - private String outputKey; + private String successKey = "success"; - public ParameterParsingNodeData(String inputTextKey, List> parameters, String instruction, - String outputKey, VariableSelector input) { - super(List.of(input), List.of()); - this.inputTextKey = inputTextKey; - this.parameters = parameters; - this.instruction = instruction; - this.outputKey = outputKey; + private String dataKey = "data"; + + private String reasonKey = "reason"; + + public record Param(String name, VariableType type, String description) { + @Override + public String toString() { + return String.format("ParameterParsingNode.param(%s, %s, %s)", ObjectToCodeUtil.toCode(name()), + ObjectToCodeUtil.toCode(Optional.ofNullable(type()).orElse(VariableType.STRING).value()), + ObjectToCodeUtil.toCode(description())); + } } - public String getInputTextKey() { - return inputTextKey; + public VariableSelector getInputSelector() { + return inputSelector; } - public void setInputTextKey(String inputTextKey) { - this.inputTextKey = inputTextKey; + public void setInputSelector(VariableSelector inputSelector) { + this.inputSelector = inputSelector; } - public List> getParameters() { + public String getChatModeName() { + return chatModeName; + } + + public void setChatModeName(String chatModeName) { + this.chatModeName = chatModeName; + } + + public Map getModeParams() { + return modeParams; + } + + public void setModeParams(Map modeParams) { + this.modeParams = modeParams; + } + + public List getParameters() { return parameters; } - public void setParameters(List> parameters) { + public void setParameters(List parameters) { this.parameters = parameters; } @@ -69,12 +104,28 @@ public void setInstruction(String instruction) { this.instruction = instruction; } - public String getOutputKey() { - return outputKey; + public String getSuccessKey() { + return successKey; + } + + public void setSuccessKey(String successKey) { + this.successKey = successKey; + } + + public String getDataKey() { + return dataKey; + } + + public void setDataKey(String dataKey) { + this.dataKey = dataKey; + } + + public String getReasonKey() { + return reasonKey; } - public void setOutputKey(String outputKey) { - this.outputKey = outputKey; + public void setReasonKey(String reasonKey) { + this.reasonKey = reasonKey; } } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/ParameterParsingNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/ParameterParsingNodeDataConverter.java index eaa5626418..e7d0f0b3cb 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/ParameterParsingNodeDataConverter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/ParameterParsingNodeDataConverter.java @@ -16,9 +16,9 @@ package com.alibaba.cloud.ai.studio.admin.generator.service.dsl.converter; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.function.BiConsumer; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -31,6 +31,8 @@ import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.AbstractNodeDataConverter; import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; +import com.alibaba.cloud.ai.studio.admin.generator.utils.MapReadUtil; +import com.fasterxml.jackson.core.JsonProcessingException; import org.springframework.stereotype.Component; /** @@ -55,44 +57,43 @@ protected List> getDialectConverters( private enum ParameterParsingNodeConverter { DIFY(new DialectConverter<>() { - @SuppressWarnings("unchecked") @Override public ParameterParsingNodeData parse(Map data) { - // 获取指令 - String instruction = data.getOrDefault("instruction", "unknown").toString(); - Object query = data.get("query"); - // 获取输入 - VariableSelector input; - if (query instanceof List queryList && queryList.size() >= 2) { - input = new VariableSelector(queryList.get(0).toString(), queryList.get(1).toString()); - } - else { - input = new VariableSelector("unknown", "unknown"); - } - // 获取输出参数 - List> parametersList = (List>) data.getOrDefault("parameters", - List.of()); - return new ParameterParsingNodeData("input", parametersList, instruction, "output", input); + ParameterParsingNodeData nodeData = new ParameterParsingNodeData(); + + // 获取必要信息 + List selectorList = MapReadUtil.safeCastToList(data.get("query"), String.class); + VariableSelector selector = new VariableSelector(selectorList.get(0), selectorList.get(1)); + String chatModelName = this.exactChatModelName(DSLDialectType.DIFY, data); + Map modelParams = this.exactChatModelParam(DSLDialectType.DIFY, data); + List params = MapReadUtil.safeCastToListWithMap(data.get("parameters")) + .stream() + .filter(map -> map.containsKey("name")) + .map(map -> { + String name = map.get("name").toString(); + String description = map.getOrDefault("description", "").toString(); + VariableType type = VariableType + .fromDifyValue(map.getOrDefault("type", VariableType.OBJECT.difyValue()).toString()) + .orElse(VariableType.OBJECT); + return new ParameterParsingNodeData.Param(name, type, description); + }) + .toList(); + String instruction = Optional.ofNullable(data.get("instruction")).map(Object::toString).orElse(""); + + // 设置信息 + nodeData.setInputSelector(selector); + nodeData.setChatModeName(chatModelName); + nodeData.setModeParams(modelParams); + nodeData.setParameters(params); + nodeData.setInstruction(instruction); + nodeData.setSuccessKey("__is_success"); + nodeData.setReasonKey("__reason"); + return nodeData; } @Override public Map dump(ParameterParsingNodeData nd) { - Map m = new LinkedHashMap<>(); - - if (nd.getInstruction() != null) { - m.put("instruction", nd.getInstruction()); - } - - if (nd.getParameters() != null) { - m.put("parameters", nd.getParameters()); - } - - if (nd.getInputs() != null && !nd.getInputs().isEmpty()) { - VariableSelector selector = nd.getInputs().get(0); - m.put("query", List.of(selector.getNamespace(), selector.getName())); - } - - return m; + throw new UnsupportedOperationException(); } @Override @@ -101,7 +102,62 @@ public Boolean supportDialect(DSLDialectType dialect) { } }), - CUSTOM(defaultCustomDialectConverter(ParameterParsingNodeData.class)); + STUDIO(new DialectConverter<>() { + @Override + public Boolean supportDialect(DSLDialectType dialectType) { + return DSLDialectType.STUDIO.equals(dialectType); + } + + @Override + public ParameterParsingNodeData parse(Map data) throws JsonProcessingException { + ParameterParsingNodeData nodeData = new ParameterParsingNodeData(); + + // 获取必要信息 + VariableSelector selector = this.varTemplateToSelector(DSLDialectType.STUDIO, MapReadUtil + .safeCastToListWithMap(MapReadUtil.getMapDeepValue(data, String.class, "config", "input_params")) + .get(0) + .get("value") + .toString()); + String chatModelName = this.exactChatModelName(DSLDialectType.STUDIO, data); + Map modelParams = this.exactChatModelParam(DSLDialectType.STUDIO, data); + List params = Optional + .ofNullable(MapReadUtil.safeCastToListWithMap( + MapReadUtil.getMapDeepValue(data, String.class, "config", "node_param", "extract_params"))) + .orElse(List.of()) + .stream() + .filter(map -> map.containsKey("key")) + .map(map -> { + String name = map.get("key").toString(); + String description = map.getOrDefault("desc", "").toString(); + VariableType type = VariableType + .fromStudioValue(map.getOrDefault("type", VariableType.OBJECT.studioValue()).toString()) + .orElse(VariableType.OBJECT); + return new ParameterParsingNodeData.Param(name, type, description); + }) + .toList(); + String instruction = Optional + .ofNullable(MapReadUtil.getMapDeepValue(data, String.class, "config", "node_param", "instruction")) + .map(Object::toString) + .orElse(""); + + // 设置必要信息 + nodeData.setInputSelector(selector); + nodeData.setChatModeName(chatModelName); + nodeData.setModeParams(modelParams); + nodeData.setParameters(params); + nodeData.setInstruction(instruction); + nodeData.setSuccessKey("_is_completed"); + nodeData.setReasonKey("_reason"); + return nodeData; + } + + @Override + public Map dump(ParameterParsingNodeData nodeData) { + throw new UnsupportedOperationException(); + } + }) + + , CUSTOM(defaultCustomDialectConverter(ParameterParsingNodeData.class)); private final DialectConverter converter; @@ -123,19 +179,26 @@ public String generateVarName(int count) { @Override public BiConsumer> postProcessConsumer(DSLDialectType dialectType) { return switch (dialectType) { - case DIFY -> emptyProcessConsumer().andThen((nodeData, map) -> { - nodeData.setOutputKey(nodeData.getVarName() + "_output"); - List variableList = nodeData.getParameters() - .stream() - .map(mp -> new Variable(mp.getOrDefault("name", "unknown").toString(), - VariableType.fromDifyValue(mp.getOrDefault("type", "string").toString()) - .orElse(VariableType.OBJECT)) - .setDescription(mp.getOrDefault("description", "").toString())) + case DIFY, STUDIO -> emptyProcessConsumer().andThen((nodeData, idToVarName) -> { + // 设置输出 + List outputs = Stream + .concat(nodeData.getParameters().stream().map(p -> new Variable(p.name(), p.type())), + ParameterParsingNodeData.getDefaultOutputSchema(dialectType).stream()) .toList(); - nodeData.setOutputs(variableList); - }).andThen(super.postProcessConsumer(dialectType)).andThen((nodeData, varName) -> { - nodeData.setInputTextKey(nodeData.getInputs().get(0).getNameInCode()); - }); + nodeData.setOutputs(outputs); + + // 设置输入以及key + Optional.ofNullable(nodeData.getInputSelector()) + .ifPresent(selector -> selector + .setNameInCode(idToVarName.getOrDefault(selector.getNamespace(), selector.getNamespace()) + "_" + + selector.getName())); + nodeData.setSuccessKey(nodeData.getVarName() + "_" + nodeData.getSuccessKey()); + nodeData.setReasonKey(nodeData.getVarName() + "_" + nodeData.getReasonKey()); + nodeData.setDataKey(nodeData.getVarName() + "_" + nodeData.getDataKey()); + + // 格式化instruction + nodeData.setInstruction(this.convertVarTemplate(dialectType, nodeData.getInstruction(), idToVarName)); + }).andThen(super.postProcessConsumer(dialectType)); default -> super.postProcessConsumer(dialectType); }; } diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ParameterParsingNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ParameterParsingNodeSection.java index 00f0f81809..0d0dedb1b7 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ParameterParsingNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ParameterParsingNodeSection.java @@ -17,9 +17,6 @@ package com.alibaba.cloud.ai.studio.admin.generator.service.generator.workflow.sections; import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; -import java.util.stream.Stream; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.Node; import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeType; @@ -27,8 +24,8 @@ import com.alibaba.cloud.ai.studio.admin.generator.service.dsl.DSLDialectType; import com.alibaba.cloud.ai.studio.admin.generator.service.generator.workflow.NodeSection; +import com.alibaba.cloud.ai.studio.admin.generator.utils.ObjectToCodeUtil; import org.springframework.stereotype.Component; -import org.springframework.util.CollectionUtils; @Component public class ParameterParsingNodeSection implements NodeSection { @@ -40,63 +37,67 @@ public boolean support(NodeType nodeType) { @Override public String render(Node node, String varName) { - ParameterParsingNodeData d = (ParameterParsingNodeData) node.getData(); - String id = node.getId(); - StringBuilder sb = new StringBuilder(); - - sb.append(String.format("// —— ParameterParsingNode [%s] ——%n", id)); - sb.append(String.format("ParameterParsingNode %s = ParameterParsingNode.builder()%n", varName)); - - if (d.getInputTextKey() != null) { - sb.append(String.format(".inputTextKey(\"%s\")%n", escape(d.getInputTextKey()))); - } - - sb.append(".chatClient(chatClient)\n"); - - List> params = d.getParameters(); - if (!CollectionUtils.isEmpty(params)) { - String joined = params.stream().map(m -> { - String mapCode = Stream - .of("name", m.getOrDefault("name", "unknown").toString(), "type", - m.getOrDefault("type", "string").toString(), "description", - m.getOrDefault("description", "").toString()) - .map(s -> "\"" + s + "\"") - .collect(Collectors.joining(", ")); - return String.format("Map.of(%s)", mapCode); - }).collect(Collectors.joining(", ")); - sb.append(String.format(".parameters(List.of(%s))%n", joined)); - } - - if (d.getOutputKey() != null) { - sb.append(String.format(".outputKey(\"%s\")%n", escape(d.getOutputKey()))); - } - - sb.append(".build();\n"); - - // 辅助节点 - String assistNodeCode = String.format("wrapperParameterNodeAction(%s, \"%s\", \"%s\")", varName, varName, - d.getOutputKey()); - - sb.append(String.format("stateGraph.addNode(\"%s\", AsyncNodeAction.node_async(%s));%n%n", varName, - assistNodeCode)); - - return sb.toString(); + ParameterParsingNodeData nodeData = ((ParameterParsingNodeData) node.getData()); + return String.format(""" + // -- ParameterParsingNode [%s] -- + stateGraph.addNode("%s", AsyncNodeAction.node_async( + createParameterParsingAction(chatModel, %s, %s, %s, %s, %s, %s, %s, %s) + )); + + """, node.getId(), varName, ObjectToCodeUtil.toCode(nodeData.getChatModeName()), + ObjectToCodeUtil.toCode(nodeData.getModeParams()), + ObjectToCodeUtil.toCode(nodeData.getInputSelector().getNameInCode()), + ObjectToCodeUtil.toCode(nodeData.getParameters()), ObjectToCodeUtil.toCode(nodeData.getSuccessKey()), + ObjectToCodeUtil.toCode(nodeData.getDataKey()), ObjectToCodeUtil.toCode(nodeData.getReasonKey()), + ObjectToCodeUtil.toCode(nodeData.getInstruction())); } @Override public String assistMethodCode(DSLDialectType dialectType) { return switch (dialectType) { - case DIFY -> + case DIFY, STUDIO -> """ - private NodeAction wrapperParameterNodeAction(NodeAction nodeAction, String nodeName, String key) { - return (state) -> { - Map result = nodeAction.apply(state); - Object object = result.get(key); - if(!(object instanceof Map map)) { - return Map.of(); - } - return map.entrySet().stream().collect(Collectors.toMap(e -> nodeName + "_" + e.getKey(), Map.Entry::getValue)); - }; + private NodeAction createParameterParsingAction( + ChatModel chatModel, + String chatModelName, Map modeParams, + String inputKey, List parameters, + String successKey, String dataKey, String reasonKey, String instruction) { + // build ChatClient + var chatOptionsBuilder = DashScopeChatOptions.builder().withModel(chatModelName); + Optional.ofNullable(modeParams.get("temperature")) + .ifPresent(val -> chatOptionsBuilder.withTemperature(val.doubleValue())); + Optional.ofNullable(modeParams.get("seed")).ifPresent(val -> chatOptionsBuilder.withSeed(val.intValue())); + Optional.ofNullable(modeParams.get("top_p")).ifPresent(val -> chatOptionsBuilder.withTopP(val.doubleValue())); + Optional.ofNullable(modeParams.get("top_k")).ifPresent(val -> chatOptionsBuilder.withTopK(val.intValue())); + Optional.ofNullable(modeParams.get("max_tokens")) + .ifPresent(val -> chatOptionsBuilder.withMaxToken(val.intValue())); + Optional.ofNullable(modeParams.get("repetition_penalty")) + .ifPresent(val -> chatOptionsBuilder.withRepetitionPenalty(val.doubleValue())); + final ChatClient chatClient = ChatClient.builder(chatModel).defaultOptions(chatOptionsBuilder.build()).build(); + + // build Node + ParameterParsingNode node = ParameterParsingNode.builder() + .inputText("") + .inputTextKey(inputKey) + .chatClient(chatClient) + .parameters(parameters) + .successKey(successKey) + .dataKey(dataKey) + .reasonKey(reasonKey) + .instruction(instruction) + .build(); + + // unpack answer + return state -> { + Map res = node.apply(state); + if(!(Boolean) res.get(successKey)) { + return res; + } + Map finalRes = new HashMap<>(res); + Map data = (Map) finalRes.remove(dataKey); + finalRes.putAll(data); + return finalRes; + }; } """; default -> ""; @@ -105,7 +106,7 @@ private NodeAction wrapperParameterNodeAction(NodeAction nodeAction, String node @Override public List getImports() { - return List.of("com.alibaba.cloud.ai.graph.node.ParameterParsingNode", "java.util.stream.Collectors"); + return List.of("com.alibaba.cloud.ai.graph.node.ParameterParsingNode"); } } From 78a173039341b1bc762d4ba4c06fcfeec0831a0c Mon Sep 17 00:00:00 2001 From: VLSMB <2047857654@qq.com> Date: Tue, 16 Sep 2025 18:34:44 +0800 Subject: [PATCH 12/12] fix bugs --- .../ParameterParsingNodeDataConverter.java | 4 ++-- .../sections/ParameterParsingNodeSection.java | 17 ++++++++++++----- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/ParameterParsingNodeDataConverter.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/ParameterParsingNodeDataConverter.java index e7d0f0b3cb..ffc95ea456 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/ParameterParsingNodeDataConverter.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/ParameterParsingNodeDataConverter.java @@ -114,7 +114,7 @@ public ParameterParsingNodeData parse(Map data) throws JsonProce // 获取必要信息 VariableSelector selector = this.varTemplateToSelector(DSLDialectType.STUDIO, MapReadUtil - .safeCastToListWithMap(MapReadUtil.getMapDeepValue(data, String.class, "config", "input_params")) + .safeCastToListWithMap(MapReadUtil.getMapDeepValue(data, List.class, "config", "input_params")) .get(0) .get("value") .toString()); @@ -122,7 +122,7 @@ public ParameterParsingNodeData parse(Map data) throws JsonProce Map modelParams = this.exactChatModelParam(DSLDialectType.STUDIO, data); List params = Optional .ofNullable(MapReadUtil.safeCastToListWithMap( - MapReadUtil.getMapDeepValue(data, String.class, "config", "node_param", "extract_params"))) + MapReadUtil.getMapDeepValue(data, List.class, "config", "node_param", "extract_params"))) .orElse(List.of()) .stream() .filter(map -> map.containsKey("key")) diff --git a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ParameterParsingNodeSection.java b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ParameterParsingNodeSection.java index 0d0dedb1b7..851d065339 100644 --- a/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ParameterParsingNodeSection.java +++ b/spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/ParameterParsingNodeSection.java @@ -41,7 +41,7 @@ public String render(Node node, String varName) { return String.format(""" // -- ParameterParsingNode [%s] -- stateGraph.addNode("%s", AsyncNodeAction.node_async( - createParameterParsingAction(chatModel, %s, %s, %s, %s, %s, %s, %s, %s) + createParameterParsingAction(chatModel, %s, %s, %s, %s, %s, %s, %s, %s, "%s") )); """, node.getId(), varName, ObjectToCodeUtil.toCode(nodeData.getChatModeName()), @@ -49,7 +49,7 @@ public String render(Node node, String varName) { ObjectToCodeUtil.toCode(nodeData.getInputSelector().getNameInCode()), ObjectToCodeUtil.toCode(nodeData.getParameters()), ObjectToCodeUtil.toCode(nodeData.getSuccessKey()), ObjectToCodeUtil.toCode(nodeData.getDataKey()), ObjectToCodeUtil.toCode(nodeData.getReasonKey()), - ObjectToCodeUtil.toCode(nodeData.getInstruction())); + ObjectToCodeUtil.toCode(nodeData.getInstruction()), varName); } @Override @@ -61,7 +61,7 @@ private NodeAction createParameterParsingAction( ChatModel chatModel, String chatModelName, Map modeParams, String inputKey, List parameters, - String successKey, String dataKey, String reasonKey, String instruction) { + String successKey, String dataKey, String reasonKey, String instruction, String outputKeyPrefix) { // build ChatClient var chatOptionsBuilder = DashScopeChatOptions.builder().withModel(chatModelName); Optional.ofNullable(modeParams.get("temperature")) @@ -95,7 +95,12 @@ private NodeAction createParameterParsingAction( } Map finalRes = new HashMap<>(res); Map data = (Map) finalRes.remove(dataKey); - finalRes.putAll(data); + finalRes.putAll(data.entrySet() + .stream() + .map(e -> + Map.entry(outputKeyPrefix + "_" + e.getKey(), e.getValue())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)) + ); return finalRes; }; } @@ -106,7 +111,9 @@ private NodeAction createParameterParsingAction( @Override public List getImports() { - return List.of("com.alibaba.cloud.ai.graph.node.ParameterParsingNode"); + return List.of("com.alibaba.cloud.ai.graph.node.ParameterParsingNode", + "com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions", "java.util.Optional", + "java.util.stream.Collectors"); } }