Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 // Contains the implementation of the operations.
     18 
     19 #define LOG_TAG "Operations"
     20 
     21 #include "Operations.h"
     22 #include "CpuOperationUtils.h"
     23 
     24 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
     25 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
     26 
     27 namespace android {
     28 namespace nn {
     29 
     30 bool reshapeGeneric(const void* inputData, const Shape& inputShape,
     31                     void* outputData, const Shape& outputShape) {
     32     size_t count = sizeOfData(inputShape.type, inputShape.dimensions);
     33     memcpy(outputData, inputData, count);
     34     return true;
     35 }
     36 
     37 bool resizeBilinearFloat32(const float* inputData, const Shape& inputShape,
     38                            float* outputData, const Shape& outputShape) {
     39     int32_t height = (int32_t) getSizeOfDimension(outputShape, 1);
     40     int32_t width  = (int32_t) getSizeOfDimension(outputShape, 2);
     41 
     42     int32_t outDimData[2] = {height, width};
     43     // We have to fake a tensor here, to satisfy ResizeBilinear().
     44     Shape outDimShape;
     45     outDimShape.dimensions = {1, 1, 1, 2};
     46 
     47     tflite::optimized_ops::ResizeBilinear(
     48             inputData, convertShapeToDims(inputShape),
     49             outDimData, convertShapeToDims(outDimShape),
     50             outputData, convertShapeToDims(outputShape));
     51     return true;
     52 }
     53 
     54 bool depthToSpaceGeneric(const uint8_t* inputData, const Shape& inputShape,
     55                          int32_t blockSize,
     56                          uint8_t* outputData, const Shape& outputShape) {
     57     if (inputShape.type == OperandType::TENSOR_FLOAT32) {
     58        tflite::optimized_ops::DepthToSpace(
     59                 reinterpret_cast<const float*>(inputData),
     60                 convertShapeToDims(inputShape),
     61                 blockSize,
     62                 reinterpret_cast<float*>(outputData),
     63                 convertShapeToDims(outputShape));
     64     } else if (inputShape.type == OperandType::TENSOR_QUANT8_ASYMM) {
     65         tflite::optimized_ops::DepthToSpace(
     66                 reinterpret_cast<const uint8_t*>(inputData),
     67                 convertShapeToDims(inputShape),
     68                 blockSize,
     69                 reinterpret_cast<uint8_t*>(outputData),
     70                 convertShapeToDims(outputShape));
     71     } else {
     72         LOG(ERROR) << "Unsupported data type";
     73         return false;
     74     }
     75     return true;
     76 }
     77 
     78 bool spaceToDepthGeneric(const uint8_t* inputData, const Shape& inputShape,
     79                          int32_t blockSize,
     80                          uint8_t* outputData, const Shape& outputShape) {
     81     if (inputShape.type == OperandType::TENSOR_FLOAT32) {
     82         tflite::optimized_ops::SpaceToDepth(
     83                 reinterpret_cast<const float*>(inputData),
     84                 convertShapeToDims(inputShape),
     85                 blockSize,
     86                 reinterpret_cast<float*>(outputData),
     87                 convertShapeToDims(outputShape));
     88     } else if (inputShape.type == OperandType::TENSOR_QUANT8_ASYMM) {
     89         tflite::optimized_ops::SpaceToDepth(
     90                 reinterpret_cast<const uint8_t*>(inputData),
     91                 convertShapeToDims(inputShape),
     92                 blockSize,
     93                 reinterpret_cast<uint8_t*>(outputData),
     94                 convertShapeToDims(outputShape));
     95     } else {
     96         LOG(ERROR) << "Unsupported data type";
     97         return false;
     98     }
     99     return true;
    100 }
    101 
    102 bool padGeneric(const uint8_t* inputData, const Shape& inputShape,
    103                 const int32_t* paddings,
    104                 uint8_t* outputData, const Shape& outputShape) {
    105     int32_t numInputDims = static_cast<int32_t>(getNumberOfDimensions(inputShape));
    106 
    107     std::vector<int> beforePadding;
    108     std::vector<int> afterPadding;
    109     // The lower level implementation expects the paddings in the reverse order.
    110     for (int32_t i = numInputDims - 1; i >= 0; --i) {
    111         beforePadding.push_back(paddings[i * 2]);
    112         afterPadding.push_back(paddings[i * 2 + 1]);
    113     }
    114 
    115     if (inputShape.type == OperandType::TENSOR_FLOAT32) {
    116         tflite::optimized_ops::Pad(
    117                 reinterpret_cast<const float*>(inputData),
    118                 convertShapeToDims(inputShape),
    119                 beforePadding, afterPadding,
    120                 reinterpret_cast<float*>(outputData),
    121                 convertShapeToDims(outputShape));
    122     } else if (inputShape.type == OperandType::TENSOR_QUANT8_ASYMM) {
    123         tflite::optimized_ops::Pad(
    124                 reinterpret_cast<const uint8_t*>(inputData),
    125                 convertShapeToDims(inputShape),
    126                 beforePadding, afterPadding,
    127                 reinterpret_cast<uint8_t*>(outputData),
    128                 convertShapeToDims(outputShape));
    129     } else {
    130         LOG(ERROR) << "Unsupported data type";
    131         return false;
    132     }
    133     return true;
    134 }
    135 
    136 bool batchToSpaceGeneric(const uint8_t* inputData, const Shape& inputShape,
    137                          const int32_t* blockSize,
    138                          uint8_t* outputData, const Shape& outputShape) {
    139     // Needed by low level implementation, but not really used.
    140     tflite::Dims<4> blockSizeDim;
    141     if (inputShape.type == OperandType::TENSOR_FLOAT32) {
    142        tflite::optimized_ops::BatchToSpaceND(
    143                 reinterpret_cast<const float*>(inputData),
    144                 convertShapeToDims(inputShape),
    145                 blockSize, blockSizeDim,
    146                 reinterpret_cast<float*>(outputData),
    147                 convertShapeToDims(outputShape));
    148     } else if (inputShape.type == OperandType::TENSOR_QUANT8_ASYMM) {
    149         tflite::optimized_ops::BatchToSpaceND(
    150                 reinterpret_cast<const uint8_t*>(inputData),
    151                 convertShapeToDims(inputShape),
    152                 blockSize, blockSizeDim,
    153                 reinterpret_cast<uint8_t*>(outputData),
    154                 convertShapeToDims(outputShape));
    155     } else {
    156         LOG(ERROR) << "Unsupported data type";
    157         return false;
    158     }
    159     return true;
    160 }
    161 
    162 bool spaceToBatchGeneric(const uint8_t* inputData, const Shape& inputShape,
    163                          const int32_t* blockSize,
    164                          const int32_t* padding, const Shape& paddingShape,
    165                          uint8_t* outputData, const Shape& outputShape) {
    166     // Needed by low level implementation, but not really used.
    167     tflite::Dims<4> blockSizeDim;
    168     if (inputShape.type == OperandType::TENSOR_FLOAT32) {
    169         tflite::optimized_ops::SpaceToBatchND(
    170                 reinterpret_cast<const float*>(inputData),
    171                 convertShapeToDims(inputShape),
    172                 blockSize, blockSizeDim,
    173                 padding, convertShapeToDims(paddingShape),
    174                 reinterpret_cast<float*>(outputData),
    175                 convertShapeToDims(outputShape));
    176     } else if (inputShape.type == OperandType::TENSOR_QUANT8_ASYMM) {
    177         tflite::optimized_ops::SpaceToBatchND(
    178                 reinterpret_cast<const uint8_t*>(inputData),
    179                 convertShapeToDims(inputShape),
    180                 blockSize, blockSizeDim,
    181                 padding, convertShapeToDims(paddingShape),
    182                 reinterpret_cast<uint8_t*>(outputData),
    183                 convertShapeToDims(outputShape));
    184     } else {
    185         LOG(ERROR) << "Unsupported data type";
    186         return false;
    187     }
    188     return true;
    189 }
    190 
    191 bool squeezeGeneric(const void* inputData, const Shape& inputShape,
    192                     void* outputData, const Shape& outputShape) {
    193     size_t count = sizeOfData(inputShape.type, inputShape.dimensions);
    194     memcpy(outputData, inputData, count);
    195     return true;
    196 }
    197 
    198 bool transposeGeneric(const uint8_t* inputData, const Shape& inputShape,
    199                       const int32_t* perm, const Shape& permShape,
    200                       uint8_t* outputData, const Shape& outputShape) {
    201     // Reverse the permuted axes and convert to 4D due to the way Dims are
    202     // constructed.
    203     const int32_t kOutputDimensionNum = 4;
    204 
    205     int32_t permSize = static_cast<int32_t>(getSizeOfDimension(permShape, 0));
    206     int32_t reversed_perm[kOutputDimensionNum];
    207     for (int32_t output_k = 0, input_k = permSize - 1; output_k < permSize;
    208              ++output_k, --input_k) {
    209         reversed_perm[output_k] = permSize - perm[input_k] - 1;
    210     }
    211     for (int32_t k = permSize; k < kOutputDimensionNum; ++k) {
    212         reversed_perm[k] = k;
    213     }
    214     if (inputShape.type == OperandType::TENSOR_FLOAT32) {
    215         tflite::reference_ops::Transpose(
    216                 reinterpret_cast<const float*>(inputData),
    217                 convertShapeToDims(inputShape),
    218                 reinterpret_cast<float*>(outputData),
    219                 convertShapeToDims(outputShape),
    220                 reversed_perm);
    221     } else if (inputShape.type == OperandType::TENSOR_QUANT8_ASYMM) {
    222         tflite::reference_ops::Transpose(
    223                 reinterpret_cast<const uint8_t*>(inputData),
    224                 convertShapeToDims(inputShape),
    225                 reinterpret_cast<uint8_t*>(outputData),
    226                 convertShapeToDims(outputShape),
    227                 reversed_perm);
    228     } else {
    229         LOG(ERROR) << "Unsupported data type";
    230         return false;
    231     }
    232     return true;
    233 }
    234 } // namespace nn
    235 } // namespace android
    236