1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ 16 #define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ 17 18 #include "tensorflow/contrib/lite/kernels/internal/compatibility.h" 19 20 namespace tflite { 21 22 enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu }; 23 24 template <int N> 25 struct Dims { 26 int sizes[N]; 27 int strides[N]; 28 }; 29 30 // Gets next index to iterate through a multidimensional array. 31 inline bool NextIndex(const int num_dims, const int* dims, int* current) { 32 TFLITE_DCHECK_GT(num_dims, 0); 33 TFLITE_DCHECK(dims != nullptr); 34 TFLITE_DCHECK(current != nullptr); 35 int carry = 1; 36 for (int idx = num_dims - 1; idx >= 0; --idx) { 37 int current_val = current[idx] + carry; 38 TFLITE_DCHECK_GE(dims[idx], current_val); 39 if (dims[idx] == current_val) { 40 current[idx] = 0; 41 } else { 42 current[idx] = current_val; 43 carry = 0; 44 break; 45 } 46 } 47 return (carry == 0); 48 } 49 50 // Gets offset of index if reducing on axis. When reducing, the flattened offset 51 // will not change, if the input index changes on the given axis. For example, 52 // if you have a 3D tensor and you are reducing to 2D by eliminating axis 0, 53 // then index (0, 1, 2) and index (1, 1, 2) will map to the same flattened 54 // offset. 55 // TODO(kanlig): uses Dims to represent dimensions. 56 inline size_t ReducedOutputOffset(const int num_dims, const int* dims, 57 const int* index, const int num_axis, 58 const int* axis) { 59 TFLITE_DCHECK_GT(num_dims, 0); 60 TFLITE_DCHECK(dims != nullptr); 61 TFLITE_DCHECK(index != nullptr); 62 size_t offset = 0; 63 for (int idx = 0; idx < num_dims; ++idx) { 64 // if we need to skip this axis 65 bool is_axis = false; 66 if (axis != nullptr) { 67 for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) { 68 if (idx == axis[axis_idx]) { 69 is_axis = true; 70 break; 71 } 72 } 73 } 74 if (!is_axis) { 75 offset = offset * static_cast<size_t>(dims[idx]) + 76 static_cast<size_t>(index[idx]); 77 } 78 } 79 return offset; 80 } 81 82 inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) { 83 TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]); 84 TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]); 85 TFLITE_DCHECK(i2 >= 0 && i2 < dims.sizes[2]); 86 TFLITE_DCHECK(i3 >= 0 && i3 < dims.sizes[3]); 87 return i0 * dims.strides[0] + i1 * dims.strides[1] + i2 * dims.strides[2] + 88 i3 * dims.strides[3]; 89 } 90 91 inline int Offset(const Dims<4>& dims, int* index) { 92 return Offset(dims, index[0], index[1], index[2], index[3]); 93 } 94 95 // Get array size, DCHECKing that the dim index is in range. 96 template <int N> 97 int ArraySize(const Dims<N>& array, int index) { 98 TFLITE_DCHECK(index >= 0 && index < N); 99 return array.sizes[index]; 100 } 101 102 // Get common array size, DCHECKing that they all agree. 103 template <typename ArrayType1, typename ArrayType2> 104 int MatchingArraySize(const ArrayType1& array1, int index1, 105 const ArrayType2& array2, int index2) { 106 TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2)); 107 return ArraySize(array1, index1); 108 } 109 110 template <typename ArrayType1, typename ArrayType2, typename... Args> 111 int MatchingArraySize(const ArrayType1& array1, int index1, 112 const ArrayType2& array2, int index2, Args... args) { 113 TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2)); 114 return MatchingArraySize(array1, index1, args...); 115 } 116 117 inline int RequiredBufferSizeForDims(const Dims<4>& dims) { 118 int max_offset = 0; 119 for (int i = 0; i < 4; i++) { 120 max_offset += (dims.sizes[i] - 1) * dims.strides[i]; 121 } 122 return max_offset + 1; 123 } 124 125 template <int N> 126 bool IsPackedWithoutStrides(const Dims<N>& dims) { 127 int expected_stride = 1; 128 for (int d = 0; d < N; d++) { 129 if (dims.strides[d] != expected_stride) return false; 130 expected_stride *= dims.sizes[d]; 131 } 132 return true; 133 } 134 135 } // namespace tflite 136 137 #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ 138