Home | History | Annotate | Download | only in graph_transformations
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include <algorithm>
     16 #include <memory>
     17 #include <string>
     18 #include <unordered_map>
     19 #include <vector>
     20 
     21 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
     22 #include "tensorflow/lite/toco/model.h"
     23 #include "tensorflow/lite/toco/runtime/types.h"
     24 #include "tensorflow/lite/toco/tooling_util.h"
     25 #include "tensorflow/core/platform/logging.h"
     26 
     27 namespace toco {
     28 
     29 namespace {
     30 
     31 std::vector<bool> VectorGreaterThan(const std::vector<int>& a,
     32                                     const std::vector<int>& b) {
     33   DCHECK_EQ(a.size(), b.size());
     34   const int size = a.size();
     35   std::vector<bool> result(size);
     36   for (int i = 0; i < size; i++) {
     37     result[i] = a[i] > b[i];
     38   }
     39   return result;
     40 }
     41 
     42 void PairwiseVectorSelect(const std::vector<bool>& selector,
     43                           const std::vector<int>& input_a,
     44                           const std::vector<int>& input_b,
     45                           std::vector<int>* output_a,
     46                           std::vector<int>* output_b) {
     47   DCHECK_EQ(input_a.size(), input_b.size());
     48   DCHECK_EQ(output_a->size(), output_b->size());
     49   DCHECK_EQ(input_a.size(), output_a->size());
     50   DCHECK_EQ(selector.size(), input_a.size());
     51   const int size = input_a.size();
     52   for (int i = 0; i < size; i++) {
     53     if (selector[i]) {
     54       (*output_a)[i] = input_a[i];
     55       (*output_b)[i] = input_b[i];
     56     } else {
     57       (*output_a)[i] = input_b[i];
     58       (*output_b)[i] = input_a[i];
     59     }
     60   }
     61 }
     62 
     63 template <ArrayDataType InputsDataType, ArrayDataType OutputDataType>
     64 void EvaluateBinaryOperatorOnConstantInputs(Model* model,
     65                                             const Operator* binary_op) {
     66   CHECK(IsConstantParameterArray(*model, binary_op->inputs[0]));
     67   CHECK(IsConstantParameterArray(*model, binary_op->inputs[1]));
     68   CHECK(binary_op->fused_activation_function ==
     69         FusedActivationFunctionType::kNone);
     70   const auto& input0_array = model->GetArray(binary_op->inputs[0]);
     71   const auto& input1_array = model->GetArray(binary_op->inputs[1]);
     72   const auto& output_name = binary_op->outputs[0];
     73   auto& output_array = model->GetArray(output_name);
     74   CHECK(input0_array.data_type == InputsDataType);
     75   CHECK(input1_array.data_type == InputsDataType);
     76   CHECK(output_array.data_type == OutputDataType);
     77 
     78   // We have already tested above for existence of input buffers
     79   // (synonymous to being a constant param).
     80   CHECK(input0_array.buffer);
     81   CHECK(input1_array.buffer);
     82   // On the other hand, the output should not already have a buffer.
     83   CHECK(!output_array.buffer);
     84 
     85   const auto& input0_data = input0_array.GetBuffer<InputsDataType>().data;
     86   const auto& input1_data = input1_array.GetBuffer<InputsDataType>().data;
     87   // Create the buffer on the output array, effectively turning it into
     88   // a constant parameter
     89 
     90   const Shape& output_shape = output_array.shape();
     91   auto& output_data = output_array.GetMutableBuffer<OutputDataType>().data;
     92   const int output_buffer_size = RequiredBufferSizeForShape(output_shape);
     93   output_data.resize(output_buffer_size);
     94   const int dims_count = output_shape.dimensions_count();
     95 
     96   // It will be convenient here to have copies of the operands shapes
     97   // extended to match the number of dimensions of the output shape.
     98   Shape input0_shape = input0_array.shape();
     99   Shape input1_shape = input1_array.shape();
    100   ExtendShape(&input0_shape, dims_count);
    101   ExtendShape(&input1_shape, dims_count);
    102   // Now we may still have operands of different sizes, which would indicate
    103   // that we have to "broadcast" the smaller dimension.  We do this using a
    104   // a vector of Booleans indicating which input is the larger in each
    105   // dimension.
    106   CHECK_EQ(input0_shape.dimensions_count(), input1_shape.dimensions_count());
    107   CHECK_EQ(input0_shape.dimensions_count(), dims_count);
    108   const std::vector<bool> input0_larger =
    109       VectorGreaterThan(input0_shape.dims(), input1_shape.dims());
    110 
    111   std::vector<int> big_sizes(dims_count);
    112   std::vector<int> small_sizes(dims_count);
    113   PairwiseVectorSelect(input0_larger, input0_shape.dims(), input1_shape.dims(),
    114                        &big_sizes, &small_sizes);
    115 
    116   // The output should already be correctly sized to match the big dimensions.
    117   for (int i = 0; i < dims_count; i++) {
    118     CHECK_EQ(output_shape.dims(i), big_sizes[i]);
    119   }
    120 
    121   std::vector<int> input0_indices(dims_count);
    122   std::vector<int> input1_indices(dims_count);
    123   std::vector<int> modulo_indices(dims_count);
    124 
    125   for (int k = 0; k < output_buffer_size; k++) {
    126     const std::vector<int> output_indices = ReverseOffset(output_shape, k);
    127     for (int i = 0; i < dims_count; i++) {
    128       modulo_indices[i] = output_indices[i] % small_sizes[i];
    129     }
    130     PairwiseVectorSelect(input0_larger, output_indices, modulo_indices,
    131                          &input0_indices, &input1_indices);
    132     const auto val0 = input0_data[Offset(input0_shape, input0_indices)];
    133     const auto val1 = input1_data[Offset(input1_shape, input1_indices)];
    134 
    135     DataType<OutputDataType> outval;
    136     if (binary_op->type == OperatorType::kAdd) {
    137       outval = val0 + val1;
    138     } else if (binary_op->type == OperatorType::kMul) {
    139       outval = val0 * val1;
    140     } else if (binary_op->type == OperatorType::kSub) {
    141       outval = val0 - val1;
    142     } else if (binary_op->type == OperatorType::kDiv) {
    143       outval = val0 / val1;
    144     } else if (binary_op->type == OperatorType::kFloorDiv) {
    145       outval = floor(val0 / val1);
    146     } else if (binary_op->type == OperatorType::kFloorMod) {
    147       outval = val0 - (floor(val0 / val1) * val1);
    148     } else if (binary_op->type == OperatorType::kMinimum) {
    149       outval = std::min(val0, val1);
    150     } else if (binary_op->type == OperatorType::kMaximum) {
    151       outval = std::max(val0, val1);
    152     } else if (binary_op->type == OperatorType::kLess) {
    153       outval = val0 < val1;
    154     } else if (binary_op->type == OperatorType::kLessEqual) {
    155       outval = val0 <= val1;
    156     } else if (binary_op->type == OperatorType::kGreater) {
    157       outval = val0 > val1;
    158     } else if (binary_op->type == OperatorType::kGreaterEqual) {
    159       outval = val0 >= val1;
    160     } else {
    161       LOG(FATAL) << "should not get here";
    162     }
    163     output_data[Offset(output_shape, output_indices)] = outval;
    164   }
    165 }
    166 
    167 void EvaluateBinaryOperatorOnConstantInputs(Model* model,
    168                                             const Operator* binary_op) {
    169   const auto inputs_data_type = model->GetArray(binary_op->inputs[0]).data_type;
    170   const auto output_data_type =
    171       model->GetArray(binary_op->outputs[0]).data_type;
    172 #define TOCO_HANDLE_CASE(InputsDataType, OutputDataType)                    \
    173   if (inputs_data_type == InputsDataType &&                                 \
    174       output_data_type == OutputDataType) {                                 \
    175     EvaluateBinaryOperatorOnConstantInputs<InputsDataType, OutputDataType>( \
    176         model, binary_op);                                                  \
    177     return;                                                                 \
    178   }
    179   TOCO_HANDLE_CASE(ArrayDataType::kFloat, ArrayDataType::kFloat)
    180   TOCO_HANDLE_CASE(ArrayDataType::kFloat, ArrayDataType::kBool)
    181   TOCO_HANDLE_CASE(ArrayDataType::kInt32, ArrayDataType::kInt32)
    182   TOCO_HANDLE_CASE(ArrayDataType::kInt32, ArrayDataType::kBool)
    183   TOCO_HANDLE_CASE(ArrayDataType::kInt64, ArrayDataType::kInt64)
    184   TOCO_HANDLE_CASE(ArrayDataType::kInt64, ArrayDataType::kBool)
    185   LOG(FATAL) << "Unimplemented: don't know how to resolve a constant "
    186              << "binary operator for these data types.";
    187 #undef TOCO_HANDLE_CASE
    188 }
    189 }  // namespace
    190 
    191 ::tensorflow::Status ResolveConstantBinaryOperator::Run(Model* model,
    192                                                         std::size_t op_index,
    193                                                         bool* modified) {
    194   *modified = false;
    195   const auto binary_it = model->operators.begin() + op_index;
    196   const auto* binary_op = binary_it->get();
    197   // Test for binary ops of types that we know how to resolve
    198   if (binary_op->type != OperatorType::kAdd &&
    199       binary_op->type != OperatorType::kMul &&
    200       binary_op->type != OperatorType::kSub &&
    201       binary_op->type != OperatorType::kDiv &&
    202       binary_op->type != OperatorType::kFloorDiv &&
    203       binary_op->type != OperatorType::kFloorMod &&
    204       binary_op->type != OperatorType::kMinimum &&
    205       binary_op->type != OperatorType::kMaximum &&
    206       binary_op->type != OperatorType::kLess &&
    207       binary_op->type != OperatorType::kLessEqual &&
    208       binary_op->type != OperatorType::kGreater &&
    209       binary_op->type != OperatorType::kGreaterEqual) {
    210     return ::tensorflow::Status::OK();
    211   }
    212   CHECK_EQ(binary_op->inputs.size(), 2);
    213 
    214   const auto& input0_array = model->GetArray(binary_op->inputs[0]);
    215   const auto& input1_array = model->GetArray(binary_op->inputs[1]);
    216   // Check if both inputs are constant parameters.
    217   if (!input0_array.buffer || !input1_array.buffer) {
    218     return ::tensorflow::Status::OK();
    219   }
    220 
    221   auto& output_array = model->GetArray(binary_op->outputs[0]);
    222   // Yield until the output array dims have been resolved.
    223   if (!output_array.has_shape()) {
    224     return ::tensorflow::Status::OK();
    225   }
    226 
    227   // At the moment we don't want to care about fused activation functions.
    228   // The idea is that we should do the present constants-propagation before
    229   // activation functions get fused.
    230   if (binary_op->fused_activation_function !=
    231       FusedActivationFunctionType::kNone) {
    232     AddMessageF(
    233         "Not resolving constant %s because it has a fused activation function",
    234         LogName(*binary_op));
    235     return ::tensorflow::Status::OK();
    236   }
    237 
    238   // Check that input data types agree.
    239   CHECK(input0_array.data_type == input1_array.data_type)
    240       << "Dissimilar data types given to op outputting \""
    241       << binary_op->outputs[0] << "\". 0:\"" << binary_op->inputs[0] << "\"("
    242       << static_cast<int>(input0_array.data_type) << ")   1:\""
    243       << binary_op->inputs[1] << "\"("
    244       << static_cast<int>(input1_array.data_type) << ").";
    245 
    246   // Do the actual constants propagation
    247   EvaluateBinaryOperatorOnConstantInputs(model, binary_op);
    248 
    249   // Remove the binary operator and its inputs
    250   if (CountOpsWithInput(*model, binary_op->inputs[0]) == 1) {
    251     model->EraseArray(binary_op->inputs[0]);
    252   }
    253   if (CountOpsWithInput(*model, binary_op->inputs[1]) == 1) {
    254     model->EraseArray(binary_op->inputs[1]);
    255   }
    256   AddMessageF("Resolved constant %s to the equivalent constant array",
    257               LogName(*binary_op));
    258   model->operators.erase(binary_it);
    259   *modified = true;
    260   return ::tensorflow::Status::OK();
    261 }
    262 
    263 }  // namespace toco
    264