Skip to content

[ML] Inference request count telemetry per node #110947

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
3 changes: 2 additions & 1 deletion docs/reference/rest-api/usage.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ GET /_xpack/usage
"inference": {
"available" : true,
"enabled" : true,
"models" : []
"models" : [],
"requests" : []
},
"logstash" : {
"available" : true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ML_INFERENCE_COHERE_UNUSED_RERANK_SETTINGS_REMOVED = def(8_706_00_0);
public static final TransportVersion ENRICH_CACHE_STATS_SIZE_ADDED = def(8_707_00_0);
public static final TransportVersion ENTERPRISE_GEOIP_DOWNLOADER = def(8_708_00_0);
public static final TransportVersion ML_INFERENCE_REQUEST_TELEMETRY_ADDED = def(8_709_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.xpack.core.XPackField;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Objects;

Expand Down Expand Up @@ -102,27 +103,38 @@ public int hashCode() {
}

private final Collection<ModelStats> modelStats;
private final Collection<InferenceRequestStats> requestStats;

public InferenceFeatureSetUsage(Collection<ModelStats> modelStats) {
public InferenceFeatureSetUsage(Collection<ModelStats> modelStats, Collection<InferenceRequestStats> requestStats) {
super(XPackField.INFERENCE, true, true);
this.modelStats = modelStats;
this.modelStats = Objects.requireNonNull(modelStats);
this.requestStats = Objects.requireNonNull(requestStats);
}

public InferenceFeatureSetUsage(StreamInput in) throws IOException {
super(in);
this.modelStats = in.readCollectionAsList(ModelStats::new);
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_REQUEST_TELEMETRY_ADDED)) {
this.requestStats = in.readCollectionAsList(InferenceRequestStats::new);
} else {
this.requestStats = new ArrayList<>();
}
}

@Override
protected void innerXContent(XContentBuilder builder, Params params) throws IOException {
super.innerXContent(builder, params);
builder.xContentList("models", modelStats);
builder.xContentList("requests", requestStats);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeCollection(modelStats);
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_REQUEST_TELEMETRY_ADDED)) {
out.writeCollection(requestStats);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,20 @@
import java.util.Objects;

public class InferenceRequestStats implements SerializableStats {

public static InferenceRequestStats merge(InferenceRequestStats stats1, InferenceRequestStats stats2) {
assert stats1.modelStats.service().equals(stats2.modelStats.service()) : "services do not match";
assert stats1.modelStats.taskType().equals(stats2.modelStats.taskType()) : "task types do not match";
assert stats1.modelId.equals(stats2.modelId) : "model ids do not match";

return new InferenceRequestStats(
stats1.modelStats().service(),
stats1.modelStats().taskType(),
stats1.modelId(),
stats1.modelStats().count() + stats2.modelStats().count()
);
}

private final InferenceFeatureSetUsage.ModelStats modelStats;
private final String modelId;

Expand All @@ -35,6 +49,15 @@ public InferenceRequestStats(StreamInput in) throws IOException {
this.modelId = in.readOptionalString();
}

public InferenceFeatureSetUsage.ModelStats modelStats() {
return modelStats;
}

public String modelId() {
return modelId;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field("service", modelStats.service());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference.action;

import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.support.nodes.BaseNodeResponse;
import org.elasticsearch.action.support.nodes.BaseNodesRequest;
import org.elasticsearch.action.support.nodes.BaseNodesResponse;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.xpack.core.inference.InferenceRequestStats;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;

public class GetInternalInferenceUsageAction extends ActionType<GetInternalInferenceUsageAction.Response> {

public static final GetInternalInferenceUsageAction INSTANCE = new GetInternalInferenceUsageAction();
public static final String NAME = "cluster:monitor/xpack/inference/internal_usage/get";

public GetInternalInferenceUsageAction() {
super(NAME);
}

public static class Request extends BaseNodesRequest<Request> {

public Request() {
super((String[]) null);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
return true;
}

@Override
public int hashCode() {
// The class doesn't have any members at the moment so return the same hash code
return Objects.hash(NAME);
}
}

public static class NodeRequest extends TransportRequest {
public NodeRequest(StreamInput in) throws IOException {
super(in);
}

public NodeRequest() {}
}

public static class Response extends BaseNodesResponse<NodeResponse> implements Writeable {

public Response(StreamInput in) throws IOException {
super(in);
}

public Response(ClusterName clusterName, List<NodeResponse> nodes, List<FailedNodeException> failures) {
super(clusterName, nodes, failures);
}

@Override
protected List<NodeResponse> readNodesFrom(StreamInput in) throws IOException {
return in.readCollectionAsList(NodeResponse::new);
}

@Override
protected void writeNodesTo(StreamOutput out, List<NodeResponse> nodes) throws IOException {
out.writeCollection(nodes);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Response that = (Response) o;
return Objects.equals(getNodes(), that.getNodes()) && Objects.equals(failures(), that.failures());
}

@Override
public int hashCode() {
return Objects.hash(getNodes(), failures());
}
}

public static class NodeResponse extends BaseNodeResponse {
private final Map<String, InferenceRequestStats> inferenceRequestStats;

public NodeResponse(DiscoveryNode node, Map<String, InferenceRequestStats> inferenceRequestStats) {
super(node);
this.inferenceRequestStats = Objects.requireNonNull(inferenceRequestStats);
}

public NodeResponse(StreamInput in) throws IOException {
super(in);

inferenceRequestStats = in.readImmutableMap(StreamInput::readString, InferenceRequestStats::new);
}

public Map<String, InferenceRequestStats> getInferenceRequestStats() {
return Collections.unmodifiableMap(inferenceRequestStats);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeMap(inferenceRequestStats, StreamOutput::writeString, StreamOutput::writeWriteable);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
NodeResponse response = (NodeResponse) o;
return Objects.equals(inferenceRequestStats, response.inferenceRequestStats);
}

@Override
public int hashCode() {
return Objects.hash(inferenceRequestStats);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,46 @@ public void testToXContent_WritesModelId_WhenItIsDefined() throws IOException {
{"service":"service","task_type":"text_embedding","count":2,"model_id":"model_id"}"""));
}

public void testMerge_SumsCounts() {
var stats1 = new InferenceRequestStats("service", TaskType.TEXT_EMBEDDING, "model_id", 1);
var stats2 = new InferenceRequestStats("service", TaskType.TEXT_EMBEDDING, "model_id", 2);

assertThat(
InferenceRequestStats.merge(stats1, stats2),
is(new InferenceRequestStats("service", TaskType.TEXT_EMBEDDING, "model_id", 3))
);
}

public void testMerge_ThrowsAssertionExceptionWhenFieldsAreDifferent() {
// service names don't match
{
var stats1 = new InferenceRequestStats("service1", TaskType.TEXT_EMBEDDING, "model_id", 1);
var stats2 = new InferenceRequestStats("service2", TaskType.TEXT_EMBEDDING, "model_id", 2);

var thrownException = expectThrows(AssertionError.class, () -> InferenceRequestStats.merge(stats1, stats2));

assertThat(thrownException.getMessage(), is("services do not match"));
}
// task types don't match
{
var stats1 = new InferenceRequestStats("service", TaskType.TEXT_EMBEDDING, "model_id", 1);
var stats2 = new InferenceRequestStats("service", TaskType.RERANK, "model_id", 2);

var thrownException = expectThrows(AssertionError.class, () -> InferenceRequestStats.merge(stats1, stats2));

assertThat(thrownException.getMessage(), is("task types do not match"));
}
// model ids don't match
{
var stats1 = new InferenceRequestStats("service", TaskType.TEXT_EMBEDDING, "model_id1", 1);
var stats2 = new InferenceRequestStats("service", TaskType.TEXT_EMBEDDING, "model_id2", 2);

var thrownException = expectThrows(AssertionError.class, () -> InferenceRequestStats.merge(stats1, stats2));

assertThat(thrownException.getMessage(), is("model ids do not match"));
}
}

@Override
protected InferenceRequestStats mutateInstanceForVersion(InferenceRequestStats instance, TransportVersion version) {
return instance;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference.action;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodeUtils;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xpack.core.inference.InferenceRequestStats;
import org.elasticsearch.xpack.core.inference.InferenceRequestStatsTests;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;

import java.io.IOException;
import java.util.HashMap;

public class GetInternalInferenceUsageActionNodeResponseTests extends AbstractBWCWireSerializationTestCase<
GetInternalInferenceUsageAction.NodeResponse> {
public static GetInternalInferenceUsageAction.NodeResponse createRandom() {
DiscoveryNode node = DiscoveryNodeUtils.create("id");
var stats = new HashMap<String, InferenceRequestStats>();

for (int i = 0; i < randomIntBetween(1, 10); i++) {
stats.put(randomAlphaOfLength(10), InferenceRequestStatsTests.createRandom());
}

return new GetInternalInferenceUsageAction.NodeResponse(node, stats);
}

@Override
protected Writeable.Reader<GetInternalInferenceUsageAction.NodeResponse> instanceReader() {
return GetInternalInferenceUsageAction.NodeResponse::new;
}

@Override
protected GetInternalInferenceUsageAction.NodeResponse createTestInstance() {
return createRandom();
}

@Override
protected GetInternalInferenceUsageAction.NodeResponse mutateInstance(GetInternalInferenceUsageAction.NodeResponse instance)
throws IOException {
return randomValueOtherThan(instance, this::createTestInstance);
}

@Override
protected GetInternalInferenceUsageAction.NodeResponse mutateInstanceForVersion(
GetInternalInferenceUsageAction.NodeResponse instance,
TransportVersion version
) {
return instance;
}
}
Loading