Skip to content

Commit 23328fa

Browse files
authored
Java SQLIte connector: implement getNearestMatchingAsync and getNearestMatchesAsync (#2311)
### Motivation and Context <!-- Thank you for your contribution to the semantic-kernel repo! Please help reviewers and future users, providing the following information: 1. Why is this change required? 2. What problem does it solve? 3. What scenario does it contribute to? 4. If it fixes an open issue, please link to the issue here. --> These methods were stubbed to return null. ### Description <!-- Describe your changes, the overall approach, the underlying design. These notes will help understanding how your code works. Thanks! --> Provide implementation of these methods. Since the methods make use of vector operations which are encapsulated in SemanticVector, the vectoroperations package and SemanticVector class were moved to the api module under com.microsoft.semantickernel.ai.embeddings ### Contribution Checklist <!-- Before submitting this PR, please make sure: --> - [ ] The code builds clean without any errors or warnings - [ ] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [ ] All unit tests pass, and I have added new tests where possible - [ ] I didn't break anyone 😄
1 parent 974dabf commit 23328fa

File tree

8 files changed

+1329
-359
lines changed

8 files changed

+1329
-359
lines changed

java/connectors/semantickernel-connectors-memory-sqlite/src/main/java/com/microsoft/semantickernel/connectors/memory/sqlite/Database.java

Lines changed: 132 additions & 62 deletions
Large diffs are not rendered by default.

java/connectors/semantickernel-connectors-memory-sqlite/src/main/java/com/microsoft/semantickernel/connectors/memory/sqlite/SQLConnectorException.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1+
// Copyright (c) Microsoft. All rights reserved.
12
package com.microsoft.semantickernel.connectors.memory.sqlite;
23

34
import com.microsoft.semantickernel.SKException;
45

5-
/**
6-
* Exception thrown by the SQLite connector.
7-
*/
6+
/** Exception thrown by the SQLite connector. */
87
public class SQLConnectorException extends SKException {
98

109
/**
1110
* Create an exception with a message
11+
*
1212
* @param message a description of the cause of the exception
1313
*/
1414
public SQLConnectorException(String message) {
@@ -17,6 +17,7 @@ public SQLConnectorException(String message) {
1717

1818
/**
1919
* Create an exception with a message and a cause
20+
*
2021
* @param message a description of the cause of the exception
2122
* @param cause the cause of the exception
2223
*/

java/connectors/semantickernel-connectors-memory-sqlite/src/main/java/com/microsoft/semantickernel/connectors/memory/sqlite/SqliteMemoryStore.java renamed to java/connectors/semantickernel-connectors-memory-sqlite/src/main/java/com/microsoft/semantickernel/connectors/memory/sqlite/SQLiteMemoryStore.java

Lines changed: 124 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,39 +3,43 @@
33

44
import com.fasterxml.jackson.core.JsonProcessingException;
55
import com.fasterxml.jackson.databind.ObjectMapper;
6+
import com.microsoft.semantickernel.SKException;
67
import com.microsoft.semantickernel.ai.embeddings.Embedding;
8+
import com.microsoft.semantickernel.memory.MemoryException;
9+
import com.microsoft.semantickernel.memory.MemoryException.ErrorCodes;
710
import com.microsoft.semantickernel.memory.MemoryRecord;
811
import com.microsoft.semantickernel.memory.MemoryStore;
9-
10-
import reactor.core.publisher.Flux;
11-
import reactor.core.publisher.Mono;
12-
import reactor.util.function.Tuple2;
13-
1412
import java.sql.*;
15-
import java.time.ZonedDateTime;
1613
import java.util.ArrayList;
1714
import java.util.Collection;
15+
import java.util.Comparator;
1816
import java.util.List;
17+
import java.util.Objects;
1918
import java.util.stream.Collectors;
20-
2119
import javax.annotation.Nonnull;
20+
import reactor.core.publisher.Flux;
21+
import reactor.core.publisher.Mono;
22+
import reactor.util.function.Tuple2;
23+
import reactor.util.function.Tuples;
2224

23-
public class SqliteMemoryStore implements MemoryStore {
25+
public class SQLiteMemoryStore implements MemoryStore {
2426

2527
private final Database dbConnector;
2628
private Connection dbConnection;
2729

28-
public SqliteMemoryStore() {
30+
public SQLiteMemoryStore() {
2931
this.dbConnector = new Database();
3032
}
3133

32-
public Mono<Void> connectAsync(String filename) throws SQLException {
34+
public Mono<Void> connectAsync(@Nonnull String filename) throws SQLException {
35+
Objects.requireNonNull(filename);
3336
this.dbConnection = DriverManager.getConnection("jdbc:sqlite:" + filename);
3437
return this.dbConnector.createTableAsync(this.dbConnection);
3538
}
3639

3740
@Override
3841
public Mono<Void> createCollectionAsync(@Nonnull String collectionName) {
42+
Objects.requireNonNull(collectionName);
3943
return this.dbConnector.createCollectionAsync(this.dbConnection, collectionName);
4044
}
4145

@@ -46,62 +50,101 @@ public Mono<List<String>> getCollectionsAsync() {
4650

4751
@Override
4852
public Mono<Boolean> doesCollectionExistAsync(@Nonnull String collectionName) {
53+
Objects.requireNonNull(collectionName);
4954
return this.dbConnector.doesCollectionExistsAsync(this.dbConnection, collectionName);
5055
}
5156

5257
@Override
5358
public Mono<Void> deleteCollectionAsync(@Nonnull String collectionName) {
59+
Objects.requireNonNull(collectionName);
5460
return this.dbConnector.deleteCollectionAsync(this.dbConnection, collectionName);
5561
}
5662

5763
@Override
5864
public Mono<String> upsertAsync(@Nonnull String collectionName, @Nonnull MemoryRecord record) {
59-
return this.internalUpsertAsync(collectionName, record);
65+
Objects.requireNonNull(collectionName);
66+
Objects.requireNonNull(record);
67+
return doesCollectionExistAsync(collectionName)
68+
.handle(
69+
(exists, sink) -> {
70+
if (!exists) {
71+
sink.error(
72+
new MemoryException(
73+
ErrorCodes
74+
.ATTEMPTED_TO_ACCESS_NONEXISTENT_COLLECTION,
75+
collectionName));
76+
return;
77+
}
78+
sink.next(exists);
79+
})
80+
.then(internalUpsertAsync(collectionName, record));
6081
}
6182

6283
private Mono<String> internalUpsertAsync(
6384
@Nonnull String collectionName, @Nonnull MemoryRecord record) {
85+
Objects.requireNonNull(collectionName);
86+
Objects.requireNonNull(record);
6487
try {
6588
Mono<Void> update =
6689
this.dbConnector.updateAsync(
6790
this.dbConnection,
6891
collectionName,
69-
record.getKey(),
92+
record.getMetadata().getId(),
7093
record.getSerializedMetadata(),
7194
record.getSerializedEmbedding(),
72-
record.getTimestamp().toString());
95+
record.getTimestamp());
7396

7497
Mono<Void> insert =
7598
this.dbConnector.insertOrIgnoreAsync(
7699
this.dbConnection,
77100
collectionName,
78-
record.getKey(),
101+
record.getMetadata().getId(),
79102
record.getSerializedMetadata(),
80103
record.getSerializedEmbedding(),
81-
record.getTimestamp().toString());
104+
record.getTimestamp());
82105

83-
return update.then(insert).then(Mono.just(record.getKey()));
106+
return update.then(insert).then(Mono.just(record.getMetadata().getId()));
84107
} catch (JsonProcessingException e) {
85-
throw new RuntimeException(e);
108+
throw new SQLConnectorException("Error serializing MemoryRecord", e);
86109
}
87110
}
88111

89112
@Override
90113
public Mono<Collection<String>> upsertBatchAsync(
91114
@Nonnull String collectionName, @Nonnull Collection<MemoryRecord> records) {
92-
return Flux.fromIterable(records)
93-
.flatMap(record -> internalUpsertAsync(collectionName, record))
94-
.collect(Collectors.toCollection(ArrayList::new));
115+
Objects.requireNonNull(collectionName);
116+
Objects.requireNonNull(records);
117+
return doesCollectionExistAsync(collectionName)
118+
.handle(
119+
(exists, sink) -> {
120+
if (!exists) {
121+
sink.error(
122+
new MemoryException(
123+
ErrorCodes
124+
.ATTEMPTED_TO_ACCESS_NONEXISTENT_COLLECTION,
125+
collectionName));
126+
return;
127+
}
128+
sink.next(exists);
129+
})
130+
.then(
131+
Flux.fromIterable(records)
132+
.concatMap(record -> internalUpsertAsync(collectionName, record))
133+
.collect(Collectors.toCollection(ArrayList::new)));
95134
}
96135

97136
@Override
98137
public Mono<MemoryRecord> getAsync(
99138
@Nonnull String collectionName, @Nonnull String key, boolean withEmbedding) {
139+
Objects.requireNonNull(collectionName);
140+
Objects.requireNonNull(key);
100141
return this.internalGetAsync(collectionName, key, withEmbedding);
101142
}
102143

103144
private Mono<MemoryRecord> internalGetAsync(
104145
@Nonnull String collectionName, @Nonnull String key, boolean withEmbedding) {
146+
Objects.requireNonNull(collectionName);
147+
Objects.requireNonNull(key);
105148
Mono<Database.DatabaseEntry> entry =
106149
this.dbConnector.readAsync(this.dbConnection, collectionName, key);
107150

@@ -124,17 +167,15 @@ private Mono<MemoryRecord> internalGetAsync(
124167
.getEmbedding(),
125168
Embedding.class),
126169
databaseEntry.getKey(),
127-
ZonedDateTime.parse(
128-
databaseEntry.getTimestamp()));
170+
databaseEntry.getTimestamp());
129171
}
130172
return MemoryRecord.fromJsonMetadata(
131173
databaseEntry.getMetadata(),
132174
Embedding.empty(),
133175
databaseEntry.getKey(),
134-
ZonedDateTime.parse(
135-
databaseEntry.getTimestamp()));
176+
databaseEntry.getTimestamp());
136177
} catch (JsonProcessingException e) {
137-
throw new RuntimeException(e);
178+
throw new SQLConnectorException("Error deserializing database entry", e);
138179
}
139180
});
140181
});
@@ -145,19 +186,25 @@ public Mono<Collection<MemoryRecord>> getBatchAsync(
145186
@Nonnull String collectionName,
146187
@Nonnull Collection<String> keys,
147188
boolean withEmbeddings) {
189+
Objects.requireNonNull(collectionName);
190+
Objects.requireNonNull(keys);
148191
return Flux.fromIterable(keys)
149192
.flatMap(key -> internalGetAsync(collectionName, key, withEmbeddings))
150193
.collect(Collectors.toCollection(ArrayList::new));
151194
}
152195

153196
@Override
154197
public Mono<Void> removeAsync(@Nonnull String collectionName, @Nonnull String key) {
198+
Objects.requireNonNull(collectionName);
199+
Objects.requireNonNull(key);
155200
return this.dbConnector.deleteAsync(this.dbConnection, collectionName, key);
156201
}
157202

158203
@Override
159204
public Mono<Void> removeBatchAsync(
160205
@Nonnull String collectionName, @Nonnull Collection<String> keys) {
206+
Objects.requireNonNull(collectionName);
207+
Objects.requireNonNull(keys);
161208
return Flux.fromIterable(keys)
162209
.flatMap(
163210
key -> this.dbConnector.deleteAsync(this.dbConnection, collectionName, key))
@@ -171,7 +218,46 @@ public Mono<Collection<Tuple2<MemoryRecord, Float>>> getNearestMatchesAsync(
171218
int limit,
172219
double minRelevanceScore,
173220
boolean withEmbeddings) {
174-
return null;
221+
Objects.requireNonNull(collectionName);
222+
Objects.requireNonNull(embedding);
223+
Mono<List<Database.DatabaseEntry>> entries =
224+
this.dbConnector.readAllAsync(this.dbConnection, collectionName);
225+
226+
return entries.flatMap(
227+
databaseEntries -> {
228+
List<Tuple2<MemoryRecord, Float>> nearestMatches = new ArrayList<>();
229+
for (Database.DatabaseEntry entry : databaseEntries) {
230+
if (entry.getEmbedding() == null || entry.getEmbedding().isEmpty()) {
231+
continue;
232+
}
233+
try {
234+
Embedding recordEmbedding =
235+
new ObjectMapper()
236+
.readValue(entry.getEmbedding(), Embedding.class);
237+
float similarity = embedding.cosineSimilarity(recordEmbedding);
238+
if (similarity >= (float) minRelevanceScore) {
239+
MemoryRecord record =
240+
MemoryRecord.fromJsonMetadata(
241+
entry.getMetadata(),
242+
withEmbeddings ? recordEmbedding : null,
243+
entry.getKey(),
244+
entry.getTimestamp());
245+
nearestMatches.add(Tuples.of(record, similarity));
246+
}
247+
} catch (JsonProcessingException e) {
248+
throw new SQLConnectorException("Error deserializing database entry", e);
249+
}
250+
}
251+
List<Tuple2<MemoryRecord, Float>> results =
252+
nearestMatches.stream()
253+
.sorted(
254+
Comparator.comparing(
255+
Tuple2::getT2, (a, b) -> Float.compare(b, a)))
256+
.limit(limit)
257+
.collect(Collectors.toList());
258+
259+
return Mono.just(results);
260+
});
175261
}
176262

177263
@Override
@@ -180,13 +266,23 @@ public Mono<Tuple2<MemoryRecord, Float>> getNearestMatchAsync(
180266
@Nonnull Embedding embedding,
181267
double minRelevanceScore,
182268
boolean withEmbedding) {
183-
return null;
269+
Objects.requireNonNull(collectionName);
270+
Objects.requireNonNull(embedding);
271+
return getNearestMatchesAsync(
272+
collectionName, embedding, 1, minRelevanceScore, withEmbedding)
273+
.flatMap(
274+
nearestMatches -> {
275+
if (nearestMatches.isEmpty()) {
276+
return Mono.empty();
277+
}
278+
return Mono.just(nearestMatches.iterator().next());
279+
});
184280
}
185281

186282
public static class Builder implements MemoryStore.Builder {
187283
@Override
188284
public MemoryStore build() {
189-
return new SqliteMemoryStore();
285+
return new SQLiteMemoryStore();
190286
}
191287
}
192288
}

0 commit comments

Comments
 (0)