Skip to content

Commit 527b024

Browse files
committed
Implement HybridQuery with tests
1 parent 99f9d99 commit 527b024

File tree

3 files changed

+687
-1
lines changed

3 files changed

+687
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ classifiers = [
2424
dependencies = [
2525
"numpy>=1.26.0,<3",
2626
"pyyaml>=5.4,<7.0",
27-
"redis>=5.0,<7.0",
27+
"redis>=5.0,<7.2",
2828
"pydantic>=2,<3",
2929
"tenacity>=8.2.2",
3030
"ml-dtypes>=0.4.0,<1.0.0",

redisvl/query/hybrid.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from typing import Any, Dict, List, Literal, Optional, Set, Union
2+
3+
from redisvl.utils.full_text_query_helper import FullTextQueryHelper
4+
5+
try:
6+
from redis.commands.search.hybrid_query import HybridQuery as _HybridQuery
7+
from redis.commands.search.hybrid_query import (
8+
HybridSearchQuery,
9+
HybridVsimQuery,
10+
VectorSearchMethods,
11+
)
12+
except ImportError:
13+
raise ImportError("Hybrid queries require redis>=7.1.0")
14+
15+
from redisvl.query.filter import FilterExpression
16+
17+
18+
class HybridQuery(_HybridQuery):
19+
"""TBD"""
20+
21+
def __init__(
22+
self,
23+
text: str,
24+
text_field_name: str,
25+
vector: Union[bytes, List[float]],
26+
vector_field_name: str,
27+
text_scorer: str = "BM25STD",
28+
filter_expression: Optional[Union[str, FilterExpression]] = None,
29+
vector_search_method: Optional[Literal["KNN", "RANGE"]] = None,
30+
vector_search_method_params: Optional[Dict[str, Any]] = None,
31+
stopwords: Optional[Union[str, Set[str]]] = "english",
32+
text_weights: Optional[Dict[str, float]] = None,
33+
):
34+
self._ft_helper = FullTextQueryHelper(
35+
stopwords=stopwords,
36+
text_weights=text_weights,
37+
)
38+
39+
search_query = HybridSearchQuery(
40+
query_string=self._ft_helper.build_query_string(
41+
text=text,
42+
text_field_name=text_field_name,
43+
filter_expression=filter_expression,
44+
),
45+
scorer=text_scorer,
46+
)
47+
48+
if not isinstance(vector, bytes):
49+
vector_data: Union[str, bytes] = str(vector)
50+
else:
51+
vector_data = vector
52+
53+
vsim_search_method = None
54+
if vector_search_method:
55+
vsim_search_method = VectorSearchMethods(vector_search_method)
56+
57+
vsim_query = HybridVsimQuery(
58+
vector_field_name=vector_field_name,
59+
vector_data=vector_data,
60+
vsim_search_method=vsim_search_method,
61+
vsim_search_method_params=vector_search_method_params,
62+
# TODO: Implement filter
63+
)
64+
65+
super().__init__(
66+
search_query=search_query,
67+
vector_similarity_query=vsim_query,
68+
)

0 commit comments

Comments
 (0)