1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CONTRIB_RNN_KERNELS_GRU_OPS_H_ 17 #define TENSORFLOW_CONTRIB_RNN_KERNELS_GRU_OPS_H_ 18 19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 20 #include "tensorflow/contrib/rnn/kernels/blas_gemm.h" 21 #include "tensorflow/core/framework/tensor_types.h" 22 #include "tensorflow/core/platform/types.h" 23 24 namespace tensorflow { 25 26 class OpKernelContext; 27 28 namespace functor { 29 30 struct GRUCell { 31 GRUCell(const int batch_size, const int input_size, const int cell_size) 32 : batch_size_(batch_size), 33 input_size_(input_size), 34 cell_size_(cell_size) {} 35 36 inline Eigen::array<Eigen::DenseIndex, 2> x_offsets() const { return {0, 0}; } 37 38 inline Eigen::array<Eigen::DenseIndex, 2> x_extends() const { 39 return {batch_size_, input_size_}; 40 } 41 42 inline Eigen::array<Eigen::DenseIndex, 2> h_offsets() const { 43 return {0, input_size_}; 44 } 45 46 inline Eigen::array<Eigen::DenseIndex, 2> h_extends() const { 47 return {batch_size_, cell_size_}; 48 } 49 50 inline Eigen::array<Eigen::DenseIndex, 2> ru_r_offset() const { 51 return {0, 0}; 52 } 53 54 inline Eigen::array<Eigen::DenseIndex, 2> ru_u_offset() const { 55 return {0, cell_size_}; 56 } 57 58 inline Eigen::array<Eigen::DenseIndex, 2> cell_extents() const { 59 return {batch_size_, cell_size_}; 60 } 61 62 protected: 63 const int batch_size_; 64 const int input_size_; 65 const int cell_size_; 66 }; 67 68 template <typename Device, typename T, bool USE_CUBLAS> 69 struct GRUBlockCellFprop : public GRUCell { 70 GRUBlockCellFprop(const int batch_size, const int input_size, 71 const int cell_size) 72 : GRUCell(batch_size, input_size, cell_size) {} 73 74 void operator()( 75 OpKernelContext* ctx, const Device& d, typename TTypes<T>::ConstMatrix x, 76 typename TTypes<T>::ConstMatrix h_prev, 77 typename TTypes<T>::ConstMatrix w_ru, typename TTypes<T>::ConstMatrix w_c, 78 typename TTypes<T>::ConstVec b_ru, typename TTypes<T>::ConstVec b_c, 79 typename TTypes<T>::Matrix r_u_bar, typename TTypes<T>::Matrix r, 80 typename TTypes<T>::Matrix u, typename TTypes<T>::Matrix c, 81 typename TTypes<T>::Matrix h, typename TTypes<T>::Matrix x_h_prev, 82 typename TTypes<T>::Matrix x_h_prevr) { 83 // Concat x_h_prev = [x, h_prev]. 84 x_h_prev.slice(x_offsets(), x_extends()).device(d) = x; 85 x_h_prev.slice(h_offsets(), h_extends()).device(d) = h_prev; 86 87 // r_u_bar = x_h_prev * w_ru + b_ru 88 typename TTypes<T>::ConstMatrix const_x_h_prev(x_h_prev.data(), 89 x_h_prev.dimensions()); 90 TensorBlasGemm<Device, T, USE_CUBLAS>::compute( 91 ctx, d, false, false, T(1), const_x_h_prev, w_ru, T(0), r_u_bar); 92 93 // Creating a bias matrix for adding by broadcasting 'b_ru' 94 Eigen::array<Eigen::DenseIndex, 2> broadcast_shape({batch_size_, 1}); 95 Eigen::array<Eigen::DenseIndex, 2> b_ru_shape({1, b_ru.dimensions()[0]}); 96 r_u_bar.device(d) += b_ru.reshape(b_ru_shape).broadcast(broadcast_shape); 97 98 // Slice r_u_bar into r, u and apply the sigmoid. 99 r.device(d) = (r_u_bar.slice(ru_r_offset(), cell_extents())).sigmoid(); 100 u.device(d) = (r_u_bar.slice(ru_u_offset(), cell_extents())).sigmoid(); 101 102 // Concat x_h_prevr = [x,h_prev*r] 103 x_h_prevr.slice(x_offsets(), x_extends()).device(d) = x; 104 x_h_prevr.slice(h_offsets(), h_extends()).device(d) = h_prev * r; 105 106 // c = tanh(x_h_prevr*w_c+b_c), Note b_c is broadcasted before adding. 107 typename TTypes<T>::ConstMatrix const_x_h_prevr(x_h_prevr.data(), 108 x_h_prevr.dimensions()); 109 TensorBlasGemm<Device, T, USE_CUBLAS>::compute( 110 ctx, d, false, false, T(1), const_x_h_prevr, w_c, T(0), c); 111 112 Eigen::array<Eigen::DenseIndex, 2> b_c_shape({1, b_c.dimensions()[0]}); 113 c.device(d) += (b_c.reshape(b_c_shape).broadcast(broadcast_shape)); 114 c.device(d) = c.tanh(); 115 116 // h= u*h_prev + (1-u)*c 117 h.device(d) = u * (h_prev - c) + c; 118 } 119 }; 120 121 template <typename Device, typename T, bool USE_CUBLAS> 122 struct GRUBlockCellBprop : public GRUCell { 123 GRUBlockCellBprop(const int batch_size, const int input_size, 124 const int cell_size) 125 : GRUCell(batch_size, input_size, cell_size) {} 126 127 void operator()( 128 OpKernelContext* ctx, const Device& d, typename TTypes<T>::ConstMatrix x, 129 typename TTypes<T>::ConstMatrix h_prev, 130 typename TTypes<T>::ConstMatrix w_ru, typename TTypes<T>::ConstMatrix w_c, 131 typename TTypes<T>::ConstVec b_ru, typename TTypes<T>::ConstVec b_c, 132 typename TTypes<T>::ConstMatrix r, typename TTypes<T>::ConstMatrix u, 133 typename TTypes<T>::ConstMatrix c, typename TTypes<T>::ConstMatrix d_h, 134 typename TTypes<T>::Matrix d_x, typename TTypes<T>::Matrix d_h_prev, 135 typename TTypes<T>::Matrix d_c_bar, 136 typename TTypes<T>::Matrix d_r_bar_u_bar, 137 typename TTypes<T>::Matrix d_r_bar, typename TTypes<T>::Matrix d_u_bar, 138 typename TTypes<T>::Matrix d_hr, 139 typename TTypes<T>::Matrix d_x_comp1_and_h_prev_comp1, 140 typename TTypes<T>::Matrix d_x_comp2_and_h_prevr) { 141 // d_c_bar = d_h*(1-u)*(1-(c*c)) 142 d_c_bar.device(d) = 143 ((d_h * (u.constant(T(1)) - u)) * (c.constant(T(1)) - c * c)); 144 145 // d_u_bar = d_h*(h-c)*(u*(1-u)) 146 d_u_bar.device(d) = d_h * (h_prev - c) * u * (u.constant(T(1)) - u); 147 148 // [2nd_component_of_d_x d_h_prevr] = d_c_bar X w_c^T 149 typename TTypes<T>::ConstMatrix const_d_c_bar(d_c_bar.data(), 150 d_c_bar.dimensions()); 151 TensorBlasGemm<Device, T, USE_CUBLAS>::compute(ctx, d, false, true, T(1), 152 const_d_c_bar, w_c, T(0), 153 d_x_comp2_and_h_prevr); 154 155 d_hr.device(d) = d_x_comp2_and_h_prevr.slice(h_offsets(), h_extends()); 156 d_r_bar.device(d) = (d_hr * h_prev * r) * (r.constant(T(1)) - r); 157 158 // d_r_bar_u_bar = concatenate(d_r_bar, d_u_bar) along axis = 1. 159 d_r_bar_u_bar.slice(ru_r_offset(), cell_extents()).device(d) = d_r_bar; 160 d_r_bar_u_bar.slice(ru_u_offset(), cell_extents()).device(d) = d_u_bar; 161 162 // [1st_component_of_d_x 1st_component_of_d_h_prev] = [d_r_bar d_u_bar] X 163 // w_ru^T 164 typename TTypes<T>::ConstMatrix const_d_r_bar_u_bar( 165 d_r_bar_u_bar.data(), d_r_bar_u_bar.dimensions()); 166 TensorBlasGemm<Device, T, USE_CUBLAS>::compute( 167 ctx, d, false, true, T(1), const_d_r_bar_u_bar, w_ru, T(0), 168 d_x_comp1_and_h_prev_comp1); 169 170 // d_x = d_x_comp1 + d_x_comp2 171 d_x.device(d) = (d_x_comp1_and_h_prev_comp1 + d_x_comp2_and_h_prevr) 172 .slice(x_offsets(), x_extends()); 173 174 // d_h_prev = d_h_comp1 + d_hr*r + d_h*u 175 d_h_prev.device(d) = 176 d_x_comp1_and_h_prev_comp1.slice(h_offsets(), h_extends()) + 177 (d_hr * r) + (d_h * u); 178 } 179 }; 180 181 } // namespace functor 182 } // namespace tensorflow 183 184 #endif // TENSORFLOW_CONTRIB_RNN_KERNELS_GRU_OPS_H_ 185