@@ -35,18 +35,10 @@ class KMeansLocal {
35
35
36
36
final int sampleSize ;
37
37
final int maxIterations ;
38
- final int clustersPerNeighborhood ;
39
- final float soarLambda ;
40
38
41
- KMeansLocal (int sampleSize , int maxIterations , int clustersPerNeighborhood , float soarLambda ) {
39
+ KMeansLocal (int sampleSize , int maxIterations ) {
42
40
this .sampleSize = sampleSize ;
43
41
this .maxIterations = maxIterations ;
44
- this .clustersPerNeighborhood = clustersPerNeighborhood ;
45
- this .soarLambda = soarLambda ;
46
- }
47
-
48
- KMeansLocal (int sampleSize , int maxIterations ) {
49
- this (sampleSize , maxIterations , -1 , -1f );
50
42
}
51
43
52
44
/**
@@ -198,8 +190,13 @@ private void computeNeighborhoods(float[][] centers, List<NeighborHood> neighbor
198
190
}
199
191
}
200
192
201
- private int [] assignSpilled (FloatVectorValues vectors , List <NeighborHood > neighborhoods , float [][] centroids , int [] assignments )
202
- throws IOException {
193
+ private int [] assignSpilled (
194
+ FloatVectorValues vectors ,
195
+ List <NeighborHood > neighborhoods ,
196
+ float [][] centroids ,
197
+ int [] assignments ,
198
+ float soarLambda
199
+ ) throws IOException {
203
200
// SOAR uses an adjusted distance for assigning spilled documents which is
204
201
// given by:
205
202
//
@@ -264,6 +261,10 @@ private int[] assignSpilled(FloatVectorValues vectors, List<NeighborHood> neighb
264
261
return spilledAssignments ;
265
262
}
266
263
264
+ record NeighborHood (int [] neighbors , float maxIntraDistance ) {
265
+ static final NeighborHood EMPTY = new NeighborHood (new int [0 ], Float .POSITIVE_INFINITY );
266
+ }
267
+
267
268
/**
268
269
* cluster using a lloyd k-means algorithm that is not neighbor aware
269
270
*
@@ -274,11 +275,7 @@ private int[] assignSpilled(FloatVectorValues vectors, List<NeighborHood> neighb
274
275
* @throws IOException is thrown if vectors is inaccessible
275
276
*/
276
277
void cluster (FloatVectorValues vectors , KMeansIntermediate kMeansIntermediate ) throws IOException {
277
- cluster (vectors , kMeansIntermediate , false );
278
- }
279
-
280
- record NeighborHood (int [] neighbors , float maxIntraDistance ) {
281
- static final NeighborHood EMPTY = new NeighborHood (new int [0 ], Float .POSITIVE_INFINITY );
278
+ doCluster (vectors , kMeansIntermediate , -1 , -1 );
282
279
}
283
280
284
281
/**
@@ -290,12 +287,23 @@ record NeighborHood(int[] neighbors, float maxIntraDistance) {
290
287
* the prior assignments of the given vectors; care should be taken in
291
288
* passing in a valid output object with a centroids array that is the size of centroids expected
292
289
* and assignments that are the same size as the vectors. The SOAR assignments are overwritten by this operation.
293
- * @param neighborAware whether nearby neighboring centroids and their vectors should be used to update the centroid positions,
294
- * implies SOAR assignments
295
- * @throws IOException is thrown if vectors is inaccessible
290
+ * @param clustersPerNeighborhood number of nearby neighboring centroids to be used to update the centroid positions.
291
+ * @param soarLambda lambda used for SOAR assignments
292
+ *
293
+ * @throws IOException is thrown if vectors is inaccessible or if the clustersPerNeighborhood is less than 2
296
294
*/
297
- void cluster (FloatVectorValues vectors , KMeansIntermediate kMeansIntermediate , boolean neighborAware ) throws IOException {
295
+ void cluster (FloatVectorValues vectors , KMeansIntermediate kMeansIntermediate , int clustersPerNeighborhood , float soarLambda )
296
+ throws IOException {
297
+ if (clustersPerNeighborhood < 2 ) {
298
+ throw new IllegalArgumentException ("clustersPerNeighborhood must be at least 2, got [" + clustersPerNeighborhood + "]" );
299
+ }
300
+ doCluster (vectors , kMeansIntermediate , clustersPerNeighborhood , soarLambda );
301
+ }
302
+
303
+ private void doCluster (FloatVectorValues vectors , KMeansIntermediate kMeansIntermediate , int clustersPerNeighborhood , float soarLambda )
304
+ throws IOException {
298
305
float [][] centroids = kMeansIntermediate .centroids ();
306
+ boolean neighborAware = clustersPerNeighborhood != -1 && centroids .length > 1 ;
299
307
300
308
List <NeighborHood > neighborhoods = null ;
301
309
// if there are very few centroids, don't bother with neighborhoods or neighbor aware clustering
@@ -308,11 +316,11 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, b
308
316
computeNeighborhoods (centroids , neighborhoods , clustersPerNeighborhood );
309
317
}
310
318
cluster (vectors , kMeansIntermediate , neighborhoods );
311
- if (neighborAware && clustersPerNeighborhood > 0 ) {
319
+ if (neighborAware ) {
312
320
int [] assignments = kMeansIntermediate .assignments ();
313
321
assert assignments != null ;
314
322
assert assignments .length == vectors .size ();
315
- kMeansIntermediate .setSoarAssignments (assignSpilled (vectors , neighborhoods , centroids , assignments ));
323
+ kMeansIntermediate .setSoarAssignments (assignSpilled (vectors , neighborhoods , centroids , assignments , soarLambda ));
316
324
}
317
325
}
318
326
0 commit comments