1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <algorithm> 16 17 #include "tensorflow/contrib/lite/builtin_op_data.h" 18 #include "tensorflow/contrib/lite/context.h" 19 #include "tensorflow/contrib/lite/kernels/internal/tensor.h" 20 #include "tensorflow/contrib/lite/kernels/kernel_util.h" 21 #include "tensorflow/contrib/lite/kernels/op_macros.h" 22 namespace tflite { 23 namespace ops { 24 namespace builtin { 25 namespace topk_v2 { 26 constexpr int kInputTensor = 0; 27 constexpr int kInputTopK = 1; 28 constexpr int kOutputIndexes = 0; 29 constexpr int kOutputValues = 1; 30 31 namespace { 32 TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) { 33 TfLiteTensor* top_k = GetInput(context, node, kInputTopK); 34 // INT32 number of top results is supported. 35 TF_LITE_ENSURE_EQ(context, top_k->type, kTfLiteInt32); 36 // Check that the tensor contains only one value. 37 TF_LITE_ENSURE_EQ(context, NumDimensions(top_k), 1); 38 TF_LITE_ENSURE_EQ(context, NumElements(top_k), 1); 39 const int32 k = top_k->data.i32[0]; 40 41 TfLiteTensor* input = GetInput(context, node, kInputTensor); 42 const int num_dimensions = NumDimensions(input); 43 // Check that input has one or more dimensions. 44 TF_LITE_ENSURE_MSG(context, input->dims->size >= 1, 45 "TopK k input must have 1 or more dimensions."); 46 // Check that k is less or equal the internal dimension. 47 TF_LITE_ENSURE_MSG(context, k <= input->dims->data[num_dimensions - 1], 48 "TopK k is higher than the internal dimension."); 49 50 TfLiteIntArray* output_indexes_shape = TfLiteIntArrayCreate(num_dimensions); 51 TfLiteIntArray* output_values_shape = TfLiteIntArrayCreate(num_dimensions); 52 for (int i = 0; i < num_dimensions - 1; ++i) { 53 output_indexes_shape->data[i] = input->dims->data[i]; 54 output_values_shape->data[i] = input->dims->data[i]; 55 } 56 output_indexes_shape->data[num_dimensions - 1] = k; 57 output_values_shape->data[num_dimensions - 1] = k; 58 TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes); 59 TfLiteTensor* output_values = GetOutput(context, node, kOutputValues); 60 auto resize_tensor = [context](TfLiteTensor* tensor, TfLiteIntArray* new_size, 61 TfLiteIntArray* delete_on_error) { 62 TfLiteStatus status = context->ResizeTensor(context, tensor, new_size); 63 if (status != kTfLiteOk) { 64 TfLiteIntArrayFree(new_size); 65 if (delete_on_error != nullptr) { 66 TfLiteIntArrayFree(delete_on_error); 67 } 68 } 69 return status; 70 }; 71 TF_LITE_ENSURE_OK(context, resize_tensor(output_indexes, output_indexes_shape, 72 output_values_shape)); 73 TF_LITE_ENSURE_OK(context, 74 resize_tensor(output_values, output_values_shape, nullptr)); 75 return kTfLiteOk; 76 } 77 78 // The class that collects top indexes of k values. Based on template 79 // tensorflow::gtl::TopN<> but, for optimization, 80 // it re-uses the same container. 81 template <typename T> 82 class TopContainer { 83 public: 84 TopContainer() = delete; 85 TopContainer(int32 k, int32 row_size) : k_(k) { 86 container_.reserve(std::min(k, row_size) + 1); 87 } 88 89 void start_collecting(const T* values) { 90 values_ = values; 91 container_.clear(); 92 } 93 void push(int32 a) { 94 auto comparator = [this](int32 a, int32 b) { return compare_fun(a, b); }; 95 if (container_.size() <= k_) { 96 container_.push_back(a); 97 if (container_.size() == k_ + 1) { 98 std::make_heap(container_.begin(), container_.end(), comparator); 99 std::pop_heap(container_.begin(), container_.end(), comparator); 100 } 101 } else if (comparator(a, container_.front())) { 102 container_.back() = a; 103 std::push_heap(container_.begin(), container_.end(), comparator); 104 std::pop_heap(container_.begin(), container_.end(), comparator); 105 } 106 } 107 108 const std::vector<int32>& sorted_result() { 109 auto comparator = [this](int32 a, int32 b) { return compare_fun(a, b); }; 110 if (container_.size() <= k_) { 111 std::sort(container_.begin(), container_.end(), comparator); 112 } else { 113 std::sort_heap(container_.begin(), container_.end() - 1, comparator); 114 container_.resize(k_); 115 } 116 return container_; 117 } 118 119 private: 120 int32 k_; 121 std::vector<int32> container_; 122 const T* values_ = nullptr; 123 124 bool compare_fun(int32 a, int32 b) const { 125 if (values_[b] < values_[a]) { 126 return true; 127 } else if (values_[b] > values_[a]) { 128 return false; 129 } else { 130 return a < b; 131 } 132 } 133 }; 134 135 // Mostly modeled on tensorflow/core/kernels/topk_op.cc for CPU. 136 template <typename T> 137 void TopK(int32 row_size, int32 num_rows, const T* data, int32 k, 138 int32* output_indexes, T* output_values) { 139 TopContainer<T> topc(k, row_size); 140 for (int row = 0; row < num_rows; ++row) { 141 const T* values_row = data + row * row_size; 142 topc.start_collecting(values_row); 143 for (int32 c = 0; c < row_size; ++c) { 144 topc.push(c); 145 } 146 147 // Prepare output buffers. 148 int32* indexes_row = output_indexes + row * k; 149 T* output_row = output_values + row * k; 150 // We always assume that the output is sorted. 151 const auto& top_k = topc.sorted_result(); 152 std::copy(top_k.begin(), top_k.end(), indexes_row); 153 std::transform(top_k.begin(), top_k.end(), output_row, 154 [values_row](const int32 loc) { return values_row[loc]; }); 155 } 156 } 157 158 } // namespace 159 160 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 161 // Check that the inputs and outputs have the right sizes and types. 162 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); 163 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2); 164 165 TfLiteTensor* input = GetInput(context, node, kInputTensor); 166 TfLiteTensor* output_values = GetOutput(context, node, kOutputValues); 167 TF_LITE_ENSURE_EQ(context, input->type, output_values->type); 168 169 TfLiteTensor* top_k = GetInput(context, node, kInputTopK); 170 TF_LITE_ENSURE_EQ(context, top_k->type, kTfLiteInt32); 171 172 // Set output dynamic if the input is not const. 173 if (IsConstantTensor(top_k)) { 174 TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); 175 } else { 176 TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes); 177 TfLiteTensor* output_values = GetOutput(context, node, kOutputValues); 178 SetTensorToDynamic(output_indexes); 179 SetTensorToDynamic(output_values); 180 } 181 return kTfLiteOk; 182 } 183 184 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 185 TfLiteTensor* output_values = GetOutput(context, node, kOutputValues); 186 TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes); 187 if (IsDynamicTensor(output_values)) { 188 TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); 189 } 190 TfLiteTensor* top_k = GetInput(context, node, kInputTopK); 191 const int32 k = top_k->data.i32[0]; 192 // The tensor can have more than 2 dimensions or even be a vector, the code 193 // anyway calls the internal dimension as row; 194 TfLiteTensor* input = GetInput(context, node, kInputTensor); 195 const int32 row_size = input->dims->data[input->dims->size - 1]; 196 int32 num_rows = 1; 197 for (int i = 0; i < input->dims->size - 1; ++i) { 198 num_rows *= input->dims->data[i]; 199 } 200 switch (output_values->type) { 201 case kTfLiteFloat32: 202 TopK(row_size, num_rows, input->data.f, k, output_indexes->data.i32, 203 output_values->data.f); 204 break; 205 case kTfLiteUInt8: 206 TopK(row_size, num_rows, input->data.uint8, k, output_indexes->data.i32, 207 output_values->data.uint8); 208 break; 209 case kTfLiteInt32: 210 TopK(row_size, num_rows, input->data.i32, k, output_indexes->data.i32, 211 output_values->data.i32); 212 break; 213 case kTfLiteInt64: 214 TopK(row_size, num_rows, input->data.i64, k, output_indexes->data.i32, 215 output_values->data.i64); 216 break; 217 default: 218 context->ReportError(context, "Type is currently not supported by TopK."); 219 return kTfLiteError; 220 } 221 222 return kTfLiteOk; 223 } 224 } // namespace topk_v2 225 TfLiteRegistration* Register_TOPK_V2() { 226 static TfLiteRegistration r = {nullptr, nullptr, topk_v2::Prepare, 227 topk_v2::Eval}; 228 return &r; 229 } 230 } // namespace builtin 231 } // namespace ops 232 } // namespace tflite 233