From 8f11aa070450c33277a98f3a746643f8a7b3c005 Mon Sep 17 00:00:00 2001 From: Makoto <2762006003@qq.com> Date: Tue, 16 Sep 2025 21:41:11 +0800 Subject: [PATCH 1/2] refactor(nl2sql): modify the nl2sql graph invocation method and add human feedback scheduling and node test cases --- .../spring-ai-alibaba-nl2sql-chat/README.md | 4 +- .../cloud/ai/service/Nl2SqlService.java | 20 ++--- .../HumanFeedbackDispatcherTest.java | 41 +++++++++ .../cloud/ai/node/HumanFeedbackNodeTest.java | 84 +++++++++++++++++++ 4 files changed, 137 insertions(+), 12 deletions(-) create mode 100644 spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/java/com/alibaba/cloud/ai/dispatcher/HumanFeedbackDispatcherTest.java create mode 100644 spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/java/com/alibaba/cloud/ai/node/HumanFeedbackNodeTest.java diff --git a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/README.md b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/README.md index 9e613ade03..5eabdd9f7a 100644 --- a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/README.md +++ b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/README.md @@ -510,8 +510,8 @@ public class Nl2sqlController { .setTables(Arrays.asList("categories", "order_items", "orders", "products", "users", "product_categories")); simpleVectorStoreService.schema(schemaInitRequest); - Optional invoke = compiledGraph.invoke(Map.of(INPUT_KEY, query)); - OverAllState overAllState = invoke.get(); + Optional call = compiledGraph.call(Map.of(INPUT_KEY, query)); + OverAllState overAllState = call.get(); return overAllState.value(RESULT).get().toString(); } diff --git a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/Nl2SqlService.java b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/Nl2SqlService.java index 3619ebc93e..ccb7d24fa3 100644 --- a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/Nl2SqlService.java +++ b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/Nl2SqlService.java @@ -69,9 +69,9 @@ public String nl2sql(String naturalQuery, String agentId) throws GraphRunnerExce agentId = ""; } Map stateMap = Map.of(IS_ONLY_NL2SQL, true, INPUT_KEY, naturalQuery, AGENT_ID, agentId); - Optional invoke = this.nl2sqlGraph.invoke(stateMap); - OverAllState state = invoke.orElseThrow(() -> { - logger.error("Nl2SqlService invoke fail, stateMap: {}", stateMap); + Optional call = this.nl2sqlGraph.call(stateMap); + OverAllState state = call.orElseThrow(() -> { + logger.error("Nl2SqlService call fail, stateMap: {}", stateMap); return new GraphRunnerException("图运行失败"); }); return state.value(ONLY_NL2SQL_OUTPUT, ""); @@ -96,14 +96,14 @@ public String nl2sql(String naturalQuery) throws GraphRunnerException { * @return CompletableFuture * @throws GraphRunnerException 图运行异常 */ - public CompletableFuture nl2sqlWithProcess(Consumer nl2SqlProcessConsumer, - String naturalQuery, String agentId, RunnableConfig runnableConfig) throws GraphRunnerException { + public CompletableFuture nl2sqlWithProcess(Consumer nl2SqlProcessConsumer, String naturalQuery, + String agentId, RunnableConfig runnableConfig) throws GraphRunnerException { Map stateMap = Map.of(IS_ONLY_NL2SQL, true, INPUT_KEY, naturalQuery, AGENT_ID, agentId); Consumer consumer = (output) -> { Nl2SqlProcess sqlProcess = this.nodeOutputToNl2sqlProcess(output); nl2SqlProcessConsumer.accept(sqlProcess); }; - return this.nl2sqlGraph.stream(stateMap, runnableConfig).forEachAsync(consumer); + return this.nl2sqlGraph.fluxStream(stateMap, runnableConfig).doOnNext(consumer::accept).then().toFuture(); } /** @@ -114,8 +114,8 @@ public CompletableFuture nl2sqlWithProcess(Consumer nl2Sq * @return CompletableFuture * @throws GraphRunnerException 图运行异常 */ - public CompletableFuture nl2sqlWithProcess(Consumer nl2SqlProcessConsumer, - String naturalQuery, String agentId) throws GraphRunnerException { + public CompletableFuture nl2sqlWithProcess(Consumer nl2SqlProcessConsumer, String naturalQuery, + String agentId) throws GraphRunnerException { return this.nl2sqlWithProcess(nl2SqlProcessConsumer, naturalQuery, agentId, RunnableConfig.builder().build()); } @@ -126,8 +126,8 @@ public CompletableFuture nl2sqlWithProcess(Consumer nl2Sq * @return CompletableFuture * @throws GraphRunnerException 图运行异常 */ - public CompletableFuture nl2sqlWithProcess(Consumer nl2SqlProcessConsumer, - String naturalQuery) throws GraphRunnerException { + public CompletableFuture nl2sqlWithProcess(Consumer nl2SqlProcessConsumer, String naturalQuery) + throws GraphRunnerException { return this.nl2sqlWithProcess(nl2SqlProcessConsumer, naturalQuery, ""); } diff --git a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/java/com/alibaba/cloud/ai/dispatcher/HumanFeedbackDispatcherTest.java b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/java/com/alibaba/cloud/ai/dispatcher/HumanFeedbackDispatcherTest.java new file mode 100644 index 0000000000..ec8d184dec --- /dev/null +++ b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/java/com/alibaba/cloud/ai/dispatcher/HumanFeedbackDispatcherTest.java @@ -0,0 +1,41 @@ +package com.alibaba.cloud.ai.dispatcher; + +import com.alibaba.cloud.ai.graph.OverAllState; +import com.alibaba.cloud.ai.graph.StateGraph; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +class HumanFeedbackDispatcherTest { + + private HumanFeedbackDispatcher dispatcher; + + private OverAllState state; + + @BeforeEach + void setUp() { + dispatcher = new HumanFeedbackDispatcher(); + state = new OverAllState(); + } + + @Test + void testWaitForFeedbackReturnsEND() throws Exception { + state.updateState(java.util.Map.of("human_next_node", "WAIT_FOR_FEEDBACK")); + String next = dispatcher.apply(state); + assertEquals(StateGraph.END, next); + } + + @Test + void testNormalRouting() throws Exception { + state.updateState(java.util.Map.of("human_next_node", "PLANNER_NODE")); + String next = dispatcher.apply(state); + assertEquals("PLANNER_NODE", next); + } + + @Test + void testDefaultToENDWhenMissingKey() throws Exception { + String next = dispatcher.apply(state); + assertEquals(StateGraph.END, next); + } +} diff --git a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/java/com/alibaba/cloud/ai/node/HumanFeedbackNodeTest.java b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/java/com/alibaba/cloud/ai/node/HumanFeedbackNodeTest.java new file mode 100644 index 0000000000..57151b690d --- /dev/null +++ b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/java/com/alibaba/cloud/ai/node/HumanFeedbackNodeTest.java @@ -0,0 +1,84 @@ +package com.alibaba.cloud.ai.node; + +import com.alibaba.cloud.ai.graph.OverAllState; +import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Map; + +import static com.alibaba.cloud.ai.constant.Constant.*; +import static org.junit.jupiter.api.Assertions.*; + +class HumanFeedbackNodeTest { + + private HumanFeedbackNode node; + + private OverAllState state; + + @BeforeEach + void setUp() { + node = new HumanFeedbackNode(); + state = new OverAllState(); + state.registerKeyAndStrategy(PLAN_REPAIR_COUNT, new ReplaceStrategy()); + state.registerKeyAndStrategy(PLAN_CURRENT_STEP, new ReplaceStrategy()); + state.registerKeyAndStrategy(HUMAN_REVIEW_ENABLED, new ReplaceStrategy()); + state.registerKeyAndStrategy(PLAN_VALIDATION_ERROR, new ReplaceStrategy()); + } + + @Test + void testApproveFlow() throws Exception { + state.withHumanFeedback(new OverAllState.HumanFeedback(Map.of( + "feed_back", true + ), null)); + + Map result = node.apply(state); + assertEquals(PLAN_EXECUTOR_NODE, result.get("human_next_node")); + assertEquals(false, result.get(HUMAN_REVIEW_ENABLED)); + assertFalse(result.containsKey(PLAN_VALIDATION_ERROR)); + } + + @Test + void testRejectFlowWithContent() throws Exception { + state.updateState(Map.of(PLAN_REPAIR_COUNT, 0)); + state.withHumanFeedback(new OverAllState.HumanFeedback(Map.of( + "feed_back", false, + "feed_back_content", "需要补充过滤条件" + ), null)); + + Map result = node.apply(state); + assertEquals(PLANNER_NODE, result.get("human_next_node")); + assertEquals(1, result.get(PLAN_REPAIR_COUNT)); + assertEquals(1, result.get(PLAN_CURRENT_STEP)); + assertEquals(true, result.get(HUMAN_REVIEW_ENABLED)); + assertEquals("需要补充过滤条件", result.get(PLAN_VALIDATION_ERROR)); + } + + @Test + void testRejectFlowWithoutContent() throws Exception { + state.updateState(Map.of(PLAN_REPAIR_COUNT, 2)); + state.withHumanFeedback(new OverAllState.HumanFeedback(Map.of( + "feed_back", false + ), null)); + + Map result = node.apply(state); + assertEquals(PLANNER_NODE, result.get("human_next_node")); + assertEquals(3, result.get(PLAN_REPAIR_COUNT)); + assertEquals(1, result.get(PLAN_CURRENT_STEP)); + assertEquals(true, result.get(HUMAN_REVIEW_ENABLED)); + assertEquals("Plan rejected by user", result.get(PLAN_VALIDATION_ERROR)); + } + + @Test + void testWaitForFeedback() throws Exception { + Map result = node.apply(state); + assertEquals("WAIT_FOR_FEEDBACK", result.get("human_next_node")); + } + + @Test + void testMaxRepairExceeded() throws Exception { + state.updateState(Map.of(PLAN_REPAIR_COUNT, 3)); + Map result = node.apply(state); + assertEquals("END", result.get("human_next_node")); + } +} From 9500a013fb0f51cf859cd4d0241974bad861fd24 Mon Sep 17 00:00:00 2001 From: Makoto <2762006003@qq.com> Date: Tue, 16 Sep 2025 22:08:12 +0800 Subject: [PATCH 2/2] fix format --- .../HumanFeedbackDispatcherTest.java | 17 ++++++++++ .../cloud/ai/node/HumanFeedbackNodeTest.java | 31 +++++++++++++------ 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/java/com/alibaba/cloud/ai/dispatcher/HumanFeedbackDispatcherTest.java b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/java/com/alibaba/cloud/ai/dispatcher/HumanFeedbackDispatcherTest.java index ec8d184dec..c58aed97bf 100644 --- a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/java/com/alibaba/cloud/ai/dispatcher/HumanFeedbackDispatcherTest.java +++ b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/java/com/alibaba/cloud/ai/dispatcher/HumanFeedbackDispatcherTest.java @@ -1,3 +1,19 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package com.alibaba.cloud.ai.dispatcher; import com.alibaba.cloud.ai.graph.OverAllState; @@ -38,4 +54,5 @@ void testDefaultToENDWhenMissingKey() throws Exception { String next = dispatcher.apply(state); assertEquals(StateGraph.END, next); } + } diff --git a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/java/com/alibaba/cloud/ai/node/HumanFeedbackNodeTest.java b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/java/com/alibaba/cloud/ai/node/HumanFeedbackNodeTest.java index 57151b690d..c56faa0bc5 100644 --- a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/java/com/alibaba/cloud/ai/node/HumanFeedbackNodeTest.java +++ b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/java/com/alibaba/cloud/ai/node/HumanFeedbackNodeTest.java @@ -1,3 +1,19 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package com.alibaba.cloud.ai.node; import com.alibaba.cloud.ai.graph.OverAllState; @@ -28,9 +44,7 @@ void setUp() { @Test void testApproveFlow() throws Exception { - state.withHumanFeedback(new OverAllState.HumanFeedback(Map.of( - "feed_back", true - ), null)); + state.withHumanFeedback(new OverAllState.HumanFeedback(Map.of("feed_back", true), null)); Map result = node.apply(state); assertEquals(PLAN_EXECUTOR_NODE, result.get("human_next_node")); @@ -41,10 +55,8 @@ void testApproveFlow() throws Exception { @Test void testRejectFlowWithContent() throws Exception { state.updateState(Map.of(PLAN_REPAIR_COUNT, 0)); - state.withHumanFeedback(new OverAllState.HumanFeedback(Map.of( - "feed_back", false, - "feed_back_content", "需要补充过滤条件" - ), null)); + state.withHumanFeedback( + new OverAllState.HumanFeedback(Map.of("feed_back", false, "feed_back_content", "需要补充过滤条件"), null)); Map result = node.apply(state); assertEquals(PLANNER_NODE, result.get("human_next_node")); @@ -57,9 +69,7 @@ void testRejectFlowWithContent() throws Exception { @Test void testRejectFlowWithoutContent() throws Exception { state.updateState(Map.of(PLAN_REPAIR_COUNT, 2)); - state.withHumanFeedback(new OverAllState.HumanFeedback(Map.of( - "feed_back", false - ), null)); + state.withHumanFeedback(new OverAllState.HumanFeedback(Map.of("feed_back", false), null)); Map result = node.apply(state); assertEquals(PLANNER_NODE, result.get("human_next_node")); @@ -81,4 +91,5 @@ void testMaxRepairExceeded() throws Exception { Map result = node.apply(state); assertEquals("END", result.get("human_next_node")); } + }