4949 'data_source' , 'tensorflow_datasets' , 'Data source. Valid options are: '
5050 '`tensorflow_datasets`, `realistic_ssl`' )
5151flags .DEFINE_integer (
52- 'target_num_train_per_class' , 20 ,
52+ 'target_num_train_per_class' , 400 ,
5353 'Number of samples per class to use for training.' )
5454flags .DEFINE_integer (
55- 'target_num_val' , 10000 ,
55+ 'target_num_val' , 1000 ,
5656 'Number of samples to be used for validation.' )
5757flags .DEFINE_integer (
58- 'seed' , 1234 ,
58+ 'seed' , 123 ,
5959 'Seed used by the random number generators.' )
6060flags .DEFINE_bool (
6161 'load_preprocessed' , False ,
142142 'Minimum number of iterations to train the agreement model for after '
143143 'the best validation accuracy is improved.' )
144144flags .DEFINE_integer (
145- 'num_samples_to_label' , 200 ,
145+ 'num_samples_to_label' , 500 ,
146146 'Number of samples to label after each co-train iteration.' )
147147flags .DEFINE_float (
148148 'min_confidence_new_label' , 0.4 ,
156156 'Minimum number of co-train iterations the agreement must be trained '
157157 'before it is used in the classifier.' )
158158flags .DEFINE_float (
159- 'ratio_valid_agr' , 0.2 ,
159+ 'ratio_valid_agr' , 0.1 ,
160160 'Ratio of edges used for validating the agreement model.' )
161161flags .DEFINE_integer (
162162 'max_samples_valid_agr' , 10000 ,
190190 'Schedule for decaying the weight decay in the agreement model. Choose '
191191 'between None or linear.' )
192192flags .DEFINE_integer (
193- 'batch_size_agr' , 32 , 'Batch size for agreement model.' )
193+ 'batch_size_agr' , 512 , 'Batch size for agreement model.' )
194194flags .DEFINE_integer (
195- 'batch_size_cls' , 32 , 'Batch size for classification model.' )
195+ 'batch_size_cls' , 512 , 'Batch size for classification model.' )
196196flags .DEFINE_float (
197197 'gradient_clip' , None ,
198198 'The gradient clipping global norm value. If None, no clipping is done.' )
240240 'reg_weight_uu' , 0.05 ,
241241 'Regularization weight for unlabeled-unlabeled edges.' )
242242flags .DEFINE_integer (
243- 'num_pairs_reg' , 512 ,
243+ 'num_pairs_reg' , 128 ,
244244 'Number of pairs of nodes to use in the agreement loss term of the '
245245 'classification model.' )
246246flags .DEFINE_string (
252252 'penalize_neg_agr' , True ,
253253 'Whether to encourage differences when agreement is negative.' )
254254flags .DEFINE_bool (
255- 'use_l2_cls' , True ,
255+ 'use_l2_cls' , False ,
256256 'Whether to use L2 loss for the classifier, not cross entropy.' )
257257flags .DEFINE_bool (
258258 'first_iter_original' , True ,
259259 'Whether to use the original model in the first iteration, without self '
260260 'labeling or agreement loss.' )
261261flags .DEFINE_bool (
262- 'inductive' , False ,
262+ 'inductive' , True ,
263263 'Whether to use an inductive or transductive SSL setting.' )
264264flags .DEFINE_string (
265265 'experiment_suffix' , '' ,
277277flags .DEFINE_string (
278278 'optimizer' , 'adam' ,
279279 'Which optimizer to use. Valid options are `adam`, `amsgrad`.' )
280+ flags .DEFINE_bool (
281+ 'load_from_checkpoint' , False ,
282+ 'Whether to load the trained model and the data that has been self-labeled '
283+ 'from a previous run, if available. This is useful if a process can get '
284+ 'preempted or interrupted.' )
280285
281286
282287def parse_layers_string (layers_string ):
@@ -306,11 +311,12 @@ def pick_model(data):
306311 """Picks the models depending on the provided configuration flags."""
307312 # Create model classification.
308313 if FLAGS .model_cls == 'mlp' :
309- hidden_classif = (parse_layers_string (FLAGS .hidden_cls )
310- if FLAGS .hidden_cls is not None else [])
314+ hidden_cls = (
315+ parse_layers_string (FLAGS .hidden_cls )
316+ if FLAGS .hidden_cls is not None else [])
311317 model_cls = MLP (
312318 output_dim = data .num_classes ,
313- hidden_sizes = hidden_classif ,
319+ hidden_sizes = hidden_cls ,
314320 activation = tf .nn .leaky_relu ,
315321 name = 'mlp_cls' )
316322 elif FLAGS .model_cls == 'cnn' :
@@ -417,22 +423,25 @@ def main(argv):
417423 logging .info ('Preprocessed data saved to %s.' , path )
418424
419425 # Put together parameters to create a model name.
420- model_name = FLAGS .model_cls + (( '_' + FLAGS . hidden_cls )
421- if FLAGS .model_cls == 'mlp' else '' )
422- model_name += '-' + FLAGS .model_agr + (( '_' + FLAGS . hidden_agr )
423- if FLAGS .model_agr == 'mlp' else '' )
424- model_name += ( '-aggr_' + FLAGS .aggregation_agr_inputs + '_' +
425- FLAGS .hidden_aggreg )
426+ model_name = FLAGS .model_cls
427+ model_name += ( '_' + FLAGS . hidden_cls ) if FLAGS .model_cls == 'mlp' else ''
428+ model_name += '-' + FLAGS .model_agr
429+ model_name += ( '_' + FLAGS . hidden_agr ) if FLAGS .model_agr == 'mlp' else ''
430+ model_name += '-aggr_' + FLAGS .aggregation_agr_inputs
431+ model_name += ( '_' + FLAGS . hidden_aggreg ) if FLAGS .hidden_aggreg else ''
426432 model_name += ('-add_%d-conf_%.2f-iter_cls_%d-iter_agr_%d-batch_cls_%d' %
427433 (FLAGS .num_samples_to_label , FLAGS .min_confidence_new_label ,
428434 FLAGS .max_num_iter_cls , FLAGS .max_num_iter_agr ,
429435 FLAGS .batch_size_cls ))
430- model_name += '-perfectAgr' if FLAGS .use_perfect_agreement else ''
431- model_name += '-perfectCls' if FLAGS .use_perfect_classifier else ''
436+ model_name += '-LL_%s_LU_%s_UU_%s' % (str (
437+ FLAGS .reg_weight_ll ), str (FLAGS .reg_weight_lu ), str (FLAGS .reg_weight_uu ))
438+ model_name += '-perfAgr' if FLAGS .use_perfect_agreement else ''
439+ model_name += '-perfCls' if FLAGS .use_perfect_classifier else ''
432440 model_name += '-keepProp' if FLAGS .keep_label_proportions else ''
433441 model_name += '-PenNegAgr' if FLAGS .penalize_neg_agr else ''
434- model_name += '-inductive' if FLAGS .inductive else ''
435- model_name += '-L2Loss' if FLAGS .use_l2_cls else '-CELoss'
442+ model_name += '-transduct' if not FLAGS .inductive else ''
443+ model_name += '-L2' if FLAGS .use_l2_cls else '-CE'
444+ model_name += '-seed_' + str (FLAGS .seed )
436445 model_name += FLAGS .experiment_suffix
437446 logging .info ('Model name: %s' , model_name )
438447
@@ -451,7 +460,6 @@ def main(argv):
451460 model_cls , model_agr = pick_model (data )
452461
453462 # Train.
454- optimizer = tf .train .AdamOptimizer
455463 trainer = TrainerCotraining (
456464 model_cls = model_cls ,
457465 model_agr = model_agr ,
@@ -466,7 +474,7 @@ def main(argv):
466474 min_confidence_new_label = FLAGS .min_confidence_new_label ,
467475 keep_label_proportions = FLAGS .keep_label_proportions ,
468476 num_warm_up_iter_agr = FLAGS .num_warm_up_iter_agr ,
469- optimizer = optimizer ,
477+ optimizer = tf . train . AdamOptimizer ,
470478 gradient_clip = FLAGS .gradient_clip ,
471479 batch_size_agr = FLAGS .batch_size_agr ,
472480 batch_size_cls = FLAGS .batch_size_cls ,
@@ -511,7 +519,8 @@ def main(argv):
511519 lr_decay_rate_cls = FLAGS .lr_decay_rate_cls ,
512520 lr_decay_steps_cls = FLAGS .lr_decay_steps_cls ,
513521 lr_decay_rate_agr = FLAGS .lr_decay_rate_agr ,
514- lr_decay_steps_agr = FLAGS .lr_decay_steps_agr )
522+ lr_decay_steps_agr = FLAGS .lr_decay_steps_agr ,
523+ load_from_checkpoint = FLAGS .load_from_checkpoint )
515524
516525 trainer .train (data )
517526
0 commit comments