Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
/*
* Copyright 2024-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.alibaba.cloud.ai.graph.node;

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.metadata.ToolMetadata;
import org.springframework.lang.Nullable;

import java.util.Arrays;
import java.util.List;
import java.util.Map;

// Agent节点
public class AgentNode implements NodeAction {

private static final Logger logger = LoggerFactory.getLogger(AgentNode.class);

private final ChatClient chatClient;

private final ToolCallback[] toolCallbacks;

private final Strategy strategy;

// Prompt内可以有变量,格式为{varName},将在正式调用Client前替换占位变量
private final String systemPrompt;

private final String userPrompt;

private final Integer maxIterations;

private final String outputKey;

public enum Strategy {

REACT, TOOL_CALLING

}

public AgentNode(ChatClient chatClient, ToolCallback[] toolCallbacks, Strategy strategy, String systemPrompt,
String userPrompt, Integer maxIterations, String outputKey) {
this.chatClient = chatClient;
this.strategy = strategy == null ? Strategy.REACT : strategy;
this.systemPrompt = systemPrompt == null ? "" : systemPrompt;
this.userPrompt = userPrompt == null ? "" : userPrompt;
this.maxIterations = maxIterations == null ? 1 : maxIterations;
this.outputKey = outputKey == null ? "agent_output" : outputKey;
if (this.chatClient == null) {
throw new IllegalArgumentException("ChatClient is required");
}
this.toolCallbacks = Arrays.stream(toolCallbacks).map(toolCallback -> {
// ToolCalling策略调用完工具后直接返回,需要包装一层使得returnDirect为true
if (this.strategy == Strategy.TOOL_CALLING && !toolCallback.getToolMetadata().returnDirect()) {
final ToolMetadata toolMetadata = ToolMetadata.builder().returnDirect(true).build();
return new ToolCallback() {
@Override
public ToolDefinition getToolDefinition() {
return toolCallback.getToolDefinition();
}

@Override
public ToolMetadata getToolMetadata() {
// returnDirect为true的ToolMetadata
return toolMetadata;
}

@Override
public String call(String toolInput) {
return toolCallback.call(toolInput);
}

@Override
public String call(String toolInput, @Nullable ToolContext tooContext) {
return toolCallback.call(toolInput, tooContext);
}

};
}
else {
return toolCallback;
}
}).toArray(ToolCallback[]::new);
}

@Override
public Map<String, Object> apply(OverAllState state) throws Exception {
String userPrompt = new PromptTemplate(this.userPrompt).render(state.data());
String systemPrompt = new PromptTemplate(this.systemPrompt).render(state.data());
String output = switch (this.strategy) {
case TOOL_CALLING, REACT -> {
int count = this.maxIterations;
// 重试机制
while (count-- > 0) {
try {
String content = this.chatClient.prompt(systemPrompt)
.toolCallbacks(this.toolCallbacks)
.user(userPrompt)
.call()
.content();
if (content != null) {
logger.warn("ChatClient Call Return Null...");
yield content;
}
}
catch (Exception e) {
logger.warn("ChatClient Call Fail: {}", e.getMessage());
}
}
yield null;
}
};
return Map.of(this.outputKey, output == null ? "" : output);
}

public static Builder builder() {
return new Builder();
}

public static class Builder {

private ChatClient chatClient;

private ToolCallback[] toolCallbacks;

private Strategy strategy;

private String systemPrompt;

private String userPrompt;

private Integer maxIterations;

private String outputKey;

public Builder chatClient(ChatClient chatClient) {
this.chatClient = chatClient;
return this;
}

public Builder toolCallbacks(ToolCallback[] toolCallbacks) {
this.toolCallbacks = toolCallbacks;
return this;
}

public Builder toolCallBacks(List<ToolCallback> toolCallbacks) {
this.toolCallbacks = toolCallbacks.toArray(ToolCallback[]::new);
return this;
}

public Builder strategy(Strategy strategy) {
this.strategy = strategy;
return this;
}

public Builder systemPrompt(String systemPrompt) {
this.systemPrompt = systemPrompt;
return this;
}

public Builder userPrompt(String userPrompt) {
this.userPrompt = userPrompt;
return this;
}

public Builder maxIterations(Integer maxIterations) {
this.maxIterations = maxIterations;
return this;
}

public Builder outputKey(String outputKey) {
this.outputKey = outputKey;
return this;
}

public AgentNode build() {
return new AgentNode(chatClient, toolCallbacks, strategy, systemPrompt, userPrompt, maxIterations,
outputKey);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,4 @@ public void setVarName(String varName) {
this.varName = varName;
}

public static String defaultOutputKey(String nodeId) {
return nodeId + "_output";
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ public enum NodeType {

ANSWER("answer", "answer"),

AGENT("agent", "agent"),

LLM("llm", "llm"),

CODE("code", "code"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Copyright 2024-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.alibaba.cloud.ai.studio.admin.generator.model.workflow.nodedata;

import com.alibaba.cloud.ai.studio.admin.generator.model.Variable;
import com.alibaba.cloud.ai.studio.admin.generator.model.VariableType;
import com.alibaba.cloud.ai.studio.admin.generator.model.workflow.NodeData;

import java.util.List;
import java.util.Map;

public class AgentNodeData extends NodeData {

public static Variable getDefaultOutputSchema() {
return new Variable("text", VariableType.STRING.value());
}

private Map<String, Object> agentParameterMap;

private String agentStrategyName;

private String instructionPrompt;

private String queryPrompt;

private List<ToolData> toolList;

private Integer maxIterations;

public record ToolData(Map<String, Object> parameters, String toolName, String toolDescription) {
}

public Map<String, Object> getAgentParameterMap() {
return agentParameterMap;
}

public void setAgentParameterMap(Map<String, Object> agentParameterMap) {
this.agentParameterMap = agentParameterMap;
}

public String getAgentStrategyName() {
return agentStrategyName;
}

public void setAgentStrategyName(String agentStrategyName) {
// todo: 支持更多的策略,如MCP,多轮对话
this.agentStrategyName = switch (agentStrategyName) {
case "function_calling" -> "TOOL_CALLING";
default -> "REACT";
};
}

public String getInstructionPrompt() {
return instructionPrompt;
}

public void setInstructionPrompt(String instructionPrompt) {
this.instructionPrompt = instructionPrompt;
}

public String getQueryPrompt() {
return queryPrompt;
}

public void setQueryPrompt(String queryPrompt) {
this.queryPrompt = queryPrompt;
}

public List<ToolData> getToolList() {
return toolList;
}

public void setToolList(List<ToolData> toolList) {
this.toolList = toolList;
}

public Integer getMaxIterations() {
return maxIterations;
}

public void setMaxIterations(Integer maxIterations) {
this.maxIterations = maxIterations;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ public enum DSLDialectType {

CUSTOM("custom", ".yml");

private String value;
private final String value;

private String fileExtension;
private final String fileExtension;

public String value() {
return value;
Expand Down
Loading
Loading