1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <memory> 16 #include <string> 17 #include <unordered_map> 18 #include <vector> 19 20 #include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" 21 #include "tensorflow/contrib/lite/toco/model.h" 22 #include "tensorflow/contrib/lite/toco/tooling_util.h" 23 #include "tensorflow/core/platform/logging.h" 24 25 namespace toco { 26 27 bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) { 28 const auto switch_it = model->operators.begin() + op_index; 29 const auto* switch_op = switch_it->get(); 30 if (switch_op->type != OperatorType::kTensorFlowSwitch) { 31 return false; 32 } 33 34 CHECK_EQ(switch_op->inputs.size(), 2); 35 CHECK_EQ(switch_op->outputs.size(), 2); 36 const string& predicate_name = switch_op->inputs[1]; 37 // If the predicate array hasn't been resolved to a constant yet, 38 // we need to yield. 39 if (!IsConstantParameterArray(*model, predicate_name)) { 40 AddMessageF( 41 "Waiting for the boolean predicate of %s to be resolved to a constant", 42 LogName(*switch_op)); 43 return false; 44 } 45 46 // The predicate should be boolean, and should consist of a single value. 47 const auto& predicate_array = model->GetArray(predicate_name); 48 CHECK(predicate_array.data_type == ArrayDataType::kBool); 49 for (const auto& dim : predicate_array.shape().dims()) { 50 CHECK_EQ(dim, 1); 51 } 52 53 // Obtain the predicate boolean value. 54 const auto& predicate_data = 55 predicate_array.GetBuffer<ArrayDataType::kBool>().data; 56 CHECK_EQ(predicate_data.size(), 1); 57 const bool predicate_value = predicate_data[0]; 58 59 // From the TensorFlow docs on .switch() in 60 // third_party/tensorflow/python/ops/control_flow_ops.py 61 // 62 // If `pred` is false, the `data` input is forwarded to the first output. 63 // Otherwise, the data goes to the second output. 64 // 65 // Note that this comment used to say the opposite and was recently fixed: 66 // https://github.com/tensorflow/tensorflow/commit/bc456e361d49d1d89a74b80060c70efb51fd7d87#diff-76ab9dafbe12c20ddc3769c6b108986c 67 const int selected_output_index = predicate_value ? 1 : 0; 68 const int nonselected_output_index = predicate_value ? 0 : 1; 69 70 // Update the edges of the graph ahead of removing the node: 71 // edges that were pointing to the selected output, should instead 72 // point to the input of the Switch node. 73 for (const auto& other_op : model->operators) { 74 for (auto& input : other_op->inputs) { 75 if (input == switch_op->outputs[selected_output_index]) { 76 input = switch_op->inputs[0]; 77 } 78 } 79 } 80 81 // There remains to handle the edges that were pointing to the nonselected 82 // output. We will just discard those edges. Concretely, at the moment, 83 // our only examples of graphs with Switch nodes have them feeding into Merge 84 // nodes, so what we're saying here is that we'll make the convention, 85 // in our toco internal representation, that Merge nodes with only 1 input 86 // are Merge nodes that have been resolved already and should be have as 87 // Identity nodes, simply forwarding their input. 88 // 89 for (const auto& other_op : model->operators) { 90 auto input_it = other_op->inputs.begin(); 91 while (input_it != other_op->inputs.end()) { 92 if (*input_it == switch_op->outputs[nonselected_output_index]) { 93 // Let us guard our assumption that only Merge nodes consume the outputs 94 // of Switch nodes: 95 CHECK(other_op->type == OperatorType::kTensorFlowMerge); 96 input_it = other_op->inputs.erase(input_it); 97 } else { 98 ++input_it; 99 } 100 } 101 } 102 103 // Remove the output arrays if they are now unused. 104 for (int i = 0; i < 2; i++) { 105 if (!GetOpWithInput(*model, switch_op->outputs[i])) { 106 model->EraseArray(switch_op->outputs[i]); 107 } 108 } 109 // Remove input arrays if they are only used by the switch itself and aren't 110 // the output of another op (will get handled by RemoveUnusedOp in that case). 111 for (const auto& input : switch_op->inputs) { 112 if (CountOpsWithInput(*model, input) == 1 && 113 !GetOpWithOutput(*model, input)) { 114 model->EraseArray(input); 115 } 116 } 117 // Remove the switch node itself. 118 AddMessageF("Removing already-resolved %s", LogName(*switch_op)); 119 model->operators.erase(switch_it); 120 return true; 121 } 122 123 } // namespace toco 124