Home | History | Annotate | Download | only in internal
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef ANDROID_ML_NN_COMMON_OPERATIONS_INTERNAL_TYPES_H_
     18 #define ANDROID_ML_NN_COMMON_OPERATIONS_INTERNAL_TYPES_H_
     19 
     20 #include "compatibility.h"
     21 
     22 namespace android {
     23 namespace nn {
     24 
     25 enum class FusedActivationFunctionType { kNone, kRelu6, kRelu1, kRelu };
     26 
     27 template <int N>
     28 struct Dims {
     29   int sizes[N];
     30   int strides[N];
     31 };
     32 
     33 struct Shape;
     34 
     35 inline Dims<4> convertShapeToDims(const Shape& shape) {
     36   Dims<4> dims;
     37   for (int i=0; i<4; i++) {
     38     dims.sizes[i] = 1;
     39   }
     40 
     41   if (shape.dimensions.size() == 1) {
     42     dims.sizes[0] = (int)getSizeOfDimension(shape, 0);
     43   } else {
     44     for (int i=0; i<4; i++) {
     45       int src = (int)shape.dimensions.size()-i-1;
     46       if (src >= 0) {
     47         dims.sizes[i] = (int)getSizeOfDimension(shape, src);
     48       }
     49     }
     50   }
     51 
     52   dims.strides[0] = 1;
     53   for (int i = 1; i<4; i++) {
     54     dims.strides[i] = dims.strides[i-1] * dims.sizes[i-1];
     55   }
     56   return dims;
     57 }
     58 
     59 inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) {
     60   DCHECK(i0 >= 0 && i0 < dims.sizes[0]);
     61   DCHECK(i1 >= 0 && i1 < dims.sizes[1]);
     62   DCHECK(i2 >= 0 && i2 < dims.sizes[2]);
     63   DCHECK(i3 >= 0 && i3 < dims.sizes[3]);
     64   return i0 * dims.strides[0] + i1 * dims.strides[1] + i2 * dims.strides[2] +
     65          i3 * dims.strides[3];
     66 }
     67 
     68 // Get array size, DCHECKing that the dim index is in range.
     69 template <int N>
     70 int ArraySize(const Dims<N>& array, int index) {
     71   DCHECK(index >= 0 && index < N);
     72   return array.sizes[index];
     73 }
     74 
     75 // Get common array size, DCHECKing that they all agree.
     76 template <typename ArrayType1, typename ArrayType2>
     77 int MatchingArraySize(const ArrayType1& array1, int index1,
     78                       const ArrayType2& array2, int index2) {
     79   DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
     80   return ArraySize(array1, index1);
     81 }
     82 
     83 template <typename ArrayType1, typename ArrayType2, typename... Args>
     84 int MatchingArraySize(const ArrayType1& array1, int index1,
     85                       const ArrayType2& array2, int index2, Args... args) {
     86   DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
     87   return MatchingArraySize(array1, index1, args...);
     88 }
     89 
     90 inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
     91   int max_offset = 0;
     92   for (int i = 0; i < 4; i++) {
     93     max_offset += (dims.sizes[i] - 1) * dims.strides[i];
     94   }
     95   return max_offset + 1;
     96 }
     97 
     98 template <int N>
     99 bool IsPackedWithoutStrides(const Dims<N>& dims) {
    100   int expected_stride = 1;
    101   for (int d = 0; d < N; d++) {
    102     if (dims.strides[d] != expected_stride) return false;
    103     expected_stride *= dims.sizes[d];
    104   }
    105   return true;
    106 }
    107 
    108 }  // namespace nn
    109 }  // namespace android
    110 
    111 #endif  // ANDROID_ML_NN_COMMON_OPERATIONS_INTERNAL_TYPES_H_
    112