2121
2222import org .apache .commons .lang3 .mutable .MutableInt ;
2323import org .neo4j .gds .api .Graph ;
24+ import org .neo4j .gds .beta .pregel .BasePregelComputation ;
2425import org .neo4j .gds .beta .pregel .Messenger ;
2526import org .neo4j .gds .beta .pregel .NodeValue ;
2627import org .neo4j .gds .beta .pregel .PregelConfig ;
3637 */
3738public class ComputeContext <CONFIG extends PregelConfig > extends NodeCentricContext <CONFIG > {
3839
39- final RelationshipWeightApplier relationshipWeightApplier ;
4040 private final HugeAtomicBitSet voteBits ;
4141
4242 private final Messenger <?> messenger ;
4343 private final MutableInt iteration ;
4444 private final AtomicBoolean hasSendMessage ;
45+ protected BasePregelComputation <CONFIG > computation ;
4546
4647 public ComputeContext (Graph graph ,
4748 CONFIG config ,
48- RelationshipWeightApplier relationshipWeightApplier ,
49+ BasePregelComputation < CONFIG > computation ,
4950 NodeValue nodeValue ,
5051 Messenger <?> messenger ,
5152 HugeAtomicBitSet voteBits ,
5253 MutableInt iteration ,
5354 AtomicBoolean hasSendMessage ,
5455 ProgressTracker progressTracker ) {
5556 super (graph , config , nodeValue , progressTracker );
56- this .relationshipWeightApplier = relationshipWeightApplier ;
57+ this .computation = computation ;
5758 this .sendMessagesFunction = config .hasRelationshipWeightProperty ()
5859 ? this ::sendToNeighborsWeighted
5960 : this ::sendToNeighbors ;
@@ -163,7 +164,7 @@ private void sendToNeighbors(long sourceNodeId, double message) {
163164
164165 private void sendToNeighborsWeighted (long sourceNodeId , double message ) {
165166 graph .forEachRelationship (sourceNodeId , 1.0 , (ignored , targetNodeId , weight ) -> {
166- sendTo (targetNodeId , relationshipWeightApplier .applyRelationshipWeight (message , weight ));
167+ sendTo (targetNodeId , computation .applyRelationshipWeight (message , weight ));
167168 return true ;
168169 });
169170 }
@@ -173,19 +174,14 @@ interface SendMessagesFunction {
173174 void sendToNeighbors (long sourceNodeId , double message );
174175 }
175176
176- @ FunctionalInterface
177- public interface RelationshipWeightApplier {
178- double applyRelationshipWeight (double nodeValue , double relationshipWeight );
179- }
180-
181177 public static final class BidirectionalComputeContext <CONFIG extends PregelConfig > extends ComputeContext <CONFIG > implements BidirectionalNodeCentricContext {
182178
183179 private final SendMessagesIncomingFunction sendMessagesIncomingFunction ;
184180
185181 public BidirectionalComputeContext (
186182 Graph graph ,
187183 CONFIG config ,
188- RelationshipWeightApplier relationshipWeightApplier ,
184+ BasePregelComputation < CONFIG > computation ,
189185 NodeValue nodeValue ,
190186 Messenger <?> messenger ,
191187 HugeAtomicBitSet voteBits ,
@@ -196,7 +192,7 @@ public BidirectionalComputeContext(
196192 super (
197193 graph ,
198194 config ,
199- relationshipWeightApplier ,
195+ computation ,
200196 nodeValue ,
201197 messenger ,
202198 voteBits ,
@@ -226,7 +222,7 @@ private void sendToIncomingNeighbors(long sourceNodeId, double message) {
226222
227223 private void sendToIncomingNeighborsWeighted (long sourceNodeId , double message ) {
228224 graph .forEachInverseRelationship (sourceNodeId , 1.0 , (ignored , targetNodeId , weight ) -> {
229- sendTo (targetNodeId , relationshipWeightApplier .applyRelationshipWeight (message , weight ));
225+ sendTo (targetNodeId , computation .applyRelationshipWeight (message , weight ));
230226 return true ;
231227 });
232228 }
0 commit comments