1+ /*
2+ * Copyright (c) "Neo4j"
3+ * Neo4j Sweden AB [http://neo4j.com]
4+ *
5+ * This file is part of Neo4j.
6+ *
7+ * Neo4j is free software: you can redistribute it and/or modify
8+ * it under the terms of the GNU General Public License as published by
9+ * the Free Software Foundation, either version 3 of the License, or
10+ * (at your option) any later version.
11+ *
12+ * This program is distributed in the hope that it will be useful,
13+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
14+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+ * GNU General Public License for more details.
16+ *
17+ * You should have received a copy of the GNU General Public License
18+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
19+ */
20+ package org .neo4j .gds .betweenness ;
21+
22+ import com .carrotsearch .hppc .BitSet ;
23+ import com .carrotsearch .hppc .BitSetIterator ;
24+ import org .neo4j .gds .api .Graph ;
25+ import org .neo4j .gds .core .concurrency .RunWithConcurrency ;
26+ import org .neo4j .gds .core .utils .partition .Partition ;
27+ import org .neo4j .gds .core .utils .partition .PartitionUtils ;
28+
29+ import java .util .Collection ;
30+ import java .util .Optional ;
31+ import java .util .SplittableRandom ;
32+ import java .util .concurrent .ExecutorService ;
33+ import java .util .concurrent .atomic .AtomicInteger ;
34+ import java .util .concurrent .atomic .AtomicLong ;
35+ import java .util .stream .Collectors ;
36+
37+ import static com .carrotsearch .hppc .BitSetIterator .NO_MORE ;
38+
39+ public class RandomDegreeSelectionStrategy implements SelectionStrategy {
40+
41+ private final long samplingSize ;
42+ private final Optional <Long > maybeRandomSeed ;
43+ private final AtomicLong nodeQueue = new AtomicLong ();
44+
45+ private long graphSize ;
46+ private BitSet sampleSet ;
47+
48+ public RandomDegreeSelectionStrategy (long samplingSize ) {
49+ this (samplingSize , Optional .empty ());
50+ }
51+
52+ public RandomDegreeSelectionStrategy (long samplingSize , Optional <Long > maybeRandomSeed ) {
53+ this .samplingSize = samplingSize ;
54+ this .maybeRandomSeed = maybeRandomSeed ;
55+ }
56+
57+ @ Override
58+ public void init (Graph graph , ExecutorService executorService , int concurrency ) {
59+ assert samplingSize <= graph .nodeCount ();
60+ this .sampleSet = new BitSet (graph .nodeCount ());
61+ this .graphSize = graph .nodeCount ();
62+ nodeQueue .set (0 );
63+ var partitions = PartitionUtils .numberAlignedPartitioning (concurrency , graph .nodeCount (), Long .SIZE );
64+ var maxDegree = maxDegree (graph , partitions , executorService , concurrency );
65+ selectNodes (graph , partitions , maxDegree , executorService , concurrency );
66+ }
67+
68+ @ Override
69+ public long next () {
70+ long nextNodeId ;
71+ while ((nextNodeId = nodeQueue .getAndIncrement ()) < graphSize ) {
72+ if (sampleSet .get (nextNodeId )) {
73+ return nextNodeId ;
74+ }
75+ }
76+ return NONE_SELECTED ;
77+ }
78+
79+ private static int maxDegree (
80+ Graph graph ,
81+ Collection <Partition > partitions ,
82+ ExecutorService executorService ,
83+ int concurrency
84+ ) {
85+ AtomicInteger maxDegree = new AtomicInteger (0 );
86+
87+ var tasks = partitions .stream ()
88+ .map (partition -> (Runnable ) () -> partition .consume (nodeId -> {
89+ int degree = graph .degree (nodeId );
90+ int current = maxDegree .get ();
91+ while (degree > current ) {
92+ int newCurrent = maxDegree .compareAndExchange (current , degree );
93+ if (newCurrent == current ) {
94+ break ;
95+ }
96+ current = newCurrent ;
97+ }
98+ })).collect (Collectors .toList ());
99+
100+ RunWithConcurrency .builder ()
101+ .concurrency (concurrency )
102+ .tasks (tasks )
103+ .executor (executorService )
104+ .run ();
105+
106+ return maxDegree .get ();
107+ }
108+
109+ private void selectNodes (
110+ Graph graph ,
111+ Collection <Partition > partitions ,
112+ int maxDegree ,
113+ ExecutorService executorService ,
114+ int concurrency
115+ ) {
116+ var random = maybeRandomSeed .map (SplittableRandom ::new ).orElseGet (SplittableRandom ::new );
117+ var selectionSize = new AtomicLong (0 );
118+ var tasks = partitions .stream ()
119+ .map (partition -> (Runnable ) () -> {
120+ var threadLocalRandom = random .split ();
121+ var fromNode = partition .startNode ();
122+ var toNode = partition .startNode () + partition .nodeCount ();
123+
124+ for (long nodeId = fromNode ; nodeId < toNode ; nodeId ++) {
125+ var currentSelectionSize = selectionSize .get ();
126+ if (currentSelectionSize >= samplingSize ) {
127+ break ;
128+ }
129+ int nodeDegree = graph .degree (nodeId );
130+ // probability factor is in range [1, maxDegree] (inclusive both ends)
131+ // the probability of a node being selected is probabilityFactor * (1 / maxDegree)
132+ int probabilityFactor = threadLocalRandom .nextInt (maxDegree ) + 1 ;
133+ if (probabilityFactor <= nodeDegree ) {
134+ while (true ) {
135+ long actualCurrentSelectionSize = selectionSize .compareAndExchange (
136+ currentSelectionSize ,
137+ currentSelectionSize + 1
138+ );
139+ if (currentSelectionSize == actualCurrentSelectionSize ) {
140+ sampleSet .set (nodeId );
141+ break ;
142+ }
143+ if (actualCurrentSelectionSize >= samplingSize ) {
144+ break ;
145+ }
146+ currentSelectionSize = actualCurrentSelectionSize ;
147+ }
148+ }
149+ }
150+ }).collect (Collectors .toList ());
151+
152+ RunWithConcurrency .builder ()
153+ .concurrency (concurrency )
154+ .tasks (tasks )
155+ .executor (executorService )
156+ .run ();
157+
158+ long actualSelectedNodes = selectionSize .get ();
159+
160+ if (actualSelectedNodes < samplingSize ) {
161+ // Flip bitset to be able to iterate unset bits.
162+ // The upper range is Graph#nodeCount() since
163+ // BitSet#size() returns a multiple of 64.
164+ // We need to make sure to stay within bounds.
165+ sampleSet .flip (0 , graph .nodeCount ());
166+ // Potentially iterate the bitset multiple times
167+ // until we have exactly numSeedNodes nodes.
168+ BitSetIterator iterator ;
169+ while (actualSelectedNodes < samplingSize ) {
170+ iterator = sampleSet .iterator ();
171+ var unselectedNode = iterator .nextSetBit ();
172+ while (unselectedNode != NO_MORE && actualSelectedNodes < samplingSize ) {
173+ if (random .nextDouble () >= 0.5 ) {
174+ sampleSet .flip (unselectedNode );
175+ actualSelectedNodes ++;
176+ }
177+ unselectedNode = iterator .nextSetBit ();
178+ }
179+ }
180+ sampleSet .flip (0 , graph .nodeCount ());
181+ }
182+ }
183+ }
0 commit comments