Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_
     17 #define TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_
     18 
     19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     20 #include "tensorflow/contrib/rnn/kernels/blas_gemm.h"
     21 #include "tensorflow/core/framework/tensor_types.h"
     22 #include "tensorflow/core/kernels/eigen_activations.h"
     23 #include "tensorflow/core/platform/types.h"
     24 
     25 namespace tensorflow {
     26 class OpKernelContext;
     27 
     28 namespace functor {
     29 
     30 template <typename Device, typename T>
     31 struct TensorZero {
     32   void operator()(const Device& d, typename TTypes<T>::Flat t) {
     33     t.device(d) = t.constant(T(0));
     34   }
     35 };
     36 
     37 template <typename Device, typename T>
     38 struct TensorUnalignedZero {
     39   void operator()(const Device& d, typename TTypes<T>::UnalignedFlat t) {
     40     t.device(d) = t.constant(T(0));
     41   }
     42 };
     43 
     44 template <typename Device, typename T>
     45 struct TensorCopy {
     46   void operator()(const Device& d, typename TTypes<T>::ConstFlat src,
     47                   typename TTypes<T>::Flat dst) {
     48     dst.device(d) = src;
     49   }
     50 };
     51 
     52 template <typename Device, typename T>
     53 struct TensorCopyUnaligned {
     54   void operator()(const Device& d, typename TTypes<T>::UnalignedConstFlat src,
     55                   typename TTypes<T>::Flat dst) {
     56     dst.device(d) = src;
     57   }
     58 };
     59 
     60 template <typename Device, typename T>
     61 struct TensorCopyToUnaligned {
     62   void operator()(const Device& d, typename TTypes<T>::ConstFlat src,
     63                   typename TTypes<T>::UnalignedFlat dst) {
     64     dst.device(d) = src;
     65   }
     66 };
     67 
     68 template <typename Device, typename T>
     69 struct TensorAdd {
     70   void operator()(const Device& d, typename TTypes<T>::ConstFlat a,
     71                   typename TTypes<T>::ConstFlat b, typename TTypes<T>::Flat c) {
     72     c.device(d) = a + b;
     73   }
     74 };
     75 
     76 template <typename Device, typename T>
     77 struct TensorZeroPadding {
     78   void operator()(const Device& d, const int64 time_idx,
     79                   typename TTypes<int64>::ConstVec seq_len,
     80                   typename TTypes<float>::Vec mask,
     81                   typename TTypes<float>::Matrix m) {
     82     // mask is shape [batch_size].
     83     mask.device(d) = seq_len.constant(time_idx) < seq_len;
     84 
     85     // m_shape is [batch_size, 1].
     86     Eigen::array<Eigen::DenseIndex, 2> m_shape({m.dimensions()[0], 1});
     87     // broadcast_shape is [1, units].
     88     Eigen::array<Eigen::DenseIndex, 2> broadcast_shape({1, m.dimensions()[1]});
     89 
     90     // m is shape [batch_size, units].
     91     m.device(d) = m * mask.reshape(m_shape).broadcast(broadcast_shape);
     92   }
     93 };
     94 
     95 struct LSTMBlockCell {
     96   LSTMBlockCell(const int batch_size, const int input_size, const int cell_size)
     97       : batch_size_(batch_size),
     98         input_size_(input_size),
     99         cell_size_(cell_size) {}
    100 
    101   int batch_size() const { return batch_size_; }
    102 
    103   int input_size() const { return input_size_; }
    104 
    105   int cell_size() const { return cell_size_; }
    106 
    107   inline Eigen::array<Eigen::DenseIndex, 2> icfo_i_offsets() const {
    108     return {0, 0};
    109   }
    110 
    111   inline Eigen::array<Eigen::DenseIndex, 2> icfo_c_offsets() const {
    112     return {0, cell_size_};
    113   }
    114 
    115   inline Eigen::array<Eigen::DenseIndex, 2> icfo_f_offsets() const {
    116     return {0, cell_size_ * 2};
    117   }
    118 
    119   inline Eigen::array<Eigen::DenseIndex, 2> icfo_o_offsets() const {
    120     return {0, cell_size_ * 3};
    121   }
    122 
    123   inline Eigen::array<Eigen::DenseIndex, 2> cell_extents() const {
    124     return {batch_size_, cell_size_};
    125   }
    126 
    127   inline Eigen::array<Eigen::DenseIndex, 2> xh_x_offsets() const {
    128     return {0, 0};
    129   }
    130 
    131   inline Eigen::array<Eigen::DenseIndex, 2> xh_x_extents() const {
    132     return {batch_size_, input_size_};
    133   }
    134 
    135   inline Eigen::array<Eigen::DenseIndex, 2> xh_h_offsets() const {
    136     return {0, input_size_};
    137   }
    138 
    139   inline Eigen::array<Eigen::DenseIndex, 2> xh_h_extents() const {
    140     return {batch_size_, cell_size_};
    141   }
    142 
    143  protected:
    144   const int batch_size_;
    145   const int input_size_;
    146   const int cell_size_;
    147 };
    148 
    149 // See lstm_ops.cc for CPUDevice implementation and lstm_ops_gpu.cu.cc for
    150 // GPUDevice implementation.
    151 template <typename Device, typename T, bool USE_CUBLAS>
    152 struct LSTMBlockCellFprop : public LSTMBlockCell {
    153   LSTMBlockCellFprop(const int batch_size, const int input_size,
    154                      const int cell_size)
    155       : LSTMBlockCell(batch_size, input_size, cell_size) {}
    156 
    157   void operator()(
    158       OpKernelContext* ctx, const Device& d, const T forget_bias,
    159       const T cell_clip, bool use_peephole, typename TTypes<T>::ConstMatrix x,
    160       typename TTypes<T>::ConstMatrix cs_prev,
    161       typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
    162       typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
    163       typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b,
    164       typename TTypes<T>::Matrix xh, typename TTypes<T>::Matrix i,
    165       typename TTypes<T>::Matrix cs, typename TTypes<T>::Matrix f,
    166       typename TTypes<T>::Matrix o, typename TTypes<T>::Matrix ci,
    167       typename TTypes<T>::Matrix co, typename TTypes<T>::Matrix icfo,
    168       typename TTypes<T>::Matrix h);
    169 };
    170 
    171 // See lstm_ops.cc for CPUDevice implementation and lstm_ops_gpu.cu.cc for
    172 // GPUDevice implementation.
    173 template <typename Device, typename T, bool USE_CUBLAS>
    174 struct LSTMBlockCellBprop : public LSTMBlockCell {
    175   LSTMBlockCellBprop(const int batch_size, const int input_size,
    176                      const int cell_size)
    177       : LSTMBlockCell(batch_size, input_size, cell_size) {}
    178 
    179   void operator()(
    180       OpKernelContext* ctx, const Device& d, bool use_peephole,
    181       typename TTypes<T>::ConstMatrix x,
    182       typename TTypes<T>::ConstMatrix cs_prev,
    183       typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
    184       typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
    185       typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b,
    186       typename TTypes<T>::ConstMatrix i, typename TTypes<T>::ConstMatrix cs,
    187       typename TTypes<T>::ConstMatrix f, typename TTypes<T>::ConstMatrix o,
    188       typename TTypes<T>::ConstMatrix ci, typename TTypes<T>::ConstMatrix co,
    189       typename TTypes<T>::ConstMatrix cs_grad,
    190       typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,
    191       typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,
    192       typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,
    193       typename TTypes<T>::Matrix dicfo, typename TTypes<T>::Matrix cs_prev_grad,
    194       typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,
    195       typename TTypes<T>::Vec wco_grad);
    196 };
    197 
    198 template <typename Device, typename T, bool USE_CUBLAS>
    199 struct BlockLSTMBprop : public LSTMBlockCell {
    200   BlockLSTMBprop(const int batch_size, const int input_size,
    201                  const int cell_size)
    202       : LSTMBlockCell(batch_size, input_size, cell_size) {}
    203 
    204   void operator()(
    205       OpKernelContext* ctx, const Device& d, bool use_peephole,
    206       typename TTypes<T>::ConstMatrix x,
    207       typename TTypes<T>::ConstMatrix cs_prev,
    208       typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
    209       typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
    210       typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b,
    211       typename TTypes<T>::Matrix xh, typename TTypes<T>::ConstMatrix i,
    212       typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f,
    213       typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci,
    214       typename TTypes<T>::ConstMatrix co,
    215       typename TTypes<T>::ConstMatrix cs_grad,
    216       typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,
    217       typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,
    218       typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,
    219       typename TTypes<T>::Matrix dicfo, typename TTypes<T>::Matrix cs_prev_grad,
    220       typename TTypes<T>::Matrix h_prev_grad,
    221       typename TTypes<T>::Matrix xh_grad, typename TTypes<T>::Matrix x_grad,
    222       typename TTypes<T>::Matrix w_grad, typename TTypes<T>::Vec wci_grad,
    223       typename TTypes<T>::Vec wcf_grad, typename TTypes<T>::Vec wco_grad,
    224       typename TTypes<T>::Vec b_grad) {
    225     // do[t] = sigm'(o[t]) .* dh[t] .* co[t]
    226     do_.device(d) = o * (o.constant(T(1)) - o) * h_grad * co;
    227 
    228     // dcs[t] += tanh'(cs[t]) .* dh[t] .* o[t] + dcs[t + 1] .* f[t + 1]
    229     dcs.device(d) = (co.constant(T(1)) - co * co) * h_grad * o + cs_grad;
    230 
    231     Eigen::array<Eigen::DenseIndex, 2> p_shape({1, cell_size_});
    232     Eigen::array<Eigen::DenseIndex, 2> p_broadcast_shape({batch_size_, 1});
    233     if (use_peephole) {
    234       dcs.device(d) =
    235           dcs + do_ * wco.reshape(p_shape).broadcast(p_broadcast_shape);
    236     }
    237 
    238     // dci[t] = tanh'(ci[t]) dcs[t] i[t]
    239     dci.device(d) = (ci.constant(T(1)) - ci * ci) * dcs * i;
    240 
    241     // df[t] = sigm'(f[t]) dcs[t] cs[t - 1]
    242     df.device(d) = f * (f.constant(T(1)) - f) * dcs * cs_prev;
    243 
    244     // di[t] = sigm'(i[t]) dcs[t] ci[t]
    245     di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci;
    246 
    247     dicfo.slice(icfo_i_offsets(), cell_extents()).device(d) = di;
    248     dicfo.slice(icfo_c_offsets(), cell_extents()).device(d) = dci;
    249     dicfo.slice(icfo_f_offsets(), cell_extents()).device(d) = df;
    250     dicfo.slice(icfo_o_offsets(), cell_extents()).device(d) = do_;
    251 
    252     cs_prev_grad.device(d) = dcs * f;
    253     if (use_peephole) {
    254       cs_prev_grad.device(d) =
    255           cs_prev_grad +
    256           di * wci.reshape(p_shape).broadcast(p_broadcast_shape) +
    257           df * wcf.reshape(p_shape).broadcast(p_broadcast_shape);
    258     }
    259 
    260     // xh_grad.
    261     typename TTypes<T>::ConstMatrix const_dicfo(dicfo.data(),
    262                                                 dicfo.dimensions());
    263     TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
    264         ctx, d, false, true, T(1), const_dicfo, w, T(0), xh_grad);
    265 
    266     // xh.
    267     xh.slice(xh_x_offsets(), xh_x_extents()).device(d) = x;
    268     xh.slice(xh_h_offsets(), xh_h_extents()).device(d) = h_prev;
    269     typename TTypes<T>::ConstMatrix const_xh(xh.data(), xh.dimensions());
    270 
    271     // x_grad.
    272     x_grad.device(d) = xh_grad.slice(xh_x_offsets(), xh_x_extents());
    273     h_prev_grad.device(d) = xh_grad.slice(xh_h_offsets(), xh_h_extents());
    274 
    275     // w_grad.
    276     TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
    277         ctx, d, true, false, T(1), const_xh, const_dicfo, T(1), w_grad);
    278 
    279     // b_grad.
    280     b_grad.device(d) += dicfo.sum(Eigen::array<int, 1>({0}));
    281 
    282     if (use_peephole) {
    283       wci_grad.device(d) += (di * cs_prev).sum(Eigen::array<int, 1>({0}));
    284       wcf_grad.device(d) += (df * cs_prev).sum(Eigen::array<int, 1>({0}));
    285       wco_grad.device(d) += (do_ * cs).sum(Eigen::array<int, 1>({0}));
    286     }
    287   }
    288 };
    289 
    290 }  // namespace functor
    291 }  // namespace tensorflow
    292 
    293 #endif  // TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_
    294