Home | History | Annotate | Download | only in label_image
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <cstdarg>
     17 #include <cstdio>
     18 #include <cstdlib>
     19 #include <fstream>
     20 #include <iostream>
     21 #include <memory>
     22 #include <sstream>
     23 #include <string>
     24 #include <unordered_set>
     25 #include <vector>
     26 
     27 #include <fcntl.h>      // NOLINT(build/include_order)
     28 #include <getopt.h>     // NOLINT(build/include_order)
     29 #include <sys/time.h>   // NOLINT(build/include_order)
     30 #include <sys/types.h>  // NOLINT(build/include_order)
     31 #include <sys/uio.h>    // NOLINT(build/include_order)
     32 #include <unistd.h>     // NOLINT(build/include_order)
     33 
     34 #include "tensorflow/contrib/lite/kernels/register.h"
     35 #include "tensorflow/contrib/lite/model.h"
     36 #include "tensorflow/contrib/lite/optional_debug_tools.h"
     37 #include "tensorflow/contrib/lite/string_util.h"
     38 
     39 #include "tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h"
     40 #include "tensorflow/contrib/lite/examples/label_image/get_top_n.h"
     41 
     42 #define LOG(x) std::cerr
     43 
     44 namespace tflite {
     45 namespace label_image {
     46 
     47 double get_us(struct timeval t) { return (t.tv_sec * 1000000 + t.tv_usec); }
     48 
     49 // Takes a file name, and loads a list of labels from it, one per line, and
     50 // returns a vector of the strings. It pads with empty strings so the length
     51 // of the result is a multiple of 16, because our model expects that.
     52 TfLiteStatus ReadLabelsFile(const string& file_name,
     53                             std::vector<string>* result,
     54                             size_t* found_label_count) {
     55   std::ifstream file(file_name);
     56   if (!file) {
     57     LOG(FATAL) << "Labels file " << file_name << " not found\n";
     58     return kTfLiteError;
     59   }
     60   result->clear();
     61   string line;
     62   while (std::getline(file, line)) {
     63     result->push_back(line);
     64   }
     65   *found_label_count = result->size();
     66   const int padding = 16;
     67   while (result->size() % padding) {
     68     result->emplace_back();
     69   }
     70   return kTfLiteOk;
     71 }
     72 
     73 void RunInference(Settings* s) {
     74   if (!s->model_name.c_str()) {
     75     LOG(ERROR) << "no model file name\n";
     76     exit(-1);
     77   }
     78 
     79   std::unique_ptr<tflite::FlatBufferModel> model;
     80   std::unique_ptr<tflite::Interpreter> interpreter;
     81   model = tflite::FlatBufferModel::BuildFromFile(s->model_name.c_str());
     82   if (!model) {
     83     LOG(FATAL) << "\nFailed to mmap model " << s->model_name << "\n";
     84     exit(-1);
     85   }
     86   LOG(INFO) << "Loaded model " << s->model_name << "\n";
     87   model->error_reporter();
     88   LOG(INFO) << "resolved reporter\n";
     89 
     90   tflite::ops::builtin::BuiltinOpResolver resolver;
     91 
     92   tflite::InterpreterBuilder(*model, resolver)(&interpreter);
     93   if (!interpreter) {
     94     LOG(FATAL) << "Failed to construct interpreter\n";
     95     exit(-1);
     96   }
     97 
     98   interpreter->UseNNAPI(s->accel);
     99 
    100   if (s->verbose) {
    101     LOG(INFO) << "tensors size: " << interpreter->tensors_size() << "\n";
    102     LOG(INFO) << "nodes size: " << interpreter->nodes_size() << "\n";
    103     LOG(INFO) << "inputs: " << interpreter->inputs().size() << "\n";
    104     LOG(INFO) << "input(0) name: " << interpreter->GetInputName(0) << "\n";
    105 
    106     int t_size = interpreter->tensors_size();
    107     for (int i = 0; i < t_size; i++) {
    108       if (interpreter->tensor(i)->name)
    109         LOG(INFO) << i << ": " << interpreter->tensor(i)->name << ", "
    110                   << interpreter->tensor(i)->bytes << ", "
    111                   << interpreter->tensor(i)->type << ", "
    112                   << interpreter->tensor(i)->params.scale << ", "
    113                   << interpreter->tensor(i)->params.zero_point << "\n";
    114     }
    115   }
    116 
    117   if (s->number_of_threads != -1) {
    118     interpreter->SetNumThreads(s->number_of_threads);
    119   }
    120 
    121   int image_width = 224;
    122   int image_height = 224;
    123   int image_channels = 3;
    124   uint8_t* in = read_bmp(s->input_bmp_name, &image_width, &image_height,
    125                          &image_channels, s);
    126 
    127   int input = interpreter->inputs()[0];
    128   if (s->verbose) LOG(INFO) << "input: " << input << "\n";
    129 
    130   const std::vector<int> inputs = interpreter->inputs();
    131   const std::vector<int> outputs = interpreter->outputs();
    132 
    133   if (s->verbose) {
    134     LOG(INFO) << "number of inputs: " << inputs.size() << "\n";
    135     LOG(INFO) << "number of outputs: " << outputs.size() << "\n";
    136   }
    137 
    138   if (interpreter->AllocateTensors() != kTfLiteOk) {
    139     LOG(FATAL) << "Failed to allocate tensors!";
    140   }
    141 
    142   if (s->verbose) PrintInterpreterState(interpreter.get());
    143 
    144   // get input dimension from the input tensor metadata
    145   // assuming one input only
    146   TfLiteIntArray* dims = interpreter->tensor(input)->dims;
    147   int wanted_height = dims->data[1];
    148   int wanted_width = dims->data[2];
    149   int wanted_channels = dims->data[3];
    150 
    151   switch (interpreter->tensor(input)->type) {
    152     case kTfLiteFloat32:
    153       s->input_floating = true;
    154       resize<float>(interpreter->typed_tensor<float>(input), in, image_height,
    155                     image_width, image_channels, wanted_height, wanted_width,
    156                     wanted_channels, s);
    157       break;
    158     case kTfLiteUInt8:
    159       resize<uint8_t>(interpreter->typed_tensor<uint8_t>(input), in,
    160                       image_height, image_width, image_channels, wanted_height,
    161                       wanted_width, wanted_channels, s);
    162       break;
    163     default:
    164       LOG(FATAL) << "cannot handle input type "
    165                  << interpreter->tensor(input)->type << " yet";
    166       exit(-1);
    167   }
    168 
    169   struct timeval start_time, stop_time;
    170   gettimeofday(&start_time, NULL);
    171   for (int i = 0; i < s->loop_count; i++) {
    172     if (interpreter->Invoke() != kTfLiteOk) {
    173       LOG(FATAL) << "Failed to invoke tflite!\n";
    174     }
    175   }
    176   gettimeofday(&stop_time, NULL);
    177   LOG(INFO) << "invoked \n";
    178   LOG(INFO) << "average time: "
    179             << (get_us(stop_time) - get_us(start_time)) / (s->loop_count * 1000)
    180             << " ms \n";
    181 
    182   const int output_size = 1000;
    183   const size_t num_results = 5;
    184   const float threshold = 0.001f;
    185 
    186   std::vector<std::pair<float, int>> top_results;
    187 
    188   int output = interpreter->outputs()[0];
    189   switch (interpreter->tensor(output)->type) {
    190     case kTfLiteFloat32:
    191       get_top_n<float>(interpreter->typed_output_tensor<float>(0), output_size,
    192                        num_results, threshold, &top_results, true);
    193       break;
    194     case kTfLiteUInt8:
    195       get_top_n<uint8_t>(interpreter->typed_output_tensor<uint8_t>(0),
    196                          output_size, num_results, threshold, &top_results,
    197                          false);
    198       break;
    199     default:
    200       LOG(FATAL) << "cannot handle output type "
    201                  << interpreter->tensor(input)->type << " yet";
    202       exit(-1);
    203   }
    204 
    205   std::vector<string> labels;
    206   size_t label_count;
    207 
    208   if (ReadLabelsFile(s->labels_file_name, &labels, &label_count) != kTfLiteOk)
    209     exit(-1);
    210 
    211   for (const auto& result : top_results) {
    212     const float confidence = result.first;
    213     const int index = result.second;
    214     LOG(INFO) << confidence << ": " << index << " " << labels[index] << "\n";
    215   }
    216 }
    217 
    218 void display_usage() {
    219   LOG(INFO) << "label_image\n"
    220             << "--accelerated, -a: [0|1], use Android NNAPI or note\n"
    221             << "--count, -c: loop interpreter->Invoke() for certain times\n"
    222             << "--input_mean, -b: input mean\n"
    223             << "--input_std, -s: input standard deviation\n"
    224             << "--image, -i: image_name.bmp\n"
    225             << "--labels, -l: labels for the model\n"
    226             << "--tflite_model, -m: model_name.tflite\n"
    227             << "--threads, -t: number of threads\n"
    228             << "--verbose, -v: [0|1] print more information\n"
    229             << "\n";
    230 }
    231 
    232 int Main(int argc, char** argv) {
    233   Settings s;
    234 
    235   int c;
    236   while (1) {
    237     static struct option long_options[] = {
    238         {"accelerated", required_argument, 0, 'a'},
    239         {"count", required_argument, 0, 'c'},
    240         {"verbose", required_argument, 0, 'v'},
    241         {"image", required_argument, 0, 'i'},
    242         {"labels", required_argument, 0, 'l'},
    243         {"tflite_model", required_argument, 0, 'm'},
    244         {"threads", required_argument, 0, 't'},
    245         {"input_mean", required_argument, 0, 'b'},
    246         {"input_std", required_argument, 0, 's'},
    247         {0, 0, 0, 0}};
    248 
    249     /* getopt_long stores the option index here. */
    250     int option_index = 0;
    251 
    252     c = getopt_long(argc, argv, "a:b:c:f:i:l:m:s:t:v:", long_options,
    253                     &option_index);
    254 
    255     /* Detect the end of the options. */
    256     if (c == -1) break;
    257 
    258     switch (c) {
    259       case 'a':
    260         s.accel = strtol(  // NOLINT(runtime/deprecated_fn)
    261             optarg, (char**)NULL, 10);
    262         break;
    263       case 'b':
    264         s.input_mean = strtod(optarg, NULL);
    265         break;
    266       case 'c':
    267         s.loop_count = strtol(  // NOLINT(runtime/deprecated_fn)
    268             optarg, (char**)NULL, 10);
    269         break;
    270       case 'i':
    271         s.input_bmp_name = optarg;
    272         break;
    273       case 'l':
    274         s.labels_file_name = optarg;
    275         break;
    276       case 'm':
    277         s.model_name = optarg;
    278         break;
    279       case 's':
    280         s.input_std = strtod(optarg, NULL);
    281         break;
    282       case 't':
    283         s.number_of_threads = strtol(  // NOLINT(runtime/deprecated_fn)
    284             optarg, (char**)NULL, 10);
    285         break;
    286       case 'v':
    287         s.verbose = strtol(  // NOLINT(runtime/deprecated_fn)
    288             optarg, (char**)NULL, 10);
    289         break;
    290       case 'h':
    291       case '?':
    292         /* getopt_long already printed an error message. */
    293         display_usage();
    294         exit(-1);
    295       default:
    296         exit(-1);
    297     }
    298   }
    299   RunInference(&s);
    300   return 0;
    301 }
    302 
    303 }  // namespace label_image
    304 }  // namespace tflite
    305 
    306 int main(int argc, char** argv) {
    307   return tflite::label_image::Main(argc, argv);
    308 }
    309