@@ -39,10 +39,12 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase {
39
39
// TODO: replace with proper test features
40
40
private static final String COHERE_EMBEDDINGS_ADDED_TEST_FEATURE = "gte_v8.13.0" ;
41
41
private static final String COHERE_RERANK_ADDED_TEST_FEATURE = "gte_v8.14.0" ;
42
+ private static final String COHERE_COMPLETIONS_ADDED_TEST_FEATURE = "gte_v8.15.0" ;
42
43
private static final String COHERE_V2_API_ADDED_TEST_FEATURE = "inference.cohere.v2" ;
43
44
44
45
private static MockWebServer cohereEmbeddingsServer ;
45
46
private static MockWebServer cohereRerankServer ;
47
+ private static MockWebServer cohereCompletionsServer ;
46
48
47
49
private enum ApiVersion {
48
50
V1 ,
@@ -60,12 +62,16 @@ public static void startWebServer() throws IOException {
60
62
61
63
cohereRerankServer = new MockWebServer ();
62
64
cohereRerankServer .start ();
65
+
66
+ cohereCompletionsServer = new MockWebServer ();
67
+ cohereCompletionsServer .start ();
63
68
}
64
69
65
70
@ AfterClass
66
71
public static void shutdown () {
67
72
cohereEmbeddingsServer .close ();
68
73
cohereRerankServer .close ();
74
+ cohereCompletionsServer .close ();
69
75
}
70
76
71
77
@ SuppressWarnings ("unchecked" )
@@ -326,6 +332,80 @@ private void assertRerank(String inferenceId) throws IOException {
326
332
assertThat (inferenceMap .entrySet (), not (empty ()));
327
333
}
328
334
335
+ @ SuppressWarnings ("unchecked" )
336
+ public void testCohereCompletions () throws IOException {
337
+ var completionsSupported = oldClusterHasFeature (COHERE_COMPLETIONS_ADDED_TEST_FEATURE );
338
+ assumeTrue ("Cohere completions not supported" , completionsSupported );
339
+
340
+ ApiVersion oldClusterApiVersion = oldClusterHasFeature (COHERE_V2_API_ADDED_TEST_FEATURE ) ? ApiVersion .V2 : ApiVersion .V1 ;
341
+
342
+ final String oldClusterId = "old-cluster-completions" ;
343
+
344
+ if (isOldCluster ()) {
345
+ // queue a response as PUT will call the service
346
+ cohereCompletionsServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (completionsResponse (oldClusterApiVersion )));
347
+ put (oldClusterId , completionsConfig (getUrl (cohereCompletionsServer )), TaskType .COMPLETION );
348
+
349
+ var configs = (List <Map <String , Object >>) get (TaskType .COMPLETION , oldClusterId ).get ("endpoints" );
350
+ assertThat (configs , hasSize (1 ));
351
+ assertEquals ("cohere" , configs .get (0 ).get ("service" ));
352
+ var serviceSettings = (Map <String , Object >) configs .get (0 ).get ("service_settings" );
353
+ assertThat (serviceSettings , hasEntry ("model_id" , "command" ));
354
+ } else if (isMixedCluster ()) {
355
+ var configs = (List <Map <String , Object >>) get (TaskType .COMPLETION , oldClusterId ).get ("endpoints" );
356
+ assertThat (configs , hasSize (1 ));
357
+ assertEquals ("cohere" , configs .get (0 ).get ("service" ));
358
+ var serviceSettings = (Map <String , Object >) configs .get (0 ).get ("service_settings" );
359
+ assertThat (serviceSettings , hasEntry ("model_id" , "command" ));
360
+ } else if (isUpgradedCluster ()) {
361
+ // check old cluster model
362
+ var configs = (List <Map <String , Object >>) get (TaskType .COMPLETION , oldClusterId ).get ("endpoints" );
363
+ var serviceSettings = (Map <String , Object >) configs .get (0 ).get ("service_settings" );
364
+ assertThat (serviceSettings , hasEntry ("model_id" , "command" ));
365
+
366
+ final String newClusterId = "new-cluster-completions" ;
367
+ {
368
+ cohereCompletionsServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (completionsResponse (oldClusterApiVersion )));
369
+ var inferenceMap = inference (oldClusterId , TaskType .COMPLETION , "some text" );
370
+ assertThat (inferenceMap .entrySet (), not (empty ()));
371
+ assertVersionInPath (cohereCompletionsServer .requests ().getLast (), "chat" , oldClusterApiVersion );
372
+ }
373
+ {
374
+ // new cluster uses the V2 API
375
+ cohereCompletionsServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (completionsResponse (ApiVersion .V2 )));
376
+ put (newClusterId , completionsConfig (getUrl (cohereCompletionsServer )), TaskType .COMPLETION );
377
+
378
+ cohereCompletionsServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (completionsResponse (ApiVersion .V2 )));
379
+ var inferenceMap = inference (newClusterId , TaskType .COMPLETION , "some text" );
380
+ assertThat (inferenceMap .entrySet (), not (empty ()));
381
+ assertVersionInPath (cohereCompletionsServer .requests ().getLast (), "chat" , ApiVersion .V2 );
382
+ }
383
+
384
+ {
385
+ // new endpoints use the V2 API which require the model to be set
386
+ final String upgradedClusterNoModel = "upgraded-cluster-missing-model-id" ;
387
+ var jsonBody = Strings .format ("""
388
+ {
389
+ "service": "cohere",
390
+ "service_settings": {
391
+ "url": "%s",
392
+ "api_key": "XXXX"
393
+ }
394
+ }
395
+ """ , getUrl (cohereEmbeddingsServer ));
396
+
397
+ var e = expectThrows (ResponseException .class , () -> put (upgradedClusterNoModel , jsonBody , TaskType .COMPLETION ));
398
+ assertThat (
399
+ e .getMessage (),
400
+ containsString ("Validation Failed: 1: The [service_settings.model_id] field is required for the Cohere V2 API." )
401
+ );
402
+ }
403
+
404
+ delete (oldClusterId );
405
+ delete (newClusterId );
406
+ }
407
+ }
408
+
329
409
private String embeddingConfigByte (String url ) {
330
410
return embeddingConfigTemplate (url , "byte" );
331
411
}
@@ -451,4 +531,86 @@ private String rerankResponse() {
451
531
""" ;
452
532
}
453
533
534
+ private String completionsConfig (String url ) {
535
+ return Strings .format ("""
536
+ {
537
+ "service": "cohere",
538
+ "service_settings": {
539
+ "api_key": "XXXX",
540
+ "model_id": "command",
541
+ "url": "%s"
542
+ }
543
+ }
544
+ """ , url );
545
+ }
546
+
547
+ private String completionsResponse (ApiVersion version ) {
548
+ return switch (version ) {
549
+ case V1 -> v1CompletionsResponse ();
550
+ case V2 -> v2CompletionsResponse ();
551
+ };
552
+ }
553
+
554
+ private String v1CompletionsResponse () {
555
+ return """
556
+ {
557
+ "response_id": "some id",
558
+ "text": "result",
559
+ "generation_id": "some id",
560
+ "chat_history": [
561
+ {
562
+ "role": "USER",
563
+ "message": "some input"
564
+ },
565
+ {
566
+ "role": "CHATBOT",
567
+ "message": "v1 response from the llm"
568
+ }
569
+ ],
570
+ "finish_reason": "COMPLETE",
571
+ "meta": {
572
+ "api_version": {
573
+ "version": "1"
574
+ },
575
+ "billed_units": {
576
+ "input_tokens": 4,
577
+ "output_tokens": 191
578
+ },
579
+ "tokens": {
580
+ "input_tokens": 70,
581
+ "output_tokens": 191
582
+ }
583
+ }
584
+ }
585
+ """ ;
586
+ }
587
+
588
+ private String v2CompletionsResponse () {
589
+ return """
590
+ {
591
+ "id": "c14c80c3-18eb-4519-9460-6c92edd8cfb4",
592
+ "finish_reason": "COMPLETE",
593
+ "message": {
594
+ "role": "assistant",
595
+ "content": [
596
+ {
597
+ "type": "text",
598
+ "text": "v2 response from the LLM"
599
+ }
600
+ ]
601
+ },
602
+ "usage": {
603
+ "billed_units": {
604
+ "input_tokens": 1,
605
+ "output_tokens": 2
606
+ },
607
+ "tokens": {
608
+ "input_tokens": 3,
609
+ "output_tokens": 4
610
+ }
611
+ }
612
+ }
613
+ """ ;
614
+ }
615
+
454
616
}
0 commit comments