Skip to content

Commit 4b3a1fe

Browse files
committed
Implement async hybrid search
1 parent b691255 commit 4b3a1fe

File tree

2 files changed

+113
-11
lines changed

2 files changed

+113
-11
lines changed

redisvl/index/index.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,6 @@
5656
REDIS_HYBRID_AVAILABLE = True
5757
except ImportError:
5858
REDIS_HYBRID_AVAILABLE = False
59-
HybridResult = None # type: ignore
60-
HybridQuery = None # type: ignore
6159

6260
from redis import __version__ as redis_version
6361
from redis.client import NEVER_DECODE
@@ -1860,6 +1858,26 @@ async def search(self, *args, **kwargs) -> "Result":
18601858
except Exception as e:
18611859
raise RedisSearchError(f"Unexpected error while searching: {str(e)}") from e
18621860

1861+
if REDIS_HYBRID_AVAILABLE:
1862+
1863+
async def hybrid_search(
1864+
self, query: HybridQuery, **kwargs
1865+
) -> List[Dict[str, Any]]:
1866+
client = await self._get_client()
1867+
results: HybridResult = await client.ft(
1868+
self.schema.index.name
1869+
).hybrid_search(
1870+
query=query.query,
1871+
combine_method=query.combination_method,
1872+
post_processing=(
1873+
query.postprocessing_config
1874+
if query.postprocessing_config.build_args()
1875+
else None
1876+
),
1877+
**kwargs,
1878+
) # type: ignore
1879+
return process_hybrid_results(results)
1880+
18631881
async def batch_query(
18641882
self, queries: List[BaseQuery], batch_size: int = 10
18651883
) -> List[List[Dict[str, Any]]]:

tests/integration/test_hybrid.py

Lines changed: 93 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
import pytest
2-
from redis.commands.search.hybrid_result import HybridResult
32

4-
from redisvl.index import SearchIndex
5-
from redisvl.query.filter import FilterExpression, Geo, GeoRadius, Num, Tag, Text
3+
from redisvl.index import AsyncSearchIndex, SearchIndex
4+
from redisvl.query.filter import Geo, GeoRadius, Num, Tag, Text
65
from redisvl.query.hybrid import HybridQuery
76
from redisvl.redis.utils import array_to_buffer
8-
from tests.conftest import skip_if_redis_version_below
7+
from redisvl.schema import IndexSchema
8+
from tests.conftest import (
9+
get_redis_version_async,
10+
skip_if_redis_version_below,
11+
skip_if_redis_version_below_async,
12+
)
913

1014

1115
@pytest.fixture
12-
def index(multi_vector_data, redis_url, worker_id):
13-
14-
index = SearchIndex.from_dict(
16+
def index_schema(worker_id):
17+
return IndexSchema.from_dict(
1518
{
1619
"index": {
1720
"name": f"user_index_{worker_id}",
@@ -56,10 +59,14 @@ def index(multi_vector_data, redis_url, worker_id):
5659
},
5760
},
5861
],
59-
},
60-
redis_url=redis_url,
62+
}
6163
)
6264

65+
66+
@pytest.fixture
67+
def index(index_schema, multi_vector_data, redis_url):
68+
index = SearchIndex(schema=index_schema, redis_url=redis_url)
69+
6370
# create the index (no data yet)
6471
index.create(overwrite=True)
6572

@@ -81,6 +88,24 @@ def hash_preprocess(item: dict) -> dict:
8188
index.delete(drop=True)
8289

8390

91+
@pytest.fixture
92+
async def async_index(index_schema, multi_vector_data, async_client):
93+
index = AsyncSearchIndex(schema=index_schema, redis_client=async_client)
94+
await index.create(overwrite=True)
95+
96+
def hash_preprocess(item: dict) -> dict:
97+
return {
98+
**item,
99+
"user_embedding": array_to_buffer(item["user_embedding"], "float32"),
100+
"image_embedding": array_to_buffer(item["image_embedding"], "float32"),
101+
"audio_embedding": array_to_buffer(item["audio_embedding"], "float64"),
102+
}
103+
104+
await index.load(multi_vector_data, preprocess=hash_preprocess)
105+
yield index
106+
await index.delete(drop=True)
107+
108+
84109
def test_hybrid_query(index):
85110
skip_if_redis_version_below(index.client, "8.4.0")
86111

@@ -371,3 +396,62 @@ def test_hybrid_query_word_weights(index, scorer):
371396

372397
weighted_results = index.hybrid_search(weighted_query)
373398
assert weighted_results != unweighted_results
399+
400+
401+
@pytest.mark.asyncio
402+
async def test_hybrid_query_async(async_index):
403+
await skip_if_redis_version_below_async(async_index.client, "8.4.0")
404+
405+
text = "a medical professional with expertise in lung cancer"
406+
text_field = "description"
407+
vector = [0.1, 0.1, 0.5]
408+
vector_field = "user_embedding"
409+
return_fields = ["user", "credit_score", "age", "job", "location", "description"]
410+
411+
hybrid_query = HybridQuery(
412+
text=text,
413+
text_field_name=text_field,
414+
yield_text_score_as="text_score",
415+
vector=vector,
416+
vector_field_name=vector_field,
417+
yield_vsim_score_as="vsim_score",
418+
combination_method="RRF",
419+
yield_combined_score_as="hybrid_score",
420+
return_fields=return_fields,
421+
)
422+
423+
results = await async_index.hybrid_search(hybrid_query)
424+
assert isinstance(results, list)
425+
assert len(results) == 7
426+
for doc in results:
427+
assert doc["user"] in [
428+
"john",
429+
"derrick",
430+
"nancy",
431+
"tyler",
432+
"tim",
433+
"taimur",
434+
"joe",
435+
"mary",
436+
]
437+
assert int(doc["age"]) in [18, 14, 94, 100, 12, 15, 35]
438+
assert doc["job"] in ["engineer", "doctor", "dermatologist", "CEO", "dentist"]
439+
assert doc["credit_score"] in ["high", "low", "medium"]
440+
441+
hybrid_query = HybridQuery(
442+
text=text,
443+
text_field_name=text_field,
444+
vector=vector,
445+
vector_field_name=vector_field,
446+
num_results=3,
447+
combination_method="RRF",
448+
yield_combined_score_as="hybrid_score",
449+
)
450+
451+
results = await async_index.hybrid_search(hybrid_query)
452+
assert len(results) == 3
453+
assert (
454+
results[0]["hybrid_score"]
455+
>= results[1]["hybrid_score"]
456+
>= results[2]["hybrid_score"]
457+
)

0 commit comments

Comments
 (0)