Skip to content

Commit 7950536

Browse files
committed
fix the tests
1 parent 11e8aad commit 7950536

File tree

19 files changed

+113
-40
lines changed

19 files changed

+113
-40
lines changed

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ public void setup() throws Exception {
6363
ModelRegistry modelRegistry = internalCluster().getCurrentMasterNodeInstance(ModelRegistry.class);
6464
Utils.storeSparseModel("sparse-endpoint", modelRegistry);
6565
Utils.storeDenseModel(
66+
"dense-endpoint",
6667
modelRegistry,
6768
randomIntBetween(1, 100),
6869
// dot product means that we need normalized vectors; it's not worth doing that in this test

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ public void setup() throws Exception {
9191
);
9292
int dimensions = DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, 100);
9393
Utils.storeSparseModel("sparse-endpoint", modelRegistry);
94-
Utils.storeDenseModel(modelRegistry, dimensions, similarity, elementType);
94+
Utils.storeDenseModel("dense-endpoint", modelRegistry, dimensions, similarity, elementType);
9595
}
9696

9797
@Override
@@ -122,27 +122,20 @@ public Settings indexSettings() {
122122
}
123123

124124
public void testBulkOperations() throws Exception {
125-
prepareCreate(INDEX_NAME).setMapping(
126-
String.format(
127-
Locale.ROOT,
128-
"""
129-
{
130-
"properties": {
131-
"sparse_field": {
132-
"type": "semantic_text",
133-
"inference_id": "%s"
134-
},
135-
"dense_field": {
136-
"type": "semantic_text",
137-
"inference_id": "%s"
138-
}
139-
}
125+
prepareCreate(INDEX_NAME).setMapping(String.format(Locale.ROOT, """
126+
{
127+
"properties": {
128+
"sparse_field": {
129+
"type": "semantic_text",
130+
"inference_id": "%s"
131+
},
132+
"dense_field": {
133+
"type": "semantic_text",
134+
"inference_id": "%s"
140135
}
141-
""",
142-
TestSparseInferenceServiceExtension.TestInferenceService.NAME,
143-
TestDenseInferenceServiceExtension.TestInferenceService.NAME
144-
)
145-
).get();
136+
}
137+
}
138+
""", "sparse-endpoint", "dense-endpoint")).get();
146139
assertRandomBulkOperations(INDEX_NAME, isIndexRequest -> {
147140
Map<String, Object> map = new HashMap<>();
148141
map.put("sparse_field", isIndexRequest && rarely() ? null : randomSemanticTextInput());

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/RerankWindowSizeIT.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.elasticsearch.ElasticsearchStatusException;
1111
import org.elasticsearch.plugins.Plugin;
1212
import org.elasticsearch.test.ESIntegTestCase;
13+
import org.elasticsearch.test.ESTestCase;
1314
import org.elasticsearch.xpack.core.inference.action.GetRerankerAction;
1415
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
1516
import org.elasticsearch.xpack.inference.Utils;
@@ -22,6 +23,7 @@
2223

2324
import static org.hamcrest.Matchers.containsString;
2425

26+
@ESTestCase.WithoutEntitlements // due to dependency issue ES-12435
2527
public class RerankWindowSizeIT extends ESIntegTestCase {
2628

2729
@Before
@@ -41,7 +43,7 @@ public void testRerankWindowSizeAction() {
4143
assertEquals(333, response.getWindowSize());
4244
}
4345

44-
public void testActionNotARerankder() {
46+
public void testActionNotAReranker() {
4547
var e = expectThrows(
4648
ElasticsearchStatusException.class,
4749
() -> client().execute(GetRerankerAction.INSTANCE, new GetRerankerAction.Request("sparse-endpoint")).actionGet()

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexVersionIT.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@
3131
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
3232
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
3333
import org.elasticsearch.xpack.inference.Utils;
34-
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
35-
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
3634
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
3735
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
3836
import org.junit.Before;
@@ -69,7 +67,7 @@ public void setup() throws Exception {
6967
);
7068
int dimensions = DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, 100);
7169
Utils.storeSparseModel("sparse-endpoint", modelRegistry);
72-
Utils.storeDenseModel(modelRegistry, dimensions, similarity, elementType);
70+
Utils.storeDenseModel("dense-endpoint", modelRegistry, dimensions, similarity, elementType);
7371

7472
Set<IndexVersion> availableVersions = IndexVersionUtils.allReleasedVersions()
7573
.stream()
@@ -113,11 +111,11 @@ public void testSemanticText() throws Exception {
113111
.startObject("properties")
114112
.startObject(SPARSE_SEMANTIC_FIELD)
115113
.field("type", "semantic_text")
116-
.field("inference_id", TestSparseInferenceServiceExtension.TestInferenceService.NAME)
114+
.field("inference_id", "sparse-endpoint")
117115
.endObject()
118116
.startObject(DENSE_SEMANTIC_FIELD)
119117
.field("type", "semantic_text")
120-
.field("inference_id", TestDenseInferenceServiceExtension.TestInferenceService.NAME)
118+
.field("inference_id", "dense-endpoint")
121119
.endObject()
122120
.endObject()
123121
.endObject();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.elasticsearch.inference.Model;
2727
import org.elasticsearch.inference.ModelConfigurations;
2828
import org.elasticsearch.inference.ModelSecrets;
29+
import org.elasticsearch.inference.RerankingInferenceService;
2930
import org.elasticsearch.inference.SettingsConfiguration;
3031
import org.elasticsearch.inference.SimilarityMeasure;
3132
import org.elasticsearch.inference.TaskType;
@@ -69,7 +70,7 @@
6970
import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings.SERVICE_ID;
7071
import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings.WORKSPACE_NAME;
7172

72-
public class AlibabaCloudSearchService extends SenderService {
73+
public class AlibabaCloudSearchService extends SenderService implements RerankingInferenceService {
7374
public static final String NAME = AlibabaCloudSearchUtils.SERVICE_NAME;
7475
private static final String SERVICE_NAME = "AlibabaCloud AI Search";
7576

@@ -390,6 +391,11 @@ public TransportVersion getMinimalSupportedVersion() {
390391
return TransportVersions.V_8_16_0;
391392
}
392393

394+
@Override
395+
public int rerankerWindowSize(String modelId) {
396+
return RerankingInferenceService.LARGE_WINDOW_SIZE;
397+
}
398+
393399
public static class Configuration {
394400
public static InferenceServiceConfiguration get() {
395401
return configuration.getOrCompute();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.elasticsearch.inference.Model;
2727
import org.elasticsearch.inference.ModelConfigurations;
2828
import org.elasticsearch.inference.ModelSecrets;
29+
import org.elasticsearch.inference.RerankingInferenceService;
2930
import org.elasticsearch.inference.SettingsConfiguration;
3031
import org.elasticsearch.inference.SimilarityMeasure;
3132
import org.elasticsearch.inference.TaskType;
@@ -64,7 +65,7 @@
6465
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
6566
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
6667

67-
public class CustomService extends SenderService {
68+
public class CustomService extends SenderService implements RerankingInferenceService {
6869

6970
public static final String NAME = "custom";
7071
private static final String SERVICE_NAME = "Custom";
@@ -366,6 +367,11 @@ public boolean hideFromConfigurationApi() {
366367
return true;
367368
}
368369

370+
@Override
371+
public int rerankerWindowSize(String modelId) {
372+
return RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE;
373+
}
374+
369375
public static class Configuration {
370376
public static InferenceServiceConfiguration get() {
371377
return configuration.getOrCompute();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ private static GoogleVertexAiModel createModel(
386386

387387
@Override
388388
public int rerankerWindowSize(String modelId) {
389-
return 0; // TODO
389+
return RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE;
390390
}
391391

392392
public static class Configuration {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.elasticsearch.inference.InferenceServiceResults;
2121
import org.elasticsearch.inference.InputType;
2222
import org.elasticsearch.inference.Model;
23+
import org.elasticsearch.inference.RerankingInferenceService;
2324
import org.elasticsearch.inference.SettingsConfiguration;
2425
import org.elasticsearch.inference.SimilarityMeasure;
2526
import org.elasticsearch.inference.TaskType;
@@ -57,7 +58,7 @@
5758
* This class is responsible for managing the Hugging Face inference service.
5859
* It manages model creation, as well as chunked, non-chunked, and unified completion inference.
5960
*/
60-
public class HuggingFaceService extends HuggingFaceBaseService {
61+
public class HuggingFaceService extends HuggingFaceBaseService implements RerankingInferenceService {
6162
public static final String NAME = "hugging_face";
6263

6364
private static final String SERVICE_NAME = "Hugging Face";
@@ -228,6 +229,11 @@ public TransportVersion getMinimalSupportedVersion() {
228229
return TransportVersions.V_8_15_0;
229230
}
230231

232+
@Override
233+
public int rerankerWindowSize(String modelId) {
234+
return RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE;
235+
}
236+
231237
public static class Configuration {
232238
public static InferenceServiceConfiguration get() {
233239
return configuration.getOrCompute();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.elasticsearch.inference.Model;
2626
import org.elasticsearch.inference.ModelConfigurations;
2727
import org.elasticsearch.inference.ModelSecrets;
28+
import org.elasticsearch.inference.RerankingInferenceService;
2829
import org.elasticsearch.inference.SettingsConfiguration;
2930
import org.elasticsearch.inference.SimilarityMeasure;
3031
import org.elasticsearch.inference.TaskType;
@@ -63,7 +64,7 @@
6364
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
6465
import static org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceFields.EMBEDDING_MAX_BATCH_SIZE;
6566

66-
public class JinaAIService extends SenderService {
67+
public class JinaAIService extends SenderService implements RerankingInferenceService {
6768
public static final String NAME = "jinaai";
6869

6970
private static final String SERVICE_NAME = "Jina AI";
@@ -347,6 +348,11 @@ public TransportVersion getMinimalSupportedVersion() {
347348
return TransportVersions.JINA_AI_INTEGRATION_ADDED;
348349
}
349350

351+
@Override
352+
public int rerankerWindowSize(String modelId) {
353+
return RerankingInferenceService.LARGE_WINDOW_SIZE;
354+
}
355+
350356
public static class Configuration {
351357
public static InferenceServiceConfiguration get() {
352358
return configuration.getOrCompute();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.elasticsearch.inference.InferenceServiceResults;
2626
import org.elasticsearch.inference.InputType;
2727
import org.elasticsearch.inference.Model;
28+
import org.elasticsearch.inference.RerankingInferenceService;
2829
import org.elasticsearch.inference.SettingsConfiguration;
2930
import org.elasticsearch.inference.TaskType;
3031
import org.elasticsearch.inference.UnifiedCompletionRequest;
@@ -48,7 +49,7 @@
4849
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
4950
import static org.elasticsearch.xpack.inference.services.ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails;
5051

51-
public class SageMakerService implements InferenceService {
52+
public class SageMakerService implements InferenceService, RerankingInferenceService {
5253
public static final String NAME = "amazon_sagemaker";
5354
private static final String DISPLAY_NAME = "Amazon SageMaker";
5455
private static final List<String> ALIASES = List.of("sagemaker", "amazonsagemaker");
@@ -328,4 +329,9 @@ public TransportVersion getMinimalSupportedVersion() {
328329
public void close() throws IOException {
329330
client.close();
330331
}
332+
333+
@Override
334+
public int rerankerWindowSize(String modelId) {
335+
return RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE;
336+
}
331337
}

0 commit comments

Comments
 (0)