3030import org .neo4j .gds .ml .core .features .FeatureConsumer ;
3131import org .neo4j .gds .ml .core .features .FeatureExtraction ;
3232import org .neo4j .gds .ml .core .features .FeatureExtractor ;
33- import org .neo4j .gds .ml .util .ShuffleUtil ;
3433
35- import java .util .Arrays ;
3634import java .util .List ;
3735import java .util .SplittableRandom ;
3836import java .util .stream .Collectors ;
3937
38+ import static org .neo4j .gds .utils .StringFormatting .formatWithLocale ;
39+
4040class BinarizeTask implements Runnable {
4141 private final Partition partition ;
4242 private final HugeObjectArray <HugeAtomicBitSet > truncatedFeatures ;
4343 private final List <FeatureExtractor > featureExtractors ;
44- private final int [][] propertyEmbeddings ;
45- private final FeatureBinarizationConfig binarizationConfig ;
44+ private final double [][] propertyEmbeddings ;
45+
46+ private final double threshold ;
47+ private final BinarizeFeaturesConfig binarizationConfig ;
4648 private final ProgressTracker progressTracker ;
4749 private long totalNumFeatures ;
4850
51+ private double scalarProductSum ;
52+
53+ private double scalarProductSumOfSquares ;
54+
4955 BinarizeTask (
5056 Partition partition ,
5157 HashGNNConfig config ,
5258 HugeObjectArray <HugeAtomicBitSet > truncatedFeatures ,
5359 List <FeatureExtractor > featureExtractors ,
54- int [][] propertyEmbeddings ,
60+ double [][] propertyEmbeddings ,
5561 ProgressTracker progressTracker
5662 ) {
5763 this .partition = partition ;
5864 this .binarizationConfig = config .binarizeFeatures ().orElseThrow ();
65+ this .threshold = binarizationConfig .threshold ();
5966 this .truncatedFeatures = truncatedFeatures ;
6067 this .featureExtractors = featureExtractors ;
6168 this .propertyEmbeddings = propertyEmbeddings ;
@@ -73,6 +80,8 @@ static HugeObjectArray<HugeAtomicBitSet> compute(
7380 ) {
7481 progressTracker .beginSubTask ("Binarize node property features" );
7582
83+ var binarizationConfig = config .binarizeFeatures ().orElseThrow ();
84+
7685 var featureExtractors = FeatureExtraction .propertyExtractors (
7786 graph ,
7887 config .featureProperties ()
@@ -101,28 +110,33 @@ static HugeObjectArray<HugeAtomicBitSet> compute(
101110
102111 totalNumFeaturesOutput .add (tasks .stream ().mapToLong (BinarizeTask ::totalNumFeatures ).sum ());
103112
113+ var squaredSum = tasks .stream ().mapToDouble (BinarizeTask ::scalarProductSumOfSquares ).sum ();
114+ var sum = tasks .stream ().mapToDouble (BinarizeTask ::scalarProductSum ).sum ();
115+ long exampleCount = graph .nodeCount () * binarizationConfig .dimension ();
116+ var avg = sum / exampleCount ;
117+
118+ var variance = (squaredSum - exampleCount * avg * avg ) / exampleCount ;
119+ var std = Math .sqrt (variance );
120+
121+ progressTracker .logInfo (formatWithLocale ("Hyperplane scalar products have mean %.4f and standard deviation %.4f. A threshold for binarization may be set to the average plus a few standard deviations." , avg , std ));
122+
104123 progressTracker .endSubTask ("Binarize node property features" );
105124
106125 return truncatedFeatures ;
107126 }
108-
109- // creates a sparse projection array with one row per input feature
127+ // creates a random projection vector for each feature
110128 // (input features vector for each node is the concatenation of the node's properties)
111- // the first half of each row contains indices of positive output features in the projected space
112- // the second half of each row contains indices of negative output features in the projected space
113129 // this array is used embed the properties themselves from inputDimension to embeddingDimension dimensions.
114- public static int [][] embedProperties (HashGNNConfig config , SplittableRandom rng , int inputDimension ) {
130+ public static double [][] embedProperties (HashGNNConfig config , SplittableRandom rng , int inputDimension ) {
115131 var binarizationConfig = config .binarizeFeatures ().orElseThrow ();
116- var permutation = new int [binarizationConfig .dimension ()];
117- Arrays .setAll (permutation , i -> i );
118-
119- var propertyEmbeddings = new int [inputDimension ][];
132+ var propertyEmbeddings = new double [inputDimension ][];
120133
121134 for (int inputFeature = 0 ; inputFeature < inputDimension ; inputFeature ++) {
122- ShuffleUtil .shuffleArray (permutation , rng );
123- propertyEmbeddings [inputFeature ] = new int [2 * binarizationConfig .densityLevel ()];
124- for (int feature = 0 ; feature < 2 * binarizationConfig .densityLevel (); feature ++) {
125- propertyEmbeddings [inputFeature ][feature ] = permutation [feature ];
135+ propertyEmbeddings [inputFeature ] = new double [binarizationConfig .dimension ()];
136+ for (int feature = 0 ; feature < binarizationConfig .dimension (); feature ++) {
137+ // Box-muller transformation to generate gaussian
138+ double matrixValue = Math .sqrt (-2 *Math .log (rng .nextDouble (0.0 , 1.0 ))) * Math .cos (2 *Math .PI * rng .nextDouble (0.0 , 1.0 ));
139+ propertyEmbeddings [inputFeature ][feature ] = matrixValue ;
126140 }
127141 }
128142 return propertyEmbeddings ;
@@ -135,48 +149,54 @@ public void run() {
135149 FeatureExtraction .extract (nodeId , -1 , featureExtractors , new FeatureConsumer () {
136150 @ Override
137151 public void acceptScalar (long nodeOffset , int offset , double value ) {
138- for (int feature = 0 ; feature < binarizationConfig .densityLevel (); feature ++) {
139- int positiveFeature = propertyEmbeddings [offset ][feature ];
140- featureVector [positiveFeature ] += value ;
141- }
142-
143- for (int feature = binarizationConfig .densityLevel (); feature < 2 * binarizationConfig .densityLevel (); feature ++) {
144- int negativeFeature = propertyEmbeddings [offset ][feature ];
145- featureVector [negativeFeature ] -= value ;
146-
152+ for (int feature = 0 ; feature < binarizationConfig .dimension (); feature ++) {
153+ double featureValue = propertyEmbeddings [offset ][feature ];
154+ featureVector [feature ] += value * featureValue ;
147155 }
148156 }
149157
150158 @ Override
151159 public void acceptArray (long nodeOffset , int offset , double [] values ) {
152160 for (int inputFeatureOffset = 0 ; inputFeatureOffset < values .length ; inputFeatureOffset ++) {
153- for (int feature = 0 ; feature < binarizationConfig .densityLevel (); feature ++) {
154- int positiveFeature = propertyEmbeddings [offset + inputFeatureOffset ][feature ];
155- featureVector [positiveFeature ] += values [inputFeatureOffset ];
156- }
157- for (int feature = binarizationConfig .densityLevel (); feature < 2 * binarizationConfig .densityLevel (); feature ++) {
158- int negativeFeature = propertyEmbeddings [offset + inputFeatureOffset ][feature ];
159- featureVector [negativeFeature ] -= values [inputFeatureOffset ];
161+ double value = values [inputFeatureOffset ];
162+ for (int feature = 0 ; feature < binarizationConfig .dimension (); feature ++) {
163+ double featureValue = propertyEmbeddings [offset + inputFeatureOffset ][feature ];
164+ featureVector [feature ] += value * featureValue ;
160165 }
161166 }
162167 }
163168 });
164169
165- var bitSet = HugeAtomicBitSet .create (binarizationConfig .dimension ());
166- for (int feature = 0 ; feature < featureVector .length ; feature ++) {
167- if (featureVector [feature ] > 0 ) {
168- bitSet .set (feature );
169- }
170- }
171- totalNumFeatures += bitSet .cardinality ();
172- truncatedFeatures .set (nodeId , bitSet );
170+ var featureSet = round (featureVector );
171+ totalNumFeatures += featureSet .cardinality ();
172+ truncatedFeatures .set (nodeId , featureSet );
173173 });
174174
175175 progressTracker .logProgress (partition .nodeCount ());
176176 }
177177
178+ private HugeAtomicBitSet round (float [] floatVector ) {
179+ var bitset = HugeAtomicBitSet .create (floatVector .length );
180+ for (int feature = 0 ; feature < floatVector .length ; feature ++) {
181+ var scalarProduct = floatVector [feature ];
182+ scalarProductSum += scalarProduct ;
183+ scalarProductSumOfSquares += scalarProduct * scalarProduct ;
184+ if (scalarProduct > threshold ) {
185+ bitset .set (feature );
186+ }
187+ }
188+ return bitset ;
189+ }
190+
178191 public long totalNumFeatures () {
179192 return totalNumFeatures ;
180193 }
181194
195+ public double scalarProductSum () {
196+ return scalarProductSum ;
197+ }
198+ public double scalarProductSumOfSquares () {
199+ return scalarProductSumOfSquares ;
200+ }
201+
182202}
0 commit comments