Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/lite/c/builtin_op_data.h"
     17 #include "tensorflow/lite/c/c_api_internal.h"
     18 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
     19 #include "tensorflow/lite/kernels/internal/tensor.h"
     20 #include "tensorflow/lite/kernels/kernel_util.h"
     21 
     22 namespace tflite {
     23 namespace ops {
     24 namespace builtin {
     25 namespace pack {
     26 namespace {
     27 
     28 constexpr int kOutputTensor = 0;
     29 
     30 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
     31   const TfLitePackParams* data =
     32       reinterpret_cast<TfLitePackParams*>(node->builtin_data);
     33 
     34   TF_LITE_ENSURE_EQ(context, NumInputs(node), data->values_count);
     35   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
     36 
     37   const TfLiteTensor* input0 = GetInput(context, node, 0);
     38   TF_LITE_ENSURE(context, NumDimensions(input0) >= data->axis);
     39   // TODO(renjieliu): Support negative axis.
     40   TF_LITE_ENSURE(context, data->axis >= 0);
     41   if (input0->type != kTfLiteInt32 && input0->type != kTfLiteFloat32 &&
     42       input0->type != kTfLiteUInt8 && input0->type != kTfLiteInt8 &&
     43       input0->type != kTfLiteInt16 && input0->type != kTfLiteInt64) {
     44     context->ReportError(context, "Type '%s' is not supported by pack.",
     45                          TfLiteTypeGetName(input0->type));
     46     return kTfLiteError;
     47   }
     48   // Make sure all inputs have the same shape and type.
     49   for (int i = 1; i < data->values_count; ++i) {
     50     const TfLiteTensor* input = GetInput(context, node, i);
     51     TF_LITE_ENSURE(context, HaveSameShapes(input0, input));
     52     TF_LITE_ENSURE_EQ(context, input0->type, input->type);
     53   }
     54 
     55   // Resize output. rank R will become rank R + 1
     56   const int dimension_size = NumDimensions(input0) + 1;
     57   const TfLiteIntArray* input_shape = input0->dims;
     58   TfLiteIntArray* output_shape = TfLiteIntArrayCreate(dimension_size);
     59   int i = 0;
     60   for (int index = 0; index < dimension_size; ++index) {
     61     if (index == data->axis) {
     62       output_shape->data[index] = data->values_count;
     63     } else {
     64       output_shape->data[index] = input_shape->data[i++];
     65     }
     66   }
     67 
     68   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
     69   TF_LITE_ENSURE_EQ(context, output->type, input0->type);
     70 
     71   // Guarantee input/output quantization params match as we do not support
     72   // packing quantized tensors.
     73   for (int i = 0; i < data->values_count; i++) {
     74     const TfLiteTensor* input = GetInput(context, node, i);
     75     TF_LITE_ENSURE_EQ(context, input->params.zero_point,
     76                       output->params.zero_point);
     77     TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale);
     78   }
     79 
     80   return context->ResizeTensor(context, output, output_shape);
     81 }
     82 
     83 template <typename T>
     84 void PackImpl(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* output,
     85               int values_count, int axis) {
     86   VectorOfTensors<T> all_inputs(*context, *node->inputs);
     87   tflite::PackParams op_params;
     88   op_params.axis = axis;
     89   op_params.inputs_count = values_count;
     90 
     91   reference_ops::Pack<T>(op_params, all_inputs.shapes(), all_inputs.data(),
     92                          GetTensorShape(output), GetTensorData<T>(output));
     93 }
     94 
     95 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
     96   const TfLitePackParams* data =
     97       reinterpret_cast<TfLitePackParams*>(node->builtin_data);
     98 
     99   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
    100   switch (output->type) {
    101     case kTfLiteFloat32: {
    102       PackImpl<float>(context, node, output, data->values_count, data->axis);
    103       break;
    104     }
    105     case kTfLiteUInt8: {
    106       PackImpl<uint8_t>(context, node, output, data->values_count, data->axis);
    107       break;
    108     }
    109     case kTfLiteInt8: {
    110       PackImpl<int8_t>(context, node, output, data->values_count, data->axis);
    111       break;
    112     }
    113     case kTfLiteInt32: {
    114       PackImpl<int32_t>(context, node, output, data->values_count, data->axis);
    115       break;
    116     }
    117     case kTfLiteInt64: {
    118       PackImpl<int64_t>(context, node, output, data->values_count, data->axis);
    119       break;
    120     }
    121     default: {
    122       context->ReportError(context, "Type '%s' is not supported by pack.",
    123                            TfLiteTypeGetName(output->type));
    124       return kTfLiteError;
    125     }
    126   }
    127 
    128   return kTfLiteOk;
    129 }
    130 
    131 }  // namespace
    132 }  // namespace pack
    133 
    134 TfLiteRegistration* Register_PACK() {
    135   static TfLiteRegistration r = {nullptr, nullptr, pack::Prepare, pack::Eval};
    136   return &r;
    137 }
    138 
    139 }  // namespace builtin
    140 }  // namespace ops
    141 }  // namespace tflite
    142