Home | History | Annotate | Download | only in graph_transforms
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // This program prints out a summary of a GraphDef file's contents, listing
     17 // things that are useful for debugging and reusing the model it contains. For
     18 // example it looks at the graph structure and op types to figure out likely
     19 // input and output nodes, and shows which ops are used by the graph. To use it,
     20 // run something like this:
     21 //
     22 // bazel build tensorflow/tools/graph_transforms:summarize_graph
     23 // bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \
     24 // --in_graph=my_graph.pb
     25 
     26 #include "tensorflow/core/framework/attr_value.pb.h"
     27 #include "tensorflow/core/framework/function.pb.h"
     28 #include "tensorflow/core/framework/node_def.pb.h"
     29 #include "tensorflow/core/framework/tensor.h"
     30 #include "tensorflow/core/framework/tensor_shape.pb.h"
     31 #include "tensorflow/core/lib/strings/str_util.h"
     32 #include "tensorflow/core/platform/env.h"
     33 #include "tensorflow/core/platform/init_main.h"
     34 #include "tensorflow/core/platform/logging.h"
     35 #include "tensorflow/core/util/command_line_flags.h"
     36 #include "tensorflow/tools/graph_transforms/file_utils.h"
     37 #include "tensorflow/tools/graph_transforms/transform_utils.h"
     38 
     39 namespace tensorflow {
     40 namespace graph_transforms {
     41 namespace {
     42 
     43 void PrintNodeInfo(const NodeDef* node) {
     44   string shape_description = "None";
     45   if (node->attr().count("shape")) {
     46     TensorShapeProto shape_proto = node->attr().at("shape").shape();
     47     Status shape_status = PartialTensorShape::IsValidShape(shape_proto);
     48     if (shape_status.ok()) {
     49       shape_description = PartialTensorShape(shape_proto).DebugString();
     50     } else {
     51       shape_description = shape_status.error_message();
     52     }
     53   }
     54   DataType dtype = DT_INVALID;
     55   if (node->attr().count("dtype")) {
     56     dtype = node->attr().at("dtype").type();
     57   }
     58   std::cout << "(name=" << node->name();
     59   std::cout << ", type=" << DataTypeString(dtype) << "(" << dtype << ")";
     60   std::cout << ", shape=" << shape_description << ") ";
     61 }
     62 
     63 void PrintBenchmarkUsage(const std::vector<const NodeDef*>& placeholders,
     64                          const std::vector<const NodeDef*>& variables,
     65                          const std::vector<const NodeDef*> outputs,
     66                          const string& graph_path) {
     67   std::vector<const NodeDef*> all_inputs(placeholders);
     68   all_inputs.insert(all_inputs.end(), variables.begin(), variables.end());
     69 
     70   std::vector<string> input_layers;
     71   std::vector<string> input_layer_types;
     72   std::vector<string> input_layer_shapes;
     73   for (const NodeDef* node : all_inputs) {
     74     input_layers.push_back(node->name());
     75     DataType dtype = DT_INVALID;
     76     if (node->attr().count("dtype")) {
     77       dtype = node->attr().at("dtype").type();
     78     }
     79     input_layer_types.push_back(DataTypeString(dtype));
     80     std::vector<int64> sizes;
     81     PartialTensorShape shape;
     82     if (node->attr().count("shape")) {
     83       TensorShapeProto shape_proto = node->attr().at("shape").shape();
     84       if (PartialTensorShape::IsValid(shape_proto)) {
     85         shape = PartialTensorShape(shape_proto);
     86       }
     87     }
     88     string sizes_string;
     89     if (shape.dims() == -1) {
     90       // Unknown shapes can have -1 for dims, so leave these blank.
     91       sizes_string = "";
     92     } else {
     93       sizes.reserve(shape.dims());
     94       for (int i = 0; i < shape.dims(); ++i) {
     95         sizes.push_back(shape.dim_size(i));
     96       }
     97       sizes_string = str_util::Join(sizes, ",");
     98     }
     99     input_layer_shapes.push_back(sizes_string);
    100   }
    101   std::vector<string> output_layers;
    102   output_layers.reserve(outputs.size());
    103   for (const NodeDef* node : outputs) {
    104     output_layers.push_back(node->name());
    105   }
    106   string input_layer_value = str_util::Join(input_layers, ",");
    107   string input_layer_type_value = str_util::Join(input_layer_types, ",");
    108   string input_layer_shape_value = str_util::Join(input_layer_shapes, ":");
    109   string output_layer_value = str_util::Join(output_layers, ",");
    110 
    111   std::cout << "To use with tensorflow/tools/benchmark:benchmark_model try "
    112                "these arguments:"
    113             << std::endl;
    114   std::cout << "bazel run tensorflow/tools/benchmark:benchmark_model --";
    115   std::cout << " --graph=" << graph_path;
    116   std::cout << " --show_flops";
    117   std::cout << " --input_layer=" << input_layer_value;
    118   std::cout << " --input_layer_type=" << input_layer_type_value;
    119   std::cout << " --input_layer_shape=" << input_layer_shape_value;
    120   std::cout << " --output_layer=" << output_layer_value;
    121   std::cout << std::endl;
    122 }
    123 
    124 Status PrintStructure(const GraphDef& graph) {
    125   GraphDef sorted_graph;
    126   TF_RETURN_IF_ERROR(SortByExecutionOrder(graph, &sorted_graph));
    127   for (const NodeDef& node : sorted_graph.node()) {
    128     std::cout << node.name() << " (" << node.op() << "): ["
    129               << str_util::Join(node.input(), ", ") << "]";
    130     if (node.op() == "Const") {
    131       Tensor tensor;
    132       if (node.attr().count("value") &&
    133           tensor.FromProto(node.attr().at("value").tensor())) {
    134         std::cout << ", value=" << tensor.DebugString();
    135       } else {
    136         LOG(WARNING) << "Decoding Tensor failed for node" << node.name();
    137       }
    138     }
    139     std::cout << std::endl;
    140   }
    141   return Status::OK();
    142 }
    143 
    144 Status SummarizeGraph(const GraphDef& graph, const string& graph_path,
    145                       bool print_structure) {
    146   std::vector<const NodeDef*> placeholders;
    147   std::vector<const NodeDef*> variables;
    148   for (const NodeDef& node : graph.node()) {
    149     if (node.op() == "Placeholder") {
    150       placeholders.push_back(&node);
    151     }
    152     if (node.op() == "Variable" || node.op() == "VariableV2") {
    153       variables.push_back(&node);
    154     }
    155   }
    156 
    157   if (placeholders.empty()) {
    158     std::cout << "No inputs spotted." << std::endl;
    159   } else {
    160     std::cout << "Found " << placeholders.size() << " possible inputs: ";
    161     for (const NodeDef* node : placeholders) {
    162       PrintNodeInfo(node);
    163     }
    164     std::cout << std::endl;
    165   }
    166 
    167   if (variables.empty()) {
    168     std::cout << "No variables spotted." << std::endl;
    169   } else {
    170     std::cout << "Found " << variables.size() << " variables: ";
    171     for (const NodeDef* node : variables) {
    172       PrintNodeInfo(node);
    173     }
    174     std::cout << std::endl;
    175   }
    176 
    177   std::map<string, std::vector<const NodeDef*>> output_map;
    178   MapNodesToOutputs(graph, &output_map);
    179   std::vector<const NodeDef*> outputs;
    180   std::unordered_set<string> unlikely_output_types = {"Const", "Assign", "NoOp",
    181                                                       "Placeholder"};
    182   for (const NodeDef& node : graph.node()) {
    183     if ((output_map.count(node.name()) == 0) &&
    184         (unlikely_output_types.count(node.op()) == 0)) {
    185       outputs.push_back(&node);
    186     }
    187   }
    188 
    189   if (outputs.empty()) {
    190     std::cout << "No outputs spotted." << std::endl;
    191   } else {
    192     std::cout << "Found " << outputs.size() << " possible outputs: ";
    193     for (const NodeDef* node : outputs) {
    194       std::cout << "(name=" << node->name();
    195       std::cout << ", op=" << node->op() << ") ";
    196     }
    197     std::cout << std::endl;
    198   }
    199 
    200   int64 const_parameter_count = 0;
    201   int64 variable_parameter_count = 0;
    202   int control_edge_count = 0;
    203   std::map<string, int> device_counts;
    204   for (const NodeDef& node : graph.node()) {
    205     for (const string& input : node.input()) {
    206       if (input.substr(0, 1) == "^") {
    207         ++control_edge_count;
    208       }
    209     }
    210     if (!node.device().empty()) {
    211       ++device_counts[node.device()];
    212     }
    213     if ((node.op() == "Const") || (node.op() == "Variable") ||
    214         (node.op() == "VariableV2")) {
    215       Tensor tensor;
    216       if (node.attr().count("value") &&
    217           tensor.FromProto(node.attr().at("value").tensor())) {
    218         const size_t num_elements = tensor.NumElements();
    219         if (node.op() == "Const") {
    220           const_parameter_count += num_elements;
    221         } else {
    222           variable_parameter_count += num_elements;
    223         }
    224       } else {
    225         LOG(WARNING) << "Decoding Tensor failed for node" << node.name();
    226       }
    227     }
    228   }
    229 
    230   std::cout << "Found " << const_parameter_count << " ("
    231             << strings::HumanReadableNum(const_parameter_count)
    232             << ") const parameters, " << variable_parameter_count << " ("
    233             << strings::HumanReadableNum(variable_parameter_count)
    234             << ") variable parameters, and " << control_edge_count
    235             << " control_edges" << std::endl;
    236   if (!device_counts.empty()) {
    237     for (const auto& device_info : device_counts) {
    238       std::cout << device_info.second << " nodes assigned to device '"
    239                 << device_info.first << "'";
    240     }
    241   }
    242 
    243   std::vector<std::pair<string, string>> invalid_inputs;
    244   FindInvalidInputs(graph, &invalid_inputs);
    245   if (!invalid_inputs.empty()) {
    246     for (const std::pair<string, string>& invalid_input : invalid_inputs) {
    247       std::cout << "Invalid input " << invalid_input.second << " for node "
    248                 << invalid_input.first << std::endl;
    249     }
    250     return errors::Internal(
    251         "Invalid graph with inputs referring to nonexistent nodes");
    252   }
    253 
    254   std::map<string, int> op_counts;
    255   for (const NodeDef& node : graph.node()) {
    256     ++op_counts[node.op()];
    257   }
    258   for (const FunctionDef& function : graph.library().function()) {
    259     for (const NodeDef& node : function.node_def()) {
    260       ++op_counts[node.op()];
    261     }
    262   }
    263   std::vector<std::pair<string, int>> op_counts_vec(op_counts.begin(),
    264                                                     op_counts.end());
    265   std::sort(op_counts_vec.begin(), op_counts_vec.end(),
    266             [](std::pair<string, int> a, std::pair<string, int> b) {
    267               return (a.second > b.second);
    268             });
    269   std::cout << "Op types used: ";
    270   bool is_first = true;
    271   for (const std::pair<string, int>& op_count : op_counts_vec) {
    272     if (!is_first) {
    273       std::cout << ", ";
    274     } else {
    275       is_first = false;
    276     }
    277     std::cout << op_count.second << " " << op_count.first;
    278   }
    279   std::cout << std::endl;
    280 
    281   PrintBenchmarkUsage(placeholders, variables, outputs, graph_path);
    282 
    283   if (print_structure) {
    284     TF_RETURN_IF_ERROR(PrintStructure(graph));
    285   }
    286 
    287   return Status::OK();
    288 }
    289 
    290 int ParseFlagsAndSummarizeGraph(int argc, char* argv[]) {
    291   string in_graph = "";
    292   bool print_structure = false;
    293   std::vector<Flag> flag_list = {
    294       Flag("in_graph", &in_graph, "input graph file name"),
    295       Flag("print_structure", &print_structure,
    296            "whether to print the network connections of the graph"),
    297   };
    298   string usage = Flags::Usage(argv[0], flag_list);
    299 
    300   const bool parse_result = Flags::Parse(&argc, argv, flag_list);
    301   // We need to call this to set up global state for TensorFlow.
    302   port::InitMain(argv[0], &argc, &argv);
    303 
    304   if (!parse_result) {
    305     LOG(ERROR) << usage;
    306     return -1;
    307   }
    308   if (argc > 1) {
    309     LOG(ERROR) << "Unknown argument " << argv[1] << ".\n" << usage;
    310     return -1;
    311   }
    312   if (in_graph.empty()) {
    313     LOG(ERROR) << "in_graph graph can't be empty.\n" << usage;
    314     return -1;
    315   }
    316 
    317   GraphDef graph_def;
    318   Status load_status = LoadTextOrBinaryGraphFile(in_graph, &graph_def);
    319   if (!load_status.ok()) {
    320     LOG(ERROR) << "Loading graph '" << in_graph << "' failed with "
    321                << load_status.error_message();
    322     LOG(ERROR) << usage;
    323     return -1;
    324   }
    325 
    326   Status summarize_result =
    327       SummarizeGraph(graph_def, in_graph, print_structure);
    328   if (!summarize_result.ok()) {
    329     LOG(ERROR) << summarize_result.error_message() << "\n" << usage;
    330     return -1;
    331   }
    332 
    333   return 0;
    334 }
    335 
    336 }  // namespace
    337 }  // namespace graph_transforms
    338 }  // namespace tensorflow
    339 
    340 int main(int argc, char* argv[]) {
    341   return tensorflow::graph_transforms::ParseFlagsAndSummarizeGraph(argc, argv);
    342 }
    343