@@ -31,6 +31,8 @@ class CNNClassifier(object):
3131 def __init__ (self , model_file , label_file , input_layer = "input" , output_layer = "final_result" , input_height = 128 , input_width = 128 , input_mean = 127.5 , input_std = 127.5 ):
3232 self ._graph = self .load_graph (model_file )
3333 self ._labels = self .load_labels (label_file )
34+ self .input_height = input_height
35+ self .input_width = input_width
3436 input_name = "import/" + input_layer
3537 output_name = "import/" + output_layer
3638 self ._input_operation = self ._graph .get_operation_by_name (input_name )
@@ -78,8 +80,8 @@ def read_tensor_from_image_file(self, file_name, input_height=299, input_width=2
7880 image_reader = tf .image .decode_jpeg (file_reader , channels = 3 , name = 'jpeg_reader' )
7981
8082 float_caster = tf .cast (image_reader , tf .float32 )
81- dims_expander = tf .expand_dims (float_caster , 0 )
82- resized = tf .image .resize_bilinear (dims_expander , [input_height , input_width ])
83+ dims_expander = tf .expand_dims (float_caster , 0 );
84+ resized = tf .image .resize_bilinear (dims_expander , [self . input_height , self . input_width ])
8385 normalized = tf .divide (tf .subtract (resized , [input_mean ]), [input_std ])
8486 sess = tf .Session ()
8587
@@ -108,7 +110,6 @@ def classify_image(self,
108110 else :
109111 t = self .read_tensor_from_image_mat (image_file_or_mat )
110112
111-
112113 results = self ._session .run (self ._output_operation .outputs [0 ],
113114 {self ._input_operation .outputs [0 ]: t })
114115
0 commit comments