Skip to content

Commit 897567e

Browse files
arjungtensorflow-copybara
authored andcommitted
Update graph-NSL tutorials not to extract neighbor features during evaluation. Neighbor features are required only when training a graph-regularized model.
Also, 'graph_loss' as a metric is no longer available, as evaluation of a graph-regularized model uses the underlying base model. PiperOrigin-RevId: 315991629
1 parent a3a64e3 commit 897567e

File tree

2 files changed

+111
-112
lines changed

2 files changed

+111
-112
lines changed

g3doc/tutorials/graph_keras_lstm_imdb.ipynb

Lines changed: 62 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
{
1+
{
22
"cells": [
33
{
44
"cell_type": "markdown",
@@ -794,58 +794,6 @@
794794
},
795795
"outputs": [],
796796
"source": [
797-
"def pad_sequence(sequence, max_seq_length):\n",
798-
" \"\"\"Pads the input sequence (a `tf.SparseTensor`) to `max_seq_length`.\"\"\"\n",
799-
" pad_size = tf.maximum([0], max_seq_length - tf.shape(sequence)[0])\n",
800-
" padded = tf.concat(\n",
801-
" [sequence.values,\n",
802-
" tf.fill((pad_size), tf.cast(0, sequence.dtype))],\n",
803-
" axis=0)\n",
804-
" # The input sequence may be larger than max_seq_length. Truncate down if\n",
805-
" # necessary.\n",
806-
" return tf.slice(padded, [0], [max_seq_length])\n",
807-
"\n",
808-
"def parse_example(example_proto):\n",
809-
" \"\"\"Extracts relevant fields from the `example_proto`.\n",
810-
"\n",
811-
" Args:\n",
812-
" example_proto: An instance of `tf.train.Example`.\n",
813-
"\n",
814-
" Returns:\n",
815-
" A pair whose first value is a dictionary containing relevant features\n",
816-
" and whose second value contains the ground truth labels.\n",
817-
" \"\"\"\n",
818-
" # The 'words' feature is a variable length word ID vector.\n",
819-
" feature_spec = {\n",
820-
" 'words': tf.io.VarLenFeature(tf.int64),\n",
821-
" 'label': tf.io.FixedLenFeature((), tf.int64, default_value=-1),\n",
822-
" }\n",
823-
" # We also extract corresponding neighbor features in a similar manner to\n",
824-
" # the features above.\n",
825-
" for i in range(HPARAMS.num_neighbors):\n",
826-
" nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')\n",
827-
" nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i, NBR_WEIGHT_SUFFIX)\n",
828-
" feature_spec[nbr_feature_key] = tf.io.VarLenFeature(tf.int64)\n",
829-
"\n",
830-
" # We assign a default value of 0.0 for the neighbor weight so that\n",
831-
" # graph regularization is done on samples based on their exact number\n",
832-
" # of neighbors. In other words, non-existent neighbors are discounted.\n",
833-
" feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(\n",
834-
" [1], tf.float32, default_value=tf.constant([0.0]))\n",
835-
"\n",
836-
" features = tf.io.parse_single_example(example_proto, feature_spec)\n",
837-
"\n",
838-
" # Since the 'words' feature is a variable length word vector, we pad it to a\n",
839-
" # constant maximum length based on HPARAMS.max_seq_length\n",
840-
" features['words'] = pad_sequence(features['words'], HPARAMS.max_seq_length)\n",
841-
" for i in range(HPARAMS.num_neighbors):\n",
842-
" nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')\n",
843-
" features[nbr_feature_key] = pad_sequence(features[nbr_feature_key],\n",
844-
" HPARAMS.max_seq_length)\n",
845-
"\n",
846-
" labels = features.pop('label')\n",
847-
" return features, labels\n",
848-
"\n",
849797
"def make_dataset(file_path, training=False):\n",
850798
" \"\"\"Creates a `tf.data.TFRecordDataset`.\n",
851799
"\n",
@@ -858,13 +806,70 @@
858806
" An instance of `tf.data.TFRecordDataset` containing the `tf.train.Example`\n",
859807
" objects.\n",
860808
" \"\"\"\n",
809+
"\n",
810+
" def pad_sequence(sequence, max_seq_length):\n",
811+
" \"\"\"Pads the input sequence (a `tf.SparseTensor`) to `max_seq_length`.\"\"\"\n",
812+
" pad_size = tf.maximum([0], max_seq_length - tf.shape(sequence)[0])\n",
813+
" padded = tf.concat(\n",
814+
" [sequence.values,\n",
815+
" tf.fill((pad_size), tf.cast(0, sequence.dtype))],\n",
816+
" axis=0)\n",
817+
" # The input sequence may be larger than max_seq_length. Truncate down if\n",
818+
" # necessary.\n",
819+
" return tf.slice(padded, [0], [max_seq_length])\n",
820+
"\n",
821+
" def parse_example(example_proto):\n",
822+
" \"\"\"Extracts relevant fields from the `example_proto`.\n",
823+
"\n",
824+
" Args:\n",
825+
" example_proto: An instance of `tf.train.Example`.\n",
826+
"\n",
827+
" Returns:\n",
828+
" A pair whose first value is a dictionary containing relevant features\n",
829+
" and whose second value contains the ground truth labels.\n",
830+
" \"\"\"\n",
831+
" # The 'words' feature is a variable length word ID vector.\n",
832+
" feature_spec = {\n",
833+
" 'words': tf.io.VarLenFeature(tf.int64),\n",
834+
" 'label': tf.io.FixedLenFeature((), tf.int64, default_value=-1),\n",
835+
" }\n",
836+
" # We also extract corresponding neighbor features in a similar manner to\n",
837+
" # the features above during training.\n",
838+
" if training:\n",
839+
" for i in range(HPARAMS.num_neighbors):\n",
840+
" nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')\n",
841+
" nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i,\n",
842+
" NBR_WEIGHT_SUFFIX)\n",
843+
" feature_spec[nbr_feature_key] = tf.io.VarLenFeature(tf.int64)\n",
844+
"\n",
845+
" # We assign a default value of 0.0 for the neighbor weight so that\n",
846+
" # graph regularization is done on samples based on their exact number\n",
847+
" # of neighbors. In other words, non-existent neighbors are discounted.\n",
848+
" feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(\n",
849+
" [1], tf.float32, default_value=tf.constant([0.0]))\n",
850+
"\n",
851+
" features = tf.io.parse_single_example(example_proto, feature_spec)\n",
852+
"\n",
853+
" # Since the 'words' feature is a variable length word vector, we pad it to a\n",
854+
" # constant maximum length based on HPARAMS.max_seq_length\n",
855+
" features['words'] = pad_sequence(features['words'], HPARAMS.max_seq_length)\n",
856+
" if training:\n",
857+
" for i in range(HPARAMS.num_neighbors):\n",
858+
" nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')\n",
859+
" features[nbr_feature_key] = pad_sequence(features[nbr_feature_key],\n",
860+
" HPARAMS.max_seq_length)\n",
861+
"\n",
862+
" labels = features.pop('label')\n",
863+
" return features, labels\n",
864+
"\n",
861865
" dataset = tf.data.TFRecordDataset([file_path])\n",
862866
" if training:\n",
863867
" dataset = dataset.shuffle(10000)\n",
864868
" dataset = dataset.map(parse_example)\n",
865869
" dataset = dataset.batch(HPARAMS.batch_size)\n",
866870
" return dataset\n",
867871
"\n",
872+
"\n",
868873
"train_dataset = make_dataset('/tmp/imdb/nsl_train_data.tfr', True)\n",
869874
"test_dataset = make_dataset('/tmp/imdb/test_data.tfr')"
870875
]
@@ -1357,11 +1362,10 @@
13571362
"id": "yBrp0Y0jHu5k"
13581363
},
13591364
"source": [
1360-
"There are six entries: one for each monitored metric -- loss, graph loss, and\n",
1361-
"accuracy -- during training and validation. We can use these to plot the\n",
1362-
"training, graph, and validation losses for comparison, as well as the training\n",
1363-
"and validation accuracy. Note that the graph loss is only computed during\n",
1364-
"training; so its value will be 0 during validation."
1365+
"There are five entries in total in the dictionary: training loss, training\n",
1366+
"accuracy, training graph loss, validation loss, and validation accuracy. We can\n",
1367+
"plot them all together for comparison. Note that the graph loss is only computed\n",
1368+
"during training."
13651369
]
13661370
},
13671371
{
@@ -1379,7 +1383,6 @@
13791383
"loss = graph_reg_history_dict['loss']\n",
13801384
"graph_loss = graph_reg_history_dict['graph_loss']\n",
13811385
"val_loss = graph_reg_history_dict['val_loss']\n",
1382-
"val_graph_loss = graph_reg_history_dict['val_graph_loss']\n",
13831386
"\n",
13841387
"epochs = range(1, len(acc) + 1)\n",
13851388
"\n",
@@ -1391,8 +1394,6 @@
13911394
"plt.plot(epochs, graph_loss, '-gD', label='Training graph loss')\n",
13921395
"# \"-b0\" is for solid blue line with circle markers.\n",
13931396
"plt.plot(epochs, val_loss, '-bo', label='Validation loss')\n",
1394-
"# \"-ms\" is for solid magenta line with square markers.\n",
1395-
"plt.plot(epochs, val_graph_loss, '-ms', label='Validation graph loss')\n",
13961397
"plt.title('Training and validation loss')\n",
13971398
"plt.xlabel('Epochs')\n",
13981399
"plt.ylabel('Loss')\n",

g3doc/tutorials/graph_keras_mlp_cora.ipynb

Lines changed: 49 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -402,52 +402,6 @@
402402
},
403403
"outputs": [],
404404
"source": [
405-
"def parse_example(example_proto):\n",
406-
" \"\"\"Extracts relevant fields from the `example_proto`.\n",
407-
"\n",
408-
" Args:\n",
409-
" example_proto: An instance of `tf.train.Example`.\n",
410-
"\n",
411-
" Returns:\n",
412-
" A pair whose first value is a dictionary containing relevant features\n",
413-
" and whose second value contains the ground truth label.\n",
414-
" \"\"\"\n",
415-
" # The 'words' feature is a multi-hot, bag-of-words representation of the\n",
416-
" # original raw text. A default value is required for examples that don't\n",
417-
" # have the feature.\n",
418-
" feature_spec = {\n",
419-
" 'words':\n",
420-
" tf.io.FixedLenFeature([HPARAMS.max_seq_length],\n",
421-
" tf.int64,\n",
422-
" default_value=tf.constant(\n",
423-
" 0,\n",
424-
" dtype=tf.int64,\n",
425-
" shape=[HPARAMS.max_seq_length])),\n",
426-
" 'label':\n",
427-
" tf.io.FixedLenFeature((), tf.int64, default_value=-1),\n",
428-
" }\n",
429-
" # We also extract corresponding neighbor features in a similar manner to\n",
430-
" # the features above.\n",
431-
" for i in range(HPARAMS.num_neighbors):\n",
432-
" nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')\n",
433-
" nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i, NBR_WEIGHT_SUFFIX)\n",
434-
" feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(\n",
435-
" [HPARAMS.max_seq_length],\n",
436-
" tf.int64,\n",
437-
" default_value=tf.constant(\n",
438-
" 0, dtype=tf.int64, shape=[HPARAMS.max_seq_length]))\n",
439-
"\n",
440-
" # We assign a default value of 0.0 for the neighbor weight so that\n",
441-
" # graph regularization is done on samples based on their exact number\n",
442-
" # of neighbors. In other words, non-existent neighbors are discounted.\n",
443-
" feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(\n",
444-
" [1], tf.float32, default_value=tf.constant([0.0]))\n",
445-
"\n",
446-
" features = tf.io.parse_single_example(example_proto, feature_spec)\n",
447-
" label = features.pop('label')\n",
448-
" return features, label\n",
449-
"\n",
450-
"\n",
451405
"def make_dataset(file_path, training=False):\n",
452406
" \"\"\"Creates a `tf.data.TFRecordDataset`.\n",
453407
"\n",
@@ -460,6 +414,55 @@
460414
" An instance of `tf.data.TFRecordDataset` containing the `tf.train.Example`\n",
461415
" objects.\n",
462416
" \"\"\"\n",
417+
"\n",
418+
" def parse_example(example_proto):\n",
419+
" \"\"\"Extracts relevant fields from the `example_proto`.\n",
420+
"\n",
421+
" Args:\n",
422+
" example_proto: An instance of `tf.train.Example`.\n",
423+
"\n",
424+
" Returns:\n",
425+
" A pair whose first value is a dictionary containing relevant features\n",
426+
" and whose second value contains the ground truth label.\n",
427+
" \"\"\"\n",
428+
" # The 'words' feature is a multi-hot, bag-of-words representation of the\n",
429+
" # original raw text. A default value is required for examples that don't\n",
430+
" # have the feature.\n",
431+
" feature_spec = {\n",
432+
" 'words':\n",
433+
" tf.io.FixedLenFeature([HPARAMS.max_seq_length],\n",
434+
" tf.int64,\n",
435+
" default_value=tf.constant(\n",
436+
" 0,\n",
437+
" dtype=tf.int64,\n",
438+
" shape=[HPARAMS.max_seq_length])),\n",
439+
" 'label':\n",
440+
" tf.io.FixedLenFeature((), tf.int64, default_value=-1),\n",
441+
" }\n",
442+
" # We also extract corresponding neighbor features in a similar manner to\n",
443+
" # the features above during training.\n",
444+
" if training:\n",
445+
" for i in range(HPARAMS.num_neighbors):\n",
446+
" nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')\n",
447+
" nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i,\n",
448+
" NBR_WEIGHT_SUFFIX)\n",
449+
" feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(\n",
450+
" [HPARAMS.max_seq_length],\n",
451+
" tf.int64,\n",
452+
" default_value=tf.constant(\n",
453+
" 0, dtype=tf.int64, shape=[HPARAMS.max_seq_length]))\n",
454+
"\n",
455+
" # We assign a default value of 0.0 for the neighbor weight so that\n",
456+
" # graph regularization is done on samples based on their exact number\n",
457+
" # of neighbors. In other words, non-existent neighbors are discounted.\n",
458+
" feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(\n",
459+
" [1], tf.float32, default_value=tf.constant([0.0]))\n",
460+
"\n",
461+
" features = tf.io.parse_single_example(example_proto, feature_spec)\n",
462+
"\n",
463+
" label = features.pop('label')\n",
464+
" return features, label\n",
465+
"\n",
463466
" dataset = tf.data.TFRecordDataset([file_path])\n",
464467
" if training:\n",
465468
" dataset = dataset.shuffle(10000)\n",
@@ -526,11 +529,6 @@
526529
"for feature_batch, label_batch in test_dataset.take(1):\n",
527530
" print('Feature list:', list(feature_batch.keys()))\n",
528531
" print('Batch of inputs:', feature_batch['words'])\n",
529-
" nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, 0, 'words')\n",
530-
" nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, 0, NBR_WEIGHT_SUFFIX)\n",
531-
" print('Batch of neighbor inputs:', feature_batch[nbr_feature_key])\n",
532-
" print('Batch of neighbor weights:',\n",
533-
" tf.reshape(feature_batch[nbr_weight_key], [-1]))\n",
534532
" print('Batch of labels:', label_batch)"
535533
]
536534
},

0 commit comments

Comments
 (0)