Skip to content

Commit cdbce16

Browse files
committed
Create new ToolChatOptions if input request has tool configuration without options
Signed-off-by: Filip Hrisafov <filip.hrisafov@gmail.com>
1 parent 49e5c63 commit cdbce16

File tree

3 files changed

+101
-21
lines changed

3 files changed

+101
-21
lines changed

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,14 @@
2222
import java.nio.charset.Charset;
2323
import java.util.ArrayList;
2424
import java.util.Arrays;
25+
import java.util.Collection;
26+
import java.util.Collections;
2527
import java.util.HashMap;
28+
import java.util.LinkedHashSet;
2629
import java.util.List;
2730
import java.util.Map;
2831
import java.util.Optional;
32+
import java.util.Set;
2933
import java.util.function.Consumer;
3034

3135
import io.micrometer.observation.Observation;
@@ -571,7 +575,7 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe
571575

572576
private final List<Media> media = new ArrayList<>();
573577

574-
private final List<String> toolNames = new ArrayList<>();
578+
private final Set<String> toolNames = new LinkedHashSet<>();
575579

576580
private final List<ToolCallback> toolCallbacks = new ArrayList<>();
577581

@@ -607,9 +611,9 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe
607611

608612
public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userText,
609613
Map<String, Object> userParams, @Nullable String systemText, Map<String, Object> systemParams,
610-
List<ToolCallback> toolCallbacks, List<Message> messages, List<String> toolNames, List<Media> media,
611-
@Nullable ChatOptions chatOptions, List<Advisor> advisors, Map<String, Object> advisorParams,
612-
ObservationRegistry observationRegistry,
614+
List<ToolCallback> toolCallbacks, List<Message> messages, Collection<String> toolNames,
615+
List<Media> media, @Nullable ChatOptions chatOptions, List<Advisor> advisors,
616+
Map<String, Object> advisorParams, ObservationRegistry observationRegistry,
613617
@Nullable ChatClientObservationConvention observationConvention, Map<String, Object> toolContext,
614618
@Nullable TemplateRenderer templateRenderer) {
615619

@@ -686,7 +690,7 @@ public List<Media> getMedia() {
686690
return this.media;
687691
}
688692

689-
public List<String> getToolNames() {
693+
public Set<String> getToolNames() {
690694
return this.toolNames;
691695
}
692696

@@ -702,6 +706,10 @@ public TemplateRenderer getTemplateRenderer() {
702706
return this.templateRenderer;
703707
}
704708

709+
public boolean hasToolConfiguration() {
710+
return !this.toolNames.isEmpty() || !this.toolCallbacks.isEmpty() || !this.toolContext.isEmpty();
711+
}
712+
705713
/**
706714
* Return a {@link ChatClient.Builder} to create a new {@link ChatClient} whose
707715
* settings are replicated from this {@link ChatClientRequest}.
@@ -784,7 +792,7 @@ public <T extends ChatOptions> ChatClientRequestSpec options(T options) {
784792
public ChatClientRequestSpec toolNames(String... toolNames) {
785793
Assert.notNull(toolNames, "toolNames cannot be null");
786794
Assert.noNullElements(toolNames, "toolNames cannot contain null elements");
787-
this.toolNames.addAll(List.of(toolNames));
795+
Collections.addAll(this.toolNames, toolNames);
788796
return this;
789797
}
790798

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -94,22 +94,40 @@ static ChatClientRequest toChatClientRequest(DefaultChatClient.DefaultChatClient
9494
*/
9595

9696
ChatOptions processedChatOptions = inputRequest.getChatOptions();
97-
if (processedChatOptions instanceof ToolCallingChatOptions toolCallingChatOptions) {
98-
if (!inputRequest.getToolNames().isEmpty()) {
99-
Set<String> toolNames = ToolCallingChatOptions
100-
.mergeToolNames(new HashSet<>(inputRequest.getToolNames()), toolCallingChatOptions.getToolNames());
101-
toolCallingChatOptions.setToolNames(toolNames);
97+
if (inputRequest.hasToolConfiguration()) {
98+
if (processedChatOptions == null) {
99+
ToolCallingChatOptions.Builder builder = ToolCallingChatOptions.builder();
100+
if (!inputRequest.getToolNames().isEmpty()) {
101+
builder.toolNames(inputRequest.getToolNames());
102+
}
103+
if (!inputRequest.getToolCallbacks().isEmpty()) {
104+
List<ToolCallback> toolCallbacks = inputRequest.getToolCallbacks();
105+
ToolCallingChatOptions.validateToolCallbacks(toolCallbacks);
106+
builder.toolCallbacks(inputRequest.getToolCallbacks());
107+
}
108+
if (!CollectionUtils.isEmpty(inputRequest.getToolContext())) {
109+
builder.toolContext(inputRequest.getToolContext());
110+
}
111+
112+
processedChatOptions = builder.build();
102113
}
103-
if (!inputRequest.getToolCallbacks().isEmpty()) {
104-
List<ToolCallback> toolCallbacks = ToolCallingChatOptions
105-
.mergeToolCallbacks(inputRequest.getToolCallbacks(), toolCallingChatOptions.getToolCallbacks());
106-
ToolCallingChatOptions.validateToolCallbacks(toolCallbacks);
107-
toolCallingChatOptions.setToolCallbacks(toolCallbacks);
108-
}
109-
if (!CollectionUtils.isEmpty(inputRequest.getToolContext())) {
110-
Map<String, Object> toolContext = ToolCallingChatOptions.mergeToolContext(inputRequest.getToolContext(),
111-
toolCallingChatOptions.getToolContext());
112-
toolCallingChatOptions.setToolContext(toolContext);
114+
else if (processedChatOptions instanceof ToolCallingChatOptions toolCallingChatOptions) {
115+
if (!inputRequest.getToolNames().isEmpty()) {
116+
Set<String> toolNames = ToolCallingChatOptions.mergeToolNames(
117+
new HashSet<>(inputRequest.getToolNames()), toolCallingChatOptions.getToolNames());
118+
toolCallingChatOptions.setToolNames(toolNames);
119+
}
120+
if (!inputRequest.getToolCallbacks().isEmpty()) {
121+
List<ToolCallback> toolCallbacks = ToolCallingChatOptions
122+
.mergeToolCallbacks(inputRequest.getToolCallbacks(), toolCallingChatOptions.getToolCallbacks());
123+
ToolCallingChatOptions.validateToolCallbacks(toolCallbacks);
124+
toolCallingChatOptions.setToolCallbacks(toolCallbacks);
125+
}
126+
if (!CollectionUtils.isEmpty(inputRequest.getToolContext())) {
127+
Map<String, Object> toolContext = ToolCallingChatOptions
128+
.mergeToolContext(inputRequest.getToolContext(), toolCallingChatOptions.getToolContext());
129+
toolCallingChatOptions.setToolContext(toolContext);
130+
}
113131
}
114132
}
115133

spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,60 @@ void whenToolContextAndChatOptionsAreProvidedThenTheValuesAreMerged() {
322322
.containsAllEntriesOf(toolContext2);
323323
}
324324

325+
@Test
326+
void whenToolNamesWithoutChatOptionsAreProvidedThenToolCallingChatOptionsAreSet() {
327+
List<String> toolNames = List.of("tool1", "tool2");
328+
ChatModel chatModel = mock(ChatModel.class);
329+
DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient
330+
.create(chatModel)
331+
.prompt()
332+
.toolNames(toolNames.toArray(new String[0]));
333+
334+
ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest);
335+
336+
assertThat(result).isNotNull();
337+
assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class);
338+
ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions();
339+
assertThat(resultOptions).isNotNull();
340+
assertThat(resultOptions.getToolNames()).containsExactlyInAnyOrderElementsOf(toolNames);
341+
}
342+
343+
@Test
344+
void whenToolCallbacksWithoutChatOptionsAreProvidedThenToolCallingChatOptionsAreSet() {
345+
ToolCallback toolCallback = new TestToolCallback("tool1");
346+
ChatModel chatModel = mock(ChatModel.class);
347+
DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient
348+
.create(chatModel)
349+
.prompt()
350+
.toolCallbacks(toolCallback);
351+
352+
ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest);
353+
354+
assertThat(result).isNotNull();
355+
assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class);
356+
ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions();
357+
assertThat(resultOptions).isNotNull();
358+
assertThat(resultOptions.getToolCallbacks()).contains(toolCallback);
359+
}
360+
361+
@Test
362+
void whenToolContextWithoutChatOptionsIsProvidedThenToolCallingChatOptionsAreSet() {
363+
Map<String, Object> toolContext = Map.of("key", "value");
364+
ChatModel chatModel = mock(ChatModel.class);
365+
DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient
366+
.create(chatModel)
367+
.prompt()
368+
.toolContext(toolContext);
369+
370+
ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest);
371+
372+
assertThat(result).isNotNull();
373+
assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class);
374+
ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions();
375+
assertThat(resultOptions).isNotNull();
376+
assertThat(resultOptions.getToolContext()).containsAllEntriesOf(toolContext);
377+
}
378+
325379
@Test
326380
void whenAdvisorParamsAreProvidedThenTheyAreAddedToContext() {
327381
Map<String, Object> advisorParams = Map.of("key1", "value1", "key2", "value2");

0 commit comments

Comments
 (0)