Skip to content

Commit 80f1927

Browse files
committed
Add support for combination methods and postprocessing
1 parent da2283e commit 80f1927

File tree

2 files changed

+416
-72
lines changed

2 files changed

+416
-72
lines changed

redisvl/query/hybrid.py

Lines changed: 124 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55
from redisvl.utils.full_text_query_helper import FullTextQueryHelper
66

77
try:
8-
from redis.commands.search.hybrid_query import HybridQuery as _HybridQuery
98
from redis.commands.search.hybrid_query import (
9+
CombinationMethods,
10+
CombineResultsMethod,
11+
HybridPostProcessingConfig,
12+
HybridQuery as RedisHybridQuery,
1013
HybridSearchQuery,
1114
HybridVsimQuery,
1215
VectorSearchMethods,
@@ -17,7 +20,7 @@
1720
from redisvl.query.filter import FilterExpression
1821

1922

20-
class HybridQuery(_HybridQuery):
23+
class HybridQuery:
2124
"""
2225
A hybrid search query that combines text search and vector similarity, with configurable fusion methods.
2326
"""
@@ -34,12 +37,18 @@ def __init__(
3437
vector_search_method: Optional[Literal["KNN", "RANGE"]] = None,
3538
knn_k: Optional[int] = None,
3639
knn_ef_runtime: Optional[int] = None,
37-
range_radius: Optional[float] = None,
40+
range_radius: Optional[int] = None,
3841
range_epsilon: Optional[float] = None,
3942
yield_vsim_score_as: Optional[str] = None,
4043
vector_filter_expression: Optional[Union[str, FilterExpression]] = None,
4144
stopwords: Optional[Union[str, Set[str]]] = "english",
4245
text_weights: Optional[Dict[str, float]] = None,
46+
combination_method: Optional[Literal["RRF", "LINEAR"]] = None,
47+
rrf_window: Optional[int] = None,
48+
rrf_constant: Optional[float] = None,
49+
linear_alpha: Optional[float] = None,
50+
linear_beta: Optional[float] = None,
51+
yield_combined_score_as: Optional[str] = None,
4352
):
4453
"""
4554
Instantiates a HybridQuery object.
@@ -50,7 +59,9 @@ def __init__(
5059
vector: The vector to perform vector similarity search.
5160
vector_field_name: The vector field name to search in.
5261
text_scorer: The text scorer to use. Options are {TFIDF, TFIDF.DOCNORM,
53-
BM25, DISMAX, DOCSCORE, BM25STD}. Defaults to "BM25STD".
62+
BM25STD, BM25STD.NORM, BM25STD.TANH, DISMAX, DOCSCORE, HAMMING}. Defaults to "BM25STD". For more
63+
information about supported scroring algorithms,
64+
see https://redis.io/docs/latest/develop/ai/search-and-query/advanced-concepts/scoring/
5465
text_filter_expression: The filter expression to use for the text search. Defaults to None.
5566
yield_text_score_as: The name of the field to yield the text score as.
5667
vector_search_method: The vector search method to use. Options are {KNN, RANGE}. Defaults to None.
@@ -73,6 +84,18 @@ def __init__(
7384
text_weights (Optional[Dict[str, float]]): The importance weighting of individual words
7485
within the query text. Defaults to None, as no modifications will be made to the
7586
text_scorer score.
87+
combination_method: The combination method to use. Options are {RRF, LINEAR}. Defaults to None.
88+
rrf_window: The window size to use for the reciprocal rank fusion (RRF) combination method. Limits
89+
fusion scope.
90+
rrf_constant: The constant to use for the reciprocal rank fusion (RRF) combination method. Controls decay
91+
of rank influence.
92+
linear_alpha: The weight of the first query for the linear combination method (LINEAR).
93+
linear_beta: The weight of the second query for the linear combination method (LINEAR).
94+
yield_combined_score_as: The name of the field to yield the combined score as.
95+
96+
Notes:
97+
If RRF combination method is used, then at least one of `rrf_window` or `rrf_constant` must be provided.
98+
If LINEAR combination method is used, then at least one of `linear_alpha` or `linear_beta` must be provided.
7699
77100
Raises:
78101
TypeError: If the stopwords are not a set, list, or tuple of strings.
@@ -82,18 +105,65 @@ def __init__(
82105
ValueError: If `vector_search_method` is "KNN" and `knn_k` is not provided.
83106
ValueError: If `vector_search_method` is "RANGE" and `range_radius` is not provided.
84107
"""
108+
self.postprocessing_config = HybridPostProcessingConfig()
85109
self._ft_helper = FullTextQueryHelper(
86110
stopwords=stopwords,
87111
text_weights=text_weights,
88112
)
89113

114+
query_string = self._ft_helper.build_query_string(
115+
text, text_field_name, text_filter_expression
116+
)
117+
118+
self.query = self.build_query(
119+
text_query=query_string,
120+
vector=vector,
121+
vector_field_name=vector_field_name,
122+
text_scorer=text_scorer,
123+
yield_text_score_as=yield_text_score_as,
124+
vector_search_method=vector_search_method,
125+
knn_k=knn_k,
126+
knn_ef_runtime=knn_ef_runtime,
127+
range_radius=range_radius,
128+
range_epsilon=range_epsilon,
129+
yield_vsim_score_as=yield_vsim_score_as,
130+
vector_filter_expression=vector_filter_expression,
131+
)
132+
133+
if combination_method:
134+
self.combination_method: Optional[CombineResultsMethod] = (
135+
self.build_combination_method(
136+
combination_method=combination_method,
137+
rrf_window=rrf_window,
138+
rrf_constant=rrf_constant,
139+
linear_alpha=linear_alpha,
140+
linear_beta=linear_beta,
141+
yield_score_as=yield_combined_score_as,
142+
)
143+
)
144+
else:
145+
self.combination_method = None
146+
147+
@staticmethod
148+
def build_query(
149+
text_query: str,
150+
vector: Union[bytes, List[float]],
151+
vector_field_name: str,
152+
text_scorer: str = "BM25STD",
153+
yield_text_score_as: Optional[str] = None,
154+
vector_search_method: Optional[Literal["KNN", "RANGE"]] = None,
155+
knn_k: Optional[int] = None,
156+
knn_ef_runtime: Optional[int] = None,
157+
range_radius: Optional[int] = None,
158+
range_epsilon: Optional[float] = None,
159+
yield_vsim_score_as: Optional[str] = None,
160+
vector_filter_expression: Optional[Union[str, FilterExpression]] = None,
161+
) -> RedisHybridQuery:
162+
"""Build a Redis HybridQuery for the hybrid search."""
163+
90164
# Serialize the full-text search query
91165
search_query = HybridSearchQuery(
92-
query_string=self._ft_helper.build_query_string(
93-
text=text,
94-
text_field_name=text_field_name,
95-
filter_expression=text_filter_expression,
96-
),
166+
query_string=text_query,
97167
scorer=text_scorer,
98168
yield_score_as=yield_text_score_as,
99169
)
@@ -105,8 +175,8 @@ def __init__(
105175
vector_data = vector
106176

107177
# Serialize vector similarity search method and params, if specified
108-
vsim_search_method = None
109-
vsim_search_method_params = {}
178+
vsim_search_method: Optional[VectorSearchMethods] = None
179+
vsim_search_method_params: Dict[str, Any] = {}
110180
if vector_search_method == "KNN":
111181
vsim_search_method = VectorSearchMethods.KNN
112182
if not knn_k:
@@ -128,18 +198,58 @@ def __init__(
128198
elif vector_search_method is not None:
129199
raise ValueError(f"Unknown vector search method: {vector_search_method}")
130200

201+
if vector_filter_expression:
202+
vsim_filter = Filter("FILTER", str(vector_filter_expression))
203+
else:
204+
vsim_filter = None
205+
131206
# Serialize the vector similarity query
132207
vsim_query = HybridVsimQuery(
133208
vector_field_name=vector_field_name,
134209
vector_data=vector_data,
135210
vsim_search_method=vsim_search_method,
136211
vsim_search_method_params=vsim_search_method_params,
137-
filter=vector_filter_expression and Filter("FILTER", str(vector_filter_expression)),
212+
filter=vsim_filter,
138213
yield_score_as=yield_vsim_score_as,
139214
)
140215

141-
# Initialize the base HybridQuery
142-
super().__init__(
216+
return RedisHybridQuery(
143217
search_query=search_query,
144218
vector_similarity_query=vsim_query,
145219
)
220+
221+
@staticmethod
222+
def build_combination_method(
223+
combination_method: Literal["RRF", "LINEAR"],
224+
rrf_window: Optional[int] = None,
225+
rrf_constant: Optional[float] = None,
226+
linear_alpha: Optional[float] = None,
227+
linear_beta: Optional[float] = None,
228+
yield_score_as: Optional[str] = None,
229+
) -> CombineResultsMethod:
230+
"""Build a configuration for combining hybrid search scores."""
231+
method_params: Dict[str, Any] = {}
232+
if combination_method == "RRF":
233+
method = CombinationMethods.RRF
234+
if rrf_window:
235+
method_params["WINDOW"] = rrf_window
236+
if rrf_constant:
237+
method_params["CONSTANT"] = rrf_constant
238+
239+
elif combination_method == "LINEAR":
240+
method = CombinationMethods.LINEAR
241+
if linear_alpha:
242+
method_params["ALPHA"] = linear_alpha
243+
if linear_beta:
244+
method_params["BETA"] = linear_beta
245+
246+
else:
247+
raise ValueError(f"Unknown combination method: {combination_method}")
248+
249+
if yield_score_as:
250+
method_params["YIELD_SCORE_AS"] = yield_score_as
251+
252+
return CombineResultsMethod(
253+
method=method,
254+
**method_params,
255+
)

0 commit comments

Comments
 (0)