Skip to content

Commit de83271

Browse files
More tests
1 parent c9ff298 commit de83271

19 files changed

+505
-447
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.elasticsearch.rest.RestStatus;
2222
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
2323
import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets;
24+
import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString;
2425

2526
import java.net.URI;
2627
import java.net.URISyntaxException;
@@ -557,17 +558,18 @@ public static void validateMapValues(
557558
);
558559
} else {
559560
return Strings.format(
560-
"Map field [%s] has an entry that is not valid, [%s => %s]. Value type is not one of [%s].",
561+
"Map field [%s] has an entry that is not valid, [%s => %s]. Value type of [%s] is not one of [%s].",
561562
settingName,
562563
entry.getKey(),
563564
entry.getValue(),
565+
entry.getValue(),
564566
String.join(", ", validTypesAsStrings)
565567
);
566568
}
567569
};
568570

569571
if (isAllowed == false) {
570-
var validTypesAsStrings = allowedTypes.stream().map(Class::toString).toArray(String[]::new);
572+
var validTypesAsStrings = allowedTypes.stream().map(Class::getSimpleName).toArray(String[]::new);
571573
Arrays.sort(validTypesAsStrings);
572574

573575
validationException.addValidationError(errorMessage.apply(validTypesAsStrings));
@@ -576,7 +578,7 @@ public static void validateMapValues(
576578
}
577579
}
578580

579-
public static Map<String, SecureString> convertMapStringsToSecureString(
581+
public static Map<String, SerializableSecureString> convertMapStringsToSecureString(
580582
Map<String, ?> map,
581583
String settingName,
582584
ValidationException validationException
@@ -585,22 +587,24 @@ public static Map<String, SecureString> convertMapStringsToSecureString(
585587
return Map.of();
586588
}
587589

588-
validateMapValues(map, List.of(String.class), settingName, validationException, true);
590+
validateMapStringValues(map, settingName, validationException, true);
589591

590592
return map.entrySet()
591593
.stream()
592-
.collect(Collectors.toMap(Map.Entry::getKey, e -> new SecureString(((String) e.getValue()).toCharArray())));
594+
.collect(Collectors.toMap(Map.Entry::getKey, e -> new SerializableSecureString((String) e.getValue())));
593595
}
594596

595597
/**
596598
* Removes null values.
597599
*/
598-
public static void removeNullValues(Map<String, Object> map) {
600+
public static Map<String, Object> removeNullValues(Map<String, Object> map) {
599601
if (map == null) {
600-
return;
602+
return map;
601603
}
602604

603605
map.values().removeIf(Objects::isNull);
606+
607+
return map;
604608
}
605609

606610
public static Integer extractRequiredPositiveIntegerLessThanOrEqualToMax(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomSecretSettings.java

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
import org.elasticsearch.common.ValidationException;
1313
import org.elasticsearch.common.io.stream.StreamInput;
1414
import org.elasticsearch.common.io.stream.StreamOutput;
15-
import org.elasticsearch.common.settings.SecureString;
1615
import org.elasticsearch.core.Nullable;
1716
import org.elasticsearch.inference.SecretSettings;
1817
import org.elasticsearch.xcontent.XContentBuilder;
18+
import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString;
1919

2020
import java.io.IOException;
2121
import java.util.HashMap;
@@ -48,22 +48,22 @@ public static CustomSecretSettings fromMap(@Nullable Map<String, Object> map) {
4848
return new CustomSecretSettings(secureStringMap);
4949
}
5050

51-
private final Map<String, SecureString> secretParameters;
51+
private final Map<String, SerializableSecureString> secretParameters;
5252

5353
@Override
5454
public SecretSettings newSecretSettings(Map<String, Object> newSecrets) {
5555
return fromMap(new HashMap<>(newSecrets));
5656
}
5757

58-
public CustomSecretSettings(@Nullable Map<String, SecureString> secretParameters) {
58+
public CustomSecretSettings(@Nullable Map<String, SerializableSecureString> secretParameters) {
5959
this.secretParameters = Objects.requireNonNullElse(secretParameters, Map.of());
6060
}
6161

6262
public CustomSecretSettings(StreamInput in) throws IOException {
63-
secretParameters = in.readImmutableMap(StreamInput::readString, StreamInput::readSecureString);
63+
secretParameters = in.readImmutableMap(SerializableSecureString::new);
6464
}
6565

66-
public Map<String, SecureString> getSecretParameters() {
66+
public Map<String, SerializableSecureString> getSecretParameters() {
6767
return secretParameters;
6868
}
6969

@@ -74,7 +74,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
7474
builder.startObject(SECRET_PARAMETERS);
7575
{
7676
for (var entry : secretParameters.entrySet()) {
77-
builder.field(entry.getKey(), entry.getValue().toString());
77+
builder.field(entry.getKey(), entry.getValue());
7878
}
7979
}
8080
builder.endObject();
@@ -95,7 +95,9 @@ public TransportVersion getMinimalSupportedVersion() {
9595

9696
@Override
9797
public void writeTo(StreamOutput out) throws IOException {
98-
out.writeMap(secretParameters, StreamOutput::writeString, StreamOutput::writeSecureString);
98+
out.writeMap(secretParameters, (streamOutput, v) -> {
99+
v.writeTo(streamOutput);
100+
});
99101
}
100102

101103
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import org.elasticsearch.xcontent.ToXContentObject;
2323
import org.elasticsearch.xcontent.XContentBuilder;
2424
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
25-
import org.elasticsearch.xpack.inference.services.custom.response.BaseCustomResponseParser;
2625
import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser;
2726
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser;
2827
import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser;
@@ -183,7 +182,7 @@ public CustomServiceSettings(StreamInput in) throws IOException {
183182
url = in.readString();
184183
headers = in.readImmutableMap(StreamInput::readString);
185184
requestContentString = in.readString();
186-
responseJsonParser = in.readNamedWriteable(BaseCustomResponseParser.class);
185+
responseJsonParser = in.readNamedWriteable(CustomResponseParser.class);
187186
rateLimitSettings = new RateLimitSettings(in);
188187
errorParser = new ErrorResponseParser(in);
189188
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ public CustomRequest(String query, List<String> input, CustomModel model) {
6464
uri = buildUri();
6565
}
6666

67-
private static <T> String toJson(T value, String field) {
67+
// default for testing
68+
static <T> String toJson(T value, String field) {
6869
try {
6970
XContentBuilder builder = JsonXContent.contentBuilder();
7071
// TODO test this, I think it'll write the quotes for us so we don't need to include them in the content string

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParser.java

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.inference.services.custom.response;
99

10+
import org.elasticsearch.common.Strings;
1011
import org.elasticsearch.common.ValidationException;
1112
import org.elasticsearch.common.io.stream.StreamInput;
1213
import org.elasticsearch.common.io.stream.StreamOutput;
@@ -15,6 +16,7 @@
1516
import org.elasticsearch.xpack.inference.common.MapPathExtractor;
1617

1718
import java.io.IOException;
19+
import java.util.List;
1820
import java.util.Map;
1921
import java.util.Objects;
2022

@@ -79,11 +81,22 @@ public String getWriteableName() {
7981

8082
@Override
8183
public ChatCompletionResults transform(Map<String, Object> map) {
82-
var completionList = validateAndCastList(
83-
validateList(MapPathExtractor.extract(map, completionResultPath)),
84-
(obj) -> toType(obj, String.class)
85-
);
86-
87-
return new ChatCompletionResults(completionList.stream().map(ChatCompletionResults.Result::new).toList());
84+
var extractedField = MapPathExtractor.extract(map, completionResultPath);
85+
86+
validateNonNull(extractedField);
87+
88+
if (extractedField instanceof List<?> extractedList) {
89+
var completionList = validateAndCastList(extractedList, (obj) -> toType(obj, String.class));
90+
return new ChatCompletionResults(completionList.stream().map(ChatCompletionResults.Result::new).toList());
91+
} else if (extractedField instanceof String extractedString) {
92+
return new ChatCompletionResults(List.of(new ChatCompletionResults.Result(extractedString)));
93+
} else {
94+
throw new IllegalArgumentException(
95+
Strings.format(
96+
"Extracted field is an invalid type, expected a list or a string but received [%s]",
97+
extractedField.getClass().getSimpleName()
98+
)
99+
);
100+
}
88101
}
89102
}

0 commit comments

Comments
 (0)