Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include <unistd.h>
     16 #include <cassert>
     17 #include <cmath>
     18 #include <cstdio>
     19 #include <cstdlib>
     20 #include <iostream>
     21 #include <limits>
     22 
     23 #include "tensorflow/contrib/lite/builtin_op_data.h"
     24 #include "tensorflow/contrib/lite/context.h"
     25 #include "tensorflow/contrib/lite/kernels/activation_functor.h"
     26 #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
     27 #include "tensorflow/contrib/lite/kernels/op_macros.h"
     28 
     29 namespace tflite {
     30 namespace ops {
     31 namespace builtin {
     32 namespace rnn {
     33 
     34 constexpr int kInputTensor = 0;
     35 constexpr int kWeightsTensor = 1;
     36 constexpr int kRecurrentWeightsTensor = 2;
     37 constexpr int kBiasTensor = 3;
     38 constexpr int KHiddenStateTensor = 0;
     39 constexpr int kOutputTensor = 1;
     40 
     41 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
     42   // Check we have all the inputs and outputs we need.
     43   TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
     44   TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
     45 
     46   TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
     47   TfLiteTensor* input_weights =
     48       &context->tensors[node->inputs->data[kWeightsTensor]];
     49   TfLiteTensor* recurrent_weights =
     50       &context->tensors[node->inputs->data[kRecurrentWeightsTensor]];
     51   TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]];
     52 
     53   // Check all the parameters of tensor match within themselves and match the
     54   // input configuration.
     55   const int batch_size = input->dims->data[0];
     56   const int num_units = input_weights->dims->data[0];
     57   TF_LITE_ASSERT_EQ(input->dims->data[1], input_weights->dims->data[1]);
     58   TF_LITE_ASSERT_EQ(input_weights->dims->data[0], bias->dims->data[0]);
     59   TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[0], bias->dims->data[0]);
     60   TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]);
     61 
     62   TfLiteTensor* hidden_state =
     63       &context->tensors[node->outputs->data[KHiddenStateTensor]];
     64   TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]];
     65 
     66   // Resize state.
     67   TfLiteIntArray* hidden_state_size_array = TfLiteIntArrayCreate(2);
     68   hidden_state_size_array->data[0] = batch_size;
     69   hidden_state_size_array->data[1] = num_units;
     70   TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, hidden_state,
     71                                                    hidden_state_size_array));
     72 
     73   // Mark hidden state as a persistent tensor.
     74   hidden_state->allocation_type = kTfLiteArenaRwPersistent;
     75 
     76   // Resize output.
     77   TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2);
     78   output_size_array->data[0] = batch_size;
     79   output_size_array->data[1] = num_units;
     80   TF_LITE_ENSURE_OK(context,
     81                     context->ResizeTensor(context, output, output_size_array));
     82 
     83   return kTfLiteOk;
     84 }
     85 
     86 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
     87   auto* params = reinterpret_cast<TfLiteRNNParams*>(node->builtin_data);
     88 
     89   TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
     90   TfLiteTensor* input_weights =
     91       &context->tensors[node->inputs->data[kWeightsTensor]];
     92   TfLiteTensor* recurrent_weights =
     93       &context->tensors[node->inputs->data[kRecurrentWeightsTensor]];
     94   TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]];
     95   TfLiteTensor* hidden_state =
     96       &context->tensors[node->outputs->data[KHiddenStateTensor]];
     97   TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]];
     98 
     99   // Initialize the pointer bias.
    100   const float* bias_ptr = bias->data.f;
    101 
    102   const int batch_size = input->dims->data[0];
    103   const int num_units = input_weights->dims->data[0];
    104   const int input_size = input->dims->data[1];
    105 
    106   // Initialize the pointer to hidden state.
    107   float* hidden_state_ptr_batch = hidden_state->data.f;
    108   // Initialize the pointer to input and output.
    109   const float* input_ptr_batch = input->data.f;
    110   float* output_ptr_batch = output->data.f;
    111   // Initialize input_weights and recurrent_weights.
    112   const float* input_weights_ptr = input_weights->data.f;
    113   const float* recurrent_weights_ptr = recurrent_weights->data.f;
    114 
    115   kernel_utils::RnnBatchStep(input_ptr_batch, input_weights_ptr,
    116                              recurrent_weights_ptr, bias_ptr, input_size,
    117                              num_units, batch_size, params->activation,
    118                              hidden_state_ptr_batch, output_ptr_batch);
    119   return kTfLiteOk;
    120 }
    121 
    122 }  // namespace rnn
    123 
    124 TfLiteRegistration* Register_RNN() {
    125   static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
    126                                  rnn::Prepare, rnn::Eval};
    127   return &r;
    128 }
    129 
    130 }  // namespace builtin
    131 }  // namespace ops
    132 }  // namespace tflite
    133