From 74df6a9f58cfe15a651a5372311298dce9dfdc80 Mon Sep 17 00:00:00 2001 From: zhouyou9505 Date: Mon, 9 Jun 2025 22:24:42 +0800 Subject: [PATCH 1/4] reactagent add preLlmHook postLlmHook preToolHook postToolHook --- .../cloud/ai/graph/agent/ReactAgent.java | 132 ++++++++- .../cloud/ai/graph/agent/ReactAgentTest.java | 274 ++++++++++++++++++ 2 files changed, 393 insertions(+), 13 deletions(-) create mode 100644 spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/agent/ReactAgentTest.java diff --git a/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/agent/ReactAgent.java b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/agent/ReactAgent.java index 5f24d0284b..44852107c7 100644 --- a/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/agent/ReactAgent.java +++ b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/agent/ReactAgent.java @@ -18,6 +18,7 @@ import java.util.List; import java.util.Map; import java.util.function.Function; +import java.util.ArrayList; import com.alibaba.cloud.ai.graph.*; import com.alibaba.cloud.ai.graph.exception.GraphStateException; @@ -31,6 +32,7 @@ import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.resolution.ToolCallbackResolver; @@ -51,6 +53,14 @@ public class ReactAgent { private CompiledGraph compiledGraph; + private NodeAction preLlmHook; + + private NodeAction postLlmHook; + + private NodeAction preToolHook; + + private NodeAction postToolHook; + private List tools; private int max_iterations = 10; @@ -63,8 +73,25 @@ public class ReactAgent { private Function shouldContinueFunc; + private ReactAgent(String name,LlmNode llmNode, ToolNode toolNode, int maxIterations, OverAllStateFactory overAllStateFactory, + CompileConfig compileConfig, Function shouldContinueFunc, + NodeAction preLlmHook, NodeAction postLlmHook, NodeAction preToolHook, NodeAction postToolHook) throws GraphStateException { + this.name = name; + this.llmNode = llmNode; + this.toolNode = toolNode; + this.max_iterations = maxIterations; + this.overAllStateFactory = overAllStateFactory; + this.compileConfig = compileConfig; + this.shouldContinueFunc = shouldContinueFunc; + this.preLlmHook = preLlmHook; + this.postLlmHook = postLlmHook; + this.preToolHook = preToolHook; + this.postToolHook = postToolHook; + this.graph = initGraph(); + } + public ReactAgent(LlmNode llmNode, ToolNode toolNode, int maxIterations, OverAllStateFactory overAllStateFactory, - CompileConfig compileConfig, Function shouldContinueFunc) + CompileConfig compileConfig, Function shouldContinueFunc) throws GraphStateException { this.llmNode = llmNode; this.toolNode = toolNode; @@ -165,11 +192,56 @@ private StateGraph initGraph() throws GraphStateException { }; } - return new StateGraph(this.overAllStateFactory).addNode("agent", node_async(this.llmNode)) - .addNode("tool", node_async(this.toolNode)) - .addEdge(START, "agent") - .addConditionalEdges("agent", edge_async(this::think), Map.of("continue", "tool", "end", END)) - .addEdge("tool", "agent"); + StateGraph graph = new StateGraph(this.overAllStateFactory); + + if (preLlmHook != null) { + graph.addNode("preLlm", node_async(preLlmHook)); + } + graph.addNode("llm", node_async(this.llmNode)); + if (postLlmHook != null) { + graph.addNode("postLlm", node_async(postLlmHook)); + } + + if (preToolHook != null) { + graph.addNode("preTool", node_async(preToolHook)); + } + graph.addNode("tool", node_async(this.toolNode)); + if (postToolHook != null) { + graph.addNode("postTool", node_async(postToolHook)); + } + + if (preLlmHook != null) { + graph.addEdge(START, "preLlm") + .addEdge("preLlm", "llm"); + } else { + graph.addEdge(START, "llm"); + } + + if (postLlmHook != null) { + graph.addEdge("llm", "postLlm") + .addConditionalEdges("postLlm", edge_async(this::think), Map.of( + "continue", preToolHook != null ? "preTool" : "tool", + "end", END + )); + } else { + graph.addConditionalEdges("llm", edge_async(this::think), Map.of( + "continue", preToolHook != null ? "preTool" : "tool", + "end", END + )); + } + + // 添加工具相关边 + if (preToolHook != null) { + graph.addEdge("preTool", "tool"); + } + if (postToolHook != null) { + graph.addEdge("tool", "postTool") + .addEdge("postTool", preLlmHook != null ? "preLlm" : "llm"); + } else { + graph.addEdge("tool", preLlmHook != null ? "preLlm" : "llm"); + } + + return graph; } private String think(OverAllState state) { @@ -260,6 +332,14 @@ public static class Builder { private Function shouldContinueFunc; + private NodeAction preLlmHook; + + private NodeAction postLlmHook; + + private NodeAction preToolHook; + + private NodeAction postToolHook; + public Builder name(String name) { this.name = name; return this; @@ -300,16 +380,42 @@ public Builder shouldContinueFunction(Function shouldCont return this; } + public Builder preLlmHook(NodeAction preLlmHook) { + this.preLlmHook = preLlmHook; + return this; + } + + public Builder postLlmHook(NodeAction postLlmHook) { + this.postLlmHook = postLlmHook; + return this; + } + + public Builder preToolHook(NodeAction preToolHook) { + this.preToolHook = preToolHook; + return this; + } + + public Builder postToolHook(NodeAction postToolHook) { + this.postToolHook = postToolHook; + return this; + } + public ReactAgent build() throws GraphStateException { + LlmNode llmNode = LlmNode.builder() + .chatClient(chatClient) + .messagesKey("messages") + .build(); + ToolNode toolNode = null; if (resolver != null) { - return new ReactAgent(name, chatClient, resolver, maxIterations, allStateFactory, compileConfig, - shouldContinueFunc); + toolNode = ToolNode.builder().toolCallbackResolver(resolver).build(); + } else if (tools != null) { + toolNode = ToolNode.builder().toolCallbacks(tools).build(); + } else { + throw new IllegalArgumentException("Either tools or resolver must be provided"); } - else if (tools != null) { - return new ReactAgent(name, chatClient, tools, maxIterations, allStateFactory, compileConfig, - shouldContinueFunc); - } - throw new IllegalArgumentException("Either tools or resolver must be provided"); + + return new ReactAgent(name, llmNode, toolNode, maxIterations, allStateFactory, compileConfig, + shouldContinueFunc, preLlmHook, postLlmHook, preToolHook, postToolHook); } } diff --git a/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/agent/ReactAgentTest.java b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/agent/ReactAgentTest.java new file mode 100644 index 0000000000..ee87075215 --- /dev/null +++ b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/agent/ReactAgentTest.java @@ -0,0 +1,274 @@ +package com.alibaba.cloud.ai.graph.agent; + +import com.alibaba.cloud.ai.graph.CompiledGraph; +import com.alibaba.cloud.ai.graph.OverAllState; +import com.alibaba.cloud.ai.graph.OverAllStateFactory; +import com.alibaba.cloud.ai.graph.action.NodeAction; +import com.alibaba.cloud.ai.graph.state.strategy.AppendStrategy; +import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy; +import com.alibaba.cloud.ai.graph.node.ToolNode; +import com.alibaba.cloud.ai.graph.action.AsyncNodeAction; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.definition.DefaultToolDefinition; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.DefaultUsage; +import org.springframework.ai.chat.messages.AssistantMessage.ToolCall; +import org.springframework.ai.content.Media; +import org.springframework.ai.tool.resolution.ToolCallbackResolver; +import org.springframework.ai.chat.model.ToolContext; + +import java.util.*; +import java.util.concurrent.CompletableFuture; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.Mockito.when; +import static org.mockito.ArgumentMatchers.anyString; + +class ReactAgentTest { + + @Mock + private ChatClient chatClient; + + @Mock + private ChatClient.ChatClientRequestSpec requestSpec; + + @Mock + private ChatClient.CallResponseSpec responseSpec; + + @Mock + private ChatResponse chatResponse; + + @Mock + private ToolCallbackResolver toolCallbackResolver; + + @Mock + private ToolCallback toolCallback; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + + // Configure mock ChatClient with complete call chain + when(chatClient.prompt()).thenReturn(requestSpec); + when(requestSpec.messages(anyList())).thenReturn(requestSpec); + when(requestSpec.advisors(anyList())).thenReturn(requestSpec); + when(requestSpec.toolCallbacks(anyList())).thenReturn(requestSpec); + when(requestSpec.call()).thenReturn(responseSpec); + + // Configure mock ToolCallbackResolver + when(toolCallbackResolver.resolve(anyString())).thenReturn(toolCallback); + when(toolCallback.call(anyString(), any(ToolContext.class))).thenReturn("test tool response"); + when(toolCallback.getToolDefinition()).thenReturn(DefaultToolDefinition.builder() + .name("test_function") + .description("A test function") + .inputSchema("{\"type\": \"object\", \"properties\": {\"arg1\": {\"type\": \"string\"}}}") + .build()); + + // Configure mock ChatResponse with ToolCalls + Map metadata = new HashMap<>(); + metadata.put("finishReason", "stop"); + List toolCalls = List.of( + new ToolCall("call_1", "function", "test_function", "{\"arg1\": \"value1\"}") + ); + AssistantMessage assistantMessage = new AssistantMessage( + "test response", + metadata, + toolCalls, + Collections.emptyList() + ); + ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.builder() + .finishReason("stop") + .build(); + Generation generation = new Generation(assistantMessage, generationMetadata); + ChatResponseMetadata responseMetadata = ChatResponseMetadata.builder() + .id("test-id") + .usage(new DefaultUsage(10, 20, 30)) + .build(); + ChatResponse response = ChatResponse.builder() + .generations(List.of(generation)) + .metadata(responseMetadata) + .build(); + when(responseSpec.chatResponse()).thenReturn(response); + } + + /** + * Tests ReactAgent with preLlmHook that modifies system prompt before LLM call. + */ + @Test + public void testReactAgentWithPreLlmHook() throws Exception { + Map prellmStore = new HashMap<>(); + + ReactAgent agent = ReactAgent.builder() + .name("testAgent") + .chatClient(chatClient) + .state(() -> new OverAllState().registerKeyAndStrategy("messages", new AppendStrategy())) + .resolver(toolCallbackResolver) + .preLlmHook(state -> { + prellmStore.put("timestamp", String.valueOf(System.currentTimeMillis())); + return Map.of(); + }) + .build(); + + CompiledGraph graph = agent.getAndCompileGraph(); + try { + Optional invoke = graph.invoke(Map.of("messages", List.of(new UserMessage("test")))); + }catch (java.util.concurrent.CompletionException e){ + + } + assertNotNull(prellmStore.get("timestamp")); + + } + + /** + * Tests ReactAgent with postLlmHook that processes LLM response. + */ + @Test + public void testReactAgentWithPostLlmHook() throws Exception { + // Create a map to store processed responses + Map responseStore = new HashMap<>(); + + ReactAgent agent = ReactAgent.builder() + .name("testAgent") + .chatClient(chatClient) + .state(() -> new OverAllState().registerKeyAndStrategy("messages", new AppendStrategy())) + .resolver(toolCallbackResolver) + + .postLlmHook(state -> { + responseStore.put("response", "Processed: " + state.value("messages")); + return Map.of(); + }) + .build(); + + CompiledGraph graph = agent.getAndCompileGraph(); + try { + Optional invoke = graph.invoke(Map.of("messages", List.of(new UserMessage("test")))); + }catch (java.util.concurrent.CompletionException e){ + + } + assertNotNull(responseStore.get("response")); + } + + /** + * Tests ReactAgent with preToolHook that prepares tool parameters. + */ + @Test + public void testReactAgentWithPreToolHook() throws Exception { + // Create a map to store tool parameters + Map toolParams = new HashMap<>(); + + ReactAgent agent = ReactAgent.builder() + .name("testAgent") + .chatClient(chatClient) + .state(() -> new OverAllState() + .registerKeyAndStrategy("toolParams",new ReplaceStrategy()) + .registerKeyAndStrategy("messages", new AppendStrategy())) + .resolver(toolCallbackResolver) + .preToolHook(state -> { + toolParams.put("timestamp", System.currentTimeMillis()); + return Map.of(); + }) + .build(); + + CompiledGraph graph = agent.getAndCompileGraph(); + try { + Optional invoke = graph.invoke(Map.of("messages", List.of(new UserMessage("test")))); + }catch (java.util.concurrent.CompletionException e){ + + } + assertNotNull(toolParams.get("timestamp")); + } + + /** + * Tests ReactAgent with postToolHook that collects tool results. + */ + @Test + public void testReactAgentWithPostToolHook() throws Exception { + // Create a map to store tool results + Map toolResults = new HashMap<>(); + + ReactAgent agent = ReactAgent.builder() + .name("testAgent") + .chatClient(chatClient) + .resolver(toolCallbackResolver) + .state(() -> new OverAllState() + .registerKeyAndStrategy("messages", new AppendStrategy()) + .registerKeyAndStrategy("toolOutput", new ReplaceStrategy())) + .postToolHook(state -> { + toolResults.put("result", "collected: " + "tool output"); + return Map.of(); + }) + .build(); + + CompiledGraph graph = agent.getAndCompileGraph(); + try { + Optional invoke = graph.invoke(Map.of("messages", List.of(new UserMessage("test")))); + }catch (java.util.concurrent.CompletionException e){ + + } + assertNotNull(toolResults.get("result")); + } + + @Test + public void testReactAgentWithAllHooks() throws Exception { + // Create maps to store results from each hook + Map prellmStore = new HashMap<>(); + Map responseStore = new HashMap<>(); + Map toolParams = new HashMap<>(); + Map toolResults = new HashMap<>(); + + ReactAgent agent = ReactAgent.builder() + .name("testAgent") + .chatClient(chatClient) + .state(() -> new OverAllState() + .registerKeyAndStrategy("messages", new AppendStrategy()) + .registerKeyAndStrategy("toolParams", new ReplaceStrategy()) + .registerKeyAndStrategy("toolOutput", new ReplaceStrategy())) + .resolver(toolCallbackResolver) + .preLlmHook(state -> { + prellmStore.put("timestamp", String.valueOf(System.currentTimeMillis())); + return Map.of(); + }) + .postLlmHook(state -> { + responseStore.put("response", "Processed: " + state.value("messages")); + return Map.of(); + }) + .preToolHook(state -> { + toolParams.put("timestamp", System.currentTimeMillis()); + return Map.of(); + }) + .postToolHook(state -> { + toolResults.put("result", "collected: " + "tool output"); + return Map.of(); + }) + .build(); + + CompiledGraph graph = agent.getAndCompileGraph(); + try { + Optional invoke = graph.invoke(Map.of("messages", List.of(new UserMessage("test")))); + } catch (java.util.concurrent.CompletionException e) { + // Ignore max iterations exception + } + + // Verify all hooks were executed + assertNotNull(prellmStore.get("timestamp"), "PreLLM hook should store timestamp"); + assertNotNull(responseStore.get("response"), "PostLLM hook should store response"); + assertNotNull(toolParams.get("timestamp"), "PreTool hook should store timestamp"); + assertNotNull(toolResults.get("result"), "PostTool hook should store result"); + } + +} \ No newline at end of file From 36120505abc89a7a084995ee7ca6a9e53263a415 Mon Sep 17 00:00:00 2001 From: zhouyou9505 Date: Tue, 10 Jun 2025 09:56:01 +0800 Subject: [PATCH 2/4] feat jmanus graph --- .../cloud/ai/graph/agent/ReactAgentTest.java | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/agent/ReactAgentTest.java b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/agent/ReactAgentTest.java index ee87075215..13663f49de 100644 --- a/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/agent/ReactAgentTest.java +++ b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/agent/ReactAgentTest.java @@ -1,3 +1,18 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.alibaba.cloud.ai.graph.agent; import com.alibaba.cloud.ai.graph.CompiledGraph; From f1fd11c073e818b30bfac743fd9dc48a87f2afed Mon Sep 17 00:00:00 2001 From: zhouyou9505 Date: Tue, 10 Jun 2025 10:45:32 +0800 Subject: [PATCH 3/4] feat jmanus graph --- .../cloud/ai/graph/agent/ReactAgentTest.java | 30 ++++++------------- 1 file changed, 9 insertions(+), 21 deletions(-) diff --git a/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/agent/ReactAgentTest.java b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/agent/ReactAgentTest.java index 13663f49de..b7cb0ba3cc 100644 --- a/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/agent/ReactAgentTest.java +++ b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/agent/ReactAgentTest.java @@ -17,42 +17,31 @@ import com.alibaba.cloud.ai.graph.CompiledGraph; import com.alibaba.cloud.ai.graph.OverAllState; -import com.alibaba.cloud.ai.graph.OverAllStateFactory; -import com.alibaba.cloud.ai.graph.action.NodeAction; import com.alibaba.cloud.ai.graph.state.strategy.AppendStrategy; import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy; -import com.alibaba.cloud.ai.graph.node.ToolNode; -import com.alibaba.cloud.ai.graph.action.AsyncNodeAction; -import org.springframework.ai.chat.messages.SystemMessage; -import org.springframework.ai.tool.definition.ToolDefinition; -import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.AssistantMessage.ToolCall; import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.tool.ToolCallback; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.metadata.DefaultUsage; -import org.springframework.ai.chat.messages.AssistantMessage.ToolCall; -import org.springframework.ai.content.Media; -import org.springframework.ai.tool.resolution.ToolCallbackResolver; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.model.ToolContext; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.DefaultToolDefinition; +import org.springframework.ai.tool.resolution.ToolCallbackResolver; import java.util.*; -import java.util.concurrent.CompletableFuture; -import static org.junit.jupiter.api.Assertions.*; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyList; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.mockito.ArgumentMatchers.*; import static org.mockito.Mockito.when; -import static org.mockito.ArgumentMatchers.anyString; class ReactAgentTest { @@ -285,5 +274,4 @@ public void testReactAgentWithAllHooks() throws Exception { assertNotNull(toolParams.get("timestamp"), "PreTool hook should store timestamp"); assertNotNull(toolResults.get("result"), "PostTool hook should store result"); } - -} \ No newline at end of file +} From 28407f43191dad6ed24cd4d5c857dc7d1e4ad010 Mon Sep 17 00:00:00 2001 From: zhouyou9505 Date: Tue, 10 Jun 2025 10:59:46 +0800 Subject: [PATCH 4/4] feat jmanus graph --- .../cloud/ai/graph/agent/ReactAgent.java | 47 +- .../cloud/ai/graph/agent/ReactAgentTest.java | 454 +++++++++--------- 2 files changed, 247 insertions(+), 254 deletions(-) diff --git a/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/agent/ReactAgent.java b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/agent/ReactAgent.java index 44852107c7..ac65a6d1cf 100644 --- a/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/agent/ReactAgent.java +++ b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/agent/ReactAgent.java @@ -73,9 +73,10 @@ public class ReactAgent { private Function shouldContinueFunc; - private ReactAgent(String name,LlmNode llmNode, ToolNode toolNode, int maxIterations, OverAllStateFactory overAllStateFactory, - CompileConfig compileConfig, Function shouldContinueFunc, - NodeAction preLlmHook, NodeAction postLlmHook, NodeAction preToolHook, NodeAction postToolHook) throws GraphStateException { + private ReactAgent(String name, LlmNode llmNode, ToolNode toolNode, int maxIterations, + OverAllStateFactory overAllStateFactory, CompileConfig compileConfig, + Function shouldContinueFunc, NodeAction preLlmHook, NodeAction postLlmHook, + NodeAction preToolHook, NodeAction postToolHook) throws GraphStateException { this.name = name; this.llmNode = llmNode; this.toolNode = toolNode; @@ -91,7 +92,7 @@ private ReactAgent(String name,LlmNode llmNode, ToolNode toolNode, int maxIterat } public ReactAgent(LlmNode llmNode, ToolNode toolNode, int maxIterations, OverAllStateFactory overAllStateFactory, - CompileConfig compileConfig, Function shouldContinueFunc) + CompileConfig compileConfig, Function shouldContinueFunc) throws GraphStateException { this.llmNode = llmNode; this.toolNode = toolNode; @@ -211,23 +212,20 @@ private StateGraph initGraph() throws GraphStateException { } if (preLlmHook != null) { - graph.addEdge(START, "preLlm") - .addEdge("preLlm", "llm"); - } else { + graph.addEdge(START, "preLlm").addEdge("preLlm", "llm"); + } + else { graph.addEdge(START, "llm"); } if (postLlmHook != null) { graph.addEdge("llm", "postLlm") - .addConditionalEdges("postLlm", edge_async(this::think), Map.of( - "continue", preToolHook != null ? "preTool" : "tool", - "end", END - )); - } else { - graph.addConditionalEdges("llm", edge_async(this::think), Map.of( - "continue", preToolHook != null ? "preTool" : "tool", - "end", END - )); + .addConditionalEdges("postLlm", edge_async(this::think), + Map.of("continue", preToolHook != null ? "preTool" : "tool", "end", END)); + } + else { + graph.addConditionalEdges("llm", edge_async(this::think), + Map.of("continue", preToolHook != null ? "preTool" : "tool", "end", END)); } // 添加工具相关边 @@ -235,9 +233,9 @@ private StateGraph initGraph() throws GraphStateException { graph.addEdge("preTool", "tool"); } if (postToolHook != null) { - graph.addEdge("tool", "postTool") - .addEdge("postTool", preLlmHook != null ? "preLlm" : "llm"); - } else { + graph.addEdge("tool", "postTool").addEdge("postTool", preLlmHook != null ? "preLlm" : "llm"); + } + else { graph.addEdge("tool", preLlmHook != null ? "preLlm" : "llm"); } @@ -401,16 +399,15 @@ public Builder postToolHook(NodeAction postToolHook) { } public ReactAgent build() throws GraphStateException { - LlmNode llmNode = LlmNode.builder() - .chatClient(chatClient) - .messagesKey("messages") - .build(); + LlmNode llmNode = LlmNode.builder().chatClient(chatClient).messagesKey("messages").build(); ToolNode toolNode = null; if (resolver != null) { toolNode = ToolNode.builder().toolCallbackResolver(resolver).build(); - } else if (tools != null) { + } + else if (tools != null) { toolNode = ToolNode.builder().toolCallbacks(tools).build(); - } else { + } + else { throw new IllegalArgumentException("Either tools or resolver must be provided"); } diff --git a/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/agent/ReactAgentTest.java b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/agent/ReactAgentTest.java index b7cb0ba3cc..b4e8d03c2f 100644 --- a/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/agent/ReactAgentTest.java +++ b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/agent/ReactAgentTest.java @@ -45,233 +45,229 @@ class ReactAgentTest { - @Mock - private ChatClient chatClient; - - @Mock - private ChatClient.ChatClientRequestSpec requestSpec; - - @Mock - private ChatClient.CallResponseSpec responseSpec; - - @Mock - private ChatResponse chatResponse; - - @Mock - private ToolCallbackResolver toolCallbackResolver; - - @Mock - private ToolCallback toolCallback; - - @BeforeEach - void setUp() { - MockitoAnnotations.openMocks(this); - - // Configure mock ChatClient with complete call chain - when(chatClient.prompt()).thenReturn(requestSpec); - when(requestSpec.messages(anyList())).thenReturn(requestSpec); - when(requestSpec.advisors(anyList())).thenReturn(requestSpec); - when(requestSpec.toolCallbacks(anyList())).thenReturn(requestSpec); - when(requestSpec.call()).thenReturn(responseSpec); - - // Configure mock ToolCallbackResolver - when(toolCallbackResolver.resolve(anyString())).thenReturn(toolCallback); - when(toolCallback.call(anyString(), any(ToolContext.class))).thenReturn("test tool response"); - when(toolCallback.getToolDefinition()).thenReturn(DefaultToolDefinition.builder() - .name("test_function") - .description("A test function") - .inputSchema("{\"type\": \"object\", \"properties\": {\"arg1\": {\"type\": \"string\"}}}") - .build()); - - // Configure mock ChatResponse with ToolCalls - Map metadata = new HashMap<>(); - metadata.put("finishReason", "stop"); - List toolCalls = List.of( - new ToolCall("call_1", "function", "test_function", "{\"arg1\": \"value1\"}") - ); - AssistantMessage assistantMessage = new AssistantMessage( - "test response", - metadata, - toolCalls, - Collections.emptyList() - ); - ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.builder() - .finishReason("stop") - .build(); - Generation generation = new Generation(assistantMessage, generationMetadata); - ChatResponseMetadata responseMetadata = ChatResponseMetadata.builder() - .id("test-id") - .usage(new DefaultUsage(10, 20, 30)) - .build(); - ChatResponse response = ChatResponse.builder() - .generations(List.of(generation)) - .metadata(responseMetadata) - .build(); - when(responseSpec.chatResponse()).thenReturn(response); - } - - /** - * Tests ReactAgent with preLlmHook that modifies system prompt before LLM call. - */ - @Test - public void testReactAgentWithPreLlmHook() throws Exception { - Map prellmStore = new HashMap<>(); - - ReactAgent agent = ReactAgent.builder() - .name("testAgent") - .chatClient(chatClient) - .state(() -> new OverAllState().registerKeyAndStrategy("messages", new AppendStrategy())) - .resolver(toolCallbackResolver) - .preLlmHook(state -> { - prellmStore.put("timestamp", String.valueOf(System.currentTimeMillis())); - return Map.of(); - }) - .build(); - - CompiledGraph graph = agent.getAndCompileGraph(); - try { - Optional invoke = graph.invoke(Map.of("messages", List.of(new UserMessage("test")))); - }catch (java.util.concurrent.CompletionException e){ - - } - assertNotNull(prellmStore.get("timestamp")); - - } - - /** - * Tests ReactAgent with postLlmHook that processes LLM response. - */ - @Test - public void testReactAgentWithPostLlmHook() throws Exception { - // Create a map to store processed responses - Map responseStore = new HashMap<>(); - - ReactAgent agent = ReactAgent.builder() - .name("testAgent") - .chatClient(chatClient) - .state(() -> new OverAllState().registerKeyAndStrategy("messages", new AppendStrategy())) - .resolver(toolCallbackResolver) - - .postLlmHook(state -> { - responseStore.put("response", "Processed: " + state.value("messages")); - return Map.of(); - }) - .build(); - - CompiledGraph graph = agent.getAndCompileGraph(); - try { - Optional invoke = graph.invoke(Map.of("messages", List.of(new UserMessage("test")))); - }catch (java.util.concurrent.CompletionException e){ - - } - assertNotNull(responseStore.get("response")); - } - - /** - * Tests ReactAgent with preToolHook that prepares tool parameters. - */ - @Test - public void testReactAgentWithPreToolHook() throws Exception { - // Create a map to store tool parameters - Map toolParams = new HashMap<>(); - - ReactAgent agent = ReactAgent.builder() - .name("testAgent") - .chatClient(chatClient) - .state(() -> new OverAllState() - .registerKeyAndStrategy("toolParams",new ReplaceStrategy()) - .registerKeyAndStrategy("messages", new AppendStrategy())) - .resolver(toolCallbackResolver) - .preToolHook(state -> { - toolParams.put("timestamp", System.currentTimeMillis()); - return Map.of(); - }) - .build(); - - CompiledGraph graph = agent.getAndCompileGraph(); - try { - Optional invoke = graph.invoke(Map.of("messages", List.of(new UserMessage("test")))); - }catch (java.util.concurrent.CompletionException e){ - - } - assertNotNull(toolParams.get("timestamp")); - } - - /** - * Tests ReactAgent with postToolHook that collects tool results. - */ - @Test - public void testReactAgentWithPostToolHook() throws Exception { - // Create a map to store tool results - Map toolResults = new HashMap<>(); - - ReactAgent agent = ReactAgent.builder() - .name("testAgent") - .chatClient(chatClient) - .resolver(toolCallbackResolver) - .state(() -> new OverAllState() - .registerKeyAndStrategy("messages", new AppendStrategy()) - .registerKeyAndStrategy("toolOutput", new ReplaceStrategy())) - .postToolHook(state -> { - toolResults.put("result", "collected: " + "tool output"); - return Map.of(); - }) - .build(); - - CompiledGraph graph = agent.getAndCompileGraph(); - try { - Optional invoke = graph.invoke(Map.of("messages", List.of(new UserMessage("test")))); - }catch (java.util.concurrent.CompletionException e){ - - } - assertNotNull(toolResults.get("result")); - } - - @Test - public void testReactAgentWithAllHooks() throws Exception { - // Create maps to store results from each hook - Map prellmStore = new HashMap<>(); - Map responseStore = new HashMap<>(); - Map toolParams = new HashMap<>(); - Map toolResults = new HashMap<>(); - - ReactAgent agent = ReactAgent.builder() - .name("testAgent") - .chatClient(chatClient) - .state(() -> new OverAllState() - .registerKeyAndStrategy("messages", new AppendStrategy()) - .registerKeyAndStrategy("toolParams", new ReplaceStrategy()) - .registerKeyAndStrategy("toolOutput", new ReplaceStrategy())) - .resolver(toolCallbackResolver) - .preLlmHook(state -> { - prellmStore.put("timestamp", String.valueOf(System.currentTimeMillis())); - return Map.of(); - }) - .postLlmHook(state -> { - responseStore.put("response", "Processed: " + state.value("messages")); - return Map.of(); - }) - .preToolHook(state -> { - toolParams.put("timestamp", System.currentTimeMillis()); - return Map.of(); - }) - .postToolHook(state -> { - toolResults.put("result", "collected: " + "tool output"); - return Map.of(); - }) - .build(); - - CompiledGraph graph = agent.getAndCompileGraph(); - try { - Optional invoke = graph.invoke(Map.of("messages", List.of(new UserMessage("test")))); - } catch (java.util.concurrent.CompletionException e) { - // Ignore max iterations exception - } - - // Verify all hooks were executed - assertNotNull(prellmStore.get("timestamp"), "PreLLM hook should store timestamp"); - assertNotNull(responseStore.get("response"), "PostLLM hook should store response"); - assertNotNull(toolParams.get("timestamp"), "PreTool hook should store timestamp"); - assertNotNull(toolResults.get("result"), "PostTool hook should store result"); - } + @Mock + private ChatClient chatClient; + + @Mock + private ChatClient.ChatClientRequestSpec requestSpec; + + @Mock + private ChatClient.CallResponseSpec responseSpec; + + @Mock + private ChatResponse chatResponse; + + @Mock + private ToolCallbackResolver toolCallbackResolver; + + @Mock + private ToolCallback toolCallback; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + + // Configure mock ChatClient with complete call chain + when(chatClient.prompt()).thenReturn(requestSpec); + when(requestSpec.messages(anyList())).thenReturn(requestSpec); + when(requestSpec.advisors(anyList())).thenReturn(requestSpec); + when(requestSpec.toolCallbacks(anyList())).thenReturn(requestSpec); + when(requestSpec.call()).thenReturn(responseSpec); + + // Configure mock ToolCallbackResolver + when(toolCallbackResolver.resolve(anyString())).thenReturn(toolCallback); + when(toolCallback.call(anyString(), any(ToolContext.class))).thenReturn("test tool response"); + when(toolCallback.getToolDefinition()).thenReturn(DefaultToolDefinition.builder() + .name("test_function") + .description("A test function") + .inputSchema("{\"type\": \"object\", \"properties\": {\"arg1\": {\"type\": \"string\"}}}") + .build()); + + // Configure mock ChatResponse with ToolCalls + Map metadata = new HashMap<>(); + metadata.put("finishReason", "stop"); + List toolCalls = List + .of(new ToolCall("call_1", "function", "test_function", "{\"arg1\": \"value1\"}")); + AssistantMessage assistantMessage = new AssistantMessage("test response", metadata, toolCalls, + Collections.emptyList()); + ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.builder().finishReason("stop").build(); + Generation generation = new Generation(assistantMessage, generationMetadata); + ChatResponseMetadata responseMetadata = ChatResponseMetadata.builder() + .id("test-id") + .usage(new DefaultUsage(10, 20, 30)) + .build(); + ChatResponse response = ChatResponse.builder() + .generations(List.of(generation)) + .metadata(responseMetadata) + .build(); + when(responseSpec.chatResponse()).thenReturn(response); + } + + /** + * Tests ReactAgent with preLlmHook that modifies system prompt before LLM call. + */ + @Test + public void testReactAgentWithPreLlmHook() throws Exception { + Map prellmStore = new HashMap<>(); + + ReactAgent agent = ReactAgent.builder() + .name("testAgent") + .chatClient(chatClient) + .state(() -> new OverAllState().registerKeyAndStrategy("messages", new AppendStrategy())) + .resolver(toolCallbackResolver) + .preLlmHook(state -> { + prellmStore.put("timestamp", String.valueOf(System.currentTimeMillis())); + return Map.of(); + }) + .build(); + + CompiledGraph graph = agent.getAndCompileGraph(); + try { + Optional invoke = graph.invoke(Map.of("messages", List.of(new UserMessage("test")))); + } + catch (java.util.concurrent.CompletionException e) { + + } + assertNotNull(prellmStore.get("timestamp")); + + } + + /** + * Tests ReactAgent with postLlmHook that processes LLM response. + */ + @Test + public void testReactAgentWithPostLlmHook() throws Exception { + // Create a map to store processed responses + Map responseStore = new HashMap<>(); + + ReactAgent agent = ReactAgent.builder() + .name("testAgent") + .chatClient(chatClient) + .state(() -> new OverAllState().registerKeyAndStrategy("messages", new AppendStrategy())) + .resolver(toolCallbackResolver) + + .postLlmHook(state -> { + responseStore.put("response", "Processed: " + state.value("messages")); + return Map.of(); + }) + .build(); + + CompiledGraph graph = agent.getAndCompileGraph(); + try { + Optional invoke = graph.invoke(Map.of("messages", List.of(new UserMessage("test")))); + } + catch (java.util.concurrent.CompletionException e) { + + } + assertNotNull(responseStore.get("response")); + } + + /** + * Tests ReactAgent with preToolHook that prepares tool parameters. + */ + @Test + public void testReactAgentWithPreToolHook() throws Exception { + // Create a map to store tool parameters + Map toolParams = new HashMap<>(); + + ReactAgent agent = ReactAgent.builder() + .name("testAgent") + .chatClient(chatClient) + .state(() -> new OverAllState().registerKeyAndStrategy("toolParams", new ReplaceStrategy()) + .registerKeyAndStrategy("messages", new AppendStrategy())) + .resolver(toolCallbackResolver) + .preToolHook(state -> { + toolParams.put("timestamp", System.currentTimeMillis()); + return Map.of(); + }) + .build(); + + CompiledGraph graph = agent.getAndCompileGraph(); + try { + Optional invoke = graph.invoke(Map.of("messages", List.of(new UserMessage("test")))); + } + catch (java.util.concurrent.CompletionException e) { + + } + assertNotNull(toolParams.get("timestamp")); + } + + /** + * Tests ReactAgent with postToolHook that collects tool results. + */ + @Test + public void testReactAgentWithPostToolHook() throws Exception { + // Create a map to store tool results + Map toolResults = new HashMap<>(); + + ReactAgent agent = ReactAgent.builder() + .name("testAgent") + .chatClient(chatClient) + .resolver(toolCallbackResolver) + .state(() -> new OverAllState().registerKeyAndStrategy("messages", new AppendStrategy()) + .registerKeyAndStrategy("toolOutput", new ReplaceStrategy())) + .postToolHook(state -> { + toolResults.put("result", "collected: " + "tool output"); + return Map.of(); + }) + .build(); + + CompiledGraph graph = agent.getAndCompileGraph(); + try { + Optional invoke = graph.invoke(Map.of("messages", List.of(new UserMessage("test")))); + } + catch (java.util.concurrent.CompletionException e) { + + } + assertNotNull(toolResults.get("result")); + } + + @Test + public void testReactAgentWithAllHooks() throws Exception { + // Create maps to store results from each hook + Map prellmStore = new HashMap<>(); + Map responseStore = new HashMap<>(); + Map toolParams = new HashMap<>(); + Map toolResults = new HashMap<>(); + + ReactAgent agent = ReactAgent.builder() + .name("testAgent") + .chatClient(chatClient) + .state(() -> new OverAllState().registerKeyAndStrategy("messages", new AppendStrategy()) + .registerKeyAndStrategy("toolParams", new ReplaceStrategy()) + .registerKeyAndStrategy("toolOutput", new ReplaceStrategy())) + .resolver(toolCallbackResolver) + .preLlmHook(state -> { + prellmStore.put("timestamp", String.valueOf(System.currentTimeMillis())); + return Map.of(); + }) + .postLlmHook(state -> { + responseStore.put("response", "Processed: " + state.value("messages")); + return Map.of(); + }) + .preToolHook(state -> { + toolParams.put("timestamp", System.currentTimeMillis()); + return Map.of(); + }) + .postToolHook(state -> { + toolResults.put("result", "collected: " + "tool output"); + return Map.of(); + }) + .build(); + + CompiledGraph graph = agent.getAndCompileGraph(); + try { + Optional invoke = graph.invoke(Map.of("messages", List.of(new UserMessage("test")))); + } + catch (java.util.concurrent.CompletionException e) { + // Ignore max iterations exception + } + + // Verify all hooks were executed + assertNotNull(prellmStore.get("timestamp"), "PreLLM hook should store timestamp"); + assertNotNull(responseStore.get("response"), "PostLLM hook should store response"); + assertNotNull(toolParams.get("timestamp"), "PreTool hook should store timestamp"); + assertNotNull(toolResults.get("result"), "PostTool hook should store result"); + } + }