Skip to content

Commit 8d1bd22

Browse files
Adding inference id to error parser log message
1 parent be84291 commit 8d1bd22

File tree

8 files changed

+60
-28
lines changed

8 files changed

+60
-28
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResult.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public void onSubscribe(Flow.Subscription subscription) {
3131

3232
@Override
3333
public void onNext(byte[] item) {
34-
subscriber.onNext(new HttpResult(response(), item, httpRequest));
34+
subscriber.onNext(new HttpResult(response(), item));
3535
}
3636

3737
@Override
@@ -70,7 +70,7 @@ public void onError(Throwable throwable) {
7070

7171
@Override
7272
public void onComplete() {
73-
fullResponse.onResponse(new HttpResult(response, stream.toByteArray(), httpRequest));
73+
fullResponse.onResponse(new HttpResult(response, stream.toByteArray()));
7474
}
7575
});
7676
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ public abstract class BaseResponseHandler implements ResponseHandler {
3636
public static final String METHOD_NOT_ALLOWED = "Received a method not allowed status code";
3737

3838
protected final String requestType;
39-
private final ResponseParser parseFunction;
39+
protected final ResponseParser parseFunction;
4040
private final BiFunction<Request, HttpResult, ErrorResponse> errorParseFunction;
4141
private final boolean canHandleStreamingResponses;
4242

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,20 @@ public record TextEmbeddingSettings(
151151
@Nullable Integer maxInputTokens,
152152
@Nullable DenseVectorFieldMapper.ElementType elementType
153153
) implements ToXContentFragment, Writeable {
154+
// This specifies float for the element type but null for all other settings
155+
public static final TextEmbeddingSettings DEFAULT_FLOAT = new TextEmbeddingSettings(
156+
null,
157+
null,
158+
null,
159+
DenseVectorFieldMapper.ElementType.FLOAT
160+
);
154161

155-
public static final TextEmbeddingSettings EMPTY = new TextEmbeddingSettings(null, null, null, null);
162+
// This refers to settings that are not related to the text embedding task type (all the settings should be null)
163+
public static final TextEmbeddingSettings NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS = new TextEmbeddingSettings(null, null, null, null);
156164

157165
public static TextEmbeddingSettings fromMap(Map<String, Object> map, TaskType taskType, ValidationException validationException) {
158166
if (taskType != TaskType.TEXT_EMBEDDING) {
159-
return EMPTY;
167+
return NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS;
160168
}
161169

162170
SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
@@ -207,7 +215,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
207215
private final ErrorResponseParser errorParser;
208216

209217
public CustomServiceSettings(
210-
@Nullable TextEmbeddingSettings textEmbeddingSettings,
218+
TextEmbeddingSettings textEmbeddingSettings,
211219
String url,
212220
@Nullable Map<String, String> headers,
213221
@Nullable QueryParameters queryParameters,
@@ -216,7 +224,7 @@ public CustomServiceSettings(
216224
@Nullable RateLimitSettings rateLimitSettings,
217225
ErrorResponseParser errorParser
218226
) {
219-
this.textEmbeddingSettings = textEmbeddingSettings == null ? TextEmbeddingSettings.EMPTY : textEmbeddingSettings;
227+
this.textEmbeddingSettings = Objects.requireNonNull(textEmbeddingSettings);
220228
this.url = Objects.requireNonNull(url);
221229
this.headers = Collections.unmodifiableMap(Objects.requireNonNullElse(headers, Map.of()));
222230
this.queryParameters = Objects.requireNonNullElse(queryParameters, QueryParameters.EMPTY);

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public void testCreateRequest_ThrowsException_ForInvalidUrl() {
5656
""";
5757

5858
var serviceSettings = new CustomServiceSettings(
59-
CustomServiceSettings.TextEmbeddingSettings.EMPTY,
59+
CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS,
6060
"${url}",
6161
null,
6262
null,

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ public void testFromMap_WithOptionalsNotSpecified() {
200200
settings,
201201
is(
202202
new CustomServiceSettings(
203-
CustomServiceSettings.TextEmbeddingSettings.EMPTY,
203+
CustomServiceSettings.TextEmbeddingSettings.DEFAULT_FLOAT,
204204
url,
205205
Map.of(),
206206
null,
@@ -664,7 +664,7 @@ public void testFromMap_ReturnsError_IfTaskTypeIsInvalid() {
664664

665665
public void testXContent() throws IOException {
666666
var entity = new CustomServiceSettings(
667-
CustomServiceSettings.TextEmbeddingSettings.EMPTY,
667+
CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS,
668668
"http://www.abc.com",
669669
Map.of("key", "value"),
670670
null,

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -141,14 +141,8 @@ private static SenderService createService(ThreadPool threadPool, HttpClientMana
141141
}
142142

143143
private static Map<String, Object> createServiceSettingsMap(TaskType taskType) {
144-
return new HashMap<>(
144+
var settingsMap = new HashMap<>(
145145
Map.of(
146-
ServiceFields.SIMILARITY,
147-
SimilarityMeasure.DOT_PRODUCT.toString(),
148-
ServiceFields.DIMENSIONS,
149-
1536,
150-
ServiceFields.MAX_INPUT_TOKENS,
151-
512,
152146
CustomServiceSettings.URL,
153147
"http://www.abc.com",
154148
CustomServiceSettings.HEADERS,
@@ -168,6 +162,17 @@ private static Map<String, Object> createServiceSettingsMap(TaskType taskType) {
168162
)
169163
)
170164
);
165+
166+
if (taskType == TaskType.TEXT_EMBEDDING) {
167+
settingsMap.putAll(Map.of(ServiceFields.SIMILARITY,
168+
SimilarityMeasure.DOT_PRODUCT.toString(),
169+
ServiceFields.DIMENSIONS,
170+
1536,
171+
ServiceFields.MAX_INPUT_TOKENS,
172+
512));
173+
}
174+
175+
return settingsMap;
171176
}
172177

173178
private static Map<String, Object> createResponseParserMap(TaskType taskType) {
@@ -229,7 +234,7 @@ private static CustomModel createInternalEmbeddingModel(
229234
CustomService.NAME,
230235
new CustomServiceSettings(
231236
new CustomServiceSettings.TextEmbeddingSettings(
232-
SimilarityMeasure.DOT_PRODUCT,
237+
similarityMeasure,
233238
123,
234239
456,
235240
DenseVectorFieldMapper.ElementType.FLOAT
@@ -253,7 +258,7 @@ private static CustomModel createCustomModel(TaskType taskType, CustomResponsePa
253258
taskType,
254259
CustomService.NAME,
255260
new CustomServiceSettings(
256-
CustomServiceSettings.TextEmbeddingSettings.EMPTY,
261+
getDefaultTextEmbeddingSettings(taskType),
257262
url,
258263
Map.of("key", "value"),
259264
QueryParameters.EMPTY,
@@ -267,6 +272,12 @@ private static CustomModel createCustomModel(TaskType taskType, CustomResponsePa
267272
);
268273
}
269274

275+
private static CustomServiceSettings.TextEmbeddingSettings getDefaultTextEmbeddingSettings(TaskType taskType) {
276+
return taskType == TaskType.TEXT_EMBEDDING
277+
? CustomServiceSettings.TextEmbeddingSettings.DEFAULT_FLOAT
278+
: CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS;
279+
}
280+
270281
public void testInfer_HandlesTextEmbeddingRequest_OpenAI_Format() throws IOException {
271282
try (var service = createService(threadPool, clientManager)) {
272283
String responseJson = """

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ public void testCreateRequest_QueryParametersAreEscaped_AndEncoded() {
100100
""";
101101

102102
var serviceSettings = new CustomServiceSettings(
103-
CustomServiceSettings.TextEmbeddingSettings.EMPTY,
103+
CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS,
104104
"http://www.elastic.co",
105105
null,
106106
// escaped characters retrieved from here: https://docs.microfocus.com/OMi/10.62/Content/OMi/ExtGuide/ExtApps/URL_encoding.htm
@@ -203,7 +203,7 @@ public void testCreateRequest_HandlesQuery() throws IOException {
203203
""";
204204

205205
var serviceSettings = new CustomServiceSettings(
206-
CustomServiceSettings.TextEmbeddingSettings.EMPTY,
206+
CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS,
207207
"http://www.elastic.co",
208208
null,
209209
null,
@@ -245,7 +245,7 @@ public void testCreateRequest_IgnoresNonStringFields_ForStringParams() throws IO
245245
""";
246246

247247
var serviceSettings = new CustomServiceSettings(
248-
CustomServiceSettings.TextEmbeddingSettings.EMPTY,
248+
CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS,
249249
"http://www.elastic.co",
250250
Map.of(HttpHeaders.ACCEPT, Strings.format("${task.key}")),
251251
null,
@@ -276,7 +276,7 @@ public void testCreateRequest_ThrowsException_ForInvalidUrl() {
276276
""";
277277

278278
var serviceSettings = new CustomServiceSettings(
279-
CustomServiceSettings.TextEmbeddingSettings.EMPTY,
279+
CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS,
280280
"${url}",
281281
Map.of(HttpHeaders.ACCEPT, Strings.format("${task.key}")),
282282
null,

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParserTests.java

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.elasticsearch.xcontent.XContentType;
1818
import org.elasticsearch.xpack.inference.external.http.HttpResult;
1919
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
20+
import org.elasticsearch.xpack.inference.external.request.Request;
2021

2122
import java.io.IOException;
2223
import java.util.HashMap;
@@ -26,6 +27,7 @@
2627
import static org.hamcrest.Matchers.is;
2728
import static org.hamcrest.Matchers.sameInstance;
2829
import static org.mockito.Mockito.mock;
30+
import static org.mockito.Mockito.when;
2931

3032
public class ErrorResponseParserTests extends ESTestCase {
3133

@@ -81,7 +83,7 @@ public void testErrorResponse_ExtractsError() throws IOException {
8183
}""");
8284

8385
var parser = new ErrorResponseParser("$.error.message");
84-
var error = parser.apply(result);
86+
var error = parser.apply(getMockRequest(), result);
8587
assertThat(error, is(new ErrorResponse("test_error_message")));
8688
}
8789

@@ -98,7 +100,7 @@ public void testFromResponse_WithOtherFieldsPresent() throws IOException {
98100
""";
99101

100102
var parser = new ErrorResponseParser("$.error.message");
101-
var error = parser.apply(getMockResult(responseJson));
103+
var error = parser.apply(getMockRequest(), getMockResult(responseJson));
102104

103105
assertThat(error, is(new ErrorResponse("You didn't provide an API key")));
104106
}
@@ -113,7 +115,7 @@ public void testFromResponse_noMessage() throws IOException {
113115
""";
114116

115117
var parser = new ErrorResponseParser("$.error.message");
116-
var error = parser.apply(getMockResult(responseJson));
118+
var error = parser.apply(getMockRequest(), getMockResult(responseJson));
117119

118120
assertThat(error, sameInstance(ErrorResponse.UNDEFINED_ERROR));
119121
assertThat(error.getErrorMessage(), is(""));
@@ -125,7 +127,7 @@ public void testErrorResponse_ReturnsUndefinedObjectIfNoError() throws IOExcepti
125127
{"noerror":true}""");
126128

127129
var parser = new ErrorResponseParser("$.error.message");
128-
var error = parser.apply(mockResult);
130+
var error = parser.apply(getMockRequest(), mockResult);
129131

130132
assertThat(error, sameInstance(ErrorResponse.UNDEFINED_ERROR));
131133
}
@@ -134,12 +136,23 @@ public void testErrorResponse_ReturnsUndefinedObjectIfNotJson() {
134136
var result = new HttpResult(mock(HttpResponse.class), Strings.toUTF8Bytes("not a json string"));
135137

136138
var parser = new ErrorResponseParser("$.error.message");
137-
var error = parser.apply(result);
139+
var error = parser.apply(getMockRequest(), result);
138140
assertThat(error, sameInstance(ErrorResponse.UNDEFINED_ERROR));
139141
}
140142

141143
private static HttpResult getMockResult(String jsonString) throws IOException {
142144
var response = mock(HttpResponse.class);
143145
return new HttpResult(response, Strings.toUTF8Bytes(XContentHelper.stripWhitespace(jsonString)));
144146
}
147+
148+
private static Request getMockRequest() {
149+
return getMockRequest("id");
150+
}
151+
152+
private static Request getMockRequest(String inferenceId) {
153+
var mockRequest = mock(Request.class);
154+
when(mockRequest.getInferenceEntityId()).thenReturn(inferenceId);
155+
156+
return mockRequest;
157+
}
145158
}

0 commit comments

Comments
 (0)