1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_ML_NN_COMMON_CPU_OPERATION_UTILS_H 18 #define ANDROID_ML_NN_COMMON_CPU_OPERATION_UTILS_H 19 20 #include "OperationsUtils.h" 21 22 #include <algorithm> 23 #include <cmath> 24 #include <limits> 25 26 #include "tensorflow/lite/kernels/internal/types.h" 27 28 namespace android { 29 namespace nn { 30 31 // The implementations in tflite/kernels/internal/ take a Dims<4> object 32 // even if the original tensors were not 4D. 33 inline tflite::Dims<4> convertShapeToDims(const Shape& shape) { 34 nnAssert(shape.dimensions.size() <= 4); 35 tflite::Dims<4> dims; 36 37 // The dimensions are reversed in Dims<4>. 38 for (int i = 0; i < 4; ++i) { 39 int src = static_cast<int>(shape.dimensions.size()) - i - 1; 40 if (src >= 0) { 41 dims.sizes[i] = static_cast<int>(getSizeOfDimension(shape, src)); 42 } else { 43 dims.sizes[i] = 1; 44 } 45 } 46 47 dims.strides[0] = 1; 48 for (int i = 1; i < 4; i++) { 49 dims.strides[i] = dims.strides[i - 1] * dims.sizes[i - 1]; 50 } 51 return dims; 52 } 53 54 inline tflite::RuntimeShape convertShapeToTflshape(const Shape& shape) { 55 nnAssert(shape.dimensions.size() <= 4); 56 57 std::vector<int32_t> tflShapeDim(shape.dimensions.begin(), shape.dimensions.end()); 58 return tflite::RuntimeShape(tflShapeDim.size(), tflShapeDim.data()); 59 } 60 61 inline void convertFloat16ToFloat32(const _Float16* input, std::vector<float>* output) { 62 CHECK(input != nullptr); 63 CHECK(output != nullptr); 64 for (int i = 0; i < output->size(); ++i) { 65 (*output)[i] = static_cast<float>(input[i]); 66 } 67 } 68 69 inline void convertFloat32ToFloat16(const std::vector<float>& input, _Float16* output) { 70 CHECK(output != nullptr); 71 for (int i = 0; i < input.size(); ++i) { 72 output[i] = input[i]; 73 } 74 } 75 76 template <typename T> 77 inline void convertQuantToFloat32(const T* input, float scale, int32_t zeroPoint, 78 std::vector<float>* output) { 79 CHECK(input != nullptr); 80 CHECK(output != nullptr); 81 for (int i = 0; i < output->size(); ++i) { 82 (*output)[i] = (static_cast<float>(input[i]) - zeroPoint) * scale; 83 } 84 } 85 86 template <typename T> 87 inline void convertFloat32ToQuant(const std::vector<float>& input, float scale, int32_t zeroPoint, 88 T* output) { 89 CHECK(output != nullptr); 90 for (int i = 0; i < input.size(); ++i) { 91 int32_t intVal = std::round(input[i] / scale + zeroPoint); 92 intVal = std::min<int32_t>(std::max<int32_t>(intVal, std::numeric_limits<T>::min()), 93 std::numeric_limits<T>::max()); 94 output[i] = static_cast<T>(intVal); 95 } 96 } 97 98 template <typename T> 99 inline bool convertNchwToNhwc(const T* nchw, const Shape& nchwShape, std::vector<T>* nhwc, 100 Shape* nhwcShape) { 101 NN_RET_CHECK_EQ(getNumberOfDimensions(nchwShape), 4) 102 << "Error converting a non-4-D tensor to NHWC layout"; 103 *nhwcShape = nchwShape; 104 const auto& fromDim = nchwShape.dimensions; 105 nhwcShape->dimensions = {fromDim[0], fromDim[2], fromDim[3], fromDim[1]}; 106 nhwc->resize(getNumberOfElements(nchwShape)); 107 auto to = nhwc->data(); 108 uint32_t spatialSize = fromDim[2] * fromDim[3]; 109 for (uint32_t n = 0; n < fromDim[0]; n++) { 110 for (uint32_t hw = 0; hw < spatialSize; hw++) { 111 for (uint32_t c = 0; c < fromDim[1]; c++) { 112 uint32_t fromIndex = n * fromDim[1] * spatialSize + c * spatialSize + hw; 113 *to++ = nchw[fromIndex]; 114 } 115 } 116 } 117 return true; 118 } 119 120 template <typename T> 121 inline bool convertNhwcToNchw(const std::vector<T>& nhwc, const Shape& nhwcShape, T* nchw) { 122 NN_RET_CHECK_EQ(getNumberOfDimensions(nhwcShape), 4) 123 << "Error converting a non-4-D tensor to NCHW layout"; 124 const auto& fromDim = nhwcShape.dimensions; 125 const auto from = nhwc.data(); 126 uint32_t spatialSize = fromDim[1] * fromDim[2]; 127 for (uint32_t n = 0; n < fromDim[0]; n++) { 128 for (uint32_t c = 0; c < fromDim[3]; c++) { 129 for (uint32_t hw = 0; hw < spatialSize; hw++) { 130 uint32_t fromIndex = n * spatialSize * fromDim[3] + hw * fromDim[3] + c; 131 *nchw++ = from[fromIndex]; 132 } 133 } 134 } 135 return true; 136 } 137 138 template <typename T> 139 class InputWithLayout { 140 public: 141 InputWithLayout(bool useNchw) : mDataOriginal(nullptr), mUseNchw(useNchw) {} 142 143 bool initialize(const T* data, const Shape& shape) { 144 mDataOriginal = data; 145 mShape = shape; 146 if (mUseNchw) { 147 return convertNchwToNhwc(mDataOriginal, shape, &mDataNhwc, &mShape); 148 } 149 return true; 150 } 151 152 const T* getNhwcBuffer() { return mUseNchw ? mDataNhwc.data() : mDataOriginal; } 153 const Shape& getNhwcShape() { return mShape; } 154 155 private: 156 const T* mDataOriginal; 157 std::vector<T> mDataNhwc; 158 Shape mShape; 159 bool mUseNchw; 160 }; 161 162 template <typename T> 163 class OutputWithLayout { 164 public: 165 OutputWithLayout(bool useNchw) : mDataOriginal(nullptr), mUseNchw(useNchw) {} 166 167 bool initialize(T* data, const Shape& shape) { 168 NN_RET_CHECK_EQ(getNumberOfDimensions(shape), 4); 169 mDataOriginal = data; 170 mShape = shape; 171 if (mUseNchw) { 172 const auto& dim = shape.dimensions; 173 mShape.dimensions = {dim[0], dim[2], dim[3], dim[1]}; 174 mDataNhwc.resize(getNumberOfElements(shape)); 175 } 176 return true; 177 } 178 179 T* getNhwcBuffer() { return mUseNchw ? mDataNhwc.data() : mDataOriginal; } 180 const Shape& getNhwcShape() { return mShape; } 181 bool commit() { 182 if (mUseNchw) { 183 return convertNhwcToNchw(mDataNhwc, mShape, mDataOriginal); 184 } 185 return true; 186 } 187 188 private: 189 T* mDataOriginal; 190 std::vector<T> mDataNhwc; 191 Shape mShape; 192 bool mUseNchw; 193 }; 194 195 } // namespace nn 196 } // namespace android 197 198 #endif // ANDROID_ML_NN_COMMON_CPU_OPERATION_UTILS_H 199