1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_ 17 #define TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_ 18 19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 20 #include "tensorflow/contrib/rnn/kernels/blas_gemm.h" 21 #include "tensorflow/core/framework/tensor_types.h" 22 #include "tensorflow/core/kernels/eigen_activations.h" 23 #include "tensorflow/core/platform/types.h" 24 25 namespace tensorflow { 26 class OpKernelContext; 27 28 namespace functor { 29 30 template <typename Device, typename T> 31 struct TensorZero { 32 void operator()(const Device& d, typename TTypes<T>::Flat t) { 33 t.device(d) = t.constant(T(0)); 34 } 35 }; 36 37 template <typename Device, typename T> 38 struct TensorUnalignedZero { 39 void operator()(const Device& d, typename TTypes<T>::UnalignedFlat t) { 40 t.device(d) = t.constant(T(0)); 41 } 42 }; 43 44 template <typename Device, typename T> 45 struct TensorCopy { 46 void operator()(const Device& d, typename TTypes<T>::ConstFlat src, 47 typename TTypes<T>::Flat dst) { 48 dst.device(d) = src; 49 } 50 }; 51 52 template <typename Device, typename T> 53 struct TensorCopyUnaligned { 54 void operator()(const Device& d, typename TTypes<T>::UnalignedConstFlat src, 55 typename TTypes<T>::Flat dst) { 56 dst.device(d) = src; 57 } 58 }; 59 60 template <typename Device, typename T> 61 struct TensorCopyToUnaligned { 62 void operator()(const Device& d, typename TTypes<T>::ConstFlat src, 63 typename TTypes<T>::UnalignedFlat dst) { 64 dst.device(d) = src; 65 } 66 }; 67 68 template <typename Device, typename T> 69 struct TensorAdd { 70 void operator()(const Device& d, typename TTypes<T>::ConstFlat a, 71 typename TTypes<T>::ConstFlat b, typename TTypes<T>::Flat c) { 72 c.device(d) = a + b; 73 } 74 }; 75 76 template <typename Device, typename T> 77 struct TensorZeroPadding { 78 void operator()(const Device& d, const int64 time_idx, 79 typename TTypes<int64>::ConstVec seq_len, 80 typename TTypes<T>::Vec mask, typename TTypes<T>::Matrix m) { 81 // mask is shape [batch_size]. 82 mask.device(d) = seq_len.constant(time_idx) < seq_len; 83 84 // m_shape is [batch_size, 1]. 85 Eigen::array<Eigen::DenseIndex, 2> m_shape({m.dimensions()[0], 1}); 86 // broadcast_shape is [1, units]. 87 Eigen::array<Eigen::DenseIndex, 2> broadcast_shape({1, m.dimensions()[1]}); 88 89 // m is shape [batch_size, units]. 90 m.device(d) = m * mask.reshape(m_shape).broadcast(broadcast_shape); 91 } 92 }; 93 94 struct LSTMBlockCell { 95 LSTMBlockCell(const int batch_size, const int input_size, const int cell_size) 96 : batch_size_(batch_size), 97 input_size_(input_size), 98 cell_size_(cell_size) {} 99 100 int batch_size() const { return batch_size_; } 101 102 int input_size() const { return input_size_; } 103 104 int cell_size() const { return cell_size_; } 105 106 inline Eigen::array<Eigen::DenseIndex, 2> icfo_i_offsets() const { 107 return {0, 0}; 108 } 109 110 inline Eigen::array<Eigen::DenseIndex, 2> icfo_c_offsets() const { 111 return {0, cell_size_}; 112 } 113 114 inline Eigen::array<Eigen::DenseIndex, 2> icfo_f_offsets() const { 115 return {0, cell_size_ * 2}; 116 } 117 118 inline Eigen::array<Eigen::DenseIndex, 2> icfo_o_offsets() const { 119 return {0, cell_size_ * 3}; 120 } 121 122 inline Eigen::array<Eigen::DenseIndex, 2> cell_extents() const { 123 return {batch_size_, cell_size_}; 124 } 125 126 inline Eigen::array<Eigen::DenseIndex, 2> xh_x_offsets() const { 127 return {0, 0}; 128 } 129 130 inline Eigen::array<Eigen::DenseIndex, 2> xh_x_extents() const { 131 return {batch_size_, input_size_}; 132 } 133 134 inline Eigen::array<Eigen::DenseIndex, 2> xh_h_offsets() const { 135 return {0, input_size_}; 136 } 137 138 inline Eigen::array<Eigen::DenseIndex, 2> xh_h_extents() const { 139 return {batch_size_, cell_size_}; 140 } 141 142 protected: 143 const int batch_size_; 144 const int input_size_; 145 const int cell_size_; 146 }; 147 148 // See lstm_ops.cc for CPUDevice implementation and lstm_ops_gpu.cu.cc for 149 // GPUDevice implementation. 150 template <typename Device, typename T, bool USE_CUBLAS> 151 struct LSTMBlockCellFprop : public LSTMBlockCell { 152 LSTMBlockCellFprop(const int batch_size, const int input_size, 153 const int cell_size) 154 : LSTMBlockCell(batch_size, input_size, cell_size) {} 155 156 void operator()(OpKernelContext* ctx, const Device& d, 157 const float forget_bias, const float cell_clip, 158 bool use_peephole, typename TTypes<T>::ConstMatrix x, 159 typename TTypes<T>::ConstMatrix cs_prev, 160 typename TTypes<T>::ConstMatrix h_prev, 161 typename TTypes<T>::ConstMatrix w, 162 typename TTypes<T>::ConstVec wci, 163 typename TTypes<T>::ConstVec wcf, 164 typename TTypes<T>::ConstVec wco, 165 typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, 166 typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs, 167 typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o, 168 typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, 169 typename TTypes<T>::Matrix icfo, 170 typename TTypes<T>::Matrix h); 171 }; 172 173 // See lstm_ops.cc for CPUDevice implementation and lstm_ops_gpu.cu.cc for 174 // GPUDevice implementation. 175 template <typename Device, typename T, bool USE_CUBLAS> 176 struct LSTMBlockCellBprop : public LSTMBlockCell { 177 LSTMBlockCellBprop(const int batch_size, const int input_size, 178 const int cell_size) 179 : LSTMBlockCell(batch_size, input_size, cell_size) {} 180 181 void operator()( 182 OpKernelContext* ctx, const Device& d, bool use_peephole, 183 typename TTypes<T>::ConstMatrix x, 184 typename TTypes<T>::ConstMatrix cs_prev, 185 typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w, 186 typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf, 187 typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b, 188 typename TTypes<T>::ConstMatrix i, typename TTypes<T>::ConstMatrix cs, 189 typename TTypes<T>::ConstMatrix f, typename TTypes<T>::ConstMatrix o, 190 typename TTypes<T>::ConstMatrix ci, typename TTypes<T>::ConstMatrix co, 191 typename TTypes<T>::ConstMatrix cs_grad, 192 typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, 193 typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, 194 typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, 195 typename TTypes<T>::Matrix dicfo, typename TTypes<T>::Matrix cs_prev_grad, 196 typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad, 197 typename TTypes<T>::Vec wco_grad); 198 }; 199 200 template <typename Device, typename T, bool USE_CUBLAS> 201 struct BlockLSTMBprop : public LSTMBlockCell { 202 BlockLSTMBprop(const int batch_size, const int input_size, 203 const int cell_size) 204 : LSTMBlockCell(batch_size, input_size, cell_size) {} 205 206 void operator()( 207 OpKernelContext* ctx, const Device& d, bool use_peephole, 208 typename TTypes<T>::ConstMatrix x, 209 typename TTypes<T>::ConstMatrix cs_prev, 210 typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w, 211 typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf, 212 typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b, 213 typename TTypes<T>::Matrix xh, typename TTypes<T>::ConstMatrix i, 214 typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f, 215 typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci, 216 typename TTypes<T>::ConstMatrix co, 217 typename TTypes<T>::ConstMatrix cs_grad, 218 typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, 219 typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, 220 typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, 221 typename TTypes<T>::Matrix dicfo, typename TTypes<T>::Matrix cs_prev_grad, 222 typename TTypes<T>::Matrix h_prev_grad, 223 typename TTypes<T>::Matrix xh_grad, typename TTypes<T>::Matrix x_grad, 224 typename TTypes<T>::Matrix w_grad, typename TTypes<T>::Vec wci_grad, 225 typename TTypes<T>::Vec wcf_grad, typename TTypes<T>::Vec wco_grad, 226 typename TTypes<T>::Vec b_grad) { 227 // do[t] = sigm'(o[t]) .* dh[t] .* co[t] 228 do_.device(d) = o * (o.constant(T(1)) - o) * h_grad * co; 229 230 // dcs[t] += tanh'(cs[t]) .* dh[t] .* o[t] + dcs[t + 1] .* f[t + 1] 231 dcs.device(d) = (co.constant(T(1)) - co * co) * h_grad * o + cs_grad; 232 233 Eigen::array<Eigen::DenseIndex, 2> p_shape({1, cell_size_}); 234 Eigen::array<Eigen::DenseIndex, 2> p_broadcast_shape({batch_size_, 1}); 235 if (use_peephole) { 236 dcs.device(d) = 237 dcs + do_ * wco.reshape(p_shape).broadcast(p_broadcast_shape); 238 } 239 240 // dci[t] = tanh'(ci[t]) dcs[t] i[t] 241 dci.device(d) = (ci.constant(T(1)) - ci * ci) * dcs * i; 242 243 // df[t] = sigm'(f[t]) dcs[t] cs[t - 1] 244 df.device(d) = f * (f.constant(T(1)) - f) * dcs * cs_prev; 245 246 // di[t] = sigm'(i[t]) dcs[t] ci[t] 247 di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci; 248 249 dicfo.slice(icfo_i_offsets(), cell_extents()).device(d) = di; 250 dicfo.slice(icfo_c_offsets(), cell_extents()).device(d) = dci; 251 dicfo.slice(icfo_f_offsets(), cell_extents()).device(d) = df; 252 dicfo.slice(icfo_o_offsets(), cell_extents()).device(d) = do_; 253 254 cs_prev_grad.device(d) = dcs * f; 255 if (use_peephole) { 256 cs_prev_grad.device(d) = 257 cs_prev_grad + 258 di * wci.reshape(p_shape).broadcast(p_broadcast_shape) + 259 df * wcf.reshape(p_shape).broadcast(p_broadcast_shape); 260 } 261 262 // xh_grad. 263 typename TTypes<T>::ConstMatrix const_dicfo(dicfo.data(), 264 dicfo.dimensions()); 265 TensorBlasGemm<Device, T, USE_CUBLAS>::compute( 266 ctx, d, false, true, 1.f, const_dicfo, w, 0.f, xh_grad); 267 268 // xh. 269 xh.slice(xh_x_offsets(), xh_x_extents()).device(d) = x; 270 xh.slice(xh_h_offsets(), xh_h_extents()).device(d) = h_prev; 271 typename TTypes<T>::ConstMatrix const_xh(xh.data(), xh.dimensions()); 272 273 // x_grad. 274 x_grad.device(d) = xh_grad.slice(xh_x_offsets(), xh_x_extents()); 275 h_prev_grad.device(d) = xh_grad.slice(xh_h_offsets(), xh_h_extents()); 276 277 // w_grad. 278 TensorBlasGemm<Device, T, USE_CUBLAS>::compute( 279 ctx, d, true, false, 1.f, const_xh, const_dicfo, 1.f, w_grad); 280 281 // b_grad. 282 b_grad.device(d) += dicfo.sum(Eigen::array<int, 1>({0})); 283 284 if (use_peephole) { 285 wci_grad.device(d) += (di * cs_prev).sum(Eigen::array<int, 1>({0})); 286 wcf_grad.device(d) += (df * cs_prev).sum(Eigen::array<int, 1>({0})); 287 wco_grad.device(d) += (do_ * cs).sum(Eigen::array<int, 1>({0})); 288 } 289 } 290 }; 291 292 } // namespace functor 293 } // namespace tensorflow 294 295 #endif // TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_ 296