Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include <cassert>
     16 #include <cmath>
     17 #include <cstdio>
     18 #include <cstdlib>
     19 #include <iostream>
     20 #include <limits>
     21 
     22 #include "tensorflow/lite/c/builtin_op_data.h"
     23 #include "tensorflow/lite/c/c_api_internal.h"
     24 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
     25 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
     26 #include "tensorflow/lite/kernels/internal/tensor.h"
     27 #include "tensorflow/lite/kernels/kernel_util.h"
     28 #include "tensorflow/lite/kernels/op_macros.h"
     29 
     30 namespace tflite {
     31 namespace ops {
     32 namespace builtin {
     33 namespace concatenation {
     34 
     35 // This file has two implementation of Concatenation.
     36 enum KernelType {
     37   kReference,
     38   kGenericOptimized,
     39 };
     40 
     41 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
     42   auto* params =
     43       reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
     44   int axis = params->axis;
     45   int num_inputs = node->inputs->size;
     46 
     47   // The number of dimensions of the input tensors must match, and all
     48   // dimensions except 'axis' must be equal.
     49   TfLiteTensor* t0 = &context->tensors[node->inputs->data[0]];
     50   TfLiteType input_type = t0->type;
     51   if (axis < 0) axis += t0->dims->size;
     52   TF_LITE_ENSURE(context, axis >= 0);
     53   TF_LITE_ENSURE(context, axis < t0->dims->size);
     54 
     55   // TODO(ahentz): These are limitations of our implementation that could be
     56   // removed with a bit of effort.
     57   TF_LITE_ENSURE(context, t0->dims->size <= 4);
     58   TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
     59   TF_LITE_ENSURE(context,
     60                  input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
     61                      input_type == kTfLiteInt8 || input_type == kTfLiteInt16 ||
     62                      input_type == kTfLiteInt32 || input_type == kTfLiteInt64);
     63 
     64   // Output dimensions will match input dimensions, except 'axis', which
     65   // will be the sum of inputs
     66   int sum_axis = t0->dims->data[axis];
     67   for (int i = 1; i < num_inputs; ++i) {
     68     TfLiteTensor* t = &context->tensors[node->inputs->data[i]];
     69     TF_LITE_ENSURE_EQ(context, t->dims->size, t0->dims->size);
     70     TF_LITE_ENSURE_EQ(context, t->type, input_type);
     71     for (int d = 0; d < t0->dims->size; ++d) {
     72       if (d == axis) {
     73         sum_axis += t->dims->data[axis];
     74       } else {
     75         TF_LITE_ENSURE_EQ(context, t->dims->data[d], t0->dims->data[d]);
     76       }
     77     }
     78   }
     79 
     80   TfLiteIntArray* output_size = TfLiteIntArrayCreate(t0->dims->size);
     81   for (int d = 0; d < t0->dims->size; ++d) {
     82     output_size->data[d] = (d == axis) ? sum_axis : t0->dims->data[d];
     83   }
     84 
     85   TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
     86   TF_LITE_ENSURE_EQ(context, output->type, input_type);
     87 
     88   if (input_type == kTfLiteInt8) {
     89     // Make sure there is no re-scaling needed for Int8 quantized kernel. This
     90     // is a restriction we introduced to Int8 kernels.
     91     VectorOfTensors<int8_t> all_inputs(*context, *node->inputs);
     92     for (int i = 0; i < node->inputs->size; ++i) {
     93       TfLiteTensor* t = &context->tensors[node->inputs->data[i]];
     94       TF_LITE_ENSURE_EQ(context, t->params.scale, output->params.scale);
     95       TF_LITE_ENSURE_EQ(context, t->params.zero_point,
     96                         output->params.zero_point);
     97     }
     98   }
     99 
    100   return context->ResizeTensor(context, output, output_size);
    101 }
    102 
    103 template <KernelType kernel_type>
    104 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
    105   auto* params =
    106       reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
    107   int axis = params->axis;
    108   TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
    109   if (axis < 0) axis += output->dims->size;
    110 
    111 // TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should
    112 // allocate and populate these during Prepare().
    113 // TODO(ycling): Activation function parameter is ignored. For now we dont have
    114 // a model with a Concatenation with fused activation function.
    115 #define TF_LITE_CONCATENATION(type, scalar)                                \
    116   {                                                                        \
    117     VectorOfTensors<scalar> all_inputs(*context, *node->inputs);           \
    118     tflite::ConcatenationParams op_params;                                 \
    119     op_params.axis = axis;                                                 \
    120     op_params.inputs_count = node->inputs->size;                           \
    121     type::Concatenation(op_params, all_inputs.shapes(), all_inputs.data(), \
    122                         GetTensorShape(output),                            \
    123                         GetTensorData<scalar>(output));                    \
    124   }
    125 
    126 #define TF_LITE_CONCATENATION_QUANTIZED(type)                                 \
    127   {                                                                           \
    128     VectorOfQuantizedTensors all_inputs(*context, *node->inputs);             \
    129     tflite::ConcatenationParams op_params;                                    \
    130     op_params.axis = axis;                                                    \
    131     op_params.input_zeropoint = all_inputs.zero_point();                      \
    132     op_params.input_scale = all_inputs.scale();                               \
    133     op_params.inputs_count = node->inputs->size;                              \
    134     op_params.output_zeropoint = output->params.zero_point;                   \
    135     op_params.output_scale = output->params.scale;                            \
    136     type::ConcatenationWithScaling(op_params, all_inputs.shapes(),            \
    137                                    all_inputs.data(), GetTensorShape(output), \
    138                                    GetTensorData<uint8>(output));             \
    139   }
    140 
    141   switch (output->type) {  // Already know in/outtypes are same.
    142     case kTfLiteFloat32:
    143       if (kernel_type == kReference) {
    144         TF_LITE_CONCATENATION(reference_ops, float);
    145       } else {
    146         TF_LITE_CONCATENATION(optimized_ops, float);
    147       }
    148       break;
    149     case kTfLiteInt32:
    150       if (kernel_type == kReference) {
    151         TF_LITE_CONCATENATION(reference_ops, int32);
    152       } else {
    153         TF_LITE_CONCATENATION(optimized_ops, int32);
    154       }
    155       break;
    156     case kTfLiteUInt8:
    157       if (kernel_type == kReference) {
    158         TF_LITE_CONCATENATION_QUANTIZED(reference_ops);
    159       } else {
    160         TF_LITE_CONCATENATION_QUANTIZED(optimized_ops);
    161       }
    162       break;
    163     case kTfLiteInt8: {
    164       if (kernel_type == kReference) {
    165         TF_LITE_CONCATENATION(reference_ops, int8_t);
    166       } else {
    167         TF_LITE_CONCATENATION(optimized_ops, int8_t);
    168       }
    169     } break;
    170     case kTfLiteInt64:
    171       if (kernel_type == kReference) {
    172         TF_LITE_CONCATENATION(reference_ops, int64_t);
    173       } else {
    174         TF_LITE_CONCATENATION(optimized_ops, int64_t);
    175       }
    176       break;
    177 
    178     default:
    179       context->ReportError(context,
    180                            "Only float32 and uint8 are currently supported.");
    181       return kTfLiteError;
    182   }
    183 
    184 #undef TF_LITE_CONCATENATION_QUANTIZED
    185 #undef TF_LITE_CONCATENATION
    186 
    187   return kTfLiteOk;
    188 }
    189 
    190 #undef TF_LITE_MACRO_DISPATCH
    191 
    192 }  // namespace concatenation
    193 
    194 TfLiteRegistration* Register_CONCATENATION_REF() {
    195   static TfLiteRegistration r = {
    196       nullptr, nullptr, concatenation::Prepare,
    197       concatenation::Eval<concatenation::kReference>};
    198   return &r;
    199 }
    200 
    201 TfLiteRegistration* Register_CONCATENATION_GENERIC_OPT() {
    202   static TfLiteRegistration r = {
    203       nullptr, nullptr, concatenation::Prepare,
    204       concatenation::Eval<concatenation::kGenericOptimized>};
    205   return &r;
    206 }
    207 
    208 TfLiteRegistration* Register_CONCATENATION() {
    209   // TODO(ahentz): It turns out the two versions of Concatenation are almost
    210   // identical, so we should consider removing one.
    211   return Register_CONCATENATION_GENERIC_OPT();
    212 }
    213 
    214 }  // namespace builtin
    215 }  // namespace ops
    216 }  // namespace tflite
    217