2626 * with finalizing nearby pre-established clusters and generate
2727 * <a href="https://research.google/blog/soar-new-algorithms-for-even-faster-vector-search-with-scann/">SOAR</a> assignments
2828 */
29- class KMeansLocal {
29+ abstract class KMeansLocal {
3030
3131 // the minimum distance that is considered to be "far enough" to a centroid in order to compute the soar distance.
3232 // For vectors that are closer than this distance to the centroid don't get spilled because they are well represented
3333 // by the centroid itself. In many cases, it indicates a degenerated distribution, e.g the cluster is composed of the
3434 // many equal vectors.
3535 private static final float SOAR_MIN_DISTANCE = 1e-16f ;
3636
37- final int sampleSize ;
38- final int maxIterations ;
37+ private final int sampleSize ;
38+ private final int maxIterations ;
3939
4040 KMeansLocal (int sampleSize , int maxIterations ) {
4141 this .sampleSize = sampleSize ;
4242 this .maxIterations = maxIterations ;
4343 }
4444
45+ /** Number of workers to use for parallelism **/
46+ protected abstract int numWorkers ();
47+
48+ /** assign to each vector the closest centroid **/
49+ protected abstract boolean stepLloyd (
50+ FloatVectorValues vectors ,
51+ IntToIntFunction translateOrd ,
52+ float [][] centroids ,
53+ FixedBitSet [] centroidChangedSlices ,
54+ int [] assignments ,
55+ NeighborHood [] neighborHoods
56+ ) throws IOException ;
57+
58+ /** assign to each vector the soar assignment **/
59+ protected abstract void assignSpilled (
60+ FloatVectorValues vectors ,
61+ KMeansIntermediate kmeansIntermediate ,
62+ NeighborHood [] neighborhoods ,
63+ float soarLambda
64+ ) throws IOException ;
65+
4566 /**
4667 * uses a Reservoir Sampling approach to picking the initial centroids which are subsequently expected
4768 * to be used by a clustering algorithm
@@ -69,20 +90,21 @@ static float[][] pickInitialCentroids(FloatVectorValues vectors, int centroidCou
6990 return centroids ;
7091 }
7192
72- private static boolean stepLloyd (
93+ /** Assign vectors from {@code start} to {@code end} to the closest centroid. */
94+ protected static boolean stepLloydSlice (
7395 FloatVectorValues vectors ,
7496 IntToIntFunction translateOrd ,
7597 float [][] centroids ,
7698 FixedBitSet centroidChanged ,
77- int [] centroidCounts ,
7899 int [] assignments ,
79- NeighborHood [] neighborhoods
100+ NeighborHood [] neighborhoods ,
101+ int start ,
102+ int end
80103 ) throws IOException {
81104 boolean changed = false ;
82- int dim = vectors .dimension ();
83105 centroidChanged .clear ();
84106 final float [] distances = new float [4 ];
85- for (int idx = 0 ; idx < vectors . size () ; idx ++) {
107+ for (int idx = start ; idx < end ; idx ++) {
86108 float [] vector = vectors .vectorValue (idx );
87109 int vectorOrd = translateOrd .apply (idx );
88110 final int assignment = assignments [vectorOrd ];
@@ -101,36 +123,49 @@ private static boolean stepLloyd(
101123 changed = true ;
102124 }
103125 }
104- if (changed ) {
105- Arrays .fill (centroidCounts , 0 );
106- for (int idx = 0 ; idx < vectors .size (); idx ++) {
107- final int assignment = assignments [translateOrd .apply (idx )];
108- if (centroidChanged .get (assignment )) {
109- float [] centroid = centroids [assignment ];
110- if (centroidCounts [assignment ]++ == 0 ) {
111- Arrays .fill (centroid , 0.0f );
112- }
113- float [] vector = vectors .vectorValue (idx );
126+ return changed ;
127+ }
128+
129+ private static void updateCentroids (
130+ FloatVectorValues vectors ,
131+ IntToIntFunction translateOrd ,
132+ float [][] centroids ,
133+ FixedBitSet [] centroidChangedSlices ,
134+ int [] centroidCounts ,
135+ int [] assignments
136+ ) throws IOException {
137+ Arrays .fill (centroidCounts , 0 );
138+ FixedBitSet centroidChanged = centroidChangedSlices [0 ];
139+ for (int j = 1 ; j < centroidChangedSlices .length ; j ++) {
140+ centroidChanged .or (centroidChangedSlices [j ]);
141+ }
142+ int dim = vectors .dimension ();
143+ for (int idx = 0 ; idx < vectors .size (); idx ++) {
144+ final int assignment = assignments [translateOrd .apply (idx )];
145+ if (centroidChanged .get (assignment )) {
146+ float [] centroid = centroids [assignment ];
147+ float [] vector = vectors .vectorValue (idx );
148+ if (centroidCounts [assignment ]++ == 0 ) {
149+ System .arraycopy (vector , 0 , centroid , 0 , dim );
150+ } else {
114151 for (int d = 0 ; d < dim ; d ++) {
115152 centroid [d ] += vector [d ];
116153 }
117154 }
118155 }
156+ }
119157
120- for (int clusterIdx = 0 ; clusterIdx < centroids .length ; clusterIdx ++) {
121- if (centroidChanged .get (clusterIdx )) {
122- float count = (float ) centroidCounts [clusterIdx ];
123- if (count > 0 ) {
124- float [] centroid = centroids [clusterIdx ];
125- for (int d = 0 ; d < dim ; d ++) {
126- centroid [d ] /= count ;
127- }
158+ for (int clusterIdx = 0 ; clusterIdx < centroids .length ; clusterIdx ++) {
159+ if (centroidChanged .get (clusterIdx )) {
160+ float count = (float ) centroidCounts [clusterIdx ];
161+ if (count > 0 ) {
162+ float [] centroid = centroids [clusterIdx ];
163+ for (int d = 0 ; d < dim ; d ++) {
164+ centroid [d ] /= count ;
128165 }
129166 }
130167 }
131168 }
132-
133- return changed ;
134169 }
135170
136171 private static int getBestCentroidFromNeighbours (
@@ -211,11 +246,14 @@ private static int getBestCentroid(float[][] centroids, float[] vector, float[]
211246 return bestCentroidOffset ;
212247 }
213248
214- private void assignSpilled (
249+ /** Assign vectors from {@code start} to {@code end} to the SOAR centroid. */
250+ protected static void assignSpilledSlice (
215251 FloatVectorValues vectors ,
216252 KMeansIntermediate kmeansIntermediate ,
217253 NeighborHood [] neighborhoods ,
218- float soarLambda
254+ float soarLambda ,
255+ int start ,
256+ int end
219257 ) throws IOException {
220258 // SOAR uses an adjusted distance for assigning spilled documents which is
221259 // given by:
@@ -235,7 +273,7 @@ private void assignSpilled(
235273
236274 float [] diffs = new float [vectors .dimension ()];
237275 final float [] distances = new float [4 ];
238- for (int i = 0 ; i < vectors . size () ; i ++) {
276+ for (int i = start ; i < end ; i ++) {
239277 float [] vector = vectors .vectorValue (i );
240278 int currAssignment = assignments [i ];
241279 float [] currentCentroid = centroids [currAssignment ];
@@ -308,7 +346,7 @@ private void assignSpilled(
308346 * passing in a valid output object with a centroids array that is the size of centroids expected
309347 * @throws IOException is thrown if vectors is inaccessible
310348 */
311- void cluster (FloatVectorValues vectors , KMeansIntermediate kMeansIntermediate ) throws IOException {
349+ final void cluster (FloatVectorValues vectors , KMeansIntermediate kMeansIntermediate ) throws IOException {
312350 doCluster (vectors , kMeansIntermediate , -1 , -1 );
313351 }
314352
@@ -326,7 +364,7 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) t
326364 *
327365 * @throws IOException is thrown if vectors is inaccessible or if the clustersPerNeighborhood is less than 2
328366 */
329- void cluster (FloatVectorValues vectors , KMeansIntermediate kMeansIntermediate , int clustersPerNeighborhood , float soarLambda )
367+ final void cluster (FloatVectorValues vectors , KMeansIntermediate kMeansIntermediate , int clustersPerNeighborhood , float soarLambda )
330368 throws IOException {
331369 if (clustersPerNeighborhood < 2 ) {
332370 throw new IllegalArgumentException ("clustersPerNeighborhood must be at least 2, got [" + clustersPerNeighborhood + "]" );
@@ -370,18 +408,25 @@ private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansInterme
370408 }
371409
372410 assert assignments .length == n ;
373- FixedBitSet centroidChanged = new FixedBitSet (centroids .length );
411+ FixedBitSet [] centroidChangedSlices = new FixedBitSet [numWorkers ()];
412+ for (int i = 0 ; i < numWorkers (); i ++) {
413+ centroidChangedSlices [i ] = new FixedBitSet (centroids .length );
414+ }
374415 int [] centroidCounts = new int [centroids .length ];
375416 for (int i = 0 ; i < maxIterations ; i ++) {
376417 // This is potentially sampled, so we need to translate ordinals
377- if (stepLloyd (sampledVectors , translateOrd , centroids , centroidChanged , centroidCounts , assignments , neighborhoods ) == false ) {
418+ if (stepLloyd (sampledVectors , translateOrd , centroids , centroidChangedSlices , assignments , neighborhoods )) {
419+ updateCentroids (sampledVectors , translateOrd , centroids , centroidChangedSlices , centroidCounts , assignments );
420+ } else {
378421 break ;
379422 }
380423 }
381424 // If we were sampled, do a once over the full set of vectors to finalize the centroids
382425 if (sampleSize < n || maxIterations == 0 ) {
383426 // No ordinal translation needed here, we are using the full set of vectors
384- stepLloyd (vectors , i -> i , centroids , centroidChanged , centroidCounts , assignments , neighborhoods );
427+ if (stepLloyd (vectors , i -> i , centroids , centroidChangedSlices , assignments , neighborhoods )) {
428+ updateCentroids (sampledVectors , translateOrd , centroids , centroidChangedSlices , centroidCounts , assignments );
429+ }
385430 }
386431 }
387432
@@ -396,8 +441,7 @@ private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansInterme
396441 */
397442 public static void cluster (FloatVectorValues vectors , float [][] centroids , int sampleSize , int maxIterations ) throws IOException {
398443 KMeansIntermediate kMeansIntermediate = new KMeansIntermediate (centroids , new int [vectors .size ()], vectors ::ordToDoc );
399- KMeansLocal kMeans = new KMeansLocal (sampleSize , maxIterations );
444+ KMeansLocal kMeans = new KMeansLocalSerial (sampleSize , maxIterations );
400445 kMeans .cluster (vectors , kMeansIntermediate );
401446 }
402-
403447}
0 commit comments