Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.ArrayList;

import com.alibaba.cloud.ai.graph.*;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
Expand All @@ -31,6 +32,7 @@
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.resolution.ToolCallbackResolver;

Expand All @@ -51,6 +53,14 @@ public class ReactAgent {

private CompiledGraph compiledGraph;

private NodeAction preLlmHook;

private NodeAction postLlmHook;

private NodeAction preToolHook;

private NodeAction postToolHook;

private List<String> tools;

private int max_iterations = 10;
Expand All @@ -63,6 +73,24 @@ public class ReactAgent {

private Function<OverAllState, Boolean> shouldContinueFunc;

private ReactAgent(String name, LlmNode llmNode, ToolNode toolNode, int maxIterations,
OverAllStateFactory overAllStateFactory, CompileConfig compileConfig,
Function<OverAllState, Boolean> shouldContinueFunc, NodeAction preLlmHook, NodeAction postLlmHook,
NodeAction preToolHook, NodeAction postToolHook) throws GraphStateException {
this.name = name;
this.llmNode = llmNode;
this.toolNode = toolNode;
this.max_iterations = maxIterations;
this.overAllStateFactory = overAllStateFactory;
this.compileConfig = compileConfig;
this.shouldContinueFunc = shouldContinueFunc;
this.preLlmHook = preLlmHook;
this.postLlmHook = postLlmHook;
this.preToolHook = preToolHook;
this.postToolHook = postToolHook;
this.graph = initGraph();
}

public ReactAgent(LlmNode llmNode, ToolNode toolNode, int maxIterations, OverAllStateFactory overAllStateFactory,
CompileConfig compileConfig, Function<OverAllState, Boolean> shouldContinueFunc)
throws GraphStateException {
Expand Down Expand Up @@ -165,11 +193,53 @@ private StateGraph initGraph() throws GraphStateException {
};
}

return new StateGraph(this.overAllStateFactory).addNode("agent", node_async(this.llmNode))
.addNode("tool", node_async(this.toolNode))
.addEdge(START, "agent")
.addConditionalEdges("agent", edge_async(this::think), Map.of("continue", "tool", "end", END))
.addEdge("tool", "agent");
StateGraph graph = new StateGraph(this.overAllStateFactory);

if (preLlmHook != null) {
graph.addNode("preLlm", node_async(preLlmHook));
}
graph.addNode("llm", node_async(this.llmNode));
if (postLlmHook != null) {
graph.addNode("postLlm", node_async(postLlmHook));
}

if (preToolHook != null) {
graph.addNode("preTool", node_async(preToolHook));
}
graph.addNode("tool", node_async(this.toolNode));
if (postToolHook != null) {
graph.addNode("postTool", node_async(postToolHook));
}

if (preLlmHook != null) {
graph.addEdge(START, "preLlm").addEdge("preLlm", "llm");
}
else {
graph.addEdge(START, "llm");
}

if (postLlmHook != null) {
graph.addEdge("llm", "postLlm")
.addConditionalEdges("postLlm", edge_async(this::think),
Map.of("continue", preToolHook != null ? "preTool" : "tool", "end", END));
}
else {
graph.addConditionalEdges("llm", edge_async(this::think),
Map.of("continue", preToolHook != null ? "preTool" : "tool", "end", END));
}

// 添加工具相关边
if (preToolHook != null) {
graph.addEdge("preTool", "tool");
}
if (postToolHook != null) {
graph.addEdge("tool", "postTool").addEdge("postTool", preLlmHook != null ? "preLlm" : "llm");
}
else {
graph.addEdge("tool", preLlmHook != null ? "preLlm" : "llm");
}

return graph;
}

private String think(OverAllState state) {
Expand Down Expand Up @@ -260,6 +330,14 @@ public static class Builder {

private Function<OverAllState, Boolean> shouldContinueFunc;

private NodeAction preLlmHook;

private NodeAction postLlmHook;

private NodeAction preToolHook;

private NodeAction postToolHook;

public Builder name(String name) {
this.name = name;
return this;
Expand Down Expand Up @@ -300,16 +378,41 @@ public Builder shouldContinueFunction(Function<OverAllState, Boolean> shouldCont
return this;
}

public Builder preLlmHook(NodeAction preLlmHook) {
this.preLlmHook = preLlmHook;
return this;
}

public Builder postLlmHook(NodeAction postLlmHook) {
this.postLlmHook = postLlmHook;
return this;
}

public Builder preToolHook(NodeAction preToolHook) {
this.preToolHook = preToolHook;
return this;
}

public Builder postToolHook(NodeAction postToolHook) {
this.postToolHook = postToolHook;
return this;
}

public ReactAgent build() throws GraphStateException {
LlmNode llmNode = LlmNode.builder().chatClient(chatClient).messagesKey("messages").build();
ToolNode toolNode = null;
if (resolver != null) {
return new ReactAgent(name, chatClient, resolver, maxIterations, allStateFactory, compileConfig,
shouldContinueFunc);
toolNode = ToolNode.builder().toolCallbackResolver(resolver).build();
}
else if (tools != null) {
return new ReactAgent(name, chatClient, tools, maxIterations, allStateFactory, compileConfig,
shouldContinueFunc);
toolNode = ToolNode.builder().toolCallbacks(tools).build();
}
else {
throw new IllegalArgumentException("Either tools or resolver must be provided");
}
throw new IllegalArgumentException("Either tools or resolver must be provided");

return new ReactAgent(name, llmNode, toolNode, maxIterations, allStateFactory, compileConfig,
shouldContinueFunc, preLlmHook, postLlmHook, preToolHook, postToolHook);
}

}
Expand Down
Loading
Loading