Skip to content

Commit b88347f

Browse files
arjungtensorflow-copybara
authored andcommitted
Define a function to strip neighbor features from a feature dictionary.
This will be used by the graph Estimator and graph Keras wrapper APIs to ignore neighbor features in all modes except training. PiperOrigin-RevId: 283655271
1 parent ddf7a70 commit b88347f

File tree

3 files changed

+148
-0
lines changed

3 files changed

+148
-0
lines changed

neural_structured_learning/lib/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from neural_structured_learning.lib.utils import maximize_within_unit_norm
1414
from neural_structured_learning.lib.utils import normalize
1515
from neural_structured_learning.lib.utils import replicate_embeddings
16+
from neural_structured_learning.lib.utils import strip_neighbor_features
1617
from neural_structured_learning.lib.utils import unpack_neighbor_features
1718

1819
__all__ = [
@@ -28,6 +29,7 @@
2829
'normalize',
2930
'pairwise_distance_wrapper',
3031
'replicate_embeddings',
32+
'strip_neighbor_features',
3133
'unpack_neighbor_features',
3234
'virtual_adv_regularizer',
3335
]

neural_structured_learning/lib/utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,3 +560,24 @@ def check_shape_compatibility(tensors, expected_shape):
560560
keep_rank)
561561

562562
return sample_features, neighbor_features, neighbor_weights
563+
564+
565+
def strip_neighbor_features(features, neighbor_config):
566+
"""Strips graph neighbor features from a feature dictionary.
567+
568+
Args:
569+
features: Dictionary of tensors mapping feature names to tensors. This
570+
dictionary includes sample features but may or may not include
571+
corresponding neighbor features for each sample feature.
572+
neighbor_config: An instance of `nsl.configs.GraphNeighborConfig`.
573+
574+
Returns:
575+
A dictionary mapping only sample feature names to tensors. Neighbor
576+
features in the input are not included.
577+
"""
578+
579+
return {
580+
key: value
581+
for key, value in features.items()
582+
if not key.startswith(neighbor_config.prefix)
583+
}

neural_structured_learning/lib/utils_test.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,5 +729,130 @@ def _unpack_neighbor_features(features):
729729
self.assertAllEqual(nbr_weights, expected_neighbor_weights)
730730

731731

732+
class StripNeighborFeaturesTest(tf.test.TestCase):
733+
"""Tests removal of neighbor features from a feature dictionary."""
734+
735+
def testEmptyFeatures(self):
736+
"""Tests strip_neighbor_features with empty input."""
737+
features = dict()
738+
neighbor_config = configs.GraphNeighborConfig()
739+
sample_features = utils.strip_neighbor_features(features, neighbor_config)
740+
741+
# We create a dummy tensor so that the computation graph is not empty.
742+
dummy_tensor = tf.constant(1.0)
743+
sample_features, dummy_tensor = self.evaluate(
744+
[sample_features, dummy_tensor])
745+
self.assertEmpty(sample_features)
746+
747+
def testNoNeighborFeatures(self):
748+
"""Tests strip_neighbor_features when there are no neighbor features."""
749+
features = {'F0': tf.constant(11.0, shape=[2, 2])}
750+
neighbor_config = configs.GraphNeighborConfig()
751+
sample_features = utils.strip_neighbor_features(features, neighbor_config)
752+
753+
expected_sample_features = {'F0': tf.constant(11.0, shape=[2, 2])}
754+
755+
sample_features = self.evaluate(sample_features)
756+
757+
# Check that only the sample features are retained.
758+
feature_keys = sorted(sample_features.keys())
759+
self.assertListEqual(feature_keys, ['F0'])
760+
761+
# Check that the values of the sample feature remains unchanged.
762+
self.assertAllEqual(sample_features['F0'], expected_sample_features['F0'])
763+
764+
def testBatchedFeatures(self):
765+
"""Tests strip_neighbor_features with batched input features."""
766+
features = {
767+
'F0':
768+
tf.constant(11.0, shape=[2, 2]),
769+
'F1':
770+
tf.SparseTensor(
771+
indices=[[0, 0], [0, 1]], values=[1.0, 2.0], dense_shape=[2,
772+
4]),
773+
'NL_nbr_0_F0':
774+
tf.constant(22.0, shape=[2, 2]),
775+
'NL_nbr_0_F1':
776+
tf.SparseTensor(
777+
indices=[[1, 0], [1, 1]], values=[3.0, 4.0], dense_shape=[2,
778+
4]),
779+
'NL_nbr_0_weight':
780+
tf.constant(0.25, shape=[2, 1]),
781+
}
782+
neighbor_config = configs.GraphNeighborConfig()
783+
sample_features = utils.strip_neighbor_features(features, neighbor_config)
784+
785+
expected_sample_features = {
786+
'F0':
787+
tf.constant(11.0, shape=[2, 2]),
788+
'F1':
789+
tf.SparseTensor(
790+
indices=[[0, 0], [0, 1]], values=[1.0, 2.0], dense_shape=[2,
791+
4]),
792+
}
793+
794+
sample_features = self.evaluate(sample_features)
795+
796+
# Check that only the sample features are retained.
797+
feature_keys = sorted(sample_features.keys())
798+
self.assertListEqual(feature_keys, ['F0', 'F1'])
799+
800+
# Check that the values of the sample features remain unchanged.
801+
self.assertAllEqual(sample_features['F0'], expected_sample_features['F0'])
802+
self.assertAllEqual(sample_features['F1'].values,
803+
expected_sample_features['F1'].values)
804+
self.assertAllEqual(sample_features['F1'].indices,
805+
expected_sample_features['F1'].indices)
806+
self.assertAllEqual(sample_features['F1'].dense_shape,
807+
expected_sample_features['F1'].dense_shape)
808+
809+
def testFeaturesWithDynamicBatchSizeAndFeatureShape(self):
810+
"""Tests the case when the batch size and feature shape are both dynamic."""
811+
# Use a dynamic batch size and a dynamic feature shape. The former
812+
# corresponds to the first dimension of the tensors defined below, and the
813+
# latter corresonponds to the second dimension of 'sample_features' and
814+
# 'neighbor_i_features'.
815+
816+
feature_specs = {
817+
'F0': tf.TensorSpec((None, None, 3), tf.float32),
818+
'NL_nbr_0_F0': tf.TensorSpec((None, None, 3), tf.float32),
819+
'NL_nbr_0_weight': tf.TensorSpec((None, 1), tf.float32),
820+
}
821+
822+
# Specify a batch size of 3 and a pre-batching feature shape of 2x3 at run
823+
# time.
824+
sample1 = [[1, 2, 3], [3, 2, 1]]
825+
sample2 = [[4, 5, 6], [6, 5, 4]]
826+
sample3 = [[7, 8, 9], [9, 8, 7]]
827+
sample_features = [sample1, sample2, sample3] # 3x2x3
828+
829+
neighbor_0_features = [[[1, 3, 5], [5, 3, 1]], [[7, 9, 11], [11, 9, 7]],
830+
[[13, 15, 17], [17, 15, 13]]] # 3x2x3
831+
neighbor_0_weights = [[0.25], [0.5], [0.75]] # 3x1
832+
833+
expected_sample_features = {'F0': sample_features}
834+
835+
features = {
836+
'F0': sample_features,
837+
'NL_nbr_0_F0': neighbor_0_features,
838+
'NL_nbr_0_weight': neighbor_0_weights,
839+
}
840+
841+
neighbor_config = configs.GraphNeighborConfig()
842+
843+
@tf.function(input_signature=[feature_specs])
844+
def _strip_neighbor_features(features):
845+
return utils.strip_neighbor_features(features, neighbor_config)
846+
847+
sample_features = self.evaluate(_strip_neighbor_features(features))
848+
849+
# Check that only the sample features are retained.
850+
feature_keys = sorted(sample_features.keys())
851+
self.assertListEqual(feature_keys, ['F0'])
852+
853+
# Check that the value of the sample feature remains unchanged.
854+
self.assertAllEqual(sample_features['F0'], expected_sample_features['F0'])
855+
856+
732857
if __name__ == '__main__':
733858
tf.test.main()

0 commit comments

Comments
 (0)