Skip to content

Commit 5faedc4

Browse files
Generalize input nodes and create TargetNodesWProperties
Co-authored-by: Alfred Clemedtson <alfred.clemedtson@neo4j.com>
1 parent f63dd41 commit 5faedc4

File tree

13 files changed

+249
-162
lines changed

13 files changed

+249
-162
lines changed

algo/src/main/java/org/neo4j/gds/pagerank/InitialProbabilityFactory.java

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
*/
2020
package org.neo4j.gds.pagerank;
2121

22-
import org.neo4j.gds.config.ListSourceNodes;
23-
import org.neo4j.gds.config.MapSourceNodes;
24-
import org.neo4j.gds.config.SourceNodes;
22+
import org.neo4j.gds.config.InputNodes;
23+
import org.neo4j.gds.config.ListInputNodes;
24+
import org.neo4j.gds.config.MapInputNodes;
2525

2626
import java.util.HashMap;
2727
import java.util.function.LongUnaryOperator;
@@ -33,20 +33,20 @@ private InitialProbabilityFactory() {}
3333
public static InitialProbabilityProvider create(
3434
LongUnaryOperator toMappedId,
3535
double alpha,
36-
SourceNodes sourceNodes
36+
InputNodes sourceNodes
3737
) {
38-
if (sourceNodes == SourceNodes.EMPTY_SOURCE_NODES) {
38+
if (sourceNodes == InputNodes.EMPTY_INPUT_NODES) {
3939
return new GlobalRestartProbability(alpha);
40-
} else if (sourceNodes instanceof ListSourceNodes) {
41-
var newSourceNodes = sourceNodes.sourceNodes()
40+
} else if (sourceNodes instanceof ListInputNodes) {
41+
var newSourceNodes = sourceNodes.inputNodes()
4242
.stream()
4343
.mapToLong(toMappedId::applyAsLong)
4444
.boxed()
4545
.toList();
4646
return new SourceBasedRestartProbabilityList(alpha, newSourceNodes);
47-
} else if (sourceNodes instanceof MapSourceNodes) {
47+
} else if (sourceNodes instanceof MapInputNodes) {
4848
var newMap = new HashMap<Long, Double>();
49-
for (var entry : ((MapSourceNodes) sourceNodes).map().entrySet()) {
49+
for (var entry : ((MapInputNodes) sourceNodes).map().entrySet()) {
5050
var newKey = toMappedId.applyAsLong(entry.getKey());
5151
newMap.put(newKey, entry.getValue());
5252
}

algo/src/test/java/org/neo4j/gds/pagerank/InitialProbabilityFactoryTest.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
package org.neo4j.gds.pagerank;
2121

2222
import org.junit.jupiter.api.Test;
23-
import org.neo4j.gds.config.SourceNodes;
24-
import org.neo4j.gds.config.SourceNodesFactory;
23+
import org.neo4j.gds.config.InputNodes;
24+
import org.neo4j.gds.config.InputNodesFactory;
2525

2626
import java.util.List;
2727

@@ -32,23 +32,23 @@ class InitialProbabilityFactoryTest {
3232
@Test
3333
void testInitialProbabilityEmpty() {
3434
var alpha = 0.15;
35-
var sourceNodes = SourceNodes.EMPTY_SOURCE_NODES;
35+
var sourceNodes = InputNodes.EMPTY_INPUT_NODES;
3636
InitialProbabilityProvider initialProbabilityProvider = InitialProbabilityFactory.create((x) -> (2*x), alpha, sourceNodes);
3737
assertThat(initialProbabilityProvider).isInstanceOf(GlobalRestartProbability.class);
3838
}
3939

4040
@Test
4141
void testInitialProbabilityList() {
4242
var alpha = 0.15;
43-
var sourceNodesList = SourceNodesFactory.parse(List.of(0L,2L,10L));
43+
var sourceNodesList = InputNodesFactory.parse(List.of(0L,2L,10L),"FOO");
4444
InitialProbabilityProvider initialProbabilityProvider = InitialProbabilityFactory.create((x) -> (2*x), alpha, sourceNodesList);
4545
assertThat(initialProbabilityProvider).isInstanceOf(SourceBasedRestartProbabilityList.class);
4646
}
4747

4848
@Test
4949
void testInitialProbabilityListOfLists() {
5050
var alpha = 0.15;
51-
var sourceNodesMap = SourceNodesFactory.parse(List.of(List.of(0L, 1D), List.of(5L, 0.1D)));
51+
var sourceNodesMap = InputNodesFactory.parse(List.of(List.of(0L, 1D), List.of(5L, 0.1D)),"FOO");
5252
InitialProbabilityProvider initialProbabilityProvider = InitialProbabilityFactory.create((x) -> (2*x), alpha, sourceNodesMap);
5353
assertThat(initialProbabilityProvider).isInstanceOf(SourceBasedRestartProbability.class);
5454
}

applications/algorithms/centrality/src/main/java/org/neo4j/gds/applications/algorithms/centrality/CentralityAlgorithms.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,8 @@ private EigenvectorComputation<EigenvectorConfig> eigenvectorComputation(
266266
Graph graph,
267267
EigenvectorConfig configuration
268268
) {
269-
var mappedSourceNodes = new LongScatterSet(configuration.sourceNodes().sourceNodes().size());
270-
configuration.sourceNodes().sourceNodes().stream()
269+
var mappedSourceNodes = new LongScatterSet(configuration.sourceNodes().inputNodes().size());
270+
configuration.sourceNodes().inputNodes().stream()
271271
.mapToLong(graph::toMappedNodeId)
272272
.forEach(mappedSourceNodes::add);
273273

config-api/src/main/java/org/neo4j/gds/config/SourceNodes.java renamed to config-api/src/main/java/org/neo4j/gds/config/InputNodes.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@
2222
import java.util.Collection;
2323
import java.util.List;
2424

25-
public interface SourceNodes {
26-
SourceNodes EMPTY_SOURCE_NODES = new ListSourceNodes(List.of());
27-
String SOURCE_NODES_KEY = "sourceNodes";
25+
public interface InputNodes {
26+
InputNodes EMPTY_INPUT_NODES = new ListInputNodes(List.of());
2827

29-
Collection<Long> sourceNodes();
28+
Collection<Long> inputNodes();
3029
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.config;
21+
22+
import org.neo4j.gds.utils.StringFormatting;
23+
24+
import java.util.ArrayList;
25+
import java.util.List;
26+
27+
import static org.neo4j.gds.config.ConfigNodesValidations.nodesNotNegative;
28+
29+
public final class InputNodesFactory {
30+
31+
private InputNodesFactory() {}
32+
33+
public static InputNodes parse(Object object, String parameterName){
34+
if ( object instanceof InputNodes) {
35+
return (InputNodes) object;
36+
}
37+
38+
else if (object instanceof List<?> list && !list.isEmpty() && list.get(0) instanceof List){
39+
var mapsSourceNodes = new MapInputNodes(NodeIdParser.parseToMapOfNodeIdsWithProperties(object, parameterName));
40+
nodesNotNegative(mapsSourceNodes.inputNodes(), parameterName);
41+
return mapsSourceNodes;
42+
}
43+
var listSourceNodes = new ListInputNodes(NodeIdParser.parseToListOfNodeIds(object, parameterName));
44+
nodesNotNegative(listSourceNodes.inputNodes(), parameterName);
45+
return listSourceNodes;
46+
}
47+
48+
public static InputNodes parseAsList(Object object, String parameterName) {
49+
if ( object instanceof InputNodes) {
50+
return (InputNodes) object;
51+
}
52+
var listSourceNodes = new ListInputNodes(NodeIdParser.parseToListOfNodeIds(object, parameterName));
53+
nodesNotNegative(listSourceNodes.inputNodes(), parameterName);
54+
return listSourceNodes;
55+
}
56+
57+
public static List toMapOutput(InputNodes inputNodes, String parameterName) {
58+
if ( inputNodes instanceof ListInputNodes) {
59+
return new ArrayList<>(inputNodes.inputNodes());
60+
} else if ( inputNodes instanceof MapInputNodes) {
61+
return ((MapInputNodes) inputNodes).map()
62+
.entrySet()
63+
.stream()
64+
.map(entry -> List.of(entry.getKey(), entry.getValue()))
65+
.toList();
66+
}else{
67+
throw new RuntimeException(StringFormatting.formatWithLocale("Not valid %s",parameterName));
68+
}
69+
}
70+
}

config-api/src/main/java/org/neo4j/gds/config/ListSourceNodes.java renamed to config-api/src/main/java/org/neo4j/gds/config/ListInputNodes.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@
2222
import java.util.Collection;
2323
import java.util.List;
2424

25-
public class ListSourceNodes implements SourceNodes {
25+
public class ListInputNodes implements InputNodes {
2626
private final List<Long> sourceNodes;
2727

28-
public ListSourceNodes(List<Long> sourceNodes) {
28+
public ListInputNodes(List<Long> sourceNodes) {
2929
this.sourceNodes = sourceNodes;
3030
}
3131

3232
@Override
33-
public Collection<Long> sourceNodes() {
33+
public Collection<Long> inputNodes() {
3434
return sourceNodes;
3535
}
3636

config-api/src/main/java/org/neo4j/gds/config/MapSourceNodes.java renamed to config-api/src/main/java/org/neo4j/gds/config/MapInputNodes.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,16 @@
2222
import java.util.Collection;
2323
import java.util.Map;
2424

25-
public class MapSourceNodes implements SourceNodes {
25+
public class MapInputNodes implements InputNodes {
2626
private final Map<Long,Double> sourceNodes;
2727

28-
public MapSourceNodes(Map<Long,Double> sourceNodes) {
28+
public MapInputNodes(Map<Long,Double> sourceNodes) {
2929
this.sourceNodes = sourceNodes;
3030

3131
}
3232

3333
@Override
34-
public Collection<Long> sourceNodes() {
34+
public Collection<Long> inputNodes() {
3535
return sourceNodes.keySet();
3636
}
3737

config-api/src/main/java/org/neo4j/gds/config/SourceNodesFactory.java

Lines changed: 0 additions & 68 deletions
This file was deleted.

config-api/src/main/java/org/neo4j/gds/config/SourceNodesWithPropertiesConfig.java

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,18 @@
2525
import org.neo4j.gds.api.GraphStore;
2626

2727
import java.util.Collection;
28+
import java.util.List;
2829

2930
import static org.neo4j.gds.config.ConfigNodesValidations.nodesExistInGraph;
3031

3132
public interface SourceNodesWithPropertiesConfig {
3233

3334
String SOURCE_NODES_KEY = "sourceNodes";
3435

35-
@Configuration.ConvertWith(method = "org.neo4j.gds.config.SourceNodesFactory#parse")
36-
@Configuration.ToMapValue("org.neo4j.gds.config.SourceNodesFactory#toMapOutput")
37-
default SourceNodes sourceNodes() {
38-
return SourceNodes.EMPTY_SOURCE_NODES;
36+
@Configuration.ConvertWith(method = "org.neo4j.gds.config.SourceNodesWithPropertiesConfig.SourceNodesFactory#parse")
37+
@Configuration.ToMapValue("org.neo4j.gds.config.SourceNodesWithPropertiesConfig.SourceNodesFactory#toMapOutput")
38+
default InputNodes sourceNodes() {
39+
return InputNodes.EMPTY_INPUT_NODES;
3940
}
4041

4142
@Configuration.GraphStoreValidationCheck
@@ -44,7 +45,23 @@ default void validateSourceLabels(
4445
Collection<NodeLabel> selectedLabels,
4546
Collection<RelationshipType> selectedRelationshipTypes
4647
) {
47-
nodesExistInGraph(graphStore, selectedLabels, sourceNodes().sourceNodes(), SOURCE_NODES_KEY);
48+
nodesExistInGraph(graphStore, selectedLabels, sourceNodes().inputNodes(), SOURCE_NODES_KEY);
4849
}
4950

51+
final class SourceNodesFactory {
52+
53+
private SourceNodesFactory() {}
54+
55+
public static InputNodes parse(Object object){
56+
return InputNodesFactory.parse(object,SOURCE_NODES_KEY);
57+
}
58+
59+
public static InputNodes parseAsList(Object object) {
60+
return InputNodesFactory.parseAsList(object,SOURCE_NODES_KEY);
61+
}
62+
63+
public static List toMapOutput(InputNodes sourceNodes) {
64+
return InputNodesFactory.toMapOutput(sourceNodes,SOURCE_NODES_KEY);
65+
}
66+
}
5067
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.config;
21+
22+
import org.neo4j.gds.NodeLabel;
23+
import org.neo4j.gds.RelationshipType;
24+
import org.neo4j.gds.annotation.Configuration;
25+
import org.neo4j.gds.api.GraphStore;
26+
27+
import java.util.Collection;
28+
import java.util.List;
29+
30+
import static org.neo4j.gds.config.ConfigNodesValidations.nodesExistInGraph;
31+
32+
public interface TargetNodesWithPropertiesConfig {
33+
34+
String TARGET_NODES_KEY = "targetNodes";
35+
36+
@Configuration.ConvertWith(method = "org.neo4j.gds.config.TargetNodesWithPropertiesConfig.TargetNodesFactory#parse")
37+
@Configuration.ToMapValue("org.neo4j.gds.config.TargetNodesWithPropertiesConfig.TargetNodesFactory#toMapOutput")
38+
default InputNodes targetNodes() {
39+
return InputNodes.EMPTY_INPUT_NODES;
40+
}
41+
42+
@Configuration.GraphStoreValidationCheck
43+
default void validateSourceLabels(
44+
GraphStore graphStore,
45+
Collection<NodeLabel> selectedLabels,
46+
Collection<RelationshipType> selectedRelationshipTypes
47+
) {
48+
nodesExistInGraph(graphStore, selectedLabels, targetNodes().inputNodes(), TARGET_NODES_KEY);
49+
}
50+
51+
final class TargetNodesFactory {
52+
53+
private TargetNodesFactory() {}
54+
55+
public static InputNodes parse(Object object){
56+
return InputNodesFactory.parse(object,TARGET_NODES_KEY);
57+
}
58+
59+
public static InputNodes parseAsList(Object object) {
60+
return InputNodesFactory.parseAsList(object,TARGET_NODES_KEY);
61+
}
62+
63+
public static List toMapOutput(InputNodes sourceNodes) {
64+
return InputNodesFactory.toMapOutput(sourceNodes,TARGET_NODES_KEY);
65+
}
66+
}
67+
}

0 commit comments

Comments
 (0)