diff --git a/docs/reference/rest-api/usage.asciidoc b/docs/reference/rest-api/usage.asciidoc index e10240a66fbb9..0268676dee5c6 100644 --- a/docs/reference/rest-api/usage.asciidoc +++ b/docs/reference/rest-api/usage.asciidoc @@ -200,7 +200,8 @@ GET /_xpack/usage "inference": { "available" : true, "enabled" : true, - "models" : [] + "models" : [], + "requests" : [] }, "logstash" : { "available" : true, diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 3e9234db6a87c..b3f39afb506c7 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -215,6 +215,7 @@ static TransportVersion def(int id) { public static final TransportVersion ML_INFERENCE_COHERE_UNUSED_RERANK_SETTINGS_REMOVED = def(8_706_00_0); public static final TransportVersion ENRICH_CACHE_STATS_SIZE_ADDED = def(8_707_00_0); public static final TransportVersion ENTERPRISE_GEOIP_DOWNLOADER = def(8_708_00_0); + public static final TransportVersion ML_INFERENCE_REQUEST_TELEMETRY_ADDED = def(8_709_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/InferenceFeatureSetUsage.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/InferenceFeatureSetUsage.java index 61409f59f9d85..aaa1ca0449c84 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/InferenceFeatureSetUsage.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/InferenceFeatureSetUsage.java @@ -19,6 +19,7 @@ import org.elasticsearch.xpack.core.XPackField; import java.io.IOException; +import java.util.ArrayList; import java.util.Collection; import java.util.Objects; @@ -102,27 +103,38 @@ public int hashCode() { } private final Collection modelStats; + private final Collection requestStats; - public InferenceFeatureSetUsage(Collection modelStats) { + public InferenceFeatureSetUsage(Collection modelStats, Collection requestStats) { super(XPackField.INFERENCE, true, true); - this.modelStats = modelStats; + this.modelStats = Objects.requireNonNull(modelStats); + this.requestStats = Objects.requireNonNull(requestStats); } public InferenceFeatureSetUsage(StreamInput in) throws IOException { super(in); this.modelStats = in.readCollectionAsList(ModelStats::new); + if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_REQUEST_TELEMETRY_ADDED)) { + this.requestStats = in.readCollectionAsList(InferenceRequestStats::new); + } else { + this.requestStats = new ArrayList<>(); + } } @Override protected void innerXContent(XContentBuilder builder, Params params) throws IOException { super.innerXContent(builder, params); builder.xContentList("models", modelStats); + builder.xContentList("requests", requestStats); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeCollection(modelStats); + if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_REQUEST_TELEMETRY_ADDED)) { + out.writeCollection(requestStats); + } } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/InferenceRequestStats.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/InferenceRequestStats.java index 74d44b1a24173..cc9328dcdbd31 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/InferenceRequestStats.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/InferenceRequestStats.java @@ -18,6 +18,20 @@ import java.util.Objects; public class InferenceRequestStats implements SerializableStats { + + public static InferenceRequestStats merge(InferenceRequestStats stats1, InferenceRequestStats stats2) { + assert stats1.modelStats.service().equals(stats2.modelStats.service()) : "services do not match"; + assert stats1.modelStats.taskType().equals(stats2.modelStats.taskType()) : "task types do not match"; + assert stats1.modelId.equals(stats2.modelId) : "model ids do not match"; + + return new InferenceRequestStats( + stats1.modelStats().service(), + stats1.modelStats().taskType(), + stats1.modelId(), + stats1.modelStats().count() + stats2.modelStats().count() + ); + } + private final InferenceFeatureSetUsage.ModelStats modelStats; private final String modelId; @@ -35,6 +49,15 @@ public InferenceRequestStats(StreamInput in) throws IOException { this.modelId = in.readOptionalString(); } + public InferenceFeatureSetUsage.ModelStats modelStats() { + return modelStats; + } + + public String modelId() { + return modelId; + } + + @Override public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { builder.startObject(); builder.field("service", modelStats.service()); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInternalInferenceUsageAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInternalInferenceUsageAction.java new file mode 100644 index 0000000000000..38ca69a3fc601 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInternalInferenceUsageAction.java @@ -0,0 +1,137 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.action; + +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.FailedNodeException; +import org.elasticsearch.action.support.nodes.BaseNodeResponse; +import org.elasticsearch.action.support.nodes.BaseNodesRequest; +import org.elasticsearch.action.support.nodes.BaseNodesResponse; +import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.transport.TransportRequest; +import org.elasticsearch.xpack.core.inference.InferenceRequestStats; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +public class GetInternalInferenceUsageAction extends ActionType { + + public static final GetInternalInferenceUsageAction INSTANCE = new GetInternalInferenceUsageAction(); + public static final String NAME = "cluster:monitor/xpack/inference/internal_usage/get"; + + public GetInternalInferenceUsageAction() { + super(NAME); + } + + public static class Request extends BaseNodesRequest { + + public Request() { + super((String[]) null); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + return true; + } + + @Override + public int hashCode() { + // The class doesn't have any members at the moment so return the same hash code + return Objects.hash(NAME); + } + } + + public static class NodeRequest extends TransportRequest { + public NodeRequest(StreamInput in) throws IOException { + super(in); + } + + public NodeRequest() {} + } + + public static class Response extends BaseNodesResponse implements Writeable { + + public Response(StreamInput in) throws IOException { + super(in); + } + + public Response(ClusterName clusterName, List nodes, List failures) { + super(clusterName, nodes, failures); + } + + @Override + protected List readNodesFrom(StreamInput in) throws IOException { + return in.readCollectionAsList(NodeResponse::new); + } + + @Override + protected void writeNodesTo(StreamOutput out, List nodes) throws IOException { + out.writeCollection(nodes); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Response that = (Response) o; + return Objects.equals(getNodes(), that.getNodes()) && Objects.equals(failures(), that.failures()); + } + + @Override + public int hashCode() { + return Objects.hash(getNodes(), failures()); + } + } + + public static class NodeResponse extends BaseNodeResponse { + private final Map inferenceRequestStats; + + public NodeResponse(DiscoveryNode node, Map inferenceRequestStats) { + super(node); + this.inferenceRequestStats = Objects.requireNonNull(inferenceRequestStats); + } + + public NodeResponse(StreamInput in) throws IOException { + super(in); + + inferenceRequestStats = in.readImmutableMap(StreamInput::readString, InferenceRequestStats::new); + } + + public Map getInferenceRequestStats() { + return Collections.unmodifiableMap(inferenceRequestStats); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeMap(inferenceRequestStats, StreamOutput::writeString, StreamOutput::writeWriteable); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + NodeResponse response = (NodeResponse) o; + return Objects.equals(inferenceRequestStats, response.inferenceRequestStats); + } + + @Override + public int hashCode() { + return Objects.hash(inferenceRequestStats); + } + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/InferenceRequestStatsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/InferenceRequestStatsTests.java index 518612b1b0397..da13d644db3eb 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/InferenceRequestStatsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/InferenceRequestStatsTests.java @@ -50,6 +50,46 @@ public void testToXContent_WritesModelId_WhenItIsDefined() throws IOException { {"service":"service","task_type":"text_embedding","count":2,"model_id":"model_id"}""")); } + public void testMerge_SumsCounts() { + var stats1 = new InferenceRequestStats("service", TaskType.TEXT_EMBEDDING, "model_id", 1); + var stats2 = new InferenceRequestStats("service", TaskType.TEXT_EMBEDDING, "model_id", 2); + + assertThat( + InferenceRequestStats.merge(stats1, stats2), + is(new InferenceRequestStats("service", TaskType.TEXT_EMBEDDING, "model_id", 3)) + ); + } + + public void testMerge_ThrowsAssertionExceptionWhenFieldsAreDifferent() { + // service names don't match + { + var stats1 = new InferenceRequestStats("service1", TaskType.TEXT_EMBEDDING, "model_id", 1); + var stats2 = new InferenceRequestStats("service2", TaskType.TEXT_EMBEDDING, "model_id", 2); + + var thrownException = expectThrows(AssertionError.class, () -> InferenceRequestStats.merge(stats1, stats2)); + + assertThat(thrownException.getMessage(), is("services do not match")); + } + // task types don't match + { + var stats1 = new InferenceRequestStats("service", TaskType.TEXT_EMBEDDING, "model_id", 1); + var stats2 = new InferenceRequestStats("service", TaskType.RERANK, "model_id", 2); + + var thrownException = expectThrows(AssertionError.class, () -> InferenceRequestStats.merge(stats1, stats2)); + + assertThat(thrownException.getMessage(), is("task types do not match")); + } + // model ids don't match + { + var stats1 = new InferenceRequestStats("service", TaskType.TEXT_EMBEDDING, "model_id1", 1); + var stats2 = new InferenceRequestStats("service", TaskType.TEXT_EMBEDDING, "model_id2", 2); + + var thrownException = expectThrows(AssertionError.class, () -> InferenceRequestStats.merge(stats1, stats2)); + + assertThat(thrownException.getMessage(), is("model ids do not match")); + } + } + @Override protected InferenceRequestStats mutateInstanceForVersion(InferenceRequestStats instance, TransportVersion version) { return instance; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInternalInferenceUsageActionNodeResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInternalInferenceUsageActionNodeResponseTests.java new file mode 100644 index 0000000000000..81e6433afb9f8 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInternalInferenceUsageActionNodeResponseTests.java @@ -0,0 +1,57 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.action; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.node.DiscoveryNodeUtils; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xpack.core.inference.InferenceRequestStats; +import org.elasticsearch.xpack.core.inference.InferenceRequestStatsTests; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; +import java.util.HashMap; + +public class GetInternalInferenceUsageActionNodeResponseTests extends AbstractBWCWireSerializationTestCase< + GetInternalInferenceUsageAction.NodeResponse> { + public static GetInternalInferenceUsageAction.NodeResponse createRandom() { + DiscoveryNode node = DiscoveryNodeUtils.create("id"); + var stats = new HashMap(); + + for (int i = 0; i < randomIntBetween(1, 10); i++) { + stats.put(randomAlphaOfLength(10), InferenceRequestStatsTests.createRandom()); + } + + return new GetInternalInferenceUsageAction.NodeResponse(node, stats); + } + + @Override + protected Writeable.Reader instanceReader() { + return GetInternalInferenceUsageAction.NodeResponse::new; + } + + @Override + protected GetInternalInferenceUsageAction.NodeResponse createTestInstance() { + return createRandom(); + } + + @Override + protected GetInternalInferenceUsageAction.NodeResponse mutateInstance(GetInternalInferenceUsageAction.NodeResponse instance) + throws IOException { + return randomValueOtherThan(instance, this::createTestInstance); + } + + @Override + protected GetInternalInferenceUsageAction.NodeResponse mutateInstanceForVersion( + GetInternalInferenceUsageAction.NodeResponse instance, + TransportVersion version + ) { + return instance; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInternalInferenceUsageActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInternalInferenceUsageActionResponseTests.java new file mode 100644 index 0000000000000..b0a69842f778e --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/GetInternalInferenceUsageActionResponseTests.java @@ -0,0 +1,53 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.action; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; +import java.util.List; + +public class GetInternalInferenceUsageActionResponseTests extends AbstractBWCWireSerializationTestCase< + GetInternalInferenceUsageAction.Response> { + public static GetInternalInferenceUsageAction.Response createRandom() { + List responses = randomList( + 2, + 10, + GetInternalInferenceUsageActionNodeResponseTests::createRandom + ); + + return new GetInternalInferenceUsageAction.Response(ClusterName.DEFAULT, responses, List.of()); + } + + @Override + protected Writeable.Reader instanceReader() { + return GetInternalInferenceUsageAction.Response::new; + } + + @Override + protected GetInternalInferenceUsageAction.Response createTestInstance() { + return createRandom(); + } + + @Override + protected GetInternalInferenceUsageAction.Response mutateInstance(GetInternalInferenceUsageAction.Response instance) + throws IOException { + return randomValueOtherThan(instance, this::createTestInstance); + } + + @Override + protected GetInternalInferenceUsageAction.Response mutateInstanceForVersion( + GetInternalInferenceUsageAction.Response instance, + TransportVersion version + ) { + return instance; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index fce2c54c535c9..c5cdfde7d21d6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -43,11 +43,13 @@ import org.elasticsearch.xpack.core.inference.action.DeleteInferenceEndpointAction; import org.elasticsearch.xpack.core.inference.action.GetInferenceDiagnosticsAction; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; +import org.elasticsearch.xpack.core.inference.action.GetInternalInferenceUsageAction; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; import org.elasticsearch.xpack.inference.action.TransportDeleteInferenceEndpointAction; import org.elasticsearch.xpack.inference.action.TransportGetInferenceDiagnosticsAction; import org.elasticsearch.xpack.inference.action.TransportGetInferenceModelAction; +import org.elasticsearch.xpack.inference.action.TransportGetInternalInferenceUsageAction; import org.elasticsearch.xpack.inference.action.TransportInferenceAction; import org.elasticsearch.xpack.inference.action.TransportInferenceUsageAction; import org.elasticsearch.xpack.inference.action.TransportPutInferenceModelAction; @@ -84,8 +86,7 @@ import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService; import org.elasticsearch.xpack.inference.services.mistral.MistralService; import org.elasticsearch.xpack.inference.services.openai.OpenAiService; -import org.elasticsearch.xpack.inference.telemetry.InferenceAPMStats; -import org.elasticsearch.xpack.inference.telemetry.StatsMap; +import org.elasticsearch.xpack.inference.telemetry.InferenceRequestStatsMap; import java.util.ArrayList; import java.util.Collection; @@ -140,7 +141,8 @@ public InferencePlugin(Settings settings) { new ActionHandler<>(PutInferenceModelAction.INSTANCE, TransportPutInferenceModelAction.class), new ActionHandler<>(DeleteInferenceEndpointAction.INSTANCE, TransportDeleteInferenceEndpointAction.class), new ActionHandler<>(XPackUsageFeatureAction.INFERENCE, TransportInferenceUsageAction.class), - new ActionHandler<>(GetInferenceDiagnosticsAction.INSTANCE, TransportGetInferenceDiagnosticsAction.class) + new ActionHandler<>(GetInferenceDiagnosticsAction.INSTANCE, TransportGetInferenceDiagnosticsAction.class), + new ActionHandler<>(GetInternalInferenceUsageAction.INSTANCE, TransportGetInternalInferenceUsageAction.class) ); } @@ -196,9 +198,7 @@ public Collection createComponents(PluginServices services) { var actionFilter = new ShardBulkInferenceActionFilter(registry, modelRegistry); shardBulkInferenceActionFilter.set(actionFilter); - var statsFactory = new InferenceAPMStats.Factory(services.telemetryProvider().getMeterRegistry()); - var statsMap = new StatsMap<>(InferenceAPMStats::key, statsFactory::newInferenceRequestAPMCounter); - + var statsMap = InferenceRequestStatsMap.of(services.telemetryProvider().getMeterRegistry()); return List.of(modelRegistry, registry, httpClientManager, statsMap); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInternalInferenceUsageAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInternalInferenceUsageAction.java new file mode 100644 index 0000000000000..dae2bd6a2d742 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInternalInferenceUsageAction.java @@ -0,0 +1,78 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.action.FailedNodeException; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.nodes.TransportNodesAction; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.inference.action.GetInternalInferenceUsageAction; +import org.elasticsearch.xpack.inference.telemetry.InferenceRequestStatsMap; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public class TransportGetInternalInferenceUsageAction extends TransportNodesAction< + GetInternalInferenceUsageAction.Request, + GetInternalInferenceUsageAction.Response, + GetInternalInferenceUsageAction.NodeRequest, + GetInternalInferenceUsageAction.NodeResponse> { + + private final InferenceRequestStatsMap statsMap; + + @Inject + public TransportGetInternalInferenceUsageAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + InferenceRequestStatsMap statsMap + ) { + super( + GetInternalInferenceUsageAction.NAME, + clusterService, + transportService, + actionFilters, + GetInternalInferenceUsageAction.NodeRequest::new, + threadPool.executor(ThreadPool.Names.MANAGEMENT) + ); + + this.statsMap = Objects.requireNonNull(statsMap); + } + + @Override + protected GetInternalInferenceUsageAction.Response newResponse( + GetInternalInferenceUsageAction.Request request, + List nodeResponses, + List failures + ) { + return new GetInternalInferenceUsageAction.Response(clusterService.getClusterName(), nodeResponses, failures); + } + + @Override + protected GetInternalInferenceUsageAction.NodeRequest newNodeRequest(GetInternalInferenceUsageAction.Request request) { + return new GetInternalInferenceUsageAction.NodeRequest(); + } + + @Override + protected GetInternalInferenceUsageAction.NodeResponse newNodeResponse(StreamInput in, DiscoveryNode node) throws IOException { + return new GetInternalInferenceUsageAction.NodeResponse(in); + } + + @Override + protected GetInternalInferenceUsageAction.NodeResponse nodeOperation(GetInternalInferenceUsageAction.NodeRequest request, Task task) { + return new GetInternalInferenceUsageAction.NodeResponse(transportService.getLocalNode(), statsMap.toSerializableMap()); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index 575697b5d0d39..f56420a627547 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -21,22 +21,28 @@ import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.telemetry.InferenceRequestStatsMap; + +import java.util.Objects; public class TransportInferenceAction extends HandledTransportAction { private final ModelRegistry modelRegistry; private final InferenceServiceRegistry serviceRegistry; + private final InferenceRequestStatsMap statsMap; @Inject public TransportInferenceAction( TransportService transportService, ActionFilters actionFilters, ModelRegistry modelRegistry, - InferenceServiceRegistry serviceRegistry + InferenceServiceRegistry serviceRegistry, + InferenceRequestStatsMap statsMap ) { super(InferenceAction.NAME, transportService, actionFilters, InferenceAction.Request::new, EsExecutors.DIRECT_EXECUTOR_SERVICE); - this.modelRegistry = modelRegistry; - this.serviceRegistry = serviceRegistry; + this.modelRegistry = Objects.requireNonNull(modelRegistry); + this.serviceRegistry = Objects.requireNonNull(serviceRegistry); + this.statsMap = Objects.requireNonNull(statsMap); } @Override @@ -76,6 +82,7 @@ protected void doExecute(Task task, InferenceAction.Request request, ActionListe unparsedModel.settings(), unparsedModel.secrets() ); + statsMap.increment(model); inferOnService(model, request, service.get(), delegate); }); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageAction.java index 712cb1ebad781..e115a684dac2d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageAction.java @@ -9,6 +9,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.RefCountingListener; import org.elasticsearch.client.internal.Client; import org.elasticsearch.client.internal.OriginSettingClient; import org.elasticsearch.cluster.ClusterState; @@ -25,10 +26,14 @@ import org.elasticsearch.xpack.core.action.XPackUsageFeatureResponse; import org.elasticsearch.xpack.core.action.XPackUsageFeatureTransportAction; import org.elasticsearch.xpack.core.inference.InferenceFeatureSetUsage; +import org.elasticsearch.xpack.core.inference.InferenceRequestStats; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; +import org.elasticsearch.xpack.core.inference.action.GetInternalInferenceUsageAction; +import java.util.Collection; import java.util.Map; import java.util.TreeMap; +import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; @@ -63,9 +68,32 @@ protected void masterOperation( ClusterState state, ActionListener listener ) { + var modelStatsRef = new AtomicReference>(); + var requestStatsRef = new AtomicReference>(); + + try ( + var listeners = new RefCountingListener( + // buildFeatureResponse will be called when the RefCounterListener is closed + listener.map(ignored -> buildFeatureResponse(modelStatsRef.get(), requestStatsRef.get())) + ) + ) { + getModelStats(listeners.acquire(modelStatsRef::set)); + getRequestStats(listeners.acquire(requestStatsRef::set)); + } + } + + private XPackUsageFeatureResponse buildFeatureResponse( + Collection modelStats, + Collection requestStats + ) { + return new XPackUsageFeatureResponse(new InferenceFeatureSetUsage(modelStats, requestStats)); + } + + private void getModelStats(ActionListener> listener) { GetInferenceModelAction.Request getInferenceModelAction = new GetInferenceModelAction.Request("_all", TaskType.ANY); client.execute(GetInferenceModelAction.INSTANCE, getInferenceModelAction, listener.delegateFailureAndWrap((delegate, response) -> { Map stats = new TreeMap<>(); + for (ModelConfigurations model : response.getEndpoints()) { String statKey = model.getService() + ":" + model.getTaskType().name(); InferenceFeatureSetUsage.ModelStats stat = stats.computeIfAbsent( @@ -74,8 +102,21 @@ protected void masterOperation( ); stat.add(); } - InferenceFeatureSetUsage usage = new InferenceFeatureSetUsage(stats.values()); - delegate.onResponse(new XPackUsageFeatureResponse(usage)); + + delegate.onResponse(stats.values()); + })); + } + + private void getRequestStats(ActionListener> listener) { + var action = new GetInternalInferenceUsageAction.Request(); + client.execute(GetInternalInferenceUsageAction.INSTANCE, action, listener.delegateFailureAndWrap((delegate, response) -> { + var accumulatedStats = new TreeMap(); + + for (var node : response.getNodes()) { + node.getInferenceRequestStats().forEach((key, value) -> accumulatedStats.merge(key, value, InferenceRequestStats::merge)); + } + + delegate.onResponse(accumulatedStats.values()); })); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java index 538d88a59ca76..1e4d2d7524230 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java @@ -48,7 +48,7 @@ public CohereEmbeddingsModel( // should only be used for testing CohereEmbeddingsModel( - String modelId, + String inferenceId, TaskType taskType, String service, CohereEmbeddingsServiceSettings serviceSettings, @@ -56,7 +56,7 @@ public CohereEmbeddingsModel( @Nullable DefaultSecretSettings secretSettings ) { super( - new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), + new ModelConfigurations(inferenceId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secretSettings), secretSettings, serviceSettings.getCommonSettings() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java index d474e935fbda7..6ef1f6f0feefe 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java @@ -150,7 +150,7 @@ public OpenAiEmbeddingsServiceSettings( @Nullable RateLimitSettings rateLimitSettings ) { this.uri = uri; - this.modelId = modelId; + this.modelId = Objects.requireNonNull(modelId); this.organizationId = organizationId; this.similarity = similarity; this.dimensions = dimensions; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceAPMStats.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceAPMStats.java index 76977fef76045..8016c2206798c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceAPMStats.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceAPMStats.java @@ -11,37 +11,48 @@ import org.elasticsearch.telemetry.metric.LongCounter; import org.elasticsearch.telemetry.metric.MeterRegistry; +import java.util.HashMap; import java.util.Map; import java.util.Objects; public class InferenceAPMStats extends InferenceStats { - private final LongCounter inferenceAPMRequestCounter; + private final LongCounter requestCounter; - public InferenceAPMStats(Model model, MeterRegistry meterRegistry) { + public InferenceAPMStats(Model model, LongCounter requestCounter) { super(model); - this.inferenceAPMRequestCounter = meterRegistry.registerLongCounter( - "es.inference.requests.count", - "Inference API request counts for a particular service, task type, model ID", - "operations" - ); + this.requestCounter = Objects.requireNonNull(requestCounter); } @Override public void increment() { super.increment(); - inferenceAPMRequestCounter.incrementBy(1, Map.of("service", service, "task_type", taskType.toString(), "model_id", modelId)); + var attributes = new HashMap(Map.of("service", service, "task_type", taskType.toString())); + + if (modelId != null) { + attributes.put("model_id", modelId); + } + + requestCounter.incrementBy(1, attributes); } public static final class Factory { - private final MeterRegistry meterRegistry; + private final LongCounter requestCounter; public Factory(MeterRegistry meterRegistry) { - this.meterRegistry = Objects.requireNonNull(meterRegistry); + Objects.requireNonNull(meterRegistry); + + // A meter with a specific name can only be registered once + this.requestCounter = meterRegistry.registerLongCounter( + // We get an error if the name doesn't end with a specific value, total is a valid option + "es.inference.requests.count.total", + "Inference API request counts for a particular service, task type, model ID", + "operations" + ); } public InferenceAPMStats newInferenceRequestAPMCounter(Model model) { - return new InferenceAPMStats(model, meterRegistry); + return new InferenceAPMStats(model, requestCounter); } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceRequestStatsMap.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceRequestStatsMap.java new file mode 100644 index 0000000000000..fd9bd96aeb305 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceRequestStatsMap.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.telemetry; + +import org.elasticsearch.inference.Model; +import org.elasticsearch.telemetry.metric.MeterRegistry; +import org.elasticsearch.xpack.core.inference.InferenceRequestStats; + +/** + * The purpose of this class is to get around an issue with {@link org.elasticsearch.common.inject.Inject} that doesn't seem to allow + * generics in the constructor. Subclassing it here seems to work. + */ +public class InferenceRequestStatsMap extends StatsMap { + public static InferenceRequestStatsMap of(MeterRegistry meterRegistry) { + var statsFactory = new InferenceAPMStats.Factory(meterRegistry); + return new InferenceRequestStatsMap(statsFactory); + } + + private InferenceRequestStatsMap(InferenceAPMStats.Factory factory) { + super(InferenceAPMStats::key, factory::newInferenceRequestAPMCounter); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceStats.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceStats.java index d639f9da71f56..de0a68f0fb10b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceStats.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceStats.java @@ -14,7 +14,7 @@ import java.util.Objects; import java.util.concurrent.atomic.LongAdder; -public class InferenceStats implements Stats { +public class InferenceStats implements Stats { protected final String service; protected final TaskType taskType; protected final String modelId; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/Stats.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/Stats.java index bb1e9c98fc2cb..630939226f327 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/Stats.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/Stats.java @@ -9,7 +9,7 @@ import org.elasticsearch.xpack.core.inference.SerializableStats; -public interface Stats { +public interface Stats { /** * Increase the counter by one. @@ -26,5 +26,5 @@ public interface Stats { * Convert the object into a serializable form that can be written across nodes and returned in xcontent format. * @return the serializable format of the object */ - SerializableStats toSerializableForm(); + T toSerializableForm(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/StatsMap.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/StatsMap.java index 1cfecfb4507d6..da1bfe1fd3994 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/StatsMap.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/StatsMap.java @@ -21,8 +21,9 @@ * * @param The input to derive the keys and values for the map * @param The type of the values stored in the map + * @param The type that {@link Value} will convert to in order to serialize across nodes and in a response */ -public class StatsMap { +public class StatsMap, SerializableType extends SerializableStats> { private final ConcurrentMap stats = new ConcurrentHashMap<>(); private final Function keyCreator; @@ -51,7 +52,7 @@ public void increment(Input input) { * be represented in the resulting serializable map. * @return a map that is more easily serializable */ - public Map toSerializableMap() { + public Map toSerializableMap() { return stats.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().toSerializableForm())); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageActionTests.java index b0c59fe160be3..891d61f48b24d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageActionTests.java @@ -11,8 +11,10 @@ import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.node.DiscoveryNodeUtils; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.inference.ModelConfigurations; @@ -32,12 +34,15 @@ import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.core.action.XPackUsageFeatureResponse; import org.elasticsearch.xpack.core.inference.InferenceFeatureSetUsage; +import org.elasticsearch.xpack.core.inference.InferenceRequestStats; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; +import org.elasticsearch.xpack.core.inference.action.GetInternalInferenceUsageAction; import org.elasticsearch.xpack.core.watcher.support.xcontent.XContentSource; import org.junit.After; import org.junit.Before; import java.util.List; +import java.util.Map; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.core.Is.is; @@ -74,7 +79,7 @@ public void close() { client.threadPool().shutdown(); } - public void test() throws Exception { + public void testModelsResponse() throws Exception { doAnswer(invocation -> { @SuppressWarnings("unchecked") var listener = (ActionListener) invocation.getArguments()[2]; @@ -93,6 +98,15 @@ public void test() throws Exception { return Void.TYPE; }).when(client).execute(any(GetInferenceModelAction.class), any(), any()); + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocation.getArguments()[2]; + + listener.onResponse(new GetInternalInferenceUsageAction.Response(ClusterName.DEFAULT, List.of(), List.of())); + + return Void.TYPE; + }).when(client).execute(any(GetInternalInferenceUsageAction.class), any(), any()); + PlainActionFuture future = new PlainActionFuture<>(); action.masterOperation(mock(Task.class), mock(XPackUsageRequest.class), mock(ClusterState.class), future); @@ -118,4 +132,85 @@ public void test() throws Exception { assertThat(source.getValue("models.2.task_type"), is("TEXT_EMBEDDING")); assertThat(source.getValue("models.2.count"), is(3)); } + + public void testInferenceRequestStats() throws Exception { + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse( + new GetInferenceModelAction.Response( + List.of(new ModelConfigurations("model-001", TaskType.TEXT_EMBEDDING, "openai", mock(ServiceSettings.class))) + ) + ); + return Void.TYPE; + }).when(client).execute(any(GetInferenceModelAction.class), any(), any()); + + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocation.getArguments()[2]; + + // Count for openai:text_embedding:model1 is 3 total for the whole list + var nodeResponses = List.of( + new GetInternalInferenceUsageAction.NodeResponse( + DiscoveryNodeUtils.create("id1"), + Map.of("openai:text_embedding:model1", new InferenceRequestStats("openai", TaskType.TEXT_EMBEDDING, "model1", 1)) + ), + new GetInternalInferenceUsageAction.NodeResponse( + DiscoveryNodeUtils.create("z"), + // this should be ordered last in the resulting list + Map.of("z_service:text_embedding:model1", new InferenceRequestStats("z_service", TaskType.TEXT_EMBEDDING, "model1", 1)) + ), + new GetInternalInferenceUsageAction.NodeResponse( + DiscoveryNodeUtils.create("id2"), + Map.of("openai:text_embedding:model1", new InferenceRequestStats("openai", TaskType.TEXT_EMBEDDING, "model1", 1)) + ), + new GetInternalInferenceUsageAction.NodeResponse( + DiscoveryNodeUtils.create("id3"), + Map.of( + "cohere:text_embedding:model1", + new InferenceRequestStats("cohere", TaskType.TEXT_EMBEDDING, "model1", 1), + "openai:text_embedding:model1", + new InferenceRequestStats("openai", TaskType.TEXT_EMBEDDING, "model1", 1) + ) + ), + new GetInternalInferenceUsageAction.NodeResponse( + DiscoveryNodeUtils.create("id4"), + // this should be ordered first in the resulting list + Map.of("a_service:text_embedding:model1", new InferenceRequestStats("a_service", TaskType.TEXT_EMBEDDING, "model1", 1)) + ) + ); + + listener.onResponse(new GetInternalInferenceUsageAction.Response(ClusterName.DEFAULT, nodeResponses, List.of())); + + return Void.TYPE; + }).when(client).execute(any(GetInternalInferenceUsageAction.class), any(), any()); + + PlainActionFuture future = new PlainActionFuture<>(); + action.masterOperation(mock(Task.class), mock(XPackUsageRequest.class), mock(ClusterState.class), future); + + BytesStreamOutput out = new BytesStreamOutput(); + future.get().getUsage().writeTo(out); + XPackFeatureSet.Usage usage = new InferenceFeatureSetUsage(out.bytes().streamInput()); + + assertThat(usage.name(), is(XPackField.INFERENCE)); + assertTrue(usage.enabled()); + assertTrue(usage.available()); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + usage.toXContent(builder, ToXContent.EMPTY_PARAMS); + XContentSource source = new XContentSource(builder); + + assertThat(source.getAsMap().get("models"), is(List.of(Map.of("service", "openai", "task_type", "TEXT_EMBEDDING", "count", 1)))); + assertThat( + source.getAsMap().get("requests"), + is( + List.of( + Map.of("service", "a_service", "task_type", TaskType.TEXT_EMBEDDING.toString(), "model_id", "model1", "count", 1), + Map.of("service", "cohere", "task_type", TaskType.TEXT_EMBEDDING.toString(), "model_id", "model1", "count", 1), + Map.of("service", "openai", "task_type", TaskType.TEXT_EMBEDDING.toString(), "model_id", "model1", "count", 3), + Map.of("service", "z_service", "task_type", TaskType.TEXT_EMBEDDING.toString(), "model_id", "model1", "count", 1) + ) + ) + ); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java index 093283c0b37d6..6942bd9e9c766 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java @@ -213,17 +213,7 @@ public static CohereEmbeddingsModel createModel( @Nullable String model, @Nullable CohereEmbeddingType embeddingType ) { - return new CohereEmbeddingsModel( - "id", - TaskType.TEXT_EMBEDDING, - "service", - new CohereEmbeddingsServiceSettings( - new CohereServiceSettings(url, SimilarityMeasure.DOT_PRODUCT, dimensions, tokenLimit, model, null), - Objects.requireNonNullElse(embeddingType, CohereEmbeddingType.FLOAT) - ), - taskSettings, - new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) - ); + return createModel(url, apiKey, taskSettings, tokenLimit, dimensions, model, embeddingType, SimilarityMeasure.DOT_PRODUCT); } public static CohereEmbeddingsModel createModel( @@ -239,7 +229,7 @@ public static CohereEmbeddingsModel createModel( return new CohereEmbeddingsModel( "id", TaskType.TEXT_EMBEDDING, - "service", + "cohere", new CohereEmbeddingsServiceSettings( new CohereServiceSettings(url, similarityMeasure, dimensions, tokenLimit, model, null), Objects.requireNonNullElse(embeddingType, CohereEmbeddingType.FLOAT) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/InferenceAPMStatsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/InferenceAPMStatsTests.java new file mode 100644 index 0000000000000..81541bd392720 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/InferenceAPMStatsTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.telemetry; + +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.telemetry.metric.LongCounter; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModelTests; + +import java.util.Map; + +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class InferenceAPMStatsTests extends ESTestCase { + public void testIncrement_DoesNotThrowWhenModelIdIsNull() { + var counterMock = mock(LongCounter.class); + var stats = new InferenceAPMStats(CohereEmbeddingsModelTests.createModel("url", "api-key", null, null, null), counterMock); + + stats.increment(); + verify(counterMock, times(1)).incrementBy(eq(1L), eq(Map.of("service", "cohere", "task_type", TaskType.TEXT_EMBEDDING.toString()))); + } + + public void testIncrement_DoesNotThrowWhenModelIdIsNotNull() { + var counterMock = mock(LongCounter.class); + var stats = new InferenceAPMStats(CohereEmbeddingsModelTests.createModel("url", "api-key", null, "model_a", null), counterMock); + + stats.increment(); + verify(counterMock, times(1)).incrementBy( + eq(1L), + eq(Map.of("service", "cohere", "task_type", TaskType.TEXT_EMBEDDING.toString(), "model_id", "model_a")) + ); + } +} diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index 9eee5b0bd7a6f..83cca44c54cbf 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -384,6 +384,7 @@ public class Constants { "cluster:monitor/xpack/inference", "cluster:monitor/xpack/inference/get", "cluster:monitor/xpack/inference/diagnostics/get", + "cluster:monitor/xpack/inference/internal_usage/get", "cluster:monitor/xpack/info", "cluster:monitor/xpack/info/aggregate_metric", "cluster:monitor/xpack/info/analytics",