@@ -142,7 +142,7 @@ def __init__(self, manager, architecture):
142142 self .validation_batch_size = 100
143143 self .print_misclassified_test_images = False
144144 self .bottleneck_dir = "/tmp/bottleneck"
145- self .model_dir = "/tmp/imagenet "
145+ self .model_dir = "./cnn_models/cache "
146146 self .final_tensor_name = "final_result"
147147 self .write_logs = False
148148
@@ -160,7 +160,7 @@ def __init__(self, manager, architecture):
160160 raise Exception ("Did not recognize architecture flag'" )
161161
162162 # Set up the pre-trained graph.
163- self .maybe_download_and_extract (self .model_info ['data_url' ])
163+ self .maybe_download_and_extract (self .model_info ['data_url' ], self . model_info [ 'model_dir_name' ] )
164164 self .graph , self .bottleneck_tensor , self .resized_image_tensor = (
165165 self .create_model_graph (self .model_info ))
166166
@@ -517,7 +517,7 @@ def run_bottleneck_on_image(self, sess, image_data, image_data_tensor,
517517 return bottleneck_values
518518
519519
520- def maybe_download_and_extract (self , data_url ):
520+ def maybe_download_and_extract (self , data_url , model_dir_name ):
521521 """Download and extract model tar file.
522522
523523 If the pretrained model we're using doesn't already exist, this function
@@ -526,7 +526,7 @@ def maybe_download_and_extract(self, data_url):
526526 Args:
527527 data_url: Web location of the tar file containing the pretrained model.
528528 """
529- dest_directory = self .model_dir
529+ dest_directory = os . path . join ( self .model_dir , model_dir_name )
530530 if not os .path .exists (dest_directory ):
531531 os .makedirs (dest_directory )
532532 filename = data_url .split ('/' )[- 1 ]
@@ -538,11 +538,10 @@ def _progress(count, block_size, total_size):
538538 (filename ,
539539 float (count * block_size ) / float (total_size ) * 100.0 ))
540540 sys .stdout .flush ()
541-
542541 filepath , _ = urllib .request .urlretrieve (data_url , filepath , _progress )
543542 print ()
544543 statinfo = os .stat (filepath )
545- tf .logging .info ('Successfully downloaded' , filename , statinfo .st_size ,
544+ tf .logging .info ('Successfully downloaded %s %d ' , filename , statinfo .st_size ,
546545 'bytes.' )
547546 tarfile .open (filepath , 'r:gz' ).extractall (dest_directory )
548547
@@ -1084,47 +1083,64 @@ def create_model_info(self, architecture):
10841083 tf .logging .error ("Couldn't understand architecture name '%s'" ,
10851084 architecture )
10861085 return None
1087- version_string = parts [1 ]
1086+ v_string = parts [1 ]
1087+ version_string = parts [2 ]
10881088 if (version_string != '1.0' and version_string != '0.75' and
1089- version_string != '0.50' and version_string != '0.25' ):
1089+ version_string != '0.50' and version_string != '0.5' and
1090+ version_string != '0.35' and version_string != '0.25' ):
10901091 tf .logging .error (
1091- """"The Mobilenet version should be '1.0', '0.75', '0.50', or '0.25',
1092+ """"The Mobilenet version should be '1.0', '0.75', '0.50', '0.35', or '0.25',
10921093 but found '%s' for architecture '%s'""" ,
10931094 version_string , architecture )
10941095 return None
1095- size_string = parts [2 ]
1096+ size_string = parts [3 ]
10961097 if (size_string != '224' and size_string != '192' and
10971098 size_string != '160' and size_string != '128' ):
10981099 tf .logging .error (
10991100 """The Mobilenet input size should be '224', '192', '160', or '128',
11001101 but found '%s' for architecture '%s'""" ,
11011102 size_string , architecture )
11021103 return None
1103- if len (parts ) == 3 :
1104+ if len (parts ) == 4 :
11041105 is_quantized = False
11051106 else :
1106- if parts [3 ] != 'quantized' :
1107+ if parts [4 ] != 'quantized' :
11071108 tf .logging .error (
11081109 "Couldn't understand architecture suffix '%s' for '%s'" , parts [3 ],
11091110 architecture )
11101111 return None
11111112 is_quantized = True
1112- data_url = 'http://download.tensorflow.org/models/mobilenet_v1_'
1113- data_url += version_string + '_' + size_string + '_frozen.tgz'
1114- bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0'
1113+ data_url = 'http://'
1114+ model_file_name = None
1115+ bottleneck_tensor_name = None
1116+ if architecture .startswith ('mobilenet_v1' ):
1117+ data_url += 'download.tensorflow.org/models/mobilenet_v1_'
1118+ data_url += version_string + '_' + size_string + '_frozen.tgz'
1119+ bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0'
1120+ if is_quantized :
1121+ model_base_name = 'quantized_graph.pb'
1122+ else :
1123+ model_base_name = 'frozen_graph.pb'
1124+ model_dir_name = 'mobilenet_v1_'
1125+ model_dir_name += version_string + '_' + size_string
1126+ model_file_name = os .path .join (model_dir_name , model_base_name )
1127+ model_dir_name = ''
1128+ else :
1129+ data_url += 'storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_'
1130+ data_url += version_string + '_' + size_string + '.tgz'
1131+ bottleneck_tensor_name = 'MobilenetV2/Predictions/Reshape:0'
1132+ model_dir_name = 'mobilenet_v2_'
1133+ model_dir_name += version_string + '_' + size_string
1134+ model_base_name = model_dir_name + '_frozen.pb'
1135+ model_file_name = os .path .join (model_dir_name , model_base_name )
11151136 bottleneck_tensor_size = 1001
11161137 input_width = int (size_string )
11171138 input_height = int (size_string )
11181139 input_depth = 3
11191140 resized_input_tensor_name = 'input:0'
1120- if is_quantized :
1121- model_base_name = 'quantized_graph.pb'
1122- else :
1123- model_base_name = 'frozen_graph.pb'
1124- model_dir_name = 'mobilenet_v1_' + version_string + '_' + size_string
1125- model_file_name = os .path .join (model_dir_name , model_base_name )
11261141 input_mean = 127.5
11271142 input_std = 127.5
1143+ print (data_url )
11281144 else :
11291145 tf .logging .error ("Couldn't understand architecture name '%s'" , architecture )
11301146 raise ValueError ('Unknown architecture' , architecture )
@@ -1138,6 +1154,7 @@ def create_model_info(self, architecture):
11381154 'input_depth' : input_depth ,
11391155 'resized_input_tensor_name' : resized_input_tensor_name ,
11401156 'model_file_name' : model_file_name ,
1157+ 'model_dir_name' : model_dir_name ,
11411158 'input_mean' : input_mean ,
11421159 'input_std' : input_std ,
11431160 }
@@ -1170,11 +1187,3 @@ def add_jpeg_decoding(self, input_width, input_height, input_depth, input_mean,
11701187 mul_image = tf .multiply (offset_image , 1.0 / input_std )
11711188 return jpeg_data , mul_image
11721189
1173-
1174- if __name__ == '__main__' :
1175- cnn_trainer = CNNTrainer ("mobilenet_0.50_128" )
1176- cnn_trainer .retrain (
1177- image_dir = "/home/pi/tensorflow/data/applekiwi" ,
1178- output_graph = "./cnn_models/applewiki_0_5_128.pb" ,
1179- training_steps = 10 ,
1180- learning_rate = 0.1 )
0 commit comments