Skip to content

Commit 898cfdd

Browse files
committed
upgrade test
1 parent 97ce2b4 commit 898cfdd

File tree

1 file changed

+162
-0
lines changed

1 file changed

+162
-0
lines changed

x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,12 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase {
3939
// TODO: replace with proper test features
4040
private static final String COHERE_EMBEDDINGS_ADDED_TEST_FEATURE = "gte_v8.13.0";
4141
private static final String COHERE_RERANK_ADDED_TEST_FEATURE = "gte_v8.14.0";
42+
private static final String COHERE_COMPLETIONS_ADDED_TEST_FEATURE = "gte_v8.15.0";
4243
private static final String COHERE_V2_API_ADDED_TEST_FEATURE = "inference.cohere.v2";
4344

4445
private static MockWebServer cohereEmbeddingsServer;
4546
private static MockWebServer cohereRerankServer;
47+
private static MockWebServer cohereCompletionsServer;
4648

4749
private enum ApiVersion {
4850
V1,
@@ -60,12 +62,16 @@ public static void startWebServer() throws IOException {
6062

6163
cohereRerankServer = new MockWebServer();
6264
cohereRerankServer.start();
65+
66+
cohereCompletionsServer = new MockWebServer();
67+
cohereCompletionsServer.start();
6368
}
6469

6570
@AfterClass
6671
public static void shutdown() {
6772
cohereEmbeddingsServer.close();
6873
cohereRerankServer.close();
74+
cohereCompletionsServer.close();
6975
}
7076

7177
@SuppressWarnings("unchecked")
@@ -326,6 +332,80 @@ private void assertRerank(String inferenceId) throws IOException {
326332
assertThat(inferenceMap.entrySet(), not(empty()));
327333
}
328334

335+
@SuppressWarnings("unchecked")
336+
public void testCohereCompletions() throws IOException {
337+
var completionsSupported = oldClusterHasFeature(COHERE_COMPLETIONS_ADDED_TEST_FEATURE);
338+
assumeTrue("Cohere completions not supported", completionsSupported);
339+
340+
ApiVersion oldClusterApiVersion = oldClusterHasFeature(COHERE_V2_API_ADDED_TEST_FEATURE) ? ApiVersion.V2 : ApiVersion.V1;
341+
342+
final String oldClusterId = "old-cluster-completions";
343+
344+
if (isOldCluster()) {
345+
// queue a response as PUT will call the service
346+
cohereCompletionsServer.enqueue(new MockResponse().setResponseCode(200).setBody(completionsResponse(oldClusterApiVersion)));
347+
put(oldClusterId, completionsConfig(getUrl(cohereCompletionsServer)), TaskType.COMPLETION);
348+
349+
var configs = (List<Map<String, Object>>) get(TaskType.COMPLETION, oldClusterId).get("endpoints");
350+
assertThat(configs, hasSize(1));
351+
assertEquals("cohere", configs.get(0).get("service"));
352+
var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
353+
assertThat(serviceSettings, hasEntry("model_id", "command"));
354+
} else if (isMixedCluster()) {
355+
var configs = (List<Map<String, Object>>) get(TaskType.COMPLETION, oldClusterId).get("endpoints");
356+
assertThat(configs, hasSize(1));
357+
assertEquals("cohere", configs.get(0).get("service"));
358+
var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
359+
assertThat(serviceSettings, hasEntry("model_id", "command"));
360+
} else if (isUpgradedCluster()) {
361+
// check old cluster model
362+
var configs = (List<Map<String, Object>>) get(TaskType.COMPLETION, oldClusterId).get("endpoints");
363+
var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
364+
assertThat(serviceSettings, hasEntry("model_id", "command"));
365+
366+
final String newClusterId = "new-cluster-completions";
367+
{
368+
cohereCompletionsServer.enqueue(new MockResponse().setResponseCode(200).setBody(completionsResponse(oldClusterApiVersion)));
369+
var inferenceMap = inference(oldClusterId, TaskType.COMPLETION, "some text");
370+
assertThat(inferenceMap.entrySet(), not(empty()));
371+
assertVersionInPath(cohereCompletionsServer.requests().getLast(), "chat", oldClusterApiVersion);
372+
}
373+
{
374+
// new cluster uses the V2 API
375+
cohereCompletionsServer.enqueue(new MockResponse().setResponseCode(200).setBody(completionsResponse(ApiVersion.V2)));
376+
put(newClusterId, completionsConfig(getUrl(cohereCompletionsServer)), TaskType.COMPLETION);
377+
378+
cohereCompletionsServer.enqueue(new MockResponse().setResponseCode(200).setBody(completionsResponse(ApiVersion.V2)));
379+
var inferenceMap = inference(newClusterId, TaskType.COMPLETION, "some text");
380+
assertThat(inferenceMap.entrySet(), not(empty()));
381+
assertVersionInPath(cohereCompletionsServer.requests().getLast(), "chat", ApiVersion.V2);
382+
}
383+
384+
{
385+
// new endpoints use the V2 API which require the model to be set
386+
final String upgradedClusterNoModel = "upgraded-cluster-missing-model-id";
387+
var jsonBody = Strings.format("""
388+
{
389+
"service": "cohere",
390+
"service_settings": {
391+
"url": "%s",
392+
"api_key": "XXXX"
393+
}
394+
}
395+
""", getUrl(cohereEmbeddingsServer));
396+
397+
var e = expectThrows(ResponseException.class, () -> put(upgradedClusterNoModel, jsonBody, TaskType.COMPLETION));
398+
assertThat(
399+
e.getMessage(),
400+
containsString("Validation Failed: 1: The [service_settings.model_id] field is required for the Cohere V2 API.")
401+
);
402+
}
403+
404+
delete(oldClusterId);
405+
delete(newClusterId);
406+
}
407+
}
408+
329409
private String embeddingConfigByte(String url) {
330410
return embeddingConfigTemplate(url, "byte");
331411
}
@@ -451,4 +531,86 @@ private String rerankResponse() {
451531
""";
452532
}
453533

534+
private String completionsConfig(String url) {
535+
return Strings.format("""
536+
{
537+
"service": "cohere",
538+
"service_settings": {
539+
"api_key": "XXXX",
540+
"model_id": "command",
541+
"url": "%s"
542+
}
543+
}
544+
""", url);
545+
}
546+
547+
private String completionsResponse(ApiVersion version) {
548+
return switch (version) {
549+
case V1 -> v1CompletionsResponse();
550+
case V2 -> v2CompletionsResponse();
551+
};
552+
}
553+
554+
private String v1CompletionsResponse() {
555+
return """
556+
{
557+
"response_id": "some id",
558+
"text": "result",
559+
"generation_id": "some id",
560+
"chat_history": [
561+
{
562+
"role": "USER",
563+
"message": "some input"
564+
},
565+
{
566+
"role": "CHATBOT",
567+
"message": "v1 response from the llm"
568+
}
569+
],
570+
"finish_reason": "COMPLETE",
571+
"meta": {
572+
"api_version": {
573+
"version": "1"
574+
},
575+
"billed_units": {
576+
"input_tokens": 4,
577+
"output_tokens": 191
578+
},
579+
"tokens": {
580+
"input_tokens": 70,
581+
"output_tokens": 191
582+
}
583+
}
584+
}
585+
""";
586+
}
587+
588+
private String v2CompletionsResponse() {
589+
return """
590+
{
591+
"id": "c14c80c3-18eb-4519-9460-6c92edd8cfb4",
592+
"finish_reason": "COMPLETE",
593+
"message": {
594+
"role": "assistant",
595+
"content": [
596+
{
597+
"type": "text",
598+
"text": "v2 response from the LLM"
599+
}
600+
]
601+
},
602+
"usage": {
603+
"billed_units": {
604+
"input_tokens": 1,
605+
"output_tokens": 2
606+
},
607+
"tokens": {
608+
"input_tokens": 3,
609+
"output_tokens": 4
610+
}
611+
}
612+
}
613+
""";
614+
}
615+
454616
}

0 commit comments

Comments
 (0)