1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import argparse 21 22 import numpy as np 23 import tensorflow as tf 24 25 26 def load_graph(model_file): 27 graph = tf.Graph() 28 graph_def = tf.GraphDef() 29 30 with open(model_file, "rb") as f: 31 graph_def.ParseFromString(f.read()) 32 with graph.as_default(): 33 tf.import_graph_def(graph_def) 34 35 return graph 36 37 38 def read_tensor_from_image_file(file_name, 39 input_height=299, 40 input_width=299, 41 input_mean=0, 42 input_std=255): 43 input_name = "file_reader" 44 output_name = "normalized" 45 file_reader = tf.read_file(file_name, input_name) 46 if file_name.endswith(".png"): 47 image_reader = tf.image.decode_png( 48 file_reader, channels=3, name="png_reader") 49 elif file_name.endswith(".gif"): 50 image_reader = tf.squeeze( 51 tf.image.decode_gif(file_reader, name="gif_reader")) 52 elif file_name.endswith(".bmp"): 53 image_reader = tf.image.decode_bmp(file_reader, name="bmp_reader") 54 else: 55 image_reader = tf.image.decode_jpeg( 56 file_reader, channels=3, name="jpeg_reader") 57 float_caster = tf.cast(image_reader, tf.float32) 58 dims_expander = tf.expand_dims(float_caster, 0) 59 resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width]) 60 normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std]) 61 sess = tf.Session() 62 result = sess.run(normalized) 63 64 return result 65 66 67 def load_labels(label_file): 68 label = [] 69 proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines() 70 for l in proto_as_ascii_lines: 71 label.append(l.rstrip()) 72 return label 73 74 75 if __name__ == "__main__": 76 file_name = "tensorflow/examples/label_image/data/grace_hopper.jpg" 77 model_file = \ 78 "tensorflow/examples/label_image/data/inception_v3_2016_08_28_frozen.pb" 79 label_file = "tensorflow/examples/label_image/data/imagenet_slim_labels.txt" 80 input_height = 299 81 input_width = 299 82 input_mean = 0 83 input_std = 255 84 input_layer = "input" 85 output_layer = "InceptionV3/Predictions/Reshape_1" 86 87 parser = argparse.ArgumentParser() 88 parser.add_argument("--image", help="image to be processed") 89 parser.add_argument("--graph", help="graph/model to be executed") 90 parser.add_argument("--labels", help="name of file containing labels") 91 parser.add_argument("--input_height", type=int, help="input height") 92 parser.add_argument("--input_width", type=int, help="input width") 93 parser.add_argument("--input_mean", type=int, help="input mean") 94 parser.add_argument("--input_std", type=int, help="input std") 95 parser.add_argument("--input_layer", help="name of input layer") 96 parser.add_argument("--output_layer", help="name of output layer") 97 args = parser.parse_args() 98 99 if args.graph: 100 model_file = args.graph 101 if args.image: 102 file_name = args.image 103 if args.labels: 104 label_file = args.labels 105 if args.input_height: 106 input_height = args.input_height 107 if args.input_width: 108 input_width = args.input_width 109 if args.input_mean: 110 input_mean = args.input_mean 111 if args.input_std: 112 input_std = args.input_std 113 if args.input_layer: 114 input_layer = args.input_layer 115 if args.output_layer: 116 output_layer = args.output_layer 117 118 graph = load_graph(model_file) 119 t = read_tensor_from_image_file( 120 file_name, 121 input_height=input_height, 122 input_width=input_width, 123 input_mean=input_mean, 124 input_std=input_std) 125 126 input_name = "import/" + input_layer 127 output_name = "import/" + output_layer 128 input_operation = graph.get_operation_by_name(input_name) 129 output_operation = graph.get_operation_by_name(output_name) 130 131 with tf.Session(graph=graph) as sess: 132 results = sess.run(output_operation.outputs[0], { 133 input_operation.outputs[0]: t 134 }) 135 results = np.squeeze(results) 136 137 top_k = results.argsort()[-5:][::-1] 138 labels = load_labels(label_file) 139 for i in top_k: 140 print(labels[i], results[i]) 141