Home | History | Annotate | Download | only in graph_transformations
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include <memory>
     16 #include <string>
     17 #include <unordered_map>
     18 #include <vector>
     19 
     20 #include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
     21 #include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
     22 #include "tensorflow/contrib/lite/toco/model.h"
     23 #include "tensorflow/contrib/lite/toco/tooling_util.h"
     24 #include "tensorflow/core/platform/logging.h"
     25 
     26 namespace toco {
     27 
     28 namespace {
     29 
     30 template <ArrayDataType A>
     31 void DequantizeBuffer(Array* array) {
     32   const auto old_data = array->GetBuffer<A>().data;
     33   array->buffer = nullptr;
     34   array->data_type = ArrayDataType::kFloat;
     35   auto& new_data = array->GetMutableBuffer<ArrayDataType::kFloat>().data;
     36   new_data.resize(old_data.size());
     37   const auto& qparams = array->GetQuantizationParams();
     38   for (int i = 0; i < old_data.size(); i++) {
     39     new_data[i] = qparams.scale * (old_data[i] - qparams.zero_point);
     40   }
     41 }
     42 
     43 std::vector<std::unique_ptr<Operator>>::iterator FindFirstOpWithInput(
     44     Model* model, const string& array_name) {
     45   for (auto it = model->operators.begin(); it != model->operators.end(); ++it) {
     46     for (const auto& input : it->get()->inputs) {
     47       if (input == array_name) {
     48         return it;
     49       }
     50     }
     51   }
     52   return model->operators.end();
     53 }
     54 
     55 void ClearArrayQuantizationParams(const string& array_name, Model* model) {
     56   auto* array = &model->GetArray(array_name);
     57   CHECK(array->quantization_params);
     58   for (auto& input_array : *model->flags.mutable_input_arrays()) {
     59     if (input_array.name() == array_name) {
     60       auto& qparams = *array->quantization_params;
     61       const double new_std_value = 1. / qparams.scale;
     62       const double new_mean_value = qparams.zero_point;
     63       if (input_array.has_std_value()) {
     64         CHECK_LE(std::abs(new_std_value - input_array.std_value()), 0.001);
     65       } else {
     66         input_array.set_std_value(new_std_value);
     67       }
     68       if (input_array.has_mean_value()) {
     69         CHECK_LE(std::abs(new_mean_value - input_array.mean_value()), 0.001);
     70       } else {
     71         input_array.set_mean_value(new_mean_value);
     72       }
     73     }
     74   }
     75   array->quantization_params = nullptr;
     76 }
     77 
     78 bool DequantizeArray(const string& array_name,
     79                      GraphTransformation* transformation, Model* model) {
     80   auto* array = &model->GetArray(array_name);
     81   if (!array->quantization_params) {
     82     return false;
     83   }
     84   transformation->AddMessageF("Dequantizing array: %s", array_name);
     85 
     86   // Dequantize any buffer
     87   if (array->buffer) {
     88     if (array->data_type == ArrayDataType::kUint8) {
     89       DequantizeBuffer<ArrayDataType::kUint8>(array);
     90     } else if (array->data_type == ArrayDataType::kInt32) {
     91       DequantizeBuffer<ArrayDataType::kInt32>(array);
     92     } else {
     93       LOG(FATAL) << "Unhandled data type";
     94     }
     95     CHECK(array->data_type == ArrayDataType::kFloat);
     96     CHECK(array->buffer->type == ArrayDataType::kFloat);
     97 
     98     // Clear quantization params, officially makes this a non-quantized array.
     99     ClearArrayQuantizationParams(array_name, model);
    100     return true;
    101   } else {
    102     array->data_type = ArrayDataType::kFloat;
    103   }
    104 
    105   // Clear quantization params, officially makes this a non-quantized array.
    106   ClearArrayQuantizationParams(array_name, model);
    107 
    108   if (array->buffer) {
    109     return true;
    110   }
    111 
    112   auto* op_outputting_array = GetOpWithOutput(*model, array_name);
    113   if (op_outputting_array) {
    114     if (op_outputting_array->type == OperatorType::kTensorFlowReshape) {
    115       return true;
    116     }
    117   }
    118 
    119   // If there was no minmax info, we can return now. Indeed,
    120   // the below only serves to create a FakeQuant node, but some arrays are
    121   // quantized without MinMax (see the CHECK above) and that corresponds to
    122   // places where a FakeQuant node is actually not wanted, because the
    123   // quantization params are meant to be inferred in another way (e.g. bias
    124   // vector for a Conv op, see their special-casing in quantize.cc).
    125   if (!array->minmax) {
    126     return true;
    127   }
    128 
    129   // Determine whether to insert a FakeQuant before or after
    130   // this array.
    131   bool must_insert_fakequant_before = false;
    132   bool must_insert_fakequant_after = false;
    133   if (IsInputArray(*model, array_name)) {
    134     must_insert_fakequant_after = true;
    135   }
    136   for (const string& output_array : model->flags.output_arrays()) {
    137     if (array_name == output_array) {
    138       must_insert_fakequant_before = true;
    139     }
    140   }
    141   for (const auto& rnn_state : model->flags.rnn_states()) {
    142     if (array_name == rnn_state.state_array()) {
    143       must_insert_fakequant_after = true;
    144     }
    145     if (array_name == rnn_state.back_edge_source_array()) {
    146       must_insert_fakequant_before = true;
    147     }
    148   }
    149   CHECK(!(must_insert_fakequant_before && must_insert_fakequant_after));
    150 
    151   // Create and insert the FakeQuant node
    152   auto* fakequant_op = new FakeQuantOperator;
    153   model->operators.emplace(FindFirstOpWithInput(model, array_name),
    154                            fakequant_op);
    155   const string& new_array_name = AvailableArrayName(*model, array_name);
    156   auto& new_array = model->GetOrCreateArray(new_array_name);
    157   new_array.data_type = ArrayDataType::kFloat;
    158   new_array.copy_shape(array->shape());
    159   new_array.GetOrCreateMinMax() = array->GetMinMax();
    160   fakequant_op->minmax.reset(new MinMax);
    161   *fakequant_op->minmax = array->GetMinMax();
    162   if (must_insert_fakequant_before) {
    163     for (const auto& op : model->operators) {
    164       for (string& output : op->outputs) {
    165         if (output == array_name) {
    166           output = new_array_name;
    167         }
    168       }
    169     }
    170     fakequant_op->inputs = {new_array_name};
    171     fakequant_op->outputs = {array_name};
    172   } else {
    173     for (const auto& op : model->operators) {
    174       for (string& input : op->inputs) {
    175         if (input == array_name) {
    176           input = new_array_name;
    177         }
    178       }
    179     }
    180     fakequant_op->inputs = {array_name};
    181     fakequant_op->outputs = {new_array_name};
    182   }
    183   return true;
    184 }
    185 
    186 }  // namespace
    187 
    188 bool Dequantize::Run(Model* model, std::size_t op_index) {
    189   const auto op_it = model->operators.begin() + op_index;
    190   auto* op = op_it->get();
    191 
    192   if (op->type == OperatorType::kDequantize) {
    193     auto& input_array = model->GetArray(op->inputs[0]);
    194     if (input_array.data_type == ArrayDataType::kFloat) {
    195       return false;
    196     }
    197     if (input_array.final_data_type != ArrayDataType::kFloat) {
    198       return false;
    199     }
    200     input_array.data_type = ArrayDataType::kFloat;
    201     input_array.quantization_params = nullptr;
    202     auto& output_array = model->GetArray(op->outputs[0]);
    203     output_array.data_type = ArrayDataType::kFloat;
    204     output_array.quantization_params = nullptr;
    205     return RemoveTrivialPassthroughOp(this, model, op_index);
    206   }
    207 
    208   std::vector<string> arrays;
    209   for (const string& input : op->inputs) {
    210     arrays.push_back(input);
    211   }
    212   for (const string& output : op->outputs) {
    213     arrays.push_back(output);
    214   }
    215   bool changed = false;
    216   for (const string& array : arrays) {
    217     if (!model->IsOptionalArray(array)) {
    218       changed |= DequantizeArray(array, this, model);
    219     }
    220   }
    221 
    222   return changed;
    223 }
    224 
    225 }  // namespace toco
    226