1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include "TopK_V2.h" 17 18 #include "OperationsUtils.h" 19 20 #include <algorithm> 21 22 namespace android { 23 namespace nn { 24 namespace topk_v2 { 25 26 namespace { 27 28 template <typename T> 29 bool evalGeneric(const T* inputData, const Shape& inputShape, const int32_t k, T* valuesData, 30 const Shape& /*valuesShape*/, int32_t* indicesData, 31 const Shape& /*indicesShape*/) { 32 const int rowSize = inputShape.dimensions.back(); 33 const int totalSize = getNumberOfElements(inputShape); 34 std::vector<std::pair<T, int32_t>> values(rowSize); 35 T* curOutputValue = valuesData; 36 int32_t* curOutputIndex = indicesData; 37 for (int rowBegin = 0; rowBegin < totalSize; rowBegin += rowSize) { 38 for (int i = 0; i < rowSize; ++i) { 39 values[i] = std::make_pair(inputData[rowBegin + i], i); 40 } 41 std::nth_element(values.begin(), values.begin() + (rowSize - k), values.end()); 42 std::sort(values.begin() + (rowSize - k), values.end()); 43 std::reverse(values.begin(), values.end()); 44 for (int i = 0; i < k; ++i) { 45 *curOutputValue = values[i].first; 46 *curOutputIndex = values[i].second; 47 curOutputValue++; 48 curOutputIndex++; 49 } 50 } 51 return true; 52 } 53 54 } // namespace 55 56 bool prepare(const Shape& input, int32_t k, Shape* values, Shape* indices) { 57 NN_CHECK(k > 0); 58 NN_CHECK(k <= input.dimensions.back()); 59 60 values->dimensions = input.dimensions; 61 values->dimensions.back() = k; 62 indices->dimensions = input.dimensions; 63 indices->dimensions.back() = k; 64 return true; 65 } 66 67 bool eval(const void* inputData, const Shape& inputShape, const int32_t k, void* valuesData, 68 const Shape& valuesShape, void* indicesData, const Shape& indicesShape) { 69 switch (inputShape.type) { 70 case OperandType::TENSOR_FLOAT16: { 71 return evalGeneric(reinterpret_cast<const _Float16*>(inputData), inputShape, k, 72 reinterpret_cast<_Float16*>(valuesData), valuesShape, 73 reinterpret_cast<int32_t*>(indicesData), indicesShape); 74 } break; 75 case OperandType::TENSOR_FLOAT32: { 76 return evalGeneric(reinterpret_cast<const float*>(inputData), inputShape, k, 77 reinterpret_cast<float*>(valuesData), valuesShape, 78 reinterpret_cast<int32_t*>(indicesData), indicesShape); 79 } break; 80 case OperandType::TENSOR_INT32: { 81 return evalGeneric(reinterpret_cast<const int32_t*>(inputData), inputShape, k, 82 reinterpret_cast<int32_t*>(valuesData), valuesShape, 83 reinterpret_cast<int32_t*>(indicesData), indicesShape); 84 } break; 85 case OperandType::TENSOR_QUANT8_ASYMM: { 86 return evalGeneric(reinterpret_cast<const uint8_t*>(inputData), inputShape, k, 87 reinterpret_cast<uint8_t*>(valuesData), valuesShape, 88 reinterpret_cast<int32_t*>(indicesData), indicesShape); 89 } break; 90 default: { 91 LOG(ERROR) << "Unsupported data type: " << toString(inputShape.type); 92 return false; 93 } 94 } 95 } 96 97 } // namespace topk_v2 98 } // namespace nn 99 } // namespace android 100