33
44import com .fasterxml .jackson .core .JsonProcessingException ;
55import com .fasterxml .jackson .databind .ObjectMapper ;
6+ import com .microsoft .semantickernel .SKException ;
67import com .microsoft .semantickernel .ai .embeddings .Embedding ;
8+ import com .microsoft .semantickernel .memory .MemoryException ;
9+ import com .microsoft .semantickernel .memory .MemoryException .ErrorCodes ;
710import com .microsoft .semantickernel .memory .MemoryRecord ;
811import com .microsoft .semantickernel .memory .MemoryStore ;
9-
10- import reactor .core .publisher .Flux ;
11- import reactor .core .publisher .Mono ;
12- import reactor .util .function .Tuple2 ;
13-
1412import java .sql .*;
15- import java .time .ZonedDateTime ;
1613import java .util .ArrayList ;
1714import java .util .Collection ;
15+ import java .util .Comparator ;
1816import java .util .List ;
17+ import java .util .Objects ;
1918import java .util .stream .Collectors ;
20-
2119import javax .annotation .Nonnull ;
20+ import reactor .core .publisher .Flux ;
21+ import reactor .core .publisher .Mono ;
22+ import reactor .util .function .Tuple2 ;
23+ import reactor .util .function .Tuples ;
2224
23- public class SqliteMemoryStore implements MemoryStore {
25+ public class SQLiteMemoryStore implements MemoryStore {
2426
2527 private final Database dbConnector ;
2628 private Connection dbConnection ;
2729
28- public SqliteMemoryStore () {
30+ public SQLiteMemoryStore () {
2931 this .dbConnector = new Database ();
3032 }
3133
32- public Mono <Void > connectAsync (String filename ) throws SQLException {
34+ public Mono <Void > connectAsync (@ Nonnull String filename ) throws SQLException {
35+ Objects .requireNonNull (filename );
3336 this .dbConnection = DriverManager .getConnection ("jdbc:sqlite:" + filename );
3437 return this .dbConnector .createTableAsync (this .dbConnection );
3538 }
3639
3740 @ Override
3841 public Mono <Void > createCollectionAsync (@ Nonnull String collectionName ) {
42+ Objects .requireNonNull (collectionName );
3943 return this .dbConnector .createCollectionAsync (this .dbConnection , collectionName );
4044 }
4145
@@ -46,62 +50,101 @@ public Mono<List<String>> getCollectionsAsync() {
4650
4751 @ Override
4852 public Mono <Boolean > doesCollectionExistAsync (@ Nonnull String collectionName ) {
53+ Objects .requireNonNull (collectionName );
4954 return this .dbConnector .doesCollectionExistsAsync (this .dbConnection , collectionName );
5055 }
5156
5257 @ Override
5358 public Mono <Void > deleteCollectionAsync (@ Nonnull String collectionName ) {
59+ Objects .requireNonNull (collectionName );
5460 return this .dbConnector .deleteCollectionAsync (this .dbConnection , collectionName );
5561 }
5662
5763 @ Override
5864 public Mono <String > upsertAsync (@ Nonnull String collectionName , @ Nonnull MemoryRecord record ) {
59- return this .internalUpsertAsync (collectionName , record );
65+ Objects .requireNonNull (collectionName );
66+ Objects .requireNonNull (record );
67+ return doesCollectionExistAsync (collectionName )
68+ .handle (
69+ (exists , sink ) -> {
70+ if (!exists ) {
71+ sink .error (
72+ new MemoryException (
73+ ErrorCodes
74+ .ATTEMPTED_TO_ACCESS_NONEXISTENT_COLLECTION ,
75+ collectionName ));
76+ return ;
77+ }
78+ sink .next (exists );
79+ })
80+ .then (internalUpsertAsync (collectionName , record ));
6081 }
6182
6283 private Mono <String > internalUpsertAsync (
6384 @ Nonnull String collectionName , @ Nonnull MemoryRecord record ) {
85+ Objects .requireNonNull (collectionName );
86+ Objects .requireNonNull (record );
6487 try {
6588 Mono <Void > update =
6689 this .dbConnector .updateAsync (
6790 this .dbConnection ,
6891 collectionName ,
69- record .getKey (),
92+ record .getMetadata (). getId (),
7093 record .getSerializedMetadata (),
7194 record .getSerializedEmbedding (),
72- record .getTimestamp (). toString () );
95+ record .getTimestamp ());
7396
7497 Mono <Void > insert =
7598 this .dbConnector .insertOrIgnoreAsync (
7699 this .dbConnection ,
77100 collectionName ,
78- record .getKey (),
101+ record .getMetadata (). getId (),
79102 record .getSerializedMetadata (),
80103 record .getSerializedEmbedding (),
81- record .getTimestamp (). toString () );
104+ record .getTimestamp ());
82105
83- return update .then (insert ).then (Mono .just (record .getKey ()));
106+ return update .then (insert ).then (Mono .just (record .getMetadata (). getId ()));
84107 } catch (JsonProcessingException e ) {
85- throw new RuntimeException ( e );
108+ throw new SQLConnectorException ( "Error serializing MemoryRecord" , e );
86109 }
87110 }
88111
89112 @ Override
90113 public Mono <Collection <String >> upsertBatchAsync (
91114 @ Nonnull String collectionName , @ Nonnull Collection <MemoryRecord > records ) {
92- return Flux .fromIterable (records )
93- .flatMap (record -> internalUpsertAsync (collectionName , record ))
94- .collect (Collectors .toCollection (ArrayList ::new ));
115+ Objects .requireNonNull (collectionName );
116+ Objects .requireNonNull (records );
117+ return doesCollectionExistAsync (collectionName )
118+ .handle (
119+ (exists , sink ) -> {
120+ if (!exists ) {
121+ sink .error (
122+ new MemoryException (
123+ ErrorCodes
124+ .ATTEMPTED_TO_ACCESS_NONEXISTENT_COLLECTION ,
125+ collectionName ));
126+ return ;
127+ }
128+ sink .next (exists );
129+ })
130+ .then (
131+ Flux .fromIterable (records )
132+ .concatMap (record -> internalUpsertAsync (collectionName , record ))
133+ .collect (Collectors .toCollection (ArrayList ::new )));
95134 }
96135
97136 @ Override
98137 public Mono <MemoryRecord > getAsync (
99138 @ Nonnull String collectionName , @ Nonnull String key , boolean withEmbedding ) {
139+ Objects .requireNonNull (collectionName );
140+ Objects .requireNonNull (key );
100141 return this .internalGetAsync (collectionName , key , withEmbedding );
101142 }
102143
103144 private Mono <MemoryRecord > internalGetAsync (
104145 @ Nonnull String collectionName , @ Nonnull String key , boolean withEmbedding ) {
146+ Objects .requireNonNull (collectionName );
147+ Objects .requireNonNull (key );
105148 Mono <Database .DatabaseEntry > entry =
106149 this .dbConnector .readAsync (this .dbConnection , collectionName , key );
107150
@@ -124,17 +167,15 @@ private Mono<MemoryRecord> internalGetAsync(
124167 .getEmbedding (),
125168 Embedding .class ),
126169 databaseEntry .getKey (),
127- ZonedDateTime .parse (
128- databaseEntry .getTimestamp ()));
170+ databaseEntry .getTimestamp ());
129171 }
130172 return MemoryRecord .fromJsonMetadata (
131173 databaseEntry .getMetadata (),
132174 Embedding .empty (),
133175 databaseEntry .getKey (),
134- ZonedDateTime .parse (
135- databaseEntry .getTimestamp ()));
176+ databaseEntry .getTimestamp ());
136177 } catch (JsonProcessingException e ) {
137- throw new RuntimeException ( e );
178+ throw new SQLConnectorException ( "Error deserializing database entry" , e );
138179 }
139180 });
140181 });
@@ -145,19 +186,25 @@ public Mono<Collection<MemoryRecord>> getBatchAsync(
145186 @ Nonnull String collectionName ,
146187 @ Nonnull Collection <String > keys ,
147188 boolean withEmbeddings ) {
189+ Objects .requireNonNull (collectionName );
190+ Objects .requireNonNull (keys );
148191 return Flux .fromIterable (keys )
149192 .flatMap (key -> internalGetAsync (collectionName , key , withEmbeddings ))
150193 .collect (Collectors .toCollection (ArrayList ::new ));
151194 }
152195
153196 @ Override
154197 public Mono <Void > removeAsync (@ Nonnull String collectionName , @ Nonnull String key ) {
198+ Objects .requireNonNull (collectionName );
199+ Objects .requireNonNull (key );
155200 return this .dbConnector .deleteAsync (this .dbConnection , collectionName , key );
156201 }
157202
158203 @ Override
159204 public Mono <Void > removeBatchAsync (
160205 @ Nonnull String collectionName , @ Nonnull Collection <String > keys ) {
206+ Objects .requireNonNull (collectionName );
207+ Objects .requireNonNull (keys );
161208 return Flux .fromIterable (keys )
162209 .flatMap (
163210 key -> this .dbConnector .deleteAsync (this .dbConnection , collectionName , key ))
@@ -171,7 +218,46 @@ public Mono<Collection<Tuple2<MemoryRecord, Float>>> getNearestMatchesAsync(
171218 int limit ,
172219 double minRelevanceScore ,
173220 boolean withEmbeddings ) {
174- return null ;
221+ Objects .requireNonNull (collectionName );
222+ Objects .requireNonNull (embedding );
223+ Mono <List <Database .DatabaseEntry >> entries =
224+ this .dbConnector .readAllAsync (this .dbConnection , collectionName );
225+
226+ return entries .flatMap (
227+ databaseEntries -> {
228+ List <Tuple2 <MemoryRecord , Float >> nearestMatches = new ArrayList <>();
229+ for (Database .DatabaseEntry entry : databaseEntries ) {
230+ if (entry .getEmbedding () == null || entry .getEmbedding ().isEmpty ()) {
231+ continue ;
232+ }
233+ try {
234+ Embedding recordEmbedding =
235+ new ObjectMapper ()
236+ .readValue (entry .getEmbedding (), Embedding .class );
237+ float similarity = embedding .cosineSimilarity (recordEmbedding );
238+ if (similarity >= (float ) minRelevanceScore ) {
239+ MemoryRecord record =
240+ MemoryRecord .fromJsonMetadata (
241+ entry .getMetadata (),
242+ withEmbeddings ? recordEmbedding : null ,
243+ entry .getKey (),
244+ entry .getTimestamp ());
245+ nearestMatches .add (Tuples .of (record , similarity ));
246+ }
247+ } catch (JsonProcessingException e ) {
248+ throw new SQLConnectorException ("Error deserializing database entry" , e );
249+ }
250+ }
251+ List <Tuple2 <MemoryRecord , Float >> results =
252+ nearestMatches .stream ()
253+ .sorted (
254+ Comparator .comparing (
255+ Tuple2 ::getT2 , (a , b ) -> Float .compare (b , a )))
256+ .limit (limit )
257+ .collect (Collectors .toList ());
258+
259+ return Mono .just (results );
260+ });
175261 }
176262
177263 @ Override
@@ -180,13 +266,23 @@ public Mono<Tuple2<MemoryRecord, Float>> getNearestMatchAsync(
180266 @ Nonnull Embedding embedding ,
181267 double minRelevanceScore ,
182268 boolean withEmbedding ) {
183- return null ;
269+ Objects .requireNonNull (collectionName );
270+ Objects .requireNonNull (embedding );
271+ return getNearestMatchesAsync (
272+ collectionName , embedding , 1 , minRelevanceScore , withEmbedding )
273+ .flatMap (
274+ nearestMatches -> {
275+ if (nearestMatches .isEmpty ()) {
276+ return Mono .empty ();
277+ }
278+ return Mono .just (nearestMatches .iterator ().next ());
279+ });
184280 }
185281
186282 public static class Builder implements MemoryStore .Builder {
187283 @ Override
188284 public MemoryStore build () {
189- return new SqliteMemoryStore ();
285+ return new SQLiteMemoryStore ();
190286 }
191287 }
192288}
0 commit comments