Skip to content

Commit 8f5e51a

Browse files
authored
Add query_vector_builder Support to Diversify Retriever (#139094)
* Add query_vector_builder support to diversify * add docs; rewrite tests; cleanups * context - return null if no query vector * add yaml test to verify use query_vector_builder * rename YAML test; validate builder; cleanups
1 parent 2db252e commit 8f5e51a

File tree

10 files changed

+382
-63
lines changed

10 files changed

+382
-63
lines changed

docs/reference/elasticsearch/rest-apis/retrievers/diversify-retriever.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@ The ordering of results returned from the inner retriever is preserved.
5050

5151
Query vector. Must have the same number of dimensions as the vector field you are searching against.
5252
Must be either an array of floats or a hex-encoded byte vector.
53+
If you provide a `query_vector`, you cannot also provide a `query_vector_builder`.
54+
55+
`query_vector_builder`
56+
: (Optional, query vector builder object)
57+
58+
Defines a [model](docs-content://solutions/search/vector/knn.md#knn-semantic-search) to build a query vector.
59+
If you provide a `query_vector_builder`, you cannot also provide a `query_vector`.
60+
5361

5462
`lambda`
5563
: (Required for `mmr`, float)

server/src/main/java/org/elasticsearch/search/diversification/DiversifyRetrieverBuilder.java

Lines changed: 84 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
package org.elasticsearch.search.diversification;
1111

1212
import org.apache.lucene.search.ScoreDoc;
13+
import org.apache.lucene.util.SetOnce;
1314
import org.elasticsearch.ElasticsearchException;
1415
import org.elasticsearch.ElasticsearchStatusException;
1516
import org.elasticsearch.action.ActionRequestValidationException;
@@ -26,6 +27,7 @@
2627
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
2728
import org.elasticsearch.search.retriever.RetrieverBuilder;
2829
import org.elasticsearch.search.retriever.RetrieverParserContext;
30+
import org.elasticsearch.search.vectors.QueryVectorBuilder;
2931
import org.elasticsearch.search.vectors.VectorData;
3032
import org.elasticsearch.xcontent.ConstructingObjectParser;
3133
import org.elasticsearch.xcontent.ObjectParser;
@@ -40,15 +42,16 @@
4042
import java.util.Locale;
4143
import java.util.Map;
4244
import java.util.Objects;
45+
import java.util.function.Supplier;
4346

4447
import static org.elasticsearch.action.ValidateActions.addValidationError;
48+
import static org.elasticsearch.common.Strings.format;
4549
import static org.elasticsearch.search.rank.RankBuilder.DEFAULT_RANK_WINDOW_SIZE;
4650
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
4751
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
4852

4953
public final class DiversifyRetrieverBuilder extends CompoundRetrieverBuilder<DiversifyRetrieverBuilder> {
5054

51-
public static final Float DEFAULT_LAMBDA_VALUE = 0.7f;
5255
public static final int DEFAULT_SIZE_VALUE = 10;
5356

5457
public static final NodeFeature RETRIEVER_RESULT_DIVERSIFICATION_MMR_FEATURE = new NodeFeature("retriever.result_diversification_mmr");
@@ -58,6 +61,7 @@ public final class DiversifyRetrieverBuilder extends CompoundRetrieverBuilder<Di
5861
public static final ParseField TYPE_FIELD = new ParseField("type");
5962
public static final ParseField FIELD_FIELD = new ParseField("field");
6063
public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector");
64+
public static final ParseField QUERY_VECTOR_BUILDER_FIELD = new ParseField("query_vector_builder");
6165
public static final ParseField LAMBDA_FIELD = new ParseField("lambda");
6266
public static final ParseField SIZE_FIELD = new ParseField("size");
6367

@@ -83,8 +87,9 @@ public SearchHit hit() {
8387
int rankWindowSize = args[3] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[3];
8488

8589
VectorData queryVector = args[4] == null ? null : (VectorData) args[4];
86-
Float lambda = args[5] == null ? null : (Float) args[5];
87-
Integer size = args[6] == null ? null : (Integer) args[6];
90+
QueryVectorBuilder queryVectorBuilder = args[5] == null ? null : (QueryVectorBuilder) args[5];
91+
Float lambda = args[6] == null ? null : (Float) args[6];
92+
Integer size = args[7] == null ? null : (Integer) args[7];
8893

8994
return new DiversifyRetrieverBuilder(
9095
RetrieverSource.from((RetrieverBuilder) args[0]),
@@ -93,6 +98,7 @@ public SearchHit hit() {
9398
rankWindowSize,
9499
size,
95100
queryVector,
101+
queryVectorBuilder,
96102
lambda
97103
);
98104
}
@@ -113,17 +119,22 @@ public SearchHit hit() {
113119
QUERY_VECTOR_FIELD,
114120
ObjectParser.ValueType.OBJECT_ARRAY_STRING_OR_NUMBER
115121
);
122+
PARSER.declareNamedObject(
123+
optionalConstructorArg(),
124+
(p, c, n) -> p.namedObject(QueryVectorBuilder.class, n, c),
125+
QUERY_VECTOR_BUILDER_FIELD
126+
);
116127
PARSER.declareFloat(optionalConstructorArg(), LAMBDA_FIELD);
117128
PARSER.declareInt(optionalConstructorArg(), SIZE_FIELD);
118129
RetrieverBuilder.declareBaseParserFields(PARSER);
119130
}
120131

121132
private final ResultDiversificationType diversificationType;
122133
private final String diversificationField;
123-
private final VectorData queryVector;
134+
private final Supplier<VectorData> queryVector;
135+
private final QueryVectorBuilder queryVectorBuilder;
124136
private final Float lambda;
125137
private final Integer size;
126-
private ResultDiversificationContext diversificationContext = null;
127138

128139
DiversifyRetrieverBuilder(
129140
RetrieverSource innerRetriever,
@@ -132,12 +143,14 @@ public SearchHit hit() {
132143
int rankWindowSize,
133144
@Nullable Integer size,
134145
@Nullable VectorData queryVector,
146+
@Nullable QueryVectorBuilder queryVectorBuilder,
135147
@Nullable Float lambda
136148
) {
137149
super(List.of(innerRetriever), rankWindowSize);
138150
this.diversificationType = diversificationType;
139151
this.diversificationField = diversificationField;
140-
this.queryVector = queryVector;
152+
this.queryVector = queryVector != null ? () -> queryVector : null;
153+
this.queryVectorBuilder = queryVectorBuilder;
141154
this.lambda = lambda;
142155
this.size = size == null ? Math.min(DEFAULT_SIZE_VALUE, rankWindowSize) : size;
143156
}
@@ -148,7 +161,8 @@ public SearchHit hit() {
148161
String diversificationField,
149162
int rankWindowSize,
150163
@Nullable Integer size,
151-
@Nullable VectorData queryVector,
164+
@Nullable Supplier<VectorData> queryVector,
165+
@Nullable QueryVectorBuilder queryVectorBuilder,
152166
@Nullable Float lambda
153167
) {
154168
super(innerRetrievers, rankWindowSize);
@@ -157,6 +171,7 @@ public SearchHit hit() {
157171
this.diversificationType = diversificationType;
158172
this.diversificationField = diversificationField;
159173
this.queryVector = queryVector;
174+
this.queryVectorBuilder = queryVectorBuilder;
160175
this.lambda = lambda;
161176
this.size = size == null ? Math.min(DEFAULT_SIZE_VALUE, rankWindowSize) : size;
162177
}
@@ -170,6 +185,7 @@ protected DiversifyRetrieverBuilder clone(List<RetrieverSource> newChildRetrieve
170185
rankWindowSize,
171186
size,
172187
queryVector,
188+
queryVectorBuilder,
173189
lambda
174190
);
175191
}
@@ -181,6 +197,19 @@ public ActionRequestValidationException validate(
181197
boolean isScroll,
182198
boolean allowPartialSearchResults
183199
) {
200+
if (queryVector != null && queryVectorBuilder != null) {
201+
validationException = addValidationError(
202+
String.format(
203+
Locale.ROOT,
204+
"[%s] MMR result diversification can have one of [%s] or [%s], but not both",
205+
getName(),
206+
QUERY_VECTOR_FIELD.getPreferredName(),
207+
QUERY_VECTOR_BUILDER_FIELD.getPreferredName()
208+
),
209+
validationException
210+
);
211+
}
212+
184213
if (diversificationType.equals(ResultDiversificationType.MMR)) {
185214
validationException = validateMMRDiversification(validationException);
186215
}
@@ -235,17 +264,37 @@ private ActionRequestValidationException validateMMRDiversification(ActionReques
235264

236265
@Override
237266
protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
238-
if (diversificationType.equals(ResultDiversificationType.MMR)) {
239-
// field vectors will be filled in during the combine
240-
diversificationContext = new MMRResultDiversificationContext(
267+
if (queryVectorBuilder != null) {
268+
SetOnce<VectorData> toSet = new SetOnce<>();
269+
ctx.registerAsyncAction((c, l) -> {
270+
queryVectorBuilder.buildVector(c, l.delegateFailureAndWrap((ll, v) -> {
271+
toSet.set(v == null ? null : new VectorData(v));
272+
if (v == null) {
273+
ll.onFailure(
274+
new IllegalArgumentException(
275+
format(
276+
"[%s] with name [%s] returned null query_vector",
277+
QUERY_VECTOR_BUILDER_FIELD.getPreferredName(),
278+
queryVectorBuilder.getWriteableName()
279+
)
280+
)
281+
);
282+
return;
283+
}
284+
ll.onResponse(null);
285+
}));
286+
});
287+
288+
return new DiversifyRetrieverBuilder(
289+
innerRetrievers,
290+
diversificationType,
241291
diversificationField,
242-
lambda,
243-
size == null ? DEFAULT_SIZE_VALUE : size,
244-
queryVector
292+
rankWindowSize,
293+
size,
294+
() -> toSet.get(),
295+
null,
296+
lambda
245297
);
246-
} else {
247-
// should not happen
248-
throw new IllegalArgumentException("Unknown diversification type [" + diversificationType + "]");
249298
}
250299

251300
return this;
@@ -281,13 +330,6 @@ protected Exception processInnerItemFailureException(Exception ex) {
281330

282331
@Override
283332
protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, boolean explain) {
284-
if (diversificationContext == null) {
285-
throw new ElasticsearchStatusException(
286-
"diversificationContext is not set. \"doRewrite\" should have been called beforehand.",
287-
RestStatus.INTERNAL_SERVER_ERROR
288-
);
289-
}
290-
291333
if (rankResults.isEmpty()) {
292334
return new RankDoc[0];
293335
}
@@ -302,6 +344,8 @@ protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, b
302344
return new RankDoc[0];
303345
}
304346

347+
ResultDiversificationContext diversificationContext = getResultDiversificationContext();
348+
305349
// gather and set the query vectors
306350
// and create our intermediate results set
307351
RankDoc[] results = new RankDoc[scoreDocs.length];
@@ -344,6 +388,15 @@ protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, b
344388
}
345389
}
346390

391+
private ResultDiversificationContext getResultDiversificationContext() {
392+
if (diversificationType.equals(ResultDiversificationType.MMR)) {
393+
return new MMRResultDiversificationContext(diversificationField, lambda, size == null ? DEFAULT_SIZE_VALUE : size, queryVector);
394+
}
395+
396+
// should not happen
397+
throw new IllegalArgumentException("Unknown diversification type [" + diversificationType + "]");
398+
}
399+
347400
private void extractFieldVectorData(int docId, Object fieldValue, Map<Integer, VectorData> fieldVectors) {
348401
switch (fieldValue) {
349402
case float[] floatArray -> {
@@ -427,7 +480,11 @@ protected void doToXContent(XContentBuilder builder, Params params) throws IOExc
427480
builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize);
428481

429482
if (queryVector != null) {
430-
builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector);
483+
builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector.get());
484+
}
485+
486+
if (queryVectorBuilder != null) {
487+
builder.field(QUERY_VECTOR_BUILDER_FIELD.getPreferredName(), queryVectorBuilder);
431488
}
432489

433490
if (lambda != null) {
@@ -451,6 +508,8 @@ public boolean doEquals(Object o) {
451508
&& this.diversificationType.equals(other.diversificationType)
452509
&& this.diversificationField.equals(other.diversificationField)
453510
&& Objects.equals(this.lambda, other.lambda)
454-
&& Objects.equals(this.queryVector, other.queryVector);
511+
&& ((queryVector == null && other.queryVector == null)
512+
|| (queryVector != null && other.queryVector != null && Objects.equals(queryVector.get(), other.queryVector.get())))
513+
&& Objects.equals(this.queryVectorBuilder, other.queryVectorBuilder);
455514
}
456515
}

server/src/main/java/org/elasticsearch/search/diversification/ResultDiversificationContext.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@
1414

1515
import java.util.Map;
1616
import java.util.Set;
17+
import java.util.function.Supplier;
1718

1819
public abstract class ResultDiversificationContext {
1920
private final String field;
2021
private final int size;
21-
private final VectorData queryVector;
22+
private final Supplier<VectorData> queryVector;
2223
private Map<Integer, VectorData> fieldVectors = null;
2324

24-
protected ResultDiversificationContext(String field, int size, @Nullable VectorData queryVector) {
25+
protected ResultDiversificationContext(String field, int size, @Nullable Supplier<VectorData> queryVector) {
2526
this.field = field;
2627
this.size = size;
2728
this.queryVector = queryVector;
@@ -45,7 +46,7 @@ public void setFieldVectors(Map<Integer, VectorData> fieldVectors) {
4546
}
4647

4748
public VectorData getQueryVector() {
48-
return queryVector;
49+
return queryVector == null ? null : queryVector.get();
4950
}
5051

5152
public VectorData getFieldVector(int rank) {

server/src/main/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversificationContext.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
import org.elasticsearch.search.diversification.ResultDiversificationContext;
1414
import org.elasticsearch.search.vectors.VectorData;
1515

16+
import java.util.function.Supplier;
17+
1618
public class MMRResultDiversificationContext extends ResultDiversificationContext {
1719

1820
private final float lambda;
1921

20-
public MMRResultDiversificationContext(String field, float lambda, int size, @Nullable VectorData queryVector) {
22+
public MMRResultDiversificationContext(String field, float lambda, int size, @Nullable Supplier<VectorData> queryVector) {
2123
super(field, size, queryVector);
2224
this.lambda = lambda;
2325
}

server/src/test/java/org/elasticsearch/search/diversification/DiversifyRetrieverBuilderParsingTests.java

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ protected DiversifyRetrieverBuilder createTestInstance() {
5959
rankWindowSize,
6060
size,
6161
queryVector,
62+
null,
6263
lambda
6364
);
6465
}
@@ -92,11 +93,7 @@ protected NamedXContentRegistry xContentRegistry() {
9293

9394
private VectorData getRandomQueryVector() {
9495
if (randomBoolean()) {
95-
float[] queryVector = new float[randomIntBetween(5, 256)];
96-
for (int i = 0; i < queryVector.length; i++) {
97-
queryVector[i] = randomFloatBetween(0.0f, 1.0f, true);
98-
}
99-
return new VectorData(queryVector);
96+
return new VectorData(getRandomFloatQueryVector());
10097
}
10198

10299
byte[] queryVector = new byte[randomIntBetween(5, 256)];
@@ -105,4 +102,12 @@ private VectorData getRandomQueryVector() {
105102
}
106103
return new VectorData(queryVector);
107104
}
105+
106+
private float[] getRandomFloatQueryVector() {
107+
float[] queryVector = new float[randomIntBetween(5, 256)];
108+
for (int i = 0; i < queryVector.length; i++) {
109+
queryVector[i] = randomFloatBetween(0.0f, 1.0f, true);
110+
}
111+
return queryVector;
112+
}
108113
}

0 commit comments

Comments
 (0)