Home | History | Annotate | Download | only in label_image
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // A minimal but useful C++ example showing how to load an Imagenet-style object
     17 // recognition TensorFlow model, prepare input images for it, run them through
     18 // the graph, and interpret the results.
     19 //
     20 // It's designed to have as few dependencies and be as clear as possible, so
     21 // it's more verbose than it could be in production code. In particular, using
     22 // auto for the types of a lot of the returned values from TensorFlow calls can
     23 // remove a lot of boilerplate, but I find the explicit types useful in sample
     24 // code to make it simple to look up the classes involved.
     25 //
     26 // To use it, compile and then run in a working directory with the
     27 // learning/brain/tutorials/label_image/data/ folder below it, and you should
     28 // see the top five labels for the example Lena image output. You can then
     29 // customize it to use your own models or images by changing the file names at
     30 // the top of the main() function.
     31 //
     32 // The googlenet_graph.pb file included by default is created from Inception.
     33 //
     34 // Note that, for GIF inputs, to reuse existing code, only single-frame ones
     35 // are supported.
     36 
     37 #include <fstream>
     38 #include <utility>
     39 #include <vector>
     40 
     41 #include "tensorflow/cc/ops/const_op.h"
     42 #include "tensorflow/cc/ops/image_ops.h"
     43 #include "tensorflow/cc/ops/standard_ops.h"
     44 #include "tensorflow/core/framework/graph.pb.h"
     45 #include "tensorflow/core/framework/tensor.h"
     46 #include "tensorflow/core/graph/default_device.h"
     47 #include "tensorflow/core/graph/graph_def_builder.h"
     48 #include "tensorflow/core/lib/core/errors.h"
     49 #include "tensorflow/core/lib/core/stringpiece.h"
     50 #include "tensorflow/core/lib/core/threadpool.h"
     51 #include "tensorflow/core/lib/io/path.h"
     52 #include "tensorflow/core/lib/strings/str_util.h"
     53 #include "tensorflow/core/lib/strings/stringprintf.h"
     54 #include "tensorflow/core/platform/env.h"
     55 #include "tensorflow/core/platform/init_main.h"
     56 #include "tensorflow/core/platform/logging.h"
     57 #include "tensorflow/core/platform/types.h"
     58 #include "tensorflow/core/public/session.h"
     59 #include "tensorflow/core/util/command_line_flags.h"
     60 
     61 // These are all common classes it's handy to reference with no namespace.
     62 using tensorflow::Flag;
     63 using tensorflow::Tensor;
     64 using tensorflow::Status;
     65 using tensorflow::string;
     66 using tensorflow::int32;
     67 
     68 // Takes a file name, and loads a list of labels from it, one per line, and
     69 // returns a vector of the strings. It pads with empty strings so the length
     70 // of the result is a multiple of 16, because our model expects that.
     71 Status ReadLabelsFile(const string& file_name, std::vector<string>* result,
     72                       size_t* found_label_count) {
     73   std::ifstream file(file_name);
     74   if (!file) {
     75     return tensorflow::errors::NotFound("Labels file ", file_name,
     76                                         " not found.");
     77   }
     78   result->clear();
     79   string line;
     80   while (std::getline(file, line)) {
     81     result->push_back(line);
     82   }
     83   *found_label_count = result->size();
     84   const int padding = 16;
     85   while (result->size() % padding) {
     86     result->emplace_back();
     87   }
     88   return Status::OK();
     89 }
     90 
     91 static Status ReadEntireFile(tensorflow::Env* env, const string& filename,
     92                              Tensor* output) {
     93   tensorflow::uint64 file_size = 0;
     94   TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
     95 
     96   string contents;
     97   contents.resize(file_size);
     98 
     99   std::unique_ptr<tensorflow::RandomAccessFile> file;
    100   TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
    101 
    102   tensorflow::StringPiece data;
    103   TF_RETURN_IF_ERROR(file->Read(0, file_size, &data, &(contents)[0]));
    104   if (data.size() != file_size) {
    105     return tensorflow::errors::DataLoss("Truncated read of '", filename,
    106                                         "' expected ", file_size, " got ",
    107                                         data.size());
    108   }
    109   output->scalar<string>()() = string(data);
    110   return Status::OK();
    111 }
    112 
    113 // Given an image file name, read in the data, try to decode it as an image,
    114 // resize it to the requested size, and then scale the values as desired.
    115 Status ReadTensorFromImageFile(const string& file_name, const int input_height,
    116                                const int input_width, const float input_mean,
    117                                const float input_std,
    118                                std::vector<Tensor>* out_tensors) {
    119   auto root = tensorflow::Scope::NewRootScope();
    120   using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    121 
    122   string input_name = "file_reader";
    123   string output_name = "normalized";
    124 
    125   // read file_name into a tensor named input
    126   Tensor input(tensorflow::DT_STRING, tensorflow::TensorShape());
    127   TF_RETURN_IF_ERROR(
    128       ReadEntireFile(tensorflow::Env::Default(), file_name, &input));
    129 
    130   // use a placeholder to read input data
    131   auto file_reader =
    132       Placeholder(root.WithOpName("input"), tensorflow::DataType::DT_STRING);
    133 
    134   std::vector<std::pair<string, tensorflow::Tensor>> inputs = {
    135       {"input", input},
    136   };
    137 
    138   // Now try to figure out what kind of file it is and decode it.
    139   const int wanted_channels = 3;
    140   tensorflow::Output image_reader;
    141   if (tensorflow::str_util::EndsWith(file_name, ".png")) {
    142     image_reader = DecodePng(root.WithOpName("png_reader"), file_reader,
    143                              DecodePng::Channels(wanted_channels));
    144   } else if (tensorflow::str_util::EndsWith(file_name, ".gif")) {
    145     // gif decoder returns 4-D tensor, remove the first dim
    146     image_reader =
    147         Squeeze(root.WithOpName("squeeze_first_dim"),
    148                 DecodeGif(root.WithOpName("gif_reader"), file_reader));
    149   } else if (tensorflow::str_util::EndsWith(file_name, ".bmp")) {
    150     image_reader = DecodeBmp(root.WithOpName("bmp_reader"), file_reader);
    151   } else {
    152     // Assume if it's neither a PNG nor a GIF then it must be a JPEG.
    153     image_reader = DecodeJpeg(root.WithOpName("jpeg_reader"), file_reader,
    154                               DecodeJpeg::Channels(wanted_channels));
    155   }
    156   // Now cast the image data to float so we can do normal math on it.
    157   auto float_caster =
    158       Cast(root.WithOpName("float_caster"), image_reader, tensorflow::DT_FLOAT);
    159   // The convention for image ops in TensorFlow is that all images are expected
    160   // to be in batches, so that they're four-dimensional arrays with indices of
    161   // [batch, height, width, channel]. Because we only have a single image, we
    162   // have to add a batch dimension of 1 to the start with ExpandDims().
    163   auto dims_expander = ExpandDims(root, float_caster, 0);
    164   // Bilinearly resize the image to fit the required dimensions.
    165   auto resized = ResizeBilinear(
    166       root, dims_expander,
    167       Const(root.WithOpName("size"), {input_height, input_width}));
    168   // Subtract the mean and divide by the scale.
    169   Div(root.WithOpName(output_name), Sub(root, resized, {input_mean}),
    170       {input_std});
    171 
    172   // This runs the GraphDef network definition that we've just constructed, and
    173   // returns the results in the output tensor.
    174   tensorflow::GraphDef graph;
    175   TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));
    176 
    177   std::unique_ptr<tensorflow::Session> session(
    178       tensorflow::NewSession(tensorflow::SessionOptions()));
    179   TF_RETURN_IF_ERROR(session->Create(graph));
    180   TF_RETURN_IF_ERROR(session->Run({inputs}, {output_name}, {}, out_tensors));
    181   return Status::OK();
    182 }
    183 
    184 // Reads a model graph definition from disk, and creates a session object you
    185 // can use to run it.
    186 Status LoadGraph(const string& graph_file_name,
    187                  std::unique_ptr<tensorflow::Session>* session) {
    188   tensorflow::GraphDef graph_def;
    189   Status load_graph_status =
    190       ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
    191   if (!load_graph_status.ok()) {
    192     return tensorflow::errors::NotFound("Failed to load compute graph at '",
    193                                         graph_file_name, "'");
    194   }
    195   session->reset(tensorflow::NewSession(tensorflow::SessionOptions()));
    196   Status session_create_status = (*session)->Create(graph_def);
    197   if (!session_create_status.ok()) {
    198     return session_create_status;
    199   }
    200   return Status::OK();
    201 }
    202 
    203 // Analyzes the output of the Inception graph to retrieve the highest scores and
    204 // their positions in the tensor, which correspond to categories.
    205 Status GetTopLabels(const std::vector<Tensor>& outputs, int how_many_labels,
    206                     Tensor* indices, Tensor* scores) {
    207   auto root = tensorflow::Scope::NewRootScope();
    208   using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    209 
    210   string output_name = "top_k";
    211   TopK(root.WithOpName(output_name), outputs[0], how_many_labels);
    212   // This runs the GraphDef network definition that we've just constructed, and
    213   // returns the results in the output tensors.
    214   tensorflow::GraphDef graph;
    215   TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));
    216 
    217   std::unique_ptr<tensorflow::Session> session(
    218       tensorflow::NewSession(tensorflow::SessionOptions()));
    219   TF_RETURN_IF_ERROR(session->Create(graph));
    220   // The TopK node returns two outputs, the scores and their original indices,
    221   // so we have to append :0 and :1 to specify them both.
    222   std::vector<Tensor> out_tensors;
    223   TF_RETURN_IF_ERROR(session->Run({}, {output_name + ":0", output_name + ":1"},
    224                                   {}, &out_tensors));
    225   *scores = out_tensors[0];
    226   *indices = out_tensors[1];
    227   return Status::OK();
    228 }
    229 
    230 // Given the output of a model run, and the name of a file containing the labels
    231 // this prints out the top five highest-scoring values.
    232 Status PrintTopLabels(const std::vector<Tensor>& outputs,
    233                       const string& labels_file_name) {
    234   std::vector<string> labels;
    235   size_t label_count;
    236   Status read_labels_status =
    237       ReadLabelsFile(labels_file_name, &labels, &label_count);
    238   if (!read_labels_status.ok()) {
    239     LOG(ERROR) << read_labels_status;
    240     return read_labels_status;
    241   }
    242   const int how_many_labels = std::min(5, static_cast<int>(label_count));
    243   Tensor indices;
    244   Tensor scores;
    245   TF_RETURN_IF_ERROR(GetTopLabels(outputs, how_many_labels, &indices, &scores));
    246   tensorflow::TTypes<float>::Flat scores_flat = scores.flat<float>();
    247   tensorflow::TTypes<int32>::Flat indices_flat = indices.flat<int32>();
    248   for (int pos = 0; pos < how_many_labels; ++pos) {
    249     const int label_index = indices_flat(pos);
    250     const float score = scores_flat(pos);
    251     LOG(INFO) << labels[label_index] << " (" << label_index << "): " << score;
    252   }
    253   return Status::OK();
    254 }
    255 
    256 // This is a testing function that returns whether the top label index is the
    257 // one that's expected.
    258 Status CheckTopLabel(const std::vector<Tensor>& outputs, int expected,
    259                      bool* is_expected) {
    260   *is_expected = false;
    261   Tensor indices;
    262   Tensor scores;
    263   const int how_many_labels = 1;
    264   TF_RETURN_IF_ERROR(GetTopLabels(outputs, how_many_labels, &indices, &scores));
    265   tensorflow::TTypes<int32>::Flat indices_flat = indices.flat<int32>();
    266   if (indices_flat(0) != expected) {
    267     LOG(ERROR) << "Expected label #" << expected << " but got #"
    268                << indices_flat(0);
    269     *is_expected = false;
    270   } else {
    271     *is_expected = true;
    272   }
    273   return Status::OK();
    274 }
    275 
    276 int main(int argc, char* argv[]) {
    277   // These are the command-line flags the program can understand.
    278   // They define where the graph and input data is located, and what kind of
    279   // input the model expects. If you train your own model, or use something
    280   // other than inception_v3, then you'll need to update these.
    281   string image = "tensorflow/examples/label_image/data/grace_hopper.jpg";
    282   string graph =
    283       "tensorflow/examples/label_image/data/inception_v3_2016_08_28_frozen.pb";
    284   string labels =
    285       "tensorflow/examples/label_image/data/imagenet_slim_labels.txt";
    286   int32 input_width = 299;
    287   int32 input_height = 299;
    288   float input_mean = 0;
    289   float input_std = 255;
    290   string input_layer = "input";
    291   string output_layer = "InceptionV3/Predictions/Reshape_1";
    292   bool self_test = false;
    293   string root_dir = "";
    294   std::vector<Flag> flag_list = {
    295       Flag("image", &image, "image to be processed"),
    296       Flag("graph", &graph, "graph to be executed"),
    297       Flag("labels", &labels, "name of file containing labels"),
    298       Flag("input_width", &input_width, "resize image to this width in pixels"),
    299       Flag("input_height", &input_height,
    300            "resize image to this height in pixels"),
    301       Flag("input_mean", &input_mean, "scale pixel values to this mean"),
    302       Flag("input_std", &input_std, "scale pixel values to this std deviation"),
    303       Flag("input_layer", &input_layer, "name of input layer"),
    304       Flag("output_layer", &output_layer, "name of output layer"),
    305       Flag("self_test", &self_test, "run a self test"),
    306       Flag("root_dir", &root_dir,
    307            "interpret image and graph file names relative to this directory"),
    308   };
    309   string usage = tensorflow::Flags::Usage(argv[0], flag_list);
    310   const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
    311   if (!parse_result) {
    312     LOG(ERROR) << usage;
    313     return -1;
    314   }
    315 
    316   // We need to call this to set up global state for TensorFlow.
    317   tensorflow::port::InitMain(argv[0], &argc, &argv);
    318   if (argc > 1) {
    319     LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
    320     return -1;
    321   }
    322 
    323   // First we load and initialize the model.
    324   std::unique_ptr<tensorflow::Session> session;
    325   string graph_path = tensorflow::io::JoinPath(root_dir, graph);
    326   Status load_graph_status = LoadGraph(graph_path, &session);
    327   if (!load_graph_status.ok()) {
    328     LOG(ERROR) << load_graph_status;
    329     return -1;
    330   }
    331 
    332   // Get the image from disk as a float array of numbers, resized and normalized
    333   // to the specifications the main graph expects.
    334   std::vector<Tensor> resized_tensors;
    335   string image_path = tensorflow::io::JoinPath(root_dir, image);
    336   Status read_tensor_status =
    337       ReadTensorFromImageFile(image_path, input_height, input_width, input_mean,
    338                               input_std, &resized_tensors);
    339   if (!read_tensor_status.ok()) {
    340     LOG(ERROR) << read_tensor_status;
    341     return -1;
    342   }
    343   const Tensor& resized_tensor = resized_tensors[0];
    344 
    345   // Actually run the image through the model.
    346   std::vector<Tensor> outputs;
    347   Status run_status = session->Run({{input_layer, resized_tensor}},
    348                                    {output_layer}, {}, &outputs);
    349   if (!run_status.ok()) {
    350     LOG(ERROR) << "Running model failed: " << run_status;
    351     return -1;
    352   }
    353 
    354   // This is for automated testing to make sure we get the expected result with
    355   // the default settings. We know that label 653 (military uniform) should be
    356   // the top label for the Admiral Hopper image.
    357   if (self_test) {
    358     bool expected_matches;
    359     Status check_status = CheckTopLabel(outputs, 653, &expected_matches);
    360     if (!check_status.ok()) {
    361       LOG(ERROR) << "Running check failed: " << check_status;
    362       return -1;
    363     }
    364     if (!expected_matches) {
    365       LOG(ERROR) << "Self-test failed!";
    366       return -1;
    367     }
    368   }
    369 
    370   // Do something interesting with the results we've generated.
    371   Status print_status = PrintTopLabels(outputs, labels);
    372   if (!print_status.ok()) {
    373     LOG(ERROR) << "Running print failed: " << print_status;
    374     return -1;
    375   }
    376 
    377   return 0;
    378 }
    379