@@ -729,5 +729,130 @@ def _unpack_neighbor_features(features):
729729 self .assertAllEqual (nbr_weights , expected_neighbor_weights )
730730
731731
732+ class StripNeighborFeaturesTest (tf .test .TestCase ):
733+ """Tests removal of neighbor features from a feature dictionary."""
734+
735+ def testEmptyFeatures (self ):
736+ """Tests strip_neighbor_features with empty input."""
737+ features = dict ()
738+ neighbor_config = configs .GraphNeighborConfig ()
739+ sample_features = utils .strip_neighbor_features (features , neighbor_config )
740+
741+ # We create a dummy tensor so that the computation graph is not empty.
742+ dummy_tensor = tf .constant (1.0 )
743+ sample_features , dummy_tensor = self .evaluate (
744+ [sample_features , dummy_tensor ])
745+ self .assertEmpty (sample_features )
746+
747+ def testNoNeighborFeatures (self ):
748+ """Tests strip_neighbor_features when there are no neighbor features."""
749+ features = {'F0' : tf .constant (11.0 , shape = [2 , 2 ])}
750+ neighbor_config = configs .GraphNeighborConfig ()
751+ sample_features = utils .strip_neighbor_features (features , neighbor_config )
752+
753+ expected_sample_features = {'F0' : tf .constant (11.0 , shape = [2 , 2 ])}
754+
755+ sample_features = self .evaluate (sample_features )
756+
757+ # Check that only the sample features are retained.
758+ feature_keys = sorted (sample_features .keys ())
759+ self .assertListEqual (feature_keys , ['F0' ])
760+
761+ # Check that the values of the sample feature remains unchanged.
762+ self .assertAllEqual (sample_features ['F0' ], expected_sample_features ['F0' ])
763+
764+ def testBatchedFeatures (self ):
765+ """Tests strip_neighbor_features with batched input features."""
766+ features = {
767+ 'F0' :
768+ tf .constant (11.0 , shape = [2 , 2 ]),
769+ 'F1' :
770+ tf .SparseTensor (
771+ indices = [[0 , 0 ], [0 , 1 ]], values = [1.0 , 2.0 ], dense_shape = [2 ,
772+ 4 ]),
773+ 'NL_nbr_0_F0' :
774+ tf .constant (22.0 , shape = [2 , 2 ]),
775+ 'NL_nbr_0_F1' :
776+ tf .SparseTensor (
777+ indices = [[1 , 0 ], [1 , 1 ]], values = [3.0 , 4.0 ], dense_shape = [2 ,
778+ 4 ]),
779+ 'NL_nbr_0_weight' :
780+ tf .constant (0.25 , shape = [2 , 1 ]),
781+ }
782+ neighbor_config = configs .GraphNeighborConfig ()
783+ sample_features = utils .strip_neighbor_features (features , neighbor_config )
784+
785+ expected_sample_features = {
786+ 'F0' :
787+ tf .constant (11.0 , shape = [2 , 2 ]),
788+ 'F1' :
789+ tf .SparseTensor (
790+ indices = [[0 , 0 ], [0 , 1 ]], values = [1.0 , 2.0 ], dense_shape = [2 ,
791+ 4 ]),
792+ }
793+
794+ sample_features = self .evaluate (sample_features )
795+
796+ # Check that only the sample features are retained.
797+ feature_keys = sorted (sample_features .keys ())
798+ self .assertListEqual (feature_keys , ['F0' , 'F1' ])
799+
800+ # Check that the values of the sample features remain unchanged.
801+ self .assertAllEqual (sample_features ['F0' ], expected_sample_features ['F0' ])
802+ self .assertAllEqual (sample_features ['F1' ].values ,
803+ expected_sample_features ['F1' ].values )
804+ self .assertAllEqual (sample_features ['F1' ].indices ,
805+ expected_sample_features ['F1' ].indices )
806+ self .assertAllEqual (sample_features ['F1' ].dense_shape ,
807+ expected_sample_features ['F1' ].dense_shape )
808+
809+ def testFeaturesWithDynamicBatchSizeAndFeatureShape (self ):
810+ """Tests the case when the batch size and feature shape are both dynamic."""
811+ # Use a dynamic batch size and a dynamic feature shape. The former
812+ # corresponds to the first dimension of the tensors defined below, and the
813+ # latter corresonponds to the second dimension of 'sample_features' and
814+ # 'neighbor_i_features'.
815+
816+ feature_specs = {
817+ 'F0' : tf .TensorSpec ((None , None , 3 ), tf .float32 ),
818+ 'NL_nbr_0_F0' : tf .TensorSpec ((None , None , 3 ), tf .float32 ),
819+ 'NL_nbr_0_weight' : tf .TensorSpec ((None , 1 ), tf .float32 ),
820+ }
821+
822+ # Specify a batch size of 3 and a pre-batching feature shape of 2x3 at run
823+ # time.
824+ sample1 = [[1 , 2 , 3 ], [3 , 2 , 1 ]]
825+ sample2 = [[4 , 5 , 6 ], [6 , 5 , 4 ]]
826+ sample3 = [[7 , 8 , 9 ], [9 , 8 , 7 ]]
827+ sample_features = [sample1 , sample2 , sample3 ] # 3x2x3
828+
829+ neighbor_0_features = [[[1 , 3 , 5 ], [5 , 3 , 1 ]], [[7 , 9 , 11 ], [11 , 9 , 7 ]],
830+ [[13 , 15 , 17 ], [17 , 15 , 13 ]]] # 3x2x3
831+ neighbor_0_weights = [[0.25 ], [0.5 ], [0.75 ]] # 3x1
832+
833+ expected_sample_features = {'F0' : sample_features }
834+
835+ features = {
836+ 'F0' : sample_features ,
837+ 'NL_nbr_0_F0' : neighbor_0_features ,
838+ 'NL_nbr_0_weight' : neighbor_0_weights ,
839+ }
840+
841+ neighbor_config = configs .GraphNeighborConfig ()
842+
843+ @tf .function (input_signature = [feature_specs ])
844+ def _strip_neighbor_features (features ):
845+ return utils .strip_neighbor_features (features , neighbor_config )
846+
847+ sample_features = self .evaluate (_strip_neighbor_features (features ))
848+
849+ # Check that only the sample features are retained.
850+ feature_keys = sorted (sample_features .keys ())
851+ self .assertListEqual (feature_keys , ['F0' ])
852+
853+ # Check that the value of the sample feature remains unchanged.
854+ self .assertAllEqual (sample_features ['F0' ], expected_sample_features ['F0' ])
855+
856+
732857if __name__ == '__main__' :
733858 tf .test .main ()
0 commit comments