Skip to content

Commit 9f30e9a

Browse files
committed
Add count method to ModelCatalog
1 parent 9e9ef23 commit 9f30e9a

File tree

5 files changed

+56
-0
lines changed

5 files changed

+56
-0
lines changed

model-catalog-api/src/main/java/org/neo4j/gds/core/model/ModelCatalog.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ <D, C extends ModelConfig, I extends ToMapConvertible> Model<D, C, I> get(
4444

4545
Stream<Model<?, ?, ?>> getAllModels();
4646

47+
long modelsCount();
48+
4749
boolean exists(String username, String modelName);
4850

4951
Model<?, ?, ?> dropOrThrow(String username, String modelName);

model-catalog-api/src/main/java/org/neo4j/gds/core/model/UserCatalog.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,6 @@ <D, C extends ModelConfig, I extends ToMapConvertible> Model<D, C, I> get(
5656
void removeAllLoadedModels();
5757

5858
void verifyModelCanBeStored(String modelName, String modelType);
59+
60+
long size();
5961
}

open-model-catalog/src/main/java/org/neo4j/gds/core/model/OpenModelCatalog.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,11 @@ public <D, C extends ModelConfig, I extends ToMapConvertible> Model<D, C, I> get
9292
.flatMap(entry -> entry.getValue().streamModels());
9393
}
9494

95+
@Override
96+
public long modelsCount() {
97+
return userCatalogs.values().stream().mapToLong(OpenUserCatalog::size).sum();
98+
}
99+
95100
@Override
96101
public boolean exists(String username, String modelName) {
97102
return getUserCatalog(username).exists(modelName);

open-model-catalog/src/main/java/org/neo4j/gds/core/model/OpenUserCatalog.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,11 @@ public void verifyModelCanBeStored(String modelName, String modelType) {
120120
verifyModelsLimit(modelType);
121121
}
122122

123+
@Override
124+
public long size() {
125+
return userModels.size();
126+
}
127+
123128
private <D, C extends ModelConfig, I extends ToMapConvertible> Model<D, C, I> get(
124129
Model<?, ?, ?> model,
125130
Class<D> dataClass,

open-model-catalog/src/test/java/org/neo4j/gds/core/model/OpenModelCatalogTest.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import java.util.Map;
3434
import java.util.NoSuchElementException;
3535

36+
import static org.assertj.core.api.Assertions.assertThat;
3637
import static org.assertj.core.api.Assertions.assertThatThrownBy;
3738
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
3839
import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -156,6 +157,47 @@ void shouldStoreModels() {
156157
);
157158
}
158159

160+
@Test
161+
void shouldCountModels() {
162+
var model1 = Model.of(
163+
"testAlgo1",
164+
GRAPH_SCHEMA,
165+
"modelData1",
166+
TestTrainConfig.of(USERNAME, "testModel1"),
167+
Map::of
168+
);
169+
170+
var model2 = Model.of(
171+
"testAlgo2",
172+
GRAPH_SCHEMA,
173+
1337L,
174+
TestTrainConfig.of(USERNAME, "testModel2"),
175+
Map::of
176+
);
177+
178+
var publicModel = Model.of(
179+
"testAlgo2",
180+
GRAPH_SCHEMA,
181+
1337L,
182+
TestTrainConfig.of("anotherUser", "testModel2"),
183+
Map::of
184+
);
185+
186+
assertThat(modelCatalog.modelsCount()).isEqualTo(0);
187+
188+
modelCatalog.set(model1);
189+
190+
assertThat(modelCatalog.modelsCount()).isEqualTo(1);
191+
192+
modelCatalog.set(model2);
193+
194+
assertThat(modelCatalog.modelsCount()).isEqualTo(2);
195+
196+
modelCatalog.set(publicModel);
197+
198+
assertThat(modelCatalog.modelsCount()).isEqualTo(3);
199+
}
200+
159201
@Test
160202
void shouldThrowWhenTryingToGetOtherUsersModel() {
161203
modelCatalog.set(TEST_MODEL);

0 commit comments

Comments
 (0)