Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include <string.h>
     16 #include <vector>
     17 #include "tensorflow/lite/c/builtin_op_data.h"
     18 #include "tensorflow/lite/c/c_api_internal.h"
     19 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
     20 #include "tensorflow/lite/kernels/internal/tensor.h"
     21 #include "tensorflow/lite/kernels/kernel_util.h"
     22 #include "tensorflow/lite/kernels/op_macros.h"
     23 
     24 namespace tflite {
     25 namespace ops {
     26 namespace builtin {
     27 namespace transpose {
     28 
     29 // This file has two implementations of Transpose.
     30 enum KernelType {
     31   kReference,
     32 };
     33 
     34 struct TransposeContext {
     35   TransposeContext(TfLiteContext* context, TfLiteNode* node) {
     36     input = GetInput(context, node, 0);
     37     perm = GetInput(context, node, 1);
     38     output = GetOutput(context, node, 0);
     39   }
     40   const TfLiteTensor* input;
     41   const TfLiteTensor* perm;
     42   TfLiteTensor* output;
     43 };
     44 
     45 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
     46                                 TransposeContext* op_context) {
     47   int dims = NumDimensions(op_context->input);
     48   const int* perm_data = GetTensorData<int32_t>(op_context->perm);
     49 
     50   // Ensure validity of the permutations tensor as a 1D tensor.
     51   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->perm), 1);
     52   TF_LITE_ENSURE_EQ(context, op_context->perm->dims->data[0], dims);
     53   for (int idx = 0; idx < dims; ++idx) {
     54     TF_LITE_ENSURE_MSG(context, (perm_data[idx] >= 0 && perm_data[idx] < dims),
     55                        "Transpose op permutations array is out of bounds.");
     56   }
     57 
     58   // Determine size of output tensor.
     59   TfLiteIntArray* input_size = op_context->input->dims;
     60   TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
     61   for (int idx = 0; idx < dims; ++idx) {
     62     output_size->data[idx] = input_size->data[perm_data[idx]];
     63   }
     64 
     65   return context->ResizeTensor(context, op_context->output, output_size);
     66 }
     67 
     68 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
     69   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
     70   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
     71 
     72   TransposeContext op_context(context, node);
     73 
     74   // Ensure validity of input tensor.
     75   TF_LITE_ENSURE_MSG(context, NumDimensions(op_context.input) <= 4,
     76                      "Transpose op only supports 1D-4D input arrays.");
     77   TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
     78 
     79   if (!IsConstantTensor(op_context.perm)) {
     80     SetTensorToDynamic(op_context.output);
     81     return kTfLiteOk;
     82   }
     83   return ResizeOutputTensor(context, &op_context);
     84 }
     85 
     86 template <KernelType kernel_type>
     87 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
     88   TransposeContext op_context(context, node);
     89 
     90   // Resize the output tensor if the output tensor is dynamic.
     91   if (IsDynamicTensor(op_context.output)) {
     92     TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
     93   }
     94 
     95   const int* perm_data = GetTensorData<int32_t>(op_context.perm);
     96   const int size = op_context.perm->dims->data[0];
     97   TransposeParams params;
     98   params.perm_count = size;
     99   for (int i = 0; i < size; ++i) {
    100     params.perm[i] = perm_data[i];
    101   }
    102 
    103 #define TF_LITE_TRANSPOSE(type, scalar)                     \
    104   type::Transpose(params, GetTensorShape(op_context.input), \
    105                   GetTensorData<scalar>(op_context.input),  \
    106                   GetTensorShape(op_context.output),        \
    107                   GetTensorData<scalar>(op_context.output))
    108 
    109   switch (op_context.input->type) {
    110     case kTfLiteFloat32:
    111       if (kernel_type == kReference) {
    112         TF_LITE_TRANSPOSE(reference_ops, float);
    113       }
    114       break;
    115     case kTfLiteUInt8:
    116       if (kernel_type == kReference) {
    117         TF_LITE_TRANSPOSE(reference_ops, uint8_t);
    118       }
    119       break;
    120     case kTfLiteInt8:
    121       if (kernel_type == kReference) {
    122         TF_LITE_TRANSPOSE(reference_ops, int8_t);
    123       }
    124       break;
    125     case kTfLiteInt32:
    126       if (kernel_type == kReference) {
    127         TF_LITE_TRANSPOSE(reference_ops, int32_t);
    128       }
    129       break;
    130     case kTfLiteInt64:
    131       if (kernel_type == kReference) {
    132         TF_LITE_TRANSPOSE(reference_ops, int64_t);
    133       }
    134       break;
    135     default:
    136       context->ReportError(context,
    137                            "Type %d is currently not supported by Transpose.",
    138                            op_context.input->type);
    139       return kTfLiteError;
    140   }
    141 #undef TF_LITE_TRANSPOSE
    142 
    143   return kTfLiteOk;
    144 }
    145 
    146 }  // namespace transpose
    147 
    148 TfLiteRegistration* Register_TRANSPOSE_REF() {
    149   static TfLiteRegistration r = {nullptr, nullptr, transpose::Prepare,
    150                                  transpose::Eval<transpose::kReference>};
    151   return &r;
    152 }
    153 
    154 TfLiteRegistration* Register_TRANSPOSE() { return Register_TRANSPOSE_REF(); }
    155 
    156 }  // namespace builtin
    157 }  // namespace ops
    158 }  // namespace tflite
    159