Skip to content

Commit 18da21e

Browse files
committed
feat: support CodeNodeSection for Studio DSL
1 parent a774fe2 commit 18da21e

File tree

4 files changed

+118
-16
lines changed

4 files changed

+118
-16
lines changed

spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/AbstractNodeDataConverter.java

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
import java.util.Map;
2020
import java.util.function.BiConsumer;
2121
import java.util.function.BiFunction;
22+
import java.util.regex.MatchResult;
2223
import java.util.regex.Matcher;
2324
import java.util.regex.Pattern;
2425

26+
import com.alibaba.cloud.ai.studio.admin.generator.model.VariableSelector;
2527
import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeData;
2628
import com.fasterxml.jackson.core.JsonProcessingException;
2729
import com.fasterxml.jackson.core.type.TypeReference;
@@ -72,6 +74,23 @@ public interface DialectConverter<T> {
7274

7375
Map<String, Object> dump(T nodeData);
7476

77+
/**
78+
* 将模板字符串转换为变量选择器
79+
* @param dialectType dsl语言
80+
* @param template 模板字符串
81+
* @return 变量选择器
82+
*/
83+
default VariableSelector varTemplateToSelector(DSLDialectType dialectType, String template) {
84+
Pattern pattern = switch (dialectType) {
85+
case DIFY -> DIFY_VAR_TEMPLATE_PATTERN;
86+
case STUDIO -> STUDIO_VAR_TEMPLATE_PATTERN;
87+
default -> throw new UnsupportedOperationException();
88+
};
89+
Matcher matcher = pattern.matcher(template);
90+
MatchResult result = matcher.results().findFirst().orElseThrow();
91+
return new VariableSelector(result.group(1), result.group(2));
92+
}
93+
7594
}
7695

7796
public static <R> DialectConverter<R> defaultCustomDialectConverter(Class<R> clazz) {
@@ -100,6 +119,12 @@ public Map<String, Object> dump(R nodeData) {
100119

101120
protected abstract List<DialectConverter<T>> getDialectConverters();
102121

122+
private static final Pattern DIFY_VAR_TEMPLATE_PATTERN = Pattern.compile("\\{\\{#(\\w+)\\.(\\w+)#}}");
123+
124+
private static final Pattern STUDIO_VAR_TEMPLATE_PATTERN = Pattern.compile("\\$\\{(\\w+)\\.(\\w+)}");
125+
126+
private static final Pattern VAR_TEMPLATE_PATTERN = Pattern.compile("\\{(\\w+)}");
127+
103128
/**
104129
* 将文本中变量占位符进行转化,比如Dify DSL的"你好,{{#123.query#}}"转化为"你好,{nodeName1_query}"
105130
* @param dialectType dsl语言
@@ -116,8 +141,7 @@ protected String convertVarTemplate(DSLDialectType dialectType, String templateS
116141
return str;
117142
}
118143
StringBuilder result = new StringBuilder();
119-
Pattern pattern = Pattern.compile("\\{\\{#(\\w+)\\.(\\w+)#}}");
120-
Matcher matcher = pattern.matcher(str);
144+
Matcher matcher = DIFY_VAR_TEMPLATE_PATTERN.matcher(str);
121145
while (matcher.find()) {
122146
String nodeId = matcher.group(1);
123147
String varName = matcher.group(2);
@@ -133,8 +157,8 @@ protected String convertVarTemplate(DSLDialectType dialectType, String templateS
133157
return str;
134158
}
135159
StringBuilder result = new StringBuilder();
136-
Pattern pattern = Pattern.compile("\\$\\{(\\w+)\\.(\\w+)}");
137-
Matcher matcher = pattern.matcher(str);
160+
161+
Matcher matcher = STUDIO_VAR_TEMPLATE_PATTERN.matcher(str);
138162
while (matcher.find()) {
139163
String nodeId = matcher.group(1);
140164
String varName = matcher.group(2);
@@ -150,8 +174,6 @@ protected String convertVarTemplate(DSLDialectType dialectType, String templateS
150174
return func.apply(templateString, idToVarName);
151175
}
152176

153-
private static final Pattern VAR_TEMPLATE_PATTERN = Pattern.compile("\\{(\\w+)}");
154-
155177
/**
156178
* 获取模板中的变量占位符,比如"你好{var1},{var2}"返回"[var1, var2]"
157179
* @param template 模板字符串

spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/dsl/converter/CodeNodeDataConverter.java

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import java.util.stream.Stream;
2525

2626
import com.alibaba.cloud.ai.studio.admin.generator.model.Variable;
27+
import com.alibaba.cloud.ai.studio.admin.generator.model.VariableSelector;
2728
import com.alibaba.cloud.ai.studio.admin.generator.model.VariableType;
2829
import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeType;
2930
import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata.CodeNodeData;
@@ -138,7 +139,82 @@ public Boolean supportDialect(DSLDialectType dialectType) {
138139

139140
@Override
140141
public CodeNodeData parse(Map<String, Object> data) throws JsonProcessingException {
141-
return null;
142+
CodeNodeData nodeData = new CodeNodeData();
143+
144+
// 获取基本信息
145+
String code = MapReadUtil.getMapDeepValue(data, String.class, "config", "node_param", "script_content");
146+
String lang = MapReadUtil.getMapDeepValue(data, String.class, "config", "node_param", "script_type");
147+
Boolean isRetry = Optional
148+
.ofNullable(MapReadUtil.getMapDeepValue(data, Boolean.class, "config", "node_param", "retry_config",
149+
"retry_enabled"))
150+
.orElse(false);
151+
int maxRetryCount = isRetry ? Optional
152+
.ofNullable(MapReadUtil.getMapDeepValue(data, Integer.class, "config", "node_param", "retry_config",
153+
"max_retries"))
154+
.orElse(1) : 1;
155+
int retryIntervalMs = isRetry ? Optional
156+
.ofNullable(MapReadUtil.getMapDeepValue(data, Integer.class, "config", "node_param", "retry_config",
157+
"retry_interval"))
158+
.orElse(1000) : 1000;
159+
160+
List<Variable> outputParams = Optional
161+
.ofNullable(MapReadUtil.safeCastToListWithMap(
162+
MapReadUtil.getMapDeepValue(data, List.class, "config", "output_params")))
163+
.orElse(List.of())
164+
.stream()
165+
.filter(map -> map.containsKey("key"))
166+
.map(map -> new Variable(map.get("key").toString(), VariableType.OBJECT))
167+
.toList();
168+
List<CodeNodeData.CodeParam> inputParams = Optional
169+
.ofNullable(MapReadUtil
170+
.safeCastToListWithMap(MapReadUtil.getMapDeepValue(data, List.class, "config", "input_params")))
171+
.orElse(List.of())
172+
.stream()
173+
.filter(map -> map.containsKey("key") && map.containsKey("value") && map.containsKey("value_from"))
174+
.map(map -> {
175+
String key = map.get("key").toString();
176+
Object value = map.get("value");
177+
String valueFrom = map.get("value_from").toString();
178+
if ("input".equalsIgnoreCase(valueFrom)) {
179+
return CodeNodeData.CodeParam.withValue(key, value);
180+
}
181+
else {
182+
// 先以Value的形式存储selector,在post阶段转换为正确的stateKey
183+
VariableSelector selector = this.varTemplateToSelector(DSLDialectType.STUDIO,
184+
value.toString());
185+
List<String> list = List.of(selector.getNamespace(), selector.getName());
186+
return new CodeNodeData.CodeParam(key, list, value.toString());
187+
}
188+
})
189+
.toList();
190+
191+
// 设置基本信息
192+
nodeData.setCodeStyle(CodeNodeData.CodeStyle.GLOBAL_DICTIONARY);
193+
nodeData.setCode(code);
194+
nodeData.setCodeLanguage(lang);
195+
nodeData.setMaxRetryCount(maxRetryCount);
196+
nodeData.setRetryIntervalMs(retryIntervalMs);
197+
nodeData.setInputParams(inputParams);
198+
nodeData.setOutputs(outputParams);
199+
200+
// 设置错误策略
201+
String errorStrategy = MapReadUtil.getMapDeepValue(data, String.class, "config", "node_param",
202+
"try_catch_config", "strategy");
203+
if (errorStrategy != null) {
204+
// 暂仅支持默认值
205+
List<Map<String, Object>> defaultValueList = MapReadUtil
206+
.safeCastToListWithMap(MapReadUtil.getMapDeepValue(data, List.class, "config", "node_param",
207+
"try_catch_config", "default_values"));
208+
if (defaultValueList != null) {
209+
Map<String, Object> defaultValue = defaultValueList.stream()
210+
.filter(map -> map.containsKey("key") && map.containsKey("value"))
211+
.collect(Collectors.toUnmodifiableMap(map -> map.get("key").toString(),
212+
map -> map.get("value"), (a, b) -> b));
213+
nodeData.setDefaultValue(defaultValue);
214+
}
215+
}
216+
217+
return nodeData;
142218
}
143219

144220
@Override
@@ -180,7 +256,7 @@ public BiConsumer<CodeNodeData, Map<String, String>> postProcessConsumer(DSLDial
180256
@SuppressWarnings("unchecked")
181257
List<String> selector = (List<String>) param.value();
182258
return CodeNodeData.CodeParam.withKey(param.argName(),
183-
idToVarName.get(selector.get(0)) + "_" + selector.get(1));
259+
idToVarName.getOrDefault(selector.get(0), selector.get(0)) + "_" + selector.get(1));
184260
}).toList());
185261
}).andThen(super.postProcessConsumer(dialectType));
186262
default -> super.postProcessConsumer(dialectType);

spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/WorkflowProjectGenerator.java

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -237,12 +237,14 @@ private String renderImportSection(Workflow workflow) {
237237
Map.entry(NodeType.ANSWER, List.of("com.alibaba.cloud.ai.graph.node.AnswerNode")),
238238
Map.entry(NodeType.MIDDLE_OUTPUT,
239239
List.of("java.util.stream.Collectors", "org.springframework.ai.chat.prompt.PromptTemplate")),
240-
Map.entry(NodeType.CODE, List.of("com.alibaba.cloud.ai.graph.node.code.CodeExecutorNodeAction",
241-
"com.alibaba.cloud.ai.graph.node.code.entity.CodeExecutionConfig",
242-
"com.alibaba.cloud.ai.graph.node.code.CodeExecutor",
243-
"com.alibaba.cloud.ai.graph.node.code.LocalCommandlineCodeExecutor", "java.io.IOException",
244-
"java.nio.file.Files", "java.nio.file.Path", "java.util.stream.Collectors",
245-
"com.alibaba.cloud.ai.graph.node.code.entity.CodeParam")),
240+
Map.entry(NodeType.CODE,
241+
List.of("com.alibaba.cloud.ai.graph.node.code.CodeExecutorNodeAction",
242+
"com.alibaba.cloud.ai.graph.node.code.entity.CodeExecutionConfig",
243+
"com.alibaba.cloud.ai.graph.node.code.CodeExecutor",
244+
"com.alibaba.cloud.ai.graph.node.code.LocalCommandlineCodeExecutor",
245+
"java.io.IOException", "java.nio.file.Files", "java.nio.file.Path",
246+
"java.util.stream.Collectors", "com.alibaba.cloud.ai.graph.node.code.entity.CodeParam",
247+
"com.alibaba.cloud.ai.graph.node.code.entity.CodeStyle")),
246248
Map.entry(NodeType.AGENT,
247249
List.of("com.alibaba.cloud.ai.graph.node.AgentNode",
248250
"org.springframework.ai.tool.ToolCallback")),

spring-ai-alibaba-studio/spring-ai-alibaba-studio-server/spring-ai-alibaba-studio-server-admin/src/main/java/com/alibaba/cloud/ai/studio/admin/generator/service/generator/workflow/sections/CodeNodeSection.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ public String render(Node node, String varName) {
4242
.codeExecutor(codeExecutor)
4343
.codeLanguage("%s")
4444
.code(%s)
45+
.codeStyle(%s)
4546
.config(codeExecutionConfig)
4647
.params(%s)
4748
.outputKey("%s")
@@ -51,8 +52,9 @@ public String render(Node node, String varName) {
5152
));
5253
5354
""", node.getId(), varName, nodeData.getCodeLanguage(), ObjectToCodeUtil.toCode(nodeData.getCode()),
54-
ObjectToCodeUtil.toCode(nodeData.getInputParams()), nodeData.getOutputKey(), varName, varName,
55-
nodeData.getOutputKey(), varName, nodeData.getMaxRetryCount(), nodeData.getRetryIntervalMs(),
55+
ObjectToCodeUtil.toCode(nodeData.getCodeStyle()), ObjectToCodeUtil.toCode(nodeData.getInputParams()),
56+
nodeData.getOutputKey(), varName, varName, nodeData.getOutputKey(), varName,
57+
nodeData.getMaxRetryCount(), nodeData.getRetryIntervalMs(),
5658
ObjectToCodeUtil.toCode(nodeData.getDefaultValue()));
5759
}
5860

0 commit comments

Comments
 (0)