Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #if GOOGLE_CUDA
     19 #define EIGEN_USE_GPU
     20 #endif  // GOOGLE_CUDA
     21 
     22 #include "tensorflow/contrib/rnn/kernels/lstm_ops.h"
     23 
     24 #include <memory>
     25 #include <vector>
     26 
     27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     28 #include "tensorflow/core/framework/op_kernel.h"
     29 #include "tensorflow/core/framework/register_types.h"
     30 #include "tensorflow/core/framework/tensor.h"
     31 #include "tensorflow/core/framework/tensor_shape.h"
     32 #include "tensorflow/core/framework/tensor_types.h"
     33 #include "tensorflow/core/framework/types.h"
     34 #include "tensorflow/core/platform/logging.h"
     35 #include "tensorflow/core/platform/macros.h"
     36 
     37 namespace tensorflow {
     38 
     39 typedef Eigen::ThreadPoolDevice CPUDevice;
     40 typedef Eigen::GpuDevice GPUDevice;
     41 
     42 namespace functor {
     43 
     44 template <typename T>
     45 void LSTMBlockCellFpropWithEigen(
     46     const LSTMBlockCell& cell, OpKernelContext* ctx, const CPUDevice& d,
     47     const T forget_bias, const T cell_clip, bool use_peephole,
     48     typename TTypes<T>::ConstMatrix x, typename TTypes<T>::ConstMatrix cs_prev,
     49     typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
     50     typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
     51     typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b,
     52     typename TTypes<T>::Matrix xh, typename TTypes<T>::Matrix i,
     53     typename TTypes<T>::Matrix cs, typename TTypes<T>::Matrix f,
     54     typename TTypes<T>::Matrix o, typename TTypes<T>::Matrix ci,
     55     typename TTypes<T>::Matrix co, typename TTypes<T>::Matrix icfo,
     56     typename TTypes<T>::Matrix h) {
     57   // Concat xh = [x, h].
     58   xh.slice(cell.xh_x_offsets(), cell.xh_x_extents()).device(d) = x;
     59   xh.slice(cell.xh_h_offsets(), cell.xh_h_extents()).device(d) = h_prev;
     60 
     61   // states1 = xh * w + b
     62   typename TTypes<T>::ConstMatrix const_xh(xh.data(), xh.dimensions());
     63   TensorBlasGemm<CPUDevice, T, false /* USE_CUBLAS */>::compute(
     64       ctx, d, false, false, T(1), const_xh, w, T(0), icfo);
     65   Eigen::array<Eigen::DenseIndex, 2> b_shape({1, b.dimensions()[0]});
     66   Eigen::array<Eigen::DenseIndex, 2> broadcast_shape({cell.batch_size(), 1});
     67   icfo.device(d) += b.reshape(b_shape).broadcast(broadcast_shape);
     68 
     69   Eigen::array<Eigen::DenseIndex, 2> p_shape({1, cell.cell_size()});
     70   Eigen::array<Eigen::DenseIndex, 2> p_broadcast_shape({cell.batch_size(), 1});
     71 
     72   // Input gate.
     73   if (use_peephole) {
     74     auto i_peep = cs_prev * wci.reshape(p_shape).broadcast(p_broadcast_shape);
     75     i.device(d) =
     76         (icfo.slice(cell.icfo_i_offsets(), cell.cell_extents()) + i_peep)
     77             .sigmoid();
     78   } else {
     79     i.device(d) =
     80         icfo.slice(cell.icfo_i_offsets(), cell.cell_extents()).sigmoid();
     81   }
     82 
     83   // Cell input.
     84   ci.device(d) = icfo.slice(cell.icfo_c_offsets(), cell.cell_extents()).tanh();
     85 
     86   // Forget gate (w/ bias).
     87   if (use_peephole) {
     88     auto f_peep = cs_prev * wcf.reshape(p_shape).broadcast(p_broadcast_shape);
     89     f.device(d) = (icfo.slice(cell.icfo_f_offsets(), cell.cell_extents()) +
     90                    f.constant(forget_bias) + f_peep)
     91                       .sigmoid();
     92   } else {
     93     f.device(d) = (icfo.slice(cell.icfo_f_offsets(), cell.cell_extents()) +
     94                    f.constant(forget_bias))
     95                       .sigmoid();
     96   }
     97 
     98   // cs = ci .* i + f .* cs_prev
     99   cs.device(d) = i * ci + f * cs_prev;
    100 
    101   if (cell_clip > 0.0f) {
    102     cs.device(d) =
    103         cs.binaryExpr(cs.constant(cell_clip), Eigen::scalar_clip_op<T>());
    104   }
    105 
    106   // co = tanh(cs)
    107   co.device(d) = cs.tanh();
    108 
    109   // Output gate.
    110   if (use_peephole) {
    111     auto o_peep = cs * wco.reshape(p_shape).broadcast(p_broadcast_shape);
    112     o.device(d) =
    113         (icfo.slice(cell.icfo_o_offsets(), cell.cell_extents()) + o_peep)
    114             .sigmoid();
    115   } else {
    116     o.device(d) =
    117         icfo.slice(cell.icfo_o_offsets(), cell.cell_extents()).sigmoid();
    118   }
    119 
    120   // h = o .* co
    121   h.device(d) = o * co;
    122 }
    123 
    124 template <typename Device, typename T, bool USE_CUBLAS>
    125 void LSTMBlockCellBpropWithEigen(
    126     const LSTMBlockCell& cell, OpKernelContext* ctx, const Device& d,
    127     bool use_peephole, typename TTypes<T>::ConstMatrix x,
    128     typename TTypes<T>::ConstMatrix cs_prev,
    129     typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
    130     typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
    131     typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b,
    132     typename TTypes<T>::ConstMatrix i, typename TTypes<T>::ConstMatrix cs,
    133     typename TTypes<T>::ConstMatrix f, typename TTypes<T>::ConstMatrix o,
    134     typename TTypes<T>::ConstMatrix ci, typename TTypes<T>::ConstMatrix co,
    135     typename TTypes<T>::ConstMatrix cs_grad,
    136     typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,
    137     typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,
    138     typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,
    139     typename TTypes<T>::Matrix dicfo, typename TTypes<T>::Matrix cs_prev_grad,
    140     typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,
    141     typename TTypes<T>::Vec wco_grad) {
    142   // do[t] = sigm'(o[t]) .* dh[t] .* co[t]
    143   do_.device(d) = o * (o.constant(T(1)) - o) * h_grad * co;
    144 
    145   // dcs[t] += tanh'(cs[t]) .* dh[t] .* o[t] + dcs[t + 1] .* f[t + 1]
    146   dcs.device(d) = (co.constant(T(1)) - co * co) * h_grad * o + cs_grad;
    147 
    148   Eigen::array<Eigen::DenseIndex, 2> p_shape({1, cell.cell_size()});
    149   Eigen::array<Eigen::DenseIndex, 2> p_broadcast_shape({cell.batch_size(), 1});
    150   if (use_peephole) {
    151     dcs.device(d) =
    152         dcs + do_ * wco.reshape(p_shape).broadcast(p_broadcast_shape);
    153   }
    154 
    155   // dci[t] = tanh'(ci[t]) dcs[t] i[t]
    156   dci.device(d) = (ci.constant(T(1)) - ci * ci) * dcs * i;
    157 
    158   // df[t] = sigm'(f[t]) dcs[t] cs[t - 1]
    159   df.device(d) = f * (f.constant(T(1)) - f) * dcs * cs_prev;
    160 
    161   // di[t] = sigm'(i[t]) dcs[t] ci[t]
    162   di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci;
    163 
    164   dicfo.slice(cell.icfo_i_offsets(), cell.cell_extents()).device(d) = di;
    165   dicfo.slice(cell.icfo_c_offsets(), cell.cell_extents()).device(d) = dci;
    166   dicfo.slice(cell.icfo_f_offsets(), cell.cell_extents()).device(d) = df;
    167   dicfo.slice(cell.icfo_o_offsets(), cell.cell_extents()).device(d) = do_;
    168 
    169   cs_prev_grad.device(d) = dcs * f;
    170   if (use_peephole) {
    171     cs_prev_grad.device(d) =
    172         cs_prev_grad + di * wci.reshape(p_shape).broadcast(p_broadcast_shape) +
    173         df * wcf.reshape(p_shape).broadcast(p_broadcast_shape);
    174     wci_grad.device(d) = (di * cs_prev).sum(Eigen::array<int, 1>({0}));
    175     wcf_grad.device(d) = (df * cs_prev).sum(Eigen::array<int, 1>({0}));
    176     wco_grad.device(d) = (do_ * cs).sum(Eigen::array<int, 1>({0}));
    177   }
    178 }
    179 
    180 #define DEFINE_CPU_SPECS(T)                                                    \
    181   template <>                                                                  \
    182   void LSTMBlockCellFprop<CPUDevice, T, false /* USE_CUBLAS */>::operator()(   \
    183       OpKernelContext* ctx, const CPUDevice& d, const T forget_bias,           \
    184       const T cell_clip, bool use_peephole, typename TTypes<T>::ConstMatrix x, \
    185       typename TTypes<T>::ConstMatrix cs_prev,                                 \
    186       typename TTypes<T>::ConstMatrix h_prev,                                  \
    187       typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci,     \
    188       typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,      \
    189       typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh,           \
    190       typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs,             \
    191       typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o,              \
    192       typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co,            \
    193       typename TTypes<T>::Matrix icfo, typename TTypes<T>::Matrix h) {         \
    194     LSTMBlockCellFpropWithEigen<T>(                                            \
    195         *this, ctx, d, forget_bias, cell_clip, use_peephole, x, cs_prev,       \
    196         h_prev, w, wci, wcf, wco, b, xh, i, cs, f, o, ci, co, icfo, h);        \
    197   }                                                                            \
    198   template <>                                                                  \
    199   void LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */>::operator()(   \
    200       OpKernelContext* ctx, const CPUDevice& d, bool use_peephole,             \
    201       typename TTypes<T>::ConstMatrix x,                                       \
    202       typename TTypes<T>::ConstMatrix cs_prev,                                 \
    203       typename TTypes<T>::ConstMatrix h_prev,                                  \
    204       typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci,     \
    205       typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,      \
    206       typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i,       \
    207       typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f,   \
    208       typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci,   \
    209       typename TTypes<T>::ConstMatrix co,                                      \
    210       typename TTypes<T>::ConstMatrix cs_grad,                                 \
    211       typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,  \
    212       typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,          \
    213       typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,            \
    214       typename TTypes<T>::Matrix dicfo,                                        \
    215       typename TTypes<T>::Matrix cs_prev_grad,                                 \
    216       typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,      \
    217       typename TTypes<T>::Vec wco_grad) {                                      \
    218     LSTMBlockCellBpropWithEigen<CPUDevice, T, false /* USE_CUBLAS */>(         \
    219         *this, ctx, d, use_peephole, x, cs_prev, h_prev, w, wci, wcf, wco, b,  \
    220         i, cs, f, o, ci, co, cs_grad, h_grad, do_, dcs, dci, df, di, dicfo,    \
    221         cs_prev_grad, wci_grad, wcf_grad, wco_grad);                           \
    222   }                                                                            \
    223   template struct LSTMBlockCellFprop<CPUDevice, T, false /* USE_CUBLAS */>;    \
    224   template struct LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */>;
    225 
    226 DEFINE_CPU_SPECS(float);
    227 #undef DEFINE_CPU_SPECS
    228 
    229 }  // namespace functor
    230 
    231 template <typename Device, typename T, bool USE_CUBLAS>
    232 class LSTMBlockCellOp : public OpKernel {
    233  public:
    234   explicit LSTMBlockCellOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    235     OP_REQUIRES_OK(ctx, ctx->GetAttr("forget_bias", &forget_bias_));
    236     OP_REQUIRES_OK(ctx, ctx->GetAttr("cell_clip", &cell_clip_));
    237     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_));
    238   }
    239 
    240   void Compute(OpKernelContext* ctx) override {
    241     const Tensor* x_tensor = nullptr;
    242     OP_REQUIRES_OK(ctx, ctx->input("x", &x_tensor));
    243 
    244     const Tensor* cs_prev_tensor = nullptr;
    245     OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor));
    246 
    247     const Tensor* h_prev_tensor = nullptr;
    248     OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor));
    249 
    250     const Tensor* w_tensor = nullptr;
    251     OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor));
    252 
    253     const Tensor* wci_tensor = nullptr;
    254     OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor));
    255 
    256     const Tensor* wcf_tensor = nullptr;
    257     OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor));
    258 
    259     const Tensor* wco_tensor = nullptr;
    260     OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor));
    261 
    262     const Tensor* b_tensor = nullptr;
    263     OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor));
    264 
    265     const int64 batch_size = x_tensor->dim_size(0);
    266     const int64 input_size = x_tensor->dim_size(1);
    267     const int64 cell_size = cs_prev_tensor->dim_size(1);
    268 
    269     // Sanity checks for our input shapes.
    270     OP_REQUIRES(ctx, cs_prev_tensor->dim_size(0) == batch_size,
    271                 errors::InvalidArgument("cs_prev.dims(0) != batch_size: ",
    272                                         cs_prev_tensor->dim_size(0), " vs. ",
    273                                         batch_size));
    274     OP_REQUIRES(ctx, cs_prev_tensor->dim_size(1) == cell_size,
    275                 errors::InvalidArgument("cs_prev.dims(1) != cell_size: ",
    276                                         cs_prev_tensor->dim_size(1), " vs. ",
    277                                         cell_size));
    278 
    279     OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size,
    280                 errors::InvalidArgument("h_prev.dims(0) != batch_size: ",
    281                                         h_prev_tensor->dim_size(0), " vs. ",
    282                                         batch_size));
    283     OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size,
    284                 errors::InvalidArgument(
    285                     "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1),
    286                     " vs. ", cell_size));
    287 
    288     OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size,
    289                 errors::InvalidArgument(
    290                     "w.dim_size(0) != input_size + cell_size: ",
    291                     w_tensor->dim_size(0), " vs. ", input_size + cell_size));
    292     OP_REQUIRES(ctx, w_tensor->dim_size(1) == cell_size * 4,
    293                 errors::InvalidArgument(
    294                     "w.dim_size(1) != cell_size * 4: ", w_tensor->dim_size(1),
    295                     " vs. ", cell_size * 4));
    296 
    297     OP_REQUIRES(ctx, b_tensor->dim_size(0) == cell_size * 4,
    298                 errors::InvalidArgument(
    299                     "b.dim_size(0) != cell_size * 4: ", b_tensor->dim_size(0),
    300                     " vs. ", cell_size * 4));
    301 
    302     // Allocate our output tensors.
    303     Tensor* i_tensor = nullptr;
    304     OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
    305                             {"h_prev"}, "i",
    306                             TensorShape({batch_size, cell_size}), &i_tensor));
    307 
    308     Tensor* cs_tensor = nullptr;
    309     OP_REQUIRES_OK(
    310         ctx, ctx->allocate_output("cs", TensorShape({batch_size, cell_size}),
    311                                   &cs_tensor));
    312 
    313     Tensor* f_tensor = nullptr;
    314     OP_REQUIRES_OK(
    315         ctx, ctx->allocate_output("f", TensorShape({batch_size, cell_size}),
    316                                   &f_tensor));
    317 
    318     Tensor* o_tensor = nullptr;
    319     OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
    320                             {"cs_prev"}, "o",
    321                             TensorShape({batch_size, cell_size}), &o_tensor));
    322 
    323     Tensor* ci_tensor = nullptr;
    324     OP_REQUIRES_OK(
    325         ctx, ctx->allocate_output("ci", TensorShape({batch_size, cell_size}),
    326                                   &ci_tensor));
    327 
    328     Tensor* co_tensor = nullptr;
    329     OP_REQUIRES_OK(
    330         ctx, ctx->allocate_output("co", TensorShape({batch_size, cell_size}),
    331                                   &co_tensor));
    332 
    333     Tensor* h_tensor = nullptr;
    334     OP_REQUIRES_OK(
    335         ctx, ctx->allocate_output("h", TensorShape({batch_size, cell_size}),
    336                                   &h_tensor));
    337 
    338     // Allocate our temp tensors.
    339     Tensor xh_tensor;
    340     OP_REQUIRES_OK(ctx, ctx->allocate_temp(
    341                             DataTypeToEnum<T>::v(),
    342                             TensorShape({batch_size, input_size + cell_size}),
    343                             &xh_tensor));
    344 
    345     Tensor icfo_tensor;
    346     OP_REQUIRES_OK(ctx,
    347                    ctx->allocate_temp(DataTypeToEnum<T>::v(),
    348                                       TensorShape({batch_size, cell_size * 4}),
    349                                       &icfo_tensor));
    350 
    351     const Device& device = ctx->eigen_device<Device>();
    352 
    353     functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
    354                                                        cell_size)(
    355         ctx, device, forget_bias_, cell_clip_, use_peephole_,
    356         x_tensor->matrix<T>(), cs_prev_tensor->matrix<T>(),
    357         h_prev_tensor->matrix<T>(), w_tensor->matrix<T>(), wci_tensor->vec<T>(),
    358         wcf_tensor->vec<T>(), wco_tensor->vec<T>(), b_tensor->vec<T>(),
    359         xh_tensor.matrix<T>(), i_tensor->matrix<T>(), cs_tensor->matrix<T>(),
    360         f_tensor->matrix<T>(), o_tensor->matrix<T>(), ci_tensor->matrix<T>(),
    361         co_tensor->matrix<T>(), icfo_tensor.matrix<T>(), h_tensor->matrix<T>());
    362   }
    363 
    364  private:
    365   float forget_bias_;
    366   float cell_clip_;
    367   bool use_peephole_;
    368 };
    369 
    370 #define REGISTER_KERNEL(T)                                             \
    371   REGISTER_KERNEL_BUILDER(                                             \
    372       Name("LSTMBlockCell").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    373       LSTMBlockCellOp<CPUDevice, T, false>);
    374 REGISTER_KERNEL(float);
    375 // REGISTER_KERNEL(double);
    376 #undef REGISTER_KERNEL
    377 
    378 #if GOOGLE_CUDA
    379 namespace functor {
    380 #define DECLARE_GPU_SPEC(T)                                                    \
    381   template <>                                                                  \
    382   void LSTMBlockCellFprop<GPUDevice, T, true>::operator()(                     \
    383       OpKernelContext* ctx, const GPUDevice& d, const T forget_bias,           \
    384       const T cell_clip, bool use_peephole, typename TTypes<T>::ConstMatrix x, \
    385       typename TTypes<T>::ConstMatrix cs_prev,                                 \
    386       typename TTypes<T>::ConstMatrix h_prev,                                  \
    387       typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci,     \
    388       typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,      \
    389       typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh,           \
    390       typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs,             \
    391       typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o,              \
    392       typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co,            \
    393       typename TTypes<T>::Matrix icfo, typename TTypes<T>::Matrix h);          \
    394                                                                                \
    395   extern template struct LSTMBlockCellFprop<GPUDevice, T, true>;
    396 
    397 DECLARE_GPU_SPEC(float);
    398 // DECLARE_GPU_SPEC(double);
    399 #undef DECLARE_GPU_SPEC
    400 }  // end namespace functor
    401 
    402 #define REGISTER_GPU_KERNEL(T)                                         \
    403   REGISTER_KERNEL_BUILDER(                                             \
    404       Name("LSTMBlockCell").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
    405       LSTMBlockCellOp<GPUDevice, T, true>);
    406 
    407 REGISTER_GPU_KERNEL(float);
    408 // REGISTER_GPU_KERNEL(double);
    409 #undef REGISTER_GPU_KERNEL
    410 #endif  // GOOGLE_CUDA
    411 
    412 template <typename Device, typename T, bool USE_CUBLAS>
    413 class LSTMBlockCellGradOp : public OpKernel {
    414  public:
    415   explicit LSTMBlockCellGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    416     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_));
    417   }
    418 
    419   void Compute(OpKernelContext* ctx) override {
    420     const Tensor* x_tensor = nullptr;
    421     OP_REQUIRES_OK(ctx, ctx->input("x", &x_tensor));
    422 
    423     const Tensor* cs_prev_tensor = nullptr;
    424     OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor));
    425 
    426     const Tensor* h_prev_tensor = nullptr;
    427     OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor));
    428 
    429     const Tensor* w_tensor = nullptr;
    430     OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor));
    431 
    432     const Tensor* wci_tensor = nullptr;
    433     OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor));
    434 
    435     const Tensor* wcf_tensor = nullptr;
    436     OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor));
    437 
    438     const Tensor* wco_tensor = nullptr;
    439     OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor));
    440 
    441     const Tensor* b_tensor = nullptr;
    442     OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor));
    443 
    444     const Tensor* i_tensor = nullptr;
    445     OP_REQUIRES_OK(ctx, ctx->input("i", &i_tensor));
    446 
    447     const Tensor* cs_tensor = nullptr;
    448     OP_REQUIRES_OK(ctx, ctx->input("cs", &cs_tensor));
    449 
    450     const Tensor* f_tensor = nullptr;
    451     OP_REQUIRES_OK(ctx, ctx->input("f", &f_tensor));
    452 
    453     const Tensor* o_tensor = nullptr;
    454     OP_REQUIRES_OK(ctx, ctx->input("o", &o_tensor));
    455 
    456     const Tensor* ci_tensor = nullptr;
    457     OP_REQUIRES_OK(ctx, ctx->input("ci", &ci_tensor));
    458 
    459     const Tensor* co_tensor = nullptr;
    460     OP_REQUIRES_OK(ctx, ctx->input("co", &co_tensor));
    461 
    462     const Tensor* cs_grad_tensor = nullptr;
    463     OP_REQUIRES_OK(ctx, ctx->input("cs_grad", &cs_grad_tensor));
    464 
    465     const Tensor* h_grad_tensor = nullptr;
    466     OP_REQUIRES_OK(ctx, ctx->input("h_grad", &h_grad_tensor));
    467 
    468     const int64 batch_size = x_tensor->dim_size(0);
    469     const int64 input_size = x_tensor->dim_size(1);
    470     const int64 cell_size = cs_prev_tensor->dim_size(1);
    471 
    472     // Sanity checks for our input shapes.
    473     OP_REQUIRES(ctx, cs_prev_tensor->dim_size(0) == batch_size,
    474                 errors::InvalidArgument("cs_prev.dims(0) != batch_size: ",
    475                                         cs_prev_tensor->dim_size(0), " vs. ",
    476                                         batch_size));
    477     OP_REQUIRES(ctx, cs_prev_tensor->dim_size(1) == cell_size,
    478                 errors::InvalidArgument("cs_prev.dims(1) != cell_size: ",
    479                                         cs_prev_tensor->dim_size(1), " vs. ",
    480                                         cell_size));
    481 
    482     OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size,
    483                 errors::InvalidArgument("h_prev.dims(0) != batch_size: ",
    484                                         h_prev_tensor->dim_size(0), " vs. ",
    485                                         batch_size));
    486     OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size,
    487                 errors::InvalidArgument(
    488                     "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1),
    489                     " vs. ", cell_size));
    490 
    491     OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size,
    492                 errors::InvalidArgument(
    493                     "w.dim_size(0) != input_size + cell_size: ",
    494                     w_tensor->dim_size(0), " vs. ", input_size + cell_size));
    495     OP_REQUIRES(ctx, w_tensor->dim_size(1) == cell_size * 4,
    496                 errors::InvalidArgument(
    497                     "w.dim_size(1) != cell_size * 4: ", w_tensor->dim_size(1),
    498                     " vs. ", cell_size * 4));
    499 
    500     OP_REQUIRES(ctx, b_tensor->dim_size(0) == cell_size * 4,
    501                 errors::InvalidArgument(
    502                     "b.dim_size(0) != cell_size * 4: ", b_tensor->dim_size(0),
    503                     " vs. ", cell_size * 4));
    504 
    505     OP_REQUIRES(ctx, i_tensor->dim_size(0) == batch_size,
    506                 errors::InvalidArgument(
    507                     "i.dim_size(0) != batch_size: ", i_tensor->dim_size(0),
    508                     " vs. ", batch_size));
    509     OP_REQUIRES(ctx, i_tensor->dim_size(1) == cell_size,
    510                 errors::InvalidArgument(
    511                     "i.dim_size(1) != cell_size: ", i_tensor->dim_size(1),
    512                     " vs. ", cell_size));
    513 
    514     OP_REQUIRES(ctx, cs_tensor->dim_size(0) == batch_size,
    515                 errors::InvalidArgument(
    516                     "cs.dim_size(0) != batch_size: ", cs_tensor->dim_size(0),
    517                     " vs. ", batch_size));
    518     OP_REQUIRES(ctx, cs_tensor->dim_size(1) == cell_size,
    519                 errors::InvalidArgument(
    520                     "cs.dim_size(1) != cell_size: ", cs_tensor->dim_size(1),
    521                     " vs. ", cell_size));
    522 
    523     OP_REQUIRES(ctx, f_tensor->dim_size(0) == batch_size,
    524                 errors::InvalidArgument(
    525                     "f.dim_size(0) != batch_size: ", f_tensor->dim_size(0),
    526                     " vs. ", batch_size));
    527     OP_REQUIRES(ctx, f_tensor->dim_size(1) == cell_size,
    528                 errors::InvalidArgument(
    529                     "i.dim_size(1) != cell_size: ", f_tensor->dim_size(1),
    530                     " vs. ", cell_size));
    531 
    532     OP_REQUIRES(ctx, o_tensor->dim_size(0) == batch_size,
    533                 errors::InvalidArgument(
    534                     "o.dim_size(0) != batch_size: ", o_tensor->dim_size(0),
    535                     " vs. ", batch_size));
    536     OP_REQUIRES(ctx, o_tensor->dim_size(1) == cell_size,
    537                 errors::InvalidArgument(
    538                     "o.dim_size(1) != cell_size: ", o_tensor->dim_size(1),
    539                     " vs. ", cell_size));
    540 
    541     OP_REQUIRES(ctx, ci_tensor->dim_size(0) == batch_size,
    542                 errors::InvalidArgument(
    543                     "ci.dim_size(0) != batch_size: ", ci_tensor->dim_size(0),
    544                     " vs. ", batch_size));
    545     OP_REQUIRES(ctx, ci_tensor->dim_size(1) == cell_size,
    546                 errors::InvalidArgument(
    547                     "ci.dim_size(1) != cell_size: ", ci_tensor->dim_size(1),
    548                     " vs. ", cell_size));
    549 
    550     OP_REQUIRES(ctx, co_tensor->dim_size(0) == batch_size,
    551                 errors::InvalidArgument(
    552                     "co.dim_size(0) != batch_size: ", co_tensor->dim_size(0),
    553                     " vs. ", batch_size));
    554     OP_REQUIRES(ctx, co_tensor->dim_size(1) == cell_size,
    555                 errors::InvalidArgument(
    556                     "co.dim_size(1) != cell_size: ", co_tensor->dim_size(1),
    557                     " vs. ", cell_size));
    558 
    559     OP_REQUIRES(ctx, cs_grad_tensor->dim_size(0) == batch_size,
    560                 errors::InvalidArgument(
    561                     "cs_grad_tensor.dims(0) != batch_size: ",
    562                     cs_grad_tensor->dim_size(0), " vs. ", batch_size));
    563     OP_REQUIRES(ctx, cs_grad_tensor->dim_size(1) == cell_size,
    564                 errors::InvalidArgument("cs_grad_tensor.dims(1) != cell_size: ",
    565                                         cs_grad_tensor->dim_size(1), " vs. ",
    566                                         cell_size));
    567 
    568     OP_REQUIRES(ctx, h_grad_tensor->dim_size(0) == batch_size,
    569                 errors::InvalidArgument("h_grad_tensor.dims(0) != batch_size: ",
    570                                         h_grad_tensor->dim_size(0), " vs. ",
    571                                         batch_size));
    572     OP_REQUIRES(ctx, h_grad_tensor->dim_size(1) == cell_size,
    573                 errors::InvalidArgument("h_grad_tensor.dims(1) != cell_size: ",
    574                                         h_grad_tensor->dim_size(1), " vs. ",
    575                                         cell_size));
    576 
    577     // Allocate our output tensors.
    578     Tensor* cs_prev_grad_tensor = nullptr;
    579     OP_REQUIRES_OK(
    580         ctx, ctx->forward_input_or_allocate_output(
    581                  {"cs_grad"}, "cs_prev_grad",
    582                  TensorShape({batch_size, cell_size}), &cs_prev_grad_tensor));
    583 
    584     Tensor* dicfo_tensor = nullptr;
    585     OP_REQUIRES_OK(ctx, ctx->allocate_output(
    586                             "dicfo", TensorShape({batch_size, cell_size * 4}),
    587                             &dicfo_tensor));
    588 
    589     Tensor* wci_grad_tensor = nullptr;
    590     OP_REQUIRES_OK(
    591         ctx, ctx->forward_input_or_allocate_output(
    592                  {"wci"}, "wci_grad", wci_tensor->shape(), &wci_grad_tensor));
    593 
    594     Tensor* wcf_grad_tensor = nullptr;
    595     OP_REQUIRES_OK(
    596         ctx, ctx->forward_input_or_allocate_output(
    597                  {"wcf"}, "wcf_grad", wcf_tensor->shape(), &wcf_grad_tensor));
    598 
    599     Tensor* wco_grad_tensor = nullptr;
    600     OP_REQUIRES_OK(
    601         ctx, ctx->forward_input_or_allocate_output(
    602                  {"wco"}, "wco_grad", wco_tensor->shape(), &wco_grad_tensor));
    603 
    604     // Allocate our temp tensors.
    605     Tensor do_tensor;
    606     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
    607                                            TensorShape({batch_size, cell_size}),
    608                                            &do_tensor));
    609 
    610     Tensor dcs_tensor;
    611     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
    612                                            TensorShape({batch_size, cell_size}),
    613                                            &dcs_tensor));
    614 
    615     Tensor dci_tensor;
    616     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
    617                                            TensorShape({batch_size, cell_size}),
    618                                            &dci_tensor));
    619 
    620     Tensor df_tensor;
    621     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
    622                                            TensorShape({batch_size, cell_size}),
    623                                            &df_tensor));
    624 
    625     Tensor di_tensor;
    626     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
    627                                            TensorShape({batch_size, cell_size}),
    628                                            &di_tensor));
    629 
    630     const Device& device = ctx->eigen_device<Device>();
    631 
    632     functor::TensorZero<Device, T>()(device, wci_grad_tensor->flat<float>());
    633     functor::TensorZero<Device, T>()(device, wcf_grad_tensor->flat<float>());
    634     functor::TensorZero<Device, T>()(device, wco_grad_tensor->flat<float>());
    635 
    636     functor::LSTMBlockCellBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
    637                                                        cell_size)(
    638         ctx, device, use_peephole_, x_tensor->matrix<T>(),
    639         cs_prev_tensor->matrix<T>(), h_prev_tensor->matrix<T>(),
    640         w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(),
    641         wco_tensor->vec<T>(), b_tensor->vec<T>(), i_tensor->matrix<T>(),
    642         cs_tensor->matrix<T>(), f_tensor->matrix<T>(), o_tensor->matrix<T>(),
    643         ci_tensor->matrix<T>(), co_tensor->matrix<T>(),
    644         cs_grad_tensor->matrix<T>(), h_grad_tensor->matrix<T>(),
    645         do_tensor.matrix<T>(), dcs_tensor.matrix<T>(), dci_tensor.matrix<T>(),
    646         df_tensor.matrix<T>(), di_tensor.matrix<T>(), dicfo_tensor->matrix<T>(),
    647         cs_prev_grad_tensor->matrix<T>(), wci_grad_tensor->vec<T>(),
    648         wcf_grad_tensor->vec<T>(), wco_grad_tensor->vec<T>());
    649   }
    650 
    651  protected:
    652   bool use_peephole_;
    653 };
    654 
    655 #define REGISTER_KERNEL(T)                                                 \
    656   REGISTER_KERNEL_BUILDER(                                                 \
    657       Name("LSTMBlockCellGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    658       LSTMBlockCellGradOp<CPUDevice, T, false>);
    659 REGISTER_KERNEL(float);
    660 // REGISTER_KERNEL(double);
    661 #undef REGISTER_KERNEL
    662 
    663 #if GOOGLE_CUDA
    664 namespace functor {
    665 #define DECLARE_GPU_SPEC(T)                                                   \
    666   template <>                                                                 \
    667   void LSTMBlockCellBprop<GPUDevice, T, true>::operator()(                    \
    668       OpKernelContext* ctx, const GPUDevice& d, bool use_peephole,            \
    669       typename TTypes<T>::ConstMatrix x,                                      \
    670       typename TTypes<T>::ConstMatrix cs_prev,                                \
    671       typename TTypes<T>::ConstMatrix h_prev,                                 \
    672       typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci,    \
    673       typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,     \
    674       typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i,      \
    675       typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f,  \
    676       typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci,  \
    677       typename TTypes<T>::ConstMatrix co,                                     \
    678       typename TTypes<T>::ConstMatrix cs_grad,                                \
    679       typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \
    680       typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,         \
    681       typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,           \
    682       typename TTypes<T>::Matrix dicfo,                                       \
    683       typename TTypes<T>::Matrix cs_prev_grad,                                \
    684       typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,     \
    685       typename TTypes<T>::Vec wco_grad);                                      \
    686                                                                               \
    687   extern template struct LSTMBlockCellBprop<GPUDevice, T,                     \
    688                                             true /* USE_CUBLAS */>;
    689 
    690 DECLARE_GPU_SPEC(float);
    691 // DECLARE_GPU_SPEC(double);
    692 #undef DECLARE_GPU_SPEC
    693 }  // namespace functor
    694 
    695 #define REGISTER_GPU_KERNEL(T)                                             \
    696   REGISTER_KERNEL_BUILDER(                                                 \
    697       Name("LSTMBlockCellGrad").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
    698       LSTMBlockCellGradOp<GPUDevice, T, true>);
    699 
    700 REGISTER_GPU_KERNEL(float);
    701 // REGISTER_GPU_KERNEL(double);
    702 #undef REGISTER_GPU_KERNEL
    703 #endif  // GOOGLE_CUDA
    704 
    705 namespace {
    706 
    707 // This helper class can be used to access timeslices of a 3D tensor. If a slice
    708 // happens to be unaligned (usually because both batch size and number of cells
    709 // are odd - this isn't common) this involves overhead, since data needs to be
    710 // copied. However, if all slices are aligned, the bits aren't copied. In the
    711 // cases where copying is needed, the outputs have to be recopied back.
    712 // At the end of each time step you should call FinishTimeStep which does this,
    713 // and also allows for reuse of temporary tensors.
    714 template <typename Device, typename T>
    715 class SliceHelper {
    716  public:
    717   explicit SliceHelper(OpKernelContext* ctx)
    718       : ctx_(ctx), device_(ctx_->eigen_device<Device>()) {}
    719 
    720   ~SliceHelper() {
    721     CHECK(copy_out_.empty());
    722     for (const auto& entry : pool_) {
    723       CHECK(!entry.second.second);  // nothing is in use
    724     }
    725   }
    726 
    727   // Slice through an input tensor. This may copy unaligned slices, but no
    728   // copying back will be done at the end.
    729   const Tensor InputSlice(const Tensor& t, int pos, const string& name) {
    730     Tensor res = UnalignedSlice(t, pos);
    731     if (res.IsAligned()) {
    732       return res;
    733     } else {
    734       return AlignTensor(res, name);
    735     }
    736   }
    737 
    738   // Slice through an output tensor. This may copy unaligned slices, and
    739   // schedule copying back on destruction.
    740   Tensor OutputSlice(Tensor* t, int pos, const string& name) {
    741     Tensor res = UnalignedSlice(*t, pos);
    742     if (res.IsAligned()) {
    743       return res;
    744     } else {
    745       Tensor aligned = AlignTensor(res, name);
    746       copy_out_.emplace_back(res, aligned);
    747       return aligned;
    748     }
    749   }
    750 
    751   void FinishTimeStep() {
    752     for (const auto& p : copy_out_) {
    753       const Tensor& aligned = p.second;
    754       Tensor original = p.first;
    755       // Copy from aligned back to original.
    756       functor::TensorCopyToUnaligned<Device, T>()(device_, aligned.flat<T>(),
    757                                                   original.unaligned_flat<T>());
    758     }
    759     copy_out_.clear();
    760     // Mark all entries as not in use.
    761     for (auto& entry : pool_) {
    762       entry.second.second = false;
    763     }
    764   }
    765 
    766  private:
    767   // Return a slice at position 'pos'. Result may be unaligned. The resulting
    768   // tensor always shares data with the source tensor.
    769   Tensor UnalignedSlice(const Tensor& t, int pos) const {
    770     Tensor res;
    771     // CHECK should never fail here, since the number of elements must match
    772     CHECK(res.CopyFrom(t.Slice(pos, pos + 1), {t.dim_size(1), t.dim_size(2)}));
    773     return res;
    774   }
    775 
    776   // Assumes input is not aligned, creates a temporary aligned tensor of the
    777   // same shape and copies the original tensor's content into it.
    778   Tensor AlignTensor(const Tensor& t, const string& name) {
    779     VLOG(1) << "AlignTensor called for " << name << ", shape "
    780             << t.shape().DebugString()
    781             << ". This is unnecessary copying. Consider using shapes with even "
    782             << "sizes";
    783     Tensor aligned;
    784     auto found = pool_.find(name);
    785     if (found != pool_.end()) {  // found in pool
    786       CHECK(!found->second.second) << "Tensor " << name << " is in use";
    787       found->second.second = true;  // mark in use
    788       aligned = found->second.first;
    789       CHECK(aligned.shape().IsSameSize(t.shape()));
    790       CHECK_EQ(aligned.dtype(), t.dtype());
    791     } else {  // allocate a new temporary tensor
    792       TF_CHECK_OK(ctx_->allocate_temp(t.dtype(), t.shape(), &aligned));
    793       pool_.emplace(name, std::make_pair(aligned, true));
    794     }
    795     functor::TensorCopyUnaligned<Device, T>()(device_, t.unaligned_flat<T>(),
    796                                               aligned.flat<T>());
    797     return aligned;
    798   }
    799 
    800   // Tensors to be copied.
    801   std::vector<std::pair<Tensor, const Tensor>> copy_out_;
    802   // A pool of pre-allocated temporary tensors, with an indicator for whether
    803   // it's in use.
    804   std::map<string, std::pair<Tensor, bool>> pool_;
    805   // Op context
    806   OpKernelContext* ctx_ = nullptr;
    807   // Device
    808   const Device& device_;
    809 };
    810 
    811 }  // namespace
    812 
    813 template <typename Device, typename T, bool USE_CUBLAS>
    814 class BlockLSTMOp : public OpKernel {
    815  public:
    816   explicit BlockLSTMOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    817     OP_REQUIRES_OK(ctx, ctx->GetAttr("forget_bias", &forget_bias_));
    818     OP_REQUIRES_OK(ctx, ctx->GetAttr("cell_clip", &cell_clip_));
    819     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_));
    820   }
    821 
    822   void Compute(OpKernelContext* ctx) override {
    823     const Tensor* seq_len_max_tensor = nullptr;
    824     OP_REQUIRES_OK(ctx, ctx->input("seq_len_max", &seq_len_max_tensor));
    825 
    826     const Tensor* x;
    827     OP_REQUIRES_OK(ctx, ctx->input("x", &x));
    828     OP_REQUIRES(ctx, x->dims() == 3, errors::InvalidArgument("x must be 3D"));
    829     const int64 timelen = x->dim_size(0);
    830     const int64 batch_size = x->dim_size(1);
    831     const int64 input_size = x->dim_size(2);
    832 
    833     const Tensor* cs_prev_tensor = nullptr;
    834     OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor));
    835     OP_REQUIRES(ctx, cs_prev_tensor->dims() == 2,
    836                 errors::InvalidArgument("cs_prev must be 2D"));
    837     OP_REQUIRES(ctx, cs_prev_tensor->dim_size(0) == batch_size,
    838                 errors::InvalidArgument("cs_prev.dims(0) != batch_size: ",
    839                                         cs_prev_tensor->dim_size(0), " vs. ",
    840                                         batch_size));
    841     const int64 cell_size = cs_prev_tensor->dim_size(1);
    842 
    843     if (batch_size * input_size % 2 == 1) {
    844       LOG(WARNING) << "BlockLSTMOp is inefficient when both batch_size and "
    845                    << "input_size are odd. You are using: batch_size="
    846                    << batch_size << ", input_size=" << input_size;
    847     }
    848     if (batch_size * cell_size % 2 == 1) {
    849       LOG(WARNING) << "BlockLSTMOp is inefficient when both batch_size and "
    850                    << "cell_size are odd. You are using: batch_size="
    851                    << batch_size << ", cell_size=" << cell_size;
    852     }
    853 
    854     const Tensor* h_prev_tensor = nullptr;
    855     OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor));
    856     OP_REQUIRES(ctx, h_prev_tensor->dims() == 2,
    857                 errors::InvalidArgument("h_prev must be 2D"));
    858     OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size,
    859                 errors::InvalidArgument("h_prev.dims(0) != batch_size: ",
    860                                         h_prev_tensor->dim_size(0), " vs. ",
    861                                         batch_size));
    862     OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size,
    863                 errors::InvalidArgument(
    864                     "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1),
    865                     " vs. ", cell_size));
    866 
    867     const Tensor* w_tensor = nullptr;
    868     OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor));
    869     OP_REQUIRES(ctx, w_tensor->dims() == 2,
    870                 errors::InvalidArgument("w must be 2D"));
    871     OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size,
    872                 errors::InvalidArgument(
    873                     "w.dim_size(0) != input_size + cell_size: ",
    874                     w_tensor->dim_size(0), " vs. ", input_size + cell_size));
    875     OP_REQUIRES(ctx, w_tensor->dim_size(1) == cell_size * 4,
    876                 errors::InvalidArgument(
    877                     "w.dim_size(1) != cell_size * 4: ", w_tensor->dim_size(1),
    878                     " vs. ", cell_size * 4));
    879 
    880     const Tensor* wci_tensor = nullptr;
    881     OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor));
    882     OP_REQUIRES(ctx, wci_tensor->dims() == 1,
    883                 errors::InvalidArgument("wci must be 1D"));
    884     OP_REQUIRES(ctx, wci_tensor->dim_size(0) == cell_size,
    885                 errors::InvalidArgument(
    886                     "wci.dim_size(0) != cell_size: ", wci_tensor->dim_size(0),
    887                     " vs. ", cell_size));
    888 
    889     const Tensor* wcf_tensor = nullptr;
    890     OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor));
    891     OP_REQUIRES(ctx, wcf_tensor->dims() == 1,
    892                 errors::InvalidArgument("wcf must be 1D"));
    893     OP_REQUIRES(ctx, wcf_tensor->dim_size(0) == cell_size,
    894                 errors::InvalidArgument(
    895                     "wcf.dim_size(0) != cell_size: ", wcf_tensor->dim_size(0),
    896                     " vs. ", cell_size));
    897 
    898     const Tensor* wco_tensor = nullptr;
    899     OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor));
    900     OP_REQUIRES(ctx, wco_tensor->dims() == 1,
    901                 errors::InvalidArgument("wco must be 1D"));
    902     OP_REQUIRES(ctx, wco_tensor->dim_size(0) == cell_size,
    903                 errors::InvalidArgument(
    904                     "wco.dim_size(0) != cell_size: ", wco_tensor->dim_size(0),
    905                     " vs. ", cell_size));
    906 
    907     const Tensor* b_tensor = nullptr;
    908     OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor));
    909     OP_REQUIRES(ctx, b_tensor->dims() == 1,
    910                 errors::InvalidArgument("b must be 1D"));
    911     OP_REQUIRES(ctx, b_tensor->dim_size(0) == cell_size * 4,
    912                 errors::InvalidArgument(
    913                     "b.dim_size(0) != cell_size * 4: ", b_tensor->dim_size(0),
    914                     " vs. ", cell_size * 4));
    915 
    916     TensorShape batch_cell_shape({timelen, batch_size, cell_size});
    917     Tensor* i_out;
    918     OP_REQUIRES_OK(ctx, ctx->allocate_output("i", batch_cell_shape, &i_out));
    919 
    920     Tensor* cs_out;
    921     OP_REQUIRES_OK(ctx, ctx->allocate_output("cs", batch_cell_shape, &cs_out));
    922 
    923     Tensor* f_out;
    924     OP_REQUIRES_OK(ctx, ctx->allocate_output("f", batch_cell_shape, &f_out));
    925 
    926     Tensor* o_out;
    927     OP_REQUIRES_OK(ctx, ctx->allocate_output("o", batch_cell_shape, &o_out));
    928 
    929     Tensor* ci_out;
    930     OP_REQUIRES_OK(ctx, ctx->allocate_output("ci", batch_cell_shape, &ci_out));
    931 
    932     Tensor* co_out;
    933     OP_REQUIRES_OK(ctx, ctx->allocate_output("co", batch_cell_shape, &co_out));
    934 
    935     Tensor* h_out;
    936     OP_REQUIRES_OK(ctx, ctx->allocate_output("h", batch_cell_shape, &h_out));
    937 
    938     Tensor xh_tensor;
    939     OP_REQUIRES_OK(ctx, ctx->allocate_temp(
    940                             DataTypeToEnum<T>::v(),
    941                             TensorShape({batch_size, input_size + cell_size}),
    942                             &xh_tensor));
    943 
    944     Tensor icfo_tensor;
    945     OP_REQUIRES_OK(ctx,
    946                    ctx->allocate_temp(DataTypeToEnum<T>::v(),
    947                                       TensorShape({batch_size, cell_size * 4}),
    948                                       &icfo_tensor));
    949 
    950     const Device& device = ctx->eigen_device<Device>();
    951 
    952     const int64 seq_len_max = seq_len_max_tensor->scalar<int64>()();
    953     SliceHelper<Device, T> slicer(ctx);
    954     for (int64 t = 0; t < seq_len_max; ++t) {
    955       const Tensor x_tensor = slicer.InputSlice(*x, t, "x");
    956       const Tensor& cs_prev_tensor2 =
    957           t == 0 ? *cs_prev_tensor
    958                  : slicer.OutputSlice(cs_out, t - 1, "cs_prev");
    959       const Tensor& h_prev_tensor2 =
    960           t == 0 ? *h_prev_tensor : slicer.OutputSlice(h_out, t - 1, "h_prev");
    961 
    962       Tensor i_tensor = slicer.OutputSlice(i_out, t, "i_out");
    963       Tensor cs_tensor = slicer.OutputSlice(cs_out, t, "cs_out");
    964       Tensor f_tensor = slicer.OutputSlice(f_out, t, "f_out");
    965       Tensor o_tensor = slicer.OutputSlice(o_out, t, "o_out");
    966       Tensor ci_tensor = slicer.OutputSlice(ci_out, t, "ci_out");
    967       Tensor co_tensor = slicer.OutputSlice(co_out, t, "co_out");
    968       Tensor h_tensor = slicer.OutputSlice(h_out, t, "h_out");
    969 
    970       functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
    971                                                          cell_size)(
    972           ctx, device, forget_bias_, cell_clip_, use_peephole_,
    973           x_tensor.matrix<T>(), cs_prev_tensor2.matrix<T>(),
    974           h_prev_tensor2.matrix<T>(), w_tensor->matrix<T>(),
    975           wci_tensor->vec<T>(), wcf_tensor->vec<T>(), wco_tensor->vec<T>(),
    976           b_tensor->vec<T>(), xh_tensor.matrix<T>(), i_tensor.matrix<T>(),
    977           cs_tensor.matrix<T>(), f_tensor.matrix<T>(), o_tensor.matrix<T>(),
    978           ci_tensor.matrix<T>(), co_tensor.matrix<T>(), icfo_tensor.matrix<T>(),
    979           h_tensor.matrix<T>());
    980       slicer.FinishTimeStep();
    981     }
    982 
    983     if (seq_len_max < timelen) {
    984       Tensor cs_tensor = cs_out->Slice(seq_len_max, timelen);
    985       Tensor h_tensor = h_out->Slice(seq_len_max, timelen);
    986 
    987       functor::TensorUnalignedZero<Device, T>()(
    988           device, cs_tensor.unaligned_flat<float>());
    989       functor::TensorUnalignedZero<Device, T>()(
    990           device, h_tensor.unaligned_flat<float>());
    991     }
    992   }
    993 
    994  private:
    995   float forget_bias_;
    996   float cell_clip_;
    997   bool use_peephole_;
    998 };
    999 
   1000 #define REGISTER_KERNEL(T)                                         \
   1001   REGISTER_KERNEL_BUILDER(                                         \
   1002       Name("BlockLSTM").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
   1003       BlockLSTMOp<CPUDevice, T, false>);
   1004 REGISTER_KERNEL(float);
   1005 // REGISTER_KERNEL(double);
   1006 #undef REGISTER_KERNEL
   1007 
   1008 #if GOOGLE_CUDA
   1009 namespace functor {
   1010 #define DECLARE_GPU_SPEC(T)                                              \
   1011   template <>                                                            \
   1012   void TensorZero<GPUDevice, T>::operator()(const GPUDevice& d,          \
   1013                                             typename TTypes<T>::Flat t); \
   1014                                                                          \
   1015   extern template struct TensorZero<GPUDevice, T>;                       \
   1016                                                                          \
   1017   template <>                                                            \
   1018   void TensorUnalignedZero<GPUDevice, T>::operator()(                    \
   1019       const GPUDevice& d, typename TTypes<T>::UnalignedFlat t);          \
   1020                                                                          \
   1021   extern template struct TensorUnalignedZero<GPUDevice, T>;
   1022 
   1023 DECLARE_GPU_SPEC(float);
   1024 // DECLARE_GPU_SPEC(double);
   1025 #undef DECLARE_GPU_SPEC
   1026 }  // end namespace functor
   1027 
   1028 #define REGISTER_GPU_KERNEL(T)                           \
   1029   REGISTER_KERNEL_BUILDER(Name("BlockLSTM")              \
   1030                               .Device(DEVICE_GPU)        \
   1031                               .HostMemory("seq_len_max") \
   1032                               .TypeConstraint<T>("T"),   \
   1033                           BlockLSTMOp<GPUDevice, T, true>);
   1034 
   1035 REGISTER_GPU_KERNEL(float);
   1036 // REGISTER_GPU_KERNEL(double);
   1037 #undef REGISTER_GPU_KERNEL
   1038 #endif  // GOOGLE_CUDA
   1039 
   1040 template <typename Device, typename T, bool USE_CUBLAS>
   1041 class BlockLSTMGradOp : public OpKernel {
   1042  public:
   1043   explicit BlockLSTMGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
   1044     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_));
   1045   }
   1046 
   1047   void Compute(OpKernelContext* ctx) override {
   1048     const Tensor* seq_len_max_tensor = nullptr;
   1049     OP_REQUIRES_OK(ctx, ctx->input("seq_len_max", &seq_len_max_tensor));
   1050 
   1051     const Tensor* x;
   1052     OP_REQUIRES_OK(ctx, ctx->input("x", &x));
   1053     OP_REQUIRES(ctx, x->dims() == 3, errors::InvalidArgument("x must be 3D"));
   1054     const int64 timelen = x->dim_size(0);
   1055     const int64 batch_size = x->dim_size(1);
   1056     const int64 input_size = x->dim_size(2);
   1057 
   1058     const Tensor* cs_prev_tensor = nullptr;
   1059     OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor));
   1060 
   1061     const Tensor* h_prev_tensor = nullptr;
   1062     OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor));
   1063 
   1064     const Tensor* w_tensor = nullptr;
   1065     OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor));
   1066     const int64 cell_size = w_tensor->dim_size(1) / 4;
   1067     OP_REQUIRES(ctx, input_size + cell_size == w_tensor->dim_size(0),
   1068                 errors::InvalidArgument(
   1069                     "w matrix rows don't match: ", input_size + cell_size,
   1070                     " vs. ", w_tensor->dim_size(0)));
   1071 
   1072     const Tensor* wci_tensor = nullptr;
   1073     OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor));
   1074 
   1075     const Tensor* wcf_tensor = nullptr;
   1076     OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor));
   1077 
   1078     const Tensor* wco_tensor = nullptr;
   1079     OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor));
   1080 
   1081     const Tensor* b_tensor = nullptr;
   1082     OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor));
   1083     OP_REQUIRES(
   1084         ctx, cell_size == b_tensor->dim_size(0) / 4,
   1085         errors::InvalidArgument("w and b cell_size don't match: ", cell_size,
   1086                                 " vs. ", b_tensor->dim_size(0)));
   1087 
   1088     const Tensor* i_out = nullptr;
   1089     OP_REQUIRES_OK(ctx, ctx->input("i", &i_out));
   1090 
   1091     const Tensor* cs_out = nullptr;
   1092     OP_REQUIRES_OK(ctx, ctx->input("cs", &cs_out));
   1093 
   1094     const Tensor* f_out = nullptr;
   1095     OP_REQUIRES_OK(ctx, ctx->input("f", &f_out));
   1096 
   1097     const Tensor* o_out = nullptr;
   1098     OP_REQUIRES_OK(ctx, ctx->input("o", &o_out));
   1099 
   1100     const Tensor* ci_out = nullptr;
   1101     OP_REQUIRES_OK(ctx, ctx->input("ci", &ci_out));
   1102 
   1103     const Tensor* co_out = nullptr;
   1104     OP_REQUIRES_OK(ctx, ctx->input("co", &co_out));
   1105 
   1106     const Tensor* h_out = nullptr;
   1107     OP_REQUIRES_OK(ctx, ctx->input("h", &h_out));
   1108 
   1109     const Tensor* cs_grad = nullptr;
   1110     OP_REQUIRES_OK(ctx, ctx->input("cs_grad", &cs_grad));
   1111 
   1112     const Tensor* h_grad = nullptr;
   1113     OP_REQUIRES_OK(ctx, ctx->input("h_grad", &h_grad));
   1114 
   1115     TensorShape batch_input_shape({timelen, batch_size, input_size});
   1116     Tensor* x_grad;
   1117     OP_REQUIRES_OK(ctx,
   1118                    ctx->allocate_output("x_grad", batch_input_shape, &x_grad));
   1119 
   1120     Tensor* cs_prev_grad_tensor = nullptr;
   1121     OP_REQUIRES_OK(ctx,
   1122                    ctx->allocate_output("cs_prev_grad", cs_prev_tensor->shape(),
   1123                                         &cs_prev_grad_tensor));
   1124 
   1125     Tensor* h_prev_grad_tensor = nullptr;
   1126     OP_REQUIRES_OK(ctx,
   1127                    ctx->allocate_output("h_prev_grad", h_prev_tensor->shape(),
   1128                                         &h_prev_grad_tensor));
   1129 
   1130     Tensor* w_grad_tensor = nullptr;
   1131     OP_REQUIRES_OK(
   1132         ctx, ctx->allocate_output("w_grad", w_tensor->shape(), &w_grad_tensor));
   1133 
   1134     Tensor* wci_grad_tensor = nullptr;
   1135     OP_REQUIRES_OK(ctx, ctx->allocate_output("wci_grad", wci_tensor->shape(),
   1136                                              &wci_grad_tensor));
   1137 
   1138     Tensor* wcf_grad_tensor = nullptr;
   1139     OP_REQUIRES_OK(ctx, ctx->allocate_output("wcf_grad", wcf_tensor->shape(),
   1140                                              &wcf_grad_tensor));
   1141 
   1142     Tensor* wco_grad_tensor = nullptr;
   1143     OP_REQUIRES_OK(ctx, ctx->allocate_output("wco_grad", wco_tensor->shape(),
   1144                                              &wco_grad_tensor));
   1145 
   1146     Tensor* b_grad_tensor = nullptr;
   1147     OP_REQUIRES_OK(
   1148         ctx, ctx->allocate_output("b_grad", b_tensor->shape(), &b_grad_tensor));
   1149 
   1150     TensorShape batch_cell_shape({batch_size, cell_size});
   1151 
   1152     Tensor xh_tensor;
   1153     OP_REQUIRES_OK(ctx, ctx->allocate_temp(
   1154                             DataTypeToEnum<T>::v(),
   1155                             TensorShape({batch_size, input_size + cell_size}),
   1156                             &xh_tensor));
   1157 
   1158     Tensor xh_grad_tensor;
   1159     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
   1160                                            xh_tensor.shape(), &xh_grad_tensor));
   1161 
   1162     Tensor do_tensor;
   1163     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
   1164                                            batch_cell_shape, &do_tensor));
   1165 
   1166     Tensor dcs_tensor;
   1167     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
   1168                                            batch_cell_shape, &dcs_tensor));
   1169 
   1170     Tensor dci_tensor;
   1171     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
   1172                                            batch_cell_shape, &dci_tensor));
   1173 
   1174     Tensor df_tensor;
   1175     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
   1176                                            batch_cell_shape, &df_tensor));
   1177 
   1178     Tensor di_tensor;
   1179     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
   1180                                            batch_cell_shape, &di_tensor));
   1181 
   1182     Tensor dicfo_tensor;
   1183     OP_REQUIRES_OK(ctx,
   1184                    ctx->allocate_temp(DataTypeToEnum<T>::v(),
   1185                                       TensorShape({batch_size, cell_size * 4}),
   1186                                       &dicfo_tensor));
   1187 
   1188     Tensor cs_grad_tensor;
   1189     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
   1190                                            batch_cell_shape, &cs_grad_tensor));
   1191 
   1192     Tensor h_grad_tensor;
   1193     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
   1194                                            batch_cell_shape, &h_grad_tensor));
   1195 
   1196     const Device& device = ctx->eigen_device<Device>();
   1197 
   1198     functor::TensorZero<Device, T>()(device, cs_grad_tensor.flat<float>());
   1199     functor::TensorZero<Device, T>()(device,
   1200                                      cs_prev_grad_tensor->flat<float>());
   1201     functor::TensorZero<Device, T>()(device, h_grad_tensor.flat<float>());
   1202     functor::TensorZero<Device, T>()(device, h_prev_grad_tensor->flat<float>());
   1203     functor::TensorZero<Device, T>()(device, w_grad_tensor->flat<float>());
   1204     functor::TensorZero<Device, T>()(device, wci_grad_tensor->flat<float>());
   1205     functor::TensorZero<Device, T>()(device, wcf_grad_tensor->flat<float>());
   1206     functor::TensorZero<Device, T>()(device, wco_grad_tensor->flat<float>());
   1207     functor::TensorZero<Device, T>()(device, b_grad_tensor->flat<float>());
   1208 
   1209     const int64 seq_len_max = seq_len_max_tensor->scalar<int64>()();
   1210     SliceHelper<Device, T> slicer(ctx);
   1211     for (int64 t = seq_len_max - 1; t >= 0; --t) {
   1212       const Tensor& x_tensor = slicer.InputSlice(*x, t, "x");
   1213       const Tensor& cs_prev_tensor2 =
   1214           t == 0 ? *cs_prev_tensor
   1215                  : slicer.InputSlice(*cs_out, t - 1, "cs_prev");
   1216       const Tensor& h_prev_tensor2 =
   1217           t == 0 ? *h_prev_tensor : slicer.InputSlice(*h_out, t - 1, "h_prev");
   1218       const Tensor& i_tensor = slicer.InputSlice(*i_out, t, "i_out");
   1219       const Tensor& cs_tensor = slicer.InputSlice(*cs_out, t, "cs_out");
   1220       const Tensor& f_tensor = slicer.InputSlice(*f_out, t, "f_out");
   1221       const Tensor& o_tensor = slicer.InputSlice(*o_out, t, "o_out");
   1222       const Tensor& ci_tensor = slicer.InputSlice(*ci_out, t, "ci_out");
   1223       const Tensor& co_tensor = slicer.InputSlice(*co_out, t, "co_out");
   1224 
   1225       // Grab previous CS grad.
   1226       const Tensor& const_cs_prev_grad_tensor = *cs_prev_grad_tensor;
   1227       const Tensor const_cs_grad_slice =
   1228           slicer.InputSlice(*cs_grad, t, "cs_grad");
   1229       functor::TensorAdd<Device, T>()(
   1230           device, const_cs_prev_grad_tensor.flat<T>(),
   1231           const_cs_grad_slice.flat<T>(), cs_grad_tensor.flat<T>());
   1232 
   1233       // Combine previous h grad and h grad coming on top.
   1234       const Tensor& const_h_prev_grad_tensor = *h_prev_grad_tensor;
   1235       const Tensor const_h_grad_slice = slicer.InputSlice(*h_grad, t, "h_grad");
   1236       functor::TensorAdd<Device, T>()(
   1237           device, const_h_prev_grad_tensor.flat<T>(),
   1238           const_h_grad_slice.flat<T>(), h_grad_tensor.flat<T>());
   1239 
   1240       const Tensor& const_cs_grad_tensor = cs_grad_tensor;
   1241       const Tensor& const_h_grad_tensor = h_grad_tensor;
   1242 
   1243       Tensor x_grad_tensor = slicer.OutputSlice(x_grad, t, "x_grad");
   1244       functor::BlockLSTMBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
   1245                                                      cell_size)(
   1246           ctx, device, use_peephole_, x_tensor.matrix<T>(),
   1247           cs_prev_tensor2.matrix<T>(), h_prev_tensor2.matrix<T>(),
   1248           w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(),
   1249           wco_tensor->vec<T>(), b_tensor->vec<T>(), xh_tensor.matrix<T>(),
   1250           i_tensor.matrix<T>(), cs_tensor.matrix<T>(), f_tensor.matrix<T>(),
   1251           o_tensor.matrix<T>(), ci_tensor.matrix<T>(), co_tensor.matrix<T>(),
   1252           const_cs_grad_tensor.matrix<T>(), const_h_grad_tensor.matrix<T>(),
   1253           do_tensor.matrix<T>(), dcs_tensor.matrix<T>(), dci_tensor.matrix<T>(),
   1254           df_tensor.matrix<T>(), di_tensor.matrix<T>(),
   1255           dicfo_tensor.matrix<T>(), cs_prev_grad_tensor->matrix<T>(),
   1256           h_prev_grad_tensor->matrix<T>(), xh_grad_tensor.matrix<T>(),
   1257           x_grad_tensor.matrix<T>(), w_grad_tensor->matrix<T>(),
   1258           wci_grad_tensor->vec<T>(), wcf_grad_tensor->vec<T>(),
   1259           wco_grad_tensor->vec<T>(), b_grad_tensor->vec<T>());
   1260       slicer.FinishTimeStep();
   1261     }
   1262 
   1263     if (seq_len_max < timelen) {
   1264       Tensor x_grad_tensor = x_grad->Slice(seq_len_max, timelen);
   1265       functor::TensorUnalignedZero<Device, T>()(
   1266           device, x_grad_tensor.unaligned_flat<T>());
   1267     }
   1268   }
   1269 
   1270  private:
   1271   bool use_peephole_;
   1272 };
   1273 
   1274 #define REGISTER_KERNEL(T)                                             \
   1275   REGISTER_KERNEL_BUILDER(                                             \
   1276       Name("BlockLSTMGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
   1277       BlockLSTMGradOp<CPUDevice, T, false>);
   1278 REGISTER_KERNEL(float);
   1279 // REGISTER_KERNEL(double);
   1280 #undef REGISTER_KERNEL
   1281 
   1282 #if GOOGLE_CUDA
   1283 namespace functor {
   1284 #define DECLARE_GPU_SPEC(T)                                                    \
   1285   template <>                                                                  \
   1286   void TensorCopy<GPUDevice, T>::operator()(const GPUDevice& d,                \
   1287                                             typename TTypes<T>::ConstFlat src, \
   1288                                             typename TTypes<T>::Flat dst);     \
   1289                                                                                \
   1290   template <>                                                                  \
   1291   void TensorCopyUnaligned<GPUDevice, T>::operator()(                          \
   1292       const GPUDevice& d, typename TTypes<T>::UnalignedConstFlat src,          \
   1293       typename TTypes<T>::Flat dst);                                           \
   1294                                                                                \
   1295   template <>                                                                  \
   1296   void TensorCopyToUnaligned<GPUDevice, T>::operator()(                        \
   1297       const GPUDevice& d, typename TTypes<T>::ConstFlat src,                   \
   1298       typename TTypes<T>::UnalignedFlat dst);                                  \
   1299                                                                                \
   1300   template <>                                                                  \
   1301   void TensorAdd<GPUDevice, T>::operator()(                                    \
   1302       const GPUDevice& d, typename TTypes<T>::ConstFlat a,                     \
   1303       typename TTypes<T>::ConstFlat b, typename TTypes<T>::Flat c);            \
   1304                                                                                \
   1305   template <>                                                                  \
   1306   void BlockLSTMBprop<GPUDevice, T, true>::operator()(                         \
   1307       OpKernelContext* ctx, const GPUDevice& d, bool use_peephole,             \
   1308       typename TTypes<T>::ConstMatrix x,                                       \
   1309       typename TTypes<T>::ConstMatrix cs_prev,                                 \
   1310       typename TTypes<T>::ConstMatrix h_prev,                                  \
   1311       typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci,     \
   1312       typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,      \
   1313       typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh,           \
   1314       typename TTypes<T>::ConstMatrix i, typename TTypes<T>::ConstMatrix cs,   \
   1315       typename TTypes<T>::ConstMatrix f, typename TTypes<T>::ConstMatrix o,    \
   1316       typename TTypes<T>::ConstMatrix ci, typename TTypes<T>::ConstMatrix co,  \
   1317       typename TTypes<T>::ConstMatrix cs_grad,                                 \
   1318       typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,  \
   1319       typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,          \
   1320       typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,            \
   1321       typename TTypes<T>::Matrix dicfo,                                        \
   1322       typename TTypes<T>::Matrix cs_prev_grad,                                 \
   1323       typename TTypes<T>::Matrix h_prev_grad,                                  \
   1324       typename TTypes<T>::Matrix xh_grad, typename TTypes<T>::Matrix x_grad,   \
   1325       typename TTypes<T>::Matrix w_grad, typename TTypes<T>::Vec wci_grad,     \
   1326       typename TTypes<T>::Vec wcf_grad, typename TTypes<T>::Vec wco_grad,      \
   1327       typename TTypes<T>::Vec b_grad);                                         \
   1328                                                                                \
   1329   extern template struct TensorCopy<GPUDevice, T>;                             \
   1330   extern template struct TensorAdd<GPUDevice, T>;                              \
   1331   extern template struct BlockLSTMBprop<GPUDevice, T, true>;
   1332 
   1333 DECLARE_GPU_SPEC(float);
   1334 // DECLARE_GPU_SPEC(double);
   1335 #undef DECLARE_GPU_SPEC
   1336 }  // end namespace functor
   1337 
   1338 #define REGISTER_GPU_KERNEL(T)                           \
   1339   REGISTER_KERNEL_BUILDER(Name("BlockLSTMGrad")          \
   1340                               .Device(DEVICE_GPU)        \
   1341                               .HostMemory("seq_len_max") \
   1342                               .TypeConstraint<T>("T"),   \
   1343                           BlockLSTMGradOp<GPUDevice, T, true>);
   1344 
   1345 REGISTER_GPU_KERNEL(float);
   1346 // REGISTER_GPU_KERNEL(double);
   1347 #undef REGISTER_GPU_KERNEL
   1348 #endif  // GOOGLE_CUDA
   1349 
   1350 }  // end namespace tensorflow
   1351