Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CONTRIB_RNN_KERNELS_GRU_OPS_H_
     17 #define TENSORFLOW_CONTRIB_RNN_KERNELS_GRU_OPS_H_
     18 
     19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     20 #include "tensorflow/contrib/rnn/kernels/blas_gemm.h"
     21 #include "tensorflow/core/framework/tensor_types.h"
     22 #include "tensorflow/core/platform/types.h"
     23 
     24 namespace tensorflow {
     25 
     26 class OpKernelContext;
     27 
     28 namespace functor {
     29 
     30 struct GRUCell {
     31   GRUCell(const int batch_size, const int input_size, const int cell_size)
     32       : batch_size_(batch_size),
     33         input_size_(input_size),
     34         cell_size_(cell_size) {}
     35 
     36   inline Eigen::array<Eigen::DenseIndex, 2> x_offsets() const { return {0, 0}; }
     37 
     38   inline Eigen::array<Eigen::DenseIndex, 2> x_extends() const {
     39     return {batch_size_, input_size_};
     40   }
     41 
     42   inline Eigen::array<Eigen::DenseIndex, 2> h_offsets() const {
     43     return {0, input_size_};
     44   }
     45 
     46   inline Eigen::array<Eigen::DenseIndex, 2> h_extends() const {
     47     return {batch_size_, cell_size_};
     48   }
     49 
     50   inline Eigen::array<Eigen::DenseIndex, 2> ru_r_offset() const {
     51     return {0, 0};
     52   }
     53 
     54   inline Eigen::array<Eigen::DenseIndex, 2> ru_u_offset() const {
     55     return {0, cell_size_};
     56   }
     57 
     58   inline Eigen::array<Eigen::DenseIndex, 2> cell_extents() const {
     59     return {batch_size_, cell_size_};
     60   }
     61 
     62  protected:
     63   const int batch_size_;
     64   const int input_size_;
     65   const int cell_size_;
     66 };
     67 
     68 template <typename Device, typename T, bool USE_CUBLAS>
     69 struct GRUBlockCellFprop : public GRUCell {
     70   GRUBlockCellFprop(const int batch_size, const int input_size,
     71                     const int cell_size)
     72       : GRUCell(batch_size, input_size, cell_size) {}
     73 
     74   void operator()(
     75       OpKernelContext* ctx, const Device& d, typename TTypes<T>::ConstMatrix x,
     76       typename TTypes<T>::ConstMatrix h_prev,
     77       typename TTypes<T>::ConstMatrix w_ru, typename TTypes<T>::ConstMatrix w_c,
     78       typename TTypes<T>::ConstVec b_ru, typename TTypes<T>::ConstVec b_c,
     79       typename TTypes<T>::Matrix r_u_bar, typename TTypes<T>::Matrix r,
     80       typename TTypes<T>::Matrix u, typename TTypes<T>::Matrix c,
     81       typename TTypes<T>::Matrix h, typename TTypes<T>::Matrix x_h_prev,
     82       typename TTypes<T>::Matrix x_h_prevr) {
     83     // Concat x_h_prev = [x, h_prev].
     84     x_h_prev.slice(x_offsets(), x_extends()).device(d) = x;
     85     x_h_prev.slice(h_offsets(), h_extends()).device(d) = h_prev;
     86 
     87     // r_u_bar = x_h_prev * w_ru + b_ru
     88     typename TTypes<T>::ConstMatrix const_x_h_prev(x_h_prev.data(),
     89                                                    x_h_prev.dimensions());
     90     TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
     91         ctx, d, false, false, T(1), const_x_h_prev, w_ru, T(0), r_u_bar);
     92 
     93     // Creating a bias matrix for adding by broadcasting 'b_ru'
     94     Eigen::array<Eigen::DenseIndex, 2> broadcast_shape({batch_size_, 1});
     95     Eigen::array<Eigen::DenseIndex, 2> b_ru_shape({1, b_ru.dimensions()[0]});
     96     r_u_bar.device(d) += b_ru.reshape(b_ru_shape).broadcast(broadcast_shape);
     97 
     98     // Slice r_u_bar into r, u and apply the sigmoid.
     99     r.device(d) = (r_u_bar.slice(ru_r_offset(), cell_extents())).sigmoid();
    100     u.device(d) = (r_u_bar.slice(ru_u_offset(), cell_extents())).sigmoid();
    101 
    102     // Concat x_h_prevr = [x,h_prev*r]
    103     x_h_prevr.slice(x_offsets(), x_extends()).device(d) = x;
    104     x_h_prevr.slice(h_offsets(), h_extends()).device(d) = h_prev * r;
    105 
    106     // c = tanh(x_h_prevr*w_c+b_c), Note b_c is broadcasted before adding.
    107     typename TTypes<T>::ConstMatrix const_x_h_prevr(x_h_prevr.data(),
    108                                                     x_h_prevr.dimensions());
    109     TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
    110         ctx, d, false, false, T(1), const_x_h_prevr, w_c, T(0), c);
    111 
    112     Eigen::array<Eigen::DenseIndex, 2> b_c_shape({1, b_c.dimensions()[0]});
    113     c.device(d) += (b_c.reshape(b_c_shape).broadcast(broadcast_shape));
    114     c.device(d) = c.tanh();
    115 
    116     // h= u*h_prev + (1-u)*c
    117     h.device(d) = u * (h_prev - c) + c;
    118   }
    119 };
    120 
    121 template <typename Device, typename T, bool USE_CUBLAS>
    122 struct GRUBlockCellBprop : public GRUCell {
    123   GRUBlockCellBprop(const int batch_size, const int input_size,
    124                     const int cell_size)
    125       : GRUCell(batch_size, input_size, cell_size) {}
    126 
    127   void operator()(
    128       OpKernelContext* ctx, const Device& d, typename TTypes<T>::ConstMatrix x,
    129       typename TTypes<T>::ConstMatrix h_prev,
    130       typename TTypes<T>::ConstMatrix w_ru, typename TTypes<T>::ConstMatrix w_c,
    131       typename TTypes<T>::ConstVec b_ru, typename TTypes<T>::ConstVec b_c,
    132       typename TTypes<T>::ConstMatrix r, typename TTypes<T>::ConstMatrix u,
    133       typename TTypes<T>::ConstMatrix c, typename TTypes<T>::ConstMatrix d_h,
    134       typename TTypes<T>::Matrix d_x, typename TTypes<T>::Matrix d_h_prev,
    135       typename TTypes<T>::Matrix d_c_bar,
    136       typename TTypes<T>::Matrix d_r_bar_u_bar,
    137       typename TTypes<T>::Matrix d_r_bar, typename TTypes<T>::Matrix d_u_bar,
    138       typename TTypes<T>::Matrix d_hr,
    139       typename TTypes<T>::Matrix d_x_comp1_and_h_prev_comp1,
    140       typename TTypes<T>::Matrix d_x_comp2_and_h_prevr) {
    141     // d_c_bar = d_h*(1-u)*(1-(c*c))
    142     d_c_bar.device(d) =
    143         ((d_h * (u.constant(T(1)) - u)) * (c.constant(T(1)) - c * c));
    144 
    145     // d_u_bar = d_h*(h-c)*(u*(1-u))
    146     d_u_bar.device(d) = d_h * (h_prev - c) * u * (u.constant(T(1)) - u);
    147 
    148     // [2nd_component_of_d_x d_h_prevr] = d_c_bar X w_c^T
    149     typename TTypes<T>::ConstMatrix const_d_c_bar(d_c_bar.data(),
    150                                                   d_c_bar.dimensions());
    151     TensorBlasGemm<Device, T, USE_CUBLAS>::compute(ctx, d, false, true, T(1),
    152                                                    const_d_c_bar, w_c, T(0),
    153                                                    d_x_comp2_and_h_prevr);
    154 
    155     d_hr.device(d) = d_x_comp2_and_h_prevr.slice(h_offsets(), h_extends());
    156     d_r_bar.device(d) = (d_hr * h_prev * r) * (r.constant(T(1)) - r);
    157 
    158     // d_r_bar_u_bar = concatenate(d_r_bar, d_u_bar) along axis = 1.
    159     d_r_bar_u_bar.slice(ru_r_offset(), cell_extents()).device(d) = d_r_bar;
    160     d_r_bar_u_bar.slice(ru_u_offset(), cell_extents()).device(d) = d_u_bar;
    161 
    162     // [1st_component_of_d_x 1st_component_of_d_h_prev] = [d_r_bar d_u_bar] X
    163     // w_ru^T
    164     typename TTypes<T>::ConstMatrix const_d_r_bar_u_bar(
    165         d_r_bar_u_bar.data(), d_r_bar_u_bar.dimensions());
    166     TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
    167         ctx, d, false, true, T(1), const_d_r_bar_u_bar, w_ru, T(0),
    168         d_x_comp1_and_h_prev_comp1);
    169 
    170     // d_x = d_x_comp1 + d_x_comp2
    171     d_x.device(d) = (d_x_comp1_and_h_prev_comp1 + d_x_comp2_and_h_prevr)
    172                         .slice(x_offsets(), x_extends());
    173 
    174     // d_h_prev = d_h_comp1 + d_hr*r + d_h*u
    175     d_h_prev.device(d) =
    176         d_x_comp1_and_h_prev_comp1.slice(h_offsets(), h_extends()) +
    177         (d_hr * r) + (d_h * u);
    178   }
    179 };
    180 
    181 }  // namespace functor
    182 }  // namespace tensorflow
    183 
    184 #endif  // TENSORFLOW_CONTRIB_RNN_KERNELS_GRU_OPS_H_
    185