Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 #include "TopK_V2.h"
     17 
     18 #include "OperationsUtils.h"
     19 
     20 #include <algorithm>
     21 
     22 namespace android {
     23 namespace nn {
     24 namespace topk_v2 {
     25 
     26 namespace {
     27 
     28 template <typename T>
     29 bool evalGeneric(const T* inputData, const Shape& inputShape, const int32_t k, T* valuesData,
     30                  const Shape& /*valuesShape*/, int32_t* indicesData,
     31                  const Shape& /*indicesShape*/) {
     32     const int rowSize = inputShape.dimensions.back();
     33     const int totalSize = getNumberOfElements(inputShape);
     34     std::vector<std::pair<T, int32_t>> values(rowSize);
     35     T* curOutputValue = valuesData;
     36     int32_t* curOutputIndex = indicesData;
     37     for (int rowBegin = 0; rowBegin < totalSize; rowBegin += rowSize) {
     38         for (int i = 0; i < rowSize; ++i) {
     39             values[i] = std::make_pair(inputData[rowBegin + i], i);
     40         }
     41         std::nth_element(values.begin(), values.begin() + (rowSize - k), values.end());
     42         std::sort(values.begin() + (rowSize - k), values.end());
     43         std::reverse(values.begin(), values.end());
     44         for (int i = 0; i < k; ++i) {
     45             *curOutputValue = values[i].first;
     46             *curOutputIndex = values[i].second;
     47             curOutputValue++;
     48             curOutputIndex++;
     49         }
     50     }
     51     return true;
     52 }
     53 
     54 }  // namespace
     55 
     56 bool prepare(const Shape& input, int32_t k, Shape* values, Shape* indices) {
     57     NN_CHECK(k > 0);
     58     NN_CHECK(k <= input.dimensions.back());
     59 
     60     values->dimensions = input.dimensions;
     61     values->dimensions.back() = k;
     62     indices->dimensions = input.dimensions;
     63     indices->dimensions.back() = k;
     64     return true;
     65 }
     66 
     67 bool eval(const void* inputData, const Shape& inputShape, const int32_t k, void* valuesData,
     68           const Shape& valuesShape, void* indicesData, const Shape& indicesShape) {
     69     switch (inputShape.type) {
     70         case OperandType::TENSOR_FLOAT16: {
     71             return evalGeneric(reinterpret_cast<const _Float16*>(inputData), inputShape, k,
     72                                reinterpret_cast<_Float16*>(valuesData), valuesShape,
     73                                reinterpret_cast<int32_t*>(indicesData), indicesShape);
     74         } break;
     75         case OperandType::TENSOR_FLOAT32: {
     76             return evalGeneric(reinterpret_cast<const float*>(inputData), inputShape, k,
     77                                reinterpret_cast<float*>(valuesData), valuesShape,
     78                                reinterpret_cast<int32_t*>(indicesData), indicesShape);
     79         } break;
     80         case OperandType::TENSOR_INT32: {
     81             return evalGeneric(reinterpret_cast<const int32_t*>(inputData), inputShape, k,
     82                                reinterpret_cast<int32_t*>(valuesData), valuesShape,
     83                                reinterpret_cast<int32_t*>(indicesData), indicesShape);
     84         } break;
     85         case OperandType::TENSOR_QUANT8_ASYMM: {
     86             return evalGeneric(reinterpret_cast<const uint8_t*>(inputData), inputShape, k,
     87                                reinterpret_cast<uint8_t*>(valuesData), valuesShape,
     88                                reinterpret_cast<int32_t*>(indicesData), indicesShape);
     89         } break;
     90         default: {
     91             LOG(ERROR) << "Unsupported data type: " << toString(inputShape.type);
     92             return false;
     93         }
     94     }
     95 }
     96 
     97 }  // namespace topk_v2
     98 }  // namespace nn
     99 }  // namespace android
    100