Skip to content

Commit 34df922

Browse files
Need to address quoted strings
1 parent de83271 commit 34df922

File tree

12 files changed

+327
-81
lines changed

12 files changed

+327
-81
lines changed

x-pack/plugin/inference/build.gradle

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,4 +448,3 @@ tasks.named("thirdPartyAudit").configure {
448448
tasks.named('yamlRestTest') {
449449
usesDefaultDistribution("to be triaged")
450450
}
451-
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.common;
9+
10+
import org.elasticsearch.common.Strings;
11+
import org.elasticsearch.xcontent.XContentBuilder;
12+
import org.elasticsearch.xcontent.json.JsonXContent;
13+
14+
public class JsonUtils {
15+
16+
public static <T> String toJson(T value, String field) {
17+
try {
18+
XContentBuilder builder = JsonXContent.contentBuilder();
19+
builder.value(value);
20+
return Strings.toString(builder);
21+
} catch (Exception e) {
22+
throw new IllegalStateException(
23+
Strings.format("Failed to serialize custom request value as JSON, field: %s, error: %s", field, e.getMessage()),
24+
e
25+
);
26+
}
27+
}
28+
29+
private JsonUtils() {}
30+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.common;
9+
10+
import org.apache.commons.text.StringSubstitutor;
11+
12+
import java.util.Map;
13+
import java.util.regex.Matcher;
14+
import java.util.regex.Pattern;
15+
16+
/**
17+
* Substitutes placeholder values in a string that match the keys in a provided map with the map's corresponding values.
18+
*/
19+
public class ValidatingSubstitutor {
20+
/**
21+
* This regex pattern matches on the string {@code ${<any characters>}} excluding newlines.
22+
*/
23+
private static final Pattern VARIABLE_PLACEHOLDER_PATTERN = Pattern.compile("\\$\\{.*?\\}");
24+
25+
private final StringSubstitutor substitutor;
26+
27+
/**
28+
* @param params a map containing the placeholders as the keys, the values will be used to replace the placeholders
29+
* @param prefix a string indicating the start of a placeholder
30+
* @param suffix a string indicating the end of a placeholder
31+
*/
32+
public ValidatingSubstitutor(Map<String, String> params, String prefix, String suffix) {
33+
substitutor = new StringSubstitutor(params, prefix, suffix);
34+
}
35+
36+
/**
37+
* Substitutes placeholder values in a string that match the keys in a provided map with the map's corresponding values.
38+
* After replacement, if the source still contains a placeholder an {@link IllegalStateException} is thrown.
39+
* @param source the string that will be searched for placeholders to be replaced
40+
* @param settingName a description of the source string
41+
* @return a string with the placeholders replaced by string values
42+
*/
43+
public String replace(String source, String settingName) {
44+
var replacedString = substitutor.replace(source);
45+
ensureNoMorePlaceholdersExist(replacedString, settingName);
46+
return replacedString;
47+
}
48+
49+
private static void ensureNoMorePlaceholdersExist(String substitutedString, String settingName) {
50+
Matcher matcher = VARIABLE_PLACEHOLDER_PATTERN.matcher(substitutedString);
51+
if (matcher.find()) {
52+
throw new IllegalStateException(
53+
String.format("Found placeholder [%s] in setting [%s] after replacement call", matcher.group(), settingName)
54+
);
55+
}
56+
}
57+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser
5454
public static final String NAME = "custom_service_settings";
5555
public static final String URL = "url";
5656
public static final String HEADERS = "headers";
57+
public static final String QUERY_PARAMETERS = "query_parameters";
5758
public static final String REQUEST = "request";
5859
public static final String REQUEST_CONTENT = "content";
5960
public static final String RESPONSE = "response";
@@ -71,6 +72,15 @@ public static CustomServiceSettings fromMap(Map<String, Object> map, Configurati
7172

7273
String url = extractRequiredString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
7374

75+
Map<String, Object> queryParameters = extractOptionalMap(
76+
map,
77+
QUERY_PARAMETERS,
78+
ModelConfigurations.SERVICE_SETTINGS,
79+
validationException
80+
);
81+
removeNullValues(queryParameters);
82+
var stringQueryParameters = validateMapStringValues(queryParameters, QUERY_PARAMETERS, validationException, false);
83+
7484
Map<String, Object> headers = extractOptionalMap(map, HEADERS, ModelConfigurations.SERVICE_SETTINGS, validationException);
7585
removeNullValues(headers);
7686
var stringHeaders = validateMapStringValues(headers, HEADERS, validationException, false);
@@ -136,6 +146,7 @@ public static CustomServiceSettings fromMap(Map<String, Object> map, Configurati
136146
maxInputTokens,
137147
url,
138148
stringHeaders,
149+
stringQueryParameters,
139150
requestContentString,
140151
responseJsonParser,
141152
rateLimitSettings,
@@ -148,6 +159,7 @@ public static CustomServiceSettings fromMap(Map<String, Object> map, Configurati
148159
private final Integer maxInputTokens;
149160
private final String url;
150161
private final Map<String, String> headers;
162+
private final Map<String, String> queryParameters;
151163
private final String requestContentString;
152164
private final CustomResponseParser responseJsonParser;
153165
private final RateLimitSettings rateLimitSettings;
@@ -159,6 +171,7 @@ public CustomServiceSettings(
159171
@Nullable Integer maxInputTokens,
160172
String url,
161173
@Nullable Map<String, String> headers,
174+
@Nullable Map<String, String> queryParameters,
162175
String requestContentString,
163176
CustomResponseParser responseJsonParser,
164177
@Nullable RateLimitSettings rateLimitSettings,
@@ -169,6 +182,7 @@ public CustomServiceSettings(
169182
this.maxInputTokens = maxInputTokens;
170183
this.url = Objects.requireNonNull(url);
171184
this.headers = Collections.unmodifiableMap(Objects.requireNonNullElse(headers, Map.of()));
185+
this.queryParameters = Collections.unmodifiableMap(Objects.requireNonNullElse(queryParameters, Map.of()));
172186
this.requestContentString = Objects.requireNonNull(requestContentString);
173187
this.responseJsonParser = Objects.requireNonNull(responseJsonParser);
174188
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
@@ -181,6 +195,7 @@ public CustomServiceSettings(StreamInput in) throws IOException {
181195
maxInputTokens = in.readOptionalVInt();
182196
url = in.readString();
183197
headers = in.readImmutableMap(StreamInput::readString);
198+
queryParameters = in.readImmutableMap(StreamInput::readString);
184199
requestContentString = in.readString();
185200
responseJsonParser = in.readNamedWriteable(CustomResponseParser.class);
186201
rateLimitSettings = new RateLimitSettings(in);
@@ -214,6 +229,10 @@ public Map<String, String> getHeaders() {
214229
return headers;
215230
}
216231

232+
public Map<String, String> getQueryParameters() {
233+
return queryParameters;
234+
}
235+
217236
public String getRequestContentString() {
218237
return requestContentString;
219238
}
@@ -267,6 +286,10 @@ public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder
267286
builder.field(HEADERS, headers);
268287
}
269288

289+
if (queryParameters.isEmpty() == false) {
290+
builder.field(QUERY_PARAMETERS, queryParameters);
291+
}
292+
270293
builder.startObject(REQUEST);
271294
{
272295
builder.field(REQUEST_CONTENT, requestContentString);
@@ -302,6 +325,7 @@ public void writeTo(StreamOutput out) throws IOException {
302325
out.writeOptionalVInt(maxInputTokens);
303326
out.writeString(url);
304327
out.writeMap(headers, StreamOutput::writeString, StreamOutput::writeString);
328+
out.writeMap(queryParameters, StreamOutput::writeString, StreamOutput::writeString);
305329
out.writeString(requestContentString);
306330
out.writeNamedWriteable(responseJsonParser);
307331
rateLimitSettings.writeTo(out);
@@ -318,6 +342,7 @@ public boolean equals(Object o) {
318342
&& Objects.equals(maxInputTokens, that.maxInputTokens)
319343
&& Objects.equals(url, that.url)
320344
&& Objects.equals(headers, that.headers)
345+
&& Objects.equals(queryParameters, that.queryParameters)
321346
&& Objects.equals(requestContentString, that.requestContentString)
322347
&& Objects.equals(responseJsonParser, that.responseJsonParser)
323348
&& Objects.equals(rateLimitSettings, that.rateLimitSettings)
@@ -332,6 +357,7 @@ public int hashCode() {
332357
maxInputTokens,
333358
url,
334359
headers,
360+
queryParameters,
335361
requestContentString,
336362
responseJsonParser,
337363
rateLimitSettings,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettings.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
public class CustomTaskSettings implements TaskSettings {
3030
public static final String NAME = "custom_task_settings";
3131

32+
public static final String PARAMETERS = "json_parameters";
3233
public static final String PARAMETERS = "parameters";
3334

3435
static final CustomTaskSettings EMPTY_SETTINGS = new CustomTaskSettings(new HashMap<>());
@@ -41,7 +42,13 @@ public static CustomTaskSettings fromMap(Map<String, Object> map) {
4142

4243
Map<String, Object> parameters = extractOptionalMap(map, PARAMETERS, ModelConfigurations.TASK_SETTINGS, validationException);
4344
removeNullValues(parameters);
44-
validateMapValues(parameters, List.of(String.class, Integer.class, Double.class, Float.class, Boolean.class), PARAMETERS, validationException, false);
45+
validateMapValues(
46+
parameters,
47+
List.of(String.class, Integer.class, Double.class, Float.class, Boolean.class),
48+
PARAMETERS,
49+
validationException,
50+
false
51+
);
4552

4653
if (validationException.validationErrors().isEmpty() == false) {
4754
throw validationException;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequest.java

Lines changed: 11 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,44 +7,35 @@
77

88
package org.elasticsearch.xpack.inference.services.custom.request;
99

10-
import org.apache.commons.text.StringSubstitutor;
1110
import org.apache.http.HttpHeaders;
1211
import org.apache.http.client.methods.HttpPost;
1312
import org.apache.http.client.methods.HttpRequestBase;
1413
import org.apache.http.entity.StringEntity;
1514
import org.elasticsearch.common.Strings;
16-
import org.elasticsearch.xcontent.XContentBuilder;
1715
import org.elasticsearch.xcontent.XContentType;
18-
import org.elasticsearch.xcontent.json.JsonXContent;
16+
import org.elasticsearch.xpack.inference.common.ValidatingSubstitutor;
1917
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
2018
import org.elasticsearch.xpack.inference.external.request.Request;
2119
import org.elasticsearch.xpack.inference.services.custom.CustomModel;
2220
import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings;
2321

24-
import java.io.IOException;
2522
import java.net.URI;
2623
import java.nio.charset.StandardCharsets;
2724
import java.util.HashMap;
2825
import java.util.List;
2926
import java.util.Map;
3027
import java.util.Objects;
31-
import java.util.regex.Matcher;
32-
import java.util.regex.Pattern;
3328

29+
import static org.elasticsearch.xpack.inference.common.JsonUtils.toJson;
3430
import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.REQUEST_CONTENT;
3531
import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.URL;
3632

3733
public class CustomRequest implements Request {
38-
/**
39-
* This regex pattern matches on the string {@code "${<any characters>}"}
40-
*/
41-
private static final Pattern VARIABLE_PLACEHOLDER_PATTERN = Pattern.compile("\\$\\{.*?\\}");
42-
4334
private static final String QUERY = "query";
4435
private static final String INPUT = "input";
4536

4637
private final URI uri;
47-
private final StringSubstitutor substitutor;
38+
private final ValidatingSubstitutor substitutor;
4839
private final CustomModel model;
4940

5041
public CustomRequest(String query, List<String> input, CustomModel model) {
@@ -60,28 +51,21 @@ public CustomRequest(String query, List<String> input, CustomModel model) {
6051

6152
jsonParams.put(INPUT, toJson(input, INPUT));
6253

63-
substitutor = new StringSubstitutor(jsonParams, "${", "}");
54+
substitutor = new ValidatingSubstitutor(jsonParams, "${", "}");
6455
uri = buildUri();
6556
}
6657

67-
// default for testing
68-
static <T> String toJson(T value, String field) {
69-
try {
70-
XContentBuilder builder = JsonXContent.contentBuilder();
71-
// TODO test this, I think it'll write the quotes for us so we don't need to include them in the content string
72-
builder.value(value);
73-
return Strings.toString(builder);
74-
} catch (IOException e) {
75-
throw new IllegalStateException(Strings.format("Failed to serialize custom request value as json, field: %s", field), e);
76-
}
77-
}
78-
7958
private static void addJsonStringParams(Map<String, String> jsonStringParams, Map<String, ?> params) {
8059
for (var entry : params.entrySet()) {
8160
jsonStringParams.put(entry.getKey(), toJson(entry.getValue(), entry.getKey()));
8261
}
8362
}
8463

64+
private URI buildUri() {
65+
String replacedUrl = substitutor.replace(model.getServiceSettings().getUrl(), URL);
66+
return URI.create(replacedUrl);
67+
}
68+
8569
@Override
8670
public HttpRequest createHttpRequest() {
8771
HttpPost httpRequest = new HttpPost(uri);
@@ -97,15 +81,13 @@ private void setHeaders(HttpRequestBase httpRequest) {
9781
httpRequest.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
9882

9983
for (var entry : model.getServiceSettings().getHeaders().entrySet()) {
100-
String replacedHeadersValue = substitutor.replace(entry.getValue());
101-
placeholderValidation(replacedHeadersValue, Strings.format("header.%s", entry.getKey()));
84+
String replacedHeadersValue = substitutor.replace(entry.getValue(), Strings.format("header.%s", entry.getKey()));
10285
httpRequest.setHeader(entry.getKey(), replacedHeadersValue);
10386
}
10487
}
10588

10689
private void setRequestContent(HttpPost httpRequest) {
107-
String replacedRequestContentString = substitutor.replace(model.getServiceSettings().getRequestContentString());
108-
placeholderValidation(replacedRequestContentString, REQUEST_CONTENT);
90+
String replacedRequestContentString = substitutor.replace(model.getServiceSettings().getRequestContentString(), REQUEST_CONTENT);
10991
StringEntity stringEntity = new StringEntity(replacedRequestContentString, StandardCharsets.UTF_8);
11092
httpRequest.setEntity(stringEntity);
11193
}
@@ -133,19 +115,4 @@ public Request truncate() {
133115
public boolean[] getTruncationInfo() {
134116
return null;
135117
}
136-
137-
// default for testing
138-
URI buildUri() {
139-
String replacedUrl = substitutor.replace(model.getServiceSettings().getUrl());
140-
placeholderValidation(replacedUrl, URL);
141-
return URI.create(replacedUrl);
142-
}
143-
144-
// default for testing
145-
static void placeholderValidation(String substitutedString, String settingName) {
146-
Matcher matcher = VARIABLE_PLACEHOLDER_PATTERN.matcher(substitutedString);
147-
if (matcher.find()) {
148-
throw new IllegalStateException(String.format("Found placeholder in [%s] after replacement call", settingName));
149-
}
150-
}
151118
}

0 commit comments

Comments
 (0)