Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import com.alibaba.cloud.ai.graph.state.strategy.AppendStrategy;
import org.apache.commons.collections4.CollectionUtils;

import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
Expand Down Expand Up @@ -82,6 +83,8 @@ public class ReactAgent extends BaseAgent {

private Function<OverAllState, Boolean> shouldContinueFunc;

private String inputKey;

protected ReactAgent(LlmNode llmNode, ToolNode toolNode, Builder builder) throws GraphStateException {
this.name = builder.name;
this.description = builder.description;
Expand All @@ -96,6 +99,9 @@ protected ReactAgent(LlmNode llmNode, ToolNode toolNode, Builder builder) throws
this.postLlmHook = builder.postLlmHook;
this.preToolHook = builder.preToolHook;
this.postToolHook = builder.postToolHook;
this.inputKey = builder.inputKey;

// 初始化graph
this.graph = initGraph();
}

Expand Down Expand Up @@ -154,35 +160,52 @@ private StateGraph initGraph() throws GraphStateException {
if (keyStrategyFactory == null) {
this.keyStrategyFactory = () -> {
HashMap<String, KeyStrategy> keyStrategyHashMap = new HashMap<>();
if (inputKey != null) {
keyStrategyHashMap.put(inputKey, new ReplaceStrategy());
}
keyStrategyHashMap.put("messages", new AppendStrategy());
return keyStrategyHashMap;
};
}
else {
KeyStrategyFactory originalFactory = this.keyStrategyFactory;
this.keyStrategyFactory = () -> {
HashMap<String, KeyStrategy> keyStrategyHashMap = new HashMap<>(originalFactory.apply());
keyStrategyHashMap.put("messages", new AppendStrategy());
return keyStrategyHashMap;
};
}

NodeAction effectivePreLlmHook = this.preLlmHook;
if (effectivePreLlmHook == null) {
effectivePreLlmHook = state -> {
if (state.value("messages").isPresent()) {
List<Message> messages = (List<Message>) state.value("messages").orElseThrow();
state.updateState(Map.of(this.inputKey, messages));
}
return Map.of();
};
}

StateGraph graph = new StateGraph(name, this.keyStrategyFactory);

if (preLlmHook != null) {
graph.addNode("preLlm", node_async(preLlmHook));
}
graph.addNode("preLlm", node_async(effectivePreLlmHook));
graph.addNode("llm", node_async(this.llmNode));
if (postLlmHook != null) {
graph.addNode("postLlm", node_async(postLlmHook));
graph.addNode("postLlm", node_async(this.postLlmHook));
}

if (preToolHook != null) {
graph.addNode("preTool", node_async(preToolHook));
graph.addNode("preTool", node_async(this.preToolHook));
}

graph.addNode("tool", node_async(this.toolNode));

if (postToolHook != null) {
graph.addNode("postTool", node_async(postToolHook));
graph.addNode("postTool", node_async(this.postToolHook));
}

if (preLlmHook != null) {
graph.addEdge(START, "preLlm").addEdge("preLlm", "llm");
}
else {
graph.addEdge(START, "llm");
}
graph.addEdge(START, "preLlm").addEdge("preLlm", "llm");

if (postLlmHook != null) {
graph.addEdge("llm", "postLlm")
Expand All @@ -199,10 +222,10 @@ private StateGraph initGraph() throws GraphStateException {
graph.addEdge("preTool", "tool");
}
if (postToolHook != null) {
graph.addEdge("tool", "postTool").addEdge("postTool", preLlmHook != null ? "preLlm" : "llm");
graph.addEdge("tool", "postTool").addEdge("postTool", "preLlm");
}
else {
graph.addEdge("tool", preLlmHook != null ? "preLlm" : "llm");
graph.addEdge("tool", "preLlm");
}

return graph;
Expand Down Expand Up @@ -334,6 +357,8 @@ public static class Builder {

private NodeAction postToolHook;

private String inputKey = "messages";

public Builder name(String name) {
this.name = name;
return this;
Expand Down Expand Up @@ -419,6 +444,11 @@ public Builder postToolHook(NodeAction postToolHook) {
return this;
}

public Builder inputKey(String inputKey) {
this.inputKey = inputKey;
return this;
}

public ReactAgent build() throws GraphStateException {
if (chatClient == null) {
if (model == null) {
Expand All @@ -434,7 +464,7 @@ public ReactAgent build() throws GraphStateException {
chatClient = clientBuilder.build();
}

LlmNode.Builder llmNodeBuilder = LlmNode.builder().chatClient(chatClient).messagesKey("messages");
LlmNode.Builder llmNodeBuilder = LlmNode.builder().chatClient(chatClient).messagesKey(this.inputKey);
if (CollectionUtils.isNotEmpty(tools)) {
llmNodeBuilder.toolCallbacks(tools);
}
Expand Down
Loading