Skip to content

Commit b15ead3

Browse files
committed
Refactor FastRPParameters
1 parent bd15381 commit b15ead3

File tree

4 files changed

+21
-17
lines changed

4 files changed

+21
-17
lines changed

algo-params/node-embeddings-params/src/main/java/org/neo4j/gds/embeddings/fastrp/FastRPParameters.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,14 @@ public record FastRPParameters(
3131
List<String> featureProperties,
3232
List<Number> iterationWeights,
3333
int embeddingDimension,
34-
int propertyDimension,
34+
double propertyRatio,
3535
Optional<String> relationshipWeightProperty,
3636
float normalizationStrength,
3737
Number nodeSelfInfluence,
3838
Concurrency concurrency,
3939
Optional<Long> randomSeed
40-
) implements AlgorithmParameters { }
40+
) implements AlgorithmParameters {
41+
public int propertyDimension() {
42+
return (int) (embeddingDimension() * propertyRatio());
43+
}
44+
}

algo/src/main/java/org/neo4j/gds/embeddings/fastrp/FastRPConfigTransformer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ public static FastRPParameters toParameters(FastRPBaseConfig config) {
2828
config.featureProperties(),
2929
config.iterationWeights(),
3030
config.embeddingDimension(),
31-
config.propertyDimension(),
31+
config.propertyRatio(),
3232
config.relationshipWeightProperty(),
3333
config.normalizationStrength(),
3434
config.nodeSelfInfluence(),

algo/src/test/java/org/neo4j/gds/embeddings/fastrp/FastRPTest.java

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ void shouldSwapInitialRandomVectors() {
130130
List.of("f1", "f2", "f3"),
131131
List.of(1.0D),
132132
DEFAULT_EMBEDDING_DIMENSION,
133-
(int) (0.5 * DEFAULT_EMBEDDING_DIMENSION),
133+
0.5,
134134
Optional.empty(),
135135
0.0F,
136136
0,
@@ -176,7 +176,7 @@ void shouldAverageNeighbors() {
176176
List.of("f1", "f2", "f3"),
177177
List.of(1.0D),
178178
DEFAULT_EMBEDDING_DIMENSION,
179-
(int) (0.5 * DEFAULT_EMBEDDING_DIMENSION),
179+
0.5,
180180
Optional.empty(),
181181
0.0F,
182182
0,
@@ -219,7 +219,7 @@ void shouldAddInitialVectors() {
219219
List.of("f1", "f2", "f3"),
220220
List.of(0.0D),
221221
embeddingDimension,
222-
(int) (0.5 * embeddingDimension),
222+
0.5,
223223
Optional.empty(),
224224
0.0F,
225225
0.6,
@@ -297,7 +297,7 @@ void shouldInitialisePropertyEmbeddingsCorrectly() {
297297
List.of("f1", "f2", "f3"),
298298
List.of(1.0D),
299299
DEFAULT_EMBEDDING_DIMENSION,
300-
(int) (0.5 * DEFAULT_EMBEDDING_DIMENSION),
300+
0.5,
301301
Optional.empty(),
302302
0.0F,
303303
0,
@@ -754,7 +754,7 @@ void shouldBeDeterministicInParallel() {
754754
List.of("f1", "f2", "f3"),
755755
List.of(1.0D),
756756
DEFAULT_EMBEDDING_DIMENSION,
757-
(int) (0.5 * DEFAULT_EMBEDDING_DIMENSION),
757+
0.5,
758758
Optional.empty(),
759759
0.0F,
760760
0,
@@ -766,7 +766,7 @@ void shouldBeDeterministicInParallel() {
766766
List.of("f1", "f2", "f3"),
767767
List.of(1.0D),
768768
DEFAULT_EMBEDDING_DIMENSION,
769-
(int) (0.5 * DEFAULT_EMBEDDING_DIMENSION),
769+
0.5,
770770
Optional.empty(),
771771
0.0F,
772772
0,
@@ -818,7 +818,7 @@ void shouldAverageNeighborsWeighted() {
818818
List.of("f1", "f2", "f3"),
819819
List.of(1.0D),
820820
DEFAULT_EMBEDDING_DIMENSION,
821-
(int) (0.5 * DEFAULT_EMBEDDING_DIMENSION),
821+
0.5,
822822
Optional.of("weight"),
823823
0.0F,
824824
0,
@@ -866,7 +866,7 @@ void shouldDistributeValuesCorrectly() {
866866
List.of(),
867867
List.of(1.0D),
868868
512,
869-
(int) (0.5 * DEFAULT_EMBEDDING_DIMENSION),
869+
0.125,
870870
Optional.empty(),
871871
0.0F,
872872
0,
@@ -924,7 +924,7 @@ void shouldYieldEmptyEmbeddingForIsolatedNodes() {
924924
List.of("f1", "f2", "f3"),
925925
List.of(1.0D),
926926
DEFAULT_EMBEDDING_DIMENSION,
927-
(int) (0.5 * DEFAULT_EMBEDDING_DIMENSION),
927+
0.5,
928928
Optional.empty(),
929929
0.0F,
930930
0,
@@ -1036,7 +1036,7 @@ void shouldFailWhenNodePropertiesAreMissing() {
10361036
List.of("prop"),
10371037
List.of(1.0D, 1.0D, 1.0D, 1.0D),
10381038
64,
1039-
0,
1039+
0.5,
10401040
Optional.empty(),
10411041
0.0F,
10421042
0,
@@ -1075,7 +1075,7 @@ void shouldFailWhenRelationshipWeightIsMissing() {
10751075
List.of(),
10761076
List.of(1.0D),
10771077
DEFAULT_EMBEDDING_DIMENSION,
1078-
0,
1078+
0.5,
10791079
Optional.of("weight"),
10801080
0.0F,
10811081
0,
@@ -1174,7 +1174,7 @@ void shouldBeDeterministicGivenSameOriginalIds() {
11741174
List.of(),
11751175
List.of(1.0D),
11761176
embeddingDimension,
1177-
0,
1177+
0.5,
11781178
Optional.empty(),
11791179
0.0F,
11801180
0,
@@ -1217,7 +1217,7 @@ private HugeObjectArray<float[]> embeddings(Graph graph, List<String> properties
12171217
List.of("f1", "f2", "f3"),
12181218
List.of(1.0D),
12191219
DEFAULT_EMBEDDING_DIMENSION,
1220-
(int) (0.5 * DEFAULT_EMBEDDING_DIMENSION),
1220+
0.5,
12211221
Optional.empty(),
12221222
0.0F,
12231223
0,

algorithms-compute-facade/src/test/java/org/neo4j/gds/embeddings/NodeEmbeddingComputeFacadeTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ void fastRP() {
118118
List.of("f"),
119119
List.of(1.0D),
120120
DEFAULT_EMBEDDING_DIMENSION,
121-
(int) (0.5 * DEFAULT_EMBEDDING_DIMENSION),
121+
0.0,
122122
Optional.empty(),
123123
0.0F,
124124
0,

0 commit comments

Comments
 (0)