Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CONTRIB_RNN_KERNELS_GRU_OPS_H_
     17 #define TENSORFLOW_CONTRIB_RNN_KERNELS_GRU_OPS_H_
     18 
     19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     20 #include "tensorflow/contrib/rnn/kernels/blas_gemm.h"
     21 #include "tensorflow/core/framework/tensor_types.h"
     22 #include "tensorflow/core/platform/types.h"
     23 
     24 namespace tensorflow {
     25 
     26 class OpKernelContext;
     27 
     28 namespace functor {
     29 
     30 struct GRUCell {
     31   GRUCell(const int batch_size, const int input_size, const int cell_size)
     32       : batch_size_(batch_size),
     33         input_size_(input_size),
     34         cell_size_(cell_size) {}
     35 
     36   inline Eigen::array<Eigen::DenseIndex, 2> x_offsets() const { return {0, 0}; }
     37 
     38   inline Eigen::array<Eigen::DenseIndex, 2> x_extends() const {
     39     return {batch_size_, input_size_};
     40   }
     41 
     42   inline Eigen::array<Eigen::DenseIndex, 2> h_offsets() const {
     43     return {0, input_size_};
     44   }
     45 
     46   inline Eigen::array<Eigen::DenseIndex, 2> h_extends() const {
     47     return {batch_size_, cell_size_};
     48   }
     49 
     50   inline Eigen::array<Eigen::DenseIndex, 2> ru_r_offset() const {
     51     return {0, 0};
     52   }
     53 
     54   inline Eigen::array<Eigen::DenseIndex, 2> ru_u_offset() const {
     55     return {0, cell_size_};
     56   }
     57 
     58   inline Eigen::array<Eigen::DenseIndex, 2> cell_extents() const {
     59     return {batch_size_, cell_size_};
     60   }
     61 
     62  protected:
     63   const int batch_size_;
     64   const int input_size_;
     65   const int cell_size_;
     66 };
     67 
     68 template <typename Device, typename T, bool USE_CUBLAS>
     69 struct GRUBlockCellFprop : public GRUCell {
     70   GRUBlockCellFprop(const int batch_size, const int input_size,
     71                     const int cell_size)
     72       : GRUCell(batch_size, input_size, cell_size) {}
     73 
     74   void operator()(
     75       OpKernelContext* ctx, const Device& d, typename TTypes<T>::ConstMatrix x,
     76       typename TTypes<T>::ConstMatrix h_prev,
     77       typename TTypes<T>::ConstMatrix w_ru, typename TTypes<T>::ConstMatrix w_c,
     78       typename TTypes<T>::ConstVec b_ru, typename TTypes<T>::ConstVec b_c,
     79       typename TTypes<T>::Matrix r_u_bar, typename TTypes<T>::Matrix r,
     80       typename TTypes<T>::Matrix u, typename TTypes<T>::Matrix c,
     81       typename TTypes<T>::Matrix h, typename TTypes<T>::Matrix x_h_prev,
     82       typename TTypes<T>::Matrix x_h_prevr) {
     83     // Concat x_h_prev = [x, h_prev].
     84     x_h_prev.slice(x_offsets(), x_extends()).device(d) = x;
     85     x_h_prev.slice(h_offsets(), h_extends()).device(d) = h_prev;
     86 
     87     // r_u_bar = x_h_prev * w_ru + b_ru
     88     typename TTypes<T>::ConstMatrix const_x_h_prev(x_h_prev.data(),
     89                                                    x_h_prev.dimensions());
     90     TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
     91         ctx, d, false, false, typename gemm_compute_type<T>::type(1.f),
     92         const_x_h_prev, w_ru, typename gemm_compute_type<T>::type(0.f),
     93         r_u_bar);
     94 
     95     // Creating a bias matrix for adding by broadcasting 'b_ru'
     96     Eigen::array<Eigen::DenseIndex, 2> broadcast_shape({batch_size_, 1});
     97     Eigen::array<Eigen::DenseIndex, 2> b_ru_shape({1, b_ru.dimensions()[0]});
     98     r_u_bar.device(d) += b_ru.reshape(b_ru_shape).broadcast(broadcast_shape);
     99 
    100     // Slice r_u_bar into r, u and apply the sigmoid.
    101     r.device(d) = (r_u_bar.slice(ru_r_offset(), cell_extents())).sigmoid();
    102     u.device(d) = (r_u_bar.slice(ru_u_offset(), cell_extents())).sigmoid();
    103 
    104     // Concat x_h_prevr = [x,h_prev*r]
    105     x_h_prevr.slice(x_offsets(), x_extends()).device(d) = x;
    106     x_h_prevr.slice(h_offsets(), h_extends()).device(d) = h_prev * r;
    107 
    108     // c = tanh(x_h_prevr*w_c+b_c), Note b_c is broadcasted before adding.
    109     typename TTypes<T>::ConstMatrix const_x_h_prevr(x_h_prevr.data(),
    110                                                     x_h_prevr.dimensions());
    111     TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
    112         ctx, d, false, false, typename gemm_compute_type<T>::type(1.f),
    113         const_x_h_prevr, w_c, typename gemm_compute_type<T>::type(0.f), c);
    114 
    115     Eigen::array<Eigen::DenseIndex, 2> b_c_shape({1, b_c.dimensions()[0]});
    116     c.device(d) += (b_c.reshape(b_c_shape).broadcast(broadcast_shape));
    117     c.device(d) = c.tanh();
    118 
    119     // h= u*h_prev + (1-u)*c
    120     h.device(d) = u * (h_prev - c) + c;
    121   }
    122 };
    123 
    124 template <typename Device, typename T, bool USE_CUBLAS>
    125 struct GRUBlockCellBprop : public GRUCell {
    126   GRUBlockCellBprop(const int batch_size, const int input_size,
    127                     const int cell_size)
    128       : GRUCell(batch_size, input_size, cell_size) {}
    129 
    130   void operator()(
    131       OpKernelContext* ctx, const Device& d, typename TTypes<T>::ConstMatrix x,
    132       typename TTypes<T>::ConstMatrix h_prev,
    133       typename TTypes<T>::ConstMatrix w_ru, typename TTypes<T>::ConstMatrix w_c,
    134       typename TTypes<T>::ConstVec b_ru, typename TTypes<T>::ConstVec b_c,
    135       typename TTypes<T>::ConstMatrix r, typename TTypes<T>::ConstMatrix u,
    136       typename TTypes<T>::ConstMatrix c, typename TTypes<T>::ConstMatrix d_h,
    137       typename TTypes<T>::Matrix d_x, typename TTypes<T>::Matrix d_h_prev,
    138       typename TTypes<T>::Matrix d_c_bar,
    139       typename TTypes<T>::Matrix d_r_bar_u_bar,
    140       typename TTypes<T>::Matrix d_r_bar, typename TTypes<T>::Matrix d_u_bar,
    141       typename TTypes<T>::Matrix d_hr,
    142       typename TTypes<T>::Matrix d_x_comp1_and_h_prev_comp1,
    143       typename TTypes<T>::Matrix d_x_comp2_and_h_prevr) {
    144     // d_c_bar = d_h*(1-u)*(1-(c*c))
    145     d_c_bar.device(d) =
    146         ((d_h * (u.constant(T(1)) - u)) * (c.constant(T(1)) - c * c));
    147 
    148     // d_u_bar = d_h*(h-c)*(u*(1-u))
    149     d_u_bar.device(d) = d_h * (h_prev - c) * u * (u.constant(T(1)) - u);
    150 
    151     // [2nd_component_of_d_x d_h_prevr] = d_c_bar X w_c^T
    152     typename TTypes<T>::ConstMatrix const_d_c_bar(d_c_bar.data(),
    153                                                   d_c_bar.dimensions());
    154     TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
    155         ctx, d, false, true, typename gemm_compute_type<T>::type(1.f),
    156         const_d_c_bar, w_c, typename gemm_compute_type<T>::type(0.f),
    157         d_x_comp2_and_h_prevr);
    158 
    159     d_hr.device(d) = d_x_comp2_and_h_prevr.slice(h_offsets(), h_extends());
    160     d_r_bar.device(d) = (d_hr * h_prev * r) * (r.constant(T(1)) - r);
    161 
    162     // d_r_bar_u_bar = concatenate(d_r_bar, d_u_bar) along axis = 1.
    163     d_r_bar_u_bar.slice(ru_r_offset(), cell_extents()).device(d) = d_r_bar;
    164     d_r_bar_u_bar.slice(ru_u_offset(), cell_extents()).device(d) = d_u_bar;
    165 
    166     // [1st_component_of_d_x 1st_component_of_d_h_prev] = [d_r_bar d_u_bar] X
    167     // w_ru^T
    168     typename TTypes<T>::ConstMatrix const_d_r_bar_u_bar(
    169         d_r_bar_u_bar.data(), d_r_bar_u_bar.dimensions());
    170     TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
    171         ctx, d, false, true, typename gemm_compute_type<T>::type(1.f),
    172         const_d_r_bar_u_bar, w_ru, typename gemm_compute_type<T>::type(0.f),
    173         d_x_comp1_and_h_prev_comp1);
    174 
    175     // d_x = d_x_comp1 + d_x_comp2
    176     d_x.device(d) = (d_x_comp1_and_h_prev_comp1 + d_x_comp2_and_h_prevr)
    177                         .slice(x_offsets(), x_extends());
    178 
    179     // d_h_prev = d_h_comp1 + d_hr*r + d_h*u
    180     d_h_prev.device(d) =
    181         d_x_comp1_and_h_prev_comp1.slice(h_offsets(), h_extends()) +
    182         (d_hr * r) + (d_h * u);
    183   }
    184 };
    185 
    186 }  // namespace functor
    187 }  // namespace tensorflow
    188 
    189 #endif  // TENSORFLOW_CONTRIB_RNN_KERNELS_GRU_OPS_H_
    190