|
| 1 | +from typing import Any, Dict, List, Literal, Optional, Set, Union |
| 2 | + |
| 3 | +from redisvl.utils.full_text_query_helper import FullTextQueryHelper |
| 4 | + |
| 5 | +try: |
| 6 | + from redis.commands.search.hybrid_query import HybridQuery as _HybridQuery |
| 7 | + from redis.commands.search.hybrid_query import ( |
| 8 | + HybridSearchQuery, |
| 9 | + HybridVsimQuery, |
| 10 | + VectorSearchMethods, |
| 11 | + ) |
| 12 | +except ImportError: |
| 13 | + raise ImportError("Hybrid queries require redis>=7.1.0") |
| 14 | + |
| 15 | +from redisvl.query.filter import FilterExpression |
| 16 | + |
| 17 | + |
| 18 | +class HybridQuery(_HybridQuery): |
| 19 | + """TBD""" |
| 20 | + |
| 21 | + def __init__( |
| 22 | + self, |
| 23 | + text: str, |
| 24 | + text_field_name: str, |
| 25 | + vector: Union[bytes, List[float]], |
| 26 | + vector_field_name: str, |
| 27 | + text_scorer: str = "BM25STD", |
| 28 | + filter_expression: Optional[Union[str, FilterExpression]] = None, |
| 29 | + vector_search_method: Optional[Literal["KNN", "RANGE"]] = None, |
| 30 | + vector_search_method_params: Optional[Dict[str, Any]] = None, |
| 31 | + stopwords: Optional[Union[str, Set[str]]] = "english", |
| 32 | + text_weights: Optional[Dict[str, float]] = None, |
| 33 | + ): |
| 34 | + self._ft_helper = FullTextQueryHelper( |
| 35 | + stopwords=stopwords, |
| 36 | + text_weights=text_weights, |
| 37 | + ) |
| 38 | + |
| 39 | + search_query = HybridSearchQuery( |
| 40 | + query_string=self._ft_helper.build_query_string( |
| 41 | + text=text, |
| 42 | + text_field_name=text_field_name, |
| 43 | + filter_expression=filter_expression, |
| 44 | + ), |
| 45 | + scorer=text_scorer, |
| 46 | + ) |
| 47 | + |
| 48 | + if not isinstance(vector, bytes): |
| 49 | + vector_data: Union[str, bytes] = str(vector) |
| 50 | + else: |
| 51 | + vector_data = vector |
| 52 | + |
| 53 | + vsim_search_method = None |
| 54 | + if vector_search_method: |
| 55 | + vsim_search_method = VectorSearchMethods(vector_search_method) |
| 56 | + |
| 57 | + vsim_query = HybridVsimQuery( |
| 58 | + vector_field_name=vector_field_name, |
| 59 | + vector_data=vector_data, |
| 60 | + vsim_search_method=vsim_search_method, |
| 61 | + vsim_search_method_params=vector_search_method_params, |
| 62 | + # TODO: Implement filter |
| 63 | + ) |
| 64 | + |
| 65 | + super().__init__( |
| 66 | + search_query=search_query, |
| 67 | + vector_similarity_query=vsim_query, |
| 68 | + ) |
0 commit comments