1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CONTRIB_RNN_KERNELS_GRU_OPS_H_ 17 #define TENSORFLOW_CONTRIB_RNN_KERNELS_GRU_OPS_H_ 18 19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 20 #include "tensorflow/contrib/rnn/kernels/blas_gemm.h" 21 #include "tensorflow/core/framework/tensor_types.h" 22 #include "tensorflow/core/platform/types.h" 23 24 namespace tensorflow { 25 26 class OpKernelContext; 27 28 namespace functor { 29 30 struct GRUCell { 31 GRUCell(const int batch_size, const int input_size, const int cell_size) 32 : batch_size_(batch_size), 33 input_size_(input_size), 34 cell_size_(cell_size) {} 35 36 inline Eigen::array<Eigen::DenseIndex, 2> x_offsets() const { return {0, 0}; } 37 38 inline Eigen::array<Eigen::DenseIndex, 2> x_extends() const { 39 return {batch_size_, input_size_}; 40 } 41 42 inline Eigen::array<Eigen::DenseIndex, 2> h_offsets() const { 43 return {0, input_size_}; 44 } 45 46 inline Eigen::array<Eigen::DenseIndex, 2> h_extends() const { 47 return {batch_size_, cell_size_}; 48 } 49 50 inline Eigen::array<Eigen::DenseIndex, 2> ru_r_offset() const { 51 return {0, 0}; 52 } 53 54 inline Eigen::array<Eigen::DenseIndex, 2> ru_u_offset() const { 55 return {0, cell_size_}; 56 } 57 58 inline Eigen::array<Eigen::DenseIndex, 2> cell_extents() const { 59 return {batch_size_, cell_size_}; 60 } 61 62 protected: 63 const int batch_size_; 64 const int input_size_; 65 const int cell_size_; 66 }; 67 68 template <typename Device, typename T, bool USE_CUBLAS> 69 struct GRUBlockCellFprop : public GRUCell { 70 GRUBlockCellFprop(const int batch_size, const int input_size, 71 const int cell_size) 72 : GRUCell(batch_size, input_size, cell_size) {} 73 74 void operator()( 75 OpKernelContext* ctx, const Device& d, typename TTypes<T>::ConstMatrix x, 76 typename TTypes<T>::ConstMatrix h_prev, 77 typename TTypes<T>::ConstMatrix w_ru, typename TTypes<T>::ConstMatrix w_c, 78 typename TTypes<T>::ConstVec b_ru, typename TTypes<T>::ConstVec b_c, 79 typename TTypes<T>::Matrix r_u_bar, typename TTypes<T>::Matrix r, 80 typename TTypes<T>::Matrix u, typename TTypes<T>::Matrix c, 81 typename TTypes<T>::Matrix h, typename TTypes<T>::Matrix x_h_prev, 82 typename TTypes<T>::Matrix x_h_prevr) { 83 // Concat x_h_prev = [x, h_prev]. 84 x_h_prev.slice(x_offsets(), x_extends()).device(d) = x; 85 x_h_prev.slice(h_offsets(), h_extends()).device(d) = h_prev; 86 87 // r_u_bar = x_h_prev * w_ru + b_ru 88 typename TTypes<T>::ConstMatrix const_x_h_prev(x_h_prev.data(), 89 x_h_prev.dimensions()); 90 TensorBlasGemm<Device, T, USE_CUBLAS>::compute( 91 ctx, d, false, false, typename gemm_compute_type<T>::type(1.f), 92 const_x_h_prev, w_ru, typename gemm_compute_type<T>::type(0.f), 93 r_u_bar); 94 95 // Creating a bias matrix for adding by broadcasting 'b_ru' 96 Eigen::array<Eigen::DenseIndex, 2> broadcast_shape({batch_size_, 1}); 97 Eigen::array<Eigen::DenseIndex, 2> b_ru_shape({1, b_ru.dimensions()[0]}); 98 r_u_bar.device(d) += b_ru.reshape(b_ru_shape).broadcast(broadcast_shape); 99 100 // Slice r_u_bar into r, u and apply the sigmoid. 101 r.device(d) = (r_u_bar.slice(ru_r_offset(), cell_extents())).sigmoid(); 102 u.device(d) = (r_u_bar.slice(ru_u_offset(), cell_extents())).sigmoid(); 103 104 // Concat x_h_prevr = [x,h_prev*r] 105 x_h_prevr.slice(x_offsets(), x_extends()).device(d) = x; 106 x_h_prevr.slice(h_offsets(), h_extends()).device(d) = h_prev * r; 107 108 // c = tanh(x_h_prevr*w_c+b_c), Note b_c is broadcasted before adding. 109 typename TTypes<T>::ConstMatrix const_x_h_prevr(x_h_prevr.data(), 110 x_h_prevr.dimensions()); 111 TensorBlasGemm<Device, T, USE_CUBLAS>::compute( 112 ctx, d, false, false, typename gemm_compute_type<T>::type(1.f), 113 const_x_h_prevr, w_c, typename gemm_compute_type<T>::type(0.f), c); 114 115 Eigen::array<Eigen::DenseIndex, 2> b_c_shape({1, b_c.dimensions()[0]}); 116 c.device(d) += (b_c.reshape(b_c_shape).broadcast(broadcast_shape)); 117 c.device(d) = c.tanh(); 118 119 // h= u*h_prev + (1-u)*c 120 h.device(d) = u * (h_prev - c) + c; 121 } 122 }; 123 124 template <typename Device, typename T, bool USE_CUBLAS> 125 struct GRUBlockCellBprop : public GRUCell { 126 GRUBlockCellBprop(const int batch_size, const int input_size, 127 const int cell_size) 128 : GRUCell(batch_size, input_size, cell_size) {} 129 130 void operator()( 131 OpKernelContext* ctx, const Device& d, typename TTypes<T>::ConstMatrix x, 132 typename TTypes<T>::ConstMatrix h_prev, 133 typename TTypes<T>::ConstMatrix w_ru, typename TTypes<T>::ConstMatrix w_c, 134 typename TTypes<T>::ConstVec b_ru, typename TTypes<T>::ConstVec b_c, 135 typename TTypes<T>::ConstMatrix r, typename TTypes<T>::ConstMatrix u, 136 typename TTypes<T>::ConstMatrix c, typename TTypes<T>::ConstMatrix d_h, 137 typename TTypes<T>::Matrix d_x, typename TTypes<T>::Matrix d_h_prev, 138 typename TTypes<T>::Matrix d_c_bar, 139 typename TTypes<T>::Matrix d_r_bar_u_bar, 140 typename TTypes<T>::Matrix d_r_bar, typename TTypes<T>::Matrix d_u_bar, 141 typename TTypes<T>::Matrix d_hr, 142 typename TTypes<T>::Matrix d_x_comp1_and_h_prev_comp1, 143 typename TTypes<T>::Matrix d_x_comp2_and_h_prevr) { 144 // d_c_bar = d_h*(1-u)*(1-(c*c)) 145 d_c_bar.device(d) = 146 ((d_h * (u.constant(T(1)) - u)) * (c.constant(T(1)) - c * c)); 147 148 // d_u_bar = d_h*(h-c)*(u*(1-u)) 149 d_u_bar.device(d) = d_h * (h_prev - c) * u * (u.constant(T(1)) - u); 150 151 // [2nd_component_of_d_x d_h_prevr] = d_c_bar X w_c^T 152 typename TTypes<T>::ConstMatrix const_d_c_bar(d_c_bar.data(), 153 d_c_bar.dimensions()); 154 TensorBlasGemm<Device, T, USE_CUBLAS>::compute( 155 ctx, d, false, true, typename gemm_compute_type<T>::type(1.f), 156 const_d_c_bar, w_c, typename gemm_compute_type<T>::type(0.f), 157 d_x_comp2_and_h_prevr); 158 159 d_hr.device(d) = d_x_comp2_and_h_prevr.slice(h_offsets(), h_extends()); 160 d_r_bar.device(d) = (d_hr * h_prev * r) * (r.constant(T(1)) - r); 161 162 // d_r_bar_u_bar = concatenate(d_r_bar, d_u_bar) along axis = 1. 163 d_r_bar_u_bar.slice(ru_r_offset(), cell_extents()).device(d) = d_r_bar; 164 d_r_bar_u_bar.slice(ru_u_offset(), cell_extents()).device(d) = d_u_bar; 165 166 // [1st_component_of_d_x 1st_component_of_d_h_prev] = [d_r_bar d_u_bar] X 167 // w_ru^T 168 typename TTypes<T>::ConstMatrix const_d_r_bar_u_bar( 169 d_r_bar_u_bar.data(), d_r_bar_u_bar.dimensions()); 170 TensorBlasGemm<Device, T, USE_CUBLAS>::compute( 171 ctx, d, false, true, typename gemm_compute_type<T>::type(1.f), 172 const_d_r_bar_u_bar, w_ru, typename gemm_compute_type<T>::type(0.f), 173 d_x_comp1_and_h_prev_comp1); 174 175 // d_x = d_x_comp1 + d_x_comp2 176 d_x.device(d) = (d_x_comp1_and_h_prev_comp1 + d_x_comp2_and_h_prevr) 177 .slice(x_offsets(), x_extends()); 178 179 // d_h_prev = d_h_comp1 + d_hr*r + d_h*u 180 d_h_prev.device(d) = 181 d_x_comp1_and_h_prev_comp1.slice(h_offsets(), h_extends()) + 182 (d_hr * r) + (d_h * u); 183 } 184 }; 185 186 } // namespace functor 187 } // namespace tensorflow 188 189 #endif // TENSORFLOW_CONTRIB_RNN_KERNELS_GRU_OPS_H_ 190