11import warnings
2- from typing import Any , Dict , List , Optional , Set , Tuple , Union
2+ from typing import Any , Dict , List , Optional , Set , Union
33
44from pydantic import BaseModel , field_validator , model_validator
55from redis .commands .search .aggregation import AggregateRequest , Desc
88from redisvl .query .filter import FilterExpression
99from redisvl .redis .utils import array_to_buffer
1010from redisvl .schema .fields import VectorDataType
11- from redisvl .utils .token_escaper import TokenEscaper
11+ from redisvl .utils .full_text_query_helper import FullTextQueryHelper
1212from redisvl .utils .utils import lazy_import
1313
1414nltk = lazy_import ("nltk" )
@@ -159,8 +159,11 @@ def __init__(
159159 self ._alpha = alpha
160160 self ._dtype = dtype
161161 self ._num_results = num_results
162- self ._set_stopwords (stopwords )
163- self ._text_weights = self ._parse_text_weights (text_weights )
162+
163+ self ._ft_helper = FullTextQueryHelper (
164+ stopwords = stopwords ,
165+ text_weights = text_weights ,
166+ )
164167
165168 query_string = self ._build_query_string ()
166169 super ().__init__ (query_string )
@@ -198,115 +201,29 @@ def stopwords(self) -> Set[str]:
198201 Returns:
199202 Set[str]: The stopwords used in the query.
200203 """
201- return self ._stopwords .copy () if self ._stopwords else set ()
202-
203- def _set_stopwords (self , stopwords : Optional [Union [str , Set [str ]]] = "english" ):
204- """Set the stopwords to use in the query.
205- Args:
206- stopwords (Optional[Union[str, Set[str]]]): The stopwords to use. If a string
207- such as "english" "german" is provided then a default set of stopwords for that
208- language will be used. if a list, set, or tuple of strings is provided then those
209- will be used as stopwords. Defaults to "english". if set to "None" then no stopwords
210- will be removed.
211-
212- Raises:
213- TypeError: If the stopwords are not a set, list, or tuple of strings.
214- """
215- if not stopwords :
216- self ._stopwords = set ()
217- elif isinstance (stopwords , str ):
218- try :
219- nltk .download ("stopwords" , quiet = True )
220- self ._stopwords = set (nltk_stopwords .words (stopwords ))
221- except ImportError :
222- raise ValueError (
223- f"Loading stopwords for { stopwords } failed: nltk is not installed."
224- )
225- except Exception as e :
226- raise ValueError (f"Error trying to load { stopwords } from nltk. { e } " )
227- elif isinstance (stopwords , (Set , List , Tuple )) and all ( # type: ignore
228- isinstance (word , str ) for word in stopwords
229- ):
230- self ._stopwords = set (stopwords )
231- else :
232- raise TypeError ("stopwords must be a set, list, or tuple of strings" )
204+ return self ._ft_helper .stopwords
233205
234- def _tokenize_and_escape_query (self , user_query : str ) -> str :
235- """Convert a raw user query to a redis full text query joined by ORs
236- Args:
237- user_query (str): The user query to tokenize and escape.
206+ @property
207+ def text_weights (self ) -> Dict [str , float ]:
208+ """Get the text weights.
238209
239210 Returns:
240- str: The tokenized and escaped query string.
241-
242- Raises:
243- ValueError: If the text string becomes empty after stopwords are removed.
211+ Dictionary of word:weight mappings.
244212 """
245- escaper = TokenEscaper ()
246-
247- tokens = [
248- escaper .escape (
249- token .strip ().strip ("," ).replace ("“" , "" ).replace ("”" , "" ).lower ()
250- )
251- for token in user_query .split ()
252- ]
253-
254- token_list = [
255- token for token in tokens if token and token not in self ._stopwords
256- ]
257- for i , token in enumerate (token_list ):
258- if token in self ._text_weights :
259- token_list [i ] = f"{ token } =>{{$weight:{ self ._text_weights [token ]} }}"
260-
261- if not token_list :
262- raise ValueError ("text string cannot be empty after removing stopwords" )
263- return " | " .join (token_list )
264-
265- def _parse_text_weights (
266- self , weights : Optional [Dict [str , float ]]
267- ) -> Dict [str , float ]:
268- parsed_weights : Dict [str , float ] = {}
269- if not weights :
270- return parsed_weights
271- for word , weight in weights .items ():
272- word = word .strip ().lower ()
273- if not word or " " in word :
274- raise ValueError (
275- f"Only individual words may be weighted. Got {{ { word } :{ weight } }}"
276- )
277- if (
278- not (isinstance (weight , float ) or isinstance (weight , int ))
279- or weight < 0.0
280- ):
281- raise ValueError (
282- f"Weights must be positive number. Got {{ { word } :{ weight } }}"
283- )
284- parsed_weights [word ] = weight
285- return parsed_weights
213+ return self ._ft_helper .text_weights
286214
287215 def set_text_weights (self , weights : Dict [str , float ]):
288216 """Set or update the text weights for the query.
289217
290218 Args:
291- text_weights : Dictionary of word:weight mappings
219+ weights : Dictionary of word:weight mappings
292220 """
293- self ._text_weights = self . _parse_text_weights (weights )
221+ self ._ft_helper . set_text_weights (weights )
294222 self ._query = self ._build_query_string ()
295223
296- @property
297- def text_weights (self ) -> Dict [str , float ]:
298- """Get the text weights.
299-
300- Returns:
301- Dictionary of word:weight mappings.
302- """
303- return self ._text_weights
304-
305224 def _build_query_string (self ) -> str :
306225 """Build the full query string for text search with optional filtering."""
307- filter_expression = self ._filter_expression
308- if isinstance (self ._filter_expression , FilterExpression ):
309- filter_expression = str (self ._filter_expression )
226+ text = self ._ft_helper .build_query_string (self ._text , self ._text_field , self ._filter_expression )
310227
311228 # Build KNN query
312229 knn_query = (
@@ -316,11 +233,6 @@ def _build_query_string(self) -> str:
316233 # Add distance field alias
317234 knn_query += f" AS { self .DISTANCE_ID } "
318235
319- text = f"(~@{ self ._text_field } :({ self ._tokenize_and_escape_query (self ._text )} )"
320-
321- if filter_expression and filter_expression != "*" :
322- text += f" AND { filter_expression } "
323-
324236 return f"{ text } )=>[{ knn_query } ]"
325237
326238 def __str__ (self ) -> str :
0 commit comments