1010package org .elasticsearch .search .diversification ;
1111
1212import org .apache .lucene .search .ScoreDoc ;
13+ import org .apache .lucene .util .SetOnce ;
1314import org .elasticsearch .ElasticsearchException ;
1415import org .elasticsearch .ElasticsearchStatusException ;
1516import org .elasticsearch .action .ActionRequestValidationException ;
2627import org .elasticsearch .search .retriever .CompoundRetrieverBuilder ;
2728import org .elasticsearch .search .retriever .RetrieverBuilder ;
2829import org .elasticsearch .search .retriever .RetrieverParserContext ;
30+ import org .elasticsearch .search .vectors .QueryVectorBuilder ;
2931import org .elasticsearch .search .vectors .VectorData ;
3032import org .elasticsearch .xcontent .ConstructingObjectParser ;
3133import org .elasticsearch .xcontent .ObjectParser ;
4042import java .util .Locale ;
4143import java .util .Map ;
4244import java .util .Objects ;
45+ import java .util .function .Supplier ;
4346
4447import static org .elasticsearch .action .ValidateActions .addValidationError ;
48+ import static org .elasticsearch .common .Strings .format ;
4549import static org .elasticsearch .search .rank .RankBuilder .DEFAULT_RANK_WINDOW_SIZE ;
4650import static org .elasticsearch .xcontent .ConstructingObjectParser .constructorArg ;
4751import static org .elasticsearch .xcontent .ConstructingObjectParser .optionalConstructorArg ;
4852
4953public final class DiversifyRetrieverBuilder extends CompoundRetrieverBuilder <DiversifyRetrieverBuilder > {
5054
51- public static final Float DEFAULT_LAMBDA_VALUE = 0.7f ;
5255 public static final int DEFAULT_SIZE_VALUE = 10 ;
5356
5457 public static final NodeFeature RETRIEVER_RESULT_DIVERSIFICATION_MMR_FEATURE = new NodeFeature ("retriever.result_diversification_mmr" );
@@ -58,6 +61,7 @@ public final class DiversifyRetrieverBuilder extends CompoundRetrieverBuilder<Di
5861 public static final ParseField TYPE_FIELD = new ParseField ("type" );
5962 public static final ParseField FIELD_FIELD = new ParseField ("field" );
6063 public static final ParseField QUERY_VECTOR_FIELD = new ParseField ("query_vector" );
64+ public static final ParseField QUERY_VECTOR_BUILDER_FIELD = new ParseField ("query_vector_builder" );
6165 public static final ParseField LAMBDA_FIELD = new ParseField ("lambda" );
6266 public static final ParseField SIZE_FIELD = new ParseField ("size" );
6367
@@ -83,8 +87,9 @@ public SearchHit hit() {
8387 int rankWindowSize = args [3 ] == null ? DEFAULT_RANK_WINDOW_SIZE : (int ) args [3 ];
8488
8589 VectorData queryVector = args [4 ] == null ? null : (VectorData ) args [4 ];
86- Float lambda = args [5 ] == null ? null : (Float ) args [5 ];
87- Integer size = args [6 ] == null ? null : (Integer ) args [6 ];
90+ QueryVectorBuilder queryVectorBuilder = args [5 ] == null ? null : (QueryVectorBuilder ) args [5 ];
91+ Float lambda = args [6 ] == null ? null : (Float ) args [6 ];
92+ Integer size = args [7 ] == null ? null : (Integer ) args [7 ];
8893
8994 return new DiversifyRetrieverBuilder (
9095 RetrieverSource .from ((RetrieverBuilder ) args [0 ]),
@@ -93,6 +98,7 @@ public SearchHit hit() {
9398 rankWindowSize ,
9499 size ,
95100 queryVector ,
101+ queryVectorBuilder ,
96102 lambda
97103 );
98104 }
@@ -113,17 +119,22 @@ public SearchHit hit() {
113119 QUERY_VECTOR_FIELD ,
114120 ObjectParser .ValueType .OBJECT_ARRAY_STRING_OR_NUMBER
115121 );
122+ PARSER .declareNamedObject (
123+ optionalConstructorArg (),
124+ (p , c , n ) -> p .namedObject (QueryVectorBuilder .class , n , c ),
125+ QUERY_VECTOR_BUILDER_FIELD
126+ );
116127 PARSER .declareFloat (optionalConstructorArg (), LAMBDA_FIELD );
117128 PARSER .declareInt (optionalConstructorArg (), SIZE_FIELD );
118129 RetrieverBuilder .declareBaseParserFields (PARSER );
119130 }
120131
121132 private final ResultDiversificationType diversificationType ;
122133 private final String diversificationField ;
123- private final VectorData queryVector ;
134+ private final Supplier <VectorData > queryVector ;
135+ private final QueryVectorBuilder queryVectorBuilder ;
124136 private final Float lambda ;
125137 private final Integer size ;
126- private ResultDiversificationContext diversificationContext = null ;
127138
128139 DiversifyRetrieverBuilder (
129140 RetrieverSource innerRetriever ,
@@ -132,12 +143,14 @@ public SearchHit hit() {
132143 int rankWindowSize ,
133144 @ Nullable Integer size ,
134145 @ Nullable VectorData queryVector ,
146+ @ Nullable QueryVectorBuilder queryVectorBuilder ,
135147 @ Nullable Float lambda
136148 ) {
137149 super (List .of (innerRetriever ), rankWindowSize );
138150 this .diversificationType = diversificationType ;
139151 this .diversificationField = diversificationField ;
140- this .queryVector = queryVector ;
152+ this .queryVector = queryVector != null ? () -> queryVector : null ;
153+ this .queryVectorBuilder = queryVectorBuilder ;
141154 this .lambda = lambda ;
142155 this .size = size == null ? Math .min (DEFAULT_SIZE_VALUE , rankWindowSize ) : size ;
143156 }
@@ -148,7 +161,8 @@ public SearchHit hit() {
148161 String diversificationField ,
149162 int rankWindowSize ,
150163 @ Nullable Integer size ,
151- @ Nullable VectorData queryVector ,
164+ @ Nullable Supplier <VectorData > queryVector ,
165+ @ Nullable QueryVectorBuilder queryVectorBuilder ,
152166 @ Nullable Float lambda
153167 ) {
154168 super (innerRetrievers , rankWindowSize );
@@ -157,6 +171,7 @@ public SearchHit hit() {
157171 this .diversificationType = diversificationType ;
158172 this .diversificationField = diversificationField ;
159173 this .queryVector = queryVector ;
174+ this .queryVectorBuilder = queryVectorBuilder ;
160175 this .lambda = lambda ;
161176 this .size = size == null ? Math .min (DEFAULT_SIZE_VALUE , rankWindowSize ) : size ;
162177 }
@@ -170,6 +185,7 @@ protected DiversifyRetrieverBuilder clone(List<RetrieverSource> newChildRetrieve
170185 rankWindowSize ,
171186 size ,
172187 queryVector ,
188+ queryVectorBuilder ,
173189 lambda
174190 );
175191 }
@@ -181,6 +197,19 @@ public ActionRequestValidationException validate(
181197 boolean isScroll ,
182198 boolean allowPartialSearchResults
183199 ) {
200+ if (queryVector != null && queryVectorBuilder != null ) {
201+ validationException = addValidationError (
202+ String .format (
203+ Locale .ROOT ,
204+ "[%s] MMR result diversification can have one of [%s] or [%s], but not both" ,
205+ getName (),
206+ QUERY_VECTOR_FIELD .getPreferredName (),
207+ QUERY_VECTOR_BUILDER_FIELD .getPreferredName ()
208+ ),
209+ validationException
210+ );
211+ }
212+
184213 if (diversificationType .equals (ResultDiversificationType .MMR )) {
185214 validationException = validateMMRDiversification (validationException );
186215 }
@@ -235,17 +264,37 @@ private ActionRequestValidationException validateMMRDiversification(ActionReques
235264
236265 @ Override
237266 protected RetrieverBuilder doRewrite (QueryRewriteContext ctx ) {
238- if (diversificationType .equals (ResultDiversificationType .MMR )) {
239- // field vectors will be filled in during the combine
240- diversificationContext = new MMRResultDiversificationContext (
267+ if (queryVectorBuilder != null ) {
268+ SetOnce <VectorData > toSet = new SetOnce <>();
269+ ctx .registerAsyncAction ((c , l ) -> {
270+ queryVectorBuilder .buildVector (c , l .delegateFailureAndWrap ((ll , v ) -> {
271+ toSet .set (v == null ? null : new VectorData (v ));
272+ if (v == null ) {
273+ ll .onFailure (
274+ new IllegalArgumentException (
275+ format (
276+ "[%s] with name [%s] returned null query_vector" ,
277+ QUERY_VECTOR_BUILDER_FIELD .getPreferredName (),
278+ queryVectorBuilder .getWriteableName ()
279+ )
280+ )
281+ );
282+ return ;
283+ }
284+ ll .onResponse (null );
285+ }));
286+ });
287+
288+ return new DiversifyRetrieverBuilder (
289+ innerRetrievers ,
290+ diversificationType ,
241291 diversificationField ,
242- lambda ,
243- size == null ? DEFAULT_SIZE_VALUE : size ,
244- queryVector
292+ rankWindowSize ,
293+ size ,
294+ () -> toSet .get (),
295+ null ,
296+ lambda
245297 );
246- } else {
247- // should not happen
248- throw new IllegalArgumentException ("Unknown diversification type [" + diversificationType + "]" );
249298 }
250299
251300 return this ;
@@ -281,13 +330,6 @@ protected Exception processInnerItemFailureException(Exception ex) {
281330
282331 @ Override
283332 protected RankDoc [] combineInnerRetrieverResults (List <ScoreDoc []> rankResults , boolean explain ) {
284- if (diversificationContext == null ) {
285- throw new ElasticsearchStatusException (
286- "diversificationContext is not set. \" doRewrite\" should have been called beforehand." ,
287- RestStatus .INTERNAL_SERVER_ERROR
288- );
289- }
290-
291333 if (rankResults .isEmpty ()) {
292334 return new RankDoc [0 ];
293335 }
@@ -302,6 +344,8 @@ protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, b
302344 return new RankDoc [0 ];
303345 }
304346
347+ ResultDiversificationContext diversificationContext = getResultDiversificationContext ();
348+
305349 // gather and set the query vectors
306350 // and create our intermediate results set
307351 RankDoc [] results = new RankDoc [scoreDocs .length ];
@@ -344,6 +388,15 @@ protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, b
344388 }
345389 }
346390
391+ private ResultDiversificationContext getResultDiversificationContext () {
392+ if (diversificationType .equals (ResultDiversificationType .MMR )) {
393+ return new MMRResultDiversificationContext (diversificationField , lambda , size == null ? DEFAULT_SIZE_VALUE : size , queryVector );
394+ }
395+
396+ // should not happen
397+ throw new IllegalArgumentException ("Unknown diversification type [" + diversificationType + "]" );
398+ }
399+
347400 private void extractFieldVectorData (int docId , Object fieldValue , Map <Integer , VectorData > fieldVectors ) {
348401 switch (fieldValue ) {
349402 case float [] floatArray -> {
@@ -427,7 +480,11 @@ protected void doToXContent(XContentBuilder builder, Params params) throws IOExc
427480 builder .field (RANK_WINDOW_SIZE_FIELD .getPreferredName (), rankWindowSize );
428481
429482 if (queryVector != null ) {
430- builder .field (QUERY_VECTOR_FIELD .getPreferredName (), queryVector );
483+ builder .field (QUERY_VECTOR_FIELD .getPreferredName (), queryVector .get ());
484+ }
485+
486+ if (queryVectorBuilder != null ) {
487+ builder .field (QUERY_VECTOR_BUILDER_FIELD .getPreferredName (), queryVectorBuilder );
431488 }
432489
433490 if (lambda != null ) {
@@ -451,6 +508,8 @@ public boolean doEquals(Object o) {
451508 && this .diversificationType .equals (other .diversificationType )
452509 && this .diversificationField .equals (other .diversificationField )
453510 && Objects .equals (this .lambda , other .lambda )
454- && Objects .equals (this .queryVector , other .queryVector );
511+ && ((queryVector == null && other .queryVector == null )
512+ || (queryVector != null && other .queryVector != null && Objects .equals (queryVector .get (), other .queryVector .get ())))
513+ && Objects .equals (this .queryVectorBuilder , other .queryVectorBuilder );
455514 }
456515}
0 commit comments