1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <iterator> 16 #include <memory> 17 #include <string> 18 #include <unordered_map> 19 #include <vector> 20 21 #include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" 22 #include "tensorflow/contrib/lite/toco/model.h" 23 #include "tensorflow/contrib/lite/toco/tooling_util.h" 24 25 namespace toco { 26 27 namespace { 28 29 template <typename T> 30 bool AreAllBufferElementsZero(const std::vector<T>& buffer_data) { 31 for (auto x : buffer_data) { 32 if (x != 0) { 33 return false; 34 } 35 } 36 return true; 37 } 38 39 template <ArrayDataType Type> 40 void FillArrayWithZeros(Array* array) { 41 CHECK(array->data_type == Type); 42 std::vector<DataType<Type>>& data = array->GetMutableBuffer<Type>().data; 43 data.resize(RequiredBufferSizeForShape(array->shape())); 44 for (size_t i = 0; i < data.size(); i++) { 45 data[i] = 0; 46 } 47 } 48 49 } // namespace 50 51 // Removes a multiplication by array of constant zeros by making the output 52 // array an array of constant zeros and removing the input arrays if they are no 53 // longer needed. 54 bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) { 55 const auto mul_it = model->operators.begin() + op_index; 56 auto* mul_op = mul_it->get(); 57 if (mul_op->type != OperatorType::kMul) { 58 return false; 59 } 60 const auto& output_array_name = mul_op->outputs[0]; 61 auto& output_array = model->GetArray(output_array_name); 62 63 // Yield if the output shape is not known yet. 64 if (!output_array.has_shape()) { 65 return false; 66 } 67 68 // This transformation only handles the case where one operand is all 0's and 69 // the other is non-constant. Other cases are handled by constant propagation 70 // or the trivial binary removal pass. 71 const bool is_input_constant[2] = { 72 IsConstantParameterArray(*model, mul_op->inputs[0]), 73 IsConstantParameterArray(*model, mul_op->inputs[1]), 74 }; 75 if (!is_input_constant[0] && !is_input_constant[1]) { 76 // Neither input is constant, so nothing we can resolve here. 77 return false; 78 } 79 if (is_input_constant[0] && is_input_constant[1]) { 80 // Both inputs are constants. That's a job for constants propagation, not 81 // for us to handle here. 82 return false; 83 } 84 const int index_of_constant_input = is_input_constant[0] ? 0 : 1; 85 const int index_of_variable_input = is_input_constant[0] ? 1 : 0; 86 CHECK(is_input_constant[index_of_constant_input]); 87 CHECK(!is_input_constant[index_of_variable_input]); 88 89 const auto& constant_input_array = 90 model->GetArray(mul_op->inputs[index_of_constant_input]); 91 92 CHECK(constant_input_array.data_type == output_array.data_type); 93 switch (output_array.data_type) { 94 case ArrayDataType::kFloat: { 95 const auto& constant_input_data = 96 constant_input_array.GetBuffer<ArrayDataType::kFloat>().data; 97 if (!AreAllBufferElementsZero<DataType<ArrayDataType::kFloat>>( 98 constant_input_data)) { 99 return false; 100 } 101 FillArrayWithZeros<ArrayDataType::kFloat>(&output_array); 102 } break; 103 case ArrayDataType::kUint8: { 104 const auto& constant_input_data = 105 constant_input_array.GetBuffer<ArrayDataType::kUint8>().data; 106 if (!AreAllBufferElementsZero<DataType<ArrayDataType::kUint8>>( 107 constant_input_data)) { 108 return false; 109 } 110 FillArrayWithZeros<ArrayDataType::kUint8>(&output_array); 111 } break; 112 case ArrayDataType::kInt32: { 113 const auto& constant_input_data = 114 constant_input_array.GetBuffer<ArrayDataType::kInt32>().data; 115 if (!AreAllBufferElementsZero<DataType<ArrayDataType::kInt32>>( 116 constant_input_data)) { 117 return false; 118 } 119 FillArrayWithZeros<ArrayDataType::kInt32>(&output_array); 120 } break; 121 case ArrayDataType::kInt64: { 122 const auto& constant_input_data = 123 constant_input_array.GetBuffer<ArrayDataType::kInt64>().data; 124 if (!AreAllBufferElementsZero<DataType<ArrayDataType::kInt64>>( 125 constant_input_data)) { 126 return false; 127 } 128 FillArrayWithZeros<ArrayDataType::kInt64>(&output_array); 129 } break; 130 default: 131 AddMessageF( 132 "Cannot resolve multiply by 0 because of unsupported data type\n"); 133 return false; 134 } 135 136 // Erase input arrays to the multiply if no longer used 137 if (IsDiscardableArray(*model, mul_op->inputs[0]) && 138 CountOpsWithInput(*model, mul_op->inputs[0]) == 1) { 139 model->EraseArray(mul_op->inputs[0]); 140 } 141 if (IsDiscardableArray(*model, mul_op->inputs[1]) && 142 CountOpsWithInput(*model, mul_op->inputs[1]) == 1) { 143 model->EraseArray(mul_op->inputs[1]); 144 } 145 146 // Erase the multiply operator. 147 model->operators.erase(mul_it); 148 149 return true; 150 } 151 152 } // namespace toco 153