Skip to content

Commit 38d5b04

Browse files
committed
Bug fixes.
1 parent f6d0bef commit 38d5b04

File tree

4 files changed

+7
-29
lines changed

4 files changed

+7
-29
lines changed

research/gam/gam/experiments/helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def get_model_cls(model_name,
6262
num_classes=data.num_classes,
6363
lrelu_leakiness=0.1,
6464
horizontal_flip=dataset_name in ('cifar10',),
65-
random_translation=True,
65+
random_translation=False,
6666
gaussian_noise=dataset_name not in ('svhn', 'svhn_cropped'),
6767
width=2,
6868
num_residual_units=4,
@@ -125,7 +125,7 @@ def get_model_agr(model_name,
125125
num_classes=1,
126126
lrelu_leakiness=0.1,
127127
horizontal_flip=dataset_name in ('cifar10',),
128-
random_translation=True,
128+
random_translation=False,
129129
gaussian_noise=dataset_name not in ('svhn', 'svhn_cropped'),
130130
width=2,
131131
num_residual_units=4,

research/gam/gam/models/models_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ class attribute. The valid options are:
142142
else:
143143
raise NotImplementedError()
144144

145-
def _project(self, inputs, reuse=tf.AUTO_REUSE):
145+
def _project(self, inputs, reuse=tf.compat.v1.AUTO_REUSE):
146146
"""Projects the provided inputs using a multilayer perceptron.
147147
148148
Arguments:

research/gam/gam/models/wide_resnet.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
import numpy as np
2424
import tensorflow as tf
25-
import tensorflow_addons as tfa
2625

2726

2827
def fast_flip(images, is_training):
@@ -38,27 +37,6 @@ def func(inp):
3837
return tf.cond(is_training, lambda: func(images), lambda: images)
3938

4039

41-
def jitter(input_data, is_training):
42-
"""Applies random noise to input data when training."""
43-
44-
def func(inp):
45-
"""Wrap functionality in a subfunction."""
46-
bsz = tf.shape(inp)[0]
47-
inp = tf.pad(inp, [[0, 0], [2, 2], [2, 2], [0, 0]], mode="REFLECT")
48-
base = tf.constant([1, 0, 0, 0, 1, 0, 0, 0], shape=[1, 8], dtype=tf.float32)
49-
base = tf.tile(base, [bsz, 1])
50-
mask = tf.constant([0, 0, 1, 0, 0, 1, 0, 0], shape=[1, 8], dtype=tf.float32)
51-
mask = tf.tile(mask, [bsz, 1])
52-
jit = tf.random_uniform([bsz, 8], minval=-2, maxval=3, dtype=tf.int32)
53-
jit = tf.cast(jit, tf.float32)
54-
xforms = base + jit * mask
55-
processed_data = tfa.image.transform(images=inp, transforms=xforms)
56-
cropped_data = processed_data[:, 2:-2, 2:-2, :]
57-
return cropped_data
58-
59-
return tf.cond(is_training, lambda: func(input_data), lambda: input_data)
60-
61-
6240
class WideResnet(Model):
6341
"""Resnet implementation from `Realistic Evaluation of Deep SSL Algorithms`.
6442
@@ -289,7 +267,7 @@ def _residual(x,
289267
if self.horizontal_flip:
290268
x = fast_flip(x, is_training=is_train)
291269
if self.random_translation:
292-
x = jitter(x, is_training=is_train)
270+
raise NotImplementedError('Random translations are not implemented yet.')
293271
if self.gaussian_noise:
294272
x = tf.cond(is_train, lambda: x + tf.random_normal(tf.shape(x)) * 0.15,
295273
lambda: x)

research/gam/gam/trainer/trainer_classification_gcn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
import logging
2222
import os
2323

24-
from .adversarial_dense import entropy_y_x
25-
from .adversarial_dense import get_loss_vat
24+
from .adversarial_sparse import entropy_y_x
25+
from .adversarial_sparse import get_loss_vat
2626
import numpy as np
2727
import tensorflow as tf
2828
from .trainer_base import batch_iterator
@@ -869,7 +869,7 @@ def train(self, data, session=None, **kwargs):
869869

870870
def predict(self, session, indices, is_train):
871871
"""Make predictions for the provided sample indices."""
872-
if not indices:
872+
if len(indices) == 0:
873873
return np.zeros((0, self.data.num_classes), dtype=np.float32)
874874
feed_dict = {
875875
self.input_indices: indices,

0 commit comments

Comments
 (0)