1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" 16 17 #include <algorithm> 18 #include <memory> 19 #include <string> 20 #include <unordered_map> 21 #include <utility> 22 #include <vector> 23 24 #include "tensorflow/contrib/lite/toco/toco_port.h" 25 #include "tensorflow/contrib/lite/toco/tooling_util.h" 26 #include "tensorflow/core/platform/logging.h" 27 28 namespace toco { 29 30 namespace { 31 32 void PrintModelStats(const string& label, const Model& model) { 33 int quantized_arrays = 0; 34 for (const auto& array : model.GetArrayMap()) { 35 if (array.second->quantization_params) { 36 quantized_arrays++; 37 } 38 } 39 LOG(INFO) << label << ": " << model.operators.size() << " operators, " 40 << model.GetArrayMap().size() << " arrays (" << quantized_arrays 41 << " quantized)"; 42 } 43 44 // Some graphs have RNN back-edges that are discardable, having been 45 // created typically by TensorFlow import rather than specified by the user. 46 // Such graphs might have cycles (closed by RNN back-edges) that may be pruned. 47 // Local graph transformations can't identify such global features, 48 // so this function performs this global transformation. 49 // 50 // The other (and related) thing that is peculiar about RNN back-edges 51 // is that they do not prevent the arrays that they touch, from being 52 // pruned. Thus, they may refer to array names which no longer exist. 53 // The intent is for that to result in the eventual pruning of such 54 // 'dangling' RNN back-edges. We perform this pruning at the end of this 55 // function, as the pruning of connected components done here may leave 56 // more RNN back-edges dangling. 57 void DiscardUselessConnectedComponentsAndRNNBackEdges(Model* model) { 58 // Identify the set of arrays that are in 'useful' connected components 59 // of the graph, which means connected to output arrays. 60 std::unordered_set<string> useful_arrays; 61 for (const string& output_array : model->flags.output_arrays()) { 62 useful_arrays.insert(output_array); 63 } 64 bool found_new_useful_arrays; 65 do { 66 found_new_useful_arrays = false; 67 for (const auto& op : model->operators) { 68 bool op_touches_useful_arrays = false; 69 for (const string& output : op->outputs) { 70 op_touches_useful_arrays |= useful_arrays.count(output); 71 } 72 if (op_touches_useful_arrays) { 73 for (const string& input : op->inputs) { 74 found_new_useful_arrays |= !useful_arrays.count(input); 75 useful_arrays.insert(input); 76 } 77 for (const string& output : op->outputs) { 78 found_new_useful_arrays |= !useful_arrays.count(output); 79 useful_arrays.insert(output); 80 } 81 } 82 } 83 for (const auto& rnn_state : model->flags.rnn_states()) { 84 bool rnn_back_edge_touches_useful_arrays = 85 useful_arrays.count(rnn_state.state_array()); 86 if (rnn_back_edge_touches_useful_arrays) { 87 found_new_useful_arrays |= 88 !useful_arrays.count(rnn_state.back_edge_source_array()); 89 useful_arrays.insert(rnn_state.back_edge_source_array()); 90 } 91 } 92 } while (found_new_useful_arrays); 93 // Erase arrays that aren't useful, and that are discardable. 94 model->EraseArrays([&](const string& name) { 95 return (!useful_arrays.count(name) && IsDiscardableArray(*model, name)); 96 }); 97 // Erase operators that do not produce a useful output array. 98 for (auto it = model->operators.begin(); it != model->operators.end();) { 99 // Only need to test the first output, as we simultaneously added all of 100 // an operator's outputs to the list of output arrays. 101 if (useful_arrays.count((*it)->outputs[0])) { 102 ++it; 103 } else { 104 for (const string& output : (*it)->outputs) { 105 CHECK(!useful_arrays.count(output)); 106 } 107 it = model->operators.erase(it); 108 } 109 } 110 // Erase RNN back-edges that are 'dangling' i.e. that touch an array 111 // that no longer exists. This should only happen for discardable RNN 112 // back-edges. 113 std::vector<RnnState> rnn_states_to_keep; 114 for (const auto& rnn_state : model->flags.rnn_states()) { 115 const bool dangling = 116 !model->HasArray(rnn_state.back_edge_source_array()) || 117 !model->HasArray(rnn_state.state_array()); 118 if (dangling) { 119 CHECK(rnn_state.discardable()); 120 } else { 121 rnn_states_to_keep.push_back(rnn_state); 122 } 123 } 124 model->flags.clear_rnn_states(); 125 for (const auto& rnn_state : rnn_states_to_keep) { 126 *model->flags.add_rnn_states() = rnn_state; 127 } 128 } 129 130 bool GraphTransformationsPass(int increment, Model* model, 131 const GraphTransformationsSet& transformations) { 132 CHECK(increment == 1 || increment == -1); 133 bool changed = false; 134 if (model->operators.empty()) { 135 LOG(INFO) << "Model is empty!!!"; 136 return false; 137 } 138 int op_index = increment == 1 ? 0 : model->operators.size() - 1; 139 while (true) { 140 bool changed_now = false; 141 // Loop over all transformations at the current position in the graph. 142 for (const auto& transformation : transformations) { 143 CHECK(!changed_now); 144 CHECK(transformation->Messages().empty()); 145 changed_now = transformation->Run(model, op_index); 146 const char* made_a_change_msg = 147 changed_now ? "made a change" : "did NOT make a change"; 148 const int log_level = 149 changed_now ? kLogLevelModelChanged : kLogLevelModelUnchanged; 150 if (transformation->Messages().empty()) { 151 VLOG(log_level) << transformation->Name() << " " << made_a_change_msg 152 << " at op_index=" << op_index << "/" 153 << model->operators.size() - 1; 154 } 155 for (const string& message : transformation->Messages()) { 156 VLOG(log_level) << transformation->Name() << " " << made_a_change_msg 157 << " at op_index=" << op_index << "/" 158 << model->operators.size() - 1 << ": " << message; 159 } 160 transformation->ClearMessages(); 161 if (changed_now) { 162 DumpGraphvizVideoFrame(*model); 163 if (model->operators.empty()) return true; 164 op_index = std::min<int>(op_index, model->operators.size() - 1); 165 // Uncomment for debugging 166 // CheckInvariants(*model); 167 } 168 if (changed_now) { 169 break; 170 } 171 } 172 if (changed_now) { 173 changed = true; 174 } else { 175 const int op_index_last = 176 increment == 1 ? model->operators.size() - 1 : 0; 177 if (op_index == op_index_last) { 178 break; 179 } 180 op_index += increment; 181 } 182 } 183 DiscardUselessConnectedComponentsAndRNNBackEdges(model); 184 return changed; 185 } 186 187 } // namespace 188 189 void RunGraphTransformations(Model* model, const string& msg, 190 const GraphTransformationsSet& transformations) { 191 PrintModelStats(toco::port::StringF("Before %s", msg), *model); 192 int pass_index = 0; 193 while (GraphTransformationsPass((pass_index % 2) ? -1 : 1, model, 194 transformations)) { 195 pass_index++; 196 const auto& label = 197 toco::port::StringF("After %s pass %d", msg, pass_index); 198 PrintModelStats(label, *model); 199 CheckInvariants(*model); 200 } 201 } 202 203 } // namespace toco 204