Skip to content

Commit 31c4ffc

Browse files
authored
Merge pull request #6676 from Mats-SX/eval-scores
Rename ModelStats to EvaluationScores
2 parents 4212d1f + 1b13760 commit 31c4ffc

File tree

10 files changed

+45
-45
lines changed

10 files changed

+45
-45
lines changed

ml/ml-algo/src/main/java/org/neo4j/gds/ml/metrics/ModelStats.java renamed to ml/ml-algo/src/main/java/org/neo4j/gds/ml/metrics/EvaluationScores.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
* Statistics of the metric of the model candidate over (inner) folds
2929
*/
3030
@ValueClass
31-
public interface ModelStats {
31+
public interface EvaluationScores {
3232

3333
double avg();
3434

@@ -45,7 +45,7 @@ default Map<String, Object> toMap() {
4545
);
4646
}
4747

48-
static ModelStats of(double avg, double min, double max) {
49-
return ImmutableModelStats.of(avg, min, max);
48+
static EvaluationScores of(double avg, double min, double max) {
49+
return ImmutableEvaluationScores.of(avg, min, max);
5050
}
5151
}

ml/ml-algo/src/main/java/org/neo4j/gds/ml/metrics/ModelCandidateStats.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434
@ValueClass
3535
public interface ModelCandidateStats extends ToMapConvertible {
3636
TrainerConfig trainerConfig();
37-
Map<Metric, ModelStats> trainingStats();
38-
Map<Metric, ModelStats> validationStats();
37+
Map<Metric, EvaluationScores> trainingStats();
38+
Map<Metric, EvaluationScores> validationStats();
3939

4040
@Override
4141
@Value.Auxiliary
@@ -98,8 +98,8 @@ private List<Metric> metrics() {
9898

9999
static ModelCandidateStats of(
100100
TrainerConfig trainerConfig,
101-
Map<Metric, ModelStats> trainStats,
102-
Map<Metric, ModelStats> validationStats
101+
Map<Metric, EvaluationScores> trainStats,
102+
Map<Metric, EvaluationScores> validationStats
103103
) {
104104
return ImmutableModelCandidateStats.of(trainerConfig, trainStats, validationStats);
105105
}

ml/ml-algo/src/main/java/org/neo4j/gds/ml/metrics/ModelStatsBuilder.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,15 @@ public void update(Metric metric, double value) {
4747
sum.merge(metric, value, Double::sum);
4848
}
4949

50-
public ModelStats build(Metric metric) {
51-
return ModelStats.of(
50+
public EvaluationScores build(Metric metric) {
51+
return EvaluationScores.of(
5252
sum.get(metric) / numberOfSplits,
5353
min.get(metric),
5454
max.get(metric)
5555
);
5656
}
5757

58-
public Map<Metric, ModelStats> build() {
58+
public Map<Metric, EvaluationScores> build() {
5959
return sum.keySet().stream()
6060
.collect(
6161
Collectors.toMap(Function.identity(), this::build)

ml/ml-algo/src/main/java/org/neo4j/gds/ml/training/TrainingStatistics.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
import org.jetbrains.annotations.TestOnly;
2323
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
2424
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
25-
import org.neo4j.gds.ml.metrics.ImmutableModelStats;
25+
import org.neo4j.gds.ml.metrics.EvaluationScores;
26+
import org.neo4j.gds.ml.metrics.ImmutableEvaluationScores;
2627
import org.neo4j.gds.ml.metrics.Metric;
2728
import org.neo4j.gds.ml.metrics.ModelCandidateStats;
28-
import org.neo4j.gds.ml.metrics.ModelStats;
2929
import org.neo4j.gds.ml.models.TrainerConfig;
3030

3131
import java.util.ArrayList;
@@ -51,12 +51,12 @@ public TrainingStatistics(List<? extends Metric> metrics) {
5151
}
5252

5353
@TestOnly
54-
public List<ModelStats> getTrainStats(Metric metric) {
54+
public List<EvaluationScores> getTrainStats(Metric metric) {
5555
return modelCandidateStats.stream().map(stats -> stats.trainingStats().get(metric)).collect(Collectors.toList());
5656
}
5757

5858
@TestOnly
59-
public List<ModelStats> getValidationStats(Metric metric) {
59+
public List<EvaluationScores> getValidationStats(Metric metric) {
6060
return modelCandidateStats.stream().map(stats -> stats.validationStats().get(metric)).collect(Collectors.toList());
6161
}
6262

@@ -85,7 +85,7 @@ public Map<Metric, Double> trainMetricsAvg(int trial) {
8585
return extractAverage(modelCandidateStats.get(trial).trainingStats());
8686
}
8787

88-
private Map<Metric, Double> extractAverage(Map<Metric, ModelStats> statsMap) {
88+
private Map<Metric, Double> extractAverage(Map<Metric, EvaluationScores> statsMap) {
8989
return statsMap.entrySet().stream()
9090
.collect(Collectors.toMap(
9191
Map.Entry::getKey,
@@ -148,7 +148,7 @@ public static MemoryEstimation memoryEstimationStatsMap(int numberOfMetricsSpeci
148148
public static MemoryEstimation memoryEstimationStatsMap(int numberOfMetricsSpecifications, int numberOfModelCandidates, int numberOfClasses) {
149149
var numberOfMetrics = numberOfMetricsSpecifications * numberOfClasses;
150150
var numberOfModelStats = numberOfMetrics * numberOfModelCandidates;
151-
var sizeOfOneModelStatsInBytes = sizeOfInstance(ImmutableModelStats.class);
151+
var sizeOfOneModelStatsInBytes = sizeOfInstance(ImmutableEvaluationScores.class);
152152
var sizeOfAllModelStatsInBytes = sizeOfOneModelStatsInBytes * numberOfModelStats;
153153
return MemoryEstimations.builder("StatsMap")
154154
.fixed("array list", sizeOfInstance(ArrayList.class))

ml/ml-algo/src/test/java/org/neo4j/gds/ml/metrics/ModelStatsBuilderTest.java renamed to ml/ml-algo/src/test/java/org/neo4j/gds/ml/metrics/EvaluationScoresBuilderTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
import static org.assertj.core.api.Assertions.assertThat;
2626

27-
class ModelStatsBuilderTest {
27+
class EvaluationScoresBuilderTest {
2828

2929
@ParameterizedTest
3030
@CsvSource(value = {

ml/ml-algo/src/test/java/org/neo4j/gds/ml/metrics/ModelCandidateStatsTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,18 @@ class ModelCandidateStatsTest {
3434
void testRender() {
3535
var candidateStats = ModelCandidateStats.of(
3636
LogisticRegressionTrainConfig.DEFAULT,
37-
Map.of(OUT_OF_BAG_ERROR, ModelStats.of(
37+
Map.of(OUT_OF_BAG_ERROR, EvaluationScores.of(
3838
0.33,
3939
0.13,
4040
0.13
4141
)),
4242
Map.of(
43-
ROOT_MEAN_SQUARED_ERROR, ModelStats.of(
43+
ROOT_MEAN_SQUARED_ERROR, EvaluationScores.of(
4444
0.3,
4545
0.1,
4646
0.1
4747
),
48-
AUCPR, ModelStats.of(
48+
AUCPR, EvaluationScores.of(
4949
0.4,
5050
0.2,
5151
0.2

ml/ml-algo/src/test/java/org/neo4j/gds/ml/training/TrainingStatisticsTest.java

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
import org.neo4j.gds.collections.LongMultiSet;
2828
import org.neo4j.gds.core.GraphDimensions;
2929
import org.neo4j.gds.ml.core.subgraph.LocalIdMap;
30+
import org.neo4j.gds.ml.metrics.EvaluationScores;
3031
import org.neo4j.gds.ml.metrics.Metric;
3132
import org.neo4j.gds.ml.metrics.ModelCandidateStats;
32-
import org.neo4j.gds.ml.metrics.ModelStats;
3333
import org.neo4j.gds.ml.metrics.classification.F1Weighted;
3434
import org.neo4j.gds.ml.metrics.classification.GlobalAccuracy;
3535
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionTrainConfig;
@@ -67,12 +67,12 @@ void selectsBestParametersAccordingToMainMetric(List<Metric> metrics, String exp
6767
new TestTrainerConfig("lower average rmse"),
6868
Map.of(),
6969
Map.of(
70-
ROOT_MEAN_SQUARED_ERROR, ModelStats.of(
70+
ROOT_MEAN_SQUARED_ERROR, EvaluationScores.of(
7171
0.2,
7272
0.2,
7373
0.2
7474
),
75-
AUCPR, ModelStats.of(
75+
AUCPR, EvaluationScores.of(
7676
0.0,
7777
1000,
7878
1000
@@ -83,12 +83,12 @@ void selectsBestParametersAccordingToMainMetric(List<Metric> metrics, String exp
8383
new TestTrainerConfig("higher average aucpr"),
8484
Map.of(),
8585
Map.of(
86-
ROOT_MEAN_SQUARED_ERROR, ModelStats.of(
86+
ROOT_MEAN_SQUARED_ERROR, EvaluationScores.of(
8787
0.3,
8888
0.1,
8989
0.1
9090
),
91-
AUCPR, ModelStats.of(
91+
AUCPR, EvaluationScores.of(
9292
0.4,
9393
0.2,
9494
0.2
@@ -108,7 +108,7 @@ void getBestTrialStuff() {
108108
Map.of(),
109109
Map.of(
110110
AUCPR,
111-
ModelStats.of(
111+
EvaluationScores.of(
112112
0.1,
113113
1000,
114114
1000
@@ -120,7 +120,7 @@ void getBestTrialStuff() {
120120
Map.of(),
121121
Map.of(
122122
AUCPR,
123-
ModelStats.of(
123+
EvaluationScores.of(
124124
0.2,
125125
0.2,
126126
0.2
@@ -132,7 +132,7 @@ void getBestTrialStuff() {
132132
Map.of(),
133133
Map.of(
134134
AUCPR,
135-
ModelStats.of(
135+
EvaluationScores.of(
136136
0.2,
137137
0.2,
138138
0.2
@@ -143,12 +143,12 @@ void getBestTrialStuff() {
143143
new TestTrainerConfig("notprimarymetric"),
144144
Map.of(),
145145
Map.of(
146-
AUCPR, ModelStats.of(
146+
AUCPR, EvaluationScores.of(
147147
0.0,
148148
0.0,
149149
0.0
150150
),
151-
F1_WEIGHTED, ModelStats.of(
151+
F1_WEIGHTED, EvaluationScores.of(
152152
5000,
153153
5000,
154154
5000
@@ -165,17 +165,17 @@ void rendersBestModel() {
165165
var trainingStatistics = new TrainingStatistics(List.of(AUCPR, F1_WEIGHTED, OUT_OF_BAG_ERROR));
166166

167167
var candidate = new TestTrainerConfig("train");
168-
ModelStats trainStats = ModelStats.of(
168+
EvaluationScores trainStats = EvaluationScores.of(
169169
0.1,
170170
0.1,
171171
0.1
172172
);
173-
ModelStats validationStats = ModelStats.of(
173+
EvaluationScores validationStats = EvaluationScores.of(
174174
0.4,
175175
0.3,
176176
0.5
177177
);
178-
ModelStats oobStats = ModelStats.of(
178+
EvaluationScores oobStats = EvaluationScores.of(
179179
0.5,
180180
0.4,
181181
0.9
@@ -241,13 +241,13 @@ void toMap() {
241241

242242
selectResult.addCandidateStats(ModelCandidateStats.of(
243243
firstCandidate,
244-
Map.of(ACCURACY, ModelStats.of(0.33, 0.1, 0.6)),
245-
Map.of(ACCURACY, ModelStats.of(0.4, 0.3, 0.5))
244+
Map.of(ACCURACY, EvaluationScores.of(0.33, 0.1, 0.6)),
245+
Map.of(ACCURACY, EvaluationScores.of(0.4, 0.3, 0.5))
246246
));
247247
selectResult.addCandidateStats(ModelCandidateStats.of(
248248
secondCandidate,
249-
Map.of(ACCURACY, ModelStats.of(0.2, 0.01, 0.7)),
250-
Map.of(ACCURACY, ModelStats.of(0.8, 0.7, 0.9))
249+
Map.of(ACCURACY, EvaluationScores.of(0.2, 0.01, 0.7)),
250+
Map.of(ACCURACY, EvaluationScores.of(0.8, 0.7, 0.9))
251251
));
252252

253253
var expectedTrainAccuracyStats1 = Map.of("avg", 0.33, "min", 0.1, "max", 0.6);

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionTrain.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
import org.neo4j.gds.ml.core.ReadOnlyHugeLongIdentityArray;
3535
import org.neo4j.gds.ml.core.batch.BatchQueue;
3636
import org.neo4j.gds.ml.core.subgraph.LocalIdMap;
37-
import org.neo4j.gds.ml.metrics.ImmutableModelStats;
37+
import org.neo4j.gds.ml.metrics.ImmutableEvaluationScores;
3838
import org.neo4j.gds.ml.metrics.MetricConsumer;
3939
import org.neo4j.gds.ml.metrics.ModelSpecificMetricsHandler;
4040
import org.neo4j.gds.ml.metrics.ModelStatsBuilder;
@@ -325,7 +325,7 @@ public static MemoryEstimation estimate(
325325
.add("Outer train stats map", TrainingStatistics.memoryEstimationStatsMap(numberOfMetrics, 1, 1))
326326
.add("Test stats map", TrainingStatistics.memoryEstimationStatsMap(numberOfMetrics, 1, 1))
327327
.fixed("Best model stats", MemoryRange
328-
.of(MemoryUsage.sizeOfInstance(ImmutableModelStats.class))
328+
.of(MemoryUsage.sizeOfInstance(ImmutableEvaluationScores.class))
329329
.times(2)
330330
.add(Double.BYTES * 2)
331331
.times(numberOfMetrics))

pipeline/src/test/java/org/neo4j/gds/ml/pipeline/nodePipeline/classification/train/NodeClassificationToModelConverterTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import org.neo4j.gds.collections.LongMultiSet;
3333
import org.neo4j.gds.ml.core.subgraph.LocalIdMap;
3434
import org.neo4j.gds.ml.metrics.ModelCandidateStats;
35-
import org.neo4j.gds.ml.metrics.ModelStats;
35+
import org.neo4j.gds.ml.metrics.EvaluationScores;
3636
import org.neo4j.gds.ml.metrics.classification.ClassificationMetricSpecification;
3737
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionClassifier;
3838
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionData;
@@ -68,8 +68,8 @@ void convertsModel() {
6868
trainingStatistics.addTestScore(metric, 0.799999);
6969
trainingStatistics.addOuterTrainScore(metric, 0.666666);
7070
trainingStatistics.addCandidateStats(ModelCandidateStats.of(modelCandidate,
71-
Map.of(metric, ModelStats.of(0.89999, 0.79999, 0.99999)),
72-
Map.of(metric, ModelStats.of(0.649999, 0.499999, 0.7999999))
71+
Map.of(metric, EvaluationScores.of(0.89999, 0.79999, 0.99999)),
72+
Map.of(metric, EvaluationScores.of(0.649999, 0.499999, 0.7999999))
7373
));
7474
var ncResult = ImmutableNodeClassificationTrainResult.of(classifier, trainingStatistics, classIdMap, classCounts);
7575
var pipeline = new NodeClassificationTrainingPipeline();

pipeline/src/test/java/org/neo4j/gds/ml/pipeline/nodePipeline/regression/NodeRegressionToModelConverterTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121

2222
import org.junit.jupiter.api.Test;
2323
import org.neo4j.gds.gdl.GdlFactory;
24+
import org.neo4j.gds.ml.metrics.EvaluationScores;
2425
import org.neo4j.gds.ml.metrics.ModelCandidateStats;
25-
import org.neo4j.gds.ml.metrics.ModelStats;
2626
import org.neo4j.gds.ml.metrics.regression.RegressionMetrics;
2727
import org.neo4j.gds.ml.models.Regressor;
2828
import org.neo4j.gds.ml.models.TrainingMethod;
@@ -104,8 +104,8 @@ public int featureDimension() {
104104
trainStats.addTestScore(metric, 0.799999);
105105
trainStats.addOuterTrainScore(metric, 0.666666);
106106
trainStats.addCandidateStats(ModelCandidateStats.of(modelCandidate,
107-
Map.of(metric, ModelStats.of(0.89999, 0.79999, 0.99999)),
108-
Map.of(metric, ModelStats.of(0.649999, 0.499999, 0.7999999))
107+
Map.of(metric, EvaluationScores.of(0.89999, 0.79999, 0.99999)),
108+
Map.of(metric, EvaluationScores.of(0.649999, 0.499999, 0.7999999))
109109
));
110110

111111
var trainResult = ImmutableNodeRegressionTrainResult.of(regressor, trainStats);

0 commit comments

Comments
 (0)