Skip to content

[ML] Prevent the trained model deployment memory estimation from double-counting allocations. #131918

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
5 changes: 5 additions & 0 deletions docs/changelog/131918.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 131918
summary: Prevent the trained model deployment memory estimation from double-counting allocations.
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,7 @@ private static void copyAssignments(
Map<AssignmentPlan.Node, Integer> nodeAssignments = source.assignments(m).orElse(Map.of());
for (Map.Entry<AssignmentPlan.Node, Integer> assignment : nodeAssignments.entrySet()) {
AssignmentPlan.Node originalNode = originalNodeById.get(assignment.getKey().id());
dest.assignModelToNode(m, originalNode, assignment.getValue());
// As the node has all its available memory we need to manually account memory of models with
// current allocations.
dest.accountMemory(m, originalNode);
dest.assignModelToNodeAndAccountForCurrentAllocations(m, originalNode, assignment.getValue());
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,30 @@ public class AssignmentPlan implements Comparable<AssignmentPlan> {
* @param perDeploymentMemoryBytes
* @param perAllocationMemoryBytes
*/
/**
* Interface for memory estimation function used by Deployment
*/
@FunctionalInterface
public interface MemoryEstimator {
/**
* Estimates memory usage for a given number of allocations
*
* @param modelId the model ID
* @param modelBytes the model size in bytes
* @param perDeploymentMemoryBytes the fixed per-deployment memory overhead
* @param perAllocationMemoryBytes the memory per allocation
* @param allocations the number of allocations
* @return estimated memory usage in bytes
*/
long estimateMemoryUsageBytes(
String modelId,
long modelBytes,
long perDeploymentMemoryBytes,
long perAllocationMemoryBytes,
int allocations
);
}

public record Deployment(
String deploymentId,
String modelId,
Expand All @@ -56,8 +80,42 @@ public record Deployment(
AdaptiveAllocationsSettings adaptiveAllocationsSettings,
Priority priority,
long perDeploymentMemoryBytes,
long perAllocationMemoryBytes
long perAllocationMemoryBytes,
MemoryEstimator memoryEstimator
) {

/**
* Default constructor that uses the standard memory estimator
*/
public Deployment(
String deploymentId,
String modelId,
long memoryBytes,
int allocations,
int threadsPerAllocation,
Map<String, Integer> currentAllocationsByNodeId,
int maxAssignedAllocations,
AdaptiveAllocationsSettings adaptiveAllocationsSettings,
Priority priority,
long perDeploymentMemoryBytes,
long perAllocationMemoryBytes
) {
this(
deploymentId,
modelId,
memoryBytes,
allocations,
threadsPerAllocation,
currentAllocationsByNodeId,
maxAssignedAllocations,
adaptiveAllocationsSettings,
priority,
perDeploymentMemoryBytes,
perAllocationMemoryBytes,
StartTrainedModelDeploymentAction::estimateMemoryUsageBytes
);
}

public Deployment(
String deploymentId,
String modelId,
Expand All @@ -81,7 +139,8 @@ public Deployment(
adaptiveAllocationsSettings,
Priority.NORMAL,
perDeploymentMemoryBytes,
perAllocationMemoryBytes
perAllocationMemoryBytes,
StartTrainedModelDeploymentAction::estimateMemoryUsageBytes
);
}

Expand All @@ -98,7 +157,7 @@ boolean hasEverBeenAllocated() {
}

public long estimateMemoryUsageBytes(int allocations) {
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(
return memoryEstimator.estimateMemoryUsageBytes(
modelId,
memoryBytes,
perDeploymentMemoryBytes,
Expand All @@ -108,29 +167,11 @@ public long estimateMemoryUsageBytes(int allocations) {
}

long estimateAdditionalMemoryUsageBytes(int allocationsOld, int allocationsNew) {
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(
modelId,
memoryBytes,
perDeploymentMemoryBytes,
perAllocationMemoryBytes,
allocationsNew
) - StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(
modelId,
memoryBytes,
perDeploymentMemoryBytes,
perAllocationMemoryBytes,
allocationsOld
);
return estimateMemoryUsageBytes(allocationsNew) - estimateMemoryUsageBytes(allocationsOld);
}

long minimumMemoryRequiredBytes() {
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(
modelId,
memoryBytes,
perDeploymentMemoryBytes,
perAllocationMemoryBytes,
1
);
return estimateMemoryUsageBytes(1);
}

int findOptimalAllocations(int maxAllocations, long availableMemoryBytes) {
Expand Down Expand Up @@ -459,22 +500,40 @@ public Builder assignModelToNode(Deployment deployment, Node node, int allocatio
return this;
}

/**
* Assigns a model to a node, and if the node has existing allocations,
* also accounts for memory usage of those existing allocations.
* This handles the common case of transferring assignments between plans.
*
* @param deployment The deployment to assign
* @param node The target node
* @param newAllocations The number of new allocations to assign
*/
public void assignModelToNodeAndAccountForCurrentAllocations(Deployment deployment, Node node, int newAllocations) {
// First, handle the assignment of new allocations
assignModelToNode(deployment, node, newAllocations);

// Then, account for memory for current allocations that if needed
int currentAllocations = getCurrentAllocations(deployment, node);
if (currentAllocations > 0) {
long memoryForCurrentAllocations = deployment.estimateMemoryUsageBytes(currentAllocations);
accountMemory(deployment, node, memoryForCurrentAllocations);
Copy link
Contributor

@jan-elastic jan-elastic Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this. Why call:

  • assignModelToNode for the new allocations; and
  • accountMemory for the old ones?

I guess I also don't really understand what the state of AssignmentPlan exactly contains.

Isn't the old already accounted for? And what about the cores?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Furthermore, it feels like this shouldn't be doing that much, so I don't get why it's 500+ lines of similar confusing methods...

}
}

private int getAssignedAllocations(Deployment deployment, Node node) {
int currentAllocations = getCurrentAllocations(deployment, node);
int assignmentAllocations = assignments.get(deployment).get(node);
return currentAllocations + assignmentAllocations;
}

private static int getCurrentAllocations(Deployment m, Node n) {
return m.currentAllocationsByNodeId.containsKey(n.id()) ? m.currentAllocationsByNodeId.get(n.id()) : 0;
private static int getCurrentAllocations(Deployment deployment, Node node) {
return deployment.currentAllocationsByNodeId.getOrDefault(node.id(), 0);
}

public void accountMemory(Deployment m, Node n) {
if (m.currentAllocationsByNodeId().containsKey(n.id())) {
int allocations = m.currentAllocationsByNodeId().get(n.id());
long requiredMemory = m.estimateMemoryUsageBytes(allocations);
accountMemory(m, n, requiredMemory);
}
long requiredMemory = getDeploymentMemoryRequirement(m, n, getCurrentAllocations(m, n));
accountMemory(m, n, requiredMemory);
}

private void accountMemory(Deployment m, Node n, long requiredMemory) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,7 @@ private AssignmentPlan swapOriginalModelsInPlan(
Map<Node, Integer> nodeAssignments = plan.assignments(m).orElse(Map.of());
for (Map.Entry<Node, Integer> assignment : nodeAssignments.entrySet()) {
Node originalNode = originalNodeById.get(assignment.getKey().id());
planBuilder.assignModelToNode(originalDeployment, originalNode, assignment.getValue());
// As the node has all its available memory we need to manually account memory of models with
// current allocations.
planBuilder.accountMemory(originalDeployment, originalNode);
planBuilder.assignModelToNodeAndAccountForCurrentAllocations(originalDeployment, originalNode, assignment.getValue());
}
}
return planBuilder.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@

import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.Deployment;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.MemoryEstimator;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.Node;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

Expand All @@ -20,6 +23,7 @@
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThan;
import static org.hamcrest.Matchers.lessThanOrEqualTo;

public class AssignmentPlanTests extends ESTestCase {

Expand Down Expand Up @@ -740,4 +744,75 @@ public void testCountPreviouslyAssignedThatAreStillAssigned() {
.build();
assertThat(plan.countPreviouslyAssignedModelsThatAreStillAssigned(), equalTo(2L));
}

public void testAssignModelToNodeAndAccountForCurrentAllocations_GivenScalingFromThreeToFourAllocations() {
// Test that we are not double-counting memory when scaling from 3 to 4 allocations

int targetAllocations = 4;
int currentAllocations = 3;
int newAllocations = targetAllocations - currentAllocations;

// Create a node with sufficient memory
long nodeMemoryBytes = ByteSizeValue.ofGb(1).getBytes();
Node node = new Node("node-1", nodeMemoryBytes, 8);

// Create a deployment with 3 current allocations on node-1, target 4 allocations
Map<String, Integer> currentAllocationsMap = Map.of("node-1", currentAllocations);

// Create a deployment that's being scaled from 3 to 4 allocations
long modelBytes = ByteSizeValue.ofMb(10).getBytes();
long perDeploymentMemoryBytes = ByteSizeValue.ofMb(100).getBytes();
long perAllocationMemoryBytes = ByteSizeValue.ofMb(10).getBytes();

// List to track allocation counts used in memory estimation
final List<Integer> allocationCounts = new ArrayList<>();

// Create tracking memory estimator that logs allocation parameters
MemoryEstimator trackingEstimator = (modelId, bytes, deploymentMemory, allocationMemory, allocations) -> {
// Record the allocations parameter
allocationCounts.add(allocations);

// Return a simple calculation for memory estimation
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(
modelId,
bytes,
deploymentMemory,
allocationMemory,
allocations
);
};

// Create a deployment with our tracking estimator
Deployment deployment = new Deployment(
"test-deployment",
"test-model",
modelBytes,
targetAllocations,
1,
currentAllocationsMap,
currentAllocations,
null,
null,
perDeploymentMemoryBytes,
perAllocationMemoryBytes,
trackingEstimator // inject our tracking estimator
);

// Create a builder and use our method to add 1 allocation to the 3 current ones
AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(node), List.of(deployment));
builder.assignModelToNodeAndAccountForCurrentAllocations(deployment, node, newAllocations);

// Build the plan and verify assignments
AssignmentPlan plan = builder.build();
assertThat(plan.assignments(deployment).isPresent(), is(true));
assertThat(plan.assignments(deployment).get().get(node), equalTo(newAllocations)); // Verifies 1 new allocation assigned

// If we don't have double-counting, the memory estimation should be called no more than 4 allocations
int maxAllocationCount = allocationCounts.stream().max(Integer::compare).orElse(0);
assertThat(
"Should never calculate memory for more than 4 allocations at once",
maxAllocationCount,
lessThanOrEqualTo(targetAllocations)
);
}
}
Loading