1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_ML_NN_COMMON_OPERATIONS_INTERNAL_TYPES_H_ 18 #define ANDROID_ML_NN_COMMON_OPERATIONS_INTERNAL_TYPES_H_ 19 20 #include "compatibility.h" 21 22 namespace android { 23 namespace nn { 24 25 enum class FusedActivationFunctionType { kNone, kRelu6, kRelu1, kRelu }; 26 27 template <int N> 28 struct Dims { 29 int sizes[N]; 30 int strides[N]; 31 }; 32 33 struct Shape; 34 35 inline Dims<4> convertShapeToDims(const Shape& shape) { 36 Dims<4> dims; 37 for (int i=0; i<4; i++) { 38 dims.sizes[i] = 1; 39 } 40 41 if (shape.dimensions.size() == 1) { 42 dims.sizes[0] = (int)getSizeOfDimension(shape, 0); 43 } else { 44 for (int i=0; i<4; i++) { 45 int src = (int)shape.dimensions.size()-i-1; 46 if (src >= 0) { 47 dims.sizes[i] = (int)getSizeOfDimension(shape, src); 48 } 49 } 50 } 51 52 dims.strides[0] = 1; 53 for (int i = 1; i<4; i++) { 54 dims.strides[i] = dims.strides[i-1] * dims.sizes[i-1]; 55 } 56 return dims; 57 } 58 59 inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) { 60 DCHECK(i0 >= 0 && i0 < dims.sizes[0]); 61 DCHECK(i1 >= 0 && i1 < dims.sizes[1]); 62 DCHECK(i2 >= 0 && i2 < dims.sizes[2]); 63 DCHECK(i3 >= 0 && i3 < dims.sizes[3]); 64 return i0 * dims.strides[0] + i1 * dims.strides[1] + i2 * dims.strides[2] + 65 i3 * dims.strides[3]; 66 } 67 68 // Get array size, DCHECKing that the dim index is in range. 69 template <int N> 70 int ArraySize(const Dims<N>& array, int index) { 71 DCHECK(index >= 0 && index < N); 72 return array.sizes[index]; 73 } 74 75 // Get common array size, DCHECKing that they all agree. 76 template <typename ArrayType1, typename ArrayType2> 77 int MatchingArraySize(const ArrayType1& array1, int index1, 78 const ArrayType2& array2, int index2) { 79 DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2)); 80 return ArraySize(array1, index1); 81 } 82 83 template <typename ArrayType1, typename ArrayType2, typename... Args> 84 int MatchingArraySize(const ArrayType1& array1, int index1, 85 const ArrayType2& array2, int index2, Args... args) { 86 DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2)); 87 return MatchingArraySize(array1, index1, args...); 88 } 89 90 inline int RequiredBufferSizeForDims(const Dims<4>& dims) { 91 int max_offset = 0; 92 for (int i = 0; i < 4; i++) { 93 max_offset += (dims.sizes[i] - 1) * dims.strides[i]; 94 } 95 return max_offset + 1; 96 } 97 98 template <int N> 99 bool IsPackedWithoutStrides(const Dims<N>& dims) { 100 int expected_stride = 1; 101 for (int d = 0; d < N; d++) { 102 if (dims.strides[d] != expected_stride) return false; 103 expected_stride *= dims.sizes[d]; 104 } 105 return true; 106 } 107 108 } // namespace nn 109 } // namespace android 110 111 #endif // ANDROID_ML_NN_COMMON_OPERATIONS_INTERNAL_TYPES_H_ 112