@@ -63,7 +63,7 @@ class HParams(object):
6363 ### model architecture
6464 num_fc_units = attr .ib (default = [50 , 50 ])
6565 ### training parameters
66- train_epochs = attr .ib (default = 100 )
66+ train_epochs = attr .ib (default = 10 )
6767 batch_size = attr .ib (default = 128 )
6868 dropout_rate = attr .ib (default = 0.5 )
6969 ### eval parameters
@@ -93,8 +93,8 @@ def load_dataset(filename):
9393 return tf .data .TFRecordDataset ([filename ])
9494
9595
96- def make_datasets ( train_data_path , test_data_path , hparams ):
97- """Returns training and test data as a pair of `tf.data.Dataset` instances ."""
96+ def make_dataset ( file_path , training , include_nbr_features , hparams ):
97+ """Returns a `tf.data.Dataset` instance based on data in `file_path` ."""
9898
9999 def parse_example (example_proto ):
100100 """Extracts relevant fields from the `example_proto`.
@@ -120,35 +120,33 @@ def parse_example(example_proto):
120120 'label' :
121121 tf .io .FixedLenFeature ((), tf .int64 , default_value = - 1 ),
122122 }
123- for i in range (hparams .num_neighbors ):
124- nbr_feature_key = '{}{}_{}' .format (NBR_FEATURE_PREFIX , i , 'words' )
125- nbr_weight_key = '{}{}{}' .format (NBR_FEATURE_PREFIX , i , NBR_WEIGHT_SUFFIX )
126- feature_spec [nbr_feature_key ] = tf .io .FixedLenFeature (
127- [hparams .max_seq_length ],
128- tf .int64 ,
129- default_value = tf .constant (
130- 0 , dtype = tf .int64 , shape = [hparams .max_seq_length ]))
131- feature_spec [nbr_weight_key ] = tf .io .FixedLenFeature (
132- [1 ], tf .float32 , default_value = tf .constant ([0.0 ]))
123+ if include_nbr_features :
124+ for i in range (hparams .num_neighbors ):
125+ nbr_feature_key = '{}{}_{}' .format (NBR_FEATURE_PREFIX , i , 'words' )
126+ nbr_weight_key = '{}{}{}' .format (NBR_FEATURE_PREFIX , i ,
127+ NBR_WEIGHT_SUFFIX )
128+ feature_spec [nbr_feature_key ] = tf .io .FixedLenFeature (
129+ [hparams .max_seq_length ],
130+ tf .int64 ,
131+ default_value = tf .constant (
132+ 0 , dtype = tf .int64 , shape = [hparams .max_seq_length ]))
133+ feature_spec [nbr_weight_key ] = tf .io .FixedLenFeature (
134+ [1 ], tf .float32 , default_value = tf .constant ([0.0 ]))
133135
134136 features = tf .io .parse_single_example (example_proto , feature_spec )
135137
136138 labels = features .pop ('label' )
137139 return features , labels
138140
139- def make_dataset (file_path , training = False ):
140- """Creates a `tf.data.Dataset`."""
141- # If the dataset is sharded, the following code may be required:
142- # filenames = tf.data.Dataset.list_files(file_path, shuffle=True)
143- # dataset = filenames.interleave(load_dataset, cycle_length=1)
144- dataset = load_dataset (file_path )
145- if training :
146- dataset = dataset .shuffle (10000 )
147- dataset = dataset .map (parse_example )
148- dataset = dataset .batch (hparams .batch_size )
149- return dataset
150-
151- return make_dataset (train_data_path , True ), make_dataset (test_data_path )
141+ # If the dataset is sharded, the following code may be required:
142+ # filenames = tf.data.Dataset.list_files(file_path, shuffle=True)
143+ # dataset = filenames.interleave(load_dataset, cycle_length=1)
144+ dataset = load_dataset (file_path )
145+ if training :
146+ dataset = dataset .shuffle (10000 )
147+ dataset = dataset .map (parse_example )
148+ dataset = dataset .batch (hparams .batch_size )
149+ return dataset
152150
153151
154152def make_mlp_sequential_model (hparams ):
@@ -270,7 +268,8 @@ def main(argv):
270268 (len (argv ) - 1 ))
271269
272270 hparams = get_hyper_parameters ()
273- train_dataset , test_dataset = make_datasets (argv [1 ], argv [2 ], hparams )
271+ train_data_path = argv [1 ]
272+ test_data_path = argv [2 ]
274273
275274 # Graph regularization configuration.
276275 graph_reg_config = nsl .configs .make_graph_reg_config (
@@ -287,6 +286,8 @@ def main(argv):
287286 }
288287 for base_model_tag , base_model in base_models .items ():
289288 logging .info ('\n ====== %s BASE MODEL TEST BEGIN ======' , base_model_tag )
289+ train_dataset = make_dataset (train_data_path , True , False , hparams )
290+ test_dataset = make_dataset (test_data_path , False , False , hparams )
290291 train_and_evaluate (base_model , 'Base MLP model' , train_dataset ,
291292 test_dataset , hparams )
292293
@@ -295,6 +296,8 @@ def main(argv):
295296 # Wrap the base MLP model with graph regularization.
296297 graph_reg_model = nsl .keras .GraphRegularization (base_model ,
297298 graph_reg_config )
299+ train_dataset = make_dataset (train_data_path , True , True , hparams )
300+ test_dataset = make_dataset (test_data_path , False , False , hparams )
298301 train_and_evaluate (graph_reg_model , 'MLP + graph regularization' ,
299302 train_dataset , test_dataset , hparams )
300303
0 commit comments