Home | History | Annotate | Download | only in label_image
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import argparse
     21 
     22 import numpy as np
     23 import tensorflow as tf
     24 
     25 
     26 def load_graph(model_file):
     27   graph = tf.Graph()
     28   graph_def = tf.GraphDef()
     29 
     30   with open(model_file, "rb") as f:
     31     graph_def.ParseFromString(f.read())
     32   with graph.as_default():
     33     tf.import_graph_def(graph_def)
     34 
     35   return graph
     36 
     37 
     38 def read_tensor_from_image_file(file_name,
     39                                 input_height=299,
     40                                 input_width=299,
     41                                 input_mean=0,
     42                                 input_std=255):
     43   input_name = "file_reader"
     44   output_name = "normalized"
     45   file_reader = tf.read_file(file_name, input_name)
     46   if file_name.endswith(".png"):
     47     image_reader = tf.image.decode_png(
     48         file_reader, channels=3, name="png_reader")
     49   elif file_name.endswith(".gif"):
     50     image_reader = tf.squeeze(
     51         tf.image.decode_gif(file_reader, name="gif_reader"))
     52   elif file_name.endswith(".bmp"):
     53     image_reader = tf.image.decode_bmp(file_reader, name="bmp_reader")
     54   else:
     55     image_reader = tf.image.decode_jpeg(
     56         file_reader, channels=3, name="jpeg_reader")
     57   float_caster = tf.cast(image_reader, tf.float32)
     58   dims_expander = tf.expand_dims(float_caster, 0)
     59   resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
     60   normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
     61   sess = tf.Session()
     62   result = sess.run(normalized)
     63 
     64   return result
     65 
     66 
     67 def load_labels(label_file):
     68   label = []
     69   proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines()
     70   for l in proto_as_ascii_lines:
     71     label.append(l.rstrip())
     72   return label
     73 
     74 
     75 if __name__ == "__main__":
     76   file_name = "tensorflow/examples/label_image/data/grace_hopper.jpg"
     77   model_file = \
     78     "tensorflow/examples/label_image/data/inception_v3_2016_08_28_frozen.pb"
     79   label_file = "tensorflow/examples/label_image/data/imagenet_slim_labels.txt"
     80   input_height = 299
     81   input_width = 299
     82   input_mean = 0
     83   input_std = 255
     84   input_layer = "input"
     85   output_layer = "InceptionV3/Predictions/Reshape_1"
     86 
     87   parser = argparse.ArgumentParser()
     88   parser.add_argument("--image", help="image to be processed")
     89   parser.add_argument("--graph", help="graph/model to be executed")
     90   parser.add_argument("--labels", help="name of file containing labels")
     91   parser.add_argument("--input_height", type=int, help="input height")
     92   parser.add_argument("--input_width", type=int, help="input width")
     93   parser.add_argument("--input_mean", type=int, help="input mean")
     94   parser.add_argument("--input_std", type=int, help="input std")
     95   parser.add_argument("--input_layer", help="name of input layer")
     96   parser.add_argument("--output_layer", help="name of output layer")
     97   args = parser.parse_args()
     98 
     99   if args.graph:
    100     model_file = args.graph
    101   if args.image:
    102     file_name = args.image
    103   if args.labels:
    104     label_file = args.labels
    105   if args.input_height:
    106     input_height = args.input_height
    107   if args.input_width:
    108     input_width = args.input_width
    109   if args.input_mean:
    110     input_mean = args.input_mean
    111   if args.input_std:
    112     input_std = args.input_std
    113   if args.input_layer:
    114     input_layer = args.input_layer
    115   if args.output_layer:
    116     output_layer = args.output_layer
    117 
    118   graph = load_graph(model_file)
    119   t = read_tensor_from_image_file(
    120       file_name,
    121       input_height=input_height,
    122       input_width=input_width,
    123       input_mean=input_mean,
    124       input_std=input_std)
    125 
    126   input_name = "import/" + input_layer
    127   output_name = "import/" + output_layer
    128   input_operation = graph.get_operation_by_name(input_name)
    129   output_operation = graph.get_operation_by_name(output_name)
    130 
    131   with tf.Session(graph=graph) as sess:
    132     results = sess.run(output_operation.outputs[0], {
    133         input_operation.outputs[0]: t
    134     })
    135   results = np.squeeze(results)
    136 
    137   top_k = results.argsort()[-5:][::-1]
    138   labels = load_labels(label_file)
    139   for i in top_k:
    140     print(labels[i], results[i])
    141