1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // Contains the implementation of the operations. 18 19 #define LOG_TAG "Operations" 20 21 #include "CpuOperationUtils.h" 22 #include "Operations.h" 23 24 #include "tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h" 25 #include "tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h" 26 27 #include "Tracing.h" 28 29 namespace android { 30 namespace nn { 31 32 bool floorFloat16(const _Float16* inputData, _Float16* outputData, const Shape& shape) { 33 NNTRACE_TRANS("floorFloat16"); 34 std::vector<float> inputDataFloat32(getNumberOfElements(shape)); 35 convertFloat16ToFloat32(inputData, &inputDataFloat32); 36 37 std::vector<float> outputDataFloat32(getNumberOfElements(shape)); 38 floorFloat32(inputDataFloat32.data(), outputDataFloat32.data(), shape); 39 convertFloat32ToFloat16(outputDataFloat32, outputData); 40 return true; 41 } 42 43 bool floorFloat32(const float* inputData, float* outputData, const Shape& shape) { 44 NNTRACE_TRANS("floorFloat32"); 45 tflite::Dims<4> dim = convertShapeToDims(shape); 46 NNTRACE_COMP_SWITCH("optimized_ops::Floor"); 47 tflite::optimized_ops::Floor(inputData, dim, outputData, dim); 48 return true; 49 } 50 51 bool meanFloat16(_Float16* inputData, const Shape& inputShape, const int32_t* axis, 52 const Shape& axisShape, bool keepDims, _Float16* outputData, 53 const Shape& outputShape) { 54 NNTRACE_TRANS("meanFloat16"); 55 std::vector<float> inputDataFloat32(getNumberOfElements(inputShape)); 56 convertFloat16ToFloat32(inputData, &inputDataFloat32); 57 58 std::vector<float> outputDataFloat32(getNumberOfElements(outputShape)); 59 meanGeneric<float, float>(inputDataFloat32.data(), inputShape, axis, axisShape, keepDims, 60 outputDataFloat32.data(), outputShape); 61 convertFloat32ToFloat16(outputDataFloat32, outputData); 62 return true; 63 } 64 65 template <typename T, typename U> 66 bool meanGeneric(T* inputData, const Shape& inputShape, const int32_t* axis, const Shape& axisShape, 67 bool keepDims, T* outputData, const Shape& outputShape) { 68 NNTRACE_TRANS("meanGeneric"); 69 // Creates a temp index to iterate through input data. 70 int32_t* scratchBuffer = new int32_t[getNumberOfDimensions(inputShape)]; 71 72 // Creates a temp tensor to store resolved axis given input data. 73 int32_t axisSize = static_cast<int32_t>(getSizeOfDimension(axisShape, 0)); 74 int32_t* resolvedAxis = new int32_t[axisSize]; 75 76 bool result = true; 77 U* tempSumBuffer = new (std::nothrow) U[getNumberOfElements(outputShape)]; 78 if (!tempSumBuffer) { 79 LOG(ERROR) << "Failed to allocate tempSumBuffer for MEAN"; 80 result = false; 81 } else { 82 NNTRACE_COMP_SWITCH("optimized_ops::Mean"); 83 tflite::reference_ops::Mean<T, U>( 84 inputData, reinterpret_cast<const int*>(inputShape.dimensions.data()), 85 getNumberOfDimensions(inputShape), outputData, 86 reinterpret_cast<const int*>(outputShape.dimensions.data()), 87 getNumberOfDimensions(outputShape), axis, axisSize, keepDims, scratchBuffer, 88 resolvedAxis, tempSumBuffer); 89 delete[] tempSumBuffer; 90 } 91 delete[] scratchBuffer; 92 delete[] resolvedAxis; 93 return result; 94 } 95 template bool meanGeneric<float, float>(float* inputData, const Shape& inputShape, 96 const int32_t* axis, const Shape& axisShape, bool keepDims, 97 float* outputData, const Shape& outputShape); 98 template bool meanGeneric<uint8_t, int32_t>(uint8_t* inputData, const Shape& inputShape, 99 const int32_t* axis, const Shape& axisShape, 100 bool keepDims, uint8_t* outputData, 101 const Shape& outputShape); 102 103 } // namespace nn 104 } // namespace android 105