Skip to content

Commit 14b76bc

Browse files
committed
Make splits invariant to class values
1 parent c8aa819 commit 14b76bc

File tree

7 files changed

+208
-11
lines changed

7 files changed

+208
-11
lines changed

ml/ml-algo/src/main/java/org/neo4j/gds/ml/metrics/classification/GlobalAccuracy.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import java.math.BigDecimal;
2525
import java.math.RoundingMode;
2626
import java.util.Comparator;
27+
import java.util.Objects;
2728

2829
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
2930

@@ -69,7 +70,7 @@ public double compute(HugeIntArray targets, HugeIntArray predictions) {
6970

7071
@Override
7172
public int hashCode() {
72-
return super.hashCode();
73+
return Objects.hash(NAME);
7374
}
7475

7576
@Override

ml/ml-algo/src/main/java/org/neo4j/gds/ml/splitting/StratifiedKFoldSplitter.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ public class StratifiedKFoldSplitter {
5050
private final ReadOnlyHugeLongArray ids;
5151
private final LongToLongFunction targets;
5252
private final SplittableRandom random;
53-
private final SortedSet<Long> distinctTargets;
53+
private final SortedSet<Long> distinctInternalTargets;
5454

5555
public static MemoryEstimation memoryEstimationForNodeSet(int k, double trainFraction) {
5656
return memoryEstimation(k, dim -> (long) (dim.nodeCount() * trainFraction));
@@ -79,12 +79,12 @@ public static MemoryEstimation memoryEstimation(int k, ToLongFunction<GraphDimen
7979
);
8080
}
8181

82-
public StratifiedKFoldSplitter(int k, ReadOnlyHugeLongArray ids, LongToLongFunction targets, Optional<Long> randomSeed, SortedSet<Long> distinctTargets) {
82+
public StratifiedKFoldSplitter(int k, ReadOnlyHugeLongArray ids, LongToLongFunction targets, Optional<Long> randomSeed, SortedSet<Long> distinctInternalTargets) {
8383
this.k = k;
8484
this.ids = ids;
8585
this.targets = targets;
8686
this.random = ShuffleUtil.createRandomDataGenerator(randomSeed);
87-
this.distinctTargets = distinctTargets;
87+
this.distinctInternalTargets = distinctInternalTargets;
8888
}
8989

9090
public List<TrainingExamplesSplit> splits() {
@@ -97,7 +97,7 @@ public List<TrainingExamplesSplit> splits() {
9797
allocateArrays(nodeCount, trainSets, testSets);
9898

9999
var roundRobinPointer = new MutableInt();
100-
distinctTargets.forEach(currentClass -> {
100+
distinctInternalTargets.forEach(currentClass -> {
101101
for (long offset = 0; offset < ids.size(); offset++) {
102102
var id = ids.get(offset);
103103
if (targets.applyAsLong(id) == currentClass) {

ml/ml-algo/src/main/java/org/neo4j/gds/ml/training/CrossValidation.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ public CrossValidation(
8585
public void selectModel(
8686
ReadOnlyHugeLongArray outerTrainSet,
8787
LongToLongFunction targets,
88-
SortedSet<Long> distinctTargets,
88+
SortedSet<Long> distinctInternalTargets,
8989
TrainingStatistics trainingStatistics,
9090
Iterator<TrainerConfig> modelCandidates
9191
) {
@@ -95,7 +95,7 @@ public void selectModel(
9595
outerTrainSet,
9696
targets,
9797
randomSeed,
98-
distinctTargets
98+
distinctInternalTargets
9999
).splits();
100100
progressTracker.endSubTask("Create validation folds");
101101

ml/ml-algo/src/main/java/org/neo4j/gds/ml/training/TrainingStatistics.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ public List<EvaluationScores> getValidationStats(Metric metric) {
6060
return modelCandidateStats.stream().map(stats -> stats.validationStats().get(metric)).collect(Collectors.toList());
6161
}
6262

63+
@TestOnly
64+
public Double getTestScore(Metric metric) {
65+
return testScores.get(metric);
66+
}
67+
6368
/**
6469
* Turns this class into a Cypher map, to be returned in a procedure YIELD field.
6570
* This is intentionally omitting the test scores.

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionTrain.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ private void findBestModelCandidate(
207207
crossValidation.selectModel(
208208
trainRelationshipIds,
209209
trainData.labels()::get,
210+
//LP always have 2 classes 0,1 the original ids happen to be the same as internal
210211
new TreeSet<>(classIdMap.originalIdsList()),
211212
trainingStatistics,
212213
modelCandidates

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/nodePipeline/classification/train/NodeClassificationTrain.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
import java.util.TreeSet;
7171
import java.util.function.LongUnaryOperator;
7272
import java.util.stream.Collectors;
73+
import java.util.stream.LongStream;
7374

7475
import static org.neo4j.gds.core.utils.mem.MemoryEstimations.delegateEstimation;
7576
import static org.neo4j.gds.core.utils.mem.MemoryEstimations.maxEstimation;
@@ -374,10 +375,10 @@ private void findBestModelCandidate(ReadOnlyHugeLongArray trainNodeIds, Features
374375
trainConfig.randomSeed()
375376
);
376377

377-
var sortedClassIds = new TreeSet<Long>();
378-
for (long clazz : classCounts.keys()) {
379-
sortedClassIds.add(clazz);
380-
}
378+
var sortedClassIds = LongStream
379+
.range(0, classCounts.size())
380+
.boxed()
381+
.collect(Collectors.toCollection(TreeSet::new));
381382

382383
crossValidation.selectModel(
383384
trainNodeIds,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.ml.pipeline.nodePipeline.classification.train;
21+
22+
import org.junit.jupiter.api.Test;
23+
import org.neo4j.gds.api.GraphStore;
24+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
25+
import org.neo4j.gds.executor.ExecutionContext;
26+
import org.neo4j.gds.extension.GdlExtension;
27+
import org.neo4j.gds.extension.GdlGraph;
28+
import org.neo4j.gds.extension.Inject;
29+
import org.neo4j.gds.ml.metrics.classification.Accuracy;
30+
import org.neo4j.gds.ml.metrics.classification.ClassificationMetricSpecification;
31+
import org.neo4j.gds.ml.metrics.classification.GlobalAccuracy;
32+
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionTrainConfigImpl;
33+
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeFeatureProducer;
34+
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeFeatureStep;
35+
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline;
36+
37+
import java.util.List;
38+
39+
import static org.assertj.core.api.Assertions.assertThat;
40+
41+
@GdlExtension
42+
public class NodeClassificationTrainClassValueInvarianceTest {
43+
44+
private static final String GRAPH_NAME_1 = "G11";
45+
46+
@GdlGraph(graphNamePrefix = "nodes1")
47+
private static final String DB_QUERY1 =
48+
"CREATE " +
49+
" (a1:N {bananas: 100.0, arrayProperty: [1.2, 1.2], a: 1.2, b: 1.2, t: 0})" +
50+
", (a2:N {bananas: 100.0, arrayProperty: [2.8, 2.5], a: 2.8, b: 2.5, t: 0})" +
51+
", (a3:N {bananas: 100.0, arrayProperty: [3.3, 0.5], a: 3.3, b: 0.5, t: 0})" +
52+
", (a4:N {bananas: 100.0, arrayProperty: [1.0, 0.5], a: 1.0, b: 0.5, t: 0})" +
53+
", (a5:N {bananas: 100.0, arrayProperty: [1.32, 0.5], a: 1.32, b: 0.5, t: 0})" +
54+
", (a6:N {bananas: 100.0, arrayProperty: [1.3, 1.5], a: 1.3, b: 1.5, t: 1})" +
55+
", (a7:N {bananas: 100.0, arrayProperty: [5.3, 10.5], a: 5.3, b: 10.5, t: 1})" +
56+
", (a8:N {bananas: 100.0, arrayProperty: [1.3, 2.5], a: 1.3, b: 2.5, t: 1})" +
57+
", (a9:N {bananas: 100.0, arrayProperty: [0.0, 66.8], a: 0.0, b: 66.8, t: 1})" +
58+
", (a10:N {bananas: 100.0, arrayProperty: [0.1, 2.8], a: 0.1, b: 2.8, t: 1})" +
59+
", (a11:N {bananas: 100.0, arrayProperty: [0.66, 2.8], a: 0.66, b: 2.8, t: 1})" +
60+
", (a12:N {bananas: 100.0, arrayProperty: [2.0, 10.8], a: 2.0, b: 10.8, t: 1})" +
61+
", (a13:N {bananas: 100.0, arrayProperty: [5.0, 7.8], a: 5.0, b: 7.8, t: 2})" +
62+
", (a14:N {bananas: 100.0, arrayProperty: [4.0, 5.8], a: 4.0, b: 5.8, t: 2})" +
63+
", (a15:N {bananas: 100.0, arrayProperty: [1.0, 0.9], a: 1.0, b: 0.9, t: 2})";
64+
65+
@Inject
66+
private GraphStore nodes1GraphStore;
67+
68+
private static final String GRAPH_NAME_2 = "G2";
69+
70+
@GdlGraph(graphNamePrefix = "nodes2")
71+
private static final String DB_QUERY2 =
72+
"CREATE " +
73+
" (a1:N {bananas: 100.0, arrayProperty: [1.2, 1.2], a: 1.2, b: 1.2, t: 0})" +
74+
", (a2:N {bananas: 100.0, arrayProperty: [2.8, 2.5], a: 2.8, b: 2.5, t: 0})" +
75+
", (a3:N {bananas: 100.0, arrayProperty: [3.3, 0.5], a: 3.3, b: 0.5, t: 0})" +
76+
", (a4:N {bananas: 100.0, arrayProperty: [1.0, 0.5], a: 1.0, b: 0.5, t: 0})" +
77+
", (a5:N {bananas: 100.0, arrayProperty: [1.32, 0.5], a: 1.32, b: 0.5, t: 0})" +
78+
", (a6:N {bananas: 100.0, arrayProperty: [1.3, 1.5], a: 1.3, b: 1.5, t: 222})" +
79+
", (a7:N {bananas: 100.0, arrayProperty: [5.3, 10.5], a: 5.3, b: 10.5, t: 222})" +
80+
", (a8:N {bananas: 100.0, arrayProperty: [1.3, 2.5], a: 1.3, b: 2.5, t: 222})" +
81+
", (a9:N {bananas: 100.0, arrayProperty: [0.0, 66.8], a: 0.0, b: 66.8, t: 222})" +
82+
", (a10:N {bananas: 100.0, arrayProperty: [0.1, 2.8], a: 0.1, b: 2.8, t: 222})" +
83+
", (a11:N {bananas: 100.0, arrayProperty: [0.66, 2.8], a: 0.66, b: 2.8, t: 222})" +
84+
", (a12:N {bananas: 100.0, arrayProperty: [2.0, 10.8], a: 2.0, b: 10.8, t: 222})" +
85+
", (a13:N {bananas: 100.0, arrayProperty: [5.0, 7.8], a: 5.0, b: 7.8, t: 333})" +
86+
", (a14:N {bananas: 100.0, arrayProperty: [4.0, 5.8], a: 4.0, b: 5.8, t: 333})" +
87+
", (a15:N {bananas: 100.0, arrayProperty: [1.0, 0.9], a: 1.0, b: 0.9, t: 333})";
88+
89+
@Inject
90+
private GraphStore nodes2GraphStore;
91+
92+
/**
93+
* This tests that the specific class values do not matter, as long as the ordering is the same.
94+
* However, if the *ordering* of the class values differ, the splits in cross-validation could differ, resulting in different accuracies.
95+
*/
96+
@Test
97+
void trainWithDifferentClassValues() {
98+
var pipeline = new NodeClassificationTrainingPipeline();
99+
pipeline.addFeatureStep(NodeFeatureStep.of("a"));
100+
pipeline.addFeatureStep(NodeFeatureStep.of("b"));
101+
102+
var lrTrainerConfig = LogisticRegressionTrainConfigImpl.builder().build();
103+
pipeline.addTrainerConfig(lrTrainerConfig);
104+
105+
var accuracyMetricSpec = ClassificationMetricSpecification.Parser.parse("accuracy");
106+
var accuracyPerClassMetricSpec = ClassificationMetricSpecification.Parser.parse("accuracy(class=*)");
107+
108+
var config01 = createConfig("model1", GRAPH_NAME_1, List.of(accuracyMetricSpec, accuracyPerClassMetricSpec), 1L);
109+
var ncTrain01 = createWithExecutionContext(
110+
nodes1GraphStore,
111+
pipeline,
112+
config01,
113+
ProgressTracker.NULL_TRACKER
114+
);
115+
var result01 = ncTrain01.run();
116+
assertThat(result01.classifier().data().featureDimension()).isEqualTo(2);
117+
118+
//Run with graph that have class values 0 and 2
119+
var config02 = createConfig("model2", GRAPH_NAME_2, List.of(accuracyMetricSpec, accuracyPerClassMetricSpec), 1L);
120+
var ncTrain02 = createWithExecutionContext(
121+
nodes2GraphStore,
122+
pipeline,
123+
config02,
124+
ProgressTracker.NULL_TRACKER
125+
);
126+
var result02 = ncTrain02.run();
127+
assertThat(result01.classifier().data().featureDimension()).isEqualTo(2);
128+
129+
var globalAccuracy = new GlobalAccuracy();
130+
var accuracyForClass1 = new Accuracy(0, 0);
131+
var accuracyForClass2 = new Accuracy(1, 1);
132+
var accuracyForClass3 = new Accuracy(2, 2);
133+
134+
var accuracyForClass0 = new Accuracy(0, 0);
135+
var accuracyForClass222 = new Accuracy(222, 1);
136+
var accuracyForClass333 = new Accuracy(333, 2);
137+
138+
assertThat(result01.trainingStatistics().getTrainStats(globalAccuracy)).isEqualTo(result02.trainingStatistics().getTrainStats(globalAccuracy));
139+
assertThat(result01.trainingStatistics().getTrainStats(accuracyForClass1)).isEqualTo(result02.trainingStatistics().getTrainStats(accuracyForClass0));
140+
assertThat(result01.trainingStatistics().getTrainStats(accuracyForClass2)).isEqualTo(result02.trainingStatistics().getTrainStats(accuracyForClass222));
141+
assertThat(result01.trainingStatistics().getTrainStats(accuracyForClass3)).isEqualTo(result02.trainingStatistics().getTrainStats(accuracyForClass333));
142+
143+
assertThat(result01.trainingStatistics().getValidationStats(globalAccuracy)).isEqualTo(result02.trainingStatistics().getValidationStats(globalAccuracy));
144+
assertThat(result01.trainingStatistics().getValidationStats(accuracyForClass1)).isEqualTo(result02.trainingStatistics().getValidationStats(accuracyForClass0));
145+
assertThat(result01.trainingStatistics().getValidationStats(accuracyForClass2)).isEqualTo(result02.trainingStatistics().getValidationStats(accuracyForClass222));
146+
assertThat(result01.trainingStatistics().getValidationStats(accuracyForClass3)).isEqualTo(result02.trainingStatistics().getValidationStats(accuracyForClass333));
147+
148+
assertThat(result01.trainingStatistics().getTestScore(globalAccuracy)).isEqualTo(result02.trainingStatistics().getTestScore(globalAccuracy));
149+
assertThat(result01.trainingStatistics().getTestScore(accuracyForClass1)).isEqualTo(result02.trainingStatistics().getTestScore(accuracyForClass0));
150+
assertThat(result01.trainingStatistics().getTestScore(accuracyForClass2)).isEqualTo(result02.trainingStatistics().getTestScore(accuracyForClass222));
151+
assertThat(result01.trainingStatistics().getTestScore(accuracyForClass3)).isEqualTo(result02.trainingStatistics().getTestScore(accuracyForClass333));
152+
153+
}
154+
155+
private NodeClassificationPipelineTrainConfig createConfig(
156+
String modelName,
157+
String graphName,
158+
List<ClassificationMetricSpecification> metricSpecification,
159+
long randomSeed
160+
) {
161+
return NodeClassificationPipelineTrainConfigImpl.builder()
162+
.pipeline("")
163+
.graphName(graphName)
164+
.modelUser("DUMMY")
165+
.modelName(modelName)
166+
.concurrency(1)
167+
.randomSeed(randomSeed)
168+
.targetProperty("t")
169+
.metrics(metricSpecification)
170+
.build();
171+
}
172+
173+
static NodeClassificationTrain createWithExecutionContext(
174+
GraphStore graphStore,
175+
NodeClassificationTrainingPipeline pipeline,
176+
NodeClassificationPipelineTrainConfig config,
177+
ProgressTracker progressTracker
178+
) {
179+
var nodeFeatureProducer = NodeFeatureProducer.create(graphStore, config, ExecutionContext.EMPTY, progressTracker);
180+
return NodeClassificationTrain.create(
181+
graphStore,
182+
pipeline,
183+
config,
184+
nodeFeatureProducer,
185+
progressTracker
186+
);
187+
}
188+
189+
}

0 commit comments

Comments
 (0)