Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "Operations.h"
     18 #include "CpuOperationUtils.h"
     19 
     20 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
     21 
     22 namespace android {
     23 namespace nn {
     24 
     25 // If possible we will use this static buffer for the tensor.
     26 static constexpr size_t kStaticBufferSize = 1605632;
     27 static char static_scratch_buffer[kStaticBufferSize];
     28 
     29 // executionMutex is used to protect concurrent access of the static_scratch_buffer
     30 // and other non-threadsafe resources like gemmlowp::GemmContext.
     31 // std::mutex is safe for pthreads on Android.
     32 static std::mutex executionMutex;
     33 
     34 #define ANDROID_NN_CONV_PARAMETERS(Type)                                        \
     35     uint32_t height       = getSizeOfDimension(inputShape, 1);                  \
     36     uint32_t width        = getSizeOfDimension(inputShape, 2);                  \
     37     uint32_t filterHeight = getSizeOfDimension(filterShape, 1);                 \
     38     uint32_t filterWidth  = getSizeOfDimension(filterShape, 2);                 \
     39     uint32_t outHeight    = getSizeOfDimension(outputShape, 1);                 \
     40     uint32_t outWidth     = getSizeOfDimension(outputShape, 2);                 \
     41     uint32_t inDepth      = getSizeOfDimension(inputShape, 3);                  \
     42                                                                                 \
     43     uint32_t paddingHeight = (uint32_t)padding_top;                             \
     44     uint32_t paddingWidth = (uint32_t)padding_left;                             \
     45                                                                                 \
     46     tflite::Dims<4> im2colDim;                                                  \
     47     im2colDim.sizes[3] = (int)getSizeOfDimension(outputShape, 0);               \
     48     im2colDim.sizes[2] = (int)getSizeOfDimension(outputShape, 1);               \
     49     im2colDim.sizes[1] = (int)getSizeOfDimension(outputShape, 2);               \
     50     im2colDim.sizes[0] = (int)inDepth * filterHeight * filterWidth;             \
     51                                                                                 \
     52     im2colDim.strides[0] = 1;                                                   \
     53     for (int i=1; i<4; i++) {                                                   \
     54         im2colDim.strides[i] = im2colDim.strides[i-1] * im2colDim.sizes[i-1];   \
     55     }                                                                           \
     56                                                                                 \
     57     Type* im2colData = nullptr;                                                 \
     58     uint64_t im2colByteSize = sizeof(Type);                                     \
     59     std::unique_ptr<Type[]> im2colGuard;                                        \
     60     for (int i=0; i<4; i++) {                                                   \
     61         im2colByteSize *= im2colDim.sizes[i];                                   \
     62     }                                                                           \
     63     /* http://b/77982879, tflite::optimized_ops::Conv uses int for offsets */   \
     64     if (im2colByteSize >= 0x7fffffff)  {                                        \
     65         LOG(ERROR) << "Conv size is too large, not enough memory";              \
     66         return false;                                                           \
     67     }                                                                           \
     68     if (im2colByteSize <= kStaticBufferSize) {                                  \
     69         im2colData = reinterpret_cast<Type *>(static_scratch_buffer);           \
     70     } else {                                                                    \
     71         im2colData = new (std::nothrow) Type[im2colByteSize / sizeof(Type)];    \
     72         if (im2colData == nullptr) {                                            \
     73             LOG(ERROR) << "Conv size is too large, not enough memory";          \
     74             return false;                                                       \
     75         }                                                                       \
     76         im2colGuard.reset(im2colData);                                          \
     77     }
     78 
     79 bool convFloat32(const float* inputData, const Shape& inputShape,
     80                  const float* filterData, const Shape& filterShape,
     81                  const float* biasData, const Shape& biasShape,
     82                  int32_t padding_left, int32_t padding_right,
     83                  int32_t padding_top, int32_t padding_bottom,
     84                  int32_t stride_width, int32_t stride_height,
     85                  int32_t activation,
     86                  float* outputData, const Shape& outputShape) {
     87 
     88     ANDROID_NN_CONV_PARAMETERS(float)
     89 
     90     float output_activation_min, output_activation_max;
     91     CalculateActivationRangeFloat(activation, &output_activation_min,
     92                                   &output_activation_max);
     93 
     94     // Prevent concurrent executions that may access the scratch buffer.
     95     std::unique_lock<std::mutex> lock(executionMutex);
     96     tflite::optimized_ops::Conv(
     97             inputData, convertShapeToDims(inputShape),
     98             filterData, convertShapeToDims(filterShape),
     99             biasData, convertShapeToDims(biasShape),
    100             stride_width, stride_height, paddingWidth, paddingHeight,
    101             output_activation_min, output_activation_max,
    102             outputData, convertShapeToDims(outputShape),
    103             im2colData, im2colDim);
    104     return true;
    105 }
    106 
    107 bool convQuant8(const uint8_t* inputData, const Shape& inputShape,
    108                 const uint8_t* filterData, const Shape& filterShape,
    109                 const int32_t* biasData, const Shape& biasShape,
    110                 int32_t padding_left, int32_t padding_right,
    111                 int32_t padding_top, int32_t padding_bottom,
    112                 int32_t stride_width, int32_t stride_height,
    113                 int32_t activation,
    114                 uint8_t* outputData, const Shape& outputShape) {
    115 
    116     ANDROID_NN_CONV_PARAMETERS(uint8_t)
    117 
    118     int32_t inputOffset = -inputShape.offset;
    119     int32_t filterOffset = -filterShape.offset;
    120     int32_t outputOffset = outputShape.offset;
    121 
    122     float real_multiplier = 0.0;
    123     int32_t output_multiplier = 0;
    124     int32_t output_shift = 0;
    125     int32_t output_activation_min = 0;
    126     int32_t output_activation_max = 0;
    127 
    128     if (!GetQuantizedConvolutionMultipler(inputShape, filterShape, biasShape,
    129                                           outputShape, &real_multiplier) ||
    130             !QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier,
    131                                               &output_shift)){
    132         return false;
    133     }
    134     CalculateActivationRangeUint8(activation, outputShape,
    135                                   &output_activation_min,
    136                                   &output_activation_max);
    137 
    138     static gemmlowp::GemmContext gemm_context;
    139 
    140     // Prevent concurrent executions that may access the scratch buffer and
    141     // gemm_context.
    142     std::unique_lock<std::mutex> lock(executionMutex);
    143     // Alow gemmlowp automatically decide how many threads to use.
    144     gemm_context.set_max_num_threads(0);
    145     tflite::optimized_ops::Conv(
    146             inputData, convertShapeToDims(inputShape), inputOffset,
    147             filterData, convertShapeToDims(filterShape), filterOffset,
    148             biasData, convertShapeToDims(biasShape),
    149             stride_width, stride_height, paddingWidth, paddingHeight,
    150             outputOffset, output_multiplier, output_shift,
    151             output_activation_min, output_activation_max,
    152             outputData, convertShapeToDims(outputShape),
    153             im2colData, im2colDim, &gemm_context);
    154     return true;
    155 }
    156 
    157 #undef ANDROID_NN_CONV_PARAMETERS
    158 }  // namespace nn
    159 }  // namespace android
    160