|
1 | | -{ |
| 1 | +{ |
2 | 2 | "cells": [ |
3 | 3 | { |
4 | 4 | "cell_type": "markdown", |
|
794 | 794 | }, |
795 | 795 | "outputs": [], |
796 | 796 | "source": [ |
797 | | - "def pad_sequence(sequence, max_seq_length):\n", |
798 | | - " \"\"\"Pads the input sequence (a `tf.SparseTensor`) to `max_seq_length`.\"\"\"\n", |
799 | | - " pad_size = tf.maximum([0], max_seq_length - tf.shape(sequence)[0])\n", |
800 | | - " padded = tf.concat(\n", |
801 | | - " [sequence.values,\n", |
802 | | - " tf.fill((pad_size), tf.cast(0, sequence.dtype))],\n", |
803 | | - " axis=0)\n", |
804 | | - " # The input sequence may be larger than max_seq_length. Truncate down if\n", |
805 | | - " # necessary.\n", |
806 | | - " return tf.slice(padded, [0], [max_seq_length])\n", |
807 | | - "\n", |
808 | | - "def parse_example(example_proto):\n", |
809 | | - " \"\"\"Extracts relevant fields from the `example_proto`.\n", |
810 | | - "\n", |
811 | | - " Args:\n", |
812 | | - " example_proto: An instance of `tf.train.Example`.\n", |
813 | | - "\n", |
814 | | - " Returns:\n", |
815 | | - " A pair whose first value is a dictionary containing relevant features\n", |
816 | | - " and whose second value contains the ground truth labels.\n", |
817 | | - " \"\"\"\n", |
818 | | - " # The 'words' feature is a variable length word ID vector.\n", |
819 | | - " feature_spec = {\n", |
820 | | - " 'words': tf.io.VarLenFeature(tf.int64),\n", |
821 | | - " 'label': tf.io.FixedLenFeature((), tf.int64, default_value=-1),\n", |
822 | | - " }\n", |
823 | | - " # We also extract corresponding neighbor features in a similar manner to\n", |
824 | | - " # the features above.\n", |
825 | | - " for i in range(HPARAMS.num_neighbors):\n", |
826 | | - " nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')\n", |
827 | | - " nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i, NBR_WEIGHT_SUFFIX)\n", |
828 | | - " feature_spec[nbr_feature_key] = tf.io.VarLenFeature(tf.int64)\n", |
829 | | - "\n", |
830 | | - " # We assign a default value of 0.0 for the neighbor weight so that\n", |
831 | | - " # graph regularization is done on samples based on their exact number\n", |
832 | | - " # of neighbors. In other words, non-existent neighbors are discounted.\n", |
833 | | - " feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(\n", |
834 | | - " [1], tf.float32, default_value=tf.constant([0.0]))\n", |
835 | | - "\n", |
836 | | - " features = tf.io.parse_single_example(example_proto, feature_spec)\n", |
837 | | - "\n", |
838 | | - " # Since the 'words' feature is a variable length word vector, we pad it to a\n", |
839 | | - " # constant maximum length based on HPARAMS.max_seq_length\n", |
840 | | - " features['words'] = pad_sequence(features['words'], HPARAMS.max_seq_length)\n", |
841 | | - " for i in range(HPARAMS.num_neighbors):\n", |
842 | | - " nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')\n", |
843 | | - " features[nbr_feature_key] = pad_sequence(features[nbr_feature_key],\n", |
844 | | - " HPARAMS.max_seq_length)\n", |
845 | | - "\n", |
846 | | - " labels = features.pop('label')\n", |
847 | | - " return features, labels\n", |
848 | | - "\n", |
849 | 797 | "def make_dataset(file_path, training=False):\n", |
850 | 798 | " \"\"\"Creates a `tf.data.TFRecordDataset`.\n", |
851 | 799 | "\n", |
|
858 | 806 | " An instance of `tf.data.TFRecordDataset` containing the `tf.train.Example`\n", |
859 | 807 | " objects.\n", |
860 | 808 | " \"\"\"\n", |
| 809 | + "\n", |
| 810 | + " def pad_sequence(sequence, max_seq_length):\n", |
| 811 | + " \"\"\"Pads the input sequence (a `tf.SparseTensor`) to `max_seq_length`.\"\"\"\n", |
| 812 | + " pad_size = tf.maximum([0], max_seq_length - tf.shape(sequence)[0])\n", |
| 813 | + " padded = tf.concat(\n", |
| 814 | + " [sequence.values,\n", |
| 815 | + " tf.fill((pad_size), tf.cast(0, sequence.dtype))],\n", |
| 816 | + " axis=0)\n", |
| 817 | + " # The input sequence may be larger than max_seq_length. Truncate down if\n", |
| 818 | + " # necessary.\n", |
| 819 | + " return tf.slice(padded, [0], [max_seq_length])\n", |
| 820 | + "\n", |
| 821 | + " def parse_example(example_proto):\n", |
| 822 | + " \"\"\"Extracts relevant fields from the `example_proto`.\n", |
| 823 | + "\n", |
| 824 | + " Args:\n", |
| 825 | + " example_proto: An instance of `tf.train.Example`.\n", |
| 826 | + "\n", |
| 827 | + " Returns:\n", |
| 828 | + " A pair whose first value is a dictionary containing relevant features\n", |
| 829 | + " and whose second value contains the ground truth labels.\n", |
| 830 | + " \"\"\"\n", |
| 831 | + " # The 'words' feature is a variable length word ID vector.\n", |
| 832 | + " feature_spec = {\n", |
| 833 | + " 'words': tf.io.VarLenFeature(tf.int64),\n", |
| 834 | + " 'label': tf.io.FixedLenFeature((), tf.int64, default_value=-1),\n", |
| 835 | + " }\n", |
| 836 | + " # We also extract corresponding neighbor features in a similar manner to\n", |
| 837 | + " # the features above during training.\n", |
| 838 | + " if training:\n", |
| 839 | + " for i in range(HPARAMS.num_neighbors):\n", |
| 840 | + " nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')\n", |
| 841 | + " nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i,\n", |
| 842 | + " NBR_WEIGHT_SUFFIX)\n", |
| 843 | + " feature_spec[nbr_feature_key] = tf.io.VarLenFeature(tf.int64)\n", |
| 844 | + "\n", |
| 845 | + " # We assign a default value of 0.0 for the neighbor weight so that\n", |
| 846 | + " # graph regularization is done on samples based on their exact number\n", |
| 847 | + " # of neighbors. In other words, non-existent neighbors are discounted.\n", |
| 848 | + " feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(\n", |
| 849 | + " [1], tf.float32, default_value=tf.constant([0.0]))\n", |
| 850 | + "\n", |
| 851 | + " features = tf.io.parse_single_example(example_proto, feature_spec)\n", |
| 852 | + "\n", |
| 853 | + " # Since the 'words' feature is a variable length word vector, we pad it to a\n", |
| 854 | + " # constant maximum length based on HPARAMS.max_seq_length\n", |
| 855 | + " features['words'] = pad_sequence(features['words'], HPARAMS.max_seq_length)\n", |
| 856 | + " if training:\n", |
| 857 | + " for i in range(HPARAMS.num_neighbors):\n", |
| 858 | + " nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')\n", |
| 859 | + " features[nbr_feature_key] = pad_sequence(features[nbr_feature_key],\n", |
| 860 | + " HPARAMS.max_seq_length)\n", |
| 861 | + "\n", |
| 862 | + " labels = features.pop('label')\n", |
| 863 | + " return features, labels\n", |
| 864 | + "\n", |
861 | 865 | " dataset = tf.data.TFRecordDataset([file_path])\n", |
862 | 866 | " if training:\n", |
863 | 867 | " dataset = dataset.shuffle(10000)\n", |
864 | 868 | " dataset = dataset.map(parse_example)\n", |
865 | 869 | " dataset = dataset.batch(HPARAMS.batch_size)\n", |
866 | 870 | " return dataset\n", |
867 | 871 | "\n", |
| 872 | + "\n", |
868 | 873 | "train_dataset = make_dataset('/tmp/imdb/nsl_train_data.tfr', True)\n", |
869 | 874 | "test_dataset = make_dataset('/tmp/imdb/test_data.tfr')" |
870 | 875 | ] |
|
1357 | 1362 | "id": "yBrp0Y0jHu5k" |
1358 | 1363 | }, |
1359 | 1364 | "source": [ |
1360 | | - "There are six entries: one for each monitored metric -- loss, graph loss, and\n", |
1361 | | - "accuracy -- during training and validation. We can use these to plot the\n", |
1362 | | - "training, graph, and validation losses for comparison, as well as the training\n", |
1363 | | - "and validation accuracy. Note that the graph loss is only computed during\n", |
1364 | | - "training; so its value will be 0 during validation." |
| 1365 | + "There are five entries in total in the dictionary: training loss, training\n", |
| 1366 | + "accuracy, training graph loss, validation loss, and validation accuracy. We can\n", |
| 1367 | + "plot them all together for comparison. Note that the graph loss is only computed\n", |
| 1368 | + "during training." |
1365 | 1369 | ] |
1366 | 1370 | }, |
1367 | 1371 | { |
|
1379 | 1383 | "loss = graph_reg_history_dict['loss']\n", |
1380 | 1384 | "graph_loss = graph_reg_history_dict['graph_loss']\n", |
1381 | 1385 | "val_loss = graph_reg_history_dict['val_loss']\n", |
1382 | | - "val_graph_loss = graph_reg_history_dict['val_graph_loss']\n", |
1383 | 1386 | "\n", |
1384 | 1387 | "epochs = range(1, len(acc) + 1)\n", |
1385 | 1388 | "\n", |
|
1391 | 1394 | "plt.plot(epochs, graph_loss, '-gD', label='Training graph loss')\n", |
1392 | 1395 | "# \"-b0\" is for solid blue line with circle markers.\n", |
1393 | 1396 | "plt.plot(epochs, val_loss, '-bo', label='Validation loss')\n", |
1394 | | - "# \"-ms\" is for solid magenta line with square markers.\n", |
1395 | | - "plt.plot(epochs, val_graph_loss, '-ms', label='Validation graph loss')\n", |
1396 | 1397 | "plt.title('Training and validation loss')\n", |
1397 | 1398 | "plt.xlabel('Epochs')\n", |
1398 | 1399 | "plt.ylabel('Loss')\n", |
|
0 commit comments