Skip to content

Commit eb63e8b

Browse files
Adding more tests
1 parent f962d74 commit eb63e8b

16 files changed

+572
-38
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@
6363
import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings;
6464
import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings;
6565
import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser;
66+
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser;
6667
import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser;
6768
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
68-
import org.elasticsearch.xpack.inference.services.custom.response.ResponseParser;
6969
import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser;
7070
import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
7171
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
@@ -198,39 +198,39 @@ private static void addCustomNamedWriteables(List<NamedWriteableRegistry.Entry>
198198

199199
namedWriteables.add(
200200
new NamedWriteableRegistry.Entry(
201-
ResponseParser.class,
201+
CustomResponseParser.class,
202202
TextEmbeddingResponseParser.NAME,
203203
TextEmbeddingResponseParser::new
204204
)
205205
);
206206

207207
namedWriteables.add(
208208
new NamedWriteableRegistry.Entry(
209-
ResponseParser.class,
209+
CustomResponseParser.class,
210210
SparseEmbeddingResponseParser.NAME,
211211
SparseEmbeddingResponseParser::new
212212
)
213213
);
214214

215215
namedWriteables.add(
216216
new NamedWriteableRegistry.Entry(
217-
ResponseParser.class,
217+
CustomResponseParser.class,
218218
RerankResponseParser.NAME,
219219
RerankResponseParser::new
220220
)
221221
);
222222

223223
namedWriteables.add(
224224
new NamedWriteableRegistry.Entry(
225-
ResponseParser.class,
225+
CustomResponseParser.class,
226226
NoopResponseParser.NAME,
227227
NoopResponseParser::new
228228
)
229229
);
230230

231231
namedWriteables.add(
232232
new NamedWriteableRegistry.Entry(
233-
ResponseParser.class,
233+
CustomResponseParser.class,
234234
CompletionResponseParser.NAME,
235235
CompletionResponseParser::new
236236
)
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.common;
9+
10+
import org.elasticsearch.common.Strings;
11+
12+
import java.util.ArrayList;
13+
import java.util.List;
14+
import java.util.Map;
15+
import java.util.regex.Pattern;
16+
17+
/**
18+
* Extracts fields from a {@link Map}.
19+
*
20+
* Uses a subset of the JSONPath schema to extract fields from a map.
21+
* For more information <a href="https://en.wikipedia.org/wiki/JSONPath">see here</a>.
22+
*
23+
* This implementation differs in out it handles lists in that JSONPath will flatten inner lists. This implementation
24+
* preserves inner lists.
25+
*
26+
* Examples of the schema:
27+
*
28+
* <pre>
29+
* {@code
30+
* $.field1.array[*].field2
31+
* $.field1.field2
32+
* }
33+
* </pre>
34+
*
35+
* Given the map
36+
* <pre>
37+
* {@code
38+
* {
39+
* "request_id": "B4AB89C8-B135-xxxx-A6F8-2BAB801A2CE4",
40+
* "latency": 38,
41+
* "usage": {
42+
* "token_count": 3072
43+
* },
44+
* "result": {
45+
* "embeddings": [
46+
* {
47+
* "index": 0,
48+
* "embedding": [
49+
* 2,
50+
* 4
51+
* ]
52+
* },
53+
* {
54+
* "index": 1,
55+
* "embedding": [
56+
* 1,
57+
* 2
58+
* ]
59+
* }
60+
* ]
61+
* }
62+
* }
63+
* }
64+
* </pre>
65+
*
66+
* <pre>
67+
* {@code
68+
* var embeddings = MapPathExtractor.extract(map, "$.result.embeddings[*].embedding");
69+
* }
70+
* </pre>
71+
*
72+
* Will result in:
73+
*
74+
* <pre>
75+
* {@code
76+
* [
77+
* [2, 4],
78+
* [1, 2]
79+
* ]
80+
* }
81+
* </pre>
82+
*
83+
* This implementation differs from JSONPath when handling a list of maps. JSONPath will flatten the result and return a single array.
84+
* this implementation will preserve each nested list while gathering the results.
85+
*
86+
* For example
87+
*
88+
* <pre>
89+
* {@code
90+
* {
91+
* "result": [
92+
* {
93+
* "key": [
94+
* {
95+
* "a": 1.1
96+
* },
97+
* {
98+
* "a": 2.2
99+
* }
100+
* ]
101+
* },
102+
* {
103+
* "key": [
104+
* {
105+
* "a": 3.3
106+
* },
107+
* {
108+
* "a": 4.4
109+
* }
110+
* ]
111+
* }
112+
* ]
113+
* }
114+
* }
115+
* {@code var embeddings = MapPathExtractor.extract(map, "$.result[*].key[*].a");}
116+
*
117+
* JSONPath: {@code [1.1, 2.2, 3.3, 4.4]}
118+
* This implementation: {@code [[1.1, 2.2], [3.3, 4.4]]}
119+
* </pre>
120+
*/
121+
public class MapPathExtractor {
122+
123+
private static final String DOLLAR = "$";
124+
125+
// default for testing
126+
static final Pattern dotFieldPattern = Pattern.compile("^\\.([^.\\[]+)(.*)");
127+
static final Pattern arrayWildcardPattern = Pattern.compile("^\\[\\*\\](.*)");
128+
129+
public static Object extract(Map<String, Object> data, String path) {
130+
if (data == null || data.isEmpty() || path == null || path.trim().isEmpty()) {
131+
return null;
132+
}
133+
134+
var cleanedPath = path.trim();
135+
136+
// Remove the prefix if it exists
137+
if (cleanedPath.startsWith(DOLLAR)) {
138+
cleanedPath = cleanedPath.substring(DOLLAR.length());
139+
}
140+
141+
return navigate(data, cleanedPath);
142+
}
143+
144+
private static Object navigate(Object current, String remainingPath) {
145+
if (current == null || remainingPath == null || remainingPath.isEmpty()) {
146+
return current;
147+
}
148+
149+
var dotFieldMatcher = dotFieldPattern.matcher(remainingPath);
150+
var arrayWildcardMatcher = arrayWildcardPattern.matcher(remainingPath);
151+
152+
if (dotFieldMatcher.matches()) {
153+
String field = dotFieldMatcher.group(1);
154+
if (field == null || field.isEmpty()) {
155+
throw new IllegalArgumentException(
156+
Strings.format(
157+
"Unable to extract field from remaining path [%s]. Fields must be delimited by a dot character.",
158+
remainingPath
159+
)
160+
);
161+
}
162+
163+
String nextPath = dotFieldMatcher.group(2);
164+
if (current instanceof Map<?, ?> currentMap) {
165+
var fieldFromMap = currentMap.get(field);
166+
if (fieldFromMap == null) {
167+
throw new IllegalArgumentException(Strings.format("Unable to find field [%s] in map", field));
168+
}
169+
170+
return navigate(currentMap.get(field), nextPath);
171+
} else {
172+
throw new IllegalArgumentException(
173+
Strings.format(
174+
"Current path [%s] matched the dot field pattern but the current object is not a map, "
175+
+ "found invalid type [%s] instead.",
176+
remainingPath,
177+
current.getClass().getSimpleName()
178+
)
179+
);
180+
}
181+
} else if (arrayWildcardMatcher.matches()) {
182+
String nextPath = arrayWildcardMatcher.group(1);
183+
if (current instanceof List<?> list) {
184+
List<Object> results = new ArrayList<>();
185+
186+
for (Object item : list) {
187+
Object result = navigate(item, nextPath);
188+
if (result != null) {
189+
results.add(result);
190+
}
191+
}
192+
193+
return results;
194+
} else {
195+
throw new IllegalArgumentException(
196+
Strings.format(
197+
"Current path [%s] matched the array field pattern but the current object is not a list, "
198+
+ "found invalid type [%s] instead.",
199+
remainingPath,
200+
current.getClass().getSimpleName()
201+
)
202+
);
203+
}
204+
}
205+
206+
throw new IllegalArgumentException(Strings.format("Invalid path received [%s], unable to extract a field name.", remainingPath));
207+
}
208+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ public abstract class BaseResponseHandler implements ResponseHandler {
3636
public static final String METHOD_NOT_ALLOWED = "Received a method not allowed status code";
3737

3838
protected final String requestType;
39-
private final ResponseParser parseFunction;
39+
protected final ResponseParser parseFunction;
4040
private final Function<HttpResult, ErrorResponse> errorParseFunction;
4141
private final boolean canHandleStreamingResponses;
4242

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomResponseHandler.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
package org.elasticsearch.xpack.inference.services.custom;
99

10+
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.inference.InferenceServiceResults;
12+
import org.elasticsearch.rest.RestStatus;
1013
import org.elasticsearch.xpack.inference.external.http.HttpResult;
1114
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
1215
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
@@ -22,6 +25,24 @@ public CustomResponseHandler(String requestType, ResponseParser parseFunction, E
2225
super(requestType, parseFunction, errorParser);
2326
}
2427

28+
@Override
29+
public InferenceServiceResults parseResult(Request request, HttpResult result) throws RetryException {
30+
try {
31+
return parseFunction.apply(request, result);
32+
} catch (Exception e) {
33+
// if we get a parse failure it's probably an incorrect configuration of the service so report the error back to the user
34+
// immediately without retrying
35+
throw new RetryException(
36+
false,
37+
new ElasticsearchStatusException(
38+
"Failed to parse custom model response, please check that the response parser path matches the response format.",
39+
RestStatus.BAD_REQUEST,
40+
e
41+
)
42+
);
43+
}
44+
}
45+
2546
/**
2647
* Validates the status code throws an RetryException if not in the range [200, 300).
2748
*

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@
2222
import org.elasticsearch.xcontent.ToXContentObject;
2323
import org.elasticsearch.xcontent.XContentBuilder;
2424
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
25+
import org.elasticsearch.xpack.inference.services.custom.response.BaseCustomResponseParser;
2526
import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser;
27+
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser;
2628
import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser;
2729
import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser;
2830
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
29-
import org.elasticsearch.xpack.inference.services.custom.response.ResponseParser;
3031
import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser;
3132
import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
3233
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
@@ -149,7 +150,7 @@ public static CustomServiceSettings fromMap(Map<String, Object> map, Configurati
149150
private final String url;
150151
private final Map<String, String> headers;
151152
private final String requestContentString;
152-
private final ResponseParser responseJsonParser;
153+
private final CustomResponseParser responseJsonParser;
153154
private final RateLimitSettings rateLimitSettings;
154155
private final ErrorResponseParser errorParser;
155156

@@ -160,7 +161,7 @@ public CustomServiceSettings(
160161
String url,
161162
@Nullable Map<String, String> headers,
162163
String requestContentString,
163-
ResponseParser responseJsonParser,
164+
CustomResponseParser responseJsonParser,
164165
@Nullable RateLimitSettings rateLimitSettings,
165166
ErrorResponseParser errorParser
166167
) {
@@ -182,7 +183,7 @@ public CustomServiceSettings(StreamInput in) throws IOException {
182183
url = in.readString();
183184
headers = in.readImmutableMap(StreamInput::readString);
184185
requestContentString = in.readString();
185-
responseJsonParser = in.readNamedWriteable(ResponseParser.class);
186+
responseJsonParser = in.readNamedWriteable(BaseCustomResponseParser.class);
186187
rateLimitSettings = new RateLimitSettings(in);
187188
errorParser = new ErrorResponseParser(in);
188189
}
@@ -218,7 +219,7 @@ public String getRequestContentString() {
218219
return requestContentString;
219220
}
220221

221-
public ResponseParser getResponseJsonParser() {
222+
public CustomResponseParser getResponseJsonParser() {
222223
return responseJsonParser;
223224
}
224225

@@ -345,7 +346,7 @@ public String modelId() {
345346
return null;
346347
}
347348

348-
private static ResponseParser extractResponseParser(
349+
private static CustomResponseParser extractResponseParser(
349350
TaskType taskType,
350351
Map<String, Object> responseParserMap,
351352
ValidationException validationException

0 commit comments

Comments
 (0)