Home | History | Annotate | Download | only in toco
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/contrib/lite/toco/dump_graphviz.h"
     16 
     17 #include <memory>
     18 #include <set>
     19 #include <unordered_set>
     20 #include <vector>
     21 
     22 #include "absl/strings/str_replace.h"
     23 #include "absl/strings/strip.h"
     24 #include "tensorflow/contrib/lite/toco/model_flags.pb.h"
     25 #include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h"
     26 #include "tensorflow/contrib/lite/toco/toco_port.h"
     27 #include "tensorflow/contrib/lite/toco/toco_types.h"
     28 #include "tensorflow/contrib/lite/toco/tooling_util.h"
     29 #include "tensorflow/core/platform/logging.h"
     30 
     31 using toco::port::AppendF;
     32 using toco::port::StringF;
     33 
     34 namespace toco {
     35 namespace {
     36 
     37 class Color {
     38  public:
     39   Color() {}
     40   Color(uint8 r, uint8 g, uint8 b) : r_(r), g_(g), b_(b) {}
     41   // Returns the string serialization of this color in graphviz format,
     42   // for use as 'fillcolor' in boxes.
     43   string FillColorString() const { return StringF("%.2X%.2X%.2X", r_, g_, b_); }
     44   // Returns the serialization in graphviz format of a suitable color to use
     45   // 'fontcolor' in the same boxes. It should black or white, whichever offers
     46   // the better contrast from FillColorString().
     47   string TextColorString() const {
     48     // https://en.wikipedia.org/wiki/Relative_luminance
     49     const float luminance = 0.2126f * r_ + 0.7152f * g_ + 0.0722f * b_;
     50     const uint8 l = luminance > 128.f ? 0 : 255;
     51     return StringF("%.2X%.2X%.2X", l, l, l);
     52   }
     53 
     54  private:
     55   uint8 r_ = 0, g_ = 0, b_ = 0;
     56 };
     57 
     58 struct NodeProperties {
     59   // The text to display inside the box for this node.
     60   string label;
     61   // The color to use for this node; will be used as 'fillcolor'
     62   // for its box. See Color::FillColorString. A suitable, different
     63   // color will be chosen for the 'fontcolor' for the inside text
     64   // label, see Color::TextColorString.
     65   Color color;
     66 };
     67 
     68 // All colors in this file are from:
     69 // https://material.io/guidelines/style/color.html
     70 
     71 Color GetColorForArray(const Model& model, const string& array_name) {
     72   // Arrays involved in RNN back-edges have a different color
     73   for (const auto& rnn_state : model.flags.rnn_states()) {
     74     // RNN state, fed by a back-edge. Bold color.
     75     if (array_name == rnn_state.state_array()) {
     76       return Color(0x0F, 0x9D, 0x58);
     77     }
     78     // RNN back-edge source, feeding a RNN state.
     79     // Light tone of the same color as RNN states.
     80     if (array_name == rnn_state.back_edge_source_array()) {
     81       return Color(0xB7, 0xE1, 0xCD);
     82     }
     83   }
     84   // Constant parameter arrays have their own bold color
     85   if (model.GetArray(array_name).buffer) {
     86     return Color(0x42, 0x85, 0xF4);
     87   }
     88   // Remaining arrays are activations.
     89   // We use gray colors for them because they are the majority
     90   // of arrays so we want to highlight other arrays instead of them.
     91   // First, we use a bolder gray for input/output arrays:
     92   const auto& dump_options = *GraphVizDumpOptions::singleton();
     93   if (IsInputArray(model, array_name) ||
     94       array_name == dump_options.graphviz_first_array ||
     95       array_name == dump_options.graphviz_last_array) {
     96     return Color(0x9E, 0x9E, 0x9E);
     97   }
     98   for (const string& output_array : model.flags.output_arrays()) {
     99     if (array_name == output_array) {
    100       return Color(0x9E, 0x9E, 0x9E);
    101     }
    102   }
    103   // Remaining arrays are intermediate activation arrays.
    104   // Lighter tone of the same grey as for input/output arrays:
    105   // We want these to be very discrete.
    106   return Color(0xF5, 0xF5, 0xF5);
    107 }
    108 
    109 void AppendArrayVal(string* string, Array const& array, int index) {
    110   if (array.buffer->type == ArrayDataType::kFloat) {
    111     const auto& data = array.GetBuffer<ArrayDataType::kFloat>().data;
    112     if (index >= data.size()) {
    113       return;
    114     }
    115     AppendF(string, "%.3f", data[index]);
    116   } else if (array.buffer->type == ArrayDataType::kUint8) {
    117     const auto& data = array.GetBuffer<ArrayDataType::kUint8>().data;
    118     if (index >= data.size()) {
    119       return;
    120     }
    121     AppendF(string, "%d", data[index]);
    122   } else if (array.buffer->type == ArrayDataType::kInt32) {
    123     const auto& data = array.GetBuffer<ArrayDataType::kInt32>().data;
    124     if (index >= data.size()) {
    125       return;
    126     }
    127     AppendF(string, "%d", data[index]);
    128   } else if (array.buffer->type == ArrayDataType::kInt64) {
    129     const auto& data = array.GetBuffer<ArrayDataType::kInt64>().data;
    130     if (index >= data.size()) {
    131       return;
    132     }
    133     AppendF(string, "%d", data[index]);
    134   }
    135 }
    136 
    137 NodeProperties GetPropertiesForArray(const Model& model,
    138                                      const string& array_name) {
    139   NodeProperties node_properties;
    140   node_properties.color = GetColorForArray(model, array_name);
    141   node_properties.label = absl::StrReplaceAll(array_name, {{"/", "/\\n"}});
    142 
    143   // Append array shape to the label.
    144   auto& array = model.GetArray(array_name);
    145 
    146   if (array.data_type == ArrayDataType::kFloat) {
    147     AppendF(&node_properties.label, "\\nType: float");
    148   } else if (array.data_type == ArrayDataType::kInt32) {
    149     AppendF(&node_properties.label, "\\nType: int32");
    150   } else if (array.data_type == ArrayDataType::kUint8) {
    151     AppendF(&node_properties.label, "\\nType: uint8");
    152   }
    153 
    154   if (array.has_shape()) {
    155     auto& array_shape = array.shape();
    156     node_properties.label += "\\n[";
    157     for (int id = 0; id < array_shape.dimensions_count(); id++) {
    158       if (id == 0) {
    159         AppendF(&node_properties.label, "%d", array_shape.dims(id));
    160       } else {
    161         // 0x00D7 is the unicode multiplication symbol
    162         AppendF(&node_properties.label, "\u00D7%d", array_shape.dims(id));
    163       }
    164     }
    165     node_properties.label += "]";
    166 
    167     if (array.buffer) {
    168       const auto& array = model.GetArray(array_name);
    169       int buffer_size = RequiredBufferSizeForShape(array.shape());
    170       if (buffer_size <= 4) {
    171         AppendF(&node_properties.label, " = ");
    172         if (array.shape().dimensions_count() > 0) {
    173           AppendF(&node_properties.label, "{");
    174         }
    175         for (int i = 0; i < buffer_size; i++) {
    176           AppendArrayVal(&node_properties.label, array, i);
    177           if (i + 1 < buffer_size) {
    178             AppendF(&node_properties.label, ", ");
    179           }
    180         }
    181       } else {
    182         AppendF(&node_properties.label, "\\n = ");
    183         if (array.shape().dimensions_count() > 0) {
    184           AppendF(&node_properties.label, "{");
    185         }
    186         AppendArrayVal(&node_properties.label, array, 0);
    187         AppendF(&node_properties.label, ", ");
    188         AppendArrayVal(&node_properties.label, array, 1);
    189         // 0x2026 is the unicode ellipsis symbol
    190         AppendF(&node_properties.label, " \u2026 ");
    191         AppendArrayVal(&node_properties.label, array, buffer_size - 2);
    192         AppendF(&node_properties.label, ", ");
    193         AppendArrayVal(&node_properties.label, array, buffer_size - 1);
    194       }
    195       if (array.shape().dimensions_count() > 0) {
    196         AppendF(&node_properties.label, "}");
    197       }
    198     }
    199   }
    200 
    201   if (array.minmax) {
    202     AppendF(&node_properties.label, "\\nMinMax: [%.3g, %.3g]",
    203             array.minmax->min, array.minmax->max);
    204   }
    205 
    206   if (array.quantization_params) {
    207     AppendF(&node_properties.label, "\\nQuantization: %.3g * (x - %d)",
    208             array.quantization_params->scale,
    209             array.quantization_params->zero_point);
    210   }
    211 
    212   if (array.alloc) {
    213     AppendF(&node_properties.label, "\\nTransient Alloc: [%d, %d)",
    214             array.alloc->start, array.alloc->end);
    215   }
    216 
    217   return node_properties;
    218 }
    219 
    220 NodeProperties GetPropertiesForOperator(const Operator& op) {
    221   NodeProperties node_properties;
    222   if (op.type == OperatorType::kTensorFlowUnsupported) {
    223     node_properties.label =
    224         static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op;
    225   } else {
    226     node_properties.label =
    227         string(absl::StripPrefix(OperatorTypeName(op.type), "TensorFlow"));
    228   }
    229   switch (op.fused_activation_function) {
    230     case FusedActivationFunctionType::kRelu:
    231       AppendF(&node_properties.label, "\\nReLU");
    232       break;
    233     case FusedActivationFunctionType::kRelu6:
    234       AppendF(&node_properties.label, "\\nReLU6");
    235       break;
    236     case FusedActivationFunctionType::kRelu1:
    237       AppendF(&node_properties.label, "\\nReLU1");
    238       break;
    239     default:
    240       break;
    241   }
    242   // Additional information for some of the operators.
    243   switch (op.type) {
    244     case OperatorType::kConv: {
    245       const auto& conv_op = static_cast<const ConvOperator&>(op);
    246       node_properties.color = Color(0xC5, 0x39, 0x29);  // Bolder color
    247       AppendF(&node_properties.label, "\\n%dx%d/%s", conv_op.stride_width,
    248               conv_op.stride_height,
    249               conv_op.padding.type == PaddingType::kSame ? "S" : "V");
    250       break;
    251     }
    252     case OperatorType::kDepthwiseConv: {
    253       const auto& conv_op = static_cast<const DepthwiseConvOperator&>(op);
    254       node_properties.color = Color(0xC5, 0x39, 0x29);  // Bolder color
    255       AppendF(&node_properties.label, "\\n%dx%d/%s", conv_op.stride_width,
    256               conv_op.stride_height,
    257               conv_op.padding.type == PaddingType::kSame ? "S" : "V");
    258       break;
    259     }
    260     case OperatorType::kFullyConnected: {
    261       node_properties.color = Color(0xC5, 0x39, 0x29);  // Bolder color
    262       break;
    263     }
    264     default:
    265       node_properties.color = Color(0xDB, 0x44, 0x37);
    266       break;
    267   }
    268 
    269   return node_properties;
    270 }
    271 
    272 std::vector<const Operator*> OperatorsToDump(const Model& model) {
    273   const auto& dump_options = *GraphVizDumpOptions::singleton();
    274   bool first_specified = !dump_options.graphviz_first_array.empty();
    275   bool last_specified = !dump_options.graphviz_last_array.empty();
    276   CHECK_EQ(first_specified, last_specified);
    277   std::vector<const Operator*> ops_to_dump;
    278   if (last_specified) {
    279     // Return only the part of the graph between graphviz_first_array
    280     // and graphviz_last_array.
    281     CHECK(model.HasArray(dump_options.graphviz_first_array));
    282     CHECK(model.HasArray(dump_options.graphviz_last_array));
    283     std::unordered_set<string> arrays_already_produced;
    284     std::vector<string> arrays_to_produce;
    285     arrays_to_produce.push_back(dump_options.graphviz_last_array);
    286     while (!arrays_to_produce.empty()) {
    287       const string array = arrays_to_produce.back();
    288       arrays_to_produce.pop_back();
    289       CHECK(!arrays_already_produced.count(array));
    290       arrays_already_produced.insert(array);
    291       const Operator* op = GetOpWithOutput(model, array);
    292       if (!op) {
    293         continue;
    294       }
    295       ops_to_dump.push_back(op);
    296       for (const string& input : op->inputs) {
    297         if (arrays_already_produced.count(input) ||
    298             input == dump_options.graphviz_first_array) {
    299           continue;
    300         }
    301         arrays_to_produce.push_back(input);
    302       }
    303     }
    304   } else {
    305     // Return the whole graph.
    306     for (const auto& op : model.operators) {
    307       ops_to_dump.push_back(op.get());
    308     }
    309   }
    310   return ops_to_dump;
    311 }
    312 
    313 }  // namespace
    314 
    315 void DumpGraphviz(const Model& model, string* output_file_contents) {
    316   AppendF(output_file_contents, "digraph Computegraph {\n");
    317 
    318   constexpr char kNodeFormat[] =
    319       "\t \"%s\" [label=\"%s\", shape=%s, style=filled, fillcolor=\"#%s\", "
    320       "fontcolor = \"#%sDD\"];\n";
    321 
    322   constexpr char kEdgeFormat[] = "\t \"%s\" -> \"%s\";\n";
    323 
    324   constexpr char kRNNBackEdgeFormat[] =
    325       "\t \"%s\" -> \"%s\" [color=\"#0F9D58\"];\n";
    326 
    327   std::vector<const Operator*> ops_to_dump = OperatorsToDump(model);
    328   std::set<string> already_added_arrays;
    329   for (int op_index = 0; op_index < ops_to_dump.size(); op_index++) {
    330     const Operator& op = *ops_to_dump[op_index];
    331     // Add node for operator.
    332     auto op_properties = GetPropertiesForOperator(op);
    333     string operator_id = StringF("op%05d", op_index);
    334     AppendF(output_file_contents, kNodeFormat, operator_id, op_properties.label,
    335             "box", op_properties.color.FillColorString().c_str(),
    336             op_properties.color.TextColorString().c_str());
    337     // Add nodes and edges for all inputs of the operator.
    338     for (const auto& input : op.inputs) {
    339       if (!model.HasArray(input)) {
    340         // Arrays should _always_ exist. Except, perhaps, during development.
    341         continue;
    342       }
    343       auto array_properties = GetPropertiesForArray(model, input);
    344       if (!already_added_arrays.count(input)) {
    345         AppendF(output_file_contents, kNodeFormat, input,
    346                 array_properties.label, "octagon",
    347                 array_properties.color.FillColorString().c_str(),
    348                 array_properties.color.TextColorString().c_str());
    349       }
    350       AppendF(output_file_contents, kEdgeFormat, input, operator_id);
    351       already_added_arrays.insert(input);
    352     }
    353     // Add nodes and edges for all outputs of the operator.
    354     for (const auto& output : op.outputs) {
    355       if (!model.HasArray(output)) {
    356         // Arrays should _always_ exist. Except, perhaps, during development.
    357         continue;
    358       }
    359       auto array_properties = GetPropertiesForArray(model, output);
    360       if (!already_added_arrays.count(output)) {
    361         AppendF(output_file_contents, kNodeFormat, output,
    362                 array_properties.label, "octagon",
    363                 array_properties.color.FillColorString().c_str(),
    364                 array_properties.color.TextColorString().c_str());
    365       }
    366       AppendF(output_file_contents, kEdgeFormat, operator_id, output);
    367       already_added_arrays.insert(output);
    368     }
    369   }
    370 
    371   for (const auto& rnn_state : model.flags.rnn_states()) {
    372     AppendF(output_file_contents, kRNNBackEdgeFormat,
    373             rnn_state.back_edge_source_array(), rnn_state.state_array());
    374   }
    375 
    376   AppendF(output_file_contents, "}\n");
    377 }
    378 }  // namespace toco
    379