Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_
     17 #define TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_
     18 
     19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     20 #include "tensorflow/contrib/rnn/kernels/blas_gemm.h"
     21 #include "tensorflow/core/framework/tensor_types.h"
     22 #include "tensorflow/core/kernels/eigen_activations.h"
     23 #include "tensorflow/core/platform/types.h"
     24 
     25 namespace tensorflow {
     26 class OpKernelContext;
     27 
     28 namespace functor {
     29 
     30 template <typename Device, typename T>
     31 struct TensorZero {
     32   void operator()(const Device& d, typename TTypes<T>::Flat t) {
     33     t.device(d) = t.constant(T(0));
     34   }
     35 };
     36 
     37 template <typename Device, typename T>
     38 struct TensorUnalignedZero {
     39   void operator()(const Device& d, typename TTypes<T>::UnalignedFlat t) {
     40     t.device(d) = t.constant(T(0));
     41   }
     42 };
     43 
     44 template <typename Device, typename T>
     45 struct TensorCopy {
     46   void operator()(const Device& d, typename TTypes<T>::ConstFlat src,
     47                   typename TTypes<T>::Flat dst) {
     48     dst.device(d) = src;
     49   }
     50 };
     51 
     52 template <typename Device, typename T>
     53 struct TensorCopyUnaligned {
     54   void operator()(const Device& d, typename TTypes<T>::UnalignedConstFlat src,
     55                   typename TTypes<T>::Flat dst) {
     56     dst.device(d) = src;
     57   }
     58 };
     59 
     60 template <typename Device, typename T>
     61 struct TensorCopyToUnaligned {
     62   void operator()(const Device& d, typename TTypes<T>::ConstFlat src,
     63                   typename TTypes<T>::UnalignedFlat dst) {
     64     dst.device(d) = src;
     65   }
     66 };
     67 
     68 template <typename Device, typename T>
     69 struct TensorAdd {
     70   void operator()(const Device& d, typename TTypes<T>::ConstFlat a,
     71                   typename TTypes<T>::ConstFlat b, typename TTypes<T>::Flat c) {
     72     c.device(d) = a + b;
     73   }
     74 };
     75 
     76 template <typename Device, typename T>
     77 struct TensorZeroPadding {
     78   void operator()(const Device& d, const int64 time_idx,
     79                   typename TTypes<int64>::ConstVec seq_len,
     80                   typename TTypes<T>::Vec mask, typename TTypes<T>::Matrix m) {
     81     // mask is shape [batch_size].
     82     mask.device(d) = seq_len.constant(time_idx) < seq_len;
     83 
     84     // m_shape is [batch_size, 1].
     85     Eigen::array<Eigen::DenseIndex, 2> m_shape({m.dimensions()[0], 1});
     86     // broadcast_shape is [1, units].
     87     Eigen::array<Eigen::DenseIndex, 2> broadcast_shape({1, m.dimensions()[1]});
     88 
     89     // m is shape [batch_size, units].
     90     m.device(d) = m * mask.reshape(m_shape).broadcast(broadcast_shape);
     91   }
     92 };
     93 
     94 struct LSTMBlockCell {
     95   LSTMBlockCell(const int batch_size, const int input_size, const int cell_size)
     96       : batch_size_(batch_size),
     97         input_size_(input_size),
     98         cell_size_(cell_size) {}
     99 
    100   int batch_size() const { return batch_size_; }
    101 
    102   int input_size() const { return input_size_; }
    103 
    104   int cell_size() const { return cell_size_; }
    105 
    106   inline Eigen::array<Eigen::DenseIndex, 2> icfo_i_offsets() const {
    107     return {0, 0};
    108   }
    109 
    110   inline Eigen::array<Eigen::DenseIndex, 2> icfo_c_offsets() const {
    111     return {0, cell_size_};
    112   }
    113 
    114   inline Eigen::array<Eigen::DenseIndex, 2> icfo_f_offsets() const {
    115     return {0, cell_size_ * 2};
    116   }
    117 
    118   inline Eigen::array<Eigen::DenseIndex, 2> icfo_o_offsets() const {
    119     return {0, cell_size_ * 3};
    120   }
    121 
    122   inline Eigen::array<Eigen::DenseIndex, 2> cell_extents() const {
    123     return {batch_size_, cell_size_};
    124   }
    125 
    126   inline Eigen::array<Eigen::DenseIndex, 2> xh_x_offsets() const {
    127     return {0, 0};
    128   }
    129 
    130   inline Eigen::array<Eigen::DenseIndex, 2> xh_x_extents() const {
    131     return {batch_size_, input_size_};
    132   }
    133 
    134   inline Eigen::array<Eigen::DenseIndex, 2> xh_h_offsets() const {
    135     return {0, input_size_};
    136   }
    137 
    138   inline Eigen::array<Eigen::DenseIndex, 2> xh_h_extents() const {
    139     return {batch_size_, cell_size_};
    140   }
    141 
    142  protected:
    143   const int batch_size_;
    144   const int input_size_;
    145   const int cell_size_;
    146 };
    147 
    148 // See lstm_ops.cc for CPUDevice implementation and lstm_ops_gpu.cu.cc for
    149 // GPUDevice implementation.
    150 template <typename Device, typename T, bool USE_CUBLAS>
    151 struct LSTMBlockCellFprop : public LSTMBlockCell {
    152   LSTMBlockCellFprop(const int batch_size, const int input_size,
    153                      const int cell_size)
    154       : LSTMBlockCell(batch_size, input_size, cell_size) {}
    155 
    156   void operator()(OpKernelContext* ctx, const Device& d,
    157                   const float forget_bias, const float cell_clip,
    158                   bool use_peephole, typename TTypes<T>::ConstMatrix x,
    159                   typename TTypes<T>::ConstMatrix cs_prev,
    160                   typename TTypes<T>::ConstMatrix h_prev,
    161                   typename TTypes<T>::ConstMatrix w,
    162                   typename TTypes<T>::ConstVec wci,
    163                   typename TTypes<T>::ConstVec wcf,
    164                   typename TTypes<T>::ConstVec wco,
    165                   typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh,
    166                   typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs,
    167                   typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o,
    168                   typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co,
    169                   typename TTypes<T>::Matrix icfo,
    170                   typename TTypes<T>::Matrix h);
    171 };
    172 
    173 // See lstm_ops.cc for CPUDevice implementation and lstm_ops_gpu.cu.cc for
    174 // GPUDevice implementation.
    175 template <typename Device, typename T, bool USE_CUBLAS>
    176 struct LSTMBlockCellBprop : public LSTMBlockCell {
    177   LSTMBlockCellBprop(const int batch_size, const int input_size,
    178                      const int cell_size)
    179       : LSTMBlockCell(batch_size, input_size, cell_size) {}
    180 
    181   void operator()(
    182       OpKernelContext* ctx, const Device& d, bool use_peephole,
    183       typename TTypes<T>::ConstMatrix x,
    184       typename TTypes<T>::ConstMatrix cs_prev,
    185       typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
    186       typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
    187       typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b,
    188       typename TTypes<T>::ConstMatrix i, typename TTypes<T>::ConstMatrix cs,
    189       typename TTypes<T>::ConstMatrix f, typename TTypes<T>::ConstMatrix o,
    190       typename TTypes<T>::ConstMatrix ci, typename TTypes<T>::ConstMatrix co,
    191       typename TTypes<T>::ConstMatrix cs_grad,
    192       typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,
    193       typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,
    194       typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,
    195       typename TTypes<T>::Matrix dicfo, typename TTypes<T>::Matrix cs_prev_grad,
    196       typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,
    197       typename TTypes<T>::Vec wco_grad);
    198 };
    199 
    200 template <typename Device, typename T, bool USE_CUBLAS>
    201 struct BlockLSTMBprop : public LSTMBlockCell {
    202   BlockLSTMBprop(const int batch_size, const int input_size,
    203                  const int cell_size)
    204       : LSTMBlockCell(batch_size, input_size, cell_size) {}
    205 
    206   void operator()(
    207       OpKernelContext* ctx, const Device& d, bool use_peephole,
    208       typename TTypes<T>::ConstMatrix x,
    209       typename TTypes<T>::ConstMatrix cs_prev,
    210       typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
    211       typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
    212       typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b,
    213       typename TTypes<T>::Matrix xh, typename TTypes<T>::ConstMatrix i,
    214       typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f,
    215       typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci,
    216       typename TTypes<T>::ConstMatrix co,
    217       typename TTypes<T>::ConstMatrix cs_grad,
    218       typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,
    219       typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,
    220       typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,
    221       typename TTypes<T>::Matrix dicfo, typename TTypes<T>::Matrix cs_prev_grad,
    222       typename TTypes<T>::Matrix h_prev_grad,
    223       typename TTypes<T>::Matrix xh_grad, typename TTypes<T>::Matrix x_grad,
    224       typename TTypes<T>::Matrix w_grad, typename TTypes<T>::Vec wci_grad,
    225       typename TTypes<T>::Vec wcf_grad, typename TTypes<T>::Vec wco_grad,
    226       typename TTypes<T>::Vec b_grad) {
    227     // do[t] = sigm'(o[t]) .* dh[t] .* co[t]
    228     do_.device(d) = o * (o.constant(T(1)) - o) * h_grad * co;
    229 
    230     // dcs[t] += tanh'(cs[t]) .* dh[t] .* o[t] + dcs[t + 1] .* f[t + 1]
    231     dcs.device(d) = (co.constant(T(1)) - co * co) * h_grad * o + cs_grad;
    232 
    233     Eigen::array<Eigen::DenseIndex, 2> p_shape({1, cell_size_});
    234     Eigen::array<Eigen::DenseIndex, 2> p_broadcast_shape({batch_size_, 1});
    235     if (use_peephole) {
    236       dcs.device(d) =
    237           dcs + do_ * wco.reshape(p_shape).broadcast(p_broadcast_shape);
    238     }
    239 
    240     // dci[t] = tanh'(ci[t]) dcs[t] i[t]
    241     dci.device(d) = (ci.constant(T(1)) - ci * ci) * dcs * i;
    242 
    243     // df[t] = sigm'(f[t]) dcs[t] cs[t - 1]
    244     df.device(d) = f * (f.constant(T(1)) - f) * dcs * cs_prev;
    245 
    246     // di[t] = sigm'(i[t]) dcs[t] ci[t]
    247     di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci;
    248 
    249     dicfo.slice(icfo_i_offsets(), cell_extents()).device(d) = di;
    250     dicfo.slice(icfo_c_offsets(), cell_extents()).device(d) = dci;
    251     dicfo.slice(icfo_f_offsets(), cell_extents()).device(d) = df;
    252     dicfo.slice(icfo_o_offsets(), cell_extents()).device(d) = do_;
    253 
    254     cs_prev_grad.device(d) = dcs * f;
    255     if (use_peephole) {
    256       cs_prev_grad.device(d) =
    257           cs_prev_grad +
    258           di * wci.reshape(p_shape).broadcast(p_broadcast_shape) +
    259           df * wcf.reshape(p_shape).broadcast(p_broadcast_shape);
    260     }
    261 
    262     // xh_grad.
    263     typename TTypes<T>::ConstMatrix const_dicfo(dicfo.data(),
    264                                                 dicfo.dimensions());
    265     TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
    266         ctx, d, false, true, 1.f, const_dicfo, w, 0.f, xh_grad);
    267 
    268     // xh.
    269     xh.slice(xh_x_offsets(), xh_x_extents()).device(d) = x;
    270     xh.slice(xh_h_offsets(), xh_h_extents()).device(d) = h_prev;
    271     typename TTypes<T>::ConstMatrix const_xh(xh.data(), xh.dimensions());
    272 
    273     // x_grad.
    274     x_grad.device(d) = xh_grad.slice(xh_x_offsets(), xh_x_extents());
    275     h_prev_grad.device(d) = xh_grad.slice(xh_h_offsets(), xh_h_extents());
    276 
    277     // w_grad.
    278     TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
    279         ctx, d, true, false, 1.f, const_xh, const_dicfo, 1.f, w_grad);
    280 
    281     // b_grad.
    282     b_grad.device(d) += dicfo.sum(Eigen::array<int, 1>({0}));
    283 
    284     if (use_peephole) {
    285       wci_grad.device(d) += (di * cs_prev).sum(Eigen::array<int, 1>({0}));
    286       wcf_grad.device(d) += (df * cs_prev).sum(Eigen::array<int, 1>({0}));
    287       wco_grad.device(d) += (do_ * cs).sum(Eigen::array<int, 1>({0}));
    288     }
    289   }
    290 };
    291 
    292 }  // namespace functor
    293 }  // namespace tensorflow
    294 
    295 #endif  // TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_
    296