Skip to content

Commit 782b46e

Browse files
.Net: Implementing cosine similarity in duckdb engine (#2638)
### Motivation and Context <!-- Thank you for your contribution to the semantic-kernel repo! Please help reviewers and future users, providing the following information: 1. Why is this change required? 2. What problem does it solve? 3. What scenario does it contribute to? 4. If it fixes an open issue, please link to the issue here. --> This pr moves cosine similarity calculation into duckDb engine to improve performance and memory usage. ### Description Moving cosine similarity calculation on the server there is no need to marshal out more data than the one scoring enough and meeting the result limit. <!-- Describe your changes, the overall approach, the underlying design. These notes will help understanding how your code works. Thanks! --> ### Contribution Checklist <!-- Before submitting this PR, please make sure: --> - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone 😄 --------- Co-authored-by: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com>
1 parent f6fb9f9 commit 782b46e

File tree

5 files changed

+77
-76
lines changed

5 files changed

+77
-76
lines changed

dotnet/SK-dotnet.sln.DotSettings

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ public void It$SOMENAME$()
223223
<s:Boolean x:Key="/Default/UserDictionary/Words/=testsettings/@EntryIndexedValue">True</s:Boolean>
224224
<s:Boolean x:Key="/Default/UserDictionary/Words/=tldr/@EntryIndexedValue">True</s:Boolean>
225225
<s:Boolean x:Key="/Default/UserDictionary/Words/=Untrust/@EntryIndexedValue">True</s:Boolean>
226+
<s:Boolean x:Key="/Default/UserDictionary/Words/=Upsert/@EntryIndexedValue">True</s:Boolean>
226227
<s:Boolean x:Key="/Default/UserDictionary/Words/=upserted/@EntryIndexedValue">True</s:Boolean>
227228
<s:Boolean x:Key="/Default/UserDictionary/Words/=Upserts/@EntryIndexedValue">True</s:Boolean>
228229
<s:Boolean x:Key="/Default/UserDictionary/Words/=wellknown/@EntryIndexedValue">True</s:Boolean>

dotnet/src/Connectors/Connectors.Memory.DuckDB/Database.cs

Lines changed: 55 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

3+
using System;
34
using System.Collections.Generic;
5+
using System.Globalization;
46
using System.Linq;
57
using System.Runtime.CompilerServices;
68
using System.Threading;
@@ -18,13 +20,25 @@ internal struct DatabaseEntry
1820
public string EmbeddingString { get; set; }
1921

2022
public string? Timestamp { get; set; }
23+
24+
public float Score { get; set; }
2125
}
2226

2327
internal sealed class Database
2428
{
2529
private const string TableName = "SKMemoryTable";
2630

27-
public Database() { }
31+
public Task CreateFunctionsAsync(DuckDBConnection conn, CancellationToken cancellationToken)
32+
{
33+
using var cmd = conn.CreateCommand();
34+
cmd.CommandText = @"
35+
CREATE OR REPLACE MACRO cosine_similarity(a,b) AS (select sum (xy) from (select x * y as xy from (select UNNEST(a) as x, UNNEST(b) as y))) / sqrt(list_aggregate(list_transform(a, x -> x * x), 'sum') * list_aggregate(list_transform(b, x -> x * x), 'sum'));
36+
CREATE OR REPLACE MACRO split_string_of_numbers(t) AS regexp_extract_all(regexp_replace(t,'(\[|\])', '', 'g'), '([+-]?([0-9]*[.])?[0-9]+)(\s*;\s*)?',1);
37+
CREATE OR REPLACE MACRO number_vector_decoder(t) AS list_transform(split_string_of_numbers(t), x -> cast(x AS double));
38+
CREATE OR REPLACE MACRO encode_number_vector(t) AS concat('[',list_aggregate(list_transform(t, x -> cast(x AS string)), 'string_agg', '; '),']');
39+
";
40+
return cmd.ExecuteNonQueryAsync(cancellationToken);
41+
}
2842

2943
public Task CreateTableAsync(DuckDBConnection conn, CancellationToken cancellationToken = default)
3044
{
@@ -34,7 +48,7 @@ public Task CreateTableAsync(DuckDBConnection conn, CancellationToken cancellati
3448
collection TEXT,
3549
key TEXT,
3650
metadata TEXT,
37-
embedding TEXT,
51+
embedding FLOAT[],
3852
timestamp TEXT,
3953
PRIMARY KEY(collection, key))";
4054
return cmd.ExecuteNonQueryAsync(cancellationToken);
@@ -50,27 +64,32 @@ public async Task CreateCollectionAsync(DuckDBConnection conn, string collection
5064

5165
using var cmd = conn.CreateCommand();
5266
cmd.CommandText = $@"
53-
INSERT INTO {TableName} VALUES (?1,?2,?3,?4,?5 ); ";
67+
INSERT INTO {TableName} VALUES (?1,?2,?3, [], ?4 ); ";
5468
cmd.Parameters.Add(new DuckDBParameter(collectionName));
5569
cmd.Parameters.Add(new DuckDBParameter(string.Empty));
5670
cmd.Parameters.Add(new DuckDBParameter(string.Empty));
5771
cmd.Parameters.Add(new DuckDBParameter(string.Empty));
58-
cmd.Parameters.Add(new DuckDBParameter(string.Empty));
5972

6073
await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
6174
}
6275

76+
private static string EncodeFloatArrayToString(float[]? data)
77+
{
78+
var dataArrayString = $"[{string.Join(", ", (data ?? Array.Empty<float>()).Select(n => n.ToString("F10", CultureInfo.InvariantCulture)))}]";
79+
return dataArrayString;
80+
}
81+
82+
[System.Diagnostics.CodeAnalysis.SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "Internal method serializing array of float and numbers")]
6383
public async Task UpdateOrInsertAsync(DuckDBConnection conn,
64-
string collection, string key, string? metadata, string? embedding, string? timestamp, CancellationToken cancellationToken = default)
84+
string collection, string key, string? metadata, float[]? embedding, string? timestamp, CancellationToken cancellationToken = default)
6585
{
86+
await this.DeleteAsync(conn, collection, key, cancellationToken).ConfigureAwait(true);
87+
var embeddingArrayString = EncodeFloatArrayToString(embedding ?? Array.Empty<float>());
6688
using var cmd = conn.CreateCommand();
67-
cmd.CommandText = $@"
68-
INSERT INTO {TableName} VALUES(?1, ?2, ?3, ?4, ?5)
69-
ON CONFLICT (collection, key) DO UPDATE SET metadata=?3, embedding=?4, timestamp=?5; ";
89+
cmd.CommandText = $"INSERT INTO {TableName} VALUES(?1, ?2, ?3, {embeddingArrayString}, ?4)";
7090
cmd.Parameters.Add(new DuckDBParameter(collection));
7191
cmd.Parameters.Add(new DuckDBParameter(key));
7292
cmd.Parameters.Add(new DuckDBParameter(metadata ?? string.Empty));
73-
cmd.Parameters.Add(new DuckDBParameter(embedding ?? string.Empty));
7493
cmd.Parameters.Add(new DuckDBParameter(timestamp ?? string.Empty));
7594
await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
7695
}
@@ -98,14 +117,23 @@ SELECT DISTINCT collection
98117
}
99118
}
100119

101-
public async IAsyncEnumerable<DatabaseEntry> ReadAllAsync(DuckDBConnection conn,
120+
[System.Diagnostics.CodeAnalysis.SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "Internal method serializing array of float and numbers")]
121+
public async IAsyncEnumerable<DatabaseEntry> GetNearestMatchesAsync(
122+
DuckDBConnection conn,
102123
string collectionName,
124+
float[]? embedding,
125+
int limit,
126+
double minRelevanceScore = 0,
103127
[EnumeratorCancellation] CancellationToken cancellationToken = default)
104128
{
129+
var embeddingArrayString = EncodeFloatArrayToString(embedding ?? Array.Empty<float>());
130+
105131
using var cmd = conn.CreateCommand();
106132
cmd.CommandText = $@"
107-
SELECT * FROM {TableName}
108-
WHERE collection=?1;";
133+
SELECT key, metadata, timestamp, cast(embedding as string) as embeddingAsString, cast(cosine_similarity(embedding,{embeddingArrayString}) as FLOAT) as score FROM {TableName}
134+
WHERE collection=?1 AND score >= {minRelevanceScore.ToString("F12", CultureInfo.InvariantCulture)}
135+
ORDER BY score DESC
136+
LIMIT {limit};";
109137
cmd.Parameters.Add(new DuckDBParameter(collectionName));
110138

111139
using var dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false);
@@ -116,10 +144,19 @@ public async IAsyncEnumerable<DatabaseEntry> ReadAllAsync(DuckDBConnection conn,
116144
{
117145
continue;
118146
}
147+
119148
string metadata = dataReader.GetString("metadata");
120-
string embedding = dataReader.GetString("embedding");
149+
string embeddingAsString = dataReader.GetString("embeddingAsString");
121150
string timestamp = dataReader.GetString("timestamp");
122-
yield return new DatabaseEntry() { Key = key, MetadataString = metadata, EmbeddingString = embedding, Timestamp = timestamp };
151+
float score = dataReader.GetFloat("score");
152+
yield return new DatabaseEntry
153+
{
154+
Key = key,
155+
MetadataString = metadata,
156+
EmbeddingString = embeddingAsString,
157+
Timestamp = timestamp,
158+
Score = score
159+
};
123160
}
124161
}
125162

@@ -130,7 +167,7 @@ public async IAsyncEnumerable<DatabaseEntry> ReadAllAsync(DuckDBConnection conn,
130167
{
131168
using var cmd = conn.CreateCommand();
132169
cmd.CommandText = $@"
133-
SELECT * FROM {TableName}
170+
SELECT metadata, timestamp, cast(embedding as string) as embeddingAsString FROM {TableName}
134171
WHERE collection=?1
135172
AND key=?2; ";
136173
cmd.Parameters.Add(new DuckDBParameter(collectionName));
@@ -140,13 +177,13 @@ public async IAsyncEnumerable<DatabaseEntry> ReadAllAsync(DuckDBConnection conn,
140177
if (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false))
141178
{
142179
string metadata = dataReader.GetString(dataReader.GetOrdinal("metadata"));
143-
string embedding = dataReader.GetString(dataReader.GetOrdinal("embedding"));
180+
string embeddingAsString = dataReader.GetString(dataReader.GetOrdinal("embeddingAsString"));
144181
string timestamp = dataReader.GetString(dataReader.GetOrdinal("timestamp"));
145-
return new DatabaseEntry()
182+
return new DatabaseEntry
146183
{
147184
Key = key,
148185
MetadataString = metadata,
149-
EmbeddingString = embedding,
186+
EmbeddingString = embeddingAsString,
150187
Timestamp = timestamp
151188
};
152189
}
@@ -175,15 +212,4 @@ DELETE FROM {TableName}
175212
cmd.Parameters.Add(new DuckDBParameter(key));
176213
return cmd.ExecuteNonQueryAsync(cancellationToken);
177214
}
178-
179-
public Task DeleteEmptyAsync(DuckDBConnection conn, string collectionName, CancellationToken cancellationToken = default)
180-
{
181-
using var cmd = conn.CreateCommand();
182-
cmd.CommandText = $@"
183-
DELETE FROM {TableName}
184-
WHERE collection=?1
185-
AND key IS NULL";
186-
cmd.Parameters.Add(new DuckDBParameter(collectionName));
187-
return cmd.ExecuteNonQueryAsync(cancellationToken);
188-
}
189215
}

dotnet/src/Connectors/Connectors.Memory.DuckDB/DuckDBExtensions.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,10 @@ public static string GetString(this DbDataReader reader, string fieldName)
1111
int ordinal = reader.GetOrdinal(fieldName);
1212
return reader.GetString(ordinal);
1313
}
14+
15+
public static float GetFloat(this DbDataReader reader, string fieldName)
16+
{
17+
int ordinal = reader.GetOrdinal(fieldName);
18+
return reader.GetFloat(ordinal);
19+
}
1420
}

dotnet/src/Connectors/Connectors.Memory.DuckDB/DuckDBMemoryStore.cs

Lines changed: 14 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
using System.Threading;
1010
using System.Threading.Tasks;
1111
using DuckDB.NET.Data;
12-
using Microsoft.SemanticKernel.AI.Embeddings.VectorOperations;
1312
using Microsoft.SemanticKernel.Memory;
1413
using Microsoft.SemanticKernel.Text;
1514

@@ -21,7 +20,7 @@ namespace Microsoft.SemanticKernel.Connectors.Memory.DuckDB;
2120
/// <remarks>The data is saved to a database file, specified in the constructor.
2221
/// The data persists between subsequent instances. Only one instance may access the file at a time.
2322
/// The caller is responsible for deleting the file.</remarks>
24-
public class DuckDBMemoryStore : IMemoryStore, IDisposable
23+
public sealed class DuckDBMemoryStore : IMemoryStore, IDisposable
2524
{
2625
/// <summary>
2726
/// Connect a DuckDB database
@@ -31,19 +30,18 @@ public class DuckDBMemoryStore : IMemoryStore, IDisposable
3130
public static async Task<DuckDBMemoryStore> ConnectAsync(string filename,
3231
CancellationToken cancellationToken = default)
3332
{
34-
var memoryStore = new DuckDBMemoryStore($"Data Source={filename}");
33+
var memoryStore = new DuckDBMemoryStore(filename);
3534
return await InitialiseMemoryStoreAsync(memoryStore, cancellationToken).ConfigureAwait(false);
3635
}
3736

3837
/// <summary>
3938
/// Connect a in memory DuckDB database
4039
/// </summary>
4140
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
42-
public static async Task<DuckDBMemoryStore> ConnectAsync(
41+
public static Task<DuckDBMemoryStore> ConnectAsync(
4342
CancellationToken cancellationToken = default)
4443
{
45-
var memoryStore = new DuckDBMemoryStore(":memory:");
46-
return await InitialiseMemoryStoreAsync(memoryStore, cancellationToken).ConfigureAwait(false);
44+
return ConnectAsync(":memory:", cancellationToken);
4745
}
4846

4947
/// <summary>
@@ -154,19 +152,14 @@ public async Task RemoveBatchAsync(string collectionName, IEnumerable<string> ke
154152
var collectionMemories = new List<MemoryRecord>();
155153
List<(MemoryRecord Record, double Score)> embeddings = new();
156154

157-
await foreach (var record in this.GetAllAsync(collectionName, cancellationToken))
155+
await foreach (var dbEntry in this._dbConnector.GetNearestMatchesAsync(this._dbConnection, collectionName, embedding.ToArray(), limit, minRelevanceScore, cancellationToken))
158156
{
159-
if (record != null)
160-
{
161-
double similarity = embedding
162-
.Span
163-
.CosineSimilarity(record.Embedding.Span);
164-
if (similarity >= minRelevanceScore)
165-
{
166-
var entry = withEmbeddings ? record : MemoryRecord.FromMetadata(record.Metadata, ReadOnlyMemory<float>.Empty, record.Key, record.Timestamp);
167-
embeddings.Add(new(entry, similarity));
168-
}
169-
}
157+
var entry = MemoryRecord.FromJsonMetadata(
158+
json: dbEntry.MetadataString,
159+
withEmbeddings ? JsonSerializer.Deserialize<ReadOnlyMemory<float>>(dbEntry.EmbeddingString, s_jsonSerializerOptions) : Array.Empty<float>(),
160+
dbEntry.Key,
161+
ParseTimestamp(dbEntry.Timestamp));
162+
embeddings.Add(new(entry, dbEntry.Score));
170163
}
171164

172165
foreach (var item in embeddings.OrderByDescending(l => l.Score).Take(limit))
@@ -197,7 +190,7 @@ public void Dispose()
197190

198191
#region protected ================================================================================
199192

200-
protected virtual void Dispose(bool disposing)
193+
private void Dispose(bool disposing)
201194
{
202195
if (!this._disposedValue)
203196
{
@@ -223,6 +216,7 @@ private static async Task<DuckDBMemoryStore> InitialiseMemoryStoreAsync(DuckDBMe
223216
{
224217
await memoryStore._dbConnection.OpenAsync(cancellationToken).ConfigureAwait(false);
225218
await memoryStore._dbConnector.CreateTableAsync(memoryStore._dbConnection, cancellationToken).ConfigureAwait(false);
219+
await memoryStore._dbConnector.CreateFunctionsAsync(memoryStore._dbConnection, cancellationToken).ConfigureAwait(false);
226220
return memoryStore;
227221
}
228222

@@ -237,16 +231,6 @@ private DuckDBMemoryStore(string filename)
237231
this._disposedValue = false;
238232
}
239233

240-
/// <summary>
241-
/// Constructor
242-
/// </summary>
243-
private DuckDBMemoryStore()
244-
{
245-
this._dbConnector = new Database();
246-
this._dbConnection = new DuckDBConnection("Data Source=:memory:;");
247-
this._disposedValue = false;
248-
}
249-
250234
/// <summary>
251235
/// Constructor
252236
/// </summary>
@@ -274,22 +258,6 @@ private DuckDBMemoryStore(DuckDBConnection connection)
274258
return null;
275259
}
276260

277-
private async IAsyncEnumerable<MemoryRecord> GetAllAsync(string collectionName, [EnumeratorCancellation] CancellationToken cancellationToken = default)
278-
{
279-
// delete empty entry in the database if it exists (see CreateCollection)
280-
await this._dbConnector.DeleteEmptyAsync(this._dbConnection, collectionName, cancellationToken).ConfigureAwait(false);
281-
282-
await foreach (DatabaseEntry dbEntry in this._dbConnector.ReadAllAsync(this._dbConnection, collectionName, cancellationToken))
283-
{
284-
var dbEntryEmbeddingString = dbEntry.EmbeddingString;
285-
ReadOnlyMemory<float> vector = JsonSerializer.Deserialize<ReadOnlyMemory<float>>(dbEntryEmbeddingString, s_jsonSerializerOptions);
286-
287-
var record = MemoryRecord.FromJsonMetadata(dbEntry.MetadataString, vector, dbEntry.Key, ParseTimestamp(dbEntry.Timestamp));
288-
289-
yield return record;
290-
}
291-
}
292-
293261
private async Task<string> InternalUpsertAsync(DuckDBConnection connection, string collectionName, MemoryRecord record, CancellationToken cancellationToken)
294262
{
295263
record.Key = record.Metadata.Id;
@@ -298,7 +266,7 @@ await this._dbConnector.UpdateOrInsertAsync(conn: connection,
298266
collection: collectionName,
299267
key: record.Key,
300268
metadata: record.GetSerializedMetadata(),
301-
embedding: JsonSerializer.Serialize(record.Embedding, s_jsonSerializerOptions),
269+
embedding: record.Embedding.ToArray(),
302270
timestamp: ToTimestampString(record.Timestamp),
303271
cancellationToken: cancellationToken).ConfigureAwait(false);
304272

dotnet/src/Connectors/Connectors.UnitTests/Memory/DuckDB/DuckDBMemoryStoreTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ public async Task GetAsyncReturnsEmptyEmbeddingUnlessSpecifiedAsync()
182182
Assert.NotNull(actualDefault);
183183
Assert.NotNull(actualWithEmbedding);
184184
Assert.True(actualDefault.Embedding.IsEmpty);
185-
Assert.False(actualWithEmbedding.Embedding.IsEmpty);
185+
Assert.Equal(actualWithEmbedding.Embedding.ToArray(), testRecord.Embedding.ToArray());
186186
}
187187

188188
[Fact]

0 commit comments

Comments
 (0)