1414using NRedisStack . Search ;
1515using NRedisStack . Search . Literals . Enums ;
1616using StackExchange . Redis ;
17+ using static NRedisStack . Search . Schema . VectorField ;
1718
1819namespace Microsoft . SemanticKernel . Connectors . Memory . Redis ;
1920
@@ -23,18 +24,65 @@ namespace Microsoft.SemanticKernel.Connectors.Memory.Redis;
2324/// <remarks>The embedded data is saved to the Redis server database specified in the constructor.
2425/// Similarity search capability is provided through the RediSearch module. Use RediSearch's "Index" to implement "Collection".
2526/// </remarks>
26- public sealed class RedisMemoryStore : IMemoryStore
27+ public class RedisMemoryStore : IMemoryStore , IDisposable
2728{
2829 /// <summary>
2930 /// Create a new instance of semantic memory using Redis.
3031 /// </summary>
31- /// <param name="database">The database of the redis server.</param>
32- /// <param name="vectorSize">Embedding vector size</param>
33- public RedisMemoryStore ( IDatabase database , int vectorSize )
32+ /// <param name="database">The database of the Redis server.</param>
33+ /// <param name="vectorSize">Embedding vector size, defaults to 1536</param>
34+ /// <param name="vectorIndexAlgorithm">Indexing algorithm for vectors, defaults to "HNSW"</param>
35+ /// <param name="vectorDistanceMetric">Metric for measuring vector distances, defaults to "COSINE"</param>
36+ /// <param name="queryDialect">Query dialect, must be 2 or greater for vector similarity searching, defaults to 2</param>
37+ public RedisMemoryStore (
38+ IDatabase database ,
39+ int vectorSize = DefaultVectorSize ,
40+ VectorAlgo vectorIndexAlgorithm = DefaultIndexAlgorithm ,
41+ VectorDistanceMetric vectorDistanceMetric = DefaultDistanceMetric ,
42+ int queryDialect = DefaultQueryDialect )
3443 {
44+ if ( vectorSize <= 0 )
45+ {
46+ throw new ArgumentException (
47+ $ "Invalid vector size: { vectorSize } . Vector size must be a positive integer.", nameof ( vectorSize ) ) ;
48+ }
49+
3550 this . _database = database ;
3651 this . _vectorSize = vectorSize ;
3752 this . _ft = database . FT ( ) ;
53+ this . _vectorIndexAlgorithm = vectorIndexAlgorithm ;
54+ this . _vectorDistanceMetric = vectorDistanceMetric . ToString ( ) ;
55+ this . _queryDialect = queryDialect ;
56+ }
57+
58+ /// <summary>
59+ /// Create a new instance of semantic memory using Redis.
60+ /// </summary>
61+ /// <param name="connectionString">Provide connection URL to a Redis instance</param>
62+ /// <param name="vectorSize">Embedding vector size, defaults to 1536</param>
63+ /// <param name="vectorIndexAlgorithm">Indexing algorithm for vectors, defaults to "HNSW"</param>
64+ /// <param name="vectorDistanceMetric">Metric for measuring vector distances, defaults to "COSINE"</param>
65+ /// <param name="queryDialect">Query dialect, must be 2 or greater for vector similarity searching, defaults to 2</param>
66+ public RedisMemoryStore (
67+ string connectionString ,
68+ int vectorSize = DefaultVectorSize ,
69+ VectorAlgo vectorIndexAlgorithm = DefaultIndexAlgorithm ,
70+ VectorDistanceMetric vectorDistanceMetric = DefaultDistanceMetric ,
71+ int queryDialect = DefaultQueryDialect )
72+ {
73+ if ( vectorSize <= 0 )
74+ {
75+ throw new ArgumentException (
76+ $ "Invalid vector size: { vectorSize } . Vector size must be a positive integer.", nameof ( vectorSize ) ) ;
77+ }
78+
79+ this . _connection = ConnectionMultiplexer . Connect ( connectionString ) ;
80+ this . _database = this . _connection . GetDatabase ( ) ;
81+ this . _vectorSize = vectorSize ;
82+ this . _ft = this . _database . FT ( ) ;
83+ this . _vectorIndexAlgorithm = vectorIndexAlgorithm ;
84+ this . _vectorDistanceMetric = vectorDistanceMetric . ToString ( ) ;
85+ this . _queryDialect = queryDialect ;
3886 }
3987
4088 /// <inheritdoc />
@@ -54,10 +102,10 @@ public async Task CreateCollectionAsync(string collectionName, CancellationToken
54102 . AddTextField ( "key" )
55103 . AddTextField ( "metadata" )
56104 . AddNumericField ( "timestamp" )
57- . AddVectorField ( "embedding" , VECTOR_INDEX_ALGORITHM , new Dictionary < string , object > {
58- { "TYPE" , VECTOR_TYPE } ,
105+ . AddVectorField ( "embedding" , this . _vectorIndexAlgorithm , new Dictionary < string , object > {
106+ { "TYPE" , DefaultVectorType } ,
59107 { "DIM" , this . _vectorSize } ,
60- { "DISTANCE_METRIC" , VECTOR_DISTANCE_METRIC } ,
108+ { "DISTANCE_METRIC" , this . _vectorDistanceMetric } ,
61109 } ) ;
62110
63111 await this . _ft . CreateAsync ( collectionName , ftCreateParams , schema ) . ConfigureAwait ( false ) ;
@@ -71,7 +119,7 @@ public async Task<bool> DoesCollectionExistAsync(string collectionName, Cancella
71119 await this . _ft . InfoAsync ( collectionName ) . ConfigureAwait ( false ) ;
72120 return true ;
73121 }
74- catch ( RedisServerException ex ) when ( ex . Message == MESSAGE_WHEN_INDEX_DOES_NOT_EXIST )
122+ catch ( RedisServerException ex ) when ( ex . Message . Equals ( IndexDoesNotExistErrorMessage , StringComparison . Ordinal ) )
75123 {
76124 return false ;
77125 }
@@ -112,7 +160,7 @@ public async Task<string> UpsertAsync(string collectionName, MemoryRecord record
112160 await this . _database . HashSetAsync ( GetRedisKey ( collectionName , record . Key ) , new [ ] {
113161 new HashEntry ( "key" , record . Key ) ,
114162 new HashEntry ( "metadata" , record . GetSerializedMetadata ( ) ) ,
115- new HashEntry ( "embedding" , MemoryMarshal . AsBytes ( record . Embedding . Span ) . ToArray ( ) ) ,
163+ new HashEntry ( "embedding" , this . ConvertEmbeddingToBytes ( record . Embedding ) ) ,
116164 new HashEntry ( "timestamp" , ToTimestampLong ( record . Timestamp ) )
117165 } , flags : CommandFlags . None ) . ConfigureAwait ( false ) ;
118166
@@ -155,11 +203,11 @@ public async Task RemoveBatchAsync(string collectionName, IEnumerable<string> ke
155203 }
156204
157205 var query = new Query ( $ "*=>[KNN { limit } @embedding $embedding AS vector_score]")
158- . AddParam ( "embedding" , MemoryMarshal . AsBytes ( embedding . Span ) . ToArray ( ) )
206+ . AddParam ( "embedding" , this . ConvertEmbeddingToBytes ( embedding ) )
159207 . SetSortBy ( "vector_score" )
160208 . ReturnFields ( "key" , "metadata" , "embedding" , "timestamp" , "vector_score" )
161209 . Limit ( 0 , limit )
162- . Dialect ( QUERY_DIALECT ) ;
210+ . Dialect ( this . _queryDialect ) ;
163211
164212 var results = await this . _ft . SearchAsync ( collectionName , query ) . ConfigureAwait ( false ) ;
165213
@@ -198,43 +246,63 @@ public async Task RemoveBatchAsync(string collectionName, IEnumerable<string> ke
198246 cancellationToken : cancellationToken ) . FirstOrDefaultAsync ( cancellationToken : cancellationToken ) . ConfigureAwait ( false ) ;
199247 }
200248
201- #region constants ================================================================================
249+ public void Dispose ( )
250+ {
251+ this . Dispose ( true ) ;
252+ GC . SuppressFinalize ( this ) ;
253+ }
254+
255+ protected virtual void Dispose ( bool disposing )
256+ {
257+ if ( disposing )
258+ {
259+ this . _connection ? . Dispose ( ) ;
260+ }
261+ }
262+
263+ #region private ================================================================================
202264
203265 /// <summary>
204- /// Vector similarity index algorithm. The default value is "HNSW".
205- /// <see href="https://redis.io/docs/stack /search/reference /vectors/#create-a-vector-field"/>
266+ /// Vector similarity index algorithm. Supported algorithms are {FLAT, HNSW}. The default value is "HNSW".
267+ /// <see href="https://redis.io/docs/interact /search-and-query/search /vectors/#create-a-vector-field"/>
206268 /// </summary>
207- internal const Schema . VectorField . VectorAlgo VECTOR_INDEX_ALGORITHM = Schema . VectorField . VectorAlgo . HNSW ;
269+ private const VectorAlgo DefaultIndexAlgorithm = VectorAlgo . HNSW ;
208270
209271 /// <summary>
210- /// Vector type. Supported types are FLOAT32 and FLOAT64. The default value is "FLOAT32".
272+ /// Vector type. Available values are {FLOAT32, FLOAT64}.
273+ /// Value "FLOAT32" is used by default based on <see cref="MemoryRecord.Embedding"/> <see cref="float"/> type.
211274 /// </summary>
212- internal const string VECTOR_TYPE = "FLOAT32" ;
275+ private const string DefaultVectorType = "FLOAT32" ;
213276
214277 /// <summary>
215- /// Supported distance metric, one of {L2, IP, COSINE}. The default value is "COSINE".
278+ /// Supported distance metrics are {L2, IP, COSINE}. The default value is "COSINE".
216279 /// </summary>
217- internal const string VECTOR_DISTANCE_METRIC = " COSINE" ;
280+ private const VectorDistanceMetric DefaultDistanceMetric = VectorDistanceMetric . COSINE ;
218281
219282 /// <summary>
220283 /// Query dialect. To use a vector similarity query, specify DIALECT 2 or higher. The default value is "2".
221- /// <see href="https://redis.io/docs/stack /search/reference /vectors/#querying-vector-fields"/>
284+ /// <see href="https://redis.io/docs/interact /search-and-query/search /vectors/#querying-vector-fields"/>
222285 /// </summary>
223- internal const int QUERY_DIALECT = 2 ;
286+ private const int DefaultQueryDialect = 2 ;
224287
225288 /// <summary>
226- /// Message when index does not exist.
227- /// <see href="https://github.com/RediSearch/RediSearch/blob/master/src/info_command.c#L96"/>
289+ /// Embedding vector size.
228290 /// </summary>
229- internal const string MESSAGE_WHEN_INDEX_DOES_NOT_EXIST = "Unknown Index name" ;
230-
231- #endregion
291+ private const int DefaultVectorSize = 1536 ;
232292
233- #region private ================================================================================
293+ /// <summary>
294+ /// Message when index does not exist.
295+ /// <see href="https://github.com/RediSearch/RediSearch/blob/master/src/info_command.c#L97"/>
296+ /// </summary>
297+ private const string IndexDoesNotExistErrorMessage = "Unknown Index name" ;
234298
235299 private readonly IDatabase _database ;
236300 private readonly int _vectorSize ;
237301 private readonly SearchCommands _ft ;
302+ private readonly ConnectionMultiplexer ? _connection ;
303+ private readonly Schema . VectorField . VectorAlgo _vectorIndexAlgorithm ;
304+ private readonly string _vectorDistanceMetric ;
305+ private readonly int _queryDialect ;
238306
239307 private static long ToTimestampLong ( DateTimeOffset ? timestamp )
240308 {
@@ -295,5 +363,10 @@ private double GetSimilarity(Document document)
295363 return 1 - vectorScore ;
296364 }
297365
366+ private byte [ ] ConvertEmbeddingToBytes ( ReadOnlyMemory < float > embedding )
367+ {
368+ return MemoryMarshal . AsBytes ( embedding . Span ) . ToArray ( ) ;
369+ }
370+
298371 #endregion
299372}
0 commit comments