16
16
import org .elasticsearch .simdvec .ESVectorUtil ;
17
17
18
18
import java .io .IOException ;
19
- import java .util .ArrayList ;
20
19
import java .util .Arrays ;
21
- import java .util .List ;
22
20
import java .util .Random ;
23
21
24
22
/**
@@ -74,7 +72,7 @@ private static boolean stepLloyd(
74
72
float [][] centroids ,
75
73
float [][] nextCentroids ,
76
74
int [] assignments ,
77
- List < NeighborHood > neighborhoods
75
+ NeighborHood [] neighborhoods
78
76
) throws IOException {
79
77
boolean changed = false ;
80
78
int dim = vectors .dimension ();
@@ -90,7 +88,7 @@ private static boolean stepLloyd(
90
88
final int assignment = assignments [vectorOrd ];
91
89
final int bestCentroidOffset ;
92
90
if (neighborhoods != null ) {
93
- bestCentroidOffset = getBestCentroidFromNeighbours (centroids , vector , assignment , neighborhoods . get ( assignment ) );
91
+ bestCentroidOffset = getBestCentroidFromNeighbours (centroids , vector , assignment , neighborhoods [ assignment ] );
94
92
} else {
95
93
bestCentroidOffset = getBestCentroid (centroids , vector );
96
94
}
@@ -152,30 +150,27 @@ private static int getBestCentroid(float[][] centroids, float[] vector) {
152
150
return bestCentroidOffset ;
153
151
}
154
152
155
- private void computeNeighborhoods (float [][] centers , List <NeighborHood > neighborhoods , int clustersPerNeighborhood ) {
156
- int k = neighborhoods .size ();
157
-
158
- if (k == 0 || clustersPerNeighborhood <= 0 ) {
159
- return ;
160
- }
161
-
162
- List <NeighborQueue > neighborQueues = new ArrayList <>(k );
153
+ private NeighborHood [] computeNeighborhoods (float [][] centers , int clustersPerNeighborhood ) {
154
+ int k = centers .length ;
155
+ assert k > clustersPerNeighborhood ;
156
+ NeighborQueue [] neighborQueues = new NeighborQueue [k ];
163
157
for (int i = 0 ; i < k ; i ++) {
164
- neighborQueues . add ( new NeighborQueue (clustersPerNeighborhood , true ) );
158
+ neighborQueues [ i ] = new NeighborQueue (clustersPerNeighborhood , true );
165
159
}
166
160
for (int i = 0 ; i < k - 1 ; i ++) {
167
161
for (int j = i + 1 ; j < k ; j ++) {
168
162
float dsq = VectorUtil .squareDistance (centers [i ], centers [j ]);
169
- neighborQueues . get ( j ) .insertWithOverflow (i , dsq );
170
- neighborQueues . get ( i ) .insertWithOverflow (j , dsq );
163
+ neighborQueues [ j ] .insertWithOverflow (i , dsq );
164
+ neighborQueues [ i ] .insertWithOverflow (j , dsq );
171
165
}
172
166
}
173
167
168
+ NeighborHood [] neighborhoods = new NeighborHood [k ];
174
169
for (int i = 0 ; i < k ; i ++) {
175
- NeighborQueue queue = neighborQueues . get ( i ) ;
170
+ NeighborQueue queue = neighborQueues [ i ] ;
176
171
if (queue .size () == 0 ) {
177
172
// no neighbors, skip
178
- neighborhoods . set ( i , NeighborHood .EMPTY ) ;
173
+ neighborhoods [ i ] = NeighborHood .EMPTY ;
179
174
continue ;
180
175
}
181
176
// consume the queue into the neighbors array and get the maximum intra-cluster distance
@@ -185,16 +180,15 @@ private void computeNeighborhoods(float[][] centers, List<NeighborHood> neighbor
185
180
while (queue .size () > 0 ) {
186
181
neighbors [neighbors .length - ++iter ] = queue .pop ();
187
182
}
188
- NeighborHood neighborHood = new NeighborHood (neighbors , maxIntraDistance );
189
- neighborhoods .set (i , neighborHood );
183
+ neighborhoods [i ] = new NeighborHood (neighbors , maxIntraDistance );
190
184
}
185
+ return neighborhoods ;
191
186
}
192
187
193
- private int [] assignSpilled (
188
+ private void assignSpilled (
194
189
FloatVectorValues vectors ,
195
- List <NeighborHood > neighborhoods ,
196
- float [][] centroids ,
197
- int [] assignments ,
190
+ KMeansIntermediate kmeansIntermediate ,
191
+ NeighborHood [] neighborhoods ,
198
192
float soarLambda
199
193
) throws IOException {
200
194
// SOAR uses an adjusted distance for assigning spilled documents which is
@@ -205,8 +199,13 @@ private int[] assignSpilled(
205
199
// Here, x is the document, c is the nearest centroid, and c_1 is the first
206
200
// centroid the document was assigned to. The document is assigned to the
207
201
// cluster with the smallest soar(x, c).
208
-
209
- int [] spilledAssignments = new int [assignments .length ];
202
+ int [] assignments = kmeansIntermediate .assignments ();
203
+ assert assignments != null ;
204
+ assert assignments .length == vectors .size ();
205
+ int [] spilledAssignments = kmeansIntermediate .soarAssignments ();
206
+ assert spilledAssignments != null ;
207
+ assert spilledAssignments .length == vectors .size ();
208
+ float [][] centroids = kmeansIntermediate .centroids ();
210
209
211
210
float [] diffs = new float [vectors .dimension ()];
212
211
for (int i = 0 ; i < vectors .size (); i ++) {
@@ -230,8 +229,8 @@ private int[] assignSpilled(
230
229
int centroidCount = centroids .length ;
231
230
IntToIntFunction centroidOrds = c -> c ;
232
231
if (neighborhoods != null ) {
233
- assert neighborhoods . get ( currAssignment ) != null ;
234
- NeighborHood neighborhood = neighborhoods . get ( currAssignment ) ;
232
+ assert neighborhoods [ currAssignment ] != null ;
233
+ NeighborHood neighborhood = neighborhoods [ currAssignment ] ;
235
234
centroidCount = neighborhood .neighbors .length ;
236
235
centroidOrds = c -> neighborhood .neighbors [c ];
237
236
}
@@ -257,8 +256,6 @@ private int[] assignSpilled(
257
256
assert bestAssignment != -1 : "Failed to assign soar vector to centroid" ;
258
257
spilledAssignments [i ] = bestAssignment ;
259
258
}
260
-
261
- return spilledAssignments ;
262
259
}
263
260
264
261
record NeighborHood (int [] neighbors , float maxIntraDistance ) {
@@ -304,27 +301,20 @@ private void doCluster(FloatVectorValues vectors, KMeansIntermediate kMeansInter
304
301
throws IOException {
305
302
float [][] centroids = kMeansIntermediate .centroids ();
306
303
boolean neighborAware = clustersPerNeighborhood != -1 && centroids .length > 1 ;
307
-
308
- List <NeighborHood > neighborhoods = null ;
304
+ NeighborHood [] neighborhoods = null ;
309
305
// if there are very few centroids, don't bother with neighborhoods or neighbor aware clustering
310
306
if (neighborAware && centroids .length > clustersPerNeighborhood ) {
311
- int k = centroids .length ;
312
- neighborhoods = new ArrayList <>(k );
313
- for (int i = 0 ; i < k ; ++i ) {
314
- neighborhoods .add (null );
315
- }
316
- computeNeighborhoods (centroids , neighborhoods , clustersPerNeighborhood );
307
+ neighborhoods = computeNeighborhoods (centroids , clustersPerNeighborhood );
317
308
}
318
309
cluster (vectors , kMeansIntermediate , neighborhoods );
319
310
if (neighborAware ) {
320
- int [] assignments = kMeansIntermediate .assignments ();
321
- assert assignments != null ;
322
- assert assignments .length == vectors .size ();
323
- kMeansIntermediate .setSoarAssignments (assignSpilled (vectors , neighborhoods , centroids , assignments , soarLambda ));
311
+ assert kMeansIntermediate .soarAssignments ().length == 0 ;
312
+ kMeansIntermediate .setSoarAssignments (new int [vectors .size ()]);
313
+ assignSpilled (vectors , kMeansIntermediate , neighborhoods , soarLambda );
324
314
}
325
315
}
326
316
327
- private void cluster (FloatVectorValues vectors , KMeansIntermediate kMeansIntermediate , List < NeighborHood > neighborhoods )
317
+ private void cluster (FloatVectorValues vectors , KMeansIntermediate kMeansIntermediate , NeighborHood [] neighborhoods )
328
318
throws IOException {
329
319
float [][] centroids = kMeansIntermediate .centroids ();
330
320
int k = centroids .length ;
0 commit comments