Skip to content

Commit e7f6ac5

Browse files
Adding encoding tests
1 parent 097246b commit e7f6ac5

File tree

4 files changed

+253
-5
lines changed

4 files changed

+253
-5
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/QueryParameters.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,6 @@ private static QueryParameters fromTuples(List<Tuple<String, String>> queryParam
5252
}
5353

5454
public record Parameter(String key, String value) implements ToXContentFragment, Writeable {
55-
private static final String KEY = "key";
56-
private static final String VALUE = "value";
57-
5855
public Parameter {
5956
Objects.requireNonNull(key);
6057
Objects.requireNonNull(value);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.custom;
9+
10+
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.action.support.PlainActionFuture;
12+
import org.elasticsearch.core.TimeValue;
13+
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.inference.TaskType;
15+
import org.elasticsearch.test.ESTestCase;
16+
import org.elasticsearch.threadpool.ThreadPool;
17+
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
18+
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
19+
import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser;
20+
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
21+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
22+
import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString;
23+
import org.junit.After;
24+
import org.junit.Before;
25+
26+
import java.util.List;
27+
import java.util.Map;
28+
29+
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
30+
import static org.hamcrest.Matchers.is;
31+
import static org.mockito.Mockito.mock;
32+
33+
public class CustomRequestManagerTests extends ESTestCase {
34+
35+
private ThreadPool threadPool;
36+
37+
@Before
38+
@Override
39+
public void setUp() throws Exception {
40+
super.setUp();
41+
threadPool = createThreadPool(inferenceUtilityPool());
42+
}
43+
44+
@After
45+
@Override
46+
public void tearDown() throws Exception {
47+
super.tearDown();
48+
terminate(threadPool);
49+
}
50+
51+
public void testCreateRequest_ThrowsException_ForInvalidUrl() {
52+
var requestContentString = """
53+
{
54+
"input": ${input}
55+
}
56+
""";
57+
58+
var serviceSettings = new CustomServiceSettings(
59+
null,
60+
null,
61+
null,
62+
"${url}",
63+
null,
64+
null,
65+
requestContentString,
66+
new RerankResponseParser("$.result.score"),
67+
new RateLimitSettings(10_000),
68+
new ErrorResponseParser("$.error.message")
69+
);
70+
71+
var model = CustomModelTests.createModel(
72+
"service",
73+
TaskType.RERANK,
74+
serviceSettings,
75+
new CustomTaskSettings(Map.of("url", "^")),
76+
new CustomSecretSettings(Map.of("api_key", new SerializableSecureString("my-secret-key")))
77+
);
78+
79+
var listener = new PlainActionFuture<InferenceServiceResults>();
80+
var manager = CustomRequestManager.of(model, threadPool);
81+
manager.execute(new EmbeddingsInput(List.of("abc", "123"), null, null), mock(RequestSender.class), () -> false, listener);
82+
83+
var exception = expectThrows(
84+
ElasticsearchStatusException.class,
85+
() -> listener.actionGet(TimeValue.timeValueSeconds(30))
86+
);
87+
88+
assertThat(exception.getMessage(), is("Failed to construct the custom service request"));
89+
assertThat(exception.getCause().getMessage(), is("Failed to build URI, error: Illegal character in path at index 0: ^"));
90+
}
91+
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ public void testFromMap() {
101101
Integer maxInputTokens = 512;
102102
String url = "http://www.abc.com";
103103
Map<String, String> headers = Map.of("key", "value");
104-
var queryParameters = new QueryParameters(List.of(new QueryParameters.Parameter("key", "value")));
104+
var queryParameters = List.of(List.of("key", "value"));
105105
String requestContentString = "request body";
106106

107107
var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding");
@@ -149,7 +149,7 @@ public void testFromMap() {
149149
maxInputTokens,
150150
url,
151151
headers,
152-
queryParameters,
152+
new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"))),
153153
requestContentString,
154154
responseParser,
155155
new RateLimitSettings(10_000),

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings;
2323
import org.elasticsearch.xpack.inference.services.custom.QueryParameters;
2424
import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser;
25+
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
2526
import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
2627
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
2728
import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString;
@@ -87,6 +88,55 @@ public void testCreateRequest() throws IOException {
8788
assertThat(convertToString(httpPost.getEntity().getContent()), is(expectedBody));
8889
}
8990

91+
public void testCreateRequest_QueryParametersAreEscaped_AndEncoded() {
92+
var requestContentString = """
93+
{
94+
"input": ${input}
95+
}
96+
""";
97+
98+
var serviceSettings = new CustomServiceSettings(
99+
null,
100+
null,
101+
null,
102+
"http://www.elastic.co",
103+
null,
104+
// escaped characters retrieved from here: https://docs.microfocus.com/OMi/10.62/Content/OMi/ExtGuide/ExtApps/URL_encoding.htm
105+
new QueryParameters(
106+
List.of(
107+
new QueryParameters.Parameter("key", " <>#%+{}|\\^~[]`;/?:@=&$"),
108+
// unicode is a 😀
109+
// Note: In the current version of the apache library (4.x) being used to do the encoding, spaces are converted to +
110+
// There's a bug fix here explaining that: https://issues.apache.org/jira/browse/HTTPCORE-628
111+
new QueryParameters.Parameter("key", \uD83D\uDE00")
112+
)
113+
),
114+
requestContentString,
115+
new TextEmbeddingResponseParser("$.result.embeddings"),
116+
new RateLimitSettings(10_000),
117+
new ErrorResponseParser("$.error.message")
118+
);
119+
120+
var model = CustomModelTests.createModel(
121+
"service",
122+
TaskType.TEXT_EMBEDDING,
123+
serviceSettings,
124+
new CustomTaskSettings(Map.of("url", "https://www.elastic.com")),
125+
new CustomSecretSettings(Map.of("api_key", new SerializableSecureString("my-secret-key")))
126+
);
127+
128+
var request = new CustomRequest(null, List.of("abc", "123"), model);
129+
var httpRequest = request.createHttpRequest();
130+
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
131+
132+
var httpPost = (HttpPost) httpRequest.httpRequestBase();
133+
assertThat(
134+
httpPost.getURI().toString(),
135+
// To visually verify that this is correct, input the query parameters into here: https://www.urldecoder.org/
136+
is("http://www.elastic.co?key=+%3C%3E%23%25%2B%7B%7D%7C%5C%5E%7E%5B%5D%60%3B%2F%3F%3A%40%3D%26%24&key=%CE%A3+%F0%9F%98%80")
137+
);
138+
}
139+
90140
public void testCreateRequest_SecretsInTheJsonBody_AreEncodedCorrectly() throws IOException {
91141
var dims = 1536;
92142
var maxInputTokens = 512;
@@ -139,6 +189,116 @@ public void testCreateRequest_SecretsInTheJsonBody_AreEncodedCorrectly() throws
139189
assertThat(convertToString(httpPost.getEntity().getContent()), is(expectedBody));
140190
}
141191

192+
public void testCreateRequest_HandlesQuery() throws IOException {
193+
var requestContentString = """
194+
{
195+
"input": ${input},
196+
"query": ${query}
197+
}
198+
""";
199+
200+
var serviceSettings = new CustomServiceSettings(
201+
null,
202+
null,
203+
null,
204+
"http://www.elastic.co",
205+
null,
206+
null,
207+
requestContentString,
208+
new RerankResponseParser("$.result.score"),
209+
new RateLimitSettings(10_000),
210+
new ErrorResponseParser("$.error.message")
211+
);
212+
213+
var model = CustomModelTests.createModel(
214+
"service",
215+
TaskType.RERANK,
216+
serviceSettings,
217+
new CustomTaskSettings(Map.of()),
218+
new CustomSecretSettings(Map.of("api_key", new SerializableSecureString("my-secret-key")))
219+
);
220+
221+
var request = new CustomRequest("query string", List.of("abc", "123"), model);
222+
var httpRequest = request.createHttpRequest();
223+
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
224+
225+
var httpPost = (HttpPost) httpRequest.httpRequestBase();
226+
227+
var expectedBody = XContentHelper.stripWhitespace("""
228+
{
229+
"input": ["abc", "123"],
230+
"query": "query string"
231+
}
232+
""");
233+
234+
assertThat(convertToString(httpPost.getEntity().getContent()), is(expectedBody));
235+
}
236+
237+
public void testCreateRequest_IgnoresNonStringFields_ForStringParams() throws IOException {
238+
var requestContentString = """
239+
{
240+
"input": ${input}
241+
}
242+
""";
243+
244+
var serviceSettings = new CustomServiceSettings(
245+
null,
246+
null,
247+
null,
248+
"http://www.elastic.co",
249+
Map.of(HttpHeaders.ACCEPT, Strings.format("${task.key}")),
250+
null,
251+
requestContentString,
252+
new RerankResponseParser("$.result.score"),
253+
new RateLimitSettings(10_000),
254+
new ErrorResponseParser("$.error.message")
255+
);
256+
257+
var model = CustomModelTests.createModel(
258+
"service",
259+
TaskType.RERANK,
260+
serviceSettings,
261+
new CustomTaskSettings(Map.of("task.key", 100)),
262+
new CustomSecretSettings(Map.of("api_key", new SerializableSecureString("my-secret-key")))
263+
);
264+
265+
var request = new CustomRequest(null, List.of("abc", "123"), model);
266+
var exception = expectThrows(IllegalStateException.class, request::createHttpRequest);
267+
assertThat(exception.getMessage(), is("Found placeholder [${task.key}] in field [header.Accept] after replacement call"));
268+
}
269+
270+
public void testCreateRequest_ThrowsException_ForInvalidUrl() {
271+
var requestContentString = """
272+
{
273+
"input": ${input}
274+
}
275+
""";
276+
277+
var serviceSettings = new CustomServiceSettings(
278+
null,
279+
null,
280+
null,
281+
"${url}",
282+
Map.of(HttpHeaders.ACCEPT, Strings.format("${task.key}")),
283+
null,
284+
requestContentString,
285+
new RerankResponseParser("$.result.score"),
286+
new RateLimitSettings(10_000),
287+
new ErrorResponseParser("$.error.message")
288+
);
289+
290+
var model = CustomModelTests.createModel(
291+
"service",
292+
TaskType.RERANK,
293+
serviceSettings,
294+
new CustomTaskSettings(Map.of("url", "^")),
295+
new CustomSecretSettings(Map.of("api_key", new SerializableSecureString("my-secret-key")))
296+
);
297+
298+
var exception = expectThrows(IllegalStateException.class, () -> new CustomRequest(null, List.of("abc", "123"), model));
299+
assertThat(exception.getMessage(), is("Failed to build URI, error: Illegal character in path at index 0: ^"));
300+
}
301+
142302
private static String convertToString(InputStream inputStream) throws IOException {
143303
return XContentHelper.stripWhitespace(Streams.copyToString(new InputStreamReader(inputStream, StandardCharsets.UTF_8)));
144304
}

0 commit comments

Comments
 (0)