1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_ 17 #define TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_ 18 19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 20 #include "tensorflow/contrib/rnn/kernels/blas_gemm.h" 21 #include "tensorflow/core/framework/tensor_types.h" 22 #include "tensorflow/core/kernels/eigen_activations.h" 23 #include "tensorflow/core/platform/types.h" 24 25 namespace tensorflow { 26 class OpKernelContext; 27 28 namespace functor { 29 30 template <typename Device, typename T> 31 struct TensorZero { 32 void operator()(const Device& d, typename TTypes<T>::Flat t) { 33 t.device(d) = t.constant(T(0)); 34 } 35 }; 36 37 template <typename Device, typename T> 38 struct TensorUnalignedZero { 39 void operator()(const Device& d, typename TTypes<T>::UnalignedFlat t) { 40 t.device(d) = t.constant(T(0)); 41 } 42 }; 43 44 template <typename Device, typename T> 45 struct TensorCopy { 46 void operator()(const Device& d, typename TTypes<T>::ConstFlat src, 47 typename TTypes<T>::Flat dst) { 48 dst.device(d) = src; 49 } 50 }; 51 52 template <typename Device, typename T> 53 struct TensorCopyUnaligned { 54 void operator()(const Device& d, typename TTypes<T>::UnalignedConstFlat src, 55 typename TTypes<T>::Flat dst) { 56 dst.device(d) = src; 57 } 58 }; 59 60 template <typename Device, typename T> 61 struct TensorCopyToUnaligned { 62 void operator()(const Device& d, typename TTypes<T>::ConstFlat src, 63 typename TTypes<T>::UnalignedFlat dst) { 64 dst.device(d) = src; 65 } 66 }; 67 68 template <typename Device, typename T> 69 struct TensorAdd { 70 void operator()(const Device& d, typename TTypes<T>::ConstFlat a, 71 typename TTypes<T>::ConstFlat b, typename TTypes<T>::Flat c) { 72 c.device(d) = a + b; 73 } 74 }; 75 76 template <typename Device, typename T> 77 struct TensorZeroPadding { 78 void operator()(const Device& d, const int64 time_idx, 79 typename TTypes<int64>::ConstVec seq_len, 80 typename TTypes<float>::Vec mask, 81 typename TTypes<float>::Matrix m) { 82 // mask is shape [batch_size]. 83 mask.device(d) = seq_len.constant(time_idx) < seq_len; 84 85 // m_shape is [batch_size, 1]. 86 Eigen::array<Eigen::DenseIndex, 2> m_shape({m.dimensions()[0], 1}); 87 // broadcast_shape is [1, units]. 88 Eigen::array<Eigen::DenseIndex, 2> broadcast_shape({1, m.dimensions()[1]}); 89 90 // m is shape [batch_size, units]. 91 m.device(d) = m * mask.reshape(m_shape).broadcast(broadcast_shape); 92 } 93 }; 94 95 struct LSTMBlockCell { 96 LSTMBlockCell(const int batch_size, const int input_size, const int cell_size) 97 : batch_size_(batch_size), 98 input_size_(input_size), 99 cell_size_(cell_size) {} 100 101 int batch_size() const { return batch_size_; } 102 103 int input_size() const { return input_size_; } 104 105 int cell_size() const { return cell_size_; } 106 107 inline Eigen::array<Eigen::DenseIndex, 2> icfo_i_offsets() const { 108 return {0, 0}; 109 } 110 111 inline Eigen::array<Eigen::DenseIndex, 2> icfo_c_offsets() const { 112 return {0, cell_size_}; 113 } 114 115 inline Eigen::array<Eigen::DenseIndex, 2> icfo_f_offsets() const { 116 return {0, cell_size_ * 2}; 117 } 118 119 inline Eigen::array<Eigen::DenseIndex, 2> icfo_o_offsets() const { 120 return {0, cell_size_ * 3}; 121 } 122 123 inline Eigen::array<Eigen::DenseIndex, 2> cell_extents() const { 124 return {batch_size_, cell_size_}; 125 } 126 127 inline Eigen::array<Eigen::DenseIndex, 2> xh_x_offsets() const { 128 return {0, 0}; 129 } 130 131 inline Eigen::array<Eigen::DenseIndex, 2> xh_x_extents() const { 132 return {batch_size_, input_size_}; 133 } 134 135 inline Eigen::array<Eigen::DenseIndex, 2> xh_h_offsets() const { 136 return {0, input_size_}; 137 } 138 139 inline Eigen::array<Eigen::DenseIndex, 2> xh_h_extents() const { 140 return {batch_size_, cell_size_}; 141 } 142 143 protected: 144 const int batch_size_; 145 const int input_size_; 146 const int cell_size_; 147 }; 148 149 // See lstm_ops.cc for CPUDevice implementation and lstm_ops_gpu.cu.cc for 150 // GPUDevice implementation. 151 template <typename Device, typename T, bool USE_CUBLAS> 152 struct LSTMBlockCellFprop : public LSTMBlockCell { 153 LSTMBlockCellFprop(const int batch_size, const int input_size, 154 const int cell_size) 155 : LSTMBlockCell(batch_size, input_size, cell_size) {} 156 157 void operator()( 158 OpKernelContext* ctx, const Device& d, const T forget_bias, 159 const T cell_clip, bool use_peephole, typename TTypes<T>::ConstMatrix x, 160 typename TTypes<T>::ConstMatrix cs_prev, 161 typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w, 162 typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf, 163 typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b, 164 typename TTypes<T>::Matrix xh, typename TTypes<T>::Matrix i, 165 typename TTypes<T>::Matrix cs, typename TTypes<T>::Matrix f, 166 typename TTypes<T>::Matrix o, typename TTypes<T>::Matrix ci, 167 typename TTypes<T>::Matrix co, typename TTypes<T>::Matrix icfo, 168 typename TTypes<T>::Matrix h); 169 }; 170 171 // See lstm_ops.cc for CPUDevice implementation and lstm_ops_gpu.cu.cc for 172 // GPUDevice implementation. 173 template <typename Device, typename T, bool USE_CUBLAS> 174 struct LSTMBlockCellBprop : public LSTMBlockCell { 175 LSTMBlockCellBprop(const int batch_size, const int input_size, 176 const int cell_size) 177 : LSTMBlockCell(batch_size, input_size, cell_size) {} 178 179 void operator()( 180 OpKernelContext* ctx, const Device& d, bool use_peephole, 181 typename TTypes<T>::ConstMatrix x, 182 typename TTypes<T>::ConstMatrix cs_prev, 183 typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w, 184 typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf, 185 typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b, 186 typename TTypes<T>::ConstMatrix i, typename TTypes<T>::ConstMatrix cs, 187 typename TTypes<T>::ConstMatrix f, typename TTypes<T>::ConstMatrix o, 188 typename TTypes<T>::ConstMatrix ci, typename TTypes<T>::ConstMatrix co, 189 typename TTypes<T>::ConstMatrix cs_grad, 190 typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, 191 typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, 192 typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, 193 typename TTypes<T>::Matrix dicfo, typename TTypes<T>::Matrix cs_prev_grad, 194 typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad, 195 typename TTypes<T>::Vec wco_grad); 196 }; 197 198 template <typename Device, typename T, bool USE_CUBLAS> 199 struct BlockLSTMBprop : public LSTMBlockCell { 200 BlockLSTMBprop(const int batch_size, const int input_size, 201 const int cell_size) 202 : LSTMBlockCell(batch_size, input_size, cell_size) {} 203 204 void operator()( 205 OpKernelContext* ctx, const Device& d, bool use_peephole, 206 typename TTypes<T>::ConstMatrix x, 207 typename TTypes<T>::ConstMatrix cs_prev, 208 typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w, 209 typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf, 210 typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b, 211 typename TTypes<T>::Matrix xh, typename TTypes<T>::ConstMatrix i, 212 typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f, 213 typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci, 214 typename TTypes<T>::ConstMatrix co, 215 typename TTypes<T>::ConstMatrix cs_grad, 216 typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, 217 typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, 218 typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, 219 typename TTypes<T>::Matrix dicfo, typename TTypes<T>::Matrix cs_prev_grad, 220 typename TTypes<T>::Matrix h_prev_grad, 221 typename TTypes<T>::Matrix xh_grad, typename TTypes<T>::Matrix x_grad, 222 typename TTypes<T>::Matrix w_grad, typename TTypes<T>::Vec wci_grad, 223 typename TTypes<T>::Vec wcf_grad, typename TTypes<T>::Vec wco_grad, 224 typename TTypes<T>::Vec b_grad) { 225 // do[t] = sigm'(o[t]) .* dh[t] .* co[t] 226 do_.device(d) = o * (o.constant(T(1)) - o) * h_grad * co; 227 228 // dcs[t] += tanh'(cs[t]) .* dh[t] .* o[t] + dcs[t + 1] .* f[t + 1] 229 dcs.device(d) = (co.constant(T(1)) - co * co) * h_grad * o + cs_grad; 230 231 Eigen::array<Eigen::DenseIndex, 2> p_shape({1, cell_size_}); 232 Eigen::array<Eigen::DenseIndex, 2> p_broadcast_shape({batch_size_, 1}); 233 if (use_peephole) { 234 dcs.device(d) = 235 dcs + do_ * wco.reshape(p_shape).broadcast(p_broadcast_shape); 236 } 237 238 // dci[t] = tanh'(ci[t]) dcs[t] i[t] 239 dci.device(d) = (ci.constant(T(1)) - ci * ci) * dcs * i; 240 241 // df[t] = sigm'(f[t]) dcs[t] cs[t - 1] 242 df.device(d) = f * (f.constant(T(1)) - f) * dcs * cs_prev; 243 244 // di[t] = sigm'(i[t]) dcs[t] ci[t] 245 di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci; 246 247 dicfo.slice(icfo_i_offsets(), cell_extents()).device(d) = di; 248 dicfo.slice(icfo_c_offsets(), cell_extents()).device(d) = dci; 249 dicfo.slice(icfo_f_offsets(), cell_extents()).device(d) = df; 250 dicfo.slice(icfo_o_offsets(), cell_extents()).device(d) = do_; 251 252 cs_prev_grad.device(d) = dcs * f; 253 if (use_peephole) { 254 cs_prev_grad.device(d) = 255 cs_prev_grad + 256 di * wci.reshape(p_shape).broadcast(p_broadcast_shape) + 257 df * wcf.reshape(p_shape).broadcast(p_broadcast_shape); 258 } 259 260 // xh_grad. 261 typename TTypes<T>::ConstMatrix const_dicfo(dicfo.data(), 262 dicfo.dimensions()); 263 TensorBlasGemm<Device, T, USE_CUBLAS>::compute( 264 ctx, d, false, true, T(1), const_dicfo, w, T(0), xh_grad); 265 266 // xh. 267 xh.slice(xh_x_offsets(), xh_x_extents()).device(d) = x; 268 xh.slice(xh_h_offsets(), xh_h_extents()).device(d) = h_prev; 269 typename TTypes<T>::ConstMatrix const_xh(xh.data(), xh.dimensions()); 270 271 // x_grad. 272 x_grad.device(d) = xh_grad.slice(xh_x_offsets(), xh_x_extents()); 273 h_prev_grad.device(d) = xh_grad.slice(xh_h_offsets(), xh_h_extents()); 274 275 // w_grad. 276 TensorBlasGemm<Device, T, USE_CUBLAS>::compute( 277 ctx, d, true, false, T(1), const_xh, const_dicfo, T(1), w_grad); 278 279 // b_grad. 280 b_grad.device(d) += dicfo.sum(Eigen::array<int, 1>({0})); 281 282 if (use_peephole) { 283 wci_grad.device(d) += (di * cs_prev).sum(Eigen::array<int, 1>({0})); 284 wcf_grad.device(d) += (df * cs_prev).sum(Eigen::array<int, 1>({0})); 285 wco_grad.device(d) += (do_ * cs).sum(Eigen::array<int, 1>({0})); 286 } 287 } 288 }; 289 290 } // namespace functor 291 } // namespace tensorflow 292 293 #endif // TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_ 294