Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include <algorithm>
     16 
     17 #include "tensorflow/contrib/lite/builtin_op_data.h"
     18 #include "tensorflow/contrib/lite/context.h"
     19 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
     20 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
     21 #include "tensorflow/contrib/lite/kernels/op_macros.h"
     22 namespace tflite {
     23 namespace ops {
     24 namespace builtin {
     25 namespace topk_v2 {
     26 constexpr int kInputTensor = 0;
     27 constexpr int kInputTopK = 1;
     28 constexpr int kOutputIndexes = 0;
     29 constexpr int kOutputValues = 1;
     30 
     31 namespace {
     32 TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
     33   TfLiteTensor* top_k = GetInput(context, node, kInputTopK);
     34   // INT32 number of top results is supported.
     35   TF_LITE_ENSURE_EQ(context, top_k->type, kTfLiteInt32);
     36   // Check that the tensor contains only one value.
     37   TF_LITE_ENSURE_EQ(context, NumDimensions(top_k), 1);
     38   TF_LITE_ENSURE_EQ(context, NumElements(top_k), 1);
     39   const int32 k = top_k->data.i32[0];
     40 
     41   TfLiteTensor* input = GetInput(context, node, kInputTensor);
     42   const int num_dimensions = NumDimensions(input);
     43   // Check that input has one or more dimensions.
     44   TF_LITE_ENSURE_MSG(context, input->dims->size >= 1,
     45                      "TopK k input must have 1 or more dimensions.");
     46   // Check that k is less or equal the internal dimension.
     47   TF_LITE_ENSURE_MSG(context, k <= input->dims->data[num_dimensions - 1],
     48                      "TopK k is higher than the internal dimension.");
     49 
     50   TfLiteIntArray* output_indexes_shape = TfLiteIntArrayCreate(num_dimensions);
     51   TfLiteIntArray* output_values_shape = TfLiteIntArrayCreate(num_dimensions);
     52   for (int i = 0; i < num_dimensions - 1; ++i) {
     53     output_indexes_shape->data[i] = input->dims->data[i];
     54     output_values_shape->data[i] = input->dims->data[i];
     55   }
     56   output_indexes_shape->data[num_dimensions - 1] = k;
     57   output_values_shape->data[num_dimensions - 1] = k;
     58   TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes);
     59   TfLiteTensor* output_values = GetOutput(context, node, kOutputValues);
     60   auto resize_tensor = [context](TfLiteTensor* tensor, TfLiteIntArray* new_size,
     61                                  TfLiteIntArray* delete_on_error) {
     62     TfLiteStatus status = context->ResizeTensor(context, tensor, new_size);
     63     if (status != kTfLiteOk) {
     64       TfLiteIntArrayFree(new_size);
     65       if (delete_on_error != nullptr) {
     66         TfLiteIntArrayFree(delete_on_error);
     67       }
     68     }
     69     return status;
     70   };
     71   TF_LITE_ENSURE_OK(context, resize_tensor(output_indexes, output_indexes_shape,
     72                                            output_values_shape));
     73   TF_LITE_ENSURE_OK(context,
     74                     resize_tensor(output_values, output_values_shape, nullptr));
     75   return kTfLiteOk;
     76 }
     77 
     78 // The class that collects top indexes of k values. Based on template
     79 // tensorflow::gtl::TopN<> but, for optimization,
     80 // it re-uses the same container.
     81 template <typename T>
     82 class TopContainer {
     83  public:
     84   TopContainer() = delete;
     85   TopContainer(int32 k, int32 row_size) : k_(k) {
     86     container_.reserve(std::min(k, row_size) + 1);
     87   }
     88 
     89   void start_collecting(const T* values) {
     90     values_ = values;
     91     container_.clear();
     92   }
     93   void push(int32 a) {
     94     auto comparator = [this](int32 a, int32 b) { return compare_fun(a, b); };
     95     if (container_.size() <= k_) {
     96       container_.push_back(a);
     97       if (container_.size() == k_ + 1) {
     98         std::make_heap(container_.begin(), container_.end(), comparator);
     99         std::pop_heap(container_.begin(), container_.end(), comparator);
    100       }
    101     } else if (comparator(a, container_.front())) {
    102       container_.back() = a;
    103       std::push_heap(container_.begin(), container_.end(), comparator);
    104       std::pop_heap(container_.begin(), container_.end(), comparator);
    105     }
    106   }
    107 
    108   const std::vector<int32>& sorted_result() {
    109     auto comparator = [this](int32 a, int32 b) { return compare_fun(a, b); };
    110     if (container_.size() <= k_) {
    111       std::sort(container_.begin(), container_.end(), comparator);
    112     } else {
    113       std::sort_heap(container_.begin(), container_.end() - 1, comparator);
    114       container_.resize(k_);
    115     }
    116     return container_;
    117   }
    118 
    119  private:
    120   int32 k_;
    121   std::vector<int32> container_;
    122   const T* values_ = nullptr;
    123 
    124   bool compare_fun(int32 a, int32 b) const {
    125     if (values_[b] < values_[a]) {
    126       return true;
    127     } else if (values_[b] > values_[a]) {
    128       return false;
    129     } else {
    130       return a < b;
    131     }
    132   }
    133 };
    134 
    135 // Mostly modeled on tensorflow/core/kernels/topk_op.cc for CPU.
    136 template <typename T>
    137 void TopK(int32 row_size, int32 num_rows, const T* data, int32 k,
    138           int32* output_indexes, T* output_values) {
    139   TopContainer<T> topc(k, row_size);
    140   for (int row = 0; row < num_rows; ++row) {
    141     const T* values_row = data + row * row_size;
    142     topc.start_collecting(values_row);
    143     for (int32 c = 0; c < row_size; ++c) {
    144       topc.push(c);
    145     }
    146 
    147     // Prepare output buffers.
    148     int32* indexes_row = output_indexes + row * k;
    149     T* output_row = output_values + row * k;
    150     // We always assume that the output is sorted.
    151     const auto& top_k = topc.sorted_result();
    152     std::copy(top_k.begin(), top_k.end(), indexes_row);
    153     std::transform(top_k.begin(), top_k.end(), output_row,
    154                    [values_row](const int32 loc) { return values_row[loc]; });
    155   }
    156 }
    157 
    158 }  // namespace
    159 
    160 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
    161   // Check that the inputs and outputs have the right sizes and types.
    162   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
    163   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
    164 
    165   TfLiteTensor* input = GetInput(context, node, kInputTensor);
    166   TfLiteTensor* output_values = GetOutput(context, node, kOutputValues);
    167   TF_LITE_ENSURE_EQ(context, input->type, output_values->type);
    168 
    169   TfLiteTensor* top_k = GetInput(context, node, kInputTopK);
    170   TF_LITE_ENSURE_EQ(context, top_k->type, kTfLiteInt32);
    171 
    172   // Set output dynamic if the input is not const.
    173   if (IsConstantTensor(top_k)) {
    174     TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
    175   } else {
    176     TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes);
    177     TfLiteTensor* output_values = GetOutput(context, node, kOutputValues);
    178     SetTensorToDynamic(output_indexes);
    179     SetTensorToDynamic(output_values);
    180   }
    181   return kTfLiteOk;
    182 }
    183 
    184 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
    185   TfLiteTensor* output_values = GetOutput(context, node, kOutputValues);
    186   TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes);
    187   if (IsDynamicTensor(output_values)) {
    188     TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
    189   }
    190   TfLiteTensor* top_k = GetInput(context, node, kInputTopK);
    191   const int32 k = top_k->data.i32[0];
    192   // The tensor can have more than 2 dimensions or even be a vector, the code
    193   // anyway calls the internal dimension as row;
    194   TfLiteTensor* input = GetInput(context, node, kInputTensor);
    195   const int32 row_size = input->dims->data[input->dims->size - 1];
    196   int32 num_rows = 1;
    197   for (int i = 0; i < input->dims->size - 1; ++i) {
    198     num_rows *= input->dims->data[i];
    199   }
    200   switch (output_values->type) {
    201     case kTfLiteFloat32:
    202       TopK(row_size, num_rows, input->data.f, k, output_indexes->data.i32,
    203            output_values->data.f);
    204       break;
    205     case kTfLiteUInt8:
    206       TopK(row_size, num_rows, input->data.uint8, k, output_indexes->data.i32,
    207            output_values->data.uint8);
    208       break;
    209     case kTfLiteInt32:
    210       TopK(row_size, num_rows, input->data.i32, k, output_indexes->data.i32,
    211            output_values->data.i32);
    212       break;
    213     case kTfLiteInt64:
    214       TopK(row_size, num_rows, input->data.i64, k, output_indexes->data.i32,
    215            output_values->data.i64);
    216       break;
    217     default:
    218       context->ReportError(context, "Type is currently not supported by TopK.");
    219       return kTfLiteError;
    220   }
    221 
    222   return kTfLiteOk;
    223 }
    224 }  // namespace topk_v2
    225 TfLiteRegistration* Register_TOPK_V2() {
    226   static TfLiteRegistration r = {nullptr, nullptr, topk_v2::Prepare,
    227                                  topk_v2::Eval};
    228   return &r;
    229 }
    230 }  // namespace builtin
    231 }  // namespace ops
    232 }  // namespace tflite
    233