Skip to content

Commit 9832369

Browse files
committed
Update hybrid search usage based on in-practice constraints
1 parent 80f1927 commit 9832369

File tree

6 files changed

+539
-98
lines changed

6 files changed

+539
-98
lines changed

redisvl/index/index.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,20 @@
4545
from redis.commands.search.aggregation import AggregateResult
4646
from redis.commands.search.document import Document
4747
from redis.commands.search.result import Result
48+
4849
from redisvl.query.query import BaseQuery
4950

51+
try:
52+
from redis.commands.search.hybrid_result import HybridResult
53+
54+
from redisvl.query.hybrid import HybridQuery
55+
56+
REDIS_HYBRID_AVAILABLE = True
57+
except ImportError:
58+
REDIS_HYBRID_AVAILABLE = False
59+
HybridResult = None # type: ignore
60+
HybridQuery = None # type: ignore
61+
5062
from redis import __version__ as redis_version
5163
from redis.client import NEVER_DECODE
5264

@@ -215,6 +227,13 @@ def _process(row):
215227
return [_process(r) for r in results.rows]
216228

217229

230+
if REDIS_HYBRID_AVAILABLE:
231+
232+
def process_hybrid_results(results: HybridResult) -> List[Dict[str, Any]]:
233+
"""Convert a hybrid result object into a list of document dictionaries."""
234+
return [convert_bytes(r) for r in results.results]
235+
236+
218237
class BaseSearchIndex:
219238
"""Base search engine class"""
220239

@@ -1003,6 +1022,23 @@ def search(self, *args, **kwargs) -> "Result":
10031022
except Exception as e:
10041023
raise RedisSearchError(f"Unexpected error while searching: {str(e)}") from e
10051024

1025+
if REDIS_HYBRID_AVAILABLE:
1026+
1027+
def hybrid_search(self, query: HybridQuery, **kwargs) -> List[Dict[str, Any]]:
1028+
results: HybridResult = self._redis_client.ft(
1029+
self.schema.index.name
1030+
).hybrid_search(
1031+
query=query.query,
1032+
combine_method=query.combination_method,
1033+
post_processing=(
1034+
query.postprocessing_config
1035+
if query.postprocessing_config.build_args()
1036+
else None
1037+
),
1038+
**kwargs,
1039+
) # type: ignore
1040+
return process_hybrid_results(results)
1041+
10061042
def batch_query(
10071043
self, queries: Sequence[BaseQuery], batch_size: int = 10
10081044
) -> List[List[Dict[str, Any]]]:

redisvl/query/aggregate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def _build_query_string(self) -> str:
235235
# Add distance field alias
236236
knn_query += f" AS {self.DISTANCE_ID}"
237237

238-
return f"{text})=>[{knn_query}]"
238+
return f"{text}=>[{knn_query}]"
239239

240240
def __str__(self) -> str:
241241
"""Return the string representation of the query."""

redisvl/query/hybrid.py

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22

33
from redis.commands.search.query import Filter
44

5+
from redisvl.redis.utils import array_to_buffer
56
from redisvl.utils.full_text_query_helper import FullTextQueryHelper
67

78
try:
89
from redis.commands.search.hybrid_query import (
910
CombinationMethods,
1011
CombineResultsMethod,
1112
HybridPostProcessingConfig,
12-
HybridQuery as RedisHybridQuery,
13+
)
14+
from redis.commands.search.hybrid_query import HybridQuery as RedisHybridQuery
15+
from redis.commands.search.hybrid_query import (
1316
HybridSearchQuery,
1417
HybridVsimQuery,
1518
VectorSearchMethods,
@@ -41,22 +44,26 @@ def __init__(
4144
range_epsilon: Optional[float] = None,
4245
yield_vsim_score_as: Optional[str] = None,
4346
vector_filter_expression: Optional[Union[str, FilterExpression]] = None,
44-
stopwords: Optional[Union[str, Set[str]]] = "english",
45-
text_weights: Optional[Dict[str, float]] = None,
4647
combination_method: Optional[Literal["RRF", "LINEAR"]] = None,
4748
rrf_window: Optional[int] = None,
4849
rrf_constant: Optional[float] = None,
4950
linear_alpha: Optional[float] = None,
5051
linear_beta: Optional[float] = None,
5152
yield_combined_score_as: Optional[str] = None,
53+
dtype: str = "float32",
54+
num_results: Optional[int] = None,
55+
return_fields: Optional[List[str]] = None,
56+
stopwords: Optional[Union[str, Set[str]]] = "english",
57+
text_weights: Optional[Dict[str, float]] = None,
5258
):
5359
"""
5460
Instantiates a HybridQuery object.
5561
5662
Args:
5763
text: The text to search for.
5864
text_field_name: The text field name to search in.
59-
vector: The vector to perform vector similarity search.
65+
vector: The vector to perform vector similarity search, converted to bytes (e.g.
66+
using `redisvl.redis.utils.array_to_buffer`).
6067
vector_field_name: The vector field name to search in.
6168
text_scorer: The text scorer to use. Options are {TFIDF, TFIDF.DOCNORM,
6269
BM25STD, BM25STD.NORM, BM25STD.TANH, DISMAX, DOCSCORE, HAMMING}. Defaults to "BM25STD". For more
@@ -72,6 +79,17 @@ def __init__(
7279
accuracy of the search.
7380
yield_vsim_score_as: The name of the field to yield the vector similarity score as.
7481
vector_filter_expression: The filter expression to use for the vector similarity search. Defaults to None.
82+
combination_method: The combination method to use. Options are {RRF, LINEAR}. Defaults to None.
83+
rrf_window: The window size to use for the reciprocal rank fusion (RRF) combination method. Limits
84+
fusion scope.
85+
rrf_constant: The constant to use for the reciprocal rank fusion (RRF) combination method. Controls decay
86+
of rank influence.
87+
linear_alpha: The weight of the first query for the linear combination method (LINEAR).
88+
linear_beta: The weight of the second query for the linear combination method (LINEAR).
89+
yield_combined_score_as: The name of the field to yield the combined score as.
90+
dtype: The data type of the vector. Defaults to "float32".
91+
num_results: The number of results to return. If not specified, the server default will be used (10).
92+
return_fields: The fields to return. Defaults to None.
7593
stopwords (Optional[Union[str, Set[str]]], optional): The stopwords to remove from the
7694
provided text prior to search-use. If a string such as "english" "german" is
7795
provided then a default set of stopwords for that language will be used. if a list,
@@ -84,14 +102,6 @@ def __init__(
84102
text_weights (Optional[Dict[str, float]]): The importance weighting of individual words
85103
within the query text. Defaults to None, as no modifications will be made to the
86104
text_scorer score.
87-
combination_method: The combination method to use. Options are {RRF, LINEAR}. Defaults to None.
88-
rrf_window: The window size to use for the reciprocal rank fusion (RRF) combination method. Limits
89-
fusion scope.
90-
rrf_constant: The constant to use for the reciprocal rank fusion (RRF) combination method. Controls decay
91-
of rank influence.
92-
linear_alpha: The weight of the first query for the linear combination method (LINEAR).
93-
linear_beta: The weight of the second query for the linear combination method (LINEAR).
94-
yield_combined_score_as: The name of the field to yield the combined score as.
95105
96106
Notes:
97107
If RRF combination method is used, then at least one of `rrf_window` or `rrf_constant` must be provided.
@@ -101,11 +111,16 @@ def __init__(
101111
TypeError: If the stopwords are not a set, list, or tuple of strings.
102112
ValueError: If the text string is empty, or if the text string becomes empty after
103113
stopwords are removed.
104-
ValueError: If `vector_search_method` is not one of {KNN, RANGE} (or None).
114+
ValueError: If `vector_search_method` is defined and isn't one of {KNN, RANGE}.
105115
ValueError: If `vector_search_method` is "KNN" and `knn_k` is not provided.
106116
ValueError: If `vector_search_method` is "RANGE" and `range_radius` is not provided.
107117
"""
108118
self.postprocessing_config = HybridPostProcessingConfig()
119+
if num_results:
120+
self.postprocessing_config.limit(offset=0, num=num_results)
121+
if return_fields:
122+
self.postprocessing_config.load(*(f"@{f}" for f in return_fields))
123+
109124
self._ft_helper = FullTextQueryHelper(
110125
stopwords=stopwords,
111126
text_weights=text_weights,
@@ -128,6 +143,7 @@ def __init__(
128143
range_epsilon=range_epsilon,
129144
yield_vsim_score_as=yield_vsim_score_as,
130145
vector_filter_expression=vector_filter_expression,
146+
dtype=dtype,
131147
)
132148

133149
if combination_method:
@@ -147,7 +163,7 @@ def __init__(
147163
@staticmethod
148164
def build_query(
149165
text_query: str,
150-
vector: Union[bytes, List[float]],
166+
vector: bytes | List[float],
151167
vector_field_name: str,
152168
text_scorer: str = "BM25STD",
153169
yield_text_score_as: Optional[str] = None,
@@ -158,6 +174,7 @@ def build_query(
158174
range_epsilon: Optional[float] = None,
159175
yield_vsim_score_as: Optional[str] = None,
160176
vector_filter_expression: Optional[Union[str, FilterExpression]] = None,
177+
dtype: str = "float32",
161178
) -> RedisHybridQuery:
162179
"""Build a Redis HybridQuery for the hybrid search."""
163180

@@ -168,11 +185,10 @@ def build_query(
168185
yield_score_as=yield_text_score_as,
169186
)
170187

171-
# If the vector isn't already bytes, it needs to be represented as a string
172-
if not isinstance(vector, bytes):
173-
vector_data: Union[str, bytes] = str(vector)
174-
else:
188+
if isinstance(vector, bytes):
175189
vector_data = vector
190+
else:
191+
vector_data = array_to_buffer(vector, dtype)
176192

177193
# Serialize vector similarity search method and params, if specified
178194
vsim_search_method: Optional[VectorSearchMethods] = None
@@ -205,7 +221,7 @@ def build_query(
205221

206222
# Serialize the vector similarity query
207223
vsim_query = HybridVsimQuery(
208-
vector_field_name=vector_field_name,
224+
vector_field_name="@" + vector_field_name,
209225
vector_data=vector_data,
210226
vsim_search_method=vsim_search_method,
211227
vsim_search_method_params=vsim_search_method_params,
@@ -240,15 +256,27 @@ def build_combination_method(
240256
method = CombinationMethods.LINEAR
241257
if linear_alpha:
242258
method_params["ALPHA"] = linear_alpha
259+
if not linear_beta:
260+
method_params["BETA"] = 1 - linear_alpha
261+
243262
if linear_beta:
263+
if not linear_alpha: # Defined first to preserve consistent ordering
264+
method_params["ALPHA"] = 1 - linear_beta
244265
method_params["BETA"] = linear_beta
245266

267+
# TODO: Warn if alpha + beta != 1
268+
246269
else:
247270
raise ValueError(f"Unknown combination method: {combination_method}")
248271

249272
if yield_score_as:
250273
method_params["YIELD_SCORE_AS"] = yield_score_as
251274

275+
if not method_params:
276+
raise ValueError(
277+
"No parameters provided for combination method - must provide at least one parameter."
278+
)
279+
252280
return CombineResultsMethod(
253281
method=method,
254282
**method_params,

redisvl/utils/full_text_query_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def build_query_string(
6969
if filter_expression and filter_expression != "*":
7070
query += f" AND {filter_expression}"
7171

72-
return query
72+
return query + ")"
7373

7474
def _get_stopwords(
7575
self, stopwords: Optional[Union[str, Set[str]]] = "english"

0 commit comments

Comments
 (0)