1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This program prints out a summary of a GraphDef file's contents, listing 17 // things that are useful for debugging and reusing the model it contains. For 18 // example it looks at the graph structure and op types to figure out likely 19 // input and output nodes, and shows which ops are used by the graph. To use it, 20 // run something like this: 21 // 22 // bazel build tensorflow/tools/graph_transforms:summarize_graph 23 // bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \ 24 // --in_graph=my_graph.pb 25 26 #include "tensorflow/core/framework/attr_value.pb.h" 27 #include "tensorflow/core/framework/function.pb.h" 28 #include "tensorflow/core/framework/node_def.pb.h" 29 #include "tensorflow/core/framework/tensor.h" 30 #include "tensorflow/core/framework/tensor_shape.pb.h" 31 #include "tensorflow/core/lib/strings/str_util.h" 32 #include "tensorflow/core/platform/env.h" 33 #include "tensorflow/core/platform/init_main.h" 34 #include "tensorflow/core/platform/logging.h" 35 #include "tensorflow/core/util/command_line_flags.h" 36 #include "tensorflow/tools/graph_transforms/file_utils.h" 37 #include "tensorflow/tools/graph_transforms/transform_utils.h" 38 39 namespace tensorflow { 40 namespace graph_transforms { 41 namespace { 42 43 void PrintNodeInfo(const NodeDef* node) { 44 string shape_description = "None"; 45 if (node->attr().count("shape")) { 46 TensorShapeProto shape_proto = node->attr().at("shape").shape(); 47 Status shape_status = PartialTensorShape::IsValidShape(shape_proto); 48 if (shape_status.ok()) { 49 shape_description = PartialTensorShape(shape_proto).DebugString(); 50 } else { 51 shape_description = shape_status.error_message(); 52 } 53 } 54 DataType dtype = DT_INVALID; 55 if (node->attr().count("dtype")) { 56 dtype = node->attr().at("dtype").type(); 57 } 58 std::cout << "(name=" << node->name(); 59 std::cout << ", type=" << DataTypeString(dtype) << "(" << dtype << ")"; 60 std::cout << ", shape=" << shape_description << ") "; 61 } 62 63 void PrintBenchmarkUsage(const std::vector<const NodeDef*>& placeholders, 64 const std::vector<const NodeDef*>& variables, 65 const std::vector<const NodeDef*> outputs, 66 const string& graph_path) { 67 std::vector<const NodeDef*> all_inputs(placeholders); 68 all_inputs.insert(all_inputs.end(), variables.begin(), variables.end()); 69 70 std::vector<string> input_layers; 71 std::vector<string> input_layer_types; 72 std::vector<string> input_layer_shapes; 73 for (const NodeDef* node : all_inputs) { 74 input_layers.push_back(node->name()); 75 DataType dtype = DT_INVALID; 76 if (node->attr().count("dtype")) { 77 dtype = node->attr().at("dtype").type(); 78 } 79 input_layer_types.push_back(DataTypeString(dtype)); 80 std::vector<int64> sizes; 81 PartialTensorShape shape; 82 if (node->attr().count("shape")) { 83 TensorShapeProto shape_proto = node->attr().at("shape").shape(); 84 if (PartialTensorShape::IsValid(shape_proto)) { 85 shape = PartialTensorShape(shape_proto); 86 } 87 } 88 string sizes_string; 89 if (shape.dims() == -1) { 90 // Unknown shapes can have -1 for dims, so leave these blank. 91 sizes_string = ""; 92 } else { 93 sizes.reserve(shape.dims()); 94 for (int i = 0; i < shape.dims(); ++i) { 95 sizes.push_back(shape.dim_size(i)); 96 } 97 sizes_string = str_util::Join(sizes, ","); 98 } 99 input_layer_shapes.push_back(sizes_string); 100 } 101 std::vector<string> output_layers; 102 output_layers.reserve(outputs.size()); 103 for (const NodeDef* node : outputs) { 104 output_layers.push_back(node->name()); 105 } 106 string input_layer_value = str_util::Join(input_layers, ","); 107 string input_layer_type_value = str_util::Join(input_layer_types, ","); 108 string input_layer_shape_value = str_util::Join(input_layer_shapes, ":"); 109 string output_layer_value = str_util::Join(output_layers, ","); 110 111 std::cout << "To use with tensorflow/tools/benchmark:benchmark_model try " 112 "these arguments:" 113 << std::endl; 114 std::cout << "bazel run tensorflow/tools/benchmark:benchmark_model --"; 115 std::cout << " --graph=" << graph_path; 116 std::cout << " --show_flops"; 117 std::cout << " --input_layer=" << input_layer_value; 118 std::cout << " --input_layer_type=" << input_layer_type_value; 119 std::cout << " --input_layer_shape=" << input_layer_shape_value; 120 std::cout << " --output_layer=" << output_layer_value; 121 std::cout << std::endl; 122 } 123 124 Status PrintStructure(const GraphDef& graph) { 125 GraphDef sorted_graph; 126 TF_RETURN_IF_ERROR(SortByExecutionOrder(graph, &sorted_graph)); 127 for (const NodeDef& node : sorted_graph.node()) { 128 std::cout << node.name() << " (" << node.op() << "): [" 129 << str_util::Join(node.input(), ", ") << "]"; 130 if (node.op() == "Const") { 131 Tensor tensor; 132 if (node.attr().count("value") && 133 tensor.FromProto(node.attr().at("value").tensor())) { 134 std::cout << ", value=" << tensor.DebugString(); 135 } else { 136 LOG(WARNING) << "Decoding Tensor failed for node" << node.name(); 137 } 138 } 139 std::cout << std::endl; 140 } 141 return Status::OK(); 142 } 143 144 Status SummarizeGraph(const GraphDef& graph, const string& graph_path, 145 bool print_structure) { 146 std::vector<const NodeDef*> placeholders; 147 std::vector<const NodeDef*> variables; 148 for (const NodeDef& node : graph.node()) { 149 if (node.op() == "Placeholder") { 150 placeholders.push_back(&node); 151 } 152 if (node.op() == "Variable" || node.op() == "VariableV2") { 153 variables.push_back(&node); 154 } 155 } 156 157 if (placeholders.empty()) { 158 std::cout << "No inputs spotted." << std::endl; 159 } else { 160 std::cout << "Found " << placeholders.size() << " possible inputs: "; 161 for (const NodeDef* node : placeholders) { 162 PrintNodeInfo(node); 163 } 164 std::cout << std::endl; 165 } 166 167 if (variables.empty()) { 168 std::cout << "No variables spotted." << std::endl; 169 } else { 170 std::cout << "Found " << variables.size() << " variables: "; 171 for (const NodeDef* node : variables) { 172 PrintNodeInfo(node); 173 } 174 std::cout << std::endl; 175 } 176 177 std::map<string, std::vector<const NodeDef*>> output_map; 178 MapNodesToOutputs(graph, &output_map); 179 std::vector<const NodeDef*> outputs; 180 std::unordered_set<string> unlikely_output_types = {"Const", "Assign", "NoOp", 181 "Placeholder"}; 182 for (const NodeDef& node : graph.node()) { 183 if ((output_map.count(node.name()) == 0) && 184 (unlikely_output_types.count(node.op()) == 0)) { 185 outputs.push_back(&node); 186 } 187 } 188 189 if (outputs.empty()) { 190 std::cout << "No outputs spotted." << std::endl; 191 } else { 192 std::cout << "Found " << outputs.size() << " possible outputs: "; 193 for (const NodeDef* node : outputs) { 194 std::cout << "(name=" << node->name(); 195 std::cout << ", op=" << node->op() << ") "; 196 } 197 std::cout << std::endl; 198 } 199 200 int64 const_parameter_count = 0; 201 int64 variable_parameter_count = 0; 202 int control_edge_count = 0; 203 std::map<string, int> device_counts; 204 for (const NodeDef& node : graph.node()) { 205 for (const string& input : node.input()) { 206 if (input.substr(0, 1) == "^") { 207 ++control_edge_count; 208 } 209 } 210 if (!node.device().empty()) { 211 ++device_counts[node.device()]; 212 } 213 if ((node.op() == "Const") || (node.op() == "Variable") || 214 (node.op() == "VariableV2")) { 215 Tensor tensor; 216 if (node.attr().count("value") && 217 tensor.FromProto(node.attr().at("value").tensor())) { 218 const size_t num_elements = tensor.NumElements(); 219 if (node.op() == "Const") { 220 const_parameter_count += num_elements; 221 } else { 222 variable_parameter_count += num_elements; 223 } 224 } else { 225 LOG(WARNING) << "Decoding Tensor failed for node" << node.name(); 226 } 227 } 228 } 229 230 std::cout << "Found " << const_parameter_count << " (" 231 << strings::HumanReadableNum(const_parameter_count) 232 << ") const parameters, " << variable_parameter_count << " (" 233 << strings::HumanReadableNum(variable_parameter_count) 234 << ") variable parameters, and " << control_edge_count 235 << " control_edges" << std::endl; 236 if (!device_counts.empty()) { 237 for (const auto& device_info : device_counts) { 238 std::cout << device_info.second << " nodes assigned to device '" 239 << device_info.first << "'"; 240 } 241 } 242 243 std::vector<std::pair<string, string>> invalid_inputs; 244 FindInvalidInputs(graph, &invalid_inputs); 245 if (!invalid_inputs.empty()) { 246 for (const std::pair<string, string>& invalid_input : invalid_inputs) { 247 std::cout << "Invalid input " << invalid_input.second << " for node " 248 << invalid_input.first << std::endl; 249 } 250 return errors::Internal( 251 "Invalid graph with inputs referring to nonexistent nodes"); 252 } 253 254 std::map<string, int> op_counts; 255 for (const NodeDef& node : graph.node()) { 256 ++op_counts[node.op()]; 257 } 258 for (const FunctionDef& function : graph.library().function()) { 259 for (const NodeDef& node : function.node_def()) { 260 ++op_counts[node.op()]; 261 } 262 } 263 std::vector<std::pair<string, int>> op_counts_vec(op_counts.begin(), 264 op_counts.end()); 265 std::sort(op_counts_vec.begin(), op_counts_vec.end(), 266 [](std::pair<string, int> a, std::pair<string, int> b) { 267 return (a.second > b.second); 268 }); 269 std::cout << "Op types used: "; 270 bool is_first = true; 271 for (const std::pair<string, int>& op_count : op_counts_vec) { 272 if (!is_first) { 273 std::cout << ", "; 274 } else { 275 is_first = false; 276 } 277 std::cout << op_count.second << " " << op_count.first; 278 } 279 std::cout << std::endl; 280 281 PrintBenchmarkUsage(placeholders, variables, outputs, graph_path); 282 283 if (print_structure) { 284 TF_RETURN_IF_ERROR(PrintStructure(graph)); 285 } 286 287 return Status::OK(); 288 } 289 290 int ParseFlagsAndSummarizeGraph(int argc, char* argv[]) { 291 string in_graph = ""; 292 bool print_structure = false; 293 std::vector<Flag> flag_list = { 294 Flag("in_graph", &in_graph, "input graph file name"), 295 Flag("print_structure", &print_structure, 296 "whether to print the network connections of the graph"), 297 }; 298 string usage = Flags::Usage(argv[0], flag_list); 299 300 const bool parse_result = Flags::Parse(&argc, argv, flag_list); 301 // We need to call this to set up global state for TensorFlow. 302 port::InitMain(argv[0], &argc, &argv); 303 304 if (!parse_result) { 305 LOG(ERROR) << usage; 306 return -1; 307 } 308 if (argc > 1) { 309 LOG(ERROR) << "Unknown argument " << argv[1] << ".\n" << usage; 310 return -1; 311 } 312 if (in_graph.empty()) { 313 LOG(ERROR) << "in_graph graph can't be empty.\n" << usage; 314 return -1; 315 } 316 317 GraphDef graph_def; 318 Status load_status = LoadTextOrBinaryGraphFile(in_graph, &graph_def); 319 if (!load_status.ok()) { 320 LOG(ERROR) << "Loading graph '" << in_graph << "' failed with " 321 << load_status.error_message(); 322 LOG(ERROR) << usage; 323 return -1; 324 } 325 326 Status summarize_result = 327 SummarizeGraph(graph_def, in_graph, print_structure); 328 if (!summarize_result.ok()) { 329 LOG(ERROR) << summarize_result.error_message() << "\n" << usage; 330 return -1; 331 } 332 333 return 0; 334 } 335 336 } // namespace 337 } // namespace graph_transforms 338 } // namespace tensorflow 339 340 int main(int argc, char* argv[]) { 341 return tensorflow::graph_transforms::ParseFlagsAndSummarizeGraph(argc, argv); 342 } 343