Skip to content

Commit 51ce1e0

Browse files
committed
Abstract full-text query construction into helper class
1 parent d992f8f commit 51ce1e0

File tree

2 files changed

+170
-104
lines changed

2 files changed

+170
-104
lines changed

redisvl/query/aggregate.py

Lines changed: 16 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import warnings
2-
from typing import Any, Dict, List, Optional, Set, Tuple, Union
2+
from typing import Any, Dict, List, Optional, Set, Union
33

44
from pydantic import BaseModel, field_validator, model_validator
55
from redis.commands.search.aggregation import AggregateRequest, Desc
@@ -8,7 +8,7 @@
88
from redisvl.query.filter import FilterExpression
99
from redisvl.redis.utils import array_to_buffer
1010
from redisvl.schema.fields import VectorDataType
11-
from redisvl.utils.token_escaper import TokenEscaper
11+
from redisvl.utils.full_text_query_helper import FullTextQueryHelper
1212
from redisvl.utils.utils import lazy_import
1313

1414
nltk = lazy_import("nltk")
@@ -159,8 +159,11 @@ def __init__(
159159
self._alpha = alpha
160160
self._dtype = dtype
161161
self._num_results = num_results
162-
self._set_stopwords(stopwords)
163-
self._text_weights = self._parse_text_weights(text_weights)
162+
163+
self._ft_helper = FullTextQueryHelper(
164+
stopwords=stopwords,
165+
text_weights=text_weights,
166+
)
164167

165168
query_string = self._build_query_string()
166169
super().__init__(query_string)
@@ -198,115 +201,29 @@ def stopwords(self) -> Set[str]:
198201
Returns:
199202
Set[str]: The stopwords used in the query.
200203
"""
201-
return self._stopwords.copy() if self._stopwords else set()
202-
203-
def _set_stopwords(self, stopwords: Optional[Union[str, Set[str]]] = "english"):
204-
"""Set the stopwords to use in the query.
205-
Args:
206-
stopwords (Optional[Union[str, Set[str]]]): The stopwords to use. If a string
207-
such as "english" "german" is provided then a default set of stopwords for that
208-
language will be used. if a list, set, or tuple of strings is provided then those
209-
will be used as stopwords. Defaults to "english". if set to "None" then no stopwords
210-
will be removed.
211-
212-
Raises:
213-
TypeError: If the stopwords are not a set, list, or tuple of strings.
214-
"""
215-
if not stopwords:
216-
self._stopwords = set()
217-
elif isinstance(stopwords, str):
218-
try:
219-
nltk.download("stopwords", quiet=True)
220-
self._stopwords = set(nltk_stopwords.words(stopwords))
221-
except ImportError:
222-
raise ValueError(
223-
f"Loading stopwords for {stopwords} failed: nltk is not installed."
224-
)
225-
except Exception as e:
226-
raise ValueError(f"Error trying to load {stopwords} from nltk. {e}")
227-
elif isinstance(stopwords, (Set, List, Tuple)) and all( # type: ignore
228-
isinstance(word, str) for word in stopwords
229-
):
230-
self._stopwords = set(stopwords)
231-
else:
232-
raise TypeError("stopwords must be a set, list, or tuple of strings")
204+
return self._ft_helper.stopwords
233205

234-
def _tokenize_and_escape_query(self, user_query: str) -> str:
235-
"""Convert a raw user query to a redis full text query joined by ORs
236-
Args:
237-
user_query (str): The user query to tokenize and escape.
206+
@property
207+
def text_weights(self) -> Dict[str, float]:
208+
"""Get the text weights.
238209
239210
Returns:
240-
str: The tokenized and escaped query string.
241-
242-
Raises:
243-
ValueError: If the text string becomes empty after stopwords are removed.
211+
Dictionary of word:weight mappings.
244212
"""
245-
escaper = TokenEscaper()
246-
247-
tokens = [
248-
escaper.escape(
249-
token.strip().strip(",").replace("“", "").replace("”", "").lower()
250-
)
251-
for token in user_query.split()
252-
]
253-
254-
token_list = [
255-
token for token in tokens if token and token not in self._stopwords
256-
]
257-
for i, token in enumerate(token_list):
258-
if token in self._text_weights:
259-
token_list[i] = f"{token}=>{{$weight:{self._text_weights[token]}}}"
260-
261-
if not token_list:
262-
raise ValueError("text string cannot be empty after removing stopwords")
263-
return " | ".join(token_list)
264-
265-
def _parse_text_weights(
266-
self, weights: Optional[Dict[str, float]]
267-
) -> Dict[str, float]:
268-
parsed_weights: Dict[str, float] = {}
269-
if not weights:
270-
return parsed_weights
271-
for word, weight in weights.items():
272-
word = word.strip().lower()
273-
if not word or " " in word:
274-
raise ValueError(
275-
f"Only individual words may be weighted. Got {{ {word}:{weight} }}"
276-
)
277-
if (
278-
not (isinstance(weight, float) or isinstance(weight, int))
279-
or weight < 0.0
280-
):
281-
raise ValueError(
282-
f"Weights must be positive number. Got {{ {word}:{weight} }}"
283-
)
284-
parsed_weights[word] = weight
285-
return parsed_weights
213+
return self._ft_helper.text_weights
286214

287215
def set_text_weights(self, weights: Dict[str, float]):
288216
"""Set or update the text weights for the query.
289217
290218
Args:
291-
text_weights: Dictionary of word:weight mappings
219+
weights: Dictionary of word:weight mappings
292220
"""
293-
self._text_weights = self._parse_text_weights(weights)
221+
self._ft_helper.set_text_weights(weights)
294222
self._query = self._build_query_string()
295223

296-
@property
297-
def text_weights(self) -> Dict[str, float]:
298-
"""Get the text weights.
299-
300-
Returns:
301-
Dictionary of word:weight mappings.
302-
"""
303-
return self._text_weights
304-
305224
def _build_query_string(self) -> str:
306225
"""Build the full query string for text search with optional filtering."""
307-
filter_expression = self._filter_expression
308-
if isinstance(self._filter_expression, FilterExpression):
309-
filter_expression = str(self._filter_expression)
226+
text = self._ft_helper.build_query_string(self._text, self._text_field, self._filter_expression)
310227

311228
# Build KNN query
312229
knn_query = (
@@ -316,11 +233,6 @@ def _build_query_string(self) -> str:
316233
# Add distance field alias
317234
knn_query += f" AS {self.DISTANCE_ID}"
318235

319-
text = f"(~@{self._text_field}:({self._tokenize_and_escape_query(self._text)})"
320-
321-
if filter_expression and filter_expression != "*":
322-
text += f" AND {filter_expression}"
323-
324236
return f"{text})=>[{knn_query}]"
325237

326238
def __str__(self) -> str:
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
from typing import Optional, Union, Set, Dict, List, Tuple
2+
3+
from redisvl.query.filter import FilterExpression
4+
from redisvl.utils.token_escaper import TokenEscaper
5+
from redisvl.utils.utils import lazy_import
6+
7+
nltk = lazy_import("nltk")
8+
nltk_stopwords = lazy_import("nltk.corpus.stopwords")
9+
10+
11+
def _parse_text_weights(
12+
weights: Optional[Dict[str, float]]
13+
) -> Dict[str, float]:
14+
parsed_weights: Dict[str, float] = {}
15+
if not weights:
16+
return parsed_weights
17+
for word, weight in weights.items():
18+
word = word.strip().lower()
19+
if not word or " " in word:
20+
raise ValueError(
21+
f"Only individual words may be weighted. Got {{ {word}:{weight} }}"
22+
)
23+
if (
24+
not (isinstance(weight, float) or isinstance(weight, int))
25+
or weight < 0.0
26+
):
27+
raise ValueError(
28+
f"Weights must be positive number. Got {{ {word}:{weight} }}"
29+
)
30+
parsed_weights[word] = weight
31+
return parsed_weights
32+
33+
34+
class FullTextQueryHelper:
35+
"""Convert raw user queries into Redis full-text queries - tokenizes, escapes, and filters stopwords from the query.
36+
"""
37+
38+
def __init__(
39+
self,
40+
stopwords: Optional[Union[str, Set[str]]] = "english",
41+
text_weights: Optional[Dict[str, float]] = None,
42+
):
43+
self._stopwords = self._get_stopwords(stopwords)
44+
self._text_weights = _parse_text_weights(text_weights)
45+
46+
@property
47+
def stopwords(self) -> Set[str]:
48+
"""Return the stopwords used in the query.
49+
Returns:
50+
Set[str]: The stopwords used in the query.
51+
"""
52+
return self._stopwords.copy() if self._stopwords else set()
53+
54+
@property
55+
def text_weights(self) -> Dict[str, float]:
56+
"""Get the text weights.
57+
58+
Returns:
59+
Dictionary of word:weight mappings.
60+
"""
61+
return self._text_weights
62+
63+
def build_query_string(
64+
self,
65+
text: str,
66+
text_field_name: str,
67+
filter_expression: Optional[Union[str, FilterExpression]] = None,
68+
) -> str:
69+
"""Build the full-text query string for text search with optional filtering."""
70+
if isinstance(filter_expression, FilterExpression):
71+
filter_expression = str(filter_expression)
72+
73+
query = f"(~@{text_field_name}:({self._tokenize_and_escape_query(text)})"
74+
75+
if filter_expression and filter_expression != "*":
76+
query += f" AND {filter_expression}"
77+
78+
return query
79+
80+
def _get_stopwords(self, stopwords: Optional[Union[str, Set[str]]] = "english") -> Set[str]:
81+
"""Get the stopwords to use in the query.
82+
83+
Args:
84+
stopwords (Optional[Union[str, Set[str]]]): The stopwords to use. If a string
85+
such as "english" "german" is provided then a default set of stopwords for that
86+
language will be used. if a list, set, or tuple of strings is provided then those
87+
will be used as stopwords. Defaults to "english". if set to "None" then no stopwords
88+
will be removed.
89+
90+
Returns:
91+
The set of stopwords to use.
92+
93+
Raises:
94+
TypeError: If the stopwords are not a set, list, or tuple of strings.
95+
"""
96+
if not stopwords:
97+
return set()
98+
elif isinstance(stopwords, str):
99+
try:
100+
nltk.download("stopwords", quiet=True)
101+
return set(nltk_stopwords.words(stopwords))
102+
except ImportError:
103+
raise ValueError(
104+
f"Loading stopwords for {stopwords} failed: nltk is not installed."
105+
)
106+
except Exception as e:
107+
raise ValueError(f"Error trying to load {stopwords} from nltk. {e}")
108+
elif isinstance(stopwords, (Set, List, Tuple)) and all( # type: ignore
109+
isinstance(word, str) for word in stopwords
110+
):
111+
return set(stopwords)
112+
else:
113+
raise TypeError("stopwords must be a set, list, or tuple of strings")
114+
115+
def set_text_weights(self, weights: Dict[str, float]):
116+
"""Set or update the text weights for the query.
117+
118+
Args:
119+
weights: Dictionary of word:weight mappings
120+
"""
121+
self._text_weights = _parse_text_weights(weights)
122+
123+
def _tokenize_and_escape_query(self, user_query: str) -> str:
124+
"""Convert a raw user query to a redis full text query joined by ORs
125+
126+
Args:
127+
user_query (str): The user query to tokenize and escape.
128+
129+
Returns:
130+
str: The tokenized and escaped query string.
131+
132+
Raises:
133+
ValueError: If the text string becomes empty after stopwords are removed.
134+
"""
135+
escaper = TokenEscaper()
136+
137+
tokens = [
138+
escaper.escape(
139+
token.strip().strip(",").replace("“", "").replace("”", "").lower()
140+
)
141+
for token in user_query.split()
142+
]
143+
144+
token_list = [
145+
token for token in tokens if token and token not in self._stopwords
146+
]
147+
for i, token in enumerate(token_list):
148+
if token in self._text_weights:
149+
token_list[i] = f"{token}=>{{$weight:{self._text_weights[token]}}}"
150+
151+
if not token_list:
152+
raise ValueError("text string cannot be empty after removing stopwords")
153+
return " | ".join(token_list)
154+

0 commit comments

Comments
 (0)