Skip to content
Merged

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,19 @@

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.node.code.entity.CodeBlock;
import com.alibaba.cloud.ai.graph.node.code.entity.CodeExecutionConfig;
import com.alibaba.cloud.ai.graph.node.code.entity.CodeExecutionResult;
import com.alibaba.cloud.ai.graph.node.code.entity.CodeLanguage;
import com.alibaba.cloud.ai.graph.node.code.entity.CodeParam;
import com.alibaba.cloud.ai.graph.node.code.entity.CodeStyle;
import com.alibaba.cloud.ai.graph.node.code.entity.RunnerAndPreload;
import com.alibaba.cloud.ai.graph.node.code.javascript.NodeJsTemplateTransformer;
import com.alibaba.cloud.ai.graph.node.code.python3.Python3TemplateTransformer;
Expand All @@ -47,10 +50,12 @@ public class CodeExecutorNodeAction implements NodeAction {

private final CodeExecutionConfig codeExecutionConfig;

private Map<String, Object> params;
private final List<CodeParam> params;

private final String outputKey;

private final CodeStyle style;

private static final Map<CodeLanguage, TemplateTransformer> CODE_TEMPLATE_TRANSFORMERS = Map.of(
CodeLanguage.PYTHON3, new Python3TemplateTransformer(), CodeLanguage.PYTHON,
new Python3TemplateTransformer(), CodeLanguage.JAVASCRIPT, new NodeJsTemplateTransformer(),
Expand All @@ -61,24 +66,25 @@ CodeLanguage.PYTHON3, new Python3TemplateTransformer(), CodeLanguage.PYTHON,
CodeLanguage.PYTHON3.getValue(), CodeLanguage.PYTHON, CodeLanguage.PYTHON.getValue(), CodeLanguage.JAVA,
CodeLanguage.JAVA.getValue());

public CodeExecutorNodeAction(CodeExecutor codeExecutor, String codeLanguage, String code,
CodeExecutionConfig config, Map<String, Object> params, String outputKey) {
public CodeExecutorNodeAction(CodeExecutor codeExecutor, String codeLanguage, String code, CodeStyle style,
CodeExecutionConfig config, List<CodeParam> params, String outputKey) {
this.codeExecutor = codeExecutor;
this.codeLanguage = codeLanguage;
this.style = style;
this.code = code;
this.codeExecutionConfig = config;
this.params = params;
this.outputKey = outputKey;
}

private Map<String, Object> executeWorkflowCodeTemplate(CodeLanguage language, String code, List<Object> inputs)
throws Exception {
private Map<String, Object> executeWorkflowCodeTemplate(CodeLanguage language, String code,
Map<String, Object> inputs) throws Exception {
TemplateTransformer templateTransformer = CODE_TEMPLATE_TRANSFORMERS.get(language);
if (templateTransformer == null) {
throw new RuntimeException("Unsupported language: " + language);
}

RunnerAndPreload runnerAndPreload = templateTransformer.transformCaller(code, inputs);
RunnerAndPreload runnerAndPreload = templateTransformer.transformCaller(code, inputs, style);
String response = executeCode(language, runnerAndPreload.preloadScript(), runnerAndPreload.runnerScript());

return templateTransformer.transformResponse(response);
Expand All @@ -100,12 +106,12 @@ private String executeCode(CodeLanguage language, String preloadScript, String c

@Override
public Map<String, Object> apply(OverAllState state) throws Exception {
List<Object> inputs = new ArrayList<>(10);
if (params != null && !params.isEmpty()) {
for (String key : params.keySet()) {
inputs.add(state.data().get((String) params.get(key)));
}
}
Map<String, Object> inputs = Optional.ofNullable(params)
.orElse(List.of())
.stream()
.collect(Collectors.toUnmodifiableMap(CodeParam::argName, param -> Optional.ofNullable(param.value())
.or(() -> StringUtils.hasText(param.stateKey()) ? state.value(param.stateKey()) : Optional.empty())
.orElseThrow(() -> new IllegalStateException("param has no value and legal key!"))));
Map<String, Object> resultObjectMap = executeWorkflowCodeTemplate(CodeLanguage.fromValue(codeLanguage), code,
inputs);
Map<String, Object> updatedState = new HashMap<>();
Expand All @@ -127,13 +133,16 @@ public static class Builder {

private String code;

private CodeStyle style;

private CodeExecutionConfig config;

private Map<String, Object> params;
private List<CodeParam> params;

private String outputKey;

public Builder() {
style = CodeStyle.EXPLICIT_PARAMETERS;
}

public Builder codeExecutor(CodeExecutor codeExecutor) {
Expand All @@ -151,13 +160,18 @@ public Builder code(String code) {
return this;
}

public Builder codeStyle(CodeStyle style) {
this.style = style;
return this;
}

public Builder config(CodeExecutionConfig config) {
this.config = config;
return this;
}

public Builder params(Map<String, String> params) {
this.params = new LinkedHashMap<>(params);
public Builder params(List<CodeParam> params) {
this.params = List.copyOf(params);
return this;
}

Expand All @@ -167,7 +181,7 @@ public Builder outputKey(String outputKey) {
}

public CodeExecutorNodeAction build() {
return new CodeExecutorNodeAction(codeExecutor, codeLanguage, code, config, params, outputKey);
return new CodeExecutorNodeAction(codeExecutor, codeLanguage, code, style, config, params, outputKey);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@

package com.alibaba.cloud.ai.graph.node.code;

import com.alibaba.cloud.ai.graph.node.code.entity.CodeStyle;
import com.alibaba.cloud.ai.graph.node.code.entity.RunnerAndPreload;
import com.fasterxml.jackson.databind.ObjectMapper;

import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
Expand All @@ -38,8 +38,8 @@ public abstract class TemplateTransformer {

protected static final String RESULT_TAG = "<<RESULT>>";

public RunnerAndPreload transformCaller(String code, List<Object> inputs) throws Exception {
String runnerScript = assembleRunnerScript(code, inputs);
public RunnerAndPreload transformCaller(String code, Map<String, Object> inputs, CodeStyle style) throws Exception {
String runnerScript = assembleRunnerScript(code, inputs, style);
String preloadScript = getPreloadScript();

return new RunnerAndPreload(runnerScript, preloadScript);
Expand All @@ -52,7 +52,7 @@ public Map<String, Object> transformResponse(String response) throws Exception {
mapper.getTypeFactory().constructMapType(Map.class, String.class, Object.class));
}

public abstract String getRunnerScript();
public abstract String getRunnerScript(CodeStyle style);

private String extractResultStrFromResponse(String response) {
Pattern pattern = Pattern.compile(RESULT_TAG + "(.*?)" + RESULT_TAG, Pattern.DOTALL);
Expand All @@ -66,14 +66,14 @@ private String extractResultStrFromResponse(String response) {
}
}

private String serializeInputs(List<Object> inputs) throws Exception {
private String serializeInputs(Map<String, Object> inputs) throws Exception {
ObjectMapper mapper = new ObjectMapper();
String inputsJsonStr = mapper.writeValueAsString(inputs);
return Base64.getEncoder().encodeToString(inputsJsonStr.getBytes(StandardCharsets.UTF_8));
}

private String assembleRunnerScript(String code, List<Object> inputs) throws Exception {
String script = getRunnerScript();
private String assembleRunnerScript(String code, Map<String, Object> inputs, CodeStyle style) throws Exception {
String script = getRunnerScript(style);
script = script.replace(CODE_PLACEHOLDER, code);
script = script.replace(INPUTS_PLACEHOLDER, serializeInputs(inputs));
return script;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright 2024-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.alibaba.cloud.ai.graph.node.code.entity;

/**
* @author vlsmb
* @since 2025/9/11
* @param argName 参数在代码中对应的名称
* @param value 参数值,如果为null,则从OverallState中获取
* @param stateKey 参数在OverallState中的key,如果value不为null,则忽略stateKey
*/
public record CodeParam(String argName, Object value, String stateKey) {
public CodeParam(String argName, String stateKey) {
this(argName, null, stateKey);
}

public static CodeParam withValue(String argName, Object value) {
return new CodeParam(argName, value, null);
}

public static CodeParam withKey(String argName, String stateKey) {
return new CodeParam(argName, null, stateKey);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright 2024-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.alibaba.cloud.ai.graph.node.code.entity;

/**
* @author vlsmb
* @since 2025/9/11
*/
public enum CodeStyle {

/**
* 参数直接作为函数形参的风格 示例: def main(x: int, y: int) -> dict:
*/
EXPLICIT_PARAMETERS,

/**
* 参数通过全局字典访问的风格 示例: def main(): x = params['x']
*/
GLOBAL_DICTIONARY

}
Loading
Loading