1919 */
2020package org .neo4j .gds .embeddings .hashgnn ;
2121
22- import com . carrotsearch . hppc . BitSet ;
22+ import org . apache . commons . lang3 . mutable . MutableLong ;
2323import org .neo4j .gds .api .Graph ;
2424import org .neo4j .gds .core .concurrency .RunWithConcurrency ;
2525import org .neo4j .gds .core .utils .TerminationFlag ;
3030import org .neo4j .gds .ml .core .features .FeatureConsumer ;
3131import org .neo4j .gds .ml .core .features .FeatureExtraction ;
3232import org .neo4j .gds .ml .core .features .FeatureExtractor ;
33- import org .neo4j .gds .ml .util .ShuffleUtil ;
3433
35- import java .util .ArrayList ;
36- import java .util .Arrays ;
3734import java .util .List ;
3835import java .util .SplittableRandom ;
3936import java .util .stream .Collectors ;
4037
41- import static org .neo4j .gds .embeddings . hashgnn . HashGNNCompanion . hashArgMin ;
38+ import static org .neo4j .gds .utils . StringFormatting . formatWithLocale ;
4239
4340class BinarizeTask implements Runnable {
4441 private final Partition partition ;
45- private final HashGNNConfig config ;
4642 private final HugeObjectArray <HugeAtomicBitSet > truncatedFeatures ;
4743 private final List <FeatureExtractor > featureExtractors ;
48- private final int [][] propertyEmbeddings ;
49- private final List < int []> hashesList ;
50- private final HashGNN . MinAndArgmin minAndArgMin ;
51- private final FeatureBinarizationConfig binarizationConfig ;
44+ private final double [][] propertyEmbeddings ;
45+
46+ private final double threshold ;
47+ private final int dimension ;
5248 private final ProgressTracker progressTracker ;
49+ private long totalFeatureCount ;
50+
51+ private double scalarProductSum ;
52+
53+ private double scalarProductSumOfSquares ;
5354
5455 BinarizeTask (
5556 Partition partition ,
56- HashGNNConfig config ,
57+ BinarizeFeaturesConfig config ,
5758 HugeObjectArray <HugeAtomicBitSet > truncatedFeatures ,
5859 List <FeatureExtractor > featureExtractors ,
59- int [][] propertyEmbeddings ,
60- List <int []> hashesList ,
60+ double [][] propertyEmbeddings ,
6161 ProgressTracker progressTracker
6262 ) {
6363 this .partition = partition ;
64- this .config = config ;
65- this .binarizationConfig = config .binarizeFeatures (). orElseThrow ();
64+ this .dimension = config . dimension () ;
65+ this .threshold = config .threshold ();
6666 this .truncatedFeatures = truncatedFeatures ;
6767 this .featureExtractors = featureExtractors ;
6868 this .propertyEmbeddings = propertyEmbeddings ;
69- this .hashesList = hashesList ;
70- this .minAndArgMin = new HashGNN .MinAndArgmin ();
7169 this .progressTracker = progressTracker ;
7270 }
7371
@@ -77,36 +75,30 @@ static HugeObjectArray<HugeAtomicBitSet> compute(
7775 HashGNNConfig config ,
7876 SplittableRandom rng ,
7977 ProgressTracker progressTracker ,
80- TerminationFlag terminationFlag
78+ TerminationFlag terminationFlag ,
79+ MutableLong totalFeatureCountOutput
8180 ) {
8281 progressTracker .beginSubTask ("Binarize node property features" );
8382
84- var hashesList = new ArrayList <int []>(config .embeddingDensity ());
85- for (int i = 0 ; i < config .embeddingDensity (); i ++) {
86- hashesList .add (HashGNNCompanion .HashTriple .computeHashesFromTriple (
87- config .binarizeFeatures ().get ().dimension (),
88- HashGNNCompanion .HashTriple .generate (rng )
89- ));
90- }
83+ var binarizationConfig = config .binarizeFeatures ().orElseThrow ();
9184
9285 var featureExtractors = FeatureExtraction .propertyExtractors (
9386 graph ,
9487 config .featureProperties ()
9588 );
9689
9790 var inputDimension = FeatureExtraction .featureCount (featureExtractors );
98- var propertyEmbeddings = embedProperties (config , rng , inputDimension );
91+ var propertyEmbeddings = embedProperties (binarizationConfig . dimension () , rng , inputDimension );
9992
10093 var truncatedFeatures = HugeObjectArray .newArray (HugeAtomicBitSet .class , graph .nodeCount ());
10194
10295 var tasks = partition .stream ()
10396 .map (p -> new BinarizeTask (
10497 p ,
105- config ,
98+ binarizationConfig ,
10699 truncatedFeatures ,
107100 featureExtractors ,
108101 propertyEmbeddings ,
109- hashesList ,
110102 progressTracker
111103 ))
112104 .collect (Collectors .toList ());
@@ -116,90 +108,105 @@ static HugeObjectArray<HugeAtomicBitSet> compute(
116108 .terminationFlag (terminationFlag )
117109 .run ();
118110
111+ totalFeatureCountOutput .add (tasks .stream ().mapToLong (BinarizeTask ::totalFeatureCount ).sum ());
112+
113+ var squaredSum = tasks .stream ().mapToDouble (BinarizeTask ::scalarProductSumOfSquares ).sum ();
114+ var sum = tasks .stream ().mapToDouble (BinarizeTask ::scalarProductSum ).sum ();
115+ long exampleCount = graph .nodeCount () * binarizationConfig .dimension ();
116+ var avg = sum / exampleCount ;
117+
118+ var variance = (squaredSum - exampleCount * avg * avg ) / exampleCount ;
119+ var std = Math .sqrt (variance );
120+
121+ progressTracker .logInfo (formatWithLocale (
122+ "Hyperplane scalar products have mean %.4f and standard deviation %.4f. A threshold for binarization may be set to the mean plus a few standard deviations." ,
123+ avg ,
124+ std
125+ ));
126+
119127 progressTracker .endSubTask ("Binarize node property features" );
120128
121129 return truncatedFeatures ;
122130 }
123131
124- // creates a sparse projection array with one row per input feature
132+ // creates a random projection vector for each feature
125133 // (input features vector for each node is the concatenation of the node's properties)
126- // the first half of each row contains indices of positive output features in the projected space
127- // the second half of each row contains indices of negative output features in the projected space
128134 // this array is used embed the properties themselves from inputDimension to embeddingDimension dimensions.
129- public static int [][] embedProperties (HashGNNConfig config , SplittableRandom rng , int inputDimension ) {
130- var binarizationConfig = config .binarizeFeatures ().orElseThrow ();
131- var permutation = new int [binarizationConfig .dimension ()];
132- Arrays .setAll (permutation , i -> i );
133-
134- var propertyEmbeddings = new int [inputDimension ][];
135+ public static double [][] embedProperties (int vectorDimension , SplittableRandom rng , int inputDimension ) {
136+ var propertyEmbeddings = new double [inputDimension ][];
135137
136138 for (int inputFeature = 0 ; inputFeature < inputDimension ; inputFeature ++) {
137- ShuffleUtil .shuffleArray (permutation , rng );
138- propertyEmbeddings [inputFeature ] = new int [2 * binarizationConfig .densityLevel ()];
139- for (int feature = 0 ; feature < 2 * binarizationConfig .densityLevel (); feature ++) {
140- propertyEmbeddings [inputFeature ][feature ] = permutation [feature ];
139+ propertyEmbeddings [inputFeature ] = new double [vectorDimension ];
140+ for (int feature = 0 ; feature < vectorDimension ; feature ++) {
141+ propertyEmbeddings [inputFeature ][feature ] = boxMullerGaussianRandom (rng );
141142 }
142143 }
143144 return propertyEmbeddings ;
144145 }
145146
147+ private static double boxMullerGaussianRandom (SplittableRandom rng ) {
148+ return Math .sqrt (-2 * Math .log (rng .nextDouble (
149+ 0.0 ,
150+ 1.0
151+ ))) * Math .cos (2 * Math .PI * rng .nextDouble (0.0 , 1.0 ));
152+ }
153+
146154 @ Override
147155 public void run () {
148- var tempFeatureContainer = new BitSet (binarizationConfig .dimension ());
149-
150156 partition .consume (nodeId -> {
151- var featureVector = new float [binarizationConfig . dimension () ];
157+ var featureVector = new float [dimension ];
152158 FeatureExtraction .extract (nodeId , -1 , featureExtractors , new FeatureConsumer () {
153159 @ Override
154160 public void acceptScalar (long nodeOffset , int offset , double value ) {
155- for (int feature = 0 ; feature < binarizationConfig .densityLevel (); feature ++) {
156- int positiveFeature = propertyEmbeddings [offset ][feature ];
157- featureVector [positiveFeature ] += value ;
158- }
159-
160- for (int feature = binarizationConfig .densityLevel (); feature < 2 * binarizationConfig .densityLevel (); feature ++) {
161- int negativeFeature = propertyEmbeddings [offset ][feature ];
162- featureVector [negativeFeature ] -= value ;
163-
161+ for (int feature = 0 ; feature < dimension ; feature ++) {
162+ double featureValue = propertyEmbeddings [offset ][feature ];
163+ featureVector [feature ] += value * featureValue ;
164164 }
165165 }
166166
167167 @ Override
168168 public void acceptArray (long nodeOffset , int offset , double [] values ) {
169169 for (int inputFeatureOffset = 0 ; inputFeatureOffset < values .length ; inputFeatureOffset ++) {
170- for (int feature = 0 ; feature < binarizationConfig .densityLevel (); feature ++) {
171- int positiveFeature = propertyEmbeddings [offset + inputFeatureOffset ][feature ];
172- featureVector [positiveFeature ] += values [inputFeatureOffset ];
173- }
174- for (int feature = binarizationConfig .densityLevel (); feature < 2 * binarizationConfig .densityLevel (); feature ++) {
175- int negativeFeature = propertyEmbeddings [offset + inputFeatureOffset ][feature ];
176- featureVector [negativeFeature ] -= values [inputFeatureOffset ];
170+ double value = values [inputFeatureOffset ];
171+ for (int feature = 0 ; feature < dimension ; feature ++) {
172+ double featureValue = propertyEmbeddings [offset + inputFeatureOffset ][feature ];
173+ featureVector [feature ] += value * featureValue ;
177174 }
178175 }
179176 }
180177 });
181178
182- truncatedFeatures .set (nodeId , roundAndSample (tempFeatureContainer , featureVector ));
179+ var featureSet = round (featureVector );
180+ totalFeatureCount += featureSet .cardinality ();
181+ truncatedFeatures .set (nodeId , featureSet );
183182 });
184183
185184 progressTracker .logProgress (partition .nodeCount ());
186185 }
187186
188- private HugeAtomicBitSet roundAndSample ( BitSet tempBitSet , float [] floatVector ) {
189- tempBitSet . clear ( );
187+ private HugeAtomicBitSet round ( float [] floatVector ) {
188+ var bitset = HugeAtomicBitSet . create ( floatVector . length );
190189 for (int feature = 0 ; feature < floatVector .length ; feature ++) {
191- if (floatVector [feature ] > 0 ) {
192- tempBitSet .set (feature );
193- }
194- }
195- var sampledBitset = HugeAtomicBitSet .create (binarizationConfig .dimension ());
196- for (int i = 0 ; i < config .embeddingDensity (); i ++) {
197- hashArgMin (tempBitSet , hashesList .get (i ), minAndArgMin );
198- if (minAndArgMin .argMin != -1 ) {
199- sampledBitset .set (minAndArgMin .argMin );
190+ var scalarProduct = floatVector [feature ];
191+ scalarProductSum += scalarProduct ;
192+ scalarProductSumOfSquares += scalarProduct * scalarProduct ;
193+ if (scalarProduct > threshold ) {
194+ bitset .set (feature );
200195 }
201196 }
202- return sampledBitset ;
197+ return bitset ;
198+ }
199+
200+ public long totalFeatureCount () {
201+ return totalFeatureCount ;
202+ }
203+
204+ public double scalarProductSum () {
205+ return scalarProductSum ;
206+ }
207+
208+ public double scalarProductSumOfSquares () {
209+ return scalarProductSumOfSquares ;
203210 }
204211
205212}
0 commit comments