Skip to content

Commit 77555f7

Browse files
committed
Allow listening to model catalog insertions
This will be used by the Shutdown to track models that were not backed up (instead of dropping the model). Also metrics can be implemented based on this
1 parent 6ec01a2 commit 77555f7

File tree

4 files changed

+63
-0
lines changed

4 files changed

+63
-0
lines changed

model-catalog-api/src/main/java/org/neo4j/gds/core/model/ModelCatalog.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828

2929
public interface ModelCatalog {
3030

31+
void registerListener(ModelCatalogListener listener);
32+
3133
void set(Model<?, ?, ?> model);
3234

3335
<D, C extends ModelConfig, I extends ToMapConvertible> Model<D, C, I> get(
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.core.model;
21+
22+
public interface ModelCatalogListener {
23+
24+
void onInsert(Model<?, ?, ?> model);
25+
}

open-model-catalog/src/main/java/org/neo4j/gds/core/model/OpenModelCatalog.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import java.util.ArrayList;
2828
import java.util.Collection;
29+
import java.util.List;
2930
import java.util.Locale;
3031
import java.util.Map;
3132
import java.util.NoSuchElementException;
@@ -38,8 +39,16 @@ public final class OpenModelCatalog implements ModelCatalog {
3839

3940
private final Map<String, OpenUserCatalog> userCatalogs;
4041

42+
private final List<ModelCatalogListener> listeners;
43+
4144
public OpenModelCatalog() {
4245
this.userCatalogs = new ConcurrentHashMap<>();
46+
this.listeners = new ArrayList<>();
47+
}
48+
49+
@Override
50+
public void registerListener(ModelCatalogListener listener) {
51+
listeners.add(listener);
4352
}
4453

4554
@Override
@@ -51,6 +60,8 @@ public void set(Model<?, ?, ?> model) {
5160
userCatalog.set(model);
5261
return userCatalog;
5362
});
63+
64+
listeners.forEach(listener -> listener.onInsert(model));
5465
}
5566

5667
@Override

open-model-catalog/src/test/java/org/neo4j/gds/core/model/OpenModelCatalogTest.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
*/
2020
package org.neo4j.gds.core.model;
2121

22+
import org.eclipse.collections.impl.Counter;
2223
import org.junit.jupiter.api.Test;
2324
import org.neo4j.gds.annotation.Configuration;
2425
import org.neo4j.gds.annotation.ValueClass;
@@ -118,6 +119,30 @@ void shouldStoreModelsPerType() {
118119
);
119120
}
120121

122+
@Test
123+
void shouldNotifyListeners() {
124+
var counter = new Counter();
125+
modelCatalog.registerListener(model -> counter.increment());
126+
127+
var model = Model.of(
128+
"testAlgo",
129+
GRAPH_SCHEMA,
130+
"testTrainData",
131+
TestTrainConfig.of(USERNAME, "testModel"),
132+
new TestCustomInfo()
133+
);
134+
135+
assertThat(counter.getCount()).isEqualTo(0);
136+
137+
modelCatalog.set(model);
138+
139+
assertThat(counter.getCount()).isEqualTo(1);
140+
141+
assertThatThrownBy(() -> modelCatalog.set(model));
142+
// not be called if the set is not successful
143+
assertThat(counter.getCount()).isEqualTo(1);
144+
}
145+
121146
@Test
122147
void shouldThrowWhenPublishing() {
123148
modelCatalog.set(TEST_MODEL);

0 commit comments

Comments
 (0)