Skip to content

Commit 3a28241

Browse files
committed
feat(ollama): add retry template integration to OllamaChatModel
1 parent 50223d2 commit 3a28241

File tree

11 files changed

+58
-16
lines changed

11 files changed

+58
-16
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
import org.springframework.ai.ollama.management.OllamaModelManager;
6161
import org.springframework.ai.ollama.management.PullModelStrategy;
6262
import org.springframework.ai.ollama.metadata.OllamaChatUsage;
63+
import org.springframework.ai.retry.RetryUtils;
64+
import org.springframework.retry.support.RetryTemplate;
6365
import org.springframework.util.Assert;
6466
import org.springframework.util.CollectionUtils;
6567
import org.springframework.util.StringUtils;
@@ -75,6 +77,7 @@
7577
* @author luocongqiu
7678
* @author Thomas Vitale
7779
* @author Jihoon Kim
80+
* @author Alexandros Pappas
7881
* @since 1.0.0
7982
*/
8083
public class OllamaChatModel extends AbstractToolCallSupport implements ChatModel {
@@ -89,20 +92,32 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode
8992

9093
private final OllamaModelManager modelManager;
9194

95+
private final RetryTemplate retryTemplate;
96+
9297
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
9398

9499
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
95100
FunctionCallbackResolver functionCallbackResolver, List<FunctionCallback> toolFunctionCallbacks,
96101
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
102+
this(ollamaApi, defaultOptions, functionCallbackResolver, toolFunctionCallbacks, observationRegistry,
103+
modelManagementOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE);
104+
}
105+
106+
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
107+
FunctionCallbackResolver functionCallbackResolver, List<FunctionCallback> toolFunctionCallbacks,
108+
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions,
109+
RetryTemplate retryTemplate) {
97110
super(functionCallbackResolver, defaultOptions, toolFunctionCallbacks);
98111
Assert.notNull(ollamaApi, "ollamaApi must not be null");
99112
Assert.notNull(defaultOptions, "defaultOptions must not be null");
100113
Assert.notNull(observationRegistry, "observationRegistry must not be null");
101114
Assert.notNull(modelManagementOptions, "modelManagementOptions must not be null");
115+
Assert.notNull(retryTemplate, "retryTemplate must not be null");
102116
this.chatApi = ollamaApi;
103117
this.defaultOptions = defaultOptions;
104118
this.observationRegistry = observationRegistry;
105119
this.modelManager = new OllamaModelManager(this.chatApi, modelManagementOptions);
120+
this.retryTemplate = retryTemplate;
106121
initializeModel(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
107122
}
108123

@@ -142,7 +157,7 @@ public ChatResponse call(Prompt prompt) {
142157
this.observationRegistry)
143158
.observe(() -> {
144159

145-
OllamaApi.ChatResponse ollamaResponse = this.chatApi.chat(request);
160+
OllamaApi.ChatResponse ollamaResponse = this.retryTemplate.execute(ctx -> this.chatApi.chat(request));
146161

147162
List<AssistantMessage.ToolCall> toolCalls = ollamaResponse.message().toolCalls() == null ? List.of()
148163
: ollamaResponse.message()
@@ -198,7 +213,8 @@ public Flux<ChatResponse> stream(Prompt prompt) {
198213

199214
observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
200215

201-
Flux<OllamaApi.ChatResponse> ollamaResponse = this.chatApi.streamingChat(request);
216+
Flux<OllamaApi.ChatResponse> ollamaResponse = this.retryTemplate
217+
.execute(ctx -> this.chatApi.streamingChat(request));
202218

203219
Flux<ChatResponse> chatResponse = ollamaResponse.map(chunk -> {
204220
String content = (chunk.message() != null) ? chunk.message().content() : "";
@@ -409,6 +425,8 @@ public static final class Builder {
409425

410426
private ModelManagementOptions modelManagementOptions = ModelManagementOptions.defaults();
411427

428+
private RetryTemplate retryTemplate;
429+
412430
private Builder() {
413431
}
414432

@@ -452,9 +470,15 @@ public Builder withModelManagementOptions(ModelManagementOptions modelManagement
452470
return this;
453471
}
454472

473+
public Builder withRetryTemplate(RetryTemplate retryTemplate) {
474+
this.retryTemplate = retryTemplate;
475+
return this;
476+
}
477+
455478
public OllamaChatModel build() {
456479
return new OllamaChatModel(this.ollamaApi, this.defaultOptions, this.functionCallbackResolver,
457-
this.toolFunctionCallbacks, this.observationRegistry, this.modelManagementOptions);
480+
this.toolFunctionCallbacks, this.observationRegistry, this.modelManagementOptions,
481+
this.retryTemplate);
458482
}
459483

460484
}

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
*
5151
* @author Christian Tzolov
5252
* @author Thomas Vitale
53+
* @author Alexandros Pappas
5354
* @since 0.8.0
5455
*/
5556
// @formatter:off
@@ -63,8 +64,6 @@ public class OllamaApi {
6364

6465
private static final String DEFAULT_BASE_URL = "http://localhost:11434";
6566

66-
private final ResponseErrorHandler responseErrorHandler;
67-
6867
private final RestClient restClient;
6968

7069
private final WebClient webClient;
@@ -92,14 +91,16 @@ public OllamaApi(String baseUrl) {
9291
*/
9392
public OllamaApi(String baseUrl, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder) {
9493

95-
this.responseErrorHandler = new OllamaResponseErrorHandler();
94+
ResponseErrorHandler responseErrorHandler = new OllamaResponseErrorHandler();
9695

9796
Consumer<HttpHeaders> defaultHeaders = headers -> {
9897
headers.setContentType(MediaType.APPLICATION_JSON);
9998
headers.setAccept(List.of(MediaType.APPLICATION_JSON));
10099
};
101100

102-
this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(defaultHeaders).build();
101+
this.restClient = restClientBuilder.baseUrl(baseUrl)
102+
.defaultStatusHandler(responseErrorHandler)
103+
.defaultHeaders(defaultHeaders).build();
103104

104105
this.webClient = webClientBuilder.baseUrl(baseUrl).defaultHeaders(defaultHeaders).build();
105106
}
@@ -120,7 +121,6 @@ public ChatResponse chat(ChatRequest chatRequest) {
120121
.uri("/api/chat")
121122
.body(chatRequest)
122123
.retrieve()
123-
.onStatus(this.responseErrorHandler)
124124
.body(ChatResponse.class);
125125
}
126126

@@ -158,7 +158,6 @@ public EmbeddingsResponse embed(EmbeddingsRequest embeddingsRequest) {
158158
.uri("/api/embed")
159159
.body(embeddingsRequest)
160160
.retrieve()
161-
.onStatus(this.responseErrorHandler)
162161
.body(EmbeddingsResponse.class);
163162
}
164163

@@ -169,7 +168,6 @@ public ListModelResponse listModels() {
169168
return this.restClient.get()
170169
.uri("/api/tags")
171170
.retrieve()
172-
.onStatus(this.responseErrorHandler)
173171
.body(ListModelResponse.class);
174172
}
175173

@@ -182,7 +180,6 @@ public ShowModelResponse showModel(ShowModelRequest showModelRequest) {
182180
.uri("/api/show")
183181
.body(showModelRequest)
184182
.retrieve()
185-
.onStatus(this.responseErrorHandler)
186183
.body(ShowModelResponse.class);
187184
}
188185

@@ -195,7 +192,6 @@ public ResponseEntity<Void> copyModel(CopyModelRequest copyModelRequest) {
195192
.uri("/api/copy")
196193
.body(copyModelRequest)
197194
.retrieve()
198-
.onStatus(this.responseErrorHandler)
199195
.toBodilessEntity();
200196
}
201197

@@ -208,7 +204,6 @@ public ResponseEntity<Void> deleteModel(DeleteModelRequest deleteModelRequest) {
208204
.uri("/api/delete")
209205
.body(deleteModelRequest)
210206
.retrieve()
211-
.onStatus(this.responseErrorHandler)
212207
.toBodilessEntity();
213208
}
214209

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.springframework.ai.ollama.api.OllamaApi;
3838
import org.springframework.ai.ollama.api.OllamaOptions;
3939
import org.springframework.ai.ollama.api.tool.MockWeatherService;
40+
import org.springframework.ai.retry.RetryUtils;
4041
import org.springframework.beans.factory.annotation.Autowired;
4142
import org.springframework.boot.SpringBootConfiguration;
4243
import org.springframework.boot.test.context.SpringBootTest;
@@ -124,6 +125,7 @@ public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
124125
return OllamaChatModel.builder()
125126
.withOllamaApi(ollamaApi)
126127
.withDefaultOptions(OllamaOptions.create().withModel(MODEL).withTemperature(0.9))
128+
.withRetryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
127129
.build();
128130
}
129131

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import org.springframework.ai.ollama.management.ModelManagementOptions;
4343
import org.springframework.ai.ollama.management.OllamaModelManager;
4444
import org.springframework.ai.ollama.management.PullModelStrategy;
45+
import org.springframework.ai.retry.RetryUtils;
4546
import org.springframework.beans.factory.annotation.Autowired;
4647
import org.springframework.boot.SpringBootConfiguration;
4748
import org.springframework.boot.test.context.SpringBootTest;
@@ -249,6 +250,7 @@ public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
249250
.withPullModelStrategy(PullModelStrategy.WHEN_MISSING)
250251
.withAdditionalModels(List.of(ADDITIONAL_MODEL))
251252
.build())
253+
.withRetryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
252254
.build();
253255
}
254256

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.springframework.ai.model.Media;
2828
import org.springframework.ai.ollama.api.OllamaApi;
2929
import org.springframework.ai.ollama.api.OllamaOptions;
30+
import org.springframework.ai.retry.RetryUtils;
3031
import org.springframework.beans.factory.annotation.Autowired;
3132
import org.springframework.boot.SpringBootConfiguration;
3233
import org.springframework.boot.test.context.SpringBootTest;
@@ -84,6 +85,7 @@ public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
8485
return OllamaChatModel.builder()
8586
.withOllamaApi(ollamaApi)
8687
.withDefaultOptions(OllamaOptions.create().withModel(MODEL).withTemperature(0.9))
88+
.withRetryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
8789
.build();
8890
}
8991

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.springframework.ai.ollama.api.OllamaApi;
3535
import org.springframework.ai.ollama.api.OllamaModel;
3636
import org.springframework.ai.ollama.api.OllamaOptions;
37+
import org.springframework.ai.retry.RetryUtils;
3738
import org.springframework.beans.factory.annotation.Autowired;
3839
import org.springframework.boot.SpringBootConfiguration;
3940
import org.springframework.boot.test.context.SpringBootTest;
@@ -47,6 +48,7 @@
4748
* Integration tests for observation instrumentation in {@link OllamaChatModel}.
4849
*
4950
* @author Thomas Vitale
51+
* @author Alexandros Pappas
5052
*/
5153
@SpringBootTest(classes = OllamaChatModelObservationIT.Config.class)
5254
public class OllamaChatModelObservationIT extends BaseOllamaIT {
@@ -172,6 +174,7 @@ public OllamaChatModel openAiChatModel(OllamaApi ollamaApi, TestObservationRegis
172174
return OllamaChatModel.builder()
173175
.withOllamaApi(ollamaApi)
174176
.withObservationRegistry(observationRegistry)
177+
.withRetryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
175178
.build();
176179
}
177180

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,15 @@
2727
import org.springframework.ai.ollama.api.OllamaApi;
2828
import org.springframework.ai.ollama.api.OllamaModel;
2929
import org.springframework.ai.ollama.api.OllamaOptions;
30+
import org.springframework.ai.retry.RetryUtils;
3031

3132
import static org.junit.jupiter.api.Assertions.assertEquals;
3233
import static org.junit.jupiter.api.Assertions.assertThrows;
3334
import static org.mockito.BDDMockito.given;
3435

3536
/**
3637
* @author Jihoon Kim
38+
* @author Alexandros Pappas
3739
* @since 1.0.0
3840
*/
3941
@ExtendWith(MockitoExtension.class)
@@ -51,6 +53,7 @@ public void buildOllamaChatModel() {
5153
() -> OllamaChatModel.builder()
5254
.withOllamaApi(this.ollamaApi)
5355
.withDefaultOptions(OllamaOptions.create().withModel(OllamaModel.LLAMA2))
56+
.withRetryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
5457
.withModelManagementOptions(null)
5558
.build());
5659
assertEquals("modelManagementOptions must not be null", exception.getMessage());

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,22 @@
2323
import org.springframework.ai.chat.prompt.Prompt;
2424
import org.springframework.ai.ollama.api.OllamaApi;
2525
import org.springframework.ai.ollama.api.OllamaOptions;
26+
import org.springframework.ai.retry.RetryUtils;
2627

2728
import static org.assertj.core.api.Assertions.assertThat;
2829

2930
/**
3031
* @author Christian Tzolov
3132
* @author Thomas Vitale
33+
* @author Alexandros Pappas
3234
*/
3335
public class OllamaChatRequestTests {
3436

3537
OllamaChatModel chatModel = OllamaChatModel.builder()
3638
.withOllamaApi(new OllamaApi())
3739
.withDefaultOptions(
3840
OllamaOptions.create().withModel("MODEL_NAME").withTopK(99).withTemperature(66.6).withNumGPU(1))
41+
.withRetryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
3942
.build();
4043

4144
@Test
@@ -113,6 +116,7 @@ public void createRequestWithDefaultOptionsModelOverride() {
113116
OllamaChatModel chatModel = OllamaChatModel.builder()
114117
.withOllamaApi(new OllamaApi())
115118
.withDefaultOptions(OllamaOptions.create().withModel("DEFAULT_OPTIONS_MODEL"))
119+
.withRetryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
116120
.build();
117121

118122
var request = chatModel.ollamaChatRequest(new Prompt("Test message content"), true);

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import org.springframework.boot.context.properties.EnableConfigurationProperties;
4242
import org.springframework.context.ApplicationContext;
4343
import org.springframework.context.annotation.Bean;
44+
import org.springframework.retry.support.RetryTemplate;
4445
import org.springframework.web.client.RestClient;
4546
import org.springframework.web.reactive.function.client.WebClient;
4647

@@ -82,7 +83,7 @@ public OllamaApi ollamaApi(OllamaConnectionDetails connectionDetails,
8283
public OllamaChatModel ollamaChatModel(OllamaApi ollamaApi, OllamaChatProperties properties,
8384
OllamaInitializationProperties initProperties, List<FunctionCallback> toolFunctionCallbacks,
8485
FunctionCallbackResolver functionCallbackResolver, ObjectProvider<ObservationRegistry> observationRegistry,
85-
ObjectProvider<ChatModelObservationConvention> observationConvention) {
86+
ObjectProvider<ChatModelObservationConvention> observationConvention, RetryTemplate retryTemplate) {
8687
var chatModelPullStrategy = initProperties.getChat().isInclude() ? initProperties.getPullModelStrategy()
8788
: PullModelStrategy.NEVER;
8889

@@ -95,6 +96,7 @@ public OllamaChatModel ollamaChatModel(OllamaApi ollamaApi, OllamaChatProperties
9596
.withModelManagementOptions(
9697
new ModelManagementOptions(chatModelPullStrategy, initProperties.getChat().getAdditionalModels(),
9798
initProperties.getTimeout(), initProperties.getMaxRetries()))
99+
.withRetryTemplate(retryTemplate)
98100
.build();
99101

100102
observationConvention.ifAvailable(chatModel::setObservationConvention);

spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationTests.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import org.junit.jupiter.api.Test;
2020

21+
import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
2122
import org.springframework.boot.autoconfigure.AutoConfigurations;
2223
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
2324
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
@@ -41,7 +42,8 @@ public void propertiesTest() {
4142
"spring.ai.ollama.chat.options.topP=0.56",
4243
"spring.ai.ollama.chat.options.topK=123")
4344
// @formatter:on
44-
.withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OllamaAutoConfiguration.class))
45+
.withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
46+
RestClientAutoConfiguration.class, OllamaAutoConfiguration.class))
4547
.run(context -> {
4648
var chatProperties = context.getBean(OllamaChatProperties.class);
4749
var connectionProperties = context.getBean(OllamaConnectionProperties.class);

0 commit comments

Comments
 (0)