Home | History | Annotate | Download | only in internal
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
     16 #define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
     17 
     18 #include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
     19 
     20 namespace tflite {
     21 
     22 enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu };
     23 
     24 template <int N>
     25 struct Dims {
     26   int sizes[N];
     27   int strides[N];
     28 };
     29 
     30 // Gets next index to iterate through a multidimensional array.
     31 inline bool NextIndex(const int num_dims, const int* dims, int* current) {
     32   TFLITE_DCHECK_GT(num_dims, 0);
     33   TFLITE_DCHECK(dims != nullptr);
     34   TFLITE_DCHECK(current != nullptr);
     35   int carry = 1;
     36   for (int idx = num_dims - 1; idx >= 0; --idx) {
     37     int current_val = current[idx] + carry;
     38     TFLITE_DCHECK_GE(dims[idx], current_val);
     39     if (dims[idx] == current_val) {
     40       current[idx] = 0;
     41     } else {
     42       current[idx] = current_val;
     43       carry = 0;
     44       break;
     45     }
     46   }
     47   return (carry == 0);
     48 }
     49 
     50 // Gets offset of index if reducing on axis. When reducing, the flattened offset
     51 // will not change, if the input index changes on the given axis. For example,
     52 // if you have a 3D tensor and you are reducing to 2D by eliminating axis 0,
     53 // then index (0, 1, 2) and index (1, 1, 2) will map to the same flattened
     54 // offset.
     55 // TODO(kanlig): uses Dims to represent dimensions.
     56 inline size_t ReducedOutputOffset(const int num_dims, const int* dims,
     57                                   const int* index, const int num_axis,
     58                                   const int* axis) {
     59   TFLITE_DCHECK_GT(num_dims, 0);
     60   TFLITE_DCHECK(dims != nullptr);
     61   TFLITE_DCHECK(index != nullptr);
     62   size_t offset = 0;
     63   for (int idx = 0; idx < num_dims; ++idx) {
     64     // if we need to skip this axis
     65     bool is_axis = false;
     66     if (axis != nullptr) {
     67       for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) {
     68         if (idx == axis[axis_idx]) {
     69           is_axis = true;
     70           break;
     71         }
     72       }
     73     }
     74     if (!is_axis) {
     75       offset = offset * static_cast<size_t>(dims[idx]) +
     76                static_cast<size_t>(index[idx]);
     77     }
     78   }
     79   return offset;
     80 }
     81 
     82 inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) {
     83   TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]);
     84   TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]);
     85   TFLITE_DCHECK(i2 >= 0 && i2 < dims.sizes[2]);
     86   TFLITE_DCHECK(i3 >= 0 && i3 < dims.sizes[3]);
     87   return i0 * dims.strides[0] + i1 * dims.strides[1] + i2 * dims.strides[2] +
     88          i3 * dims.strides[3];
     89 }
     90 
     91 inline int Offset(const Dims<4>& dims, int* index) {
     92   return Offset(dims, index[0], index[1], index[2], index[3]);
     93 }
     94 
     95 // Get array size, DCHECKing that the dim index is in range.
     96 template <int N>
     97 int ArraySize(const Dims<N>& array, int index) {
     98   TFLITE_DCHECK(index >= 0 && index < N);
     99   return array.sizes[index];
    100 }
    101 
    102 // Get common array size, DCHECKing that they all agree.
    103 template <typename ArrayType1, typename ArrayType2>
    104 int MatchingArraySize(const ArrayType1& array1, int index1,
    105                       const ArrayType2& array2, int index2) {
    106   TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
    107   return ArraySize(array1, index1);
    108 }
    109 
    110 template <typename ArrayType1, typename ArrayType2, typename... Args>
    111 int MatchingArraySize(const ArrayType1& array1, int index1,
    112                       const ArrayType2& array2, int index2, Args... args) {
    113   TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
    114   return MatchingArraySize(array1, index1, args...);
    115 }
    116 
    117 inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
    118   int max_offset = 0;
    119   for (int i = 0; i < 4; i++) {
    120     max_offset += (dims.sizes[i] - 1) * dims.strides[i];
    121   }
    122   return max_offset + 1;
    123 }
    124 
    125 template <int N>
    126 bool IsPackedWithoutStrides(const Dims<N>& dims) {
    127   int expected_stride = 1;
    128   for (int d = 0; d < N; d++) {
    129     if (dims.strides[d] != expected_stride) return false;
    130     expected_stride *= dims.sizes[d];
    131   }
    132   return true;
    133 }
    134 
    135 }  // namespace tflite
    136 
    137 #endif  // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
    138