Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <map>
     17 
     18 #include "tensorflow/lite/c/builtin_op_data.h"
     19 #include "tensorflow/lite/c/c_api_internal.h"
     20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
     21 #include "tensorflow/lite/kernels/internal/tensor.h"
     22 #include "tensorflow/lite/kernels/kernel_util.h"
     23 
     24 namespace tflite {
     25 namespace ops {
     26 namespace builtin {
     27 namespace unique {
     28 
     29 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
     30   return nullptr;
     31 }
     32 
     33 void Free(TfLiteContext* context, void* buffer) {}
     34 
     35 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
     36   static const int kOutputUniqueTensor = 0;
     37   static const int kOutputIndexTensor = 1;
     38 
     39   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
     40   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
     41   const TfLiteTensor* input = GetInput(context, node, 0);
     42   TfLiteTensor* output_unique_tensor =
     43       GetOutput(context, node, kOutputUniqueTensor);
     44   TfLiteTensor* output_index_tensor =
     45       GetOutput(context, node, kOutputIndexTensor);
     46 
     47   // The op only supports 1D input.
     48   TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
     49   TfLiteIntArray* output_index_shape = TfLiteIntArrayCopy(input->dims);
     50   // The unique values are determined during evaluation, so we don't know yet
     51   // the size of the output tensor.
     52   SetTensorToDynamic(output_unique_tensor);
     53   return context->ResizeTensor(context, output_index_tensor,
     54                                output_index_shape);
     55 }
     56 
     57 namespace {
     58 
     59 // Actual evaluation for the unique op.
     60 template <typename T, typename I>
     61 TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input,
     62                       TfLiteNode* node) {
     63   // Map from value, to index in the unique elements vector.
     64   // Note that we prefer to use map than unordered_map as it showed less
     65   // increase in the binary size.
     66   std::map<T, int> unique_values;
     67   TfLiteTensor* output_indexes = GetOutput(context, node, 1);
     68   I* indexes = GetTensorData<I>(output_indexes);
     69   const T* data = GetTensorData<T>(input);
     70   const int num_elements = NumElements(input);
     71 
     72   for (int i = 0; i < num_elements; ++i) {
     73     const auto element_it = unique_values.find(data[i]);
     74     if (element_it != unique_values.end()) {
     75       indexes[i] = element_it->second;
     76     } else {
     77       const int unique_index = unique_values.size();
     78       unique_values[data[i]] = unique_index;
     79       indexes[i] = unique_index;
     80     }
     81   }
     82   // Allocate output tensor.
     83   TfLiteTensor* unique_output = GetOutput(context, node, 0);
     84   std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)> shape(
     85       TfLiteIntArrayCreate(NumDimensions(input)), TfLiteIntArrayFree);
     86   shape->data[0] = unique_values.size();
     87   TF_LITE_ENSURE_STATUS(
     88       context->ResizeTensor(context, unique_output, shape.release()));
     89   // Set the values in the output tensor.
     90   T* output_unique_values = GetTensorData<T>(unique_output);
     91   for (int i = 0; i < unique_values.size(); ++i) {
     92     output_unique_values[i] = data[indexes[i]];
     93   }
     94   return kTfLiteOk;
     95 }
     96 
     97 template <typename T>
     98 TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input,
     99                       TfLiteNode* node) {
    100   auto* params = reinterpret_cast<TfLiteUniqueParams*>(node->builtin_data);
    101   if (params == nullptr) {
    102     context->ReportError(context, "Null params passed");
    103     return kTfLiteError;
    104   }
    105   switch (params->index_out_type) {
    106     case kTfLiteInt32:
    107       return EvalImpl<T, int32_t>(context, input, node);
    108     case kTfLiteInt64:
    109       return EvalImpl<T, int64_t>(context, input, node);
    110     default:
    111       context->ReportError(
    112           context,
    113           "Unique index output array can only be Int32 or In64, requested: ",
    114           TfLiteTypeGetName(params->index_out_type));
    115   }
    116   return kTfLiteError;
    117 }
    118 
    119 }  // namespace
    120 
    121 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
    122   const TfLiteTensor* input = GetInput(context, node, 0);
    123   TfLiteTensor* output_index_tensor = GetOutput(context, node, 1);
    124   TF_LITE_ENSURE_EQ(context, NumElements(output_index_tensor),
    125                     NumElements(input));
    126 
    127   switch (input->type) {
    128     case kTfLiteInt8:
    129       TF_LITE_ENSURE_STATUS(EvalImpl<int8_t>(context, input, node));
    130       break;
    131     case kTfLiteInt16:
    132       TF_LITE_ENSURE_STATUS(EvalImpl<int16_t>(context, input, node));
    133       break;
    134     case kTfLiteInt32:
    135       TF_LITE_ENSURE_STATUS(EvalImpl<int32_t>(context, input, node));
    136       break;
    137     case kTfLiteInt64:
    138       TF_LITE_ENSURE_STATUS(EvalImpl<int64_t>(context, input, node));
    139       break;
    140     case kTfLiteFloat32:
    141       TF_LITE_ENSURE_STATUS(EvalImpl<float>(context, input, node));
    142       break;
    143     case kTfLiteUInt8:
    144       TF_LITE_ENSURE_STATUS(EvalImpl<uint8_t>(context, input, node));
    145       break;
    146     default:
    147       context->ReportError(context, "Currently Unique doesn't support type: %s",
    148                            TfLiteTypeGetName(input->type));
    149       return kTfLiteError;
    150   }
    151   return kTfLiteOk;
    152 }
    153 
    154 }  // namespace unique
    155 
    156 TfLiteRegistration* Register_UNIQUE() {
    157   static TfLiteRegistration r = {unique::Init, unique::Free, unique::Prepare,
    158                                  unique::Eval};
    159   return &r;
    160 }
    161 
    162 }  // namespace builtin
    163 }  // namespace ops
    164 }  // namespace tflite
    165