Skip to content

Commit ac44378

Browse files
authored
Merge branch 'main' into hyq_update
2 parents be80124 + d084448 commit ac44378

File tree

4 files changed

+52
-3
lines changed

4 files changed

+52
-3
lines changed

spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/CompiledGraph.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import com.alibaba.cloud.ai.graph.internal.edge.EdgeValue;
2727
import com.alibaba.cloud.ai.graph.internal.node.ParallelNode;
2828
import com.alibaba.cloud.ai.graph.state.StateSnapshot;
29+
import org.apache.commons.lang3.StringUtils;
2930
import org.bsc.async.AsyncGenerator;
3031
import org.slf4j.Logger;
3132
import org.slf4j.LoggerFactory;
@@ -668,6 +669,10 @@ private Optional<Data<Output>> getEmbedGenerator(Map<String, Object> partialStat
668669
if (data != null) {
669670

670671
if (data instanceof Map<?, ?>) {
672+
// FIX #102
673+
// Assume that the whatever used appender channel doesn't
674+
// accept duplicates
675+
// FIX #104: remove generator
671676
var partialStateWithoutGenerator = partialState.entrySet()
672677
.stream()
673678
.filter(e -> !Objects.equals(e.getKey(), generatorEntry.getKey()))
@@ -678,7 +683,8 @@ private Optional<Data<Output>> getEmbedGenerator(Map<String, Object> partialStat
678683

679684
currentState = OverAllState.updateState(intermediateState, (Map<String, Object>) data,
680685
keyStrategyMap);
681-
overAllState.updateState(intermediateState);
686+
687+
overAllState.updateState((Map<String, Object>) data);
682688
}
683689
else {
684690
throw new IllegalArgumentException("Embedded generator must return a Map");
@@ -707,7 +713,7 @@ private CompletableFuture<Data<Output>> evaluateAction(AsyncNodeActionWithConfig
707713
}
708714

709715
this.currentState = OverAllState.updateState(currentState, updateState, keyStrategyMap);
710-
overAllState.updateState(updateState);
716+
this.overAllState.updateState(updateState);
711717
var nextNodeCommand = nextNodeId(currentNodeId, overAllState, currentState, config);
712718
nextNodeId = nextNodeCommand.gotoNode();
713719
this.currentState = nextNodeCommand.update();

spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/OverAllState.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.Objects;
2424
import java.util.Optional;
2525
import java.util.Set;
26+
import java.util.function.BiConsumer;
2627
import java.util.function.BinaryOperator;
2728
import java.util.function.Function;
2829
import java.util.stream.Collector;
@@ -325,6 +326,23 @@ public Map<String, Object> updateState(Map<String, Object> partialState) {
325326
return data();
326327
}
327328

329+
/**
330+
* Updates the internal state based on a schema-defined strategy.
331+
* <p>
332+
* This method first validates the input state, then updates the partial state
333+
* according to the provided key strategies. The updated state is formed by merging
334+
* the original state and the modified partial state, removing any null values in the
335+
* process. The resulting entries are then used to update the internal data map.
336+
* @param state the base state to update; must not be null
337+
* @param partialState the partial state containing updates; may be null or empty
338+
* @param keyStrategies the mapping of keys to update strategies; used to transform
339+
* values
340+
*/
341+
public void updateStateBySchema(Map<String, Object> state, Map<String, Object> partialState,
342+
Map<String, KeyStrategy> keyStrategies) {
343+
updateState(updateState(state, partialState, keyStrategies));
344+
}
345+
328346
/**
329347
* Key verify boolean.
330348
* @return the boolean

spring-ai-alibaba-graph/spring-ai-alibaba-graph-example/src/main/java/com/alibaba/cloud/ai/example/graph/stream/LLmSearchStreamController.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import com.alibaba.cloud.ai.example.graph.stream.node.BaiduSearchNode;
1919
import com.alibaba.cloud.ai.example.graph.stream.node.LLmNode;
20+
import com.alibaba.cloud.ai.example.graph.stream.node.ResultNode;
2021
import com.alibaba.cloud.ai.example.graph.stream.node.TavilySearchNode;
2122
import com.alibaba.cloud.ai.graph.*;
2223
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
@@ -64,19 +65,24 @@ public class LLmSearchStreamController {
6465
@Autowired
6566
private TavilySearchNode tavilySearchNode;
6667

68+
@Autowired
69+
private ResultNode resultNode;
70+
6771
@PostConstruct
6872
public void init() throws GraphStateException {
6973
workflow = new StateGraph(
7074
() -> new OverAllState().registerKeyAndStrategy("parallel_result", new AppendStrategy())
7175
.registerKeyAndStrategy("messages", new AppendStrategy()))
7276
.addNode("baiduSearchNode", node_async(baiduSearchNode))
7377
.addNode("tavilySearchNode", node_async(tavilySearchNode))
78+
.addNode("resultNode", node_async(resultNode))
7479
.addNode("llmNode", node_async(lLmNode))
7580
.addEdge(START, "baiduSearchNode")
7681
.addEdge(START, "tavilySearchNode")
7782
.addEdge("baiduSearchNode", "llmNode")
7883
.addEdge("tavilySearchNode", "llmNode")
79-
.addEdge("llmNode", END);
84+
.addEdge("llmNode", "resultNode")
85+
.addEdge("resultNode", END);
8086

8187
}
8288

@@ -98,6 +104,7 @@ public void searchChat(HttpServletRequest request, HttpServletResponse response,
98104
CompletableFuture.runAsync(() -> {
99105
try (PrintWriter writer = response.getWriter()) {
100106
generator.forEachAsync(output -> {
107+
System.out.println("output = " + output);
101108
try {
102109
if (output instanceof StreamingOutput) {
103110
writer.write("data: " + ((StreamingOutput) output).chunk() + "\n\n");
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package com.alibaba.cloud.ai.example.graph.stream.node;
2+
3+
import com.alibaba.cloud.ai.graph.OverAllState;
4+
import com.alibaba.cloud.ai.graph.action.NodeAction;
5+
import org.springframework.stereotype.Component;
6+
7+
import java.util.Map;
8+
9+
@Component
10+
public class ResultNode implements NodeAction {
11+
12+
@Override
13+
public Map<String, Object> apply(OverAllState state) throws Exception {
14+
System.out.println("messages = " + state.value("messages").get());
15+
return Map.of();
16+
}
17+
18+
}

0 commit comments

Comments
 (0)