Skip to content

Commit 168ba07

Browse files
authored
[IVF] Remove unnecessary loop over centroids and some clean up (#130694)
This commit removes an unnecessary loop when computing neighbours.
1 parent a3b6165 commit 168ba07

File tree

2 files changed

+38
-43
lines changed

2 files changed

+38
-43
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java

Lines changed: 32 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616
import org.elasticsearch.simdvec.ESVectorUtil;
1717

1818
import java.io.IOException;
19-
import java.util.ArrayList;
2019
import java.util.Arrays;
21-
import java.util.List;
2220
import java.util.Random;
2321

2422
/**
@@ -74,7 +72,7 @@ private static boolean stepLloyd(
7472
float[][] centroids,
7573
float[][] nextCentroids,
7674
int[] assignments,
77-
List<NeighborHood> neighborhoods
75+
NeighborHood[] neighborhoods
7876
) throws IOException {
7977
boolean changed = false;
8078
int dim = vectors.dimension();
@@ -90,7 +88,7 @@ private static boolean stepLloyd(
9088
final int assignment = assignments[vectorOrd];
9189
final int bestCentroidOffset;
9290
if (neighborhoods != null) {
93-
bestCentroidOffset = getBestCentroidFromNeighbours(centroids, vector, assignment, neighborhoods.get(assignment));
91+
bestCentroidOffset = getBestCentroidFromNeighbours(centroids, vector, assignment, neighborhoods[assignment]);
9492
} else {
9593
bestCentroidOffset = getBestCentroid(centroids, vector);
9694
}
@@ -152,30 +150,27 @@ private static int getBestCentroid(float[][] centroids, float[] vector) {
152150
return bestCentroidOffset;
153151
}
154152

155-
private void computeNeighborhoods(float[][] centers, List<NeighborHood> neighborhoods, int clustersPerNeighborhood) {
156-
int k = neighborhoods.size();
157-
158-
if (k == 0 || clustersPerNeighborhood <= 0) {
159-
return;
160-
}
161-
162-
List<NeighborQueue> neighborQueues = new ArrayList<>(k);
153+
private NeighborHood[] computeNeighborhoods(float[][] centers, int clustersPerNeighborhood) {
154+
int k = centers.length;
155+
assert k > clustersPerNeighborhood;
156+
NeighborQueue[] neighborQueues = new NeighborQueue[k];
163157
for (int i = 0; i < k; i++) {
164-
neighborQueues.add(new NeighborQueue(clustersPerNeighborhood, true));
158+
neighborQueues[i] = new NeighborQueue(clustersPerNeighborhood, true);
165159
}
166160
for (int i = 0; i < k - 1; i++) {
167161
for (int j = i + 1; j < k; j++) {
168162
float dsq = VectorUtil.squareDistance(centers[i], centers[j]);
169-
neighborQueues.get(j).insertWithOverflow(i, dsq);
170-
neighborQueues.get(i).insertWithOverflow(j, dsq);
163+
neighborQueues[j].insertWithOverflow(i, dsq);
164+
neighborQueues[i].insertWithOverflow(j, dsq);
171165
}
172166
}
173167

168+
NeighborHood[] neighborhoods = new NeighborHood[k];
174169
for (int i = 0; i < k; i++) {
175-
NeighborQueue queue = neighborQueues.get(i);
170+
NeighborQueue queue = neighborQueues[i];
176171
if (queue.size() == 0) {
177172
// no neighbors, skip
178-
neighborhoods.set(i, NeighborHood.EMPTY);
173+
neighborhoods[i] = NeighborHood.EMPTY;
179174
continue;
180175
}
181176
// consume the queue into the neighbors array and get the maximum intra-cluster distance
@@ -185,16 +180,15 @@ private void computeNeighborhoods(float[][] centers, List<NeighborHood> neighbor
185180
while (queue.size() > 0) {
186181
neighbors[neighbors.length - ++iter] = queue.pop();
187182
}
188-
NeighborHood neighborHood = new NeighborHood(neighbors, maxIntraDistance);
189-
neighborhoods.set(i, neighborHood);
183+
neighborhoods[i] = new NeighborHood(neighbors, maxIntraDistance);
190184
}
185+
return neighborhoods;
191186
}
192187

193-
private int[] assignSpilled(
188+
private void assignSpilled(
194189
FloatVectorValues vectors,
195-
List<NeighborHood> neighborhoods,
196-
float[][] centroids,
197-
int[] assignments,
190+
KMeansIntermediate kmeansIntermediate,
191+
NeighborHood[] neighborhoods,
198192
float soarLambda
199193
) throws IOException {
200194
// SOAR uses an adjusted distance for assigning spilled documents which is
@@ -205,8 +199,13 @@ private int[] assignSpilled(
205199
// Here, x is the document, c is the nearest centroid, and c_1 is the first
206200
// centroid the document was assigned to. The document is assigned to the
207201
// cluster with the smallest soar(x, c).
208-
209-
int[] spilledAssignments = new int[assignments.length];
202+
int[] assignments = kmeansIntermediate.assignments();
203+
assert assignments != null;
204+
assert assignments.length == vectors.size();
205+
int[] spilledAssignments = kmeansIntermediate.soarAssignments();
206+
assert spilledAssignments != null;
207+
assert spilledAssignments.length == vectors.size();
208+
float[][] centroids = kmeansIntermediate.centroids();
210209

211210
float[] diffs = new float[vectors.dimension()];
212211
for (int i = 0; i < vectors.size(); i++) {
@@ -230,8 +229,8 @@ private int[] assignSpilled(
230229
int centroidCount = centroids.length;
231230
IntToIntFunction centroidOrds = c -> c;
232231
if (neighborhoods != null) {
233-
assert neighborhoods.get(currAssignment) != null;
234-
NeighborHood neighborhood = neighborhoods.get(currAssignment);
232+
assert neighborhoods[currAssignment] != null;
233+
NeighborHood neighborhood = neighborhoods[currAssignment];
235234
centroidCount = neighborhood.neighbors.length;
236235
centroidOrds = c -> neighborhood.neighbors[c];
237236
}
@@ -257,8 +256,6 @@ private int[] assignSpilled(
257256
assert bestAssignment != -1 : "Failed to assign soar vector to centroid";
258257
spilledAssignments[i] = bestAssignment;
259258
}
260-
261-
return spilledAssignments;
262259
}
263260

264261
record NeighborHood(int[] neighbors, float maxIntraDistance) {
@@ -304,27 +301,20 @@ private void doCluster(FloatVectorValues vectors, KMeansIntermediate kMeansInter
304301
throws IOException {
305302
float[][] centroids = kMeansIntermediate.centroids();
306303
boolean neighborAware = clustersPerNeighborhood != -1 && centroids.length > 1;
307-
308-
List<NeighborHood> neighborhoods = null;
304+
NeighborHood[] neighborhoods = null;
309305
// if there are very few centroids, don't bother with neighborhoods or neighbor aware clustering
310306
if (neighborAware && centroids.length > clustersPerNeighborhood) {
311-
int k = centroids.length;
312-
neighborhoods = new ArrayList<>(k);
313-
for (int i = 0; i < k; ++i) {
314-
neighborhoods.add(null);
315-
}
316-
computeNeighborhoods(centroids, neighborhoods, clustersPerNeighborhood);
307+
neighborhoods = computeNeighborhoods(centroids, clustersPerNeighborhood);
317308
}
318309
cluster(vectors, kMeansIntermediate, neighborhoods);
319310
if (neighborAware) {
320-
int[] assignments = kMeansIntermediate.assignments();
321-
assert assignments != null;
322-
assert assignments.length == vectors.size();
323-
kMeansIntermediate.setSoarAssignments(assignSpilled(vectors, neighborhoods, centroids, assignments, soarLambda));
311+
assert kMeansIntermediate.soarAssignments().length == 0;
312+
kMeansIntermediate.setSoarAssignments(new int[vectors.size()]);
313+
assignSpilled(vectors, kMeansIntermediate, neighborhoods, soarLambda);
324314
}
325315
}
326316

327-
private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, List<NeighborHood> neighborhoods)
317+
private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, NeighborHood[] neighborhoods)
328318
throws IOException {
329319
float[][] centroids = kMeansIntermediate.centroids();
330320
int k = centroids.length;

server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeansTests.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,16 @@ public void testHKmeans() throws IOException {
3838

3939
assertEquals(Math.min(nClusters, nVectors), centroids.length, 8);
4040
assertEquals(nVectors, assignments.length);
41+
42+
for (int assignment : assignments) {
43+
assertTrue(assignment >= 0 && assignment < centroids.length);
44+
}
4145
if (centroids.length > 1 && centroids.length < nVectors) {
4246
assertEquals(nVectors, soarAssignments.length);
4347
// verify no duplicates exist
4448
for (int i = 0; i < assignments.length; i++) {
45-
assert assignments[i] != soarAssignments[i];
49+
assertTrue(soarAssignments[i] >= 0 && soarAssignments[i] < centroids.length);
50+
assertNotEquals(assignments[i], soarAssignments[i]);
4651
}
4752
} else {
4853
assertEquals(0, soarAssignments.length);

0 commit comments

Comments
 (0)