Skip to content

Commit 0c4bf22

Browse files
authored
Don't accept clustersPerNeighborhood lower than 2 (#130526)
1 parent 9fd8bf1 commit 0c4bf22

File tree

5 files changed

+55
-36
lines changed

5 files changed

+55
-36
lines changed

muted-tests.yml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -545,9 +545,6 @@ tests:
545545
- class: org.elasticsearch.compute.aggregation.TopIntAggregatorFunctionTests
546546
method: testManyInitialManyPartialFinalRunnerThrowing
547547
issue: https://github.com/elastic/elasticsearch/issues/130145
548-
- class: org.elasticsearch.index.codec.vectors.cluster.KMeansLocalTests
549-
method: testKMeansNeighbors
550-
issue: https://github.com/elastic/elasticsearch/issues/130258
551548
- class: org.elasticsearch.xpack.test.rest.XPackRestIT
552549
method: test {p0=esql/10_basic/basic with documents_found}
553550
issue: https://github.com/elastic/elasticsearch/issues/130256
@@ -575,9 +572,6 @@ tests:
575572
- class: org.elasticsearch.test.rest.yaml.RcsCcsCommonYamlTestSuiteIT
576573
method: test {p0=msearch/20_typed_keys/Multisearch test with typed_keys parameter for sampler and significant terms}
577574
issue: https://github.com/elastic/elasticsearch/issues/130472
578-
- class: org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeansTests
579-
method: testHKmeans
580-
issue: https://github.com/elastic/elasticsearch/issues/130497
581575
- class: org.elasticsearch.xpack.esql.action.EsqlActionBreakerIT
582576
method: testProjectWhere
583577
issue: https://github.com/elastic/elasticsearch/issues/130504

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ public KMeansResult cluster(FloatVectorValues vectors, int targetSize) throws IO
6868
KMeansIntermediate kMeansIntermediate = clusterAndSplit(vectors, targetSize);
6969
if (kMeansIntermediate.centroids().length > 1 && kMeansIntermediate.centroids().length < vectors.size()) {
7070
int localSampleSize = Math.min(kMeansIntermediate.centroids().length * samplesPerCluster / 2, vectors.size());
71-
KMeansLocal kMeansLocal = new KMeansLocal(localSampleSize, maxIterations, clustersPerNeighborhood, DEFAULT_SOAR_LAMBDA);
72-
kMeansLocal.cluster(vectors, kMeansIntermediate, true);
71+
KMeansLocal kMeansLocal = new KMeansLocal(localSampleSize, maxIterations);
72+
kMeansLocal.cluster(vectors, kMeansIntermediate, clustersPerNeighborhood, DEFAULT_SOAR_LAMBDA);
7373
}
7474

7575
return kMeansIntermediate;

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,10 @@ class KMeansLocal {
3535

3636
final int sampleSize;
3737
final int maxIterations;
38-
final int clustersPerNeighborhood;
39-
final float soarLambda;
4038

41-
KMeansLocal(int sampleSize, int maxIterations, int clustersPerNeighborhood, float soarLambda) {
39+
KMeansLocal(int sampleSize, int maxIterations) {
4240
this.sampleSize = sampleSize;
4341
this.maxIterations = maxIterations;
44-
this.clustersPerNeighborhood = clustersPerNeighborhood;
45-
this.soarLambda = soarLambda;
46-
}
47-
48-
KMeansLocal(int sampleSize, int maxIterations) {
49-
this(sampleSize, maxIterations, -1, -1f);
5042
}
5143

5244
/**
@@ -198,8 +190,13 @@ private void computeNeighborhoods(float[][] centers, List<NeighborHood> neighbor
198190
}
199191
}
200192

201-
private int[] assignSpilled(FloatVectorValues vectors, List<NeighborHood> neighborhoods, float[][] centroids, int[] assignments)
202-
throws IOException {
193+
private int[] assignSpilled(
194+
FloatVectorValues vectors,
195+
List<NeighborHood> neighborhoods,
196+
float[][] centroids,
197+
int[] assignments,
198+
float soarLambda
199+
) throws IOException {
203200
// SOAR uses an adjusted distance for assigning spilled documents which is
204201
// given by:
205202
//
@@ -264,6 +261,10 @@ private int[] assignSpilled(FloatVectorValues vectors, List<NeighborHood> neighb
264261
return spilledAssignments;
265262
}
266263

264+
record NeighborHood(int[] neighbors, float maxIntraDistance) {
265+
static final NeighborHood EMPTY = new NeighborHood(new int[0], Float.POSITIVE_INFINITY);
266+
}
267+
267268
/**
268269
* cluster using a lloyd k-means algorithm that is not neighbor aware
269270
*
@@ -274,11 +275,7 @@ private int[] assignSpilled(FloatVectorValues vectors, List<NeighborHood> neighb
274275
* @throws IOException is thrown if vectors is inaccessible
275276
*/
276277
void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) throws IOException {
277-
cluster(vectors, kMeansIntermediate, false);
278-
}
279-
280-
record NeighborHood(int[] neighbors, float maxIntraDistance) {
281-
static final NeighborHood EMPTY = new NeighborHood(new int[0], Float.POSITIVE_INFINITY);
278+
doCluster(vectors, kMeansIntermediate, -1, -1);
282279
}
283280

284281
/**
@@ -290,12 +287,23 @@ record NeighborHood(int[] neighbors, float maxIntraDistance) {
290287
* the prior assignments of the given vectors; care should be taken in
291288
* passing in a valid output object with a centroids array that is the size of centroids expected
292289
* and assignments that are the same size as the vectors. The SOAR assignments are overwritten by this operation.
293-
* @param neighborAware whether nearby neighboring centroids and their vectors should be used to update the centroid positions,
294-
* implies SOAR assignments
295-
* @throws IOException is thrown if vectors is inaccessible
290+
* @param clustersPerNeighborhood number of nearby neighboring centroids to be used to update the centroid positions.
291+
* @param soarLambda lambda used for SOAR assignments
292+
*
293+
* @throws IOException is thrown if vectors is inaccessible or if the clustersPerNeighborhood is less than 2
296294
*/
297-
void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, boolean neighborAware) throws IOException {
295+
void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, int clustersPerNeighborhood, float soarLambda)
296+
throws IOException {
297+
if (clustersPerNeighborhood < 2) {
298+
throw new IllegalArgumentException("clustersPerNeighborhood must be at least 2, got [" + clustersPerNeighborhood + "]");
299+
}
300+
doCluster(vectors, kMeansIntermediate, clustersPerNeighborhood, soarLambda);
301+
}
302+
303+
private void doCluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, int clustersPerNeighborhood, float soarLambda)
304+
throws IOException {
298305
float[][] centroids = kMeansIntermediate.centroids();
306+
boolean neighborAware = clustersPerNeighborhood != -1 && centroids.length > 1;
299307

300308
List<NeighborHood> neighborhoods = null;
301309
// if there are very few centroids, don't bother with neighborhoods or neighbor aware clustering
@@ -308,11 +316,11 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, b
308316
computeNeighborhoods(centroids, neighborhoods, clustersPerNeighborhood);
309317
}
310318
cluster(vectors, kMeansIntermediate, neighborhoods);
311-
if (neighborAware && clustersPerNeighborhood > 0) {
319+
if (neighborAware) {
312320
int[] assignments = kMeansIntermediate.assignments();
313321
assert assignments != null;
314322
assert assignments.length == vectors.size();
315-
kMeansIntermediate.setSoarAssignments(assignSpilled(vectors, neighborhoods, centroids, assignments));
323+
kMeansIntermediate.setSoarAssignments(assignSpilled(vectors, neighborhoods, centroids, assignments, soarLambda));
316324
}
317325
}
318326

server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeansTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ public void testHKmeans() throws IOException {
2323
int dims = random().nextInt(2, 20);
2424
int sampleSize = random().nextInt(100, nVectors + 1);
2525
int maxIterations = random().nextInt(0, 100);
26-
int clustersPerNeighborhood = random().nextInt(0, 512);
26+
int clustersPerNeighborhood = random().nextInt(2, 512);
2727
float soarLambda = random().nextFloat(0.5f, 1.5f);
2828
FloatVectorValues vectors = generateData(nVectors, dims, nClusters);
2929

server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocalTests.java

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,32 @@
1717
import java.util.ArrayList;
1818
import java.util.List;
1919

20+
import static org.hamcrest.Matchers.containsString;
21+
2022
public class KMeansLocalTests extends ESTestCase {
2123

24+
public void testIllegalClustersPerNeighborhood() {
25+
KMeansLocal kMeansLocal = new KMeansLocal(randomInt(), randomInt());
26+
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(new float[0][], new int[0], i -> i);
27+
IllegalArgumentException ex = expectThrows(
28+
IllegalArgumentException.class,
29+
() -> kMeansLocal.cluster(
30+
FloatVectorValues.fromFloats(List.of(), randomInt(1024)),
31+
kMeansIntermediate,
32+
randomIntBetween(Integer.MIN_VALUE, 1),
33+
randomFloat()
34+
)
35+
);
36+
assertThat(ex.getMessage(), containsString("clustersPerNeighborhood must be at least 2"));
37+
}
38+
2239
public void testKMeansNeighbors() throws IOException {
2340
int nClusters = random().nextInt(1, 10);
2441
int nVectors = random().nextInt(nClusters * 100, nClusters * 200);
2542
int dims = random().nextInt(2, 20);
2643
int sampleSize = random().nextInt(100, nVectors + 1);
2744
int maxIterations = random().nextInt(0, 100);
28-
int clustersPerNeighborhood = random().nextInt(0, 512);
45+
int clustersPerNeighborhood = random().nextInt(2, 512);
2946
float soarLambda = random().nextFloat(0.5f, 1.5f);
3047
FloatVectorValues vectors = generateData(nVectors, dims, nClusters);
3148

@@ -49,8 +66,8 @@ public void testKMeansNeighbors() throws IOException {
4966
}
5067

5168
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, i -> assignmentOrdinals[i]);
52-
KMeansLocal kMeansLocal = new KMeansLocal(sampleSize, maxIterations, clustersPerNeighborhood, soarLambda);
53-
kMeansLocal.cluster(vectors, kMeansIntermediate, true);
69+
KMeansLocal kMeansLocal = new KMeansLocal(sampleSize, maxIterations);
70+
kMeansLocal.cluster(vectors, kMeansIntermediate, clustersPerNeighborhood, soarLambda);
5471

5572
assertEquals(nClusters, centroids.length);
5673
assertNotNull(kMeansIntermediate.soarAssignments());
@@ -90,8 +107,8 @@ public void testKMeansNeighborsAllZero() throws IOException {
90107
}
91108

92109
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, i -> assignmentOrdinals[i]);
93-
KMeansLocal kMeansLocal = new KMeansLocal(sampleSize, maxIterations, clustersPerNeighborhood, soarLambda);
94-
kMeansLocal.cluster(fvv, kMeansIntermediate, true);
110+
KMeansLocal kMeansLocal = new KMeansLocal(sampleSize, maxIterations);
111+
kMeansLocal.cluster(fvv, kMeansIntermediate, clustersPerNeighborhood, soarLambda);
95112

96113
assertEquals(nClusters, centroids.length);
97114
assertNotNull(kMeansIntermediate.soarAssignments());

0 commit comments

Comments
 (0)