Home | History | Annotate | Download | only in graph_transformations
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
     16 
     17 #include <algorithm>
     18 #include <memory>
     19 #include <string>
     20 #include <unordered_map>
     21 #include <utility>
     22 #include <vector>
     23 
     24 #include "tensorflow/contrib/lite/toco/toco_port.h"
     25 #include "tensorflow/contrib/lite/toco/tooling_util.h"
     26 #include "tensorflow/core/platform/logging.h"
     27 
     28 namespace toco {
     29 
     30 namespace {
     31 
     32 void PrintModelStats(const string& label, const Model& model) {
     33   int quantized_arrays = 0;
     34   for (const auto& array : model.GetArrayMap()) {
     35     if (array.second->quantization_params) {
     36       quantized_arrays++;
     37     }
     38   }
     39   LOG(INFO) << label << ": " << model.operators.size() << " operators, "
     40             << model.GetArrayMap().size() << " arrays (" << quantized_arrays
     41             << " quantized)";
     42 }
     43 
     44 // Some graphs have RNN back-edges that are discardable, having been
     45 // created typically by TensorFlow import rather than specified by the user.
     46 // Such graphs might have cycles (closed by RNN back-edges) that may be pruned.
     47 // Local graph transformations can't identify such global features,
     48 // so this function performs this global transformation.
     49 //
     50 // The other (and related) thing that is peculiar about RNN back-edges
     51 // is that they do not prevent the arrays that they touch, from being
     52 // pruned. Thus, they may refer to array names which no longer exist.
     53 // The intent is for that to result in the eventual pruning of such
     54 // 'dangling' RNN back-edges. We perform this pruning at the end of this
     55 // function, as the pruning of connected components done here may leave
     56 // more RNN back-edges dangling.
     57 void DiscardUselessConnectedComponentsAndRNNBackEdges(Model* model) {
     58   // Identify the set of arrays that are in 'useful' connected components
     59   // of the graph, which means connected to output arrays.
     60   std::unordered_set<string> useful_arrays;
     61   for (const string& output_array : model->flags.output_arrays()) {
     62     useful_arrays.insert(output_array);
     63   }
     64   bool found_new_useful_arrays;
     65   do {
     66     found_new_useful_arrays = false;
     67     for (const auto& op : model->operators) {
     68       bool op_touches_useful_arrays = false;
     69       for (const string& output : op->outputs) {
     70         op_touches_useful_arrays |= useful_arrays.count(output);
     71       }
     72       if (op_touches_useful_arrays) {
     73         for (const string& input : op->inputs) {
     74           found_new_useful_arrays |= !useful_arrays.count(input);
     75           useful_arrays.insert(input);
     76         }
     77         for (const string& output : op->outputs) {
     78           found_new_useful_arrays |= !useful_arrays.count(output);
     79           useful_arrays.insert(output);
     80         }
     81       }
     82     }
     83     for (const auto& rnn_state : model->flags.rnn_states()) {
     84       bool rnn_back_edge_touches_useful_arrays =
     85           useful_arrays.count(rnn_state.state_array());
     86       if (rnn_back_edge_touches_useful_arrays) {
     87         found_new_useful_arrays |=
     88             !useful_arrays.count(rnn_state.back_edge_source_array());
     89         useful_arrays.insert(rnn_state.back_edge_source_array());
     90       }
     91     }
     92   } while (found_new_useful_arrays);
     93   // Erase arrays that aren't useful, and that are discardable.
     94   model->EraseArrays([&](const string& name) {
     95     return (!useful_arrays.count(name) && IsDiscardableArray(*model, name));
     96   });
     97   // Erase operators that do not produce a useful output array.
     98   for (auto it = model->operators.begin(); it != model->operators.end();) {
     99     // Only need to test the first output, as we simultaneously added all of
    100     // an operator's outputs to the list of output arrays.
    101     if (useful_arrays.count((*it)->outputs[0])) {
    102       ++it;
    103     } else {
    104       for (const string& output : (*it)->outputs) {
    105         CHECK(!useful_arrays.count(output));
    106       }
    107       it = model->operators.erase(it);
    108     }
    109   }
    110   // Erase RNN back-edges that are 'dangling' i.e. that touch an array
    111   // that no longer exists. This should only happen for discardable RNN
    112   // back-edges.
    113   std::vector<RnnState> rnn_states_to_keep;
    114   for (const auto& rnn_state : model->flags.rnn_states()) {
    115     const bool dangling =
    116         !model->HasArray(rnn_state.back_edge_source_array()) ||
    117         !model->HasArray(rnn_state.state_array());
    118     if (dangling) {
    119       CHECK(rnn_state.discardable());
    120     } else {
    121       rnn_states_to_keep.push_back(rnn_state);
    122     }
    123   }
    124   model->flags.clear_rnn_states();
    125   for (const auto& rnn_state : rnn_states_to_keep) {
    126     *model->flags.add_rnn_states() = rnn_state;
    127   }
    128 }
    129 
    130 bool GraphTransformationsPass(int increment, Model* model,
    131                               const GraphTransformationsSet& transformations) {
    132   CHECK(increment == 1 || increment == -1);
    133   bool changed = false;
    134   if (model->operators.empty()) {
    135     LOG(INFO) << "Model is empty!!!";
    136     return false;
    137   }
    138   int op_index = increment == 1 ? 0 : model->operators.size() - 1;
    139   while (true) {
    140     bool changed_now = false;
    141     // Loop over all transformations at the current position in the graph.
    142     for (const auto& transformation : transformations) {
    143       CHECK(!changed_now);
    144       CHECK(transformation->Messages().empty());
    145       changed_now = transformation->Run(model, op_index);
    146       const char* made_a_change_msg =
    147           changed_now ? "made a change" : "did NOT make a change";
    148       const int log_level =
    149           changed_now ? kLogLevelModelChanged : kLogLevelModelUnchanged;
    150       if (transformation->Messages().empty()) {
    151         VLOG(log_level) << transformation->Name() << " " << made_a_change_msg
    152                         << " at op_index=" << op_index << "/"
    153                         << model->operators.size() - 1;
    154       }
    155       for (const string& message : transformation->Messages()) {
    156         VLOG(log_level) << transformation->Name() << " " << made_a_change_msg
    157                         << " at op_index=" << op_index << "/"
    158                         << model->operators.size() - 1 << ": " << message;
    159       }
    160       transformation->ClearMessages();
    161       if (changed_now) {
    162         DumpGraphvizVideoFrame(*model);
    163         if (model->operators.empty()) return true;
    164         op_index = std::min<int>(op_index, model->operators.size() - 1);
    165         // Uncomment for debugging
    166         // CheckInvariants(*model);
    167       }
    168       if (changed_now) {
    169         break;
    170       }
    171     }
    172     if (changed_now) {
    173       changed = true;
    174     } else {
    175       const int op_index_last =
    176           increment == 1 ? model->operators.size() - 1 : 0;
    177       if (op_index == op_index_last) {
    178         break;
    179       }
    180       op_index += increment;
    181     }
    182   }
    183   DiscardUselessConnectedComponentsAndRNNBackEdges(model);
    184   return changed;
    185 }
    186 
    187 }  // namespace
    188 
    189 void RunGraphTransformations(Model* model, const string& msg,
    190                              const GraphTransformationsSet& transformations) {
    191   PrintModelStats(toco::port::StringF("Before %s", msg), *model);
    192   int pass_index = 0;
    193   while (GraphTransformationsPass((pass_index % 2) ? -1 : 1, model,
    194                                   transformations)) {
    195     pass_index++;
    196     const auto& label =
    197         toco::port::StringF("After %s pass %d", msg, pass_index);
    198     PrintModelStats(label, *model);
    199     CheckInvariants(*model);
    200   }
    201 }
    202 
    203 }  // namespace toco
    204