@@ -37,6 +37,8 @@ class CNNClassifier(object):
3737 def __init__ (self , model_file , label_file , input_layer = "input" , output_layer = "final_result" , input_height = 128 , input_width = 128 , input_mean = 127.5 , input_std = 127.5 ):
3838 self ._graph = self .load_graph (model_file )
3939 self ._labels = self .load_labels (label_file )
40+ self .input_height = input_height
41+ self .input_width = input_width
4042 input_name = "import/" + input_layer
4143 output_name = "import/" + output_layer
4244 self ._input_operation = self ._graph .get_operation_by_name (input_name );
@@ -88,7 +90,7 @@ def read_tensor_from_image_file(self, file_name, input_height=299, input_width=2
8890
8991 float_caster = tf .cast (image_reader , tf .float32 )
9092 dims_expander = tf .expand_dims (float_caster , 0 );
91- resized = tf .image .resize_bilinear (dims_expander , [input_height , input_width ])
93+ resized = tf .image .resize_bilinear (dims_expander , [self . input_height , self . input_width ])
9294 normalized = tf .divide (tf .subtract (resized , [input_mean ]), [input_std ])
9395 sess = tf .Session ()
9496
@@ -111,21 +113,15 @@ def load_labels(self, label_file):
111113 def classify_image (self ,
112114 image_file_or_mat ,
113115 top_results = 3 ):
114- s_t = time .time ()
115116 t = None
116117 if type (image_file_or_mat ) == str :
117118 t = self .read_tensor_from_image_file (file_name = image_file_or_mat )
118119 else :
119120 t = self .read_tensor_from_image_mat (image_file_or_mat )
120121
121- #logging.info( "time.norm: " + str(time.time() - s_t))
122- s_t = time .time ()
123-
124122 results = self ._session .run (self ._output_operation .outputs [0 ],
125123 {self ._input_operation .outputs [0 ]: t })
126124
127- #logging.info( "time.cls: " + str(time.time() - s_t))
128-
129125 top_results = min (top_results , len (self ._labels ))
130126 results = np .squeeze (results )
131127 results_idx = np .argpartition (results , - top_results )[- top_results :]
0 commit comments