Skip to content

Commit c6c47c6

Browse files
Neural-Link Teamtensorflow-copybara
authored andcommitted
Copybara import of the project:
-- 2b68d08 by Otilia Stretcu <otiliastr@gmail.com>: Small fixes. -- 3265019 by Otilia Stretcu <otiliastr@gmail.com>: Changing model name. -- dca1457 by Otilia Stretcu <otiliastr@gmail.com>: Refactoring. -- 8a90308 by Otilia Stretcu <otiliastr@gmail.com>: Refactoring. -- eef1760 by Otilia Stretcu <otiliastr@gmail.com>: Removed unused function. -- 7954700 by Otilia Stretcu <otiliastr@gmail.com>: Rename run script. PiperOrigin-RevId: 272069428
1 parent 0384ee1 commit c6c47c6

File tree

6 files changed

+269
-304
lines changed

6 files changed

+269
-304
lines changed

neural_structured_learning/research/gam/experiments/run_train_mnist.py renamed to neural_structured_learning/research/gam/experiments/run_train_gam.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,14 +176,14 @@
176176
'String representing the number of units of the hidden layers of the '
177177
'aggregation network of the agreement model.')
178178
flags.DEFINE_float(
179-
'weight_decay_cls', 0,
179+
'weight_decay_cls', None,
180180
'Weight of the L2 penalty on the classification model weights.')
181181
flags.DEFINE_string(
182182
'weight_decay_schedule_cls', None,
183183
'Schedule for decaying the weight decay in the classification model. '
184184
'Choose bewteen None or linear.')
185185
flags.DEFINE_float(
186-
'weight_decay_agr', 0,
186+
'weight_decay_agr', None,
187187
'Weight of the L2 penalty on the agreement model weights.')
188188
flags.DEFINE_string(
189189
'weight_decay_schedule_agr', None,
@@ -429,17 +429,21 @@ def main(argv):
429429
model_name += ('_' + FLAGS.hidden_agr) if FLAGS.model_agr == 'mlp' else ''
430430
model_name += '-aggr_' + FLAGS.aggregation_agr_inputs
431431
model_name += ('_' + FLAGS.hidden_aggreg) if FLAGS.hidden_aggreg else ''
432-
model_name += ('-add_%d-conf_%.2f-iter_cls_%d-iter_agr_%d-batch_cls_%d' %
432+
model_name += ('-add_%d-conf_%.2f-iterCls_%d-iterAgr_%d-batchCls_%d' %
433433
(FLAGS.num_samples_to_label, FLAGS.min_confidence_new_label,
434434
FLAGS.max_num_iter_cls, FLAGS.max_num_iter_agr,
435435
FLAGS.batch_size_cls))
436+
model_name += (('-wdecayCls_%.4f' % FLAGS.weight_decay_cls)
437+
if FLAGS.weight_decay_cls else '')
438+
model_name += (('-wdecayAgr_%.4f' % FLAGS.weight_decay_agr)
439+
if FLAGS.weight_decay_agr else '')
436440
model_name += '-LL_%s_LU_%s_UU_%s' % (str(
437441
FLAGS.reg_weight_ll), str(FLAGS.reg_weight_lu), str(FLAGS.reg_weight_uu))
438442
model_name += '-perfAgr' if FLAGS.use_perfect_agreement else ''
439443
model_name += '-perfCls' if FLAGS.use_perfect_classifier else ''
440444
model_name += '-keepProp' if FLAGS.keep_label_proportions else ''
441445
model_name += '-PenNegAgr' if FLAGS.penalize_neg_agr else ''
442-
model_name += '-transduct' if not FLAGS.inductive else ''
446+
model_name += '-transd' if not FLAGS.inductive else ''
443447
model_name += '-L2' if FLAGS.use_l2_cls else '-CE'
444448
model_name += '-seed_' + str(FLAGS.seed)
445449
model_name += FLAGS.experiment_suffix

neural_structured_learning/research/gam/models/cnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def get_loss(self,
289289
weight decay is applied.
290290
**kwargs: Keyword arguments, potentially containing the weight of the
291291
regularization term, passed under the name `weight_decay`. If this is
292-
not provided, it defaults to 0.0.
292+
not provided, it defaults to 0.004.
293293
294294
Returns:
295295
loss: The cummulated loss value.

neural_structured_learning/research/gam/models/mlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def get_loss(self,
227227
# Weight decay loss.
228228
if weight_decay is not None:
229229
for var in reg_params.values():
230-
loss += weight_decay * tf.nn.l2_loss(var)
230+
loss = loss + weight_decay * tf.nn.l2_loss(var)
231231
return loss
232232

233233
def normalize_predictions(self, predictions):

0 commit comments

Comments
 (0)