1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <memory> 16 #include <string> 17 #include <unordered_map> 18 #include <vector> 19 20 #include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" 21 #include "tensorflow/contrib/lite/toco/model.h" 22 #include "tensorflow/contrib/lite/toco/runtime/types.h" 23 #include "tensorflow/contrib/lite/toco/tooling_util.h" 24 #include "tensorflow/core/platform/logging.h" 25 26 namespace toco { 27 28 bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) { 29 auto bn_it = model->operators.begin() + op_index; 30 if (bn_it->get()->type != OperatorType::kBatchNormalization) { 31 return false; 32 } 33 const auto* bn_op = 34 static_cast<const BatchNormalizationOperator*>(bn_it->get()); 35 36 const auto& mean_array = model->GetArray(bn_op->inputs[1]); 37 const auto& multiplier_array = model->GetArray(bn_op->inputs[2]); 38 const auto& offset_array = model->GetArray(bn_op->inputs[3]); 39 40 CHECK(IsConstantParameterArray(*model, bn_op->inputs[1]) && 41 IsConstantParameterArray(*model, bn_op->inputs[2]) && 42 IsConstantParameterArray(*model, bn_op->inputs[3])) 43 << "Batch normalization resolution requires that mean, multiplier and " 44 "offset arrays be constant."; 45 46 // We should only have *float* BatchNormalizations... let's guard this 47 // assumption by CHECK's. 48 CHECK(mean_array.data_type == ArrayDataType::kFloat); 49 CHECK(multiplier_array.data_type == ArrayDataType::kFloat); 50 CHECK(offset_array.data_type == ArrayDataType::kFloat); 51 52 // Create the new Mul, Add operators 53 auto* mul_op = new MulOperator; 54 auto* add_op = new AddOperator; 55 const string mul_name = 56 AvailableArrayName(*model, bn_op->outputs[0] + "_mul"); 57 const string add_name = 58 AvailableArrayName(*model, bn_op->outputs[0] + "_add"); 59 const string mul_param_name = AvailableArrayName(*model, mul_name + "_param"); 60 const string add_param_name = AvailableArrayName(*model, add_name + "_param"); 61 mul_op->inputs = {bn_op->inputs[0], mul_param_name}; 62 mul_op->outputs = {mul_name}; 63 add_op->inputs = {mul_name, add_param_name}; 64 add_op->outputs = {bn_op->outputs[0]}; 65 AddMessageF("Splitting %s into %s and %s", LogName(*bn_op), LogName(*mul_op), 66 LogName(*add_op)); 67 68 // Create the intermediate activation array (output of mul, input of add) 69 auto& intermediate_array = model->GetOrCreateArray(mul_op->outputs[0]); 70 intermediate_array.data_type = model->GetArray(bn_op->inputs[0]).data_type; 71 72 // Insert the new operators in the graph 73 auto add_it = model->operators.emplace(bn_it, add_op); 74 auto mul_it = model->operators.emplace(add_it, mul_op); 75 // update invalidated iterators. 76 DCHECK_EQ(mul_it->get(), mul_op); 77 add_it = mul_it + 1; 78 DCHECK_EQ(add_it->get(), add_op); 79 bn_it = add_it + 1; 80 DCHECK_EQ(bn_it->get(), bn_op); 81 82 // Create the new param arrays 83 const auto& mean_shape = mean_array.shape(); 84 const auto& multiplier_shape = multiplier_array.shape(); 85 const auto& offset_shape = offset_array.shape(); 86 CHECK(mean_shape.dims() == multiplier_shape.dims()); 87 CHECK(mean_shape.dims() == offset_shape.dims()); 88 const auto& param_shape = mean_shape; 89 const int buffer_size = RequiredBufferSizeForShape(param_shape); 90 auto& mul_param_array = model->GetOrCreateArray(mul_param_name); 91 auto& add_param_array = model->GetOrCreateArray(add_param_name); 92 DropMinMax(model, mul_param_name); 93 DropMinMax(model, add_param_name); 94 mul_param_array.copy_shape(param_shape); 95 add_param_array.copy_shape(param_shape); 96 mul_param_array.data_type = ArrayDataType::kFloat; 97 add_param_array.data_type = ArrayDataType::kFloat; 98 auto& mul_float_data = 99 mul_param_array.GetMutableBuffer<ArrayDataType::kFloat>().data; 100 auto& add_float_data = 101 add_param_array.GetMutableBuffer<ArrayDataType::kFloat>().data; 102 mul_float_data.resize(buffer_size); 103 add_float_data.resize(buffer_size); 104 const auto& mean_float_data = 105 mean_array.GetBuffer<ArrayDataType::kFloat>().data; 106 const auto& multiplier_float_data = 107 multiplier_array.GetBuffer<ArrayDataType::kFloat>().data; 108 const auto& offset_float_data = 109 offset_array.GetBuffer<ArrayDataType::kFloat>().data; 110 111 CHECK(mul_float_data.size() == buffer_size); 112 CHECK(add_float_data.size() == buffer_size); 113 CHECK(mean_float_data.size() == buffer_size); 114 CHECK(multiplier_float_data.size() == buffer_size); 115 CHECK(offset_float_data.size() == buffer_size); 116 117 for (int i = 0; i < buffer_size; i++) { 118 mul_float_data[i] = multiplier_float_data[i]; 119 add_float_data[i] = 120 offset_float_data[i] - mean_float_data[i] * multiplier_float_data[i]; 121 } 122 123 // Remove the old param arrays 124 model->EraseArray(bn_op->inputs[1]); 125 model->EraseArray(bn_op->inputs[2]); 126 model->EraseArray(bn_op->inputs[3]); 127 128 // Remove the old operator 129 DCHECK_EQ(bn_it->get(), bn_op); 130 model->operators.erase(bn_it); 131 132 return true; 133 } 134 135 } // namespace toco 136