Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 // Contains the implementation of the operations.
     18 
     19 #define LOG_TAG "Operations"
     20 
     21 #include "CpuOperationUtils.h"
     22 #include "Operations.h"
     23 
     24 #include "tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h"
     25 #include "tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h"
     26 
     27 #include "Tracing.h"
     28 
     29 namespace android {
     30 namespace nn {
     31 
     32 bool floorFloat16(const _Float16* inputData, _Float16* outputData, const Shape& shape) {
     33     NNTRACE_TRANS("floorFloat16");
     34     std::vector<float> inputDataFloat32(getNumberOfElements(shape));
     35     convertFloat16ToFloat32(inputData, &inputDataFloat32);
     36 
     37     std::vector<float> outputDataFloat32(getNumberOfElements(shape));
     38     floorFloat32(inputDataFloat32.data(), outputDataFloat32.data(), shape);
     39     convertFloat32ToFloat16(outputDataFloat32, outputData);
     40     return true;
     41 }
     42 
     43 bool floorFloat32(const float* inputData, float* outputData, const Shape& shape) {
     44     NNTRACE_TRANS("floorFloat32");
     45     tflite::Dims<4> dim = convertShapeToDims(shape);
     46     NNTRACE_COMP_SWITCH("optimized_ops::Floor");
     47     tflite::optimized_ops::Floor(inputData, dim, outputData, dim);
     48     return true;
     49 }
     50 
     51 bool meanFloat16(_Float16* inputData, const Shape& inputShape, const int32_t* axis,
     52                  const Shape& axisShape, bool keepDims, _Float16* outputData,
     53                  const Shape& outputShape) {
     54     NNTRACE_TRANS("meanFloat16");
     55     std::vector<float> inputDataFloat32(getNumberOfElements(inputShape));
     56     convertFloat16ToFloat32(inputData, &inputDataFloat32);
     57 
     58     std::vector<float> outputDataFloat32(getNumberOfElements(outputShape));
     59     meanGeneric<float, float>(inputDataFloat32.data(), inputShape, axis, axisShape, keepDims,
     60                               outputDataFloat32.data(), outputShape);
     61     convertFloat32ToFloat16(outputDataFloat32, outputData);
     62     return true;
     63 }
     64 
     65 template <typename T, typename U>
     66 bool meanGeneric(T* inputData, const Shape& inputShape, const int32_t* axis, const Shape& axisShape,
     67                  bool keepDims, T* outputData, const Shape& outputShape) {
     68     NNTRACE_TRANS("meanGeneric");
     69     // Creates a temp index to iterate through input data.
     70     int32_t* scratchBuffer = new int32_t[getNumberOfDimensions(inputShape)];
     71 
     72     // Creates a temp tensor to store resolved axis given input data.
     73     int32_t axisSize = static_cast<int32_t>(getSizeOfDimension(axisShape, 0));
     74     int32_t* resolvedAxis = new int32_t[axisSize];
     75 
     76     bool result = true;
     77     U* tempSumBuffer = new (std::nothrow) U[getNumberOfElements(outputShape)];
     78     if (!tempSumBuffer) {
     79         LOG(ERROR) << "Failed to allocate tempSumBuffer for MEAN";
     80         result = false;
     81     } else {
     82         NNTRACE_COMP_SWITCH("optimized_ops::Mean");
     83         tflite::reference_ops::Mean<T, U>(
     84                 inputData, reinterpret_cast<const int*>(inputShape.dimensions.data()),
     85                 getNumberOfDimensions(inputShape), outputData,
     86                 reinterpret_cast<const int*>(outputShape.dimensions.data()),
     87                 getNumberOfDimensions(outputShape), axis, axisSize, keepDims, scratchBuffer,
     88                 resolvedAxis, tempSumBuffer);
     89         delete[] tempSumBuffer;
     90     }
     91     delete[] scratchBuffer;
     92     delete[] resolvedAxis;
     93     return result;
     94 }
     95 template bool meanGeneric<float, float>(float* inputData, const Shape& inputShape,
     96                                         const int32_t* axis, const Shape& axisShape, bool keepDims,
     97                                         float* outputData, const Shape& outputShape);
     98 template bool meanGeneric<uint8_t, int32_t>(uint8_t* inputData, const Shape& inputShape,
     99                                             const int32_t* axis, const Shape& axisShape,
    100                                             bool keepDims, uint8_t* outputData,
    101                                             const Shape& outputShape);
    102 
    103 }  // namespace nn
    104 }  // namespace android
    105