Skip to content

Commit c9ff298

Browse files
Adding tests for remaining parsers
1 parent adc3210 commit c9ff298

16 files changed

+1381
-75
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/MapPathExtractor.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,7 @@
2020
* Uses a subset of the JSONPath schema to extract fields from a map.
2121
* For more information <a href="https://en.wikipedia.org/wiki/JSONPath">see here</a>.
2222
*
23-
<<<<<<< HEAD
24-
* This implementation differs in out it handles lists in that JSONPath will flatten inner lists. This implementation
25-
=======
2623
* This implementation differs in how it handles lists in that JSONPath will flatten inner lists. This implementation
27-
>>>>>>> 7c3e8507f4f94b1a8c6f926ebd5e5a9d00ab6378
2824
* preserves inner lists.
2925
*
3026
* Examples of the schema:

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ErrorResponse.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,16 @@ public String getErrorMessage() {
3434
public boolean errorStructureFound() {
3535
return errorStructureFound;
3636
}
37+
38+
@Override
39+
public boolean equals(Object o) {
40+
if (o == null || getClass() != o.getClass()) return false;
41+
ErrorResponse that = (ErrorResponse) o;
42+
return errorStructureFound == that.errorStructureFound && Objects.equals(errorMessage, that.errorMessage);
43+
}
44+
45+
@Override
46+
public int hashCode() {
47+
return Objects.hash(errorMessage, errorStructureFound);
48+
}
3749
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/BaseCustomResponseParser.java

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ static List<?> validateList(Object obj) {
4343

4444
if (obj instanceof List<?> == false) {
4545
throw new IllegalArgumentException(
46-
Strings.format("Extracted field was an invalid type, expected a list but received [%s]", obj.getClass().getSimpleName())
46+
Strings.format("Extracted field is an invalid type, expected a list but received [%s]", obj.getClass().getSimpleName())
4747
);
4848
}
4949

@@ -54,26 +54,76 @@ static void validateNonNull(Object obj) {
5454
Objects.requireNonNull(obj, "Failed to parse response, extracted field was null");
5555
}
5656

57-
static List<Float> convertToListOfFloats(List<?> items) {
58-
return validateAndCastList(items, BaseCustomResponseParser::toFloat);
57+
static Map<String, Object> validateMap(Object obj) {
58+
validateNonNull(obj);
59+
60+
if (obj instanceof Map<?, ?> == false) {
61+
throw new IllegalArgumentException(
62+
Strings.format("Extracted field is an invalid type, expected a map but received [%s]", obj.getClass().getSimpleName())
63+
);
64+
}
65+
66+
var keys = ((Map<?, ?>) obj).keySet();
67+
for (var key : keys) {
68+
if (key instanceof String == false) {
69+
throw new IllegalStateException(
70+
Strings.format(
71+
"Extracted map has an invalid key type. Expected a string but received [%s]",
72+
key.getClass().getSimpleName()
73+
)
74+
);
75+
}
76+
}
77+
78+
@SuppressWarnings("unchecked")
79+
var result = (Map<String, Object>) obj;
80+
return result;
81+
}
82+
83+
static List<Float> convertToListOfFloats(Object obj) {
84+
return validateAndCastList(validateList(obj), BaseCustomResponseParser::toFloat);
5985
}
6086

6187
static Float toFloat(Object obj) {
88+
return toNumber(obj).floatValue();
89+
}
90+
91+
private static Number toNumber(Object obj) {
6292
if (obj instanceof Number == false) {
63-
throw new IllegalArgumentException(Strings.format("Unable to convert type [%s] to Float", obj.getClass().getSimpleName()));
93+
throw new IllegalArgumentException(Strings.format("Unable to convert type [%s] to Number", obj.getClass().getSimpleName()));
6494
}
6595

66-
return ((Number) obj).floatValue();
96+
return ((Number) obj);
6797
}
6898

69-
static <ReturnType> List<ReturnType> validateAndCastList(List<?> items, Function<Object, ReturnType> converter) {
99+
static List<Integer> convertToListOfIntegers(Object obj) {
100+
return validateAndCastList(validateList(obj), BaseCustomResponseParser::toInteger);
101+
}
102+
103+
private static Integer toInteger(Object obj) {
104+
return toNumber(obj).intValue();
105+
}
106+
107+
static <T> List<T> validateAndCastList(List<?> items, Function<Object, T> converter) {
70108
validateNonNull(items);
71109

72-
List<ReturnType> resultList = new ArrayList<>();
110+
List<T> resultList = new ArrayList<>();
73111
for (var obj : items) {
74112
resultList.add(converter.apply(obj));
75113
}
76114

77115
return resultList;
78116
}
117+
118+
static <T> T toType(Object obj, Class<T> type) {
119+
validateNonNull(obj);
120+
121+
if (type.isInstance(obj) == false) {
122+
throw new IllegalArgumentException(
123+
Strings.format("Unable to convert object of type [%s] to type [%s]", obj.getClass().getSimpleName(), type.getSimpleName())
124+
);
125+
}
126+
127+
return type.cast(obj);
128+
}
79129
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParser.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.elasticsearch.common.io.stream.StreamOutput;
1313
import org.elasticsearch.xcontent.XContentBuilder;
1414
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
15+
import org.elasticsearch.xpack.inference.common.MapPathExtractor;
1516

1617
import java.io.IOException;
1718
import java.util.Map;
@@ -78,6 +79,11 @@ public String getWriteableName() {
7879

7980
@Override
8081
public ChatCompletionResults transform(Map<String, Object> map) {
81-
return null;
82+
var completionList = validateAndCastList(
83+
validateList(MapPathExtractor.extract(map, completionResultPath)),
84+
(obj) -> toType(obj, String.class)
85+
);
86+
87+
return new ChatCompletionResults(completionList.stream().map(ChatCompletionResults.Result::new).toList());
8288
}
8389
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParser.java

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
import org.elasticsearch.common.io.stream.StreamOutput;
1313
import org.elasticsearch.xcontent.ToXContentFragment;
1414
import org.elasticsearch.xcontent.XContentBuilder;
15+
import org.elasticsearch.xcontent.XContentFactory;
16+
import org.elasticsearch.xcontent.XContentParser;
17+
import org.elasticsearch.xcontent.XContentParserConfiguration;
18+
import org.elasticsearch.xcontent.XContentType;
19+
import org.elasticsearch.xpack.inference.common.MapPathExtractor;
1520
import org.elasticsearch.xpack.inference.external.http.HttpResult;
1621
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
1722

@@ -23,6 +28,7 @@
2328
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
2429
import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.ERROR_PARSER;
2530
import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.RESPONSE;
31+
import static org.elasticsearch.xpack.inference.services.custom.response.BaseCustomResponseParser.toType;
2632

2733
public class ErrorResponseParser implements ToXContentFragment, Function<HttpResult, ErrorResponse> {
2834

@@ -52,10 +58,6 @@ public void writeTo(StreamOutput out) throws IOException {
5258
out.writeString(messagePath);
5359
}
5460

55-
public String getMessagePath() {
56-
return messagePath;
57-
}
58-
5961
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
6062
builder.startObject(ERROR_PARSER);
6163
{
@@ -80,6 +82,25 @@ public int hashCode() {
8082

8183
@Override
8284
public ErrorResponse apply(HttpResult httpResult) {
83-
return null;
85+
try (
86+
XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON)
87+
.createParser(XContentParserConfiguration.EMPTY, httpResult.body())
88+
) {
89+
var map = jsonParser.map();
90+
91+
// NOTE: This deviates from what we've done in the past. In the ErrorMessageResponseEntity logic
92+
// if we find the top level error field we'll return a response with an empty message but indicate
93+
// that we found the structure of the error object. Here if we're missing the final field we will return
94+
// a ErrorResponse.UNDEFINED_ERROR with will indicate that we did not find the structure even if for example
95+
// the outer error field does exist, but it doesn't contain the nested field we were looking for.
96+
// If in the future we want the previous behavior, we can add a new message_path field or something and have
97+
// the current path field point to the field that indicates whether we found an error object.
98+
var errorText = toType(MapPathExtractor.extract(map, messagePath), String.class);
99+
return new ErrorResponse(errorText);
100+
} catch (Exception e) {
101+
// swallow the error
102+
}
103+
104+
return ErrorResponse.UNDEFINED_ERROR;
84105
}
85106
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParser.java

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,14 @@
1111
import org.elasticsearch.common.io.stream.StreamInput;
1212
import org.elasticsearch.common.io.stream.StreamOutput;
1313
import org.elasticsearch.core.Nullable;
14+
import org.elasticsearch.core.Strings;
1415
import org.elasticsearch.xcontent.XContentBuilder;
1516
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
17+
import org.elasticsearch.xpack.inference.common.MapPathExtractor;
1618

1719
import java.io.IOException;
20+
import java.util.ArrayList;
21+
import java.util.List;
1822
import java.util.Map;
1923
import java.util.Objects;
2024

@@ -39,7 +43,7 @@ public static RerankResponseParser fromMap(Map<String, Object> responseParserMap
3943
var rerankIndex = extractOptionalString(responseParserMap, RERANK_PARSER_INDEX, JSON_PARSER, validationException);
4044
var documentText = extractOptionalString(responseParserMap, RERANK_PARSER_DOCUMENT_TEXT, JSON_PARSER, validationException);
4145

42-
if (relevanceScore == null || rerankIndex == null || documentText == null) {
46+
if (relevanceScore == null) {
4347
throw validationException;
4448
}
4549

@@ -106,6 +110,49 @@ public String getWriteableName() {
106110

107111
@Override
108112
public RankedDocsResults transform(Map<String, Object> map) {
109-
return null;
113+
var scores = convertToListOfFloats(MapPathExtractor.extract(map, relevanceScorePath));
114+
115+
List<Integer> indices = null;
116+
if (rerankIndexPath != null) {
117+
indices = convertToListOfIntegers(MapPathExtractor.extract(map, rerankIndexPath));
118+
}
119+
120+
List<String> documents = null;
121+
if (documentTextPath != null) {
122+
documents = validateAndCastList(
123+
validateList(MapPathExtractor.extract(map, documentTextPath)),
124+
(obj) -> toType(obj, String.class)
125+
);
126+
}
127+
128+
if (indices != null && indices.size() != scores.size()) {
129+
throw new IllegalStateException(
130+
Strings.format(
131+
"The number of index paths [%d] was not the same as the number of scores [%d]",
132+
indices.size(),
133+
scores.size()
134+
)
135+
);
136+
}
137+
138+
if (documents != null && documents.size() != scores.size()) {
139+
throw new IllegalStateException(
140+
Strings.format(
141+
"The number of document texts [%d] was no the same as the number of scores [%d]",
142+
documents.size(),
143+
scores.size()
144+
)
145+
);
146+
}
147+
148+
var rankedDocs = new ArrayList<RankedDocsResults.RankedDoc>();
149+
for (int i = 0; i < scores.size(); i++) {
150+
var index = indices != null ? indices.get(i) : i;
151+
var score = scores.get(i);
152+
var document = documents != null ? documents.get(i) : null;
153+
rankedDocs.add(new RankedDocsResults.RankedDoc(index, score, document));
154+
}
155+
156+
return new RankedDocsResults(rankedDocs);
110157
}
111158
}

0 commit comments

Comments
 (0)