Skip to content

Commit 097246b

Browse files
Adding query parameter handling and tests
1 parent b496732 commit 097246b

File tree

11 files changed

+588
-61
lines changed

11 files changed

+588
-61
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,8 @@ static TransportVersion def(int id) {
161161
public static final TransportVersion REMOTE_EXCEPTION_8_19 = def(8_841_0_16);
162162
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS_8_19 = def(8_841_0_17);
163163
public static final TransportVersion BATCHED_QUERY_PHASE_VERSION_BACKPORT_8_X = def(8_841_0_19);
164-
public static final TransportVersion ADD_INFERENCE_CUSTOM_MODEL_8_X = def(8_841_0_20);
165164
public static final TransportVersion SEARCH_INCREMENTAL_TOP_DOCS_NULL_BACKPORT_8_19 = def(8_841_0_20);
165+
public static final TransportVersion ADD_INFERENCE_CUSTOM_MODEL_8_X = def(8_842_0_20);
166166
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
167167
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
168168
public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED = def(9_001_0_00);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

Lines changed: 80 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.core.Nullable;
1515
import org.elasticsearch.core.Strings;
1616
import org.elasticsearch.core.TimeValue;
17+
import org.elasticsearch.core.Tuple;
1718
import org.elasticsearch.inference.InputType;
1819
import org.elasticsearch.inference.Model;
1920
import org.elasticsearch.inference.SimilarityMeasure;
@@ -25,6 +26,7 @@
2526

2627
import java.net.URI;
2728
import java.net.URISyntaxException;
29+
import java.util.ArrayList;
2830
import java.util.Arrays;
2931
import java.util.EnumSet;
3032
import java.util.HashMap;
@@ -491,15 +493,88 @@ public static Map<String, Object> extractOptionalMap(
491493
return null;
492494
}
493495

494-
if (optionalField != null && optionalField.isEmpty()) {
495-
validationException.addValidationError(ServiceUtils.mustBeNonEmptyMap(settingName, scope));
496-
}
496+
return optionalField;
497+
}
498+
499+
public static List<Tuple<String, String>> extractOptionalListOfStringTuples(
500+
Map<String, Object> map,
501+
String settingName,
502+
String scope,
503+
ValidationException validationException
504+
) {
505+
int initialValidationErrorCount = validationException.validationErrors().size();
506+
List<?> optionalField = ServiceUtils.removeAsType(map, settingName, List.class, validationException);
497507

498508
if (validationException.validationErrors().size() > initialValidationErrorCount) {
499509
return null;
500510
}
501511

502-
return optionalField;
512+
if (optionalField == null) {
513+
return null;
514+
}
515+
516+
var tuples = new ArrayList<Tuple<String, String>>();
517+
for (int tuplesIndex = 0; tuplesIndex < optionalField.size(); tuplesIndex++) {
518+
519+
var tupleEntry = optionalField.get(tuplesIndex);
520+
if (tupleEntry instanceof List<?> == false) {
521+
validationException.addValidationError(
522+
Strings.format(
523+
"[%s] failed to parse tuple list entry [%d] for setting [%s], expected a list but the entry is [%s]",
524+
scope,
525+
tuplesIndex,
526+
settingName,
527+
tupleEntry.getClass().getSimpleName()
528+
)
529+
);
530+
throw validationException;
531+
}
532+
533+
var listEntry = (List<?>) tupleEntry;
534+
if (listEntry.size() != 2) {
535+
validationException.addValidationError(
536+
Strings.format(
537+
"[%s] failed to parse tuple list entry [%d] for setting [%s], the tuple list size must be two, but was [%d]",
538+
scope,
539+
tuplesIndex,
540+
settingName,
541+
listEntry.size()
542+
)
543+
);
544+
throw validationException;
545+
}
546+
547+
var firstElement = listEntry.get(0);
548+
var secondElement = listEntry.get(1);
549+
validateTuple(firstElement, settingName, scope, "the first element", tuplesIndex, validationException);
550+
validateTuple(secondElement, settingName, scope, "the second element", tuplesIndex, validationException);
551+
tuples.add(new Tuple<>((String) firstElement, (String) secondElement));
552+
}
553+
554+
return tuples;
555+
}
556+
557+
private static void validateTuple(
558+
Object tupleValue,
559+
String settingName,
560+
String scope,
561+
String elementDescription,
562+
int index,
563+
ValidationException validationException
564+
) {
565+
if (tupleValue instanceof String == false) {
566+
validationException.addValidationError(
567+
Strings.format(
568+
"[%s] failed to parse tuple list entry [%d] for setting [%s], %s must be a string but was [%s]",
569+
scope,
570+
index,
571+
settingName,
572+
elementDescription,
573+
tupleValue.getClass().getSimpleName()
574+
)
575+
);
576+
throw validationException;
577+
}
503578
}
504579

505580
/**
@@ -540,7 +615,7 @@ public static void validateMapValues(
540615
}
541616

542617
for (var entry : map.entrySet()) {
543-
boolean isAllowed = false;
618+
var isAllowed = false;
544619

545620
for (Class<?> allowedType : allowedTypes) {
546621
if (allowedType.isInstance(entry.getValue())) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser
5454
public static final String NAME = "custom_service_settings";
5555
public static final String URL = "url";
5656
public static final String HEADERS = "headers";
57-
public static final String QUERY_PARAMETERS = "query_parameters";
5857
public static final String REQUEST = "request";
5958
public static final String REQUEST_CONTENT = "content";
6059
public static final String RESPONSE = "response";
@@ -72,14 +71,7 @@ public static CustomServiceSettings fromMap(Map<String, Object> map, Configurati
7271

7372
String url = extractRequiredString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
7473

75-
Map<String, Object> queryParameters = extractOptionalMap(
76-
map,
77-
QUERY_PARAMETERS,
78-
ModelConfigurations.SERVICE_SETTINGS,
79-
validationException
80-
);
81-
removeNullValues(queryParameters);
82-
var stringQueryParameters = validateMapStringValues(queryParameters, QUERY_PARAMETERS, validationException, false);
74+
var queryParams = QueryParameters.fromMap(map, validationException);
8375

8476
Map<String, Object> headers = extractOptionalMap(map, HEADERS, ModelConfigurations.SERVICE_SETTINGS, validationException);
8577
removeNullValues(headers);
@@ -146,7 +138,7 @@ public static CustomServiceSettings fromMap(Map<String, Object> map, Configurati
146138
maxInputTokens,
147139
url,
148140
stringHeaders,
149-
stringQueryParameters,
141+
queryParams,
150142
requestContentString,
151143
responseJsonParser,
152144
rateLimitSettings,
@@ -159,7 +151,7 @@ public static CustomServiceSettings fromMap(Map<String, Object> map, Configurati
159151
private final Integer maxInputTokens;
160152
private final String url;
161153
private final Map<String, String> headers;
162-
private final Map<String, String> queryParameters;
154+
private final QueryParameters queryParameters;
163155
private final String requestContentString;
164156
private final CustomResponseParser responseJsonParser;
165157
private final RateLimitSettings rateLimitSettings;
@@ -171,7 +163,7 @@ public CustomServiceSettings(
171163
@Nullable Integer maxInputTokens,
172164
String url,
173165
@Nullable Map<String, String> headers,
174-
@Nullable Map<String, String> queryParameters,
166+
@Nullable QueryParameters queryParameters,
175167
String requestContentString,
176168
CustomResponseParser responseJsonParser,
177169
@Nullable RateLimitSettings rateLimitSettings,
@@ -182,7 +174,7 @@ public CustomServiceSettings(
182174
this.maxInputTokens = maxInputTokens;
183175
this.url = Objects.requireNonNull(url);
184176
this.headers = Collections.unmodifiableMap(Objects.requireNonNullElse(headers, Map.of()));
185-
this.queryParameters = Collections.unmodifiableMap(Objects.requireNonNullElse(queryParameters, Map.of()));
177+
this.queryParameters = Objects.requireNonNullElse(queryParameters, QueryParameters.EMPTY);
186178
this.requestContentString = Objects.requireNonNull(requestContentString);
187179
this.responseJsonParser = Objects.requireNonNull(responseJsonParser);
188180
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
@@ -195,7 +187,7 @@ public CustomServiceSettings(StreamInput in) throws IOException {
195187
maxInputTokens = in.readOptionalVInt();
196188
url = in.readString();
197189
headers = in.readImmutableMap(StreamInput::readString);
198-
queryParameters = in.readImmutableMap(StreamInput::readString);
190+
queryParameters = new QueryParameters(in);
199191
requestContentString = in.readString();
200192
responseJsonParser = in.readNamedWriteable(CustomResponseParser.class);
201193
rateLimitSettings = new RateLimitSettings(in);
@@ -229,7 +221,7 @@ public Map<String, String> getHeaders() {
229221
return headers;
230222
}
231223

232-
public Map<String, String> getQueryParameters() {
224+
public QueryParameters getQueryParameters() {
233225
return queryParameters;
234226
}
235227

@@ -286,9 +278,7 @@ public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder
286278
builder.field(HEADERS, headers);
287279
}
288280

289-
if (queryParameters.isEmpty() == false) {
290-
builder.field(QUERY_PARAMETERS, queryParameters);
291-
}
281+
queryParameters.toXContent(builder, params);
292282

293283
builder.startObject(REQUEST);
294284
{
@@ -325,7 +315,7 @@ public void writeTo(StreamOutput out) throws IOException {
325315
out.writeOptionalVInt(maxInputTokens);
326316
out.writeString(url);
327317
out.writeMap(headers, StreamOutput::writeString, StreamOutput::writeString);
328-
out.writeMap(queryParameters, StreamOutput::writeString, StreamOutput::writeString);
318+
queryParameters.writeTo(out);
329319
out.writeString(requestContentString);
330320
out.writeNamedWriteable(responseJsonParser);
331321
rateLimitSettings.writeTo(out);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettings.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
public class CustomTaskSettings implements TaskSettings {
3030
public static final String NAME = "custom_task_settings";
3131

32-
public static final String PARAMETERS = "json_parameters";
3332
public static final String PARAMETERS = "parameters";
3433

3534
static final CustomTaskSettings EMPTY_SETTINGS = new CustomTaskSettings(new HashMap<>());
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.custom;
9+
10+
import org.elasticsearch.common.ValidationException;
11+
import org.elasticsearch.common.io.stream.StreamInput;
12+
import org.elasticsearch.common.io.stream.StreamOutput;
13+
import org.elasticsearch.common.io.stream.Writeable;
14+
import org.elasticsearch.core.Tuple;
15+
import org.elasticsearch.inference.ModelConfigurations;
16+
import org.elasticsearch.xcontent.ToXContentFragment;
17+
import org.elasticsearch.xcontent.XContentBuilder;
18+
19+
import java.io.IOException;
20+
import java.util.List;
21+
import java.util.Map;
22+
import java.util.Objects;
23+
24+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalListOfStringTuples;
25+
26+
public record QueryParameters(List<Parameter> parameters) implements ToXContentFragment, Writeable {
27+
28+
public static final QueryParameters EMPTY = new QueryParameters(List.of());
29+
public static final String QUERY_PARAMETERS = "query_parameters";
30+
31+
public static QueryParameters fromMap(Map<String, Object> map, ValidationException validationException) {
32+
List<Tuple<String, String>> queryParams = extractOptionalListOfStringTuples(
33+
map,
34+
QUERY_PARAMETERS,
35+
ModelConfigurations.SERVICE_SETTINGS,
36+
validationException
37+
);
38+
39+
if (validationException.validationErrors().isEmpty() == false) {
40+
throw validationException;
41+
}
42+
43+
return QueryParameters.fromTuples(queryParams);
44+
}
45+
46+
private static QueryParameters fromTuples(List<Tuple<String, String>> queryParams) {
47+
if (queryParams == null) {
48+
return QueryParameters.EMPTY;
49+
}
50+
51+
return new QueryParameters(queryParams.stream().map((tuple) -> new Parameter(tuple.v1(), tuple.v2())).toList());
52+
}
53+
54+
public record Parameter(String key, String value) implements ToXContentFragment, Writeable {
55+
private static final String KEY = "key";
56+
private static final String VALUE = "value";
57+
58+
public Parameter {
59+
Objects.requireNonNull(key);
60+
Objects.requireNonNull(value);
61+
}
62+
63+
public Parameter(StreamInput in) throws IOException {
64+
this(in.readString(), in.readString());
65+
}
66+
67+
@Override
68+
public void writeTo(StreamOutput out) throws IOException {
69+
out.writeString(key);
70+
out.writeString(value);
71+
}
72+
73+
@Override
74+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
75+
builder.startArray();
76+
builder.value(key);
77+
builder.value(value);
78+
builder.endArray();
79+
return builder;
80+
}
81+
}
82+
83+
public QueryParameters {
84+
Objects.requireNonNull(parameters);
85+
}
86+
87+
public QueryParameters(StreamInput in) throws IOException {
88+
this(in.readCollectionAsImmutableList(Parameter::new));
89+
}
90+
91+
@Override
92+
public void writeTo(StreamOutput out) throws IOException {
93+
out.writeCollection(parameters);
94+
}
95+
96+
@Override
97+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
98+
if (parameters.isEmpty() == false) {
99+
builder.startArray(QUERY_PARAMETERS);
100+
for (var parameter : parameters) {
101+
parameter.toXContent(builder, params);
102+
}
103+
builder.endArray();
104+
}
105+
return builder;
106+
}
107+
}

0 commit comments

Comments
 (0)