Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #if GOOGLE_CUDA
     19 #define EIGEN_USE_GPU
     20 #endif  // GOOGLE_CUDA
     21 
     22 #include "tensorflow/contrib/rnn/kernels/lstm_ops.h"
     23 
     24 #include <memory>
     25 #include <vector>
     26 
     27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     28 #include "tensorflow/core/framework/op_kernel.h"
     29 #include "tensorflow/core/framework/register_types.h"
     30 #include "tensorflow/core/framework/tensor.h"
     31 #include "tensorflow/core/framework/tensor_shape.h"
     32 #include "tensorflow/core/framework/tensor_types.h"
     33 #include "tensorflow/core/framework/types.h"
     34 #include "tensorflow/core/platform/logging.h"
     35 #include "tensorflow/core/platform/macros.h"
     36 
     37 namespace tensorflow {
     38 
     39 typedef Eigen::ThreadPoolDevice CPUDevice;
     40 typedef Eigen::GpuDevice GPUDevice;
     41 
     42 namespace functor {
     43 
     44 template <typename T>
     45 void LSTMBlockCellFpropWithEigen(
     46     const LSTMBlockCell& cell, OpKernelContext* ctx, const CPUDevice& d,
     47     const float forget_bias, const float cell_clip, bool use_peephole,
     48     typename TTypes<T>::ConstMatrix x, typename TTypes<T>::ConstMatrix cs_prev,
     49     typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
     50     typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
     51     typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b,
     52     typename TTypes<T>::Matrix xh, typename TTypes<T>::Matrix i,
     53     typename TTypes<T>::Matrix cs, typename TTypes<T>::Matrix f,
     54     typename TTypes<T>::Matrix o, typename TTypes<T>::Matrix ci,
     55     typename TTypes<T>::Matrix co, typename TTypes<T>::Matrix icfo,
     56     typename TTypes<T>::Matrix h) {
     57   // Concat xh = [x, h].
     58   xh.slice(cell.xh_x_offsets(), cell.xh_x_extents()).device(d) = x;
     59   xh.slice(cell.xh_h_offsets(), cell.xh_h_extents()).device(d) = h_prev;
     60 
     61   // states1 = xh * w + b
     62   typename TTypes<T>::ConstMatrix const_xh(xh.data(), xh.dimensions());
     63   TensorBlasGemm<CPUDevice, T, false /* USE_CUBLAS */>::compute(
     64       ctx, d, false, false, typename gemm_compute_type<T>::type(1.f), const_xh,
     65       w, typename gemm_compute_type<T>::type(0.f), icfo);
     66   Eigen::array<Eigen::DenseIndex, 2> b_shape({1, b.dimensions()[0]});
     67   Eigen::array<Eigen::DenseIndex, 2> broadcast_shape({cell.batch_size(), 1});
     68   icfo.device(d) += b.reshape(b_shape).broadcast(broadcast_shape);
     69 
     70   Eigen::array<Eigen::DenseIndex, 2> p_shape({1, cell.cell_size()});
     71   Eigen::array<Eigen::DenseIndex, 2> p_broadcast_shape({cell.batch_size(), 1});
     72 
     73   // Input gate.
     74   if (use_peephole) {
     75     auto i_peep = cs_prev * wci.reshape(p_shape).broadcast(p_broadcast_shape);
     76     i.device(d) =
     77         (icfo.slice(cell.icfo_i_offsets(), cell.cell_extents()) + i_peep)
     78             .sigmoid();
     79   } else {
     80     i.device(d) =
     81         icfo.slice(cell.icfo_i_offsets(), cell.cell_extents()).sigmoid();
     82   }
     83 
     84   // Cell input.
     85   ci.device(d) = icfo.slice(cell.icfo_c_offsets(), cell.cell_extents()).tanh();
     86 
     87   // Forget gate (w/ bias).
     88   if (use_peephole) {
     89     auto f_peep = cs_prev * wcf.reshape(p_shape).broadcast(p_broadcast_shape);
     90     f.device(d) = (icfo.slice(cell.icfo_f_offsets(), cell.cell_extents()) +
     91                    f.constant(T(forget_bias)) + f_peep)
     92                       .sigmoid();
     93   } else {
     94     f.device(d) = (icfo.slice(cell.icfo_f_offsets(), cell.cell_extents()) +
     95                    f.constant(T(forget_bias)))
     96                       .sigmoid();
     97   }
     98 
     99   // cs = ci .* i + f .* cs_prev
    100   cs.device(d) = i * ci + f * cs_prev;
    101 
    102   if (cell_clip > 0.0f) {
    103     cs.device(d) =
    104         cs.binaryExpr(cs.constant(T(cell_clip)), Eigen::scalar_clip_op<T>());
    105   }
    106 
    107   // co = tanh(cs)
    108   co.device(d) = cs.tanh();
    109 
    110   // Output gate.
    111   if (use_peephole) {
    112     auto o_peep = cs * wco.reshape(p_shape).broadcast(p_broadcast_shape);
    113     o.device(d) =
    114         (icfo.slice(cell.icfo_o_offsets(), cell.cell_extents()) + o_peep)
    115             .sigmoid();
    116   } else {
    117     o.device(d) =
    118         icfo.slice(cell.icfo_o_offsets(), cell.cell_extents()).sigmoid();
    119   }
    120 
    121   // h = o .* co
    122   h.device(d) = o * co;
    123 }
    124 
    125 template <typename Device, typename T, bool USE_CUBLAS>
    126 void LSTMBlockCellBpropWithEigen(
    127     const LSTMBlockCell& cell, OpKernelContext* ctx, const Device& d,
    128     bool use_peephole, typename TTypes<T>::ConstMatrix x,
    129     typename TTypes<T>::ConstMatrix cs_prev,
    130     typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
    131     typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
    132     typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b,
    133     typename TTypes<T>::ConstMatrix i, typename TTypes<T>::ConstMatrix cs,
    134     typename TTypes<T>::ConstMatrix f, typename TTypes<T>::ConstMatrix o,
    135     typename TTypes<T>::ConstMatrix ci, typename TTypes<T>::ConstMatrix co,
    136     typename TTypes<T>::ConstMatrix cs_grad,
    137     typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,
    138     typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,
    139     typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,
    140     typename TTypes<T>::Matrix dicfo, typename TTypes<T>::Matrix cs_prev_grad,
    141     typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,
    142     typename TTypes<T>::Vec wco_grad) {
    143   // do[t] = sigm'(o[t]) .* dh[t] .* co[t]
    144   do_.device(d) = o * (o.constant(T(1)) - o) * h_grad * co;
    145 
    146   // dcs[t] += tanh'(cs[t]) .* dh[t] .* o[t] + dcs[t + 1] .* f[t + 1]
    147   dcs.device(d) = (co.constant(T(1)) - co * co) * h_grad * o + cs_grad;
    148 
    149   Eigen::array<Eigen::DenseIndex, 2> p_shape({1, cell.cell_size()});
    150   Eigen::array<Eigen::DenseIndex, 2> p_broadcast_shape({cell.batch_size(), 1});
    151   if (use_peephole) {
    152     dcs.device(d) =
    153         dcs + do_ * wco.reshape(p_shape).broadcast(p_broadcast_shape);
    154   }
    155 
    156   // dci[t] = tanh'(ci[t]) dcs[t] i[t]
    157   dci.device(d) = (ci.constant(T(1)) - ci * ci) * dcs * i;
    158 
    159   // df[t] = sigm'(f[t]) dcs[t] cs[t - 1]
    160   df.device(d) = f * (f.constant(T(1)) - f) * dcs * cs_prev;
    161 
    162   // di[t] = sigm'(i[t]) dcs[t] ci[t]
    163   di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci;
    164 
    165   dicfo.slice(cell.icfo_i_offsets(), cell.cell_extents()).device(d) = di;
    166   dicfo.slice(cell.icfo_c_offsets(), cell.cell_extents()).device(d) = dci;
    167   dicfo.slice(cell.icfo_f_offsets(), cell.cell_extents()).device(d) = df;
    168   dicfo.slice(cell.icfo_o_offsets(), cell.cell_extents()).device(d) = do_;
    169 
    170   cs_prev_grad.device(d) = dcs * f;
    171   if (use_peephole) {
    172     cs_prev_grad.device(d) =
    173         cs_prev_grad + di * wci.reshape(p_shape).broadcast(p_broadcast_shape) +
    174         df * wcf.reshape(p_shape).broadcast(p_broadcast_shape);
    175     wci_grad.device(d) = (di * cs_prev).sum(Eigen::array<int, 1>({0}));
    176     wcf_grad.device(d) = (df * cs_prev).sum(Eigen::array<int, 1>({0}));
    177     wco_grad.device(d) = (do_ * cs).sum(Eigen::array<int, 1>({0}));
    178   }
    179 }
    180 
    181 #define DEFINE_CPU_SPECS(T)                                                   \
    182   template <>                                                                 \
    183   void LSTMBlockCellFprop<CPUDevice, T, false /* USE_CUBLAS */>::operator()(  \
    184       OpKernelContext* ctx, const CPUDevice& d, const float forget_bias,      \
    185       const float cell_clip, bool use_peephole,                               \
    186       typename TTypes<T>::ConstMatrix x,                                      \
    187       typename TTypes<T>::ConstMatrix cs_prev,                                \
    188       typename TTypes<T>::ConstMatrix h_prev,                                 \
    189       typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci,    \
    190       typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,     \
    191       typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh,          \
    192       typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs,            \
    193       typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o,             \
    194       typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co,           \
    195       typename TTypes<T>::Matrix icfo, typename TTypes<T>::Matrix h) {        \
    196     LSTMBlockCellFpropWithEigen<T>(                                           \
    197         *this, ctx, d, forget_bias, cell_clip, use_peephole, x, cs_prev,      \
    198         h_prev, w, wci, wcf, wco, b, xh, i, cs, f, o, ci, co, icfo, h);       \
    199   }                                                                           \
    200   template <>                                                                 \
    201   void LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */>::operator()(  \
    202       OpKernelContext* ctx, const CPUDevice& d, bool use_peephole,            \
    203       typename TTypes<T>::ConstMatrix x,                                      \
    204       typename TTypes<T>::ConstMatrix cs_prev,                                \
    205       typename TTypes<T>::ConstMatrix h_prev,                                 \
    206       typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci,    \
    207       typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,     \
    208       typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i,      \
    209       typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f,  \
    210       typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci,  \
    211       typename TTypes<T>::ConstMatrix co,                                     \
    212       typename TTypes<T>::ConstMatrix cs_grad,                                \
    213       typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \
    214       typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,         \
    215       typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,           \
    216       typename TTypes<T>::Matrix dicfo,                                       \
    217       typename TTypes<T>::Matrix cs_prev_grad,                                \
    218       typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,     \
    219       typename TTypes<T>::Vec wco_grad) {                                     \
    220     LSTMBlockCellBpropWithEigen<CPUDevice, T, false /* USE_CUBLAS */>(        \
    221         *this, ctx, d, use_peephole, x, cs_prev, h_prev, w, wci, wcf, wco, b, \
    222         i, cs, f, o, ci, co, cs_grad, h_grad, do_, dcs, dci, df, di, dicfo,   \
    223         cs_prev_grad, wci_grad, wcf_grad, wco_grad);                          \
    224   }                                                                           \
    225   template struct LSTMBlockCellFprop<CPUDevice, T, false /* USE_CUBLAS */>;   \
    226   template struct LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */>;
    227 
    228 DEFINE_CPU_SPECS(float);
    229 DEFINE_CPU_SPECS(Eigen::half);
    230 #undef DEFINE_CPU_SPECS
    231 
    232 }  // namespace functor
    233 
    234 template <typename Device, typename T, bool USE_CUBLAS>
    235 class LSTMBlockCellOp : public OpKernel {
    236  public:
    237   explicit LSTMBlockCellOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    238     OP_REQUIRES_OK(ctx, ctx->GetAttr("forget_bias", &forget_bias_));
    239     OP_REQUIRES_OK(ctx, ctx->GetAttr("cell_clip", &cell_clip_));
    240     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_));
    241   }
    242 
    243   void Compute(OpKernelContext* ctx) override {
    244     const Tensor* x_tensor = nullptr;
    245     OP_REQUIRES_OK(ctx, ctx->input("x", &x_tensor));
    246 
    247     const Tensor* cs_prev_tensor = nullptr;
    248     OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor));
    249 
    250     const Tensor* h_prev_tensor = nullptr;
    251     OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor));
    252 
    253     const Tensor* w_tensor = nullptr;
    254     OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor));
    255 
    256     const Tensor* wci_tensor = nullptr;
    257     OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor));
    258 
    259     const Tensor* wcf_tensor = nullptr;
    260     OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor));
    261 
    262     const Tensor* wco_tensor = nullptr;
    263     OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor));
    264 
    265     const Tensor* b_tensor = nullptr;
    266     OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor));
    267 
    268     const int64 batch_size = x_tensor->dim_size(0);
    269     const int64 input_size = x_tensor->dim_size(1);
    270     const int64 cell_size = cs_prev_tensor->dim_size(1);
    271 
    272     // Sanity checks for our input shapes.
    273     OP_REQUIRES(ctx, cs_prev_tensor->dim_size(0) == batch_size,
    274                 errors::InvalidArgument("cs_prev.dims(0) != batch_size: ",
    275                                         cs_prev_tensor->dim_size(0), " vs. ",
    276                                         batch_size));
    277     OP_REQUIRES(ctx, cs_prev_tensor->dim_size(1) == cell_size,
    278                 errors::InvalidArgument("cs_prev.dims(1) != cell_size: ",
    279                                         cs_prev_tensor->dim_size(1), " vs. ",
    280                                         cell_size));
    281 
    282     OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size,
    283                 errors::InvalidArgument("h_prev.dims(0) != batch_size: ",
    284                                         h_prev_tensor->dim_size(0), " vs. ",
    285                                         batch_size));
    286     OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size,
    287                 errors::InvalidArgument(
    288                     "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1),
    289                     " vs. ", cell_size));
    290 
    291     OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size,
    292                 errors::InvalidArgument(
    293                     "w.dim_size(0) != input_size + cell_size: ",
    294                     w_tensor->dim_size(0), " vs. ", input_size + cell_size));
    295     OP_REQUIRES(ctx, w_tensor->dim_size(1) == cell_size * 4,
    296                 errors::InvalidArgument(
    297                     "w.dim_size(1) != cell_size * 4: ", w_tensor->dim_size(1),
    298                     " vs. ", cell_size * 4));
    299 
    300     OP_REQUIRES(ctx, b_tensor->dim_size(0) == cell_size * 4,
    301                 errors::InvalidArgument(
    302                     "b.dim_size(0) != cell_size * 4: ", b_tensor->dim_size(0),
    303                     " vs. ", cell_size * 4));
    304 
    305     // Allocate our output tensors.
    306     Tensor* i_tensor = nullptr;
    307     OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
    308                             {"h_prev"}, "i",
    309                             TensorShape({batch_size, cell_size}), &i_tensor));
    310 
    311     Tensor* cs_tensor = nullptr;
    312     OP_REQUIRES_OK(
    313         ctx, ctx->allocate_output("cs", TensorShape({batch_size, cell_size}),
    314                                   &cs_tensor));
    315 
    316     Tensor* f_tensor = nullptr;
    317     OP_REQUIRES_OK(
    318         ctx, ctx->allocate_output("f", TensorShape({batch_size, cell_size}),
    319                                   &f_tensor));
    320 
    321     Tensor* o_tensor = nullptr;
    322     OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
    323                             {"cs_prev"}, "o",
    324                             TensorShape({batch_size, cell_size}), &o_tensor));
    325 
    326     Tensor* ci_tensor = nullptr;
    327     OP_REQUIRES_OK(
    328         ctx, ctx->allocate_output("ci", TensorShape({batch_size, cell_size}),
    329                                   &ci_tensor));
    330 
    331     Tensor* co_tensor = nullptr;
    332     OP_REQUIRES_OK(
    333         ctx, ctx->allocate_output("co", TensorShape({batch_size, cell_size}),
    334                                   &co_tensor));
    335 
    336     Tensor* h_tensor = nullptr;
    337     OP_REQUIRES_OK(
    338         ctx, ctx->allocate_output("h", TensorShape({batch_size, cell_size}),
    339                                   &h_tensor));
    340 
    341     // Allocate our temp tensors.
    342     Tensor xh_tensor;
    343     OP_REQUIRES_OK(ctx, ctx->allocate_temp(
    344                             DataTypeToEnum<T>::v(),
    345                             TensorShape({batch_size, input_size + cell_size}),
    346                             &xh_tensor));
    347 
    348     Tensor icfo_tensor;
    349     OP_REQUIRES_OK(ctx,
    350                    ctx->allocate_temp(DataTypeToEnum<T>::v(),
    351                                       TensorShape({batch_size, cell_size * 4}),
    352                                       &icfo_tensor));
    353 
    354     const Device& device = ctx->eigen_device<Device>();
    355 
    356     functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
    357                                                        cell_size)(
    358         ctx, device, forget_bias_, cell_clip_, use_peephole_,
    359         x_tensor->matrix<T>(), cs_prev_tensor->matrix<T>(),
    360         h_prev_tensor->matrix<T>(), w_tensor->matrix<T>(), wci_tensor->vec<T>(),
    361         wcf_tensor->vec<T>(), wco_tensor->vec<T>(), b_tensor->vec<T>(),
    362         xh_tensor.matrix<T>(), i_tensor->matrix<T>(), cs_tensor->matrix<T>(),
    363         f_tensor->matrix<T>(), o_tensor->matrix<T>(), ci_tensor->matrix<T>(),
    364         co_tensor->matrix<T>(), icfo_tensor.matrix<T>(), h_tensor->matrix<T>());
    365   }
    366 
    367  private:
    368   float forget_bias_;
    369   float cell_clip_;
    370   bool use_peephole_;
    371 };
    372 
    373 #define REGISTER_KERNEL(T)                                             \
    374   REGISTER_KERNEL_BUILDER(                                             \
    375       Name("LSTMBlockCell").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    376       LSTMBlockCellOp<CPUDevice, T, false>);
    377 REGISTER_KERNEL(float);
    378 REGISTER_KERNEL(Eigen::half);
    379 #undef REGISTER_KERNEL
    380 
    381 #if GOOGLE_CUDA
    382 namespace functor {
    383 #define DECLARE_GPU_SPEC(T)                                                \
    384   template <>                                                              \
    385   void LSTMBlockCellFprop<GPUDevice, T, true>::operator()(                 \
    386       OpKernelContext* ctx, const GPUDevice& d, const float forget_bias,   \
    387       const float cell_clip, bool use_peephole,                            \
    388       typename TTypes<T>::ConstMatrix x,                                   \
    389       typename TTypes<T>::ConstMatrix cs_prev,                             \
    390       typename TTypes<T>::ConstMatrix h_prev,                              \
    391       typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
    392       typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,  \
    393       typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh,       \
    394       typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs,         \
    395       typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o,          \
    396       typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co,        \
    397       typename TTypes<T>::Matrix icfo, typename TTypes<T>::Matrix h);      \
    398                                                                            \
    399   extern template struct LSTMBlockCellFprop<GPUDevice, T, true>;
    400 
    401 DECLARE_GPU_SPEC(float);
    402 DECLARE_GPU_SPEC(Eigen::half);
    403 #undef DECLARE_GPU_SPEC
    404 }  // end namespace functor
    405 
    406 #define REGISTER_GPU_KERNEL(T)                                         \
    407   REGISTER_KERNEL_BUILDER(                                             \
    408       Name("LSTMBlockCell").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
    409       LSTMBlockCellOp<GPUDevice, T, true>);
    410 
    411 REGISTER_GPU_KERNEL(float);
    412 REGISTER_GPU_KERNEL(Eigen::half);
    413 // REGISTER_GPU_KERNEL(double);
    414 #undef REGISTER_GPU_KERNEL
    415 #endif  // GOOGLE_CUDA
    416 
    417 template <typename Device, typename T, bool USE_CUBLAS>
    418 class LSTMBlockCellGradOp : public OpKernel {
    419  public:
    420   explicit LSTMBlockCellGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    421     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_));
    422   }
    423 
    424   void Compute(OpKernelContext* ctx) override {
    425     const Tensor* x_tensor = nullptr;
    426     OP_REQUIRES_OK(ctx, ctx->input("x", &x_tensor));
    427 
    428     const Tensor* cs_prev_tensor = nullptr;
    429     OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor));
    430 
    431     const Tensor* h_prev_tensor = nullptr;
    432     OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor));
    433 
    434     const Tensor* w_tensor = nullptr;
    435     OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor));
    436 
    437     const Tensor* wci_tensor = nullptr;
    438     OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor));
    439 
    440     const Tensor* wcf_tensor = nullptr;
    441     OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor));
    442 
    443     const Tensor* wco_tensor = nullptr;
    444     OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor));
    445 
    446     const Tensor* b_tensor = nullptr;
    447     OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor));
    448 
    449     const Tensor* i_tensor = nullptr;
    450     OP_REQUIRES_OK(ctx, ctx->input("i", &i_tensor));
    451 
    452     const Tensor* cs_tensor = nullptr;
    453     OP_REQUIRES_OK(ctx, ctx->input("cs", &cs_tensor));
    454 
    455     const Tensor* f_tensor = nullptr;
    456     OP_REQUIRES_OK(ctx, ctx->input("f", &f_tensor));
    457 
    458     const Tensor* o_tensor = nullptr;
    459     OP_REQUIRES_OK(ctx, ctx->input("o", &o_tensor));
    460 
    461     const Tensor* ci_tensor = nullptr;
    462     OP_REQUIRES_OK(ctx, ctx->input("ci", &ci_tensor));
    463 
    464     const Tensor* co_tensor = nullptr;
    465     OP_REQUIRES_OK(ctx, ctx->input("co", &co_tensor));
    466 
    467     const Tensor* cs_grad_tensor = nullptr;
    468     OP_REQUIRES_OK(ctx, ctx->input("cs_grad", &cs_grad_tensor));
    469 
    470     const Tensor* h_grad_tensor = nullptr;
    471     OP_REQUIRES_OK(ctx, ctx->input("h_grad", &h_grad_tensor));
    472 
    473     const int64 batch_size = x_tensor->dim_size(0);
    474     const int64 input_size = x_tensor->dim_size(1);
    475     const int64 cell_size = cs_prev_tensor->dim_size(1);
    476 
    477     // Sanity checks for our input shapes.
    478     OP_REQUIRES(ctx, cs_prev_tensor->dim_size(0) == batch_size,
    479                 errors::InvalidArgument("cs_prev.dims(0) != batch_size: ",
    480                                         cs_prev_tensor->dim_size(0), " vs. ",
    481                                         batch_size));
    482     OP_REQUIRES(ctx, cs_prev_tensor->dim_size(1) == cell_size,
    483                 errors::InvalidArgument("cs_prev.dims(1) != cell_size: ",
    484                                         cs_prev_tensor->dim_size(1), " vs. ",
    485                                         cell_size));
    486 
    487     OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size,
    488                 errors::InvalidArgument("h_prev.dims(0) != batch_size: ",
    489                                         h_prev_tensor->dim_size(0), " vs. ",
    490                                         batch_size));
    491     OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size,
    492                 errors::InvalidArgument(
    493                     "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1),
    494                     " vs. ", cell_size));
    495 
    496     OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size,
    497                 errors::InvalidArgument(
    498                     "w.dim_size(0) != input_size + cell_size: ",
    499                     w_tensor->dim_size(0), " vs. ", input_size + cell_size));
    500     OP_REQUIRES(ctx, w_tensor->dim_size(1) == cell_size * 4,
    501                 errors::InvalidArgument(
    502                     "w.dim_size(1) != cell_size * 4: ", w_tensor->dim_size(1),
    503                     " vs. ", cell_size * 4));
    504 
    505     OP_REQUIRES(ctx, b_tensor->dim_size(0) == cell_size * 4,
    506                 errors::InvalidArgument(
    507                     "b.dim_size(0) != cell_size * 4: ", b_tensor->dim_size(0),
    508                     " vs. ", cell_size * 4));
    509 
    510     OP_REQUIRES(ctx, i_tensor->dim_size(0) == batch_size,
    511                 errors::InvalidArgument(
    512                     "i.dim_size(0) != batch_size: ", i_tensor->dim_size(0),
    513                     " vs. ", batch_size));
    514     OP_REQUIRES(ctx, i_tensor->dim_size(1) == cell_size,
    515                 errors::InvalidArgument(
    516                     "i.dim_size(1) != cell_size: ", i_tensor->dim_size(1),
    517                     " vs. ", cell_size));
    518 
    519     OP_REQUIRES(ctx, cs_tensor->dim_size(0) == batch_size,
    520                 errors::InvalidArgument(
    521                     "cs.dim_size(0) != batch_size: ", cs_tensor->dim_size(0),
    522                     " vs. ", batch_size));
    523     OP_REQUIRES(ctx, cs_tensor->dim_size(1) == cell_size,
    524                 errors::InvalidArgument(
    525                     "cs.dim_size(1) != cell_size: ", cs_tensor->dim_size(1),
    526                     " vs. ", cell_size));
    527 
    528     OP_REQUIRES(ctx, f_tensor->dim_size(0) == batch_size,
    529                 errors::InvalidArgument(
    530                     "f.dim_size(0) != batch_size: ", f_tensor->dim_size(0),
    531                     " vs. ", batch_size));
    532     OP_REQUIRES(ctx, f_tensor->dim_size(1) == cell_size,
    533                 errors::InvalidArgument(
    534                     "i.dim_size(1) != cell_size: ", f_tensor->dim_size(1),
    535                     " vs. ", cell_size));
    536 
    537     OP_REQUIRES(ctx, o_tensor->dim_size(0) == batch_size,
    538                 errors::InvalidArgument(
    539                     "o.dim_size(0) != batch_size: ", o_tensor->dim_size(0),
    540                     " vs. ", batch_size));
    541     OP_REQUIRES(ctx, o_tensor->dim_size(1) == cell_size,
    542                 errors::InvalidArgument(
    543                     "o.dim_size(1) != cell_size: ", o_tensor->dim_size(1),
    544                     " vs. ", cell_size));
    545 
    546     OP_REQUIRES(ctx, ci_tensor->dim_size(0) == batch_size,
    547                 errors::InvalidArgument(
    548                     "ci.dim_size(0) != batch_size: ", ci_tensor->dim_size(0),
    549                     " vs. ", batch_size));
    550     OP_REQUIRES(ctx, ci_tensor->dim_size(1) == cell_size,
    551                 errors::InvalidArgument(
    552                     "ci.dim_size(1) != cell_size: ", ci_tensor->dim_size(1),
    553                     " vs. ", cell_size));
    554 
    555     OP_REQUIRES(ctx, co_tensor->dim_size(0) == batch_size,
    556                 errors::InvalidArgument(
    557                     "co.dim_size(0) != batch_size: ", co_tensor->dim_size(0),
    558                     " vs. ", batch_size));
    559     OP_REQUIRES(ctx, co_tensor->dim_size(1) == cell_size,
    560                 errors::InvalidArgument(
    561                     "co.dim_size(1) != cell_size: ", co_tensor->dim_size(1),
    562                     " vs. ", cell_size));
    563 
    564     OP_REQUIRES(ctx, cs_grad_tensor->dim_size(0) == batch_size,
    565                 errors::InvalidArgument(
    566                     "cs_grad_tensor.dims(0) != batch_size: ",
    567                     cs_grad_tensor->dim_size(0), " vs. ", batch_size));
    568     OP_REQUIRES(ctx, cs_grad_tensor->dim_size(1) == cell_size,
    569                 errors::InvalidArgument("cs_grad_tensor.dims(1) != cell_size: ",
    570                                         cs_grad_tensor->dim_size(1), " vs. ",
    571                                         cell_size));
    572 
    573     OP_REQUIRES(ctx, h_grad_tensor->dim_size(0) == batch_size,
    574                 errors::InvalidArgument("h_grad_tensor.dims(0) != batch_size: ",
    575                                         h_grad_tensor->dim_size(0), " vs. ",
    576                                         batch_size));
    577     OP_REQUIRES(ctx, h_grad_tensor->dim_size(1) == cell_size,
    578                 errors::InvalidArgument("h_grad_tensor.dims(1) != cell_size: ",
    579                                         h_grad_tensor->dim_size(1), " vs. ",
    580                                         cell_size));
    581 
    582     // Allocate our output tensors.
    583     Tensor* cs_prev_grad_tensor = nullptr;
    584     OP_REQUIRES_OK(
    585         ctx, ctx->forward_input_or_allocate_output(
    586                  {"cs_grad"}, "cs_prev_grad",
    587                  TensorShape({batch_size, cell_size}), &cs_prev_grad_tensor));
    588 
    589     Tensor* dicfo_tensor = nullptr;
    590     OP_REQUIRES_OK(ctx, ctx->allocate_output(
    591                             "dicfo", TensorShape({batch_size, cell_size * 4}),
    592                             &dicfo_tensor));
    593 
    594     Tensor* wci_grad_tensor = nullptr;
    595     OP_REQUIRES_OK(
    596         ctx, ctx->forward_input_or_allocate_output(
    597                  {"wci"}, "wci_grad", wci_tensor->shape(), &wci_grad_tensor));
    598 
    599     Tensor* wcf_grad_tensor = nullptr;
    600     OP_REQUIRES_OK(
    601         ctx, ctx->forward_input_or_allocate_output(
    602                  {"wcf"}, "wcf_grad", wcf_tensor->shape(), &wcf_grad_tensor));
    603 
    604     Tensor* wco_grad_tensor = nullptr;
    605     OP_REQUIRES_OK(
    606         ctx, ctx->forward_input_or_allocate_output(
    607                  {"wco"}, "wco_grad", wco_tensor->shape(), &wco_grad_tensor));
    608 
    609     // Allocate our temp tensors.
    610     Tensor do_tensor;
    611     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
    612                                            TensorShape({batch_size, cell_size}),
    613                                            &do_tensor));
    614 
    615     Tensor dcs_tensor;
    616     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
    617                                            TensorShape({batch_size, cell_size}),
    618                                            &dcs_tensor));
    619 
    620     Tensor dci_tensor;
    621     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
    622                                            TensorShape({batch_size, cell_size}),
    623                                            &dci_tensor));
    624 
    625     Tensor df_tensor;
    626     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
    627                                            TensorShape({batch_size, cell_size}),
    628                                            &df_tensor));
    629 
    630     Tensor di_tensor;
    631     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
    632                                            TensorShape({batch_size, cell_size}),
    633                                            &di_tensor));
    634 
    635     const Device& device = ctx->eigen_device<Device>();
    636 
    637     functor::TensorZero<Device, T>()(device, wci_grad_tensor->flat<T>());
    638     functor::TensorZero<Device, T>()(device, wcf_grad_tensor->flat<T>());
    639     functor::TensorZero<Device, T>()(device, wco_grad_tensor->flat<T>());
    640 
    641     functor::LSTMBlockCellBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
    642                                                        cell_size)(
    643         ctx, device, use_peephole_, x_tensor->matrix<T>(),
    644         cs_prev_tensor->matrix<T>(), h_prev_tensor->matrix<T>(),
    645         w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(),
    646         wco_tensor->vec<T>(), b_tensor->vec<T>(), i_tensor->matrix<T>(),
    647         cs_tensor->matrix<T>(), f_tensor->matrix<T>(), o_tensor->matrix<T>(),
    648         ci_tensor->matrix<T>(), co_tensor->matrix<T>(),
    649         cs_grad_tensor->matrix<T>(), h_grad_tensor->matrix<T>(),
    650         do_tensor.matrix<T>(), dcs_tensor.matrix<T>(), dci_tensor.matrix<T>(),
    651         df_tensor.matrix<T>(), di_tensor.matrix<T>(), dicfo_tensor->matrix<T>(),
    652         cs_prev_grad_tensor->matrix<T>(), wci_grad_tensor->vec<T>(),
    653         wcf_grad_tensor->vec<T>(), wco_grad_tensor->vec<T>());
    654   }
    655 
    656  protected:
    657   bool use_peephole_;
    658 };
    659 
    660 #define REGISTER_KERNEL(T)                                                 \
    661   REGISTER_KERNEL_BUILDER(                                                 \
    662       Name("LSTMBlockCellGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    663       LSTMBlockCellGradOp<CPUDevice, T, false>);
    664 REGISTER_KERNEL(float);
    665 REGISTER_KERNEL(Eigen::half);
    666 #undef REGISTER_KERNEL
    667 
    668 #if GOOGLE_CUDA
    669 namespace functor {
    670 #define DECLARE_GPU_SPEC(T)                                                   \
    671   template <>                                                                 \
    672   void LSTMBlockCellBprop<GPUDevice, T, true>::operator()(                    \
    673       OpKernelContext* ctx, const GPUDevice& d, bool use_peephole,            \
    674       typename TTypes<T>::ConstMatrix x,                                      \
    675       typename TTypes<T>::ConstMatrix cs_prev,                                \
    676       typename TTypes<T>::ConstMatrix h_prev,                                 \
    677       typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci,    \
    678       typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,     \
    679       typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i,      \
    680       typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f,  \
    681       typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci,  \
    682       typename TTypes<T>::ConstMatrix co,                                     \
    683       typename TTypes<T>::ConstMatrix cs_grad,                                \
    684       typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \
    685       typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,         \
    686       typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,           \
    687       typename TTypes<T>::Matrix dicfo,                                       \
    688       typename TTypes<T>::Matrix cs_prev_grad,                                \
    689       typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,     \
    690       typename TTypes<T>::Vec wco_grad);                                      \
    691                                                                               \
    692   extern template struct LSTMBlockCellBprop<GPUDevice, T,                     \
    693                                             true /* USE_CUBLAS */>;
    694 
    695 DECLARE_GPU_SPEC(float);
    696 DECLARE_GPU_SPEC(Eigen::half);
    697 // DECLARE_GPU_SPEC(double);
    698 #undef DECLARE_GPU_SPEC
    699 }  // namespace functor
    700 
    701 #define REGISTER_GPU_KERNEL(T)                                             \
    702   REGISTER_KERNEL_BUILDER(                                                 \
    703       Name("LSTMBlockCellGrad").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
    704       LSTMBlockCellGradOp<GPUDevice, T, true>);
    705 
    706 REGISTER_GPU_KERNEL(float);
    707 REGISTER_GPU_KERNEL(Eigen::half);
    708 // REGISTER_GPU_KERNEL(double);
    709 #undef REGISTER_GPU_KERNEL
    710 #endif  // GOOGLE_CUDA
    711 
    712 namespace {
    713 
    714 // This helper class can be used to access timeslices of a 3D tensor. If a slice
    715 // happens to be unaligned (usually because both batch size and number of cells
    716 // are odd - this isn't common) this involves overhead, since data needs to be
    717 // copied. However, if all slices are aligned, the bits aren't copied. In the
    718 // cases where copying is needed, the outputs have to be recopied back.
    719 // At the end of each time step you should call FinishTimeStep which does this,
    720 // and also allows for reuse of temporary tensors.
    721 template <typename Device, typename T>
    722 class SliceHelper {
    723  public:
    724   explicit SliceHelper(OpKernelContext* ctx)
    725       : ctx_(ctx), device_(ctx_->eigen_device<Device>()) {}
    726 
    727   ~SliceHelper() {
    728     CHECK(copy_out_.empty());
    729     for (const auto& entry : pool_) {
    730       CHECK(!entry.second.second);  // nothing is in use
    731     }
    732   }
    733 
    734   // Slice through an input tensor. This may copy unaligned slices, but no
    735   // copying back will be done at the end.
    736   const Tensor InputSlice(const Tensor& t, int pos, const string& name) {
    737     Tensor res = UnalignedSlice(t, pos);
    738     if (res.IsAligned()) {
    739       return res;
    740     } else {
    741       return AlignTensor(res, name);
    742     }
    743   }
    744 
    745   // Slice through an output tensor. This may copy unaligned slices, and
    746   // schedule copying back on destruction.
    747   Tensor OutputSlice(Tensor* t, int pos, const string& name) {
    748     Tensor res = UnalignedSlice(*t, pos);
    749     if (res.IsAligned()) {
    750       return res;
    751     } else {
    752       Tensor aligned = AlignTensor(res, name);
    753       copy_out_.emplace_back(res, aligned);
    754       return aligned;
    755     }
    756   }
    757 
    758   void FinishTimeStep() {
    759     for (const auto& p : copy_out_) {
    760       const Tensor& aligned = p.second;
    761       Tensor original = p.first;
    762       // Copy from aligned back to original.
    763       functor::TensorCopyToUnaligned<Device, T>()(device_, aligned.flat<T>(),
    764                                                   original.unaligned_flat<T>());
    765     }
    766     copy_out_.clear();
    767     // Mark all entries as not in use.
    768     for (auto& entry : pool_) {
    769       entry.second.second = false;
    770     }
    771   }
    772 
    773  private:
    774   // Return a slice at position 'pos'. Result may be unaligned. The resulting
    775   // tensor always shares data with the source tensor.
    776   Tensor UnalignedSlice(const Tensor& t, int pos) const {
    777     Tensor res;
    778     // CHECK should never fail here, since the number of elements must match
    779     CHECK(res.CopyFrom(t.Slice(pos, pos + 1), {t.dim_size(1), t.dim_size(2)}));
    780     return res;
    781   }
    782 
    783   // Assumes input is not aligned, creates a temporary aligned tensor of the
    784   // same shape and copies the original tensor's content into it.
    785   Tensor AlignTensor(const Tensor& t, const string& name) {
    786     VLOG(1) << "AlignTensor called for " << name << ", shape "
    787             << t.shape().DebugString()
    788             << ". This is unnecessary copying. Consider using shapes with even "
    789             << "sizes";
    790     Tensor aligned;
    791     auto found = pool_.find(name);
    792     if (found != pool_.end()) {  // found in pool
    793       CHECK(!found->second.second) << "Tensor " << name << " is in use";
    794       found->second.second = true;  // mark in use
    795       aligned = found->second.first;
    796       CHECK(aligned.shape().IsSameSize(t.shape()));
    797       CHECK_EQ(aligned.dtype(), t.dtype());
    798     } else {  // allocate a new temporary tensor
    799       TF_CHECK_OK(ctx_->allocate_temp(t.dtype(), t.shape(), &aligned));
    800       pool_.emplace(name, std::make_pair(aligned, true));
    801     }
    802     functor::TensorCopyUnaligned<Device, T>()(device_, t.unaligned_flat<T>(),
    803                                               aligned.flat<T>());
    804     return aligned;
    805   }
    806 
    807   // Tensors to be copied.
    808   std::vector<std::pair<Tensor, const Tensor>> copy_out_;
    809   // A pool of pre-allocated temporary tensors, with an indicator for whether
    810   // it's in use.
    811   std::map<string, std::pair<Tensor, bool>> pool_;
    812   // Op context
    813   OpKernelContext* ctx_ = nullptr;
    814   // Device
    815   const Device& device_;
    816 };
    817 
    818 }  // namespace
    819 
    820 template <typename Device, typename T, bool USE_CUBLAS>
    821 class BlockLSTMOp : public OpKernel {
    822  public:
    823   explicit BlockLSTMOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    824     OP_REQUIRES_OK(ctx, ctx->GetAttr("forget_bias", &forget_bias_));
    825     OP_REQUIRES_OK(ctx, ctx->GetAttr("cell_clip", &cell_clip_));
    826     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_));
    827   }
    828 
    829   void Compute(OpKernelContext* ctx) override {
    830     const Tensor* seq_len_max_tensor = nullptr;
    831     OP_REQUIRES_OK(ctx, ctx->input("seq_len_max", &seq_len_max_tensor));
    832 
    833     const Tensor* x;
    834     OP_REQUIRES_OK(ctx, ctx->input("x", &x));
    835     OP_REQUIRES(ctx, x->dims() == 3, errors::InvalidArgument("x must be 3D"));
    836     const int64 timelen = x->dim_size(0);
    837     const int64 batch_size = x->dim_size(1);
    838     const int64 input_size = x->dim_size(2);
    839 
    840     const Tensor* cs_prev_tensor = nullptr;
    841     OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor));
    842     OP_REQUIRES(ctx, cs_prev_tensor->dims() == 2,
    843                 errors::InvalidArgument("cs_prev must be 2D"));
    844     OP_REQUIRES(ctx, cs_prev_tensor->dim_size(0) == batch_size,
    845                 errors::InvalidArgument("cs_prev.dims(0) != batch_size: ",
    846                                         cs_prev_tensor->dim_size(0), " vs. ",
    847                                         batch_size));
    848     const int64 cell_size = cs_prev_tensor->dim_size(1);
    849 
    850     if (batch_size * input_size % 2 == 1) {
    851       LOG(WARNING) << "BlockLSTMOp is inefficient when both batch_size and "
    852                    << "input_size are odd. You are using: batch_size="
    853                    << batch_size << ", input_size=" << input_size;
    854     }
    855     if (batch_size * cell_size % 2 == 1) {
    856       LOG(WARNING) << "BlockLSTMOp is inefficient when both batch_size and "
    857                    << "cell_size are odd. You are using: batch_size="
    858                    << batch_size << ", cell_size=" << cell_size;
    859     }
    860 
    861     const Tensor* h_prev_tensor = nullptr;
    862     OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor));
    863     OP_REQUIRES(ctx, h_prev_tensor->dims() == 2,
    864                 errors::InvalidArgument("h_prev must be 2D"));
    865     OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size,
    866                 errors::InvalidArgument("h_prev.dims(0) != batch_size: ",
    867                                         h_prev_tensor->dim_size(0), " vs. ",
    868                                         batch_size));
    869     OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size,
    870                 errors::InvalidArgument(
    871                     "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1),
    872                     " vs. ", cell_size));
    873 
    874     const Tensor* w_tensor = nullptr;
    875     OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor));
    876     OP_REQUIRES(ctx, w_tensor->dims() == 2,
    877                 errors::InvalidArgument("w must be 2D"));
    878     OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size,
    879                 errors::InvalidArgument(
    880                     "w.dim_size(0) != input_size + cell_size: ",
    881                     w_tensor->dim_size(0), " vs. ", input_size + cell_size));
    882     OP_REQUIRES(ctx, w_tensor->dim_size(1) == cell_size * 4,
    883                 errors::InvalidArgument(
    884                     "w.dim_size(1) != cell_size * 4: ", w_tensor->dim_size(1),
    885                     " vs. ", cell_size * 4));
    886 
    887     const Tensor* wci_tensor = nullptr;
    888     OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor));
    889     OP_REQUIRES(ctx, wci_tensor->dims() == 1,
    890                 errors::InvalidArgument("wci must be 1D"));
    891     OP_REQUIRES(ctx, wci_tensor->dim_size(0) == cell_size,
    892                 errors::InvalidArgument(
    893                     "wci.dim_size(0) != cell_size: ", wci_tensor->dim_size(0),
    894                     " vs. ", cell_size));
    895 
    896     const Tensor* wcf_tensor = nullptr;
    897     OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor));
    898     OP_REQUIRES(ctx, wcf_tensor->dims() == 1,
    899                 errors::InvalidArgument("wcf must be 1D"));
    900     OP_REQUIRES(ctx, wcf_tensor->dim_size(0) == cell_size,
    901                 errors::InvalidArgument(
    902                     "wcf.dim_size(0) != cell_size: ", wcf_tensor->dim_size(0),
    903                     " vs. ", cell_size));
    904 
    905     const Tensor* wco_tensor = nullptr;
    906     OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor));
    907     OP_REQUIRES(ctx, wco_tensor->dims() == 1,
    908                 errors::InvalidArgument("wco must be 1D"));
    909     OP_REQUIRES(ctx, wco_tensor->dim_size(0) == cell_size,
    910                 errors::InvalidArgument(
    911                     "wco.dim_size(0) != cell_size: ", wco_tensor->dim_size(0),
    912                     " vs. ", cell_size));
    913 
    914     const Tensor* b_tensor = nullptr;
    915     OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor));
    916     OP_REQUIRES(ctx, b_tensor->dims() == 1,
    917                 errors::InvalidArgument("b must be 1D"));
    918     OP_REQUIRES(ctx, b_tensor->dim_size(0) == cell_size * 4,
    919                 errors::InvalidArgument(
    920                     "b.dim_size(0) != cell_size * 4: ", b_tensor->dim_size(0),
    921                     " vs. ", cell_size * 4));
    922 
    923     TensorShape batch_cell_shape({timelen, batch_size, cell_size});
    924     Tensor* i_out;
    925     OP_REQUIRES_OK(ctx, ctx->allocate_output("i", batch_cell_shape, &i_out));
    926 
    927     Tensor* cs_out;
    928     OP_REQUIRES_OK(ctx, ctx->allocate_output("cs", batch_cell_shape, &cs_out));
    929 
    930     Tensor* f_out;
    931     OP_REQUIRES_OK(ctx, ctx->allocate_output("f", batch_cell_shape, &f_out));
    932 
    933     Tensor* o_out;
    934     OP_REQUIRES_OK(ctx, ctx->allocate_output("o", batch_cell_shape, &o_out));
    935 
    936     Tensor* ci_out;
    937     OP_REQUIRES_OK(ctx, ctx->allocate_output("ci", batch_cell_shape, &ci_out));
    938 
    939     Tensor* co_out;
    940     OP_REQUIRES_OK(ctx, ctx->allocate_output("co", batch_cell_shape, &co_out));
    941 
    942     Tensor* h_out;
    943     OP_REQUIRES_OK(ctx, ctx->allocate_output("h", batch_cell_shape, &h_out));
    944 
    945     Tensor xh_tensor;
    946     OP_REQUIRES_OK(ctx, ctx->allocate_temp(
    947                             DataTypeToEnum<T>::v(),
    948                             TensorShape({batch_size, input_size + cell_size}),
    949                             &xh_tensor));
    950 
    951     Tensor icfo_tensor;
    952     OP_REQUIRES_OK(ctx,
    953                    ctx->allocate_temp(DataTypeToEnum<T>::v(),
    954                                       TensorShape({batch_size, cell_size * 4}),
    955                                       &icfo_tensor));
    956 
    957     const Device& device = ctx->eigen_device<Device>();
    958 
    959     const int64 seq_len_max = seq_len_max_tensor->scalar<int64>()();
    960     SliceHelper<Device, T> slicer(ctx);
    961     for (int64 t = 0; t < seq_len_max; ++t) {
    962       const Tensor x_tensor = slicer.InputSlice(*x, t, "x");
    963       const Tensor& cs_prev_tensor2 =
    964           t == 0 ? *cs_prev_tensor
    965                  : slicer.OutputSlice(cs_out, t - 1, "cs_prev");
    966       const Tensor& h_prev_tensor2 =
    967           t == 0 ? *h_prev_tensor : slicer.OutputSlice(h_out, t - 1, "h_prev");
    968 
    969       Tensor i_tensor = slicer.OutputSlice(i_out, t, "i_out");
    970       Tensor cs_tensor = slicer.OutputSlice(cs_out, t, "cs_out");
    971       Tensor f_tensor = slicer.OutputSlice(f_out, t, "f_out");
    972       Tensor o_tensor = slicer.OutputSlice(o_out, t, "o_out");
    973       Tensor ci_tensor = slicer.OutputSlice(ci_out, t, "ci_out");
    974       Tensor co_tensor = slicer.OutputSlice(co_out, t, "co_out");
    975       Tensor h_tensor = slicer.OutputSlice(h_out, t, "h_out");
    976 
    977       functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
    978                                                          cell_size)(
    979           ctx, device, forget_bias_, cell_clip_, use_peephole_,
    980           x_tensor.matrix<T>(), cs_prev_tensor2.matrix<T>(),
    981           h_prev_tensor2.matrix<T>(), w_tensor->matrix<T>(),
    982           wci_tensor->vec<T>(), wcf_tensor->vec<T>(), wco_tensor->vec<T>(),
    983           b_tensor->vec<T>(), xh_tensor.matrix<T>(), i_tensor.matrix<T>(),
    984           cs_tensor.matrix<T>(), f_tensor.matrix<T>(), o_tensor.matrix<T>(),
    985           ci_tensor.matrix<T>(), co_tensor.matrix<T>(), icfo_tensor.matrix<T>(),
    986           h_tensor.matrix<T>());
    987       slicer.FinishTimeStep();
    988     }
    989 
    990     if (seq_len_max < timelen) {
    991       Tensor cs_tensor = cs_out->Slice(seq_len_max, timelen);
    992       Tensor h_tensor = h_out->Slice(seq_len_max, timelen);
    993 
    994       functor::TensorUnalignedZero<Device, T>()(device,
    995                                                 cs_tensor.unaligned_flat<T>());
    996       functor::TensorUnalignedZero<Device, T>()(device,
    997                                                 h_tensor.unaligned_flat<T>());
    998     }
    999   }
   1000 
   1001  private:
   1002   float forget_bias_;
   1003   float cell_clip_;
   1004   bool use_peephole_;
   1005 };
   1006 
   1007 #define REGISTER_KERNEL(T)                                         \
   1008   REGISTER_KERNEL_BUILDER(                                         \
   1009       Name("BlockLSTM").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
   1010       BlockLSTMOp<CPUDevice, T, false>);
   1011 REGISTER_KERNEL(float);
   1012 REGISTER_KERNEL(Eigen::half);
   1013 #undef REGISTER_KERNEL
   1014 
   1015 #if GOOGLE_CUDA
   1016 namespace functor {
   1017 #define DECLARE_GPU_SPEC(T)                                              \
   1018   template <>                                                            \
   1019   void TensorZero<GPUDevice, T>::operator()(const GPUDevice& d,          \
   1020                                             typename TTypes<T>::Flat t); \
   1021                                                                          \
   1022   extern template struct TensorZero<GPUDevice, T>;                       \
   1023                                                                          \
   1024   template <>                                                            \
   1025   void TensorUnalignedZero<GPUDevice, T>::operator()(                    \
   1026       const GPUDevice& d, typename TTypes<T>::UnalignedFlat t);          \
   1027                                                                          \
   1028   extern template struct TensorUnalignedZero<GPUDevice, T>;
   1029 
   1030 DECLARE_GPU_SPEC(float);
   1031 DECLARE_GPU_SPEC(Eigen::half);
   1032 // DECLARE_GPU_SPEC(double);
   1033 #undef DECLARE_GPU_SPEC
   1034 }  // end namespace functor
   1035 
   1036 #define REGISTER_GPU_KERNEL(T)                           \
   1037   REGISTER_KERNEL_BUILDER(Name("BlockLSTM")              \
   1038                               .Device(DEVICE_GPU)        \
   1039                               .HostMemory("seq_len_max") \
   1040                               .TypeConstraint<T>("T"),   \
   1041                           BlockLSTMOp<GPUDevice, T, true>);
   1042 
   1043 REGISTER_GPU_KERNEL(float);
   1044 REGISTER_GPU_KERNEL(Eigen::half);
   1045 // REGISTER_GPU_KERNEL(double);
   1046 #undef REGISTER_GPU_KERNEL
   1047 #endif  // GOOGLE_CUDA
   1048 
   1049 template <typename Device, typename T, bool USE_CUBLAS>
   1050 class BlockLSTMGradOp : public OpKernel {
   1051  public:
   1052   explicit BlockLSTMGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
   1053     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_));
   1054   }
   1055 
   1056   void Compute(OpKernelContext* ctx) override {
   1057     const Tensor* seq_len_max_tensor = nullptr;
   1058     OP_REQUIRES_OK(ctx, ctx->input("seq_len_max", &seq_len_max_tensor));
   1059 
   1060     const Tensor* x;
   1061     OP_REQUIRES_OK(ctx, ctx->input("x", &x));
   1062     OP_REQUIRES(ctx, x->dims() == 3, errors::InvalidArgument("x must be 3D"));
   1063     const int64 timelen = x->dim_size(0);
   1064     const int64 batch_size = x->dim_size(1);
   1065     const int64 input_size = x->dim_size(2);
   1066 
   1067     const Tensor* cs_prev_tensor = nullptr;
   1068     OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor));
   1069 
   1070     const Tensor* h_prev_tensor = nullptr;
   1071     OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor));
   1072 
   1073     const Tensor* w_tensor = nullptr;
   1074     OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor));
   1075     const int64 cell_size = w_tensor->dim_size(1) / 4;
   1076     OP_REQUIRES(ctx, input_size + cell_size == w_tensor->dim_size(0),
   1077                 errors::InvalidArgument(
   1078                     "w matrix rows don't match: ", input_size + cell_size,
   1079                     " vs. ", w_tensor->dim_size(0)));
   1080 
   1081     const Tensor* wci_tensor = nullptr;
   1082     OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor));
   1083 
   1084     const Tensor* wcf_tensor = nullptr;
   1085     OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor));
   1086 
   1087     const Tensor* wco_tensor = nullptr;
   1088     OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor));
   1089 
   1090     const Tensor* b_tensor = nullptr;
   1091     OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor));
   1092     OP_REQUIRES(
   1093         ctx, cell_size == b_tensor->dim_size(0) / 4,
   1094         errors::InvalidArgument("w and b cell_size don't match: ", cell_size,
   1095                                 " vs. ", b_tensor->dim_size(0)));
   1096 
   1097     const Tensor* i_out = nullptr;
   1098     OP_REQUIRES_OK(ctx, ctx->input("i", &i_out));
   1099 
   1100     const Tensor* cs_out = nullptr;
   1101     OP_REQUIRES_OK(ctx, ctx->input("cs", &cs_out));
   1102 
   1103     const Tensor* f_out = nullptr;
   1104     OP_REQUIRES_OK(ctx, ctx->input("f", &f_out));
   1105 
   1106     const Tensor* o_out = nullptr;
   1107     OP_REQUIRES_OK(ctx, ctx->input("o", &o_out));
   1108 
   1109     const Tensor* ci_out = nullptr;
   1110     OP_REQUIRES_OK(ctx, ctx->input("ci", &ci_out));
   1111 
   1112     const Tensor* co_out = nullptr;
   1113     OP_REQUIRES_OK(ctx, ctx->input("co", &co_out));
   1114 
   1115     const Tensor* h_out = nullptr;
   1116     OP_REQUIRES_OK(ctx, ctx->input("h", &h_out));
   1117 
   1118     const Tensor* cs_grad = nullptr;
   1119     OP_REQUIRES_OK(ctx, ctx->input("cs_grad", &cs_grad));
   1120 
   1121     const Tensor* h_grad = nullptr;
   1122     OP_REQUIRES_OK(ctx, ctx->input("h_grad", &h_grad));
   1123 
   1124     TensorShape batch_input_shape({timelen, batch_size, input_size});
   1125     Tensor* x_grad;
   1126     OP_REQUIRES_OK(ctx,
   1127                    ctx->allocate_output("x_grad", batch_input_shape, &x_grad));
   1128 
   1129     Tensor* cs_prev_grad_tensor = nullptr;
   1130     OP_REQUIRES_OK(ctx,
   1131                    ctx->allocate_output("cs_prev_grad", cs_prev_tensor->shape(),
   1132                                         &cs_prev_grad_tensor));
   1133 
   1134     Tensor* h_prev_grad_tensor = nullptr;
   1135     OP_REQUIRES_OK(ctx,
   1136                    ctx->allocate_output("h_prev_grad", h_prev_tensor->shape(),
   1137                                         &h_prev_grad_tensor));
   1138 
   1139     Tensor* w_grad_tensor = nullptr;
   1140     OP_REQUIRES_OK(
   1141         ctx, ctx->allocate_output("w_grad", w_tensor->shape(), &w_grad_tensor));
   1142 
   1143     Tensor* wci_grad_tensor = nullptr;
   1144     OP_REQUIRES_OK(ctx, ctx->allocate_output("wci_grad", wci_tensor->shape(),
   1145                                              &wci_grad_tensor));
   1146 
   1147     Tensor* wcf_grad_tensor = nullptr;
   1148     OP_REQUIRES_OK(ctx, ctx->allocate_output("wcf_grad", wcf_tensor->shape(),
   1149                                              &wcf_grad_tensor));
   1150 
   1151     Tensor* wco_grad_tensor = nullptr;
   1152     OP_REQUIRES_OK(ctx, ctx->allocate_output("wco_grad", wco_tensor->shape(),
   1153                                              &wco_grad_tensor));
   1154 
   1155     Tensor* b_grad_tensor = nullptr;
   1156     OP_REQUIRES_OK(
   1157         ctx, ctx->allocate_output("b_grad", b_tensor->shape(), &b_grad_tensor));
   1158 
   1159     TensorShape batch_cell_shape({batch_size, cell_size});
   1160 
   1161     Tensor xh_tensor;
   1162     OP_REQUIRES_OK(ctx, ctx->allocate_temp(
   1163                             DataTypeToEnum<T>::v(),
   1164                             TensorShape({batch_size, input_size + cell_size}),
   1165                             &xh_tensor));
   1166 
   1167     Tensor xh_grad_tensor;
   1168     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
   1169                                            xh_tensor.shape(), &xh_grad_tensor));
   1170 
   1171     Tensor do_tensor;
   1172     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
   1173                                            batch_cell_shape, &do_tensor));
   1174 
   1175     Tensor dcs_tensor;
   1176     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
   1177                                            batch_cell_shape, &dcs_tensor));
   1178 
   1179     Tensor dci_tensor;
   1180     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
   1181                                            batch_cell_shape, &dci_tensor));
   1182 
   1183     Tensor df_tensor;
   1184     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
   1185                                            batch_cell_shape, &df_tensor));
   1186 
   1187     Tensor di_tensor;
   1188     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
   1189                                            batch_cell_shape, &di_tensor));
   1190 
   1191     Tensor dicfo_tensor;
   1192     OP_REQUIRES_OK(ctx,
   1193                    ctx->allocate_temp(DataTypeToEnum<T>::v(),
   1194                                       TensorShape({batch_size, cell_size * 4}),
   1195                                       &dicfo_tensor));
   1196 
   1197     Tensor cs_grad_tensor;
   1198     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
   1199                                            batch_cell_shape, &cs_grad_tensor));
   1200 
   1201     Tensor h_grad_tensor;
   1202     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
   1203                                            batch_cell_shape, &h_grad_tensor));
   1204 
   1205     const Device& device = ctx->eigen_device<Device>();
   1206 
   1207     functor::TensorZero<Device, T>()(device, cs_grad_tensor.flat<T>());
   1208     functor::TensorZero<Device, T>()(device, cs_prev_grad_tensor->flat<T>());
   1209     functor::TensorZero<Device, T>()(device, h_grad_tensor.flat<T>());
   1210     functor::TensorZero<Device, T>()(device, h_prev_grad_tensor->flat<T>());
   1211     functor::TensorZero<Device, T>()(device, w_grad_tensor->flat<T>());
   1212     functor::TensorZero<Device, T>()(device, wci_grad_tensor->flat<T>());
   1213     functor::TensorZero<Device, T>()(device, wcf_grad_tensor->flat<T>());
   1214     functor::TensorZero<Device, T>()(device, wco_grad_tensor->flat<T>());
   1215     functor::TensorZero<Device, T>()(device, b_grad_tensor->flat<T>());
   1216 
   1217     const int64 seq_len_max = seq_len_max_tensor->scalar<int64>()();
   1218     SliceHelper<Device, T> slicer(ctx);
   1219     for (int64 t = seq_len_max - 1; t >= 0; --t) {
   1220       const Tensor& x_tensor = slicer.InputSlice(*x, t, "x");
   1221       const Tensor& cs_prev_tensor2 =
   1222           t == 0 ? *cs_prev_tensor
   1223                  : slicer.InputSlice(*cs_out, t - 1, "cs_prev");
   1224       const Tensor& h_prev_tensor2 =
   1225           t == 0 ? *h_prev_tensor : slicer.InputSlice(*h_out, t - 1, "h_prev");
   1226       const Tensor& i_tensor = slicer.InputSlice(*i_out, t, "i_out");
   1227       const Tensor& cs_tensor = slicer.InputSlice(*cs_out, t, "cs_out");
   1228       const Tensor& f_tensor = slicer.InputSlice(*f_out, t, "f_out");
   1229       const Tensor& o_tensor = slicer.InputSlice(*o_out, t, "o_out");
   1230       const Tensor& ci_tensor = slicer.InputSlice(*ci_out, t, "ci_out");
   1231       const Tensor& co_tensor = slicer.InputSlice(*co_out, t, "co_out");
   1232 
   1233       // Grab previous CS grad.
   1234       const Tensor& const_cs_prev_grad_tensor = *cs_prev_grad_tensor;
   1235       const Tensor const_cs_grad_slice =
   1236           slicer.InputSlice(*cs_grad, t, "cs_grad");
   1237       functor::TensorAdd<Device, T>()(
   1238           device, const_cs_prev_grad_tensor.flat<T>(),
   1239           const_cs_grad_slice.flat<T>(), cs_grad_tensor.flat<T>());
   1240 
   1241       // Combine previous h grad and h grad coming on top.
   1242       const Tensor& const_h_prev_grad_tensor = *h_prev_grad_tensor;
   1243       const Tensor const_h_grad_slice = slicer.InputSlice(*h_grad, t, "h_grad");
   1244       functor::TensorAdd<Device, T>()(
   1245           device, const_h_prev_grad_tensor.flat<T>(),
   1246           const_h_grad_slice.flat<T>(), h_grad_tensor.flat<T>());
   1247 
   1248       const Tensor& const_cs_grad_tensor = cs_grad_tensor;
   1249       const Tensor& const_h_grad_tensor = h_grad_tensor;
   1250 
   1251       Tensor x_grad_tensor = slicer.OutputSlice(x_grad, t, "x_grad");
   1252       functor::BlockLSTMBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
   1253                                                      cell_size)(
   1254           ctx, device, use_peephole_, x_tensor.matrix<T>(),
   1255           cs_prev_tensor2.matrix<T>(), h_prev_tensor2.matrix<T>(),
   1256           w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(),
   1257           wco_tensor->vec<T>(), b_tensor->vec<T>(), xh_tensor.matrix<T>(),
   1258           i_tensor.matrix<T>(), cs_tensor.matrix<T>(), f_tensor.matrix<T>(),
   1259           o_tensor.matrix<T>(), ci_tensor.matrix<T>(), co_tensor.matrix<T>(),
   1260           const_cs_grad_tensor.matrix<T>(), const_h_grad_tensor.matrix<T>(),
   1261           do_tensor.matrix<T>(), dcs_tensor.matrix<T>(), dci_tensor.matrix<T>(),
   1262           df_tensor.matrix<T>(), di_tensor.matrix<T>(),
   1263           dicfo_tensor.matrix<T>(), cs_prev_grad_tensor->matrix<T>(),
   1264           h_prev_grad_tensor->matrix<T>(), xh_grad_tensor.matrix<T>(),
   1265           x_grad_tensor.matrix<T>(), w_grad_tensor->matrix<T>(),
   1266           wci_grad_tensor->vec<T>(), wcf_grad_tensor->vec<T>(),
   1267           wco_grad_tensor->vec<T>(), b_grad_tensor->vec<T>());
   1268       slicer.FinishTimeStep();
   1269     }
   1270 
   1271     if (seq_len_max < timelen) {
   1272       Tensor x_grad_tensor = x_grad->Slice(seq_len_max, timelen);
   1273       functor::TensorUnalignedZero<Device, T>()(
   1274           device, x_grad_tensor.unaligned_flat<T>());
   1275     }
   1276   }
   1277 
   1278  private:
   1279   bool use_peephole_;
   1280 };
   1281 
   1282 #define REGISTER_KERNEL(T)                                             \
   1283   REGISTER_KERNEL_BUILDER(                                             \
   1284       Name("BlockLSTMGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
   1285       BlockLSTMGradOp<CPUDevice, T, false>);
   1286 REGISTER_KERNEL(float);
   1287 REGISTER_KERNEL(Eigen::half);
   1288 #undef REGISTER_KERNEL
   1289 
   1290 #if GOOGLE_CUDA
   1291 namespace functor {
   1292 #define DECLARE_GPU_SPEC(T)                                                    \
   1293   template <>                                                                  \
   1294   void TensorCopy<GPUDevice, T>::operator()(const GPUDevice& d,                \
   1295                                             typename TTypes<T>::ConstFlat src, \
   1296                                             typename TTypes<T>::Flat dst);     \
   1297                                                                                \
   1298   template <>                                                                  \
   1299   void TensorCopyUnaligned<GPUDevice, T>::operator()(                          \
   1300       const GPUDevice& d, typename TTypes<T>::UnalignedConstFlat src,          \
   1301       typename TTypes<T>::Flat dst);                                           \
   1302                                                                                \
   1303   template <>                                                                  \
   1304   void TensorCopyToUnaligned<GPUDevice, T>::operator()(                        \
   1305       const GPUDevice& d, typename TTypes<T>::ConstFlat src,                   \
   1306       typename TTypes<T>::UnalignedFlat dst);                                  \
   1307                                                                                \
   1308   template <>                                                                  \
   1309   void TensorAdd<GPUDevice, T>::operator()(                                    \
   1310       const GPUDevice& d, typename TTypes<T>::ConstFlat a,                     \
   1311       typename TTypes<T>::ConstFlat b, typename TTypes<T>::Flat c);            \
   1312                                                                                \
   1313   template <>                                                                  \
   1314   void BlockLSTMBprop<GPUDevice, T, true>::operator()(                         \
   1315       OpKernelContext* ctx, const GPUDevice& d, bool use_peephole,             \
   1316       typename TTypes<T>::ConstMatrix x,                                       \
   1317       typename TTypes<T>::ConstMatrix cs_prev,                                 \
   1318       typename TTypes<T>::ConstMatrix h_prev,                                  \
   1319       typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci,     \
   1320       typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,      \
   1321       typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh,           \
   1322       typename TTypes<T>::ConstMatrix i, typename TTypes<T>::ConstMatrix cs,   \
   1323       typename TTypes<T>::ConstMatrix f, typename TTypes<T>::ConstMatrix o,    \
   1324       typename TTypes<T>::ConstMatrix ci, typename TTypes<T>::ConstMatrix co,  \
   1325       typename TTypes<T>::ConstMatrix cs_grad,                                 \
   1326       typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,  \
   1327       typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,          \
   1328       typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,            \
   1329       typename TTypes<T>::Matrix dicfo,                                        \
   1330       typename TTypes<T>::Matrix cs_prev_grad,                                 \
   1331       typename TTypes<T>::Matrix h_prev_grad,                                  \
   1332       typename TTypes<T>::Matrix xh_grad, typename TTypes<T>::Matrix x_grad,   \
   1333       typename TTypes<T>::Matrix w_grad, typename TTypes<T>::Vec wci_grad,     \
   1334       typename TTypes<T>::Vec wcf_grad, typename TTypes<T>::Vec wco_grad,      \
   1335       typename TTypes<T>::Vec b_grad);                                         \
   1336                                                                                \
   1337   extern template struct TensorCopy<GPUDevice, T>;                             \
   1338   extern template struct TensorAdd<GPUDevice, T>;                              \
   1339   extern template struct BlockLSTMBprop<GPUDevice, T, true>;
   1340 
   1341 DECLARE_GPU_SPEC(float);
   1342 DECLARE_GPU_SPEC(Eigen::half);
   1343 // DECLARE_GPU_SPEC(double);
   1344 #undef DECLARE_GPU_SPEC
   1345 }  // end namespace functor
   1346 
   1347 #define REGISTER_GPU_KERNEL(T)                           \
   1348   REGISTER_KERNEL_BUILDER(Name("BlockLSTMGrad")          \
   1349                               .Device(DEVICE_GPU)        \
   1350                               .HostMemory("seq_len_max") \
   1351                               .TypeConstraint<T>("T"),   \
   1352                           BlockLSTMGradOp<GPUDevice, T, true>);
   1353 
   1354 REGISTER_GPU_KERNEL(float);
   1355 REGISTER_GPU_KERNEL(Eigen::half);
   1356 // REGISTER_GPU_KERNEL(double);
   1357 #undef REGISTER_GPU_KERNEL
   1358 #endif  // GOOGLE_CUDA
   1359 
   1360 }  // end namespace tensorflow
   1361