1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <cstdarg> 17 #include <cstdio> 18 #include <cstdlib> 19 #include <fstream> 20 #include <iostream> 21 #include <memory> 22 #include <sstream> 23 #include <string> 24 #include <unordered_set> 25 #include <vector> 26 27 #include <fcntl.h> // NOLINT(build/include_order) 28 #include <getopt.h> // NOLINT(build/include_order) 29 #include <sys/time.h> // NOLINT(build/include_order) 30 #include <sys/types.h> // NOLINT(build/include_order) 31 #include <sys/uio.h> // NOLINT(build/include_order) 32 #include <unistd.h> // NOLINT(build/include_order) 33 34 #include "tensorflow/contrib/lite/kernels/register.h" 35 #include "tensorflow/contrib/lite/model.h" 36 #include "tensorflow/contrib/lite/optional_debug_tools.h" 37 #include "tensorflow/contrib/lite/string_util.h" 38 39 #include "tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h" 40 #include "tensorflow/contrib/lite/examples/label_image/get_top_n.h" 41 42 #define LOG(x) std::cerr 43 44 namespace tflite { 45 namespace label_image { 46 47 double get_us(struct timeval t) { return (t.tv_sec * 1000000 + t.tv_usec); } 48 49 // Takes a file name, and loads a list of labels from it, one per line, and 50 // returns a vector of the strings. It pads with empty strings so the length 51 // of the result is a multiple of 16, because our model expects that. 52 TfLiteStatus ReadLabelsFile(const string& file_name, 53 std::vector<string>* result, 54 size_t* found_label_count) { 55 std::ifstream file(file_name); 56 if (!file) { 57 LOG(FATAL) << "Labels file " << file_name << " not found\n"; 58 return kTfLiteError; 59 } 60 result->clear(); 61 string line; 62 while (std::getline(file, line)) { 63 result->push_back(line); 64 } 65 *found_label_count = result->size(); 66 const int padding = 16; 67 while (result->size() % padding) { 68 result->emplace_back(); 69 } 70 return kTfLiteOk; 71 } 72 73 void RunInference(Settings* s) { 74 if (!s->model_name.c_str()) { 75 LOG(ERROR) << "no model file name\n"; 76 exit(-1); 77 } 78 79 std::unique_ptr<tflite::FlatBufferModel> model; 80 std::unique_ptr<tflite::Interpreter> interpreter; 81 model = tflite::FlatBufferModel::BuildFromFile(s->model_name.c_str()); 82 if (!model) { 83 LOG(FATAL) << "\nFailed to mmap model " << s->model_name << "\n"; 84 exit(-1); 85 } 86 LOG(INFO) << "Loaded model " << s->model_name << "\n"; 87 model->error_reporter(); 88 LOG(INFO) << "resolved reporter\n"; 89 90 tflite::ops::builtin::BuiltinOpResolver resolver; 91 92 tflite::InterpreterBuilder(*model, resolver)(&interpreter); 93 if (!interpreter) { 94 LOG(FATAL) << "Failed to construct interpreter\n"; 95 exit(-1); 96 } 97 98 interpreter->UseNNAPI(s->accel); 99 100 if (s->verbose) { 101 LOG(INFO) << "tensors size: " << interpreter->tensors_size() << "\n"; 102 LOG(INFO) << "nodes size: " << interpreter->nodes_size() << "\n"; 103 LOG(INFO) << "inputs: " << interpreter->inputs().size() << "\n"; 104 LOG(INFO) << "input(0) name: " << interpreter->GetInputName(0) << "\n"; 105 106 int t_size = interpreter->tensors_size(); 107 for (int i = 0; i < t_size; i++) { 108 if (interpreter->tensor(i)->name) 109 LOG(INFO) << i << ": " << interpreter->tensor(i)->name << ", " 110 << interpreter->tensor(i)->bytes << ", " 111 << interpreter->tensor(i)->type << ", " 112 << interpreter->tensor(i)->params.scale << ", " 113 << interpreter->tensor(i)->params.zero_point << "\n"; 114 } 115 } 116 117 if (s->number_of_threads != -1) { 118 interpreter->SetNumThreads(s->number_of_threads); 119 } 120 121 int image_width = 224; 122 int image_height = 224; 123 int image_channels = 3; 124 uint8_t* in = read_bmp(s->input_bmp_name, &image_width, &image_height, 125 &image_channels, s); 126 127 int input = interpreter->inputs()[0]; 128 if (s->verbose) LOG(INFO) << "input: " << input << "\n"; 129 130 const std::vector<int> inputs = interpreter->inputs(); 131 const std::vector<int> outputs = interpreter->outputs(); 132 133 if (s->verbose) { 134 LOG(INFO) << "number of inputs: " << inputs.size() << "\n"; 135 LOG(INFO) << "number of outputs: " << outputs.size() << "\n"; 136 } 137 138 if (interpreter->AllocateTensors() != kTfLiteOk) { 139 LOG(FATAL) << "Failed to allocate tensors!"; 140 } 141 142 if (s->verbose) PrintInterpreterState(interpreter.get()); 143 144 // get input dimension from the input tensor metadata 145 // assuming one input only 146 TfLiteIntArray* dims = interpreter->tensor(input)->dims; 147 int wanted_height = dims->data[1]; 148 int wanted_width = dims->data[2]; 149 int wanted_channels = dims->data[3]; 150 151 switch (interpreter->tensor(input)->type) { 152 case kTfLiteFloat32: 153 s->input_floating = true; 154 resize<float>(interpreter->typed_tensor<float>(input), in, image_height, 155 image_width, image_channels, wanted_height, wanted_width, 156 wanted_channels, s); 157 break; 158 case kTfLiteUInt8: 159 resize<uint8_t>(interpreter->typed_tensor<uint8_t>(input), in, 160 image_height, image_width, image_channels, wanted_height, 161 wanted_width, wanted_channels, s); 162 break; 163 default: 164 LOG(FATAL) << "cannot handle input type " 165 << interpreter->tensor(input)->type << " yet"; 166 exit(-1); 167 } 168 169 struct timeval start_time, stop_time; 170 gettimeofday(&start_time, NULL); 171 for (int i = 0; i < s->loop_count; i++) { 172 if (interpreter->Invoke() != kTfLiteOk) { 173 LOG(FATAL) << "Failed to invoke tflite!\n"; 174 } 175 } 176 gettimeofday(&stop_time, NULL); 177 LOG(INFO) << "invoked \n"; 178 LOG(INFO) << "average time: " 179 << (get_us(stop_time) - get_us(start_time)) / (s->loop_count * 1000) 180 << " ms \n"; 181 182 const int output_size = 1000; 183 const size_t num_results = 5; 184 const float threshold = 0.001f; 185 186 std::vector<std::pair<float, int>> top_results; 187 188 int output = interpreter->outputs()[0]; 189 switch (interpreter->tensor(output)->type) { 190 case kTfLiteFloat32: 191 get_top_n<float>(interpreter->typed_output_tensor<float>(0), output_size, 192 num_results, threshold, &top_results, true); 193 break; 194 case kTfLiteUInt8: 195 get_top_n<uint8_t>(interpreter->typed_output_tensor<uint8_t>(0), 196 output_size, num_results, threshold, &top_results, 197 false); 198 break; 199 default: 200 LOG(FATAL) << "cannot handle output type " 201 << interpreter->tensor(input)->type << " yet"; 202 exit(-1); 203 } 204 205 std::vector<string> labels; 206 size_t label_count; 207 208 if (ReadLabelsFile(s->labels_file_name, &labels, &label_count) != kTfLiteOk) 209 exit(-1); 210 211 for (const auto& result : top_results) { 212 const float confidence = result.first; 213 const int index = result.second; 214 LOG(INFO) << confidence << ": " << index << " " << labels[index] << "\n"; 215 } 216 } 217 218 void display_usage() { 219 LOG(INFO) << "label_image\n" 220 << "--accelerated, -a: [0|1], use Android NNAPI or note\n" 221 << "--count, -c: loop interpreter->Invoke() for certain times\n" 222 << "--input_mean, -b: input mean\n" 223 << "--input_std, -s: input standard deviation\n" 224 << "--image, -i: image_name.bmp\n" 225 << "--labels, -l: labels for the model\n" 226 << "--tflite_model, -m: model_name.tflite\n" 227 << "--threads, -t: number of threads\n" 228 << "--verbose, -v: [0|1] print more information\n" 229 << "\n"; 230 } 231 232 int Main(int argc, char** argv) { 233 Settings s; 234 235 int c; 236 while (1) { 237 static struct option long_options[] = { 238 {"accelerated", required_argument, 0, 'a'}, 239 {"count", required_argument, 0, 'c'}, 240 {"verbose", required_argument, 0, 'v'}, 241 {"image", required_argument, 0, 'i'}, 242 {"labels", required_argument, 0, 'l'}, 243 {"tflite_model", required_argument, 0, 'm'}, 244 {"threads", required_argument, 0, 't'}, 245 {"input_mean", required_argument, 0, 'b'}, 246 {"input_std", required_argument, 0, 's'}, 247 {0, 0, 0, 0}}; 248 249 /* getopt_long stores the option index here. */ 250 int option_index = 0; 251 252 c = getopt_long(argc, argv, "a:b:c:f:i:l:m:s:t:v:", long_options, 253 &option_index); 254 255 /* Detect the end of the options. */ 256 if (c == -1) break; 257 258 switch (c) { 259 case 'a': 260 s.accel = strtol( // NOLINT(runtime/deprecated_fn) 261 optarg, (char**)NULL, 10); 262 break; 263 case 'b': 264 s.input_mean = strtod(optarg, NULL); 265 break; 266 case 'c': 267 s.loop_count = strtol( // NOLINT(runtime/deprecated_fn) 268 optarg, (char**)NULL, 10); 269 break; 270 case 'i': 271 s.input_bmp_name = optarg; 272 break; 273 case 'l': 274 s.labels_file_name = optarg; 275 break; 276 case 'm': 277 s.model_name = optarg; 278 break; 279 case 's': 280 s.input_std = strtod(optarg, NULL); 281 break; 282 case 't': 283 s.number_of_threads = strtol( // NOLINT(runtime/deprecated_fn) 284 optarg, (char**)NULL, 10); 285 break; 286 case 'v': 287 s.verbose = strtol( // NOLINT(runtime/deprecated_fn) 288 optarg, (char**)NULL, 10); 289 break; 290 case 'h': 291 case '?': 292 /* getopt_long already printed an error message. */ 293 display_usage(); 294 exit(-1); 295 default: 296 exit(-1); 297 } 298 } 299 RunInference(&s); 300 return 0; 301 } 302 303 } // namespace label_image 304 } // namespace tflite 305 306 int main(int argc, char** argv) { 307 return tflite::label_image::Main(argc, argv); 308 } 309