diff --git a/docs/changelog/131918.yaml b/docs/changelog/131918.yaml new file mode 100644 index 0000000000000..f77133afde3db --- /dev/null +++ b/docs/changelog/131918.yaml @@ -0,0 +1,5 @@ +pr: 131918 +summary: Prevent the trained model deployment memory estimation from double-counting allocations. +area: Machine Learning +type: bug +issues: [] diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java index 7d3cb71d4dc03..817d15f20688e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java @@ -139,10 +139,7 @@ private static void copyAssignments( Map nodeAssignments = source.assignments(m).orElse(Map.of()); for (Map.Entry assignment : nodeAssignments.entrySet()) { AssignmentPlan.Node originalNode = originalNodeById.get(assignment.getKey().id()); - dest.assignModelToNode(m, originalNode, assignment.getValue()); - // As the node has all its available memory we need to manually account memory of models with - // current allocations. - dest.accountMemory(m, originalNode); + dest.assignModelToNodeAndAccountForCurrentAllocations(m, originalNode, assignment.getValue()); } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java index 21e4a926a3247..437694b7427ab 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java @@ -45,6 +45,30 @@ public class AssignmentPlan implements Comparable { * @param perDeploymentMemoryBytes * @param perAllocationMemoryBytes */ + /** + * Interface for memory estimation function used by Deployment + */ + @FunctionalInterface + public interface MemoryEstimator { + /** + * Estimates memory usage for a given number of allocations + * + * @param modelId the model ID + * @param modelBytes the model size in bytes + * @param perDeploymentMemoryBytes the fixed per-deployment memory overhead + * @param perAllocationMemoryBytes the memory per allocation + * @param allocations the number of allocations + * @return estimated memory usage in bytes + */ + long estimateMemoryUsageBytes( + String modelId, + long modelBytes, + long perDeploymentMemoryBytes, + long perAllocationMemoryBytes, + int allocations + ); + } + public record Deployment( String deploymentId, String modelId, @@ -56,8 +80,42 @@ public record Deployment( AdaptiveAllocationsSettings adaptiveAllocationsSettings, Priority priority, long perDeploymentMemoryBytes, - long perAllocationMemoryBytes + long perAllocationMemoryBytes, + MemoryEstimator memoryEstimator ) { + + /** + * Default constructor that uses the standard memory estimator + */ + public Deployment( + String deploymentId, + String modelId, + long memoryBytes, + int allocations, + int threadsPerAllocation, + Map currentAllocationsByNodeId, + int maxAssignedAllocations, + AdaptiveAllocationsSettings adaptiveAllocationsSettings, + Priority priority, + long perDeploymentMemoryBytes, + long perAllocationMemoryBytes + ) { + this( + deploymentId, + modelId, + memoryBytes, + allocations, + threadsPerAllocation, + currentAllocationsByNodeId, + maxAssignedAllocations, + adaptiveAllocationsSettings, + priority, + perDeploymentMemoryBytes, + perAllocationMemoryBytes, + StartTrainedModelDeploymentAction::estimateMemoryUsageBytes + ); + } + public Deployment( String deploymentId, String modelId, @@ -81,7 +139,8 @@ public Deployment( adaptiveAllocationsSettings, Priority.NORMAL, perDeploymentMemoryBytes, - perAllocationMemoryBytes + perAllocationMemoryBytes, + StartTrainedModelDeploymentAction::estimateMemoryUsageBytes ); } @@ -98,7 +157,7 @@ boolean hasEverBeenAllocated() { } public long estimateMemoryUsageBytes(int allocations) { - return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes( + return memoryEstimator.estimateMemoryUsageBytes( modelId, memoryBytes, perDeploymentMemoryBytes, @@ -108,29 +167,11 @@ public long estimateMemoryUsageBytes(int allocations) { } long estimateAdditionalMemoryUsageBytes(int allocationsOld, int allocationsNew) { - return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes( - modelId, - memoryBytes, - perDeploymentMemoryBytes, - perAllocationMemoryBytes, - allocationsNew - ) - StartTrainedModelDeploymentAction.estimateMemoryUsageBytes( - modelId, - memoryBytes, - perDeploymentMemoryBytes, - perAllocationMemoryBytes, - allocationsOld - ); + return estimateMemoryUsageBytes(allocationsNew) - estimateMemoryUsageBytes(allocationsOld); } long minimumMemoryRequiredBytes() { - return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes( - modelId, - memoryBytes, - perDeploymentMemoryBytes, - perAllocationMemoryBytes, - 1 - ); + return estimateMemoryUsageBytes(1); } int findOptimalAllocations(int maxAllocations, long availableMemoryBytes) { @@ -459,22 +500,40 @@ public Builder assignModelToNode(Deployment deployment, Node node, int allocatio return this; } + /** + * Assigns a model to a node, and if the node has existing allocations, + * also accounts for memory usage of those existing allocations. + * This handles the common case of transferring assignments between plans. + * + * @param deployment The deployment to assign + * @param node The target node + * @param newAllocations The number of new allocations to assign + */ + public void assignModelToNodeAndAccountForCurrentAllocations(Deployment deployment, Node node, int newAllocations) { + // First, handle the assignment of new allocations + assignModelToNode(deployment, node, newAllocations); + + // Then, account for memory for current allocations that if needed + int currentAllocations = getCurrentAllocations(deployment, node); + if (currentAllocations > 0) { + long memoryForCurrentAllocations = deployment.estimateMemoryUsageBytes(currentAllocations); + accountMemory(deployment, node, memoryForCurrentAllocations); + } + } + private int getAssignedAllocations(Deployment deployment, Node node) { int currentAllocations = getCurrentAllocations(deployment, node); int assignmentAllocations = assignments.get(deployment).get(node); return currentAllocations + assignmentAllocations; } - private static int getCurrentAllocations(Deployment m, Node n) { - return m.currentAllocationsByNodeId.containsKey(n.id()) ? m.currentAllocationsByNodeId.get(n.id()) : 0; + private static int getCurrentAllocations(Deployment deployment, Node node) { + return deployment.currentAllocationsByNodeId.getOrDefault(node.id(), 0); } public void accountMemory(Deployment m, Node n) { - if (m.currentAllocationsByNodeId().containsKey(n.id())) { - int allocations = m.currentAllocationsByNodeId().get(n.id()); - long requiredMemory = m.estimateMemoryUsageBytes(allocations); - accountMemory(m, n, requiredMemory); - } + long requiredMemory = getDeploymentMemoryRequirement(m, n, getCurrentAllocations(m, n)); + accountMemory(m, n, requiredMemory); } private void accountMemory(Deployment m, Node n, long requiredMemory) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java index ff564e16becc0..d9073d0839b5c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java @@ -199,10 +199,7 @@ private AssignmentPlan swapOriginalModelsInPlan( Map nodeAssignments = plan.assignments(m).orElse(Map.of()); for (Map.Entry assignment : nodeAssignments.entrySet()) { Node originalNode = originalNodeById.get(assignment.getKey().id()); - planBuilder.assignModelToNode(originalDeployment, originalNode, assignment.getValue()); - // As the node has all its available memory we need to manually account memory of models with - // current allocations. - planBuilder.accountMemory(originalDeployment, originalNode); + planBuilder.assignModelToNodeAndAccountForCurrentAllocations(originalDeployment, originalNode, assignment.getValue()); } } return planBuilder.build(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanTests.java index 93e247dc898c4..85cf632bec3ba 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanTests.java @@ -9,9 +9,12 @@ import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.Deployment; +import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.MemoryEstimator; import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.Node; +import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -20,6 +23,7 @@ import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.lessThan; +import static org.hamcrest.Matchers.lessThanOrEqualTo; public class AssignmentPlanTests extends ESTestCase { @@ -740,4 +744,75 @@ public void testCountPreviouslyAssignedThatAreStillAssigned() { .build(); assertThat(plan.countPreviouslyAssignedModelsThatAreStillAssigned(), equalTo(2L)); } + + public void testAssignModelToNodeAndAccountForCurrentAllocations_GivenScalingFromThreeToFourAllocations() { + // Test that we are not double-counting memory when scaling from 3 to 4 allocations + + int targetAllocations = 4; + int currentAllocations = 3; + int newAllocations = targetAllocations - currentAllocations; + + // Create a node with sufficient memory + long nodeMemoryBytes = ByteSizeValue.ofGb(1).getBytes(); + Node node = new Node("node-1", nodeMemoryBytes, 8); + + // Create a deployment with 3 current allocations on node-1, target 4 allocations + Map currentAllocationsMap = Map.of("node-1", currentAllocations); + + // Create a deployment that's being scaled from 3 to 4 allocations + long modelBytes = ByteSizeValue.ofMb(10).getBytes(); + long perDeploymentMemoryBytes = ByteSizeValue.ofMb(100).getBytes(); + long perAllocationMemoryBytes = ByteSizeValue.ofMb(10).getBytes(); + + // List to track allocation counts used in memory estimation + final List allocationCounts = new ArrayList<>(); + + // Create tracking memory estimator that logs allocation parameters + MemoryEstimator trackingEstimator = (modelId, bytes, deploymentMemory, allocationMemory, allocations) -> { + // Record the allocations parameter + allocationCounts.add(allocations); + + // Return a simple calculation for memory estimation + return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes( + modelId, + bytes, + deploymentMemory, + allocationMemory, + allocations + ); + }; + + // Create a deployment with our tracking estimator + Deployment deployment = new Deployment( + "test-deployment", + "test-model", + modelBytes, + targetAllocations, + 1, + currentAllocationsMap, + currentAllocations, + null, + null, + perDeploymentMemoryBytes, + perAllocationMemoryBytes, + trackingEstimator // inject our tracking estimator + ); + + // Create a builder and use our method to add 1 allocation to the 3 current ones + AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(node), List.of(deployment)); + builder.assignModelToNodeAndAccountForCurrentAllocations(deployment, node, newAllocations); + + // Build the plan and verify assignments + AssignmentPlan plan = builder.build(); + assertThat(plan.assignments(deployment).isPresent(), is(true)); + assertThat(plan.assignments(deployment).get().get(node), equalTo(newAllocations)); // Verifies 1 new allocation assigned + + // If we don't have double-counting, the memory estimation should be called no more than 4 allocations + int maxAllocationCount = allocationCounts.stream().max(Integer::compare).orElse(0); + assertThat( + "Should never calculate memory for more than 4 allocations at once", + maxAllocationCount, + lessThanOrEqualTo(targetAllocations) + ); + } }