1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <memory> 16 #include <string> 17 #include <unordered_map> 18 #include <vector> 19 20 #include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" 21 #include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" 22 #include "tensorflow/contrib/lite/toco/model.h" 23 #include "tensorflow/contrib/lite/toco/tooling_util.h" 24 #include "tensorflow/core/platform/logging.h" 25 26 namespace toco { 27 28 namespace { 29 30 template <ArrayDataType A> 31 void DequantizeBuffer(Array* array) { 32 const auto old_data = array->GetBuffer<A>().data; 33 array->buffer = nullptr; 34 array->data_type = ArrayDataType::kFloat; 35 auto& new_data = array->GetMutableBuffer<ArrayDataType::kFloat>().data; 36 new_data.resize(old_data.size()); 37 const auto& qparams = array->GetQuantizationParams(); 38 for (int i = 0; i < old_data.size(); i++) { 39 new_data[i] = qparams.scale * (old_data[i] - qparams.zero_point); 40 } 41 } 42 43 std::vector<std::unique_ptr<Operator>>::iterator FindFirstOpWithInput( 44 Model* model, const string& array_name) { 45 for (auto it = model->operators.begin(); it != model->operators.end(); ++it) { 46 for (const auto& input : it->get()->inputs) { 47 if (input == array_name) { 48 return it; 49 } 50 } 51 } 52 return model->operators.end(); 53 } 54 55 void ClearArrayQuantizationParams(const string& array_name, Model* model) { 56 auto* array = &model->GetArray(array_name); 57 CHECK(array->quantization_params); 58 for (auto& input_array : *model->flags.mutable_input_arrays()) { 59 if (input_array.name() == array_name) { 60 auto& qparams = *array->quantization_params; 61 const double new_std_value = 1. / qparams.scale; 62 const double new_mean_value = qparams.zero_point; 63 if (input_array.has_std_value()) { 64 CHECK_LE(std::abs(new_std_value - input_array.std_value()), 0.001); 65 } else { 66 input_array.set_std_value(new_std_value); 67 } 68 if (input_array.has_mean_value()) { 69 CHECK_LE(std::abs(new_mean_value - input_array.mean_value()), 0.001); 70 } else { 71 input_array.set_mean_value(new_mean_value); 72 } 73 } 74 } 75 array->quantization_params = nullptr; 76 } 77 78 bool DequantizeArray(const string& array_name, 79 GraphTransformation* transformation, Model* model) { 80 auto* array = &model->GetArray(array_name); 81 if (!array->quantization_params) { 82 return false; 83 } 84 transformation->AddMessageF("Dequantizing array: %s", array_name); 85 86 // Dequantize any buffer 87 if (array->buffer) { 88 if (array->data_type == ArrayDataType::kUint8) { 89 DequantizeBuffer<ArrayDataType::kUint8>(array); 90 } else if (array->data_type == ArrayDataType::kInt32) { 91 DequantizeBuffer<ArrayDataType::kInt32>(array); 92 } else { 93 LOG(FATAL) << "Unhandled data type"; 94 } 95 CHECK(array->data_type == ArrayDataType::kFloat); 96 CHECK(array->buffer->type == ArrayDataType::kFloat); 97 98 // Clear quantization params, officially makes this a non-quantized array. 99 ClearArrayQuantizationParams(array_name, model); 100 return true; 101 } else { 102 array->data_type = ArrayDataType::kFloat; 103 } 104 105 // Clear quantization params, officially makes this a non-quantized array. 106 ClearArrayQuantizationParams(array_name, model); 107 108 if (array->buffer) { 109 return true; 110 } 111 112 auto* op_outputting_array = GetOpWithOutput(*model, array_name); 113 if (op_outputting_array) { 114 if (op_outputting_array->type == OperatorType::kTensorFlowReshape) { 115 return true; 116 } 117 } 118 119 // If there was no minmax info, we can return now. Indeed, 120 // the below only serves to create a FakeQuant node, but some arrays are 121 // quantized without MinMax (see the CHECK above) and that corresponds to 122 // places where a FakeQuant node is actually not wanted, because the 123 // quantization params are meant to be inferred in another way (e.g. bias 124 // vector for a Conv op, see their special-casing in quantize.cc). 125 if (!array->minmax) { 126 return true; 127 } 128 129 // Determine whether to insert a FakeQuant before or after 130 // this array. 131 bool must_insert_fakequant_before = false; 132 bool must_insert_fakequant_after = false; 133 if (IsInputArray(*model, array_name)) { 134 must_insert_fakequant_after = true; 135 } 136 for (const string& output_array : model->flags.output_arrays()) { 137 if (array_name == output_array) { 138 must_insert_fakequant_before = true; 139 } 140 } 141 for (const auto& rnn_state : model->flags.rnn_states()) { 142 if (array_name == rnn_state.state_array()) { 143 must_insert_fakequant_after = true; 144 } 145 if (array_name == rnn_state.back_edge_source_array()) { 146 must_insert_fakequant_before = true; 147 } 148 } 149 CHECK(!(must_insert_fakequant_before && must_insert_fakequant_after)); 150 151 // Create and insert the FakeQuant node 152 auto* fakequant_op = new FakeQuantOperator; 153 model->operators.emplace(FindFirstOpWithInput(model, array_name), 154 fakequant_op); 155 const string& new_array_name = AvailableArrayName(*model, array_name); 156 auto& new_array = model->GetOrCreateArray(new_array_name); 157 new_array.data_type = ArrayDataType::kFloat; 158 new_array.copy_shape(array->shape()); 159 new_array.GetOrCreateMinMax() = array->GetMinMax(); 160 fakequant_op->minmax.reset(new MinMax); 161 *fakequant_op->minmax = array->GetMinMax(); 162 if (must_insert_fakequant_before) { 163 for (const auto& op : model->operators) { 164 for (string& output : op->outputs) { 165 if (output == array_name) { 166 output = new_array_name; 167 } 168 } 169 } 170 fakequant_op->inputs = {new_array_name}; 171 fakequant_op->outputs = {array_name}; 172 } else { 173 for (const auto& op : model->operators) { 174 for (string& input : op->inputs) { 175 if (input == array_name) { 176 input = new_array_name; 177 } 178 } 179 } 180 fakequant_op->inputs = {array_name}; 181 fakequant_op->outputs = {new_array_name}; 182 } 183 return true; 184 } 185 186 } // namespace 187 188 bool Dequantize::Run(Model* model, std::size_t op_index) { 189 const auto op_it = model->operators.begin() + op_index; 190 auto* op = op_it->get(); 191 192 if (op->type == OperatorType::kDequantize) { 193 auto& input_array = model->GetArray(op->inputs[0]); 194 if (input_array.data_type == ArrayDataType::kFloat) { 195 return false; 196 } 197 if (input_array.final_data_type != ArrayDataType::kFloat) { 198 return false; 199 } 200 input_array.data_type = ArrayDataType::kFloat; 201 input_array.quantization_params = nullptr; 202 auto& output_array = model->GetArray(op->outputs[0]); 203 output_array.data_type = ArrayDataType::kFloat; 204 output_array.quantization_params = nullptr; 205 return RemoveTrivialPassthroughOp(this, model, op_index); 206 } 207 208 std::vector<string> arrays; 209 for (const string& input : op->inputs) { 210 arrays.push_back(input); 211 } 212 for (const string& output : op->outputs) { 213 arrays.push_back(output); 214 } 215 bool changed = false; 216 for (const string& array : arrays) { 217 if (!model->IsOptionalArray(array)) { 218 changed |= DequantizeArray(array, this, model); 219 } 220 } 221 222 return changed; 223 } 224 225 } // namespace toco 226