1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <algorithm> 16 #include <memory> 17 #include <string> 18 #include <unordered_map> 19 #include <vector> 20 21 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" 22 #include "tensorflow/lite/toco/model.h" 23 #include "tensorflow/lite/toco/runtime/types.h" 24 #include "tensorflow/lite/toco/tooling_util.h" 25 #include "tensorflow/core/platform/logging.h" 26 27 namespace toco { 28 29 namespace { 30 31 std::vector<bool> VectorGreaterThan(const std::vector<int>& a, 32 const std::vector<int>& b) { 33 DCHECK_EQ(a.size(), b.size()); 34 const int size = a.size(); 35 std::vector<bool> result(size); 36 for (int i = 0; i < size; i++) { 37 result[i] = a[i] > b[i]; 38 } 39 return result; 40 } 41 42 void PairwiseVectorSelect(const std::vector<bool>& selector, 43 const std::vector<int>& input_a, 44 const std::vector<int>& input_b, 45 std::vector<int>* output_a, 46 std::vector<int>* output_b) { 47 DCHECK_EQ(input_a.size(), input_b.size()); 48 DCHECK_EQ(output_a->size(), output_b->size()); 49 DCHECK_EQ(input_a.size(), output_a->size()); 50 DCHECK_EQ(selector.size(), input_a.size()); 51 const int size = input_a.size(); 52 for (int i = 0; i < size; i++) { 53 if (selector[i]) { 54 (*output_a)[i] = input_a[i]; 55 (*output_b)[i] = input_b[i]; 56 } else { 57 (*output_a)[i] = input_b[i]; 58 (*output_b)[i] = input_a[i]; 59 } 60 } 61 } 62 63 template <ArrayDataType InputsDataType, ArrayDataType OutputDataType> 64 void EvaluateBinaryOperatorOnConstantInputs(Model* model, 65 const Operator* binary_op) { 66 CHECK(IsConstantParameterArray(*model, binary_op->inputs[0])); 67 CHECK(IsConstantParameterArray(*model, binary_op->inputs[1])); 68 CHECK(binary_op->fused_activation_function == 69 FusedActivationFunctionType::kNone); 70 const auto& input0_array = model->GetArray(binary_op->inputs[0]); 71 const auto& input1_array = model->GetArray(binary_op->inputs[1]); 72 const auto& output_name = binary_op->outputs[0]; 73 auto& output_array = model->GetArray(output_name); 74 CHECK(input0_array.data_type == InputsDataType); 75 CHECK(input1_array.data_type == InputsDataType); 76 CHECK(output_array.data_type == OutputDataType); 77 78 // We have already tested above for existence of input buffers 79 // (synonymous to being a constant param). 80 CHECK(input0_array.buffer); 81 CHECK(input1_array.buffer); 82 // On the other hand, the output should not already have a buffer. 83 CHECK(!output_array.buffer); 84 85 const auto& input0_data = input0_array.GetBuffer<InputsDataType>().data; 86 const auto& input1_data = input1_array.GetBuffer<InputsDataType>().data; 87 // Create the buffer on the output array, effectively turning it into 88 // a constant parameter 89 90 const Shape& output_shape = output_array.shape(); 91 auto& output_data = output_array.GetMutableBuffer<OutputDataType>().data; 92 const int output_buffer_size = RequiredBufferSizeForShape(output_shape); 93 output_data.resize(output_buffer_size); 94 const int dims_count = output_shape.dimensions_count(); 95 96 // It will be convenient here to have copies of the operands shapes 97 // extended to match the number of dimensions of the output shape. 98 Shape input0_shape = input0_array.shape(); 99 Shape input1_shape = input1_array.shape(); 100 ExtendShape(&input0_shape, dims_count); 101 ExtendShape(&input1_shape, dims_count); 102 // Now we may still have operands of different sizes, which would indicate 103 // that we have to "broadcast" the smaller dimension. We do this using a 104 // a vector of Booleans indicating which input is the larger in each 105 // dimension. 106 CHECK_EQ(input0_shape.dimensions_count(), input1_shape.dimensions_count()); 107 CHECK_EQ(input0_shape.dimensions_count(), dims_count); 108 const std::vector<bool> input0_larger = 109 VectorGreaterThan(input0_shape.dims(), input1_shape.dims()); 110 111 std::vector<int> big_sizes(dims_count); 112 std::vector<int> small_sizes(dims_count); 113 PairwiseVectorSelect(input0_larger, input0_shape.dims(), input1_shape.dims(), 114 &big_sizes, &small_sizes); 115 116 // The output should already be correctly sized to match the big dimensions. 117 for (int i = 0; i < dims_count; i++) { 118 CHECK_EQ(output_shape.dims(i), big_sizes[i]); 119 } 120 121 std::vector<int> input0_indices(dims_count); 122 std::vector<int> input1_indices(dims_count); 123 std::vector<int> modulo_indices(dims_count); 124 125 for (int k = 0; k < output_buffer_size; k++) { 126 const std::vector<int> output_indices = ReverseOffset(output_shape, k); 127 for (int i = 0; i < dims_count; i++) { 128 modulo_indices[i] = output_indices[i] % small_sizes[i]; 129 } 130 PairwiseVectorSelect(input0_larger, output_indices, modulo_indices, 131 &input0_indices, &input1_indices); 132 const auto val0 = input0_data[Offset(input0_shape, input0_indices)]; 133 const auto val1 = input1_data[Offset(input1_shape, input1_indices)]; 134 135 DataType<OutputDataType> outval; 136 if (binary_op->type == OperatorType::kAdd) { 137 outval = val0 + val1; 138 } else if (binary_op->type == OperatorType::kMul) { 139 outval = val0 * val1; 140 } else if (binary_op->type == OperatorType::kSub) { 141 outval = val0 - val1; 142 } else if (binary_op->type == OperatorType::kDiv) { 143 outval = val0 / val1; 144 } else if (binary_op->type == OperatorType::kFloorDiv) { 145 outval = floor(val0 / val1); 146 } else if (binary_op->type == OperatorType::kFloorMod) { 147 outval = val0 - (floor(val0 / val1) * val1); 148 } else if (binary_op->type == OperatorType::kMinimum) { 149 outval = std::min(val0, val1); 150 } else if (binary_op->type == OperatorType::kMaximum) { 151 outval = std::max(val0, val1); 152 } else if (binary_op->type == OperatorType::kLess) { 153 outval = val0 < val1; 154 } else if (binary_op->type == OperatorType::kLessEqual) { 155 outval = val0 <= val1; 156 } else if (binary_op->type == OperatorType::kGreater) { 157 outval = val0 > val1; 158 } else if (binary_op->type == OperatorType::kGreaterEqual) { 159 outval = val0 >= val1; 160 } else { 161 LOG(FATAL) << "should not get here"; 162 } 163 output_data[Offset(output_shape, output_indices)] = outval; 164 } 165 } 166 167 void EvaluateBinaryOperatorOnConstantInputs(Model* model, 168 const Operator* binary_op) { 169 const auto inputs_data_type = model->GetArray(binary_op->inputs[0]).data_type; 170 const auto output_data_type = 171 model->GetArray(binary_op->outputs[0]).data_type; 172 #define TOCO_HANDLE_CASE(InputsDataType, OutputDataType) \ 173 if (inputs_data_type == InputsDataType && \ 174 output_data_type == OutputDataType) { \ 175 EvaluateBinaryOperatorOnConstantInputs<InputsDataType, OutputDataType>( \ 176 model, binary_op); \ 177 return; \ 178 } 179 TOCO_HANDLE_CASE(ArrayDataType::kFloat, ArrayDataType::kFloat) 180 TOCO_HANDLE_CASE(ArrayDataType::kFloat, ArrayDataType::kBool) 181 TOCO_HANDLE_CASE(ArrayDataType::kInt32, ArrayDataType::kInt32) 182 TOCO_HANDLE_CASE(ArrayDataType::kInt32, ArrayDataType::kBool) 183 TOCO_HANDLE_CASE(ArrayDataType::kInt64, ArrayDataType::kInt64) 184 TOCO_HANDLE_CASE(ArrayDataType::kInt64, ArrayDataType::kBool) 185 LOG(FATAL) << "Unimplemented: don't know how to resolve a constant " 186 << "binary operator for these data types."; 187 #undef TOCO_HANDLE_CASE 188 } 189 } // namespace 190 191 ::tensorflow::Status ResolveConstantBinaryOperator::Run(Model* model, 192 std::size_t op_index, 193 bool* modified) { 194 *modified = false; 195 const auto binary_it = model->operators.begin() + op_index; 196 const auto* binary_op = binary_it->get(); 197 // Test for binary ops of types that we know how to resolve 198 if (binary_op->type != OperatorType::kAdd && 199 binary_op->type != OperatorType::kMul && 200 binary_op->type != OperatorType::kSub && 201 binary_op->type != OperatorType::kDiv && 202 binary_op->type != OperatorType::kFloorDiv && 203 binary_op->type != OperatorType::kFloorMod && 204 binary_op->type != OperatorType::kMinimum && 205 binary_op->type != OperatorType::kMaximum && 206 binary_op->type != OperatorType::kLess && 207 binary_op->type != OperatorType::kLessEqual && 208 binary_op->type != OperatorType::kGreater && 209 binary_op->type != OperatorType::kGreaterEqual) { 210 return ::tensorflow::Status::OK(); 211 } 212 CHECK_EQ(binary_op->inputs.size(), 2); 213 214 const auto& input0_array = model->GetArray(binary_op->inputs[0]); 215 const auto& input1_array = model->GetArray(binary_op->inputs[1]); 216 // Check if both inputs are constant parameters. 217 if (!input0_array.buffer || !input1_array.buffer) { 218 return ::tensorflow::Status::OK(); 219 } 220 221 auto& output_array = model->GetArray(binary_op->outputs[0]); 222 // Yield until the output array dims have been resolved. 223 if (!output_array.has_shape()) { 224 return ::tensorflow::Status::OK(); 225 } 226 227 // At the moment we don't want to care about fused activation functions. 228 // The idea is that we should do the present constants-propagation before 229 // activation functions get fused. 230 if (binary_op->fused_activation_function != 231 FusedActivationFunctionType::kNone) { 232 AddMessageF( 233 "Not resolving constant %s because it has a fused activation function", 234 LogName(*binary_op)); 235 return ::tensorflow::Status::OK(); 236 } 237 238 // Check that input data types agree. 239 CHECK(input0_array.data_type == input1_array.data_type) 240 << "Dissimilar data types given to op outputting \"" 241 << binary_op->outputs[0] << "\". 0:\"" << binary_op->inputs[0] << "\"(" 242 << static_cast<int>(input0_array.data_type) << ") 1:\"" 243 << binary_op->inputs[1] << "\"(" 244 << static_cast<int>(input1_array.data_type) << ")."; 245 246 // Do the actual constants propagation 247 EvaluateBinaryOperatorOnConstantInputs(model, binary_op); 248 249 // Remove the binary operator and its inputs 250 if (CountOpsWithInput(*model, binary_op->inputs[0]) == 1) { 251 model->EraseArray(binary_op->inputs[0]); 252 } 253 if (CountOpsWithInput(*model, binary_op->inputs[1]) == 1) { 254 model->EraseArray(binary_op->inputs[1]); 255 } 256 AddMessageF("Resolved constant %s to the equivalent constant array", 257 LogName(*binary_op)); 258 model->operators.erase(binary_it); 259 *modified = true; 260 return ::tensorflow::Status::OK(); 261 } 262 263 } // namespace toco 264