Skip to content

Commit 70f0c23

Browse files
Merge pull request #1 from otiliastr:import-fixes
PiperOrigin-RevId: 267162545
2 parents 128a3ed + ee3246b commit 70f0c23

File tree

8 files changed

+46
-41
lines changed

8 files changed

+46
-41
lines changed

neural_structured_learning/research/gam/data/loaders.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@
1515

1616
from __future__ import absolute_import
1717
from __future__ import division
18-
from __future__ import google_type_annotations
1918
from __future__ import print_function
2019

2120
import json
2221
import logging
2322
import pickle
2423

25-
from gam.data import convert_image
26-
from gam.data import FixedDataset
27-
from gam.data import split_train_val_unlabeled
24+
from gam.data.dataset import FixedDataset
25+
from gam.data.preprocessing import convert_image
26+
from gam.data.preprocessing import split_train_val_unlabeled
27+
2828
import numpy as np
2929
import tensorflow_datasets as tfds
3030

neural_structured_learning/research/gam/experiments/run_train_mnist.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@
2929
from absl import app
3030
from absl import flags
3131

32-
from gam.data import Dataset
33-
from gam.data import load_data_realistic_ssl
34-
from gam.data import load_data_tf_datasets
35-
from gam.models import ImageCNNAgreement
36-
from gam.models import MLP
37-
from gam.models import WideResnet
38-
from gam.trainer import TrainerCotraining
32+
from gam.data.dataset import Dataset
33+
from gam.data.loaders import load_data_realistic_ssl
34+
from gam.data.loaders import load_data_tf_datasets
35+
from gam.models.cnn import ImageCNNAgreement
36+
from gam.models.mlp import MLP
37+
from gam.models.wide_resnet import WideResnet
38+
from gam.trainer.trainer_cotrain import TrainerCotraining
3939
import numpy as np
4040
import tensorflow as tf
4141

@@ -73,10 +73,10 @@
7373
'Path to the json files containing the label sample indices for '
7474
'Realistic SSL.')
7575
flags.DEFINE_string(
76-
'data_output_dir', '',
76+
'data_output_dir', './outputs',
7777
'Path to a folder where to save the preprocessed dataset.')
7878
flags.DEFINE_string(
79-
'output_dir', '',
79+
'output_dir', './outputs',
8080
'Path to a folder where checkpoints, summaries and other outputs are '
8181
'stored.')
8282
flags.DEFINE_string(
@@ -101,14 +101,18 @@
101101
flags.DEFINE_float(
102102
'learning_rate_decay_agr', None,
103103
'Learning rate decay factor for the agreement model.')
104-
flags.DEFINE_float('lr_decay_rate_cls', None,
105-
'Learning rate decay rate for the classification model.')
106-
flags.DEFINE_integer('lr_decay_steps_cls', None,
107-
'Learning rate decay steps for the classification model.')
108-
flags.DEFINE_float('lr_decay_rate_agr', None,
109-
'Learning rate decay rate for the agreement model.')
110-
flags.DEFINE_integer('lr_decay_steps_agr', None,
111-
'Learning rate decay steps for the agreement model.')
104+
flags.DEFINE_float(
105+
'lr_decay_rate_cls', None,
106+
'Learning rate decay rate for the classification model.')
107+
flags.DEFINE_integer(
108+
'lr_decay_steps_cls', None,
109+
'Learning rate decay steps for the classification model.')
110+
flags.DEFINE_float(
111+
'lr_decay_rate_agr', None,
112+
'Learning rate decay rate for the agreement model.')
113+
flags.DEFINE_integer(
114+
'lr_decay_steps_agr', None,
115+
'Learning rate decay steps for the agreement model.')
112116
flags.DEFINE_integer(
113117
'num_epochs_per_decay_cls', 350,
114118
'Number of epochs after which the learning rate decays for the '
@@ -178,8 +182,9 @@
178182
'weight_decay_schedule_cls', None,
179183
'Schedule for decaying the weight decay in the classification model. '
180184
'Choose bewteen None or linear.')
181-
flags.DEFINE_float('weight_decay_agr', 0,
182-
'Weight of the L2 penalty on the agreement model weights.')
185+
flags.DEFINE_float(
186+
'weight_decay_agr', 0,
187+
'Weight of the L2 penalty on the agreement model weights.')
183188
flags.DEFINE_string(
184189
'weight_decay_schedule_agr', None,
185190
'Schedule for decaying the weight decay in the agreement model. Choose '
@@ -438,9 +443,9 @@ def main(argv):
438443
checkpoints_dir = os.path.join(FLAGS.output_dir, 'checkpoints', model_name)
439444
data_dir = os.path.join(FLAGS.data_output_dir, 'data_checkpoints', model_name)
440445
if not os.path.exists(checkpoints_dir):
441-
os.makedir(checkpoints_dir)
446+
os.makedirs(checkpoints_dir)
442447
if not os.path.exists(data_dir):
443-
os.makedir(data_dir)
448+
os.makedirs(data_dir)
444449

445450
# Select the model based on the provided FLAGS.
446451
model_cls, model_agr = pick_model(data)

neural_structured_learning/research/gam/models/cnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from __future__ import division
2222
from __future__ import print_function
2323

24-
from gam.models import Model
24+
from gam.models.models_base import Model
2525
import numpy as np
2626
import tensorflow as tf
2727

neural_structured_learning/research/gam/models/mlp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20-
from gam.models import glorot
21-
from gam.models import Model
20+
from gam.models.models_base import glorot
21+
from gam.models.models_base import Model
2222
import numpy as np
2323
import tensorflow as tf
2424

neural_structured_learning/research/gam/models/wide_resnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
and this Github repository:
1919
https://github.com/brain-research/realistic-ssl-evaluation
2020
"""
21-
from gam.models import Model
21+
from gam.models.models_base import Model
2222

2323
import numpy as np
2424
import tensorflow as tf

neural_structured_learning/research/gam/trainer/trainer_agreement.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
import logging
2828
import os
2929

30-
from gam.trainer import batch_iterator
31-
from gam.trainer import Trainer
30+
from gam.trainer.trainer_base import batch_iterator
31+
from gam.trainer.trainer_base import Trainer
3232

3333
import numpy as np
3434
import tensorflow as tf

neural_structured_learning/research/gam/trainer/trainer_classification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
import logging
2222
import os
2323

24-
from gam.trainer import batch_iterator
25-
from gam.trainer import Trainer
24+
from gam.trainer.trainer_base import batch_iterator
25+
from gam.trainer.trainer_base import Trainer
2626

2727
import numpy as np
2828
import tensorflow as tf

neural_structured_learning/research/gam/trainer/trainer_cotrain.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@
3030
import logging
3131
import os
3232

33-
from gam.data import CotrainDataset
34-
from gam.trainer import Trainer
35-
from gam.trainer import TrainerAgreement
36-
from gam.trainer import TrainerClassification
37-
from gam.trainer import TrainerPerfectAgreement
38-
from gam.trainer import TrainerPerfectClassification
33+
from gam.data.dataset import CotrainDataset
34+
from gam.trainer.trainer_agreement import TrainerAgreement
35+
from gam.trainer.trainer_agreement import TrainerPerfectAgreement
36+
from gam.trainer.trainer_base import Trainer
37+
from gam.trainer.trainer_classification import TrainerClassification
38+
from gam.trainer.trainer_classification import TrainerPerfectClassification
3939

4040
import numpy as np
4141
import tensorflow as tf
@@ -530,10 +530,10 @@ def train(self, data, **kwargs):
530530
# If a checkpoint with the variables already exists, we restore them.
531531
if self.checkpoints_dir:
532532
checkpts_path_cotrain = os.path.join(self.checkpoints_dir, 'cotrain.ckpt')
533-
if os.path.exists(checkpts_path_cotrain+'.index'):
533+
if os.path.exists(checkpts_path_cotrain):
534534
saver.restore(session, checkpts_path_cotrain)
535535
else:
536-
os.makedir(checkpts_path_cotrain)
536+
os.makedirs(checkpts_path_cotrain)
537537
else:
538538
checkpts_path_cotrain = None
539539

0 commit comments

Comments
 (0)