Home | History | Annotate | Download | only in graph_transformations
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include <iterator>
     16 #include <memory>
     17 #include <string>
     18 #include <unordered_map>
     19 #include <vector>
     20 
     21 #include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
     22 #include "tensorflow/contrib/lite/toco/model.h"
     23 #include "tensorflow/contrib/lite/toco/tooling_util.h"
     24 
     25 namespace toco {
     26 
     27 namespace {
     28 
     29 template <typename T>
     30 bool AreAllBufferElementsZero(const std::vector<T>& buffer_data) {
     31   for (auto x : buffer_data) {
     32     if (x != 0) {
     33       return false;
     34     }
     35   }
     36   return true;
     37 }
     38 
     39 template <ArrayDataType Type>
     40 void FillArrayWithZeros(Array* array) {
     41   CHECK(array->data_type == Type);
     42   std::vector<DataType<Type>>& data = array->GetMutableBuffer<Type>().data;
     43   data.resize(RequiredBufferSizeForShape(array->shape()));
     44   for (size_t i = 0; i < data.size(); i++) {
     45     data[i] = 0;
     46   }
     47 }
     48 
     49 }  // namespace
     50 
     51 // Removes a multiplication by array of constant zeros by making the output
     52 // array an array of constant zeros and removing the input arrays if they are no
     53 // longer needed.
     54 bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
     55   const auto mul_it = model->operators.begin() + op_index;
     56   auto* mul_op = mul_it->get();
     57   if (mul_op->type != OperatorType::kMul) {
     58     return false;
     59   }
     60   const auto& output_array_name = mul_op->outputs[0];
     61   auto& output_array = model->GetArray(output_array_name);
     62 
     63   // Yield if the output shape is not known yet.
     64   if (!output_array.has_shape()) {
     65     return false;
     66   }
     67 
     68   // This transformation only handles the case where one operand is all 0's and
     69   // the other is non-constant. Other cases are handled by constant propagation
     70   // or the trivial binary removal pass.
     71   const bool is_input_constant[2] = {
     72       IsConstantParameterArray(*model, mul_op->inputs[0]),
     73       IsConstantParameterArray(*model, mul_op->inputs[1]),
     74   };
     75   if (!is_input_constant[0] && !is_input_constant[1]) {
     76     // Neither input is constant, so nothing we can resolve here.
     77     return false;
     78   }
     79   if (is_input_constant[0] && is_input_constant[1]) {
     80     // Both inputs are constants. That's a job for constants propagation, not
     81     // for us to handle here.
     82     return false;
     83   }
     84   const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
     85   const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
     86   CHECK(is_input_constant[index_of_constant_input]);
     87   CHECK(!is_input_constant[index_of_variable_input]);
     88 
     89   const auto& constant_input_array =
     90       model->GetArray(mul_op->inputs[index_of_constant_input]);
     91 
     92   CHECK(constant_input_array.data_type == output_array.data_type);
     93   switch (output_array.data_type) {
     94     case ArrayDataType::kFloat: {
     95       const auto& constant_input_data =
     96           constant_input_array.GetBuffer<ArrayDataType::kFloat>().data;
     97       if (!AreAllBufferElementsZero<DataType<ArrayDataType::kFloat>>(
     98               constant_input_data)) {
     99         return false;
    100       }
    101       FillArrayWithZeros<ArrayDataType::kFloat>(&output_array);
    102     } break;
    103     case ArrayDataType::kUint8: {
    104       const auto& constant_input_data =
    105           constant_input_array.GetBuffer<ArrayDataType::kUint8>().data;
    106       if (!AreAllBufferElementsZero<DataType<ArrayDataType::kUint8>>(
    107               constant_input_data)) {
    108         return false;
    109       }
    110       FillArrayWithZeros<ArrayDataType::kUint8>(&output_array);
    111     } break;
    112     case ArrayDataType::kInt32: {
    113       const auto& constant_input_data =
    114           constant_input_array.GetBuffer<ArrayDataType::kInt32>().data;
    115       if (!AreAllBufferElementsZero<DataType<ArrayDataType::kInt32>>(
    116               constant_input_data)) {
    117         return false;
    118       }
    119       FillArrayWithZeros<ArrayDataType::kInt32>(&output_array);
    120     } break;
    121     case ArrayDataType::kInt64: {
    122       const auto& constant_input_data =
    123           constant_input_array.GetBuffer<ArrayDataType::kInt64>().data;
    124       if (!AreAllBufferElementsZero<DataType<ArrayDataType::kInt64>>(
    125               constant_input_data)) {
    126         return false;
    127       }
    128       FillArrayWithZeros<ArrayDataType::kInt64>(&output_array);
    129     } break;
    130     default:
    131       AddMessageF(
    132           "Cannot resolve multiply by 0 because of unsupported data type\n");
    133       return false;
    134   }
    135 
    136   // Erase input arrays to the multiply if no longer used
    137   if (IsDiscardableArray(*model, mul_op->inputs[0]) &&
    138       CountOpsWithInput(*model, mul_op->inputs[0]) == 1) {
    139     model->EraseArray(mul_op->inputs[0]);
    140   }
    141   if (IsDiscardableArray(*model, mul_op->inputs[1]) &&
    142       CountOpsWithInput(*model, mul_op->inputs[1]) == 1) {
    143     model->EraseArray(mul_op->inputs[1]);
    144   }
    145 
    146   // Erase the multiply operator.
    147   model->operators.erase(mul_it);
    148 
    149   return true;
    150 }
    151 
    152 }  // namespace toco
    153