1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/contrib/lite/toco/dump_graphviz.h" 16 17 #include <memory> 18 #include <set> 19 #include <unordered_set> 20 #include <vector> 21 22 #include "absl/strings/str_replace.h" 23 #include "absl/strings/strip.h" 24 #include "tensorflow/contrib/lite/toco/model_flags.pb.h" 25 #include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" 26 #include "tensorflow/contrib/lite/toco/toco_port.h" 27 #include "tensorflow/contrib/lite/toco/toco_types.h" 28 #include "tensorflow/contrib/lite/toco/tooling_util.h" 29 #include "tensorflow/core/platform/logging.h" 30 31 using toco::port::AppendF; 32 using toco::port::StringF; 33 34 namespace toco { 35 namespace { 36 37 class Color { 38 public: 39 Color() {} 40 Color(uint8 r, uint8 g, uint8 b) : r_(r), g_(g), b_(b) {} 41 // Returns the string serialization of this color in graphviz format, 42 // for use as 'fillcolor' in boxes. 43 string FillColorString() const { return StringF("%.2X%.2X%.2X", r_, g_, b_); } 44 // Returns the serialization in graphviz format of a suitable color to use 45 // 'fontcolor' in the same boxes. It should black or white, whichever offers 46 // the better contrast from FillColorString(). 47 string TextColorString() const { 48 // https://en.wikipedia.org/wiki/Relative_luminance 49 const float luminance = 0.2126f * r_ + 0.7152f * g_ + 0.0722f * b_; 50 const uint8 l = luminance > 128.f ? 0 : 255; 51 return StringF("%.2X%.2X%.2X", l, l, l); 52 } 53 54 private: 55 uint8 r_ = 0, g_ = 0, b_ = 0; 56 }; 57 58 struct NodeProperties { 59 // The text to display inside the box for this node. 60 string label; 61 // The color to use for this node; will be used as 'fillcolor' 62 // for its box. See Color::FillColorString. A suitable, different 63 // color will be chosen for the 'fontcolor' for the inside text 64 // label, see Color::TextColorString. 65 Color color; 66 }; 67 68 // All colors in this file are from: 69 // https://material.io/guidelines/style/color.html 70 71 Color GetColorForArray(const Model& model, const string& array_name) { 72 // Arrays involved in RNN back-edges have a different color 73 for (const auto& rnn_state : model.flags.rnn_states()) { 74 // RNN state, fed by a back-edge. Bold color. 75 if (array_name == rnn_state.state_array()) { 76 return Color(0x0F, 0x9D, 0x58); 77 } 78 // RNN back-edge source, feeding a RNN state. 79 // Light tone of the same color as RNN states. 80 if (array_name == rnn_state.back_edge_source_array()) { 81 return Color(0xB7, 0xE1, 0xCD); 82 } 83 } 84 // Constant parameter arrays have their own bold color 85 if (model.GetArray(array_name).buffer) { 86 return Color(0x42, 0x85, 0xF4); 87 } 88 // Remaining arrays are activations. 89 // We use gray colors for them because they are the majority 90 // of arrays so we want to highlight other arrays instead of them. 91 // First, we use a bolder gray for input/output arrays: 92 const auto& dump_options = *GraphVizDumpOptions::singleton(); 93 if (IsInputArray(model, array_name) || 94 array_name == dump_options.graphviz_first_array || 95 array_name == dump_options.graphviz_last_array) { 96 return Color(0x9E, 0x9E, 0x9E); 97 } 98 for (const string& output_array : model.flags.output_arrays()) { 99 if (array_name == output_array) { 100 return Color(0x9E, 0x9E, 0x9E); 101 } 102 } 103 // Remaining arrays are intermediate activation arrays. 104 // Lighter tone of the same grey as for input/output arrays: 105 // We want these to be very discrete. 106 return Color(0xF5, 0xF5, 0xF5); 107 } 108 109 void AppendArrayVal(string* string, Array const& array, int index) { 110 if (array.buffer->type == ArrayDataType::kFloat) { 111 const auto& data = array.GetBuffer<ArrayDataType::kFloat>().data; 112 if (index >= data.size()) { 113 return; 114 } 115 AppendF(string, "%.3f", data[index]); 116 } else if (array.buffer->type == ArrayDataType::kUint8) { 117 const auto& data = array.GetBuffer<ArrayDataType::kUint8>().data; 118 if (index >= data.size()) { 119 return; 120 } 121 AppendF(string, "%d", data[index]); 122 } else if (array.buffer->type == ArrayDataType::kInt32) { 123 const auto& data = array.GetBuffer<ArrayDataType::kInt32>().data; 124 if (index >= data.size()) { 125 return; 126 } 127 AppendF(string, "%d", data[index]); 128 } else if (array.buffer->type == ArrayDataType::kInt64) { 129 const auto& data = array.GetBuffer<ArrayDataType::kInt64>().data; 130 if (index >= data.size()) { 131 return; 132 } 133 AppendF(string, "%d", data[index]); 134 } 135 } 136 137 NodeProperties GetPropertiesForArray(const Model& model, 138 const string& array_name) { 139 NodeProperties node_properties; 140 node_properties.color = GetColorForArray(model, array_name); 141 node_properties.label = absl::StrReplaceAll(array_name, {{"/", "/\\n"}}); 142 143 // Append array shape to the label. 144 auto& array = model.GetArray(array_name); 145 146 if (array.data_type == ArrayDataType::kFloat) { 147 AppendF(&node_properties.label, "\\nType: float"); 148 } else if (array.data_type == ArrayDataType::kInt32) { 149 AppendF(&node_properties.label, "\\nType: int32"); 150 } else if (array.data_type == ArrayDataType::kUint8) { 151 AppendF(&node_properties.label, "\\nType: uint8"); 152 } 153 154 if (array.has_shape()) { 155 auto& array_shape = array.shape(); 156 node_properties.label += "\\n["; 157 for (int id = 0; id < array_shape.dimensions_count(); id++) { 158 if (id == 0) { 159 AppendF(&node_properties.label, "%d", array_shape.dims(id)); 160 } else { 161 // 0x00D7 is the unicode multiplication symbol 162 AppendF(&node_properties.label, "\u00D7%d", array_shape.dims(id)); 163 } 164 } 165 node_properties.label += "]"; 166 167 if (array.buffer) { 168 const auto& array = model.GetArray(array_name); 169 int buffer_size = RequiredBufferSizeForShape(array.shape()); 170 if (buffer_size <= 4) { 171 AppendF(&node_properties.label, " = "); 172 if (array.shape().dimensions_count() > 0) { 173 AppendF(&node_properties.label, "{"); 174 } 175 for (int i = 0; i < buffer_size; i++) { 176 AppendArrayVal(&node_properties.label, array, i); 177 if (i + 1 < buffer_size) { 178 AppendF(&node_properties.label, ", "); 179 } 180 } 181 } else { 182 AppendF(&node_properties.label, "\\n = "); 183 if (array.shape().dimensions_count() > 0) { 184 AppendF(&node_properties.label, "{"); 185 } 186 AppendArrayVal(&node_properties.label, array, 0); 187 AppendF(&node_properties.label, ", "); 188 AppendArrayVal(&node_properties.label, array, 1); 189 // 0x2026 is the unicode ellipsis symbol 190 AppendF(&node_properties.label, " \u2026 "); 191 AppendArrayVal(&node_properties.label, array, buffer_size - 2); 192 AppendF(&node_properties.label, ", "); 193 AppendArrayVal(&node_properties.label, array, buffer_size - 1); 194 } 195 if (array.shape().dimensions_count() > 0) { 196 AppendF(&node_properties.label, "}"); 197 } 198 } 199 } 200 201 if (array.minmax) { 202 AppendF(&node_properties.label, "\\nMinMax: [%.3g, %.3g]", 203 array.minmax->min, array.minmax->max); 204 } 205 206 if (array.quantization_params) { 207 AppendF(&node_properties.label, "\\nQuantization: %.3g * (x - %d)", 208 array.quantization_params->scale, 209 array.quantization_params->zero_point); 210 } 211 212 if (array.alloc) { 213 AppendF(&node_properties.label, "\\nTransient Alloc: [%d, %d)", 214 array.alloc->start, array.alloc->end); 215 } 216 217 return node_properties; 218 } 219 220 NodeProperties GetPropertiesForOperator(const Operator& op) { 221 NodeProperties node_properties; 222 if (op.type == OperatorType::kTensorFlowUnsupported) { 223 node_properties.label = 224 static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op; 225 } else { 226 node_properties.label = 227 string(absl::StripPrefix(OperatorTypeName(op.type), "TensorFlow")); 228 } 229 switch (op.fused_activation_function) { 230 case FusedActivationFunctionType::kRelu: 231 AppendF(&node_properties.label, "\\nReLU"); 232 break; 233 case FusedActivationFunctionType::kRelu6: 234 AppendF(&node_properties.label, "\\nReLU6"); 235 break; 236 case FusedActivationFunctionType::kRelu1: 237 AppendF(&node_properties.label, "\\nReLU1"); 238 break; 239 default: 240 break; 241 } 242 // Additional information for some of the operators. 243 switch (op.type) { 244 case OperatorType::kConv: { 245 const auto& conv_op = static_cast<const ConvOperator&>(op); 246 node_properties.color = Color(0xC5, 0x39, 0x29); // Bolder color 247 AppendF(&node_properties.label, "\\n%dx%d/%s", conv_op.stride_width, 248 conv_op.stride_height, 249 conv_op.padding.type == PaddingType::kSame ? "S" : "V"); 250 break; 251 } 252 case OperatorType::kDepthwiseConv: { 253 const auto& conv_op = static_cast<const DepthwiseConvOperator&>(op); 254 node_properties.color = Color(0xC5, 0x39, 0x29); // Bolder color 255 AppendF(&node_properties.label, "\\n%dx%d/%s", conv_op.stride_width, 256 conv_op.stride_height, 257 conv_op.padding.type == PaddingType::kSame ? "S" : "V"); 258 break; 259 } 260 case OperatorType::kFullyConnected: { 261 node_properties.color = Color(0xC5, 0x39, 0x29); // Bolder color 262 break; 263 } 264 default: 265 node_properties.color = Color(0xDB, 0x44, 0x37); 266 break; 267 } 268 269 return node_properties; 270 } 271 272 std::vector<const Operator*> OperatorsToDump(const Model& model) { 273 const auto& dump_options = *GraphVizDumpOptions::singleton(); 274 bool first_specified = !dump_options.graphviz_first_array.empty(); 275 bool last_specified = !dump_options.graphviz_last_array.empty(); 276 CHECK_EQ(first_specified, last_specified); 277 std::vector<const Operator*> ops_to_dump; 278 if (last_specified) { 279 // Return only the part of the graph between graphviz_first_array 280 // and graphviz_last_array. 281 CHECK(model.HasArray(dump_options.graphviz_first_array)); 282 CHECK(model.HasArray(dump_options.graphviz_last_array)); 283 std::unordered_set<string> arrays_already_produced; 284 std::vector<string> arrays_to_produce; 285 arrays_to_produce.push_back(dump_options.graphviz_last_array); 286 while (!arrays_to_produce.empty()) { 287 const string array = arrays_to_produce.back(); 288 arrays_to_produce.pop_back(); 289 CHECK(!arrays_already_produced.count(array)); 290 arrays_already_produced.insert(array); 291 const Operator* op = GetOpWithOutput(model, array); 292 if (!op) { 293 continue; 294 } 295 ops_to_dump.push_back(op); 296 for (const string& input : op->inputs) { 297 if (arrays_already_produced.count(input) || 298 input == dump_options.graphviz_first_array) { 299 continue; 300 } 301 arrays_to_produce.push_back(input); 302 } 303 } 304 } else { 305 // Return the whole graph. 306 for (const auto& op : model.operators) { 307 ops_to_dump.push_back(op.get()); 308 } 309 } 310 return ops_to_dump; 311 } 312 313 } // namespace 314 315 void DumpGraphviz(const Model& model, string* output_file_contents) { 316 AppendF(output_file_contents, "digraph Computegraph {\n"); 317 318 constexpr char kNodeFormat[] = 319 "\t \"%s\" [label=\"%s\", shape=%s, style=filled, fillcolor=\"#%s\", " 320 "fontcolor = \"#%sDD\"];\n"; 321 322 constexpr char kEdgeFormat[] = "\t \"%s\" -> \"%s\";\n"; 323 324 constexpr char kRNNBackEdgeFormat[] = 325 "\t \"%s\" -> \"%s\" [color=\"#0F9D58\"];\n"; 326 327 std::vector<const Operator*> ops_to_dump = OperatorsToDump(model); 328 std::set<string> already_added_arrays; 329 for (int op_index = 0; op_index < ops_to_dump.size(); op_index++) { 330 const Operator& op = *ops_to_dump[op_index]; 331 // Add node for operator. 332 auto op_properties = GetPropertiesForOperator(op); 333 string operator_id = StringF("op%05d", op_index); 334 AppendF(output_file_contents, kNodeFormat, operator_id, op_properties.label, 335 "box", op_properties.color.FillColorString().c_str(), 336 op_properties.color.TextColorString().c_str()); 337 // Add nodes and edges for all inputs of the operator. 338 for (const auto& input : op.inputs) { 339 if (!model.HasArray(input)) { 340 // Arrays should _always_ exist. Except, perhaps, during development. 341 continue; 342 } 343 auto array_properties = GetPropertiesForArray(model, input); 344 if (!already_added_arrays.count(input)) { 345 AppendF(output_file_contents, kNodeFormat, input, 346 array_properties.label, "octagon", 347 array_properties.color.FillColorString().c_str(), 348 array_properties.color.TextColorString().c_str()); 349 } 350 AppendF(output_file_contents, kEdgeFormat, input, operator_id); 351 already_added_arrays.insert(input); 352 } 353 // Add nodes and edges for all outputs of the operator. 354 for (const auto& output : op.outputs) { 355 if (!model.HasArray(output)) { 356 // Arrays should _always_ exist. Except, perhaps, during development. 357 continue; 358 } 359 auto array_properties = GetPropertiesForArray(model, output); 360 if (!already_added_arrays.count(output)) { 361 AppendF(output_file_contents, kNodeFormat, output, 362 array_properties.label, "octagon", 363 array_properties.color.FillColorString().c_str(), 364 array_properties.color.TextColorString().c_str()); 365 } 366 AppendF(output_file_contents, kEdgeFormat, operator_id, output); 367 already_added_arrays.insert(output); 368 } 369 } 370 371 for (const auto& rnn_state : model.flags.rnn_states()) { 372 AppendF(output_file_contents, kRNNBackEdgeFormat, 373 rnn_state.back_edge_source_array(), rnn_state.state_array()); 374 } 375 376 AppendF(output_file_contents, "}\n"); 377 } 378 } // namespace toco 379