Skip to content

Commit 2455b00

Browse files
authored
[DiskBBQ] Add concurrency to KMeansLocal (#139239)
Hierarchical k-means algorithm can be run using several threads providing a significant speed up.
1 parent 829b2e9 commit 2455b00

File tree

9 files changed

+330
-58
lines changed

9 files changed

+330
-58
lines changed

docs/changelog/139239.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 139239
2+
summary: "[DiskBBQ] Add concurrency on KMeansLocal"
3+
area: Vector Search
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java

Lines changed: 75 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
package org.elasticsearch.index.codec.vectors.cluster;
1111

1212
import org.apache.lucene.index.FloatVectorValues;
13+
import org.apache.lucene.search.TaskExecutor;
1314

1415
import java.io.IOException;
1516
import java.util.Arrays;
@@ -25,27 +26,84 @@ public class HierarchicalKMeans {
2526
public static final int SAMPLES_PER_CLUSTER_DEFAULT = 64;
2627
public static final float DEFAULT_SOAR_LAMBDA = 1.0f;
2728
public static final int NO_SOAR_ASSIGNMENT = -1;
29+
private static final int MIN_VECTORS_PRE_THREAD = 64;
2830

2931
final int dimension;
3032
final int maxIterations;
3133
final int samplesPerCluster;
3234
final int clustersPerNeighborhood;
3335
final float soarLambda;
3436

35-
public HierarchicalKMeans(int dimension) {
36-
this(dimension, MAX_ITERATIONS_DEFAULT, SAMPLES_PER_CLUSTER_DEFAULT, MAXK, DEFAULT_SOAR_LAMBDA);
37-
}
37+
private final TaskExecutor executor;
38+
private final int numWorkers;
3839

39-
public HierarchicalKMeans(int dimension, int maxIterations, int samplesPerCluster, int clustersPerNeighborhood, float soarLambda) {
40+
private HierarchicalKMeans(
41+
int dimension,
42+
TaskExecutor executor,
43+
int numWorkers,
44+
int maxIterations,
45+
int samplesPerCluster,
46+
int clustersPerNeighborhood,
47+
float soarLambda
48+
) {
4049
this.dimension = dimension;
50+
this.executor = executor;
51+
this.numWorkers = numWorkers;
4152
this.maxIterations = maxIterations;
4253
this.samplesPerCluster = samplesPerCluster;
4354
this.clustersPerNeighborhood = clustersPerNeighborhood;
4455
this.soarLambda = soarLambda;
4556
}
4657

58+
public static HierarchicalKMeans ofSerial(int dimension) {
59+
return ofSerial(dimension, MAX_ITERATIONS_DEFAULT, SAMPLES_PER_CLUSTER_DEFAULT, MAXK, DEFAULT_SOAR_LAMBDA);
60+
}
61+
62+
public static HierarchicalKMeans ofSerial(
63+
int dimension,
64+
int maxIterations,
65+
int samplesPerCluster,
66+
int clustersPerNeighborhood,
67+
float soarLambda
68+
) {
69+
return new HierarchicalKMeans(dimension, null, 1, maxIterations, samplesPerCluster, clustersPerNeighborhood, soarLambda);
70+
}
71+
72+
public static HierarchicalKMeans ofConcurrent(int dimension, TaskExecutor executor, int numWorkers) {
73+
return ofConcurrent(
74+
dimension,
75+
executor,
76+
numWorkers,
77+
MAX_ITERATIONS_DEFAULT,
78+
SAMPLES_PER_CLUSTER_DEFAULT,
79+
MAXK,
80+
DEFAULT_SOAR_LAMBDA
81+
);
82+
}
83+
84+
public static HierarchicalKMeans ofConcurrent(
85+
int dimension,
86+
TaskExecutor executor,
87+
int numWorkers,
88+
int maxIterations,
89+
int samplesPerCluster,
90+
int clustersPerNeighborhood,
91+
float soarLambda
92+
) {
93+
return new HierarchicalKMeans(
94+
dimension,
95+
executor,
96+
numWorkers,
97+
maxIterations,
98+
samplesPerCluster,
99+
clustersPerNeighborhood,
100+
soarLambda
101+
);
102+
103+
}
104+
47105
/**
48-
* clusters or moreso partitions the set of vectors by starting with a rough number of partitions and then recursively refining those
106+
* clusters the set of vectors by starting with a rough number of partitions and then recursively refining those
49107
* lastly a pass is made to adjust nearby neighborhoods and add an extra assignment per vector to nearby neighborhoods
50108
*
51109
* @param vectors the vectors to cluster
@@ -54,7 +112,6 @@ public HierarchicalKMeans(int dimension, int maxIterations, int samplesPerCluste
54112
* @throws IOException is thrown if vectors is inaccessible
55113
*/
56114
public KMeansResult cluster(FloatVectorValues vectors, int targetSize) throws IOException {
57-
58115
if (vectors.size() == 0) {
59116
return new KMeansIntermediate();
60117
}
@@ -80,14 +137,13 @@ public KMeansResult cluster(FloatVectorValues vectors, int targetSize) throws IO
80137
KMeansIntermediate kMeansIntermediate = clusterAndSplit(vectors, targetSize);
81138
if (kMeansIntermediate.centroids().length > 1 && kMeansIntermediate.centroids().length < vectors.size()) {
82139
int localSampleSize = Math.min(kMeansIntermediate.centroids().length * samplesPerCluster / 2, vectors.size());
83-
KMeansLocal kMeansLocal = new KMeansLocal(localSampleSize, maxIterations);
140+
KMeansLocal kMeansLocal = buildKmeansLocal(vectors.size(), localSampleSize);
84141
kMeansLocal.cluster(vectors, kMeansIntermediate, clustersPerNeighborhood, soarLambda);
85142
}
86-
87143
return kMeansIntermediate;
88144
}
89145

90-
KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int targetSize) throws IOException {
146+
private KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int targetSize) throws IOException {
91147
if (vectors.size() <= targetSize) {
92148
return new KMeansIntermediate();
93149
}
@@ -99,10 +155,10 @@ KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int ta
99155
int[] assignments = new int[vectors.size()];
100156
// ensure we don't over assign to cluster 0 without adjusting it
101157
Arrays.fill(assignments, -1);
102-
KMeansLocal kmeans = new KMeansLocal(m, maxIterations);
103158
float[][] centroids = KMeansLocal.pickInitialCentroids(vectors, k);
104159
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, vectors::ordToDoc);
105-
kmeans.cluster(vectors, kMeansIntermediate);
160+
KMeansLocal kMeansLocal = buildKmeansLocal(vectors.size(), m);
161+
kMeansLocal.cluster(vectors, kMeansIntermediate);
106162

107163
// TODO: consider adding cluster size counts to the kmeans algo
108164
// handle assignment here so we can track distance and cluster size
@@ -164,6 +220,14 @@ KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int ta
164220
return kMeansIntermediate;
165221
}
166222

223+
private KMeansLocal buildKmeansLocal(int numVectors, int localSampleSize) {
224+
int numWorkers = Math.min(this.numWorkers, numVectors / MIN_VECTORS_PRE_THREAD);
225+
// if there is no executor or there is no enough vectors for more than one thread, use the serial version
226+
return executor == null || numWorkers <= 1
227+
? new KMeansLocalSerial(localSampleSize, maxIterations)
228+
: new KMeansLocalConcurrent(executor, numWorkers, localSampleSize, maxIterations);
229+
}
230+
167231
static FloatVectorValues createClusterSlice(int clusterSize, int cluster, FloatVectorValues vectors, int[] assignments) {
168232
int[] slice = new int[clusterSize];
169233
int idx = 0;

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java

Lines changed: 82 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,43 @@
2626
* with finalizing nearby pre-established clusters and generate
2727
* <a href="https://research.google/blog/soar-new-algorithms-for-even-faster-vector-search-with-scann/">SOAR</a> assignments
2828
*/
29-
class KMeansLocal {
29+
abstract class KMeansLocal {
3030

3131
// the minimum distance that is considered to be "far enough" to a centroid in order to compute the soar distance.
3232
// For vectors that are closer than this distance to the centroid don't get spilled because they are well represented
3333
// by the centroid itself. In many cases, it indicates a degenerated distribution, e.g the cluster is composed of the
3434
// many equal vectors.
3535
private static final float SOAR_MIN_DISTANCE = 1e-16f;
3636

37-
final int sampleSize;
38-
final int maxIterations;
37+
private final int sampleSize;
38+
private final int maxIterations;
3939

4040
KMeansLocal(int sampleSize, int maxIterations) {
4141
this.sampleSize = sampleSize;
4242
this.maxIterations = maxIterations;
4343
}
4444

45+
/** Number of workers to use for parallelism **/
46+
protected abstract int numWorkers();
47+
48+
/** assign to each vector the closest centroid **/
49+
protected abstract boolean stepLloyd(
50+
FloatVectorValues vectors,
51+
IntToIntFunction translateOrd,
52+
float[][] centroids,
53+
FixedBitSet[] centroidChangedSlices,
54+
int[] assignments,
55+
NeighborHood[] neighborHoods
56+
) throws IOException;
57+
58+
/** assign to each vector the soar assignment **/
59+
protected abstract void assignSpilled(
60+
FloatVectorValues vectors,
61+
KMeansIntermediate kmeansIntermediate,
62+
NeighborHood[] neighborhoods,
63+
float soarLambda
64+
) throws IOException;
65+
4566
/**
4667
* uses a Reservoir Sampling approach to picking the initial centroids which are subsequently expected
4768
* to be used by a clustering algorithm
@@ -69,20 +90,21 @@ static float[][] pickInitialCentroids(FloatVectorValues vectors, int centroidCou
6990
return centroids;
7091
}
7192

72-
private static boolean stepLloyd(
93+
/** Assign vectors from {@code start} to {@code end} to the closest centroid. */
94+
protected static boolean stepLloydSlice(
7395
FloatVectorValues vectors,
7496
IntToIntFunction translateOrd,
7597
float[][] centroids,
7698
FixedBitSet centroidChanged,
77-
int[] centroidCounts,
7899
int[] assignments,
79-
NeighborHood[] neighborhoods
100+
NeighborHood[] neighborhoods,
101+
int start,
102+
int end
80103
) throws IOException {
81104
boolean changed = false;
82-
int dim = vectors.dimension();
83105
centroidChanged.clear();
84106
final float[] distances = new float[4];
85-
for (int idx = 0; idx < vectors.size(); idx++) {
107+
for (int idx = start; idx < end; idx++) {
86108
float[] vector = vectors.vectorValue(idx);
87109
int vectorOrd = translateOrd.apply(idx);
88110
final int assignment = assignments[vectorOrd];
@@ -101,36 +123,49 @@ private static boolean stepLloyd(
101123
changed = true;
102124
}
103125
}
104-
if (changed) {
105-
Arrays.fill(centroidCounts, 0);
106-
for (int idx = 0; idx < vectors.size(); idx++) {
107-
final int assignment = assignments[translateOrd.apply(idx)];
108-
if (centroidChanged.get(assignment)) {
109-
float[] centroid = centroids[assignment];
110-
if (centroidCounts[assignment]++ == 0) {
111-
Arrays.fill(centroid, 0.0f);
112-
}
113-
float[] vector = vectors.vectorValue(idx);
126+
return changed;
127+
}
128+
129+
private static void updateCentroids(
130+
FloatVectorValues vectors,
131+
IntToIntFunction translateOrd,
132+
float[][] centroids,
133+
FixedBitSet[] centroidChangedSlices,
134+
int[] centroidCounts,
135+
int[] assignments
136+
) throws IOException {
137+
Arrays.fill(centroidCounts, 0);
138+
FixedBitSet centroidChanged = centroidChangedSlices[0];
139+
for (int j = 1; j < centroidChangedSlices.length; j++) {
140+
centroidChanged.or(centroidChangedSlices[j]);
141+
}
142+
int dim = vectors.dimension();
143+
for (int idx = 0; idx < vectors.size(); idx++) {
144+
final int assignment = assignments[translateOrd.apply(idx)];
145+
if (centroidChanged.get(assignment)) {
146+
float[] centroid = centroids[assignment];
147+
float[] vector = vectors.vectorValue(idx);
148+
if (centroidCounts[assignment]++ == 0) {
149+
System.arraycopy(vector, 0, centroid, 0, dim);
150+
} else {
114151
for (int d = 0; d < dim; d++) {
115152
centroid[d] += vector[d];
116153
}
117154
}
118155
}
156+
}
119157

120-
for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) {
121-
if (centroidChanged.get(clusterIdx)) {
122-
float count = (float) centroidCounts[clusterIdx];
123-
if (count > 0) {
124-
float[] centroid = centroids[clusterIdx];
125-
for (int d = 0; d < dim; d++) {
126-
centroid[d] /= count;
127-
}
158+
for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) {
159+
if (centroidChanged.get(clusterIdx)) {
160+
float count = (float) centroidCounts[clusterIdx];
161+
if (count > 0) {
162+
float[] centroid = centroids[clusterIdx];
163+
for (int d = 0; d < dim; d++) {
164+
centroid[d] /= count;
128165
}
129166
}
130167
}
131168
}
132-
133-
return changed;
134169
}
135170

136171
private static int getBestCentroidFromNeighbours(
@@ -211,11 +246,14 @@ private static int getBestCentroid(float[][] centroids, float[] vector, float[]
211246
return bestCentroidOffset;
212247
}
213248

214-
private void assignSpilled(
249+
/** Assign vectors from {@code start} to {@code end} to the SOAR centroid. */
250+
protected static void assignSpilledSlice(
215251
FloatVectorValues vectors,
216252
KMeansIntermediate kmeansIntermediate,
217253
NeighborHood[] neighborhoods,
218-
float soarLambda
254+
float soarLambda,
255+
int start,
256+
int end
219257
) throws IOException {
220258
// SOAR uses an adjusted distance for assigning spilled documents which is
221259
// given by:
@@ -235,7 +273,7 @@ private void assignSpilled(
235273

236274
float[] diffs = new float[vectors.dimension()];
237275
final float[] distances = new float[4];
238-
for (int i = 0; i < vectors.size(); i++) {
276+
for (int i = start; i < end; i++) {
239277
float[] vector = vectors.vectorValue(i);
240278
int currAssignment = assignments[i];
241279
float[] currentCentroid = centroids[currAssignment];
@@ -308,7 +346,7 @@ private void assignSpilled(
308346
* passing in a valid output object with a centroids array that is the size of centroids expected
309347
* @throws IOException is thrown if vectors is inaccessible
310348
*/
311-
void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) throws IOException {
349+
final void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) throws IOException {
312350
doCluster(vectors, kMeansIntermediate, -1, -1);
313351
}
314352

@@ -326,7 +364,7 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) t
326364
*
327365
* @throws IOException is thrown if vectors is inaccessible or if the clustersPerNeighborhood is less than 2
328366
*/
329-
void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, int clustersPerNeighborhood, float soarLambda)
367+
final void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, int clustersPerNeighborhood, float soarLambda)
330368
throws IOException {
331369
if (clustersPerNeighborhood < 2) {
332370
throw new IllegalArgumentException("clustersPerNeighborhood must be at least 2, got [" + clustersPerNeighborhood + "]");
@@ -370,18 +408,25 @@ private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansInterme
370408
}
371409

372410
assert assignments.length == n;
373-
FixedBitSet centroidChanged = new FixedBitSet(centroids.length);
411+
FixedBitSet[] centroidChangedSlices = new FixedBitSet[numWorkers()];
412+
for (int i = 0; i < numWorkers(); i++) {
413+
centroidChangedSlices[i] = new FixedBitSet(centroids.length);
414+
}
374415
int[] centroidCounts = new int[centroids.length];
375416
for (int i = 0; i < maxIterations; i++) {
376417
// This is potentially sampled, so we need to translate ordinals
377-
if (stepLloyd(sampledVectors, translateOrd, centroids, centroidChanged, centroidCounts, assignments, neighborhoods) == false) {
418+
if (stepLloyd(sampledVectors, translateOrd, centroids, centroidChangedSlices, assignments, neighborhoods)) {
419+
updateCentroids(sampledVectors, translateOrd, centroids, centroidChangedSlices, centroidCounts, assignments);
420+
} else {
378421
break;
379422
}
380423
}
381424
// If we were sampled, do a once over the full set of vectors to finalize the centroids
382425
if (sampleSize < n || maxIterations == 0) {
383426
// No ordinal translation needed here, we are using the full set of vectors
384-
stepLloyd(vectors, i -> i, centroids, centroidChanged, centroidCounts, assignments, neighborhoods);
427+
if (stepLloyd(vectors, i -> i, centroids, centroidChangedSlices, assignments, neighborhoods)) {
428+
updateCentroids(sampledVectors, translateOrd, centroids, centroidChangedSlices, centroidCounts, assignments);
429+
}
385430
}
386431
}
387432

@@ -396,8 +441,7 @@ private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansInterme
396441
*/
397442
public static void cluster(FloatVectorValues vectors, float[][] centroids, int sampleSize, int maxIterations) throws IOException {
398443
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, new int[vectors.size()], vectors::ordToDoc);
399-
KMeansLocal kMeans = new KMeansLocal(sampleSize, maxIterations);
444+
KMeansLocal kMeans = new KMeansLocalSerial(sampleSize, maxIterations);
400445
kMeans.cluster(vectors, kMeansIntermediate);
401446
}
402-
403447
}

0 commit comments

Comments
 (0)