11import pytest
2- from redis .commands .search .hybrid_result import HybridResult
32
4- from redisvl .index import SearchIndex
5- from redisvl .query .filter import FilterExpression , Geo , GeoRadius , Num , Tag , Text
3+ from redisvl .index import AsyncSearchIndex , SearchIndex
4+ from redisvl .query .filter import Geo , GeoRadius , Num , Tag , Text
65from redisvl .query .hybrid import HybridQuery
76from redisvl .redis .utils import array_to_buffer
8- from tests .conftest import skip_if_redis_version_below
7+ from redisvl .schema import IndexSchema
8+ from tests .conftest import (
9+ get_redis_version_async ,
10+ skip_if_redis_version_below ,
11+ skip_if_redis_version_below_async ,
12+ )
913
1014
1115@pytest .fixture
12- def index (multi_vector_data , redis_url , worker_id ):
13-
14- index = SearchIndex .from_dict (
16+ def index_schema (worker_id ):
17+ return IndexSchema .from_dict (
1518 {
1619 "index" : {
1720 "name" : f"user_index_{ worker_id } " ,
@@ -56,10 +59,14 @@ def index(multi_vector_data, redis_url, worker_id):
5659 },
5760 },
5861 ],
59- },
60- redis_url = redis_url ,
62+ }
6163 )
6264
65+
66+ @pytest .fixture
67+ def index (index_schema , multi_vector_data , redis_url ):
68+ index = SearchIndex (schema = index_schema , redis_url = redis_url )
69+
6370 # create the index (no data yet)
6471 index .create (overwrite = True )
6572
@@ -81,6 +88,24 @@ def hash_preprocess(item: dict) -> dict:
8188 index .delete (drop = True )
8289
8390
91+ @pytest .fixture
92+ async def async_index (index_schema , multi_vector_data , async_client ):
93+ index = AsyncSearchIndex (schema = index_schema , redis_client = async_client )
94+ await index .create (overwrite = True )
95+
96+ def hash_preprocess (item : dict ) -> dict :
97+ return {
98+ ** item ,
99+ "user_embedding" : array_to_buffer (item ["user_embedding" ], "float32" ),
100+ "image_embedding" : array_to_buffer (item ["image_embedding" ], "float32" ),
101+ "audio_embedding" : array_to_buffer (item ["audio_embedding" ], "float64" ),
102+ }
103+
104+ await index .load (multi_vector_data , preprocess = hash_preprocess )
105+ yield index
106+ await index .delete (drop = True )
107+
108+
84109def test_hybrid_query (index ):
85110 skip_if_redis_version_below (index .client , "8.4.0" )
86111
@@ -371,3 +396,62 @@ def test_hybrid_query_word_weights(index, scorer):
371396
372397 weighted_results = index .hybrid_search (weighted_query )
373398 assert weighted_results != unweighted_results
399+
400+
401+ @pytest .mark .asyncio
402+ async def test_hybrid_query_async (async_index ):
403+ await skip_if_redis_version_below_async (async_index .client , "8.4.0" )
404+
405+ text = "a medical professional with expertise in lung cancer"
406+ text_field = "description"
407+ vector = [0.1 , 0.1 , 0.5 ]
408+ vector_field = "user_embedding"
409+ return_fields = ["user" , "credit_score" , "age" , "job" , "location" , "description" ]
410+
411+ hybrid_query = HybridQuery (
412+ text = text ,
413+ text_field_name = text_field ,
414+ yield_text_score_as = "text_score" ,
415+ vector = vector ,
416+ vector_field_name = vector_field ,
417+ yield_vsim_score_as = "vsim_score" ,
418+ combination_method = "RRF" ,
419+ yield_combined_score_as = "hybrid_score" ,
420+ return_fields = return_fields ,
421+ )
422+
423+ results = await async_index .hybrid_search (hybrid_query )
424+ assert isinstance (results , list )
425+ assert len (results ) == 7
426+ for doc in results :
427+ assert doc ["user" ] in [
428+ "john" ,
429+ "derrick" ,
430+ "nancy" ,
431+ "tyler" ,
432+ "tim" ,
433+ "taimur" ,
434+ "joe" ,
435+ "mary" ,
436+ ]
437+ assert int (doc ["age" ]) in [18 , 14 , 94 , 100 , 12 , 15 , 35 ]
438+ assert doc ["job" ] in ["engineer" , "doctor" , "dermatologist" , "CEO" , "dentist" ]
439+ assert doc ["credit_score" ] in ["high" , "low" , "medium" ]
440+
441+ hybrid_query = HybridQuery (
442+ text = text ,
443+ text_field_name = text_field ,
444+ vector = vector ,
445+ vector_field_name = vector_field ,
446+ num_results = 3 ,
447+ combination_method = "RRF" ,
448+ yield_combined_score_as = "hybrid_score" ,
449+ )
450+
451+ results = await async_index .hybrid_search (hybrid_query )
452+ assert len (results ) == 3
453+ assert (
454+ results [0 ]["hybrid_score" ]
455+ >= results [1 ]["hybrid_score" ]
456+ >= results [2 ]["hybrid_score" ]
457+ )
0 commit comments