Skip to content

Commit c6e35ab

Browse files
Merge pull request #49 from otiliastr:master
PiperOrigin-RevId: 301199341
2 parents 76b0a59 + c74dd54 commit c6e35ab

File tree

8 files changed

+175
-35
lines changed

8 files changed

+175
-35
lines changed

research/gam/README.md

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,23 +34,48 @@ folder on a strict "as is" basis, without warranties or conditions of any kind.
3434
Also, these implementations may not be compatible with certain TensorFlow
3535
versions (such as 2.0 or above) or Python versions.
3636

37+
More details can be found in our
38+
[paper](https://papers.nips.cc/paper/9076-graph-agreement-models-for-semi-supervised-learning.pdf),
39+
[supplementary material](https://papers.nips.cc/paper/9076-graph-agreement-models-for-semi-supervised-learning-supplemental.zip),
40+
[slides](https://drive.google.com/open?id=1tWEMoyrbLnzfSfTfYFi9eWgZWaPKF3Uu) or
41+
[poster](https://drive.google.com/file/d/1BZNR4B-xM41hdLLqx4mLsQ4KKJOhjgqV/view).
42+
3743
## How to run
3844

3945
To run GAM on a graph-based dataset (e.g., Cora, Citeseer, Pubmed), from this
40-
folder run: `bash python3.7 -m gam.experiments.run_train_gam_graph
46+
folder run: `$ python3.7 -m gam.experiments.run_train_gam_graph
4147
--data_path=<path_to_data>`
4248

43-
To run GAM on datasets without a graph (e.g., CIFAR10), from this folder run:
44-
`bash python3.7 -m gam.experiments.run_train_gam`
49+
To run GAM on datasets without a graph (e.g., CIFAR10), from this folder run: `$
50+
python3.7 -m gam.experiments.run_train_gam`
51+
52+
We recommend running on a GPU. With CUDA, this can be done by prepending
53+
`CUDA_VISIBLE_DEVICES=<your-gpu-number>` in front of the run command.
4554

4655
For running on different datasets and configuration, please check the command
47-
line flags in each of the run scripts.
56+
line flags in each of the run scripts. The configurations used in our paper can
57+
be found in the file `run_configs.txt`.
58+
59+
## Visualizing the results.
60+
61+
To visualize the results in Tensorboard, use the following command, adjusting
62+
the dataset name accordingly: `$ tensorboard --logdir=outputs/summaries/cora`
63+
64+
An example of such visualization for Cora with GCN + GAM model on the Pubmed
65+
dataset is the following:
66+
![Tensorboard plot](gam_gcn_pubmed.png?raw=true "GCN + GAM on Pubmed")
67+
68+
Similarly, we can run with multiple different parameter configurations and plot
69+
the results together for comparison. An example showing the accuracy per
70+
co-train iteration of a GCN + GAM model on the Cora dataset for 3 runs with 3
71+
different random seeds is the following:
72+
![Tensorboard plot](gam_gcn_cora_multiple_seeds.png?raw=true "GCN + GAM on Cora")
4873

4974
## References
5075

51-
[[1] O. Stretcu, K. Viswanathan, D. Movshovitz-Attias, E.A. Platanios, A.
52-
Tomkins, S. Ravi. "Graph Agreement Models for Semi-Supervised Learning." NeurIPS
53-
2019](https://nips.cc/Conferences/2019/Schedule?showEvent=13925)
76+
[[1] O. Stretcu, K. Viswanathan, D. Movshovitz-Attias, E.A. Platanios, S. Ravi,
77+
A. Tomkins. "Graph Agreement Models for Semi-Supervised Learning." NeurIPS
78+
2019](https://papers.nips.cc/paper/9076-graph-agreement-models-for-semi-supervised-learning)
5479

5580
[[2] T. Bui, S. Ravi and V. Ramavajjala. "Neural Graph Learning: Training Neural
5681
Networks Using Graphs." WSDM 2018](https://research.google/pubs/pub46568.pdf)

research/gam/gam/experiments/helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def get_model_cls(model_name,
6262
num_classes=data.num_classes,
6363
lrelu_leakiness=0.1,
6464
horizontal_flip=dataset_name in ('cifar10',),
65-
random_translation=True,
65+
random_translation=False,
6666
gaussian_noise=dataset_name not in ('svhn', 'svhn_cropped'),
6767
width=2,
6868
num_residual_units=4,
@@ -125,7 +125,7 @@ def get_model_agr(model_name,
125125
num_classes=1,
126126
lrelu_leakiness=0.1,
127127
horizontal_flip=dataset_name in ('cifar10',),
128-
random_translation=True,
128+
random_translation=False,
129129
gaussian_noise=dataset_name not in ('svhn', 'svhn_cropped'),
130130
width=2,
131131
num_residual_units=4,

research/gam/gam/models/models_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ class attribute. The valid options are:
142142
else:
143143
raise NotImplementedError()
144144

145-
def _project(self, inputs, reuse=tf.AUTO_REUSE):
145+
def _project(self, inputs, reuse=tf.compat.v1.AUTO_REUSE):
146146
"""Projects the provided inputs using a multilayer perceptron.
147147
148148
Arguments:

research/gam/gam/models/wide_resnet.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
import numpy as np
2424
import tensorflow as tf
25-
import tensorflow_addons as tfa
2625

2726

2827
def fast_flip(images, is_training):
@@ -38,27 +37,6 @@ def func(inp):
3837
return tf.cond(is_training, lambda: func(images), lambda: images)
3938

4039

41-
def jitter(input_data, is_training):
42-
"""Applies random noise to input data when training."""
43-
44-
def func(inp):
45-
"""Wrap functionality in a subfunction."""
46-
bsz = tf.shape(inp)[0]
47-
inp = tf.pad(inp, [[0, 0], [2, 2], [2, 2], [0, 0]], mode="REFLECT")
48-
base = tf.constant([1, 0, 0, 0, 1, 0, 0, 0], shape=[1, 8], dtype=tf.float32)
49-
base = tf.tile(base, [bsz, 1])
50-
mask = tf.constant([0, 0, 1, 0, 0, 1, 0, 0], shape=[1, 8], dtype=tf.float32)
51-
mask = tf.tile(mask, [bsz, 1])
52-
jit = tf.random_uniform([bsz, 8], minval=-2, maxval=3, dtype=tf.int32)
53-
jit = tf.cast(jit, tf.float32)
54-
xforms = base + jit * mask
55-
processed_data = tfa.image.transform(images=inp, transforms=xforms)
56-
cropped_data = processed_data[:, 2:-2, 2:-2, :]
57-
return cropped_data
58-
59-
return tf.cond(is_training, lambda: func(input_data), lambda: input_data)
60-
61-
6240
class WideResnet(Model):
6341
"""Resnet implementation from `Realistic Evaluation of Deep SSL Algorithms`.
6442
@@ -289,7 +267,7 @@ def _residual(x,
289267
if self.horizontal_flip:
290268
x = fast_flip(x, is_training=is_train)
291269
if self.random_translation:
292-
x = jitter(x, is_training=is_train)
270+
raise NotImplementedError("Random translations are not implemented yet.")
293271
if self.gaussian_noise:
294272
x = tf.cond(is_train, lambda: x + tf.random_normal(tf.shape(x)) * 0.15,
295273
lambda: x)

research/gam/gam/trainer/trainer_classification_gcn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
import logging
2222
import os
2323

24-
from .adversarial_dense import entropy_y_x
25-
from .adversarial_dense import get_loss_vat
24+
from .adversarial_sparse import entropy_y_x
25+
from .adversarial_sparse import get_loss_vat
2626
import numpy as np
2727
import tensorflow as tf
2828
from .trainer_base import batch_iterator
193 KB
Loading

research/gam/gam_gcn_pubmed.png

106 KB
Loading

0 commit comments

Comments
 (0)