Home | History | Annotate | Download | only in graph_transformations
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include <memory>
     16 #include <string>
     17 #include <unordered_map>
     18 #include <vector>
     19 
     20 #include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
     21 #include "tensorflow/contrib/lite/toco/model.h"
     22 #include "tensorflow/contrib/lite/toco/runtime/types.h"
     23 #include "tensorflow/contrib/lite/toco/tooling_util.h"
     24 #include "tensorflow/core/platform/logging.h"
     25 
     26 namespace toco {
     27 
     28 bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) {
     29   auto bn_it = model->operators.begin() + op_index;
     30   if (bn_it->get()->type != OperatorType::kBatchNormalization) {
     31     return false;
     32   }
     33   const auto* bn_op =
     34       static_cast<const BatchNormalizationOperator*>(bn_it->get());
     35 
     36   const auto& mean_array = model->GetArray(bn_op->inputs[1]);
     37   const auto& multiplier_array = model->GetArray(bn_op->inputs[2]);
     38   const auto& offset_array = model->GetArray(bn_op->inputs[3]);
     39 
     40   CHECK(IsConstantParameterArray(*model, bn_op->inputs[1]) &&
     41         IsConstantParameterArray(*model, bn_op->inputs[2]) &&
     42         IsConstantParameterArray(*model, bn_op->inputs[3]))
     43       << "Batch normalization resolution requires that mean, multiplier and "
     44          "offset arrays be constant.";
     45 
     46   // We should only have *float* BatchNormalizations... let's guard this
     47   // assumption by CHECK's.
     48   CHECK(mean_array.data_type == ArrayDataType::kFloat);
     49   CHECK(multiplier_array.data_type == ArrayDataType::kFloat);
     50   CHECK(offset_array.data_type == ArrayDataType::kFloat);
     51 
     52   // Create the new Mul, Add operators
     53   auto* mul_op = new MulOperator;
     54   auto* add_op = new AddOperator;
     55   const string mul_name =
     56       AvailableArrayName(*model, bn_op->outputs[0] + "_mul");
     57   const string add_name =
     58       AvailableArrayName(*model, bn_op->outputs[0] + "_add");
     59   const string mul_param_name = AvailableArrayName(*model, mul_name + "_param");
     60   const string add_param_name = AvailableArrayName(*model, add_name + "_param");
     61   mul_op->inputs = {bn_op->inputs[0], mul_param_name};
     62   mul_op->outputs = {mul_name};
     63   add_op->inputs = {mul_name, add_param_name};
     64   add_op->outputs = {bn_op->outputs[0]};
     65   AddMessageF("Splitting %s into %s and %s", LogName(*bn_op), LogName(*mul_op),
     66               LogName(*add_op));
     67 
     68   // Create the intermediate activation array (output of mul, input of add)
     69   auto& intermediate_array = model->GetOrCreateArray(mul_op->outputs[0]);
     70   intermediate_array.data_type = model->GetArray(bn_op->inputs[0]).data_type;
     71 
     72   // Insert the new operators in the graph
     73   auto add_it = model->operators.emplace(bn_it, add_op);
     74   auto mul_it = model->operators.emplace(add_it, mul_op);
     75   // update invalidated iterators.
     76   DCHECK_EQ(mul_it->get(), mul_op);
     77   add_it = mul_it + 1;
     78   DCHECK_EQ(add_it->get(), add_op);
     79   bn_it = add_it + 1;
     80   DCHECK_EQ(bn_it->get(), bn_op);
     81 
     82   // Create the new param arrays
     83   const auto& mean_shape = mean_array.shape();
     84   const auto& multiplier_shape = multiplier_array.shape();
     85   const auto& offset_shape = offset_array.shape();
     86   CHECK(mean_shape.dims() == multiplier_shape.dims());
     87   CHECK(mean_shape.dims() == offset_shape.dims());
     88   const auto& param_shape = mean_shape;
     89   const int buffer_size = RequiredBufferSizeForShape(param_shape);
     90   auto& mul_param_array = model->GetOrCreateArray(mul_param_name);
     91   auto& add_param_array = model->GetOrCreateArray(add_param_name);
     92   DropMinMax(model, mul_param_name);
     93   DropMinMax(model, add_param_name);
     94   mul_param_array.copy_shape(param_shape);
     95   add_param_array.copy_shape(param_shape);
     96   mul_param_array.data_type = ArrayDataType::kFloat;
     97   add_param_array.data_type = ArrayDataType::kFloat;
     98   auto& mul_float_data =
     99       mul_param_array.GetMutableBuffer<ArrayDataType::kFloat>().data;
    100   auto& add_float_data =
    101       add_param_array.GetMutableBuffer<ArrayDataType::kFloat>().data;
    102   mul_float_data.resize(buffer_size);
    103   add_float_data.resize(buffer_size);
    104   const auto& mean_float_data =
    105       mean_array.GetBuffer<ArrayDataType::kFloat>().data;
    106   const auto& multiplier_float_data =
    107       multiplier_array.GetBuffer<ArrayDataType::kFloat>().data;
    108   const auto& offset_float_data =
    109       offset_array.GetBuffer<ArrayDataType::kFloat>().data;
    110 
    111   CHECK(mul_float_data.size() == buffer_size);
    112   CHECK(add_float_data.size() == buffer_size);
    113   CHECK(mean_float_data.size() == buffer_size);
    114   CHECK(multiplier_float_data.size() == buffer_size);
    115   CHECK(offset_float_data.size() == buffer_size);
    116 
    117   for (int i = 0; i < buffer_size; i++) {
    118     mul_float_data[i] = multiplier_float_data[i];
    119     add_float_data[i] =
    120         offset_float_data[i] - mean_float_data[i] * multiplier_float_data[i];
    121   }
    122 
    123   // Remove the old param arrays
    124   model->EraseArray(bn_op->inputs[1]);
    125   model->EraseArray(bn_op->inputs[2]);
    126   model->EraseArray(bn_op->inputs[3]);
    127 
    128   // Remove the old operator
    129   DCHECK_EQ(bn_it->get(), bn_op);
    130   model->operators.erase(bn_it);
    131 
    132   return true;
    133 }
    134 
    135 }  // namespace toco
    136