1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <unistd.h> 16 #include <cassert> 17 #include <cmath> 18 #include <cstdio> 19 #include <cstdlib> 20 #include <iostream> 21 #include <limits> 22 23 #include "tensorflow/contrib/lite/builtin_op_data.h" 24 #include "tensorflow/contrib/lite/context.h" 25 #include "tensorflow/contrib/lite/kernels/activation_functor.h" 26 #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" 27 #include "tensorflow/contrib/lite/kernels/op_macros.h" 28 29 namespace tflite { 30 namespace ops { 31 namespace builtin { 32 namespace rnn { 33 34 constexpr int kInputTensor = 0; 35 constexpr int kWeightsTensor = 1; 36 constexpr int kRecurrentWeightsTensor = 2; 37 constexpr int kBiasTensor = 3; 38 constexpr int KHiddenStateTensor = 0; 39 constexpr int kOutputTensor = 1; 40 41 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 42 // Check we have all the inputs and outputs we need. 43 TF_LITE_ENSURE_EQ(context, node->inputs->size, 4); 44 TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); 45 46 TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; 47 TfLiteTensor* input_weights = 48 &context->tensors[node->inputs->data[kWeightsTensor]]; 49 TfLiteTensor* recurrent_weights = 50 &context->tensors[node->inputs->data[kRecurrentWeightsTensor]]; 51 TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]]; 52 53 // Check all the parameters of tensor match within themselves and match the 54 // input configuration. 55 const int batch_size = input->dims->data[0]; 56 const int num_units = input_weights->dims->data[0]; 57 TF_LITE_ASSERT_EQ(input->dims->data[1], input_weights->dims->data[1]); 58 TF_LITE_ASSERT_EQ(input_weights->dims->data[0], bias->dims->data[0]); 59 TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[0], bias->dims->data[0]); 60 TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]); 61 62 TfLiteTensor* hidden_state = 63 &context->tensors[node->outputs->data[KHiddenStateTensor]]; 64 TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]]; 65 66 // Resize state. 67 TfLiteIntArray* hidden_state_size_array = TfLiteIntArrayCreate(2); 68 hidden_state_size_array->data[0] = batch_size; 69 hidden_state_size_array->data[1] = num_units; 70 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, hidden_state, 71 hidden_state_size_array)); 72 73 // Mark hidden state as a persistent tensor. 74 hidden_state->allocation_type = kTfLiteArenaRwPersistent; 75 76 // Resize output. 77 TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2); 78 output_size_array->data[0] = batch_size; 79 output_size_array->data[1] = num_units; 80 TF_LITE_ENSURE_OK(context, 81 context->ResizeTensor(context, output, output_size_array)); 82 83 return kTfLiteOk; 84 } 85 86 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 87 auto* params = reinterpret_cast<TfLiteRNNParams*>(node->builtin_data); 88 89 TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; 90 TfLiteTensor* input_weights = 91 &context->tensors[node->inputs->data[kWeightsTensor]]; 92 TfLiteTensor* recurrent_weights = 93 &context->tensors[node->inputs->data[kRecurrentWeightsTensor]]; 94 TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]]; 95 TfLiteTensor* hidden_state = 96 &context->tensors[node->outputs->data[KHiddenStateTensor]]; 97 TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]]; 98 99 // Initialize the pointer bias. 100 const float* bias_ptr = bias->data.f; 101 102 const int batch_size = input->dims->data[0]; 103 const int num_units = input_weights->dims->data[0]; 104 const int input_size = input->dims->data[1]; 105 106 // Initialize the pointer to hidden state. 107 float* hidden_state_ptr_batch = hidden_state->data.f; 108 // Initialize the pointer to input and output. 109 const float* input_ptr_batch = input->data.f; 110 float* output_ptr_batch = output->data.f; 111 // Initialize input_weights and recurrent_weights. 112 const float* input_weights_ptr = input_weights->data.f; 113 const float* recurrent_weights_ptr = recurrent_weights->data.f; 114 115 kernel_utils::RnnBatchStep(input_ptr_batch, input_weights_ptr, 116 recurrent_weights_ptr, bias_ptr, input_size, 117 num_units, batch_size, params->activation, 118 hidden_state_ptr_batch, output_ptr_batch); 119 return kTfLiteOk; 120 } 121 122 } // namespace rnn 123 124 TfLiteRegistration* Register_RNN() { 125 static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, 126 rnn::Prepare, rnn::Eval}; 127 return &r; 128 } 129 130 } // namespace builtin 131 } // namespace ops 132 } // namespace tflite 133