Home | History | Annotate | Download | only in include
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef ANDROID_ML_NN_COMMON_CPU_OPERATION_UTILS_H
     18 #define ANDROID_ML_NN_COMMON_CPU_OPERATION_UTILS_H
     19 
     20 #include "OperationsUtils.h"
     21 
     22 #include <algorithm>
     23 #include <cmath>
     24 #include <limits>
     25 
     26 #include "tensorflow/lite/kernels/internal/types.h"
     27 
     28 namespace android {
     29 namespace nn {
     30 
     31 // The implementations in tflite/kernels/internal/ take a Dims<4> object
     32 // even if the original tensors were not 4D.
     33 inline tflite::Dims<4> convertShapeToDims(const Shape& shape) {
     34     nnAssert(shape.dimensions.size() <= 4);
     35     tflite::Dims<4> dims;
     36 
     37     // The dimensions are reversed in Dims<4>.
     38     for (int i = 0; i < 4; ++i) {
     39         int src = static_cast<int>(shape.dimensions.size()) - i - 1;
     40         if (src >= 0) {
     41             dims.sizes[i] = static_cast<int>(getSizeOfDimension(shape, src));
     42         } else {
     43             dims.sizes[i] = 1;
     44         }
     45     }
     46 
     47     dims.strides[0] = 1;
     48     for (int i = 1; i < 4; i++) {
     49         dims.strides[i] = dims.strides[i - 1] * dims.sizes[i - 1];
     50     }
     51     return dims;
     52 }
     53 
     54 inline tflite::RuntimeShape convertShapeToTflshape(const Shape& shape) {
     55     nnAssert(shape.dimensions.size() <= 4);
     56 
     57     std::vector<int32_t> tflShapeDim(shape.dimensions.begin(), shape.dimensions.end());
     58     return tflite::RuntimeShape(tflShapeDim.size(), tflShapeDim.data());
     59 }
     60 
     61 inline void convertFloat16ToFloat32(const _Float16* input, std::vector<float>* output) {
     62     CHECK(input != nullptr);
     63     CHECK(output != nullptr);
     64     for (int i = 0; i < output->size(); ++i) {
     65         (*output)[i] = static_cast<float>(input[i]);
     66     }
     67 }
     68 
     69 inline void convertFloat32ToFloat16(const std::vector<float>& input, _Float16* output) {
     70     CHECK(output != nullptr);
     71     for (int i = 0; i < input.size(); ++i) {
     72         output[i] = input[i];
     73     }
     74 }
     75 
     76 template <typename T>
     77 inline void convertQuantToFloat32(const T* input, float scale, int32_t zeroPoint,
     78                                   std::vector<float>* output) {
     79     CHECK(input != nullptr);
     80     CHECK(output != nullptr);
     81     for (int i = 0; i < output->size(); ++i) {
     82         (*output)[i] = (static_cast<float>(input[i]) - zeroPoint) * scale;
     83     }
     84 }
     85 
     86 template <typename T>
     87 inline void convertFloat32ToQuant(const std::vector<float>& input, float scale, int32_t zeroPoint,
     88                                   T* output) {
     89     CHECK(output != nullptr);
     90     for (int i = 0; i < input.size(); ++i) {
     91         int32_t intVal = std::round(input[i] / scale + zeroPoint);
     92         intVal = std::min<int32_t>(std::max<int32_t>(intVal, std::numeric_limits<T>::min()),
     93                                    std::numeric_limits<T>::max());
     94         output[i] = static_cast<T>(intVal);
     95     }
     96 }
     97 
     98 template <typename T>
     99 inline bool convertNchwToNhwc(const T* nchw, const Shape& nchwShape, std::vector<T>* nhwc,
    100                               Shape* nhwcShape) {
    101     NN_RET_CHECK_EQ(getNumberOfDimensions(nchwShape), 4)
    102             << "Error converting a non-4-D tensor to NHWC layout";
    103     *nhwcShape = nchwShape;
    104     const auto& fromDim = nchwShape.dimensions;
    105     nhwcShape->dimensions = {fromDim[0], fromDim[2], fromDim[3], fromDim[1]};
    106     nhwc->resize(getNumberOfElements(nchwShape));
    107     auto to = nhwc->data();
    108     uint32_t spatialSize = fromDim[2] * fromDim[3];
    109     for (uint32_t n = 0; n < fromDim[0]; n++) {
    110         for (uint32_t hw = 0; hw < spatialSize; hw++) {
    111             for (uint32_t c = 0; c < fromDim[1]; c++) {
    112                 uint32_t fromIndex = n * fromDim[1] * spatialSize + c * spatialSize + hw;
    113                 *to++ = nchw[fromIndex];
    114             }
    115         }
    116     }
    117     return true;
    118 }
    119 
    120 template <typename T>
    121 inline bool convertNhwcToNchw(const std::vector<T>& nhwc, const Shape& nhwcShape, T* nchw) {
    122     NN_RET_CHECK_EQ(getNumberOfDimensions(nhwcShape), 4)
    123             << "Error converting a non-4-D tensor to NCHW layout";
    124     const auto& fromDim = nhwcShape.dimensions;
    125     const auto from = nhwc.data();
    126     uint32_t spatialSize = fromDim[1] * fromDim[2];
    127     for (uint32_t n = 0; n < fromDim[0]; n++) {
    128         for (uint32_t c = 0; c < fromDim[3]; c++) {
    129             for (uint32_t hw = 0; hw < spatialSize; hw++) {
    130                 uint32_t fromIndex = n * spatialSize * fromDim[3] + hw * fromDim[3] + c;
    131                 *nchw++ = from[fromIndex];
    132             }
    133         }
    134     }
    135     return true;
    136 }
    137 
    138 template <typename T>
    139 class InputWithLayout {
    140    public:
    141     InputWithLayout(bool useNchw) : mDataOriginal(nullptr), mUseNchw(useNchw) {}
    142 
    143     bool initialize(const T* data, const Shape& shape) {
    144         mDataOriginal = data;
    145         mShape = shape;
    146         if (mUseNchw) {
    147             return convertNchwToNhwc(mDataOriginal, shape, &mDataNhwc, &mShape);
    148         }
    149         return true;
    150     }
    151 
    152     const T* getNhwcBuffer() { return mUseNchw ? mDataNhwc.data() : mDataOriginal; }
    153     const Shape& getNhwcShape() { return mShape; }
    154 
    155    private:
    156     const T* mDataOriginal;
    157     std::vector<T> mDataNhwc;
    158     Shape mShape;
    159     bool mUseNchw;
    160 };
    161 
    162 template <typename T>
    163 class OutputWithLayout {
    164    public:
    165     OutputWithLayout(bool useNchw) : mDataOriginal(nullptr), mUseNchw(useNchw) {}
    166 
    167     bool initialize(T* data, const Shape& shape) {
    168         NN_RET_CHECK_EQ(getNumberOfDimensions(shape), 4);
    169         mDataOriginal = data;
    170         mShape = shape;
    171         if (mUseNchw) {
    172             const auto& dim = shape.dimensions;
    173             mShape.dimensions = {dim[0], dim[2], dim[3], dim[1]};
    174             mDataNhwc.resize(getNumberOfElements(shape));
    175         }
    176         return true;
    177     }
    178 
    179     T* getNhwcBuffer() { return mUseNchw ? mDataNhwc.data() : mDataOriginal; }
    180     const Shape& getNhwcShape() { return mShape; }
    181     bool commit() {
    182         if (mUseNchw) {
    183             return convertNhwcToNchw(mDataNhwc, mShape, mDataOriginal);
    184         }
    185         return true;
    186     }
    187 
    188    private:
    189     T* mDataOriginal;
    190     std::vector<T> mDataNhwc;
    191     Shape mShape;
    192     bool mUseNchw;
    193 };
    194 
    195 }  // namespace nn
    196 }  // namespace android
    197 
    198 #endif  // ANDROID_ML_NN_COMMON_CPU_OPERATION_UTILS_H
    199