Skip to content

Commit 8cdd32c

Browse files
committed
Some more records
1 parent 1e6a400 commit 8cdd32c

File tree

13 files changed

+138
-141
lines changed

13 files changed

+138
-141
lines changed

algo/src/main/java/org/neo4j/gds/algorithms/machinelearning/KGEPredictResult.java

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,6 @@
1919
*/
2020
package org.neo4j.gds.algorithms.machinelearning;
2121

22-
import org.neo4j.gds.annotation.ValueClass;
2322
import org.neo4j.gds.similarity.nodesim.TopKMap;
2423

25-
@ValueClass
26-
public interface KGEPredictResult {
27-
TopKMap topKMap();
28-
static KGEPredictResult of(TopKMap topKMap) {
29-
return ImmutableKGEPredictResult.of(topKMap);
30-
}
31-
}
24+
public record KGEPredictResult(TopKMap topKMap) {}

algo/src/main/java/org/neo4j/gds/algorithms/machinelearning/TopKMapComputer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ public KGEPredictResult compute() {
128128

129129
progressTracker.endSubTask();
130130

131-
return KGEPredictResult.of(topKMap);
131+
return new KGEPredictResult(topKMap);
132132
}
133133

134134

algo/src/main/java/org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainer.java

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -292,15 +292,10 @@ private EpochResult trainEpoch(
292292
progressTracker.endSubTask("Iteration");
293293
}
294294

295-
return ImmutableEpochResult.of(converged, iterationLosses);
295+
return new EpochResult(converged, iterationLosses);
296296
}
297297

298-
@ValueClass
299-
interface EpochResult {
300-
boolean converged();
301-
302-
List<Double> losses();
303-
}
298+
private record EpochResult(boolean converged, List<Double> losses) {}
304299

305300
static class BatchTask implements Runnable {
306301

@@ -384,22 +379,13 @@ default Map<String, Object> toMap() {
384379
}
385380
}
386381

387-
@ValueClass
388-
public interface ModelTrainResult {
389-
390-
GraphSageTrainMetrics metrics();
391-
392-
Layer[] layers();
393-
382+
public record ModelTrainResult(GraphSageTrainMetrics metrics, Layer[] layers) {
394383
static ModelTrainResult of(
395384
List<List<Double>> iterationLossesPerEpoch,
396385
boolean converged,
397386
Layer[] layers
398387
) {
399-
return ImmutableModelTrainResult.builder()
400-
.layers(layers)
401-
.metrics(ImmutableGraphSageTrainMetrics.of(iterationLossesPerEpoch, converged))
402-
.build();
388+
return new ModelTrainResult(ImmutableGraphSageTrainMetrics.of(iterationLossesPerEpoch, converged), layers);
403389
}
404390
}
405391
}

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2Vec.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ public Node2VecResult compute() {
9393
);
9494
}
9595

96-
var probabilitiesBuilder = new RandomWalkProbabilities.Builder(
96+
var probabilitiesBuilder = new RandomWalkProbabilitiesBuilder(
9797
graph.nodeCount(),
9898
concurrency,
9999
samplingWalkParameters.positiveSamplingFactor(),
@@ -121,7 +121,7 @@ public Node2VecResult compute() {
121121

122122
private List<Node2VecRandomWalkTask> walkTasks(
123123
CompressedRandomWalks compressedRandomWalks,
124-
RandomWalkProbabilities.Builder randomWalkPropabilitiesBuilder,
124+
RandomWalkProbabilitiesBuilder randomWalkPropabilitiesBuilder,
125125
Graph graph,
126126
Optional<Long> maybeRandomSeed,
127127
Concurrency concurrency,
@@ -166,7 +166,7 @@ private List<Node2VecRandomWalkTask> walkTasks(
166166
}
167167

168168

169-
CompressedRandomWalks createWalks(RandomWalkProbabilities.Builder probabilitiesBuilder){
169+
CompressedRandomWalks createWalks(RandomWalkProbabilitiesBuilder probabilitiesBuilder){
170170
var walks = new CompressedRandomWalks(graph.nodeCount() * samplingWalkParameters.walksPerNode());
171171

172172
progressTracker.beginSubTask("RandomWalk");

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecRandomWalkTask.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ final class Node2VecRandomWalkTask implements Runnable {
3636
private final TerminationFlag terminationFlag;
3737
private final AtomicLong walkIndex;
3838
private final CompressedRandomWalks compressedRandomWalks;
39-
private final RandomWalkProbabilities.Builder randomWalkProbabilitiesBuilder;
39+
private final RandomWalkProbabilitiesBuilder randomWalkProbabilitiesBuilder;
4040
private final RandomWalkSampler sampler;
4141
private final int walkBufferSize;
4242
private int walks;
@@ -52,7 +52,7 @@ final class Node2VecRandomWalkTask implements Runnable {
5252
TerminationFlag terminationFlag,
5353
AtomicLong walkIndex,
5454
CompressedRandomWalks compressedRandomWalks,
55-
RandomWalkProbabilities.Builder randomWalkProbabilitiesBuilder,
55+
RandomWalkProbabilitiesBuilder randomWalkProbabilitiesBuilder,
5656
int walkBufferSize,
5757
long randomSeed,
5858
int walkLength,

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/RandomWalkProbabilities.java

Lines changed: 6 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -19,106 +19,13 @@
1919
*/
2020
package org.neo4j.gds.embeddings.node2vec;
2121

22-
import org.neo4j.gds.annotation.ValueClass;
2322
import org.neo4j.gds.collections.ha.HugeDoubleArray;
2423
import org.neo4j.gds.collections.ha.HugeLongArray;
2524
import org.neo4j.gds.collections.haa.HugeAtomicLongArray;
26-
import org.neo4j.gds.core.concurrency.Concurrency;
27-
import org.neo4j.gds.core.concurrency.ParallelUtil;
28-
import org.neo4j.gds.core.utils.paged.ParalleLongPageCreator;
29-
import org.neo4j.gds.termination.TerminationFlag;
3025

31-
import java.util.concurrent.atomic.LongAdder;
32-
import java.util.stream.LongStream;
33-
34-
import static java.lang.Math.addExact;
35-
36-
@ValueClass
37-
interface RandomWalkProbabilities {
38-
39-
HugeAtomicLongArray nodeFrequencies();
40-
HugeDoubleArray positiveSamplingProbabilities();
41-
HugeLongArray negativeSamplingDistribution();
42-
long sampleCount();
43-
44-
@SuppressWarnings("immutables:incompat")
45-
class Builder {
46-
47-
private final long nodeCount;
48-
private final Concurrency concurrency;
49-
private final double positiveSamplingFactor;
50-
private final double negativeSamplingExponent;
51-
private final HugeAtomicLongArray nodeFrequencies;
52-
private final LongAdder sampleCount;
53-
54-
Builder(
55-
long nodeCount,
56-
Concurrency concurrency,
57-
double positiveSamplingFactor,
58-
double negativeSamplingExponent
59-
) {
60-
this.nodeCount = nodeCount;
61-
this.concurrency = concurrency;
62-
this.positiveSamplingFactor = positiveSamplingFactor;
63-
this.negativeSamplingExponent = negativeSamplingExponent;
64-
65-
this.nodeFrequencies = HugeAtomicLongArray.of(nodeCount, ParalleLongPageCreator.passThrough(concurrency));
66-
this.sampleCount = new LongAdder();
67-
}
68-
69-
//wip to break for the day
70-
void registerWalk(long[] walk) {
71-
for (long node : walk) {
72-
nodeFrequencies.getAndAdd(node, 1);
73-
}
74-
this.sampleCount.add(walk.length);
75-
76-
}
77-
78-
RandomWalkProbabilities build() {
79-
var centerProbabilities = computePositiveSamplingProbabilities();
80-
var contextDistribution = computeNegativeSamplingDistribution();
81-
82-
return ImmutableRandomWalkProbabilities
83-
.builder()
84-
.nodeFrequencies(nodeFrequencies)
85-
.positiveSamplingProbabilities(centerProbabilities)
86-
.negativeSamplingDistribution(contextDistribution)
87-
.sampleCount(sampleCount.longValue())
88-
.build();
89-
}
90-
91-
private HugeDoubleArray computePositiveSamplingProbabilities() {
92-
var centerProbabilities = HugeDoubleArray.newArray(nodeCount);
93-
var sum = sampleCount.longValue();
94-
95-
ParallelUtil.parallelStreamConsume(
96-
LongStream.range(0, nodeCount),
97-
concurrency,
98-
TerminationFlag.RUNNING_TRUE,
99-
nodeStream -> nodeStream.forEach(nodeId -> {
100-
double frequency = ((double) nodeFrequencies.get(nodeId)) / sum;
101-
centerProbabilities.set(
102-
nodeId,
103-
(Math.sqrt(frequency / positiveSamplingFactor) + 1) * (positiveSamplingFactor / frequency)
104-
);
105-
})
106-
);
107-
108-
return centerProbabilities;
109-
}
110-
111-
private HugeLongArray computeNegativeSamplingDistribution() {
112-
var contextDistribution = HugeLongArray.newArray(nodeCount);
113-
long sum = 0;
114-
for (var i = 0L; i < nodeCount; i++) {
115-
sum += Math.pow(nodeFrequencies.get(i), negativeSamplingExponent);
116-
sum = addExact(sum, (long) Math.pow(nodeFrequencies.get(i), negativeSamplingExponent));
117-
contextDistribution.set(i, sum);
118-
}
119-
120-
return contextDistribution;
121-
}
122-
}
123-
124-
}
26+
record RandomWalkProbabilities(
27+
HugeAtomicLongArray nodeFrequencies,
28+
HugeDoubleArray positiveSamplingProbabilities,
29+
HugeLongArray negativeSamplingDistribution,
30+
long sampleCount
31+
) {}
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.embeddings.node2vec;
21+
22+
import org.neo4j.gds.collections.ha.HugeDoubleArray;
23+
import org.neo4j.gds.collections.ha.HugeLongArray;
24+
import org.neo4j.gds.collections.haa.HugeAtomicLongArray;
25+
import org.neo4j.gds.core.concurrency.Concurrency;
26+
import org.neo4j.gds.core.concurrency.ParallelUtil;
27+
import org.neo4j.gds.core.utils.paged.ParalleLongPageCreator;
28+
import org.neo4j.gds.termination.TerminationFlag;
29+
30+
import java.util.concurrent.atomic.LongAdder;
31+
import java.util.stream.LongStream;
32+
33+
import static java.lang.Math.addExact;
34+
35+
class RandomWalkProbabilitiesBuilder {
36+
37+
private final long nodeCount;
38+
private final Concurrency concurrency;
39+
private final double positiveSamplingFactor;
40+
private final double negativeSamplingExponent;
41+
private final HugeAtomicLongArray nodeFrequencies;
42+
private final LongAdder sampleCount;
43+
44+
RandomWalkProbabilitiesBuilder(
45+
long nodeCount,
46+
Concurrency concurrency,
47+
double positiveSamplingFactor,
48+
double negativeSamplingExponent
49+
) {
50+
this.nodeCount = nodeCount;
51+
this.concurrency = concurrency;
52+
this.positiveSamplingFactor = positiveSamplingFactor;
53+
this.negativeSamplingExponent = negativeSamplingExponent;
54+
55+
this.nodeFrequencies = HugeAtomicLongArray.of(nodeCount, ParalleLongPageCreator.passThrough(concurrency));
56+
this.sampleCount = new LongAdder();
57+
}
58+
59+
//wip to break for the day
60+
void registerWalk(long[] walk) {
61+
for (long node : walk) {
62+
nodeFrequencies.getAndAdd(node, 1);
63+
}
64+
this.sampleCount.add(walk.length);
65+
66+
}
67+
68+
RandomWalkProbabilities build() {
69+
var centerProbabilities = computePositiveSamplingProbabilities();
70+
var contextDistribution = computeNegativeSamplingDistribution();
71+
72+
return new RandomWalkProbabilities(
73+
nodeFrequencies,
74+
centerProbabilities,
75+
contextDistribution,
76+
sampleCount.longValue()
77+
);
78+
}
79+
80+
private HugeDoubleArray computePositiveSamplingProbabilities() {
81+
var centerProbabilities = HugeDoubleArray.newArray(nodeCount);
82+
var sum = sampleCount.longValue();
83+
84+
ParallelUtil.parallelStreamConsume(
85+
LongStream.range(0, nodeCount),
86+
concurrency,
87+
TerminationFlag.RUNNING_TRUE,
88+
nodeStream -> nodeStream.forEach(nodeId -> {
89+
double frequency = ((double) nodeFrequencies.get(nodeId)) / sum;
90+
centerProbabilities.set(
91+
nodeId,
92+
(Math.sqrt(frequency / positiveSamplingFactor) + 1) * (positiveSamplingFactor / frequency)
93+
);
94+
})
95+
);
96+
97+
return centerProbabilities;
98+
}
99+
100+
private HugeLongArray computeNegativeSamplingDistribution() {
101+
var contextDistribution = HugeLongArray.newArray(nodeCount);
102+
long sum = 0;
103+
for (var i = 0L; i < nodeCount; i++) {
104+
sum += Math.pow(nodeFrequencies.get(i), negativeSamplingExponent);
105+
sum = addExact(sum, (long) Math.pow(nodeFrequencies.get(i), negativeSamplingExponent));
106+
contextDistribution.set(i, sum);
107+
}
108+
109+
return contextDistribution;
110+
}
111+
}

algo/src/test/java/org/neo4j/gds/embeddings/node2vec/NegativeSampleProducerTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class NegativeSampleProducerTest {
3333
@Test
3434
void shouldProduceSamplesAccordingToNodeDistribution() {
3535

36-
var builder = new RandomWalkProbabilities.Builder(
36+
var builder = new RandomWalkProbabilitiesBuilder(
3737
2,
3838
new Concurrency(4),
3939
0.001,

algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecMemoryEstimateDefinitionTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ void shouldEstimateMemory() {
4343

4444
MemoryEstimationAssert.assertThat(memoryEstimation)
4545
.memoryRange(1000, new Concurrency(1))
46-
.hasSameMinAndMaxEqualTo(7_688_448L);
46+
.hasSameMinAndMaxEqualTo(7_688_464L);
4747
}
4848

4949
}

algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecModelTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ void testModel() {
6363
int numberOfWalks = 10;
6464
int walkLength = 80;
6565

66-
var probabilitiesBuilder = new RandomWalkProbabilities.Builder(
66+
var probabilitiesBuilder = new RandomWalkProbabilitiesBuilder(
6767
numberOfClusters * clusterSize,
6868
new Concurrency(4),
6969
0.001,
@@ -173,7 +173,7 @@ void twoRunsSingleThreadedWithTheSameRandomSeed(int iterations) {
173173
int numberOfWalks = 10;
174174
int walkLength = 80;
175175

176-
var probabilitiesBuilder = new RandomWalkProbabilities.Builder(
176+
var probabilitiesBuilder = new RandomWalkProbabilitiesBuilder(
177177
numberOfClusters * clusterSize,
178178
new Concurrency(4),
179179
0.001,
@@ -264,7 +264,7 @@ void shouldCreateTrainingTasksWithCorrectRandomSeed() {
264264
}
265265

266266
private static CompressedRandomWalks generateRandomWalks(
267-
RandomWalkProbabilities.Builder probabilitiesBuilder,
267+
RandomWalkProbabilitiesBuilder probabilitiesBuilder,
268268
long numberOfClusters,
269269
int clusterSize,
270270
long numberOfWalks,

0 commit comments

Comments
 (0)