Home | History | Annotate | Download | only in label_image
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // A minimal but useful C++ example showing how to load an Imagenet-style object
     17 // recognition TensorFlow model, prepare input images for it, run them through
     18 // the graph, and interpret the results.
     19 //
     20 // It's designed to have as few dependencies and be as clear as possible, so
     21 // it's more verbose than it could be in production code. In particular, using
     22 // auto for the types of a lot of the returned values from TensorFlow calls can
     23 // remove a lot of boilerplate, but I find the explicit types useful in sample
     24 // code to make it simple to look up the classes involved.
     25 //
     26 // To use it, compile and then run in a working directory with the
     27 // learning/brain/tutorials/label_image/data/ folder below it, and you should
     28 // see the top five labels for the example Lena image output. You can then
     29 // customize it to use your own models or images by changing the file names at
     30 // the top of the main() function.
     31 //
     32 // The googlenet_graph.pb file included by default is created from Inception.
     33 //
     34 // Note that, for GIF inputs, to reuse existing code, only single-frame ones
     35 // are supported.
     36 
     37 #include <fstream>
     38 #include <utility>
     39 #include <vector>
     40 
     41 #include "tensorflow/cc/ops/const_op.h"
     42 #include "tensorflow/cc/ops/image_ops.h"
     43 #include "tensorflow/cc/ops/standard_ops.h"
     44 #include "tensorflow/core/framework/graph.pb.h"
     45 #include "tensorflow/core/framework/tensor.h"
     46 #include "tensorflow/core/graph/default_device.h"
     47 #include "tensorflow/core/graph/graph_def_builder.h"
     48 #include "tensorflow/core/lib/core/errors.h"
     49 #include "tensorflow/core/lib/core/stringpiece.h"
     50 #include "tensorflow/core/lib/core/threadpool.h"
     51 #include "tensorflow/core/lib/io/path.h"
     52 #include "tensorflow/core/lib/strings/stringprintf.h"
     53 #include "tensorflow/core/platform/env.h"
     54 #include "tensorflow/core/platform/init_main.h"
     55 #include "tensorflow/core/platform/logging.h"
     56 #include "tensorflow/core/platform/types.h"
     57 #include "tensorflow/core/public/session.h"
     58 #include "tensorflow/core/util/command_line_flags.h"
     59 
     60 // These are all common classes it's handy to reference with no namespace.
     61 using tensorflow::Flag;
     62 using tensorflow::Tensor;
     63 using tensorflow::Status;
     64 using tensorflow::string;
     65 using tensorflow::int32;
     66 
     67 // Takes a file name, and loads a list of labels from it, one per line, and
     68 // returns a vector of the strings. It pads with empty strings so the length
     69 // of the result is a multiple of 16, because our model expects that.
     70 Status ReadLabelsFile(const string& file_name, std::vector<string>* result,
     71                       size_t* found_label_count) {
     72   std::ifstream file(file_name);
     73   if (!file) {
     74     return tensorflow::errors::NotFound("Labels file ", file_name,
     75                                         " not found.");
     76   }
     77   result->clear();
     78   string line;
     79   while (std::getline(file, line)) {
     80     result->push_back(line);
     81   }
     82   *found_label_count = result->size();
     83   const int padding = 16;
     84   while (result->size() % padding) {
     85     result->emplace_back();
     86   }
     87   return Status::OK();
     88 }
     89 
     90 static Status ReadEntireFile(tensorflow::Env* env, const string& filename,
     91                              Tensor* output) {
     92   tensorflow::uint64 file_size = 0;
     93   TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
     94 
     95   string contents;
     96   contents.resize(file_size);
     97 
     98   std::unique_ptr<tensorflow::RandomAccessFile> file;
     99   TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
    100 
    101   tensorflow::StringPiece data;
    102   TF_RETURN_IF_ERROR(file->Read(0, file_size, &data, &(contents)[0]));
    103   if (data.size() != file_size) {
    104     return tensorflow::errors::DataLoss("Truncated read of '", filename,
    105                                         "' expected ", file_size, " got ",
    106                                         data.size());
    107   }
    108   output->scalar<string>()() = data.ToString();
    109   return Status::OK();
    110 }
    111 
    112 // Given an image file name, read in the data, try to decode it as an image,
    113 // resize it to the requested size, and then scale the values as desired.
    114 Status ReadTensorFromImageFile(const string& file_name, const int input_height,
    115                                const int input_width, const float input_mean,
    116                                const float input_std,
    117                                std::vector<Tensor>* out_tensors) {
    118   auto root = tensorflow::Scope::NewRootScope();
    119   using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    120 
    121   string input_name = "file_reader";
    122   string output_name = "normalized";
    123 
    124   // read file_name into a tensor named input
    125   Tensor input(tensorflow::DT_STRING, tensorflow::TensorShape());
    126   TF_RETURN_IF_ERROR(
    127       ReadEntireFile(tensorflow::Env::Default(), file_name, &input));
    128 
    129   // use a placeholder to read input data
    130   auto file_reader =
    131       Placeholder(root.WithOpName("input"), tensorflow::DataType::DT_STRING);
    132 
    133   std::vector<std::pair<string, tensorflow::Tensor>> inputs = {
    134       {"input", input},
    135   };
    136 
    137   // Now try to figure out what kind of file it is and decode it.
    138   const int wanted_channels = 3;
    139   tensorflow::Output image_reader;
    140   if (tensorflow::StringPiece(file_name).ends_with(".png")) {
    141     image_reader = DecodePng(root.WithOpName("png_reader"), file_reader,
    142                              DecodePng::Channels(wanted_channels));
    143   } else if (tensorflow::StringPiece(file_name).ends_with(".gif")) {
    144     // gif decoder returns 4-D tensor, remove the first dim
    145     image_reader =
    146         Squeeze(root.WithOpName("squeeze_first_dim"),
    147                 DecodeGif(root.WithOpName("gif_reader"), file_reader));
    148   } else if (tensorflow::StringPiece(file_name).ends_with(".bmp")) {
    149     image_reader = DecodeBmp(root.WithOpName("bmp_reader"), file_reader);
    150   } else {
    151     // Assume if it's neither a PNG nor a GIF then it must be a JPEG.
    152     image_reader = DecodeJpeg(root.WithOpName("jpeg_reader"), file_reader,
    153                               DecodeJpeg::Channels(wanted_channels));
    154   }
    155   // Now cast the image data to float so we can do normal math on it.
    156   auto float_caster =
    157       Cast(root.WithOpName("float_caster"), image_reader, tensorflow::DT_FLOAT);
    158   // The convention for image ops in TensorFlow is that all images are expected
    159   // to be in batches, so that they're four-dimensional arrays with indices of
    160   // [batch, height, width, channel]. Because we only have a single image, we
    161   // have to add a batch dimension of 1 to the start with ExpandDims().
    162   auto dims_expander = ExpandDims(root, float_caster, 0);
    163   // Bilinearly resize the image to fit the required dimensions.
    164   auto resized = ResizeBilinear(
    165       root, dims_expander,
    166       Const(root.WithOpName("size"), {input_height, input_width}));
    167   // Subtract the mean and divide by the scale.
    168   Div(root.WithOpName(output_name), Sub(root, resized, {input_mean}),
    169       {input_std});
    170 
    171   // This runs the GraphDef network definition that we've just constructed, and
    172   // returns the results in the output tensor.
    173   tensorflow::GraphDef graph;
    174   TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));
    175 
    176   std::unique_ptr<tensorflow::Session> session(
    177       tensorflow::NewSession(tensorflow::SessionOptions()));
    178   TF_RETURN_IF_ERROR(session->Create(graph));
    179   TF_RETURN_IF_ERROR(session->Run({inputs}, {output_name}, {}, out_tensors));
    180   return Status::OK();
    181 }
    182 
    183 // Reads a model graph definition from disk, and creates a session object you
    184 // can use to run it.
    185 Status LoadGraph(const string& graph_file_name,
    186                  std::unique_ptr<tensorflow::Session>* session) {
    187   tensorflow::GraphDef graph_def;
    188   Status load_graph_status =
    189       ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
    190   if (!load_graph_status.ok()) {
    191     return tensorflow::errors::NotFound("Failed to load compute graph at '",
    192                                         graph_file_name, "'");
    193   }
    194   session->reset(tensorflow::NewSession(tensorflow::SessionOptions()));
    195   Status session_create_status = (*session)->Create(graph_def);
    196   if (!session_create_status.ok()) {
    197     return session_create_status;
    198   }
    199   return Status::OK();
    200 }
    201 
    202 // Analyzes the output of the Inception graph to retrieve the highest scores and
    203 // their positions in the tensor, which correspond to categories.
    204 Status GetTopLabels(const std::vector<Tensor>& outputs, int how_many_labels,
    205                     Tensor* indices, Tensor* scores) {
    206   auto root = tensorflow::Scope::NewRootScope();
    207   using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    208 
    209   string output_name = "top_k";
    210   TopK(root.WithOpName(output_name), outputs[0], how_many_labels);
    211   // This runs the GraphDef network definition that we've just constructed, and
    212   // returns the results in the output tensors.
    213   tensorflow::GraphDef graph;
    214   TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));
    215 
    216   std::unique_ptr<tensorflow::Session> session(
    217       tensorflow::NewSession(tensorflow::SessionOptions()));
    218   TF_RETURN_IF_ERROR(session->Create(graph));
    219   // The TopK node returns two outputs, the scores and their original indices,
    220   // so we have to append :0 and :1 to specify them both.
    221   std::vector<Tensor> out_tensors;
    222   TF_RETURN_IF_ERROR(session->Run({}, {output_name + ":0", output_name + ":1"},
    223                                   {}, &out_tensors));
    224   *scores = out_tensors[0];
    225   *indices = out_tensors[1];
    226   return Status::OK();
    227 }
    228 
    229 // Given the output of a model run, and the name of a file containing the labels
    230 // this prints out the top five highest-scoring values.
    231 Status PrintTopLabels(const std::vector<Tensor>& outputs,
    232                       const string& labels_file_name) {
    233   std::vector<string> labels;
    234   size_t label_count;
    235   Status read_labels_status =
    236       ReadLabelsFile(labels_file_name, &labels, &label_count);
    237   if (!read_labels_status.ok()) {
    238     LOG(ERROR) << read_labels_status;
    239     return read_labels_status;
    240   }
    241   const int how_many_labels = std::min(5, static_cast<int>(label_count));
    242   Tensor indices;
    243   Tensor scores;
    244   TF_RETURN_IF_ERROR(GetTopLabels(outputs, how_many_labels, &indices, &scores));
    245   tensorflow::TTypes<float>::Flat scores_flat = scores.flat<float>();
    246   tensorflow::TTypes<int32>::Flat indices_flat = indices.flat<int32>();
    247   for (int pos = 0; pos < how_many_labels; ++pos) {
    248     const int label_index = indices_flat(pos);
    249     const float score = scores_flat(pos);
    250     LOG(INFO) << labels[label_index] << " (" << label_index << "): " << score;
    251   }
    252   return Status::OK();
    253 }
    254 
    255 // This is a testing function that returns whether the top label index is the
    256 // one that's expected.
    257 Status CheckTopLabel(const std::vector<Tensor>& outputs, int expected,
    258                      bool* is_expected) {
    259   *is_expected = false;
    260   Tensor indices;
    261   Tensor scores;
    262   const int how_many_labels = 1;
    263   TF_RETURN_IF_ERROR(GetTopLabels(outputs, how_many_labels, &indices, &scores));
    264   tensorflow::TTypes<int32>::Flat indices_flat = indices.flat<int32>();
    265   if (indices_flat(0) != expected) {
    266     LOG(ERROR) << "Expected label #" << expected << " but got #"
    267                << indices_flat(0);
    268     *is_expected = false;
    269   } else {
    270     *is_expected = true;
    271   }
    272   return Status::OK();
    273 }
    274 
    275 int main(int argc, char* argv[]) {
    276   // These are the command-line flags the program can understand.
    277   // They define where the graph and input data is located, and what kind of
    278   // input the model expects. If you train your own model, or use something
    279   // other than inception_v3, then you'll need to update these.
    280   string image = "tensorflow/examples/label_image/data/grace_hopper.jpg";
    281   string graph =
    282       "tensorflow/examples/label_image/data/inception_v3_2016_08_28_frozen.pb";
    283   string labels =
    284       "tensorflow/examples/label_image/data/imagenet_slim_labels.txt";
    285   int32 input_width = 299;
    286   int32 input_height = 299;
    287   float input_mean = 0;
    288   float input_std = 255;
    289   string input_layer = "input";
    290   string output_layer = "InceptionV3/Predictions/Reshape_1";
    291   bool self_test = false;
    292   string root_dir = "";
    293   std::vector<Flag> flag_list = {
    294       Flag("image", &image, "image to be processed"),
    295       Flag("graph", &graph, "graph to be executed"),
    296       Flag("labels", &labels, "name of file containing labels"),
    297       Flag("input_width", &input_width, "resize image to this width in pixels"),
    298       Flag("input_height", &input_height,
    299            "resize image to this height in pixels"),
    300       Flag("input_mean", &input_mean, "scale pixel values to this mean"),
    301       Flag("input_std", &input_std, "scale pixel values to this std deviation"),
    302       Flag("input_layer", &input_layer, "name of input layer"),
    303       Flag("output_layer", &output_layer, "name of output layer"),
    304       Flag("self_test", &self_test, "run a self test"),
    305       Flag("root_dir", &root_dir,
    306            "interpret image and graph file names relative to this directory"),
    307   };
    308   string usage = tensorflow::Flags::Usage(argv[0], flag_list);
    309   const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
    310   if (!parse_result) {
    311     LOG(ERROR) << usage;
    312     return -1;
    313   }
    314 
    315   // We need to call this to set up global state for TensorFlow.
    316   tensorflow::port::InitMain(argv[0], &argc, &argv);
    317   if (argc > 1) {
    318     LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
    319     return -1;
    320   }
    321 
    322   // First we load and initialize the model.
    323   std::unique_ptr<tensorflow::Session> session;
    324   string graph_path = tensorflow::io::JoinPath(root_dir, graph);
    325   Status load_graph_status = LoadGraph(graph_path, &session);
    326   if (!load_graph_status.ok()) {
    327     LOG(ERROR) << load_graph_status;
    328     return -1;
    329   }
    330 
    331   // Get the image from disk as a float array of numbers, resized and normalized
    332   // to the specifications the main graph expects.
    333   std::vector<Tensor> resized_tensors;
    334   string image_path = tensorflow::io::JoinPath(root_dir, image);
    335   Status read_tensor_status =
    336       ReadTensorFromImageFile(image_path, input_height, input_width, input_mean,
    337                               input_std, &resized_tensors);
    338   if (!read_tensor_status.ok()) {
    339     LOG(ERROR) << read_tensor_status;
    340     return -1;
    341   }
    342   const Tensor& resized_tensor = resized_tensors[0];
    343 
    344   // Actually run the image through the model.
    345   std::vector<Tensor> outputs;
    346   Status run_status = session->Run({{input_layer, resized_tensor}},
    347                                    {output_layer}, {}, &outputs);
    348   if (!run_status.ok()) {
    349     LOG(ERROR) << "Running model failed: " << run_status;
    350     return -1;
    351   }
    352 
    353   // This is for automated testing to make sure we get the expected result with
    354   // the default settings. We know that label 653 (military uniform) should be
    355   // the top label for the Admiral Hopper image.
    356   if (self_test) {
    357     bool expected_matches;
    358     Status check_status = CheckTopLabel(outputs, 653, &expected_matches);
    359     if (!check_status.ok()) {
    360       LOG(ERROR) << "Running check failed: " << check_status;
    361       return -1;
    362     }
    363     if (!expected_matches) {
    364       LOG(ERROR) << "Self-test failed!";
    365       return -1;
    366     }
    367   }
    368 
    369   // Do something interesting with the results we've generated.
    370   Status print_status = PrintTopLabels(outputs, labels);
    371   if (!print_status.ok()) {
    372     LOG(ERROR) << "Running print failed: " << print_status;
    373     return -1;
    374   }
    375 
    376   return 0;
    377 }
    378