Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      7     http://www.apache.org/licenses/LICENSE-2.0
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     16 #define EIGEN_USE_THREADS
     18 #if GOOGLE_CUDA
     19 #define EIGEN_USE_GPU
     20 #endif  // GOOGLE_CUDA
     22 #include "tensorflow/contrib/rnn/kernels/lstm_ops.h"
     24 #include <memory>
     25 #include <vector>
     27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     28 #include "tensorflow/core/framework/op_kernel.h"
     29 #include "tensorflow/core/framework/register_types.h"
     30 #include "tensorflow/core/framework/tensor.h"
     31 #include "tensorflow/core/framework/tensor_shape.h"
     32 #include "tensorflow/core/framework/tensor_types.h"
     33 #include "tensorflow/core/framework/types.h"
     34 #include "tensorflow/core/platform/logging.h"
     35 #include "tensorflow/core/platform/macros.h"
     37 namespace tensorflow {
     39 typedef Eigen::ThreadPoolDevice CPUDevice;
     40 typedef Eigen::GpuDevice GPUDevice;
     42 namespace functor {
     44 template <typename T>
     45 void LSTMBlockCellFpropWithEigen(
     46     const LSTMBlockCell& cell, OpKernelContext* ctx, const CPUDevice& d,
     47     const float forget_bias, const float cell_clip, bool use_peephole,
     48     typename TTypes<T>::ConstMatrix x, typename TTypes<T>::ConstMatrix cs_prev,
     49     typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
     50     typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
     51     typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b,
     52     typename TTypes<T>::Matrix xh, typename TTypes<T>::Matrix i,
     53     typename TTypes<T>::Matrix cs, typename TTypes<T>::Matrix f,
     54     typename TTypes<T>::Matrix o, typename TTypes<T>::Matrix ci,
     55     typename TTypes<T>::Matrix co, typename TTypes<T>::Matrix icfo,
     56     typename TTypes<T>::Matrix h) {
     57   // Concat xh = [x, h].
     58   xh.slice(cell.xh_x_offsets(), cell.xh_x_extents()).device(d) = x;
     59   xh.slice(cell.xh_h_offsets(), cell.xh_h_extents()).device(d) = h_prev;
     61   // states1 = xh * w + b
     62   typename TTypes<T>::ConstMatrix const_xh(xh.data(), xh.dimensions());
     63   TensorBlasGemm<CPUDevice, T, false /* USE_CUBLAS */>::compute(
     64       ctx, d, false, false, typename gemm_compute_type<T>::type(1.f), const_xh,
     65       w, typename gemm_compute_type<T>::type(0.f), icfo);
     66   Eigen::array<Eigen::DenseIndex, 2> b_shape({1, b.dimensions()[0]});
     67   Eigen::array<Eigen::DenseIndex, 2> broadcast_shape({cell.batch_size(), 1});
     68   icfo.device(d) += b.reshape(b_shape).broadcast(broadcast_shape);
     70   Eigen::array<Eigen::DenseIndex, 2> p_shape({1, cell.cell_size()});
     71   Eigen::array<Eigen::DenseIndex, 2> p_broadcast_shape({cell.batch_size(), 1});
     73   // Input gate.
     74   if (use_peephole) {
     75     auto i_peep = cs_prev * wci.reshape(p_shape).broadcast(p_broadcast_shape);
     76     i.device(d) =
     77         (icfo.slice(cell.icfo_i_offsets(), cell.cell_extents()) + i_peep)
     78             .sigmoid();
     79   } else {
     80     i.device(d) =
     81         icfo.slice(cell.icfo_i_offsets(), cell.cell_extents()).sigmoid();
     82   }
     84   // Cell input.
     85   ci.device(d) = icfo.slice(cell.icfo_c_offsets(), cell.cell_extents()).tanh();
     87   // Forget gate (w/ bias).
     88   if (use_peephole) {
     89     auto f_peep = cs_prev * wcf.reshape(p_shape).broadcast(p_broadcast_shape);
     90     f.device(d) = (icfo.slice(cell.icfo_f_offsets(), cell.cell_extents()) +
     91                    f.constant(T(forget_bias)) + f_peep)
     92                       .sigmoid();
     93   } else {
     94     f.device(d) = (icfo.slice(cell.icfo_f_offsets(), cell.cell_extents()) +
     95                    f.constant(T(forget_bias)))
     96                       .sigmoid();
     97   }
     99   // cs = ci .* i + f .* cs_prev
    100   cs.device(d) = i * ci + f * cs_prev;
    102   if (cell_clip > 0.0f) {
    103     cs.device(d) =
    104         cs.binaryExpr(cs.constant(T(cell_clip)), Eigen::scalar_clip_op<T>());
    105   }
    107   // co = tanh(cs)
    108   co.device(d) = cs.tanh();
    110   // Output gate.
    111   if (use_peephole) {
    112     auto o_peep = cs * wco.reshape(p_shape).broadcast(p_broadcast_shape);
    113     o.device(d) =
    114         (icfo.slice(cell.icfo_o_offsets(), cell.cell_extents()) + o_peep)
    115             .sigmoid();
    116   } else {
    117     o.device(d) =
    118         icfo.slice(cell.icfo_o_offsets(), cell.cell_extents()).sigmoid();
    119   }
    121   // h = o .* co
    122   h.device(d) = o * co;
    123 }
    125 template <typename Device, typename T, bool USE_CUBLAS>
    126 void LSTMBlockCellBpropWithEigen(
    127     const LSTMBlockCell& cell, OpKernelContext* ctx, const Device& d,
    128     bool use_peephole, typename TTypes<T>::ConstMatrix x,
    129     typename TTypes<T>::ConstMatrix cs_prev,
    130     typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
    131     typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
    132     typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b,
    133     typename TTypes<T>::ConstMatrix i, typename TTypes<T>::ConstMatrix cs,
    134     typename TTypes<T>::ConstMatrix f, typename TTypes<T>::ConstMatrix o,
    135     typename TTypes<T>::ConstMatrix ci, typename TTypes<T>::ConstMatrix co,
    136     typename TTypes<T>::ConstMatrix cs_grad,
    137     typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,
    138     typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,
    139     typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,
    140     typename TTypes<T>::Matrix dicfo, typename TTypes<T>::Matrix cs_prev_grad,
    141     typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,
    142     typename TTypes<T>::Vec wco_grad) {
    143   // do[t] = sigm'(o[t]) .* dh[t] .* co[t]
    144   do_.device(d) = o * (o.constant(T(1)) - o) * h_grad * co;
    146   // dcs[t] += tanh'(cs[t]) .* dh[t] .* o[t] + dcs[t + 1] .* f[t + 1]
    147   dcs.device(d) = (co.constant(T(1)) - co * co) * h_grad * o + cs_grad;
    149   Eigen::array<Eigen::DenseIndex, 2> p_shape({1, cell.cell_size()});
    150   Eigen::array<Eigen::DenseIndex, 2> p_broadcast_shape({cell.batch_size(), 1});
    151   if (use_peephole) {
    152     dcs.device(d) =
    153         dcs + do_ * wco.reshape(p_shape).broadcast(p_broadcast_shape);
    154   }
    156   // dci[t] = tanh'(ci[t]) dcs[t] i[t]
    157   dci.device(d) = (ci.constant(T(1)) - ci * ci) * dcs * i;
    159   // df[t] = sigm'(f[t]) dcs[t] cs[t - 1]
    160   df.device(d) = f * (f.constant(T(1)) - f) * dcs * cs_prev;
    162   // di[t] = sigm'(i[t]) dcs[t] ci[t]
    163   di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci;
    165   dicfo.slice(cell.icfo_i_offsets(), cell.cell_extents()).device(d) = di;
    166   dicfo.slice(cell.icfo_c_offsets(), cell.cell_extents()).device(d) = dci;
    167   dicfo.slice(cell.icfo_f_offsets(), cell.cell_extents()).device(d) = df;
    168   dicfo.slice(cell.icfo_o_offsets(), cell.cell_extents()).device(d) = do_;
    170   cs_prev_grad.device(d) = dcs * f;
    171   if (use_peephole) {
    172     cs_prev_grad.device(d) =
    173         cs_prev_grad + di * wci.reshape(p_shape).broadcast(p_broadcast_shape) +
    174         df * wcf.reshape(p_shape).broadcast(p_broadcast_shape);
    175     wci_grad.device(d) = (di * cs_prev).sum(Eigen::array<int, 1>({0}));
    176     wcf_grad.device(d) = (df * cs_prev).sum(Eigen::array<int, 1>({0}));
    177     wco_grad.device(d) = (do_ * cs).sum(Eigen::array<int, 1>({0}));
    178   }
    179 }
    181 #define DEFINE_CPU_SPECS(T)                                                   \
    182   template <>                                                                 \
    183   void LSTMBlockCellFprop<CPUDevice, T, false /* USE_CUBLAS */>::operator()(  \
    184       OpKernelContext* ctx, const CPUDevice& d, const float forget_bias,      \
    185       const float cell_clip, bool use_peephole,                               \
    186       typename TTypes<T>::ConstMatrix x,                                      \
    187       typename TTypes<T>::ConstMatrix cs_prev,                                \
    188       typename TTypes<T>::ConstMatrix h_prev,                                 \
    189       typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci,    \
    190       typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,     \
    191       typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh,          \
    192       typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs,            \
    193       typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o,             \
    194       typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co,           \
    195       typename TTypes<T>::Matrix icfo, typename TTypes<T>::Matrix h) {        \
    196     LSTMBlockCellFpropWithEigen<T>(                                           \
    197         *this, ctx, d, forget_bias, cell_clip, use_peephole, x, cs_prev,      \
    198         h_prev, w, wci, wcf, wco, b, xh, i, cs, f, o, ci, co, icfo, h);       \
    199   }                                                                           \
    200   template <>                                                                 \
    201   void LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */>::operator()(  \
    202       OpKernelContext* ctx, const CPUDevice& d, bool use_peephole,            \
    203       typename TTypes<T>::ConstMatrix x,                                      \
    204       typename TTypes<T>::ConstMatrix cs_prev,                                \
    205       typename TTypes<T>::ConstMatrix h_prev,                                 \
    206       typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci,    \
    207       typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,     \
    208       typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i,      \
    209       typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f,  \
    210       typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci,  \
    211       typename TTypes<T>::ConstMatrix co,                                     \
    212       typename TTypes<T>::ConstMatrix cs_grad,                                \
    213       typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \
    214       typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,         \
    215       typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,           \
    216       typename TTypes<T>::Matrix dicfo,                                       \
    217       typename TTypes<T>::Matrix cs_prev_grad,                                \
    218       typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,     \
    219       typename TTypes<T>::Vec wco_grad) {                                     \
    220     LSTMBlockCellBpropWithEigen<CPUDevice, T, false /* USE_CUBLAS */>(        \
    221         *this, ctx, d, use_peephole, x, cs_prev, h_prev, w, wci, wcf, wco, b, \
    222         i, cs, f, o, ci, co, cs_grad, h_grad, do_, dcs, dci, df, di, dicfo,   \
    223         cs_prev_grad, wci_grad, wcf_grad, wco_grad);                          \
    224   }                                                                           \
    225   template struct LSTMBlockCellFprop<CPUDevice, T, false /* USE_CUBLAS */>;   \
    226   template struct LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */>;
    228 DEFINE_CPU_SPECS(float);
    229 DEFINE_CPU_SPECS(Eigen::half);
    230 #undef DEFINE_CPU_SPECS
    232 }  // namespace functor
    234 template <typename Device, typename T, bool USE_CUBLAS>
    235 class LSTMBlockCellOp : public OpKernel {
    236  public:
    237   explicit LSTMBlockCellOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    238     OP_REQUIRES_OK(ctx, ctx->GetAttr("forget_bias", &forget_bias_));
    239     OP_REQUIRES_OK(ctx, ctx->GetAttr("cell_clip", &cell_clip_));
    240     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_));
    241   }
    243   void Compute(OpKernelContext* ctx) override {
    244     const Tensor* x_tensor = nullptr;
    245     OP_REQUIRES_OK(ctx, ctx->input("x", &x_tensor));
    247     const Tensor* cs_prev_tensor = nullptr;
    248     OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor));
    250     const Tensor* h_prev_tensor = nullptr;
    251     OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor));
    253     const Tensor* w_tensor = nullptr;
    254     OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor));
    256     const Tensor* wci_tensor = nullptr;
    257     OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor));
    259     const Tensor* wcf_tensor = nullptr;
    260     OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor));
    262     const Tensor* wco_tensor = nullptr;
    263     OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor));
    265     const Tensor* b_tensor = nullptr;
    266     OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor));
    268     const int64 batch_size = x_tensor->dim_size(0);
    269     const int64 input_size = x_tensor->dim_size(1);
    270     const int64 cell_size = cs_prev_tensor->dim_size(1);
    272     // Sanity checks for our input shapes.
    273     OP_REQUIRES(ctx, cs_prev_tensor->dim_size(0) == batch_size,
    274                 errors::InvalidArgument("cs_prev.dims(0) != batch_size: ",
    275                                         cs_prev_tensor->dim_size(0), " vs. ",
    276                                         batch_size));
    277     OP_REQUIRES(ctx, cs_prev_tensor->dim_size(1) == cell_size,
    278                 errors::InvalidArgument("cs_prev.dims(1) != cell_size: ",
    279                                         cs_prev_tensor->dim_size(1), " vs. ",
    280                                         cell_size));
    282     OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size,
    283                 errors::InvalidArgument("h_prev.dims(0) != batch_size: ",
    284                                         h_prev_tensor->dim_size(0), " vs. ",
    285                                         batch_size));
    286     OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size,
    287                 errors::InvalidArgument(
    288                     "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1),
    289                     " vs. ", cell_size));
    291     OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size,
    292                 errors::InvalidArgument(
    293                     "w.dim_size(0) != input_size + cell_size: ",
    294                     w_tensor->dim_size(0), " vs. ", input_size + cell_size));
    295     OP_REQUIRES(ctx, w_tensor->dim_size(1) == cell_size * 4,
    296                 errors::InvalidArgument(
    297                     "w.dim_size(1) != cell_size * 4: ", w_tensor->dim_size(1),
    298                     " vs. ", cell_size * 4));
    300     OP_REQUIRES(ctx, b_tensor->dim_size(0) == cell_size * 4,
    301                 errors::InvalidArgument(
    302                     "b.dim_size(0) != cell_size * 4: ", b_tensor->dim_size(0),
    303                     " vs. ", cell_size * 4));
    305     // Allocate our output tensors.
    306     Tensor* i_tensor = nullptr;
    307     OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
    308                             {"h_prev"}, "i",
    309                             TensorShape({batch_size, cell_size}), &i_tensor));
    311     Tensor* cs_tensor = nullptr;
    312     OP_REQUIRES_OK(
    313         ctx, ctx->allocate_output("cs", TensorShape({batch_size, cell_size}),
    314                                   &cs_tensor));
    316     Tensor* f_tensor = nullptr;
    317     OP_REQUIRES_OK(
    318         ctx, ctx->allocate_output("f", TensorShape({batch_size, cell_size}),
    319                                   &f_tensor));
    321     Tensor* o_tensor = nullptr;
    322     OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
    323                             {"cs_prev"}, "o",
    324                             TensorShape({batch_size, cell_size}), &o_tensor));
    326     Tensor* ci_tensor = nullptr;
    327     OP_REQUIRES_OK(
    328         ctx, ctx->allocate_output("ci", TensorShape({batch_size, cell_size}),
    329                                   &ci_tensor));
    331     Tensor* co_tensor = nullptr;
    332     OP_REQUIRES_OK(
    333         ctx, ctx->allocate_output("co", TensorShape({batch_size, cell_size}),
    334                                   &co_tensor));
    336     Tensor* h_tensor = nullptr;
    337     OP_REQUIRES_OK(
    338         ctx, ctx->allocate_output("h", TensorShape({batch_size, cell_size}),
    339                                   &h_tensor));
    341     // Allocate our temp tensors.
    342     Tensor xh_tensor;
    343     OP_REQUIRES_OK(ctx, ctx->allocate_temp(
    344                             DataTypeToEnum<T>::v(),
    345                             TensorShape({batch_size, input_size + cell_size}),
    346                             &xh_tensor));
    348     Tensor icfo_tensor;
    349     OP_REQUIRES_OK(ctx,
    350                    ctx->allocate_temp(DataTypeToEnum<T>::v(),
    351                                       TensorShape({batch_size, cell_size * 4}),
    352                                       &icfo_tensor));
    354     const Device& device = ctx->eigen_device<Device>();
    356     functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
    357                                                        cell_size)(
    358         ctx, device, forget_bias_, cell_clip_, use_peephole_,
    359         x_tensor->matrix<T>(), cs_prev_tensor->matrix<T>(),
    360         h_prev_tensor->matrix<T>(), w_tensor->matrix<T>(), wci_tensor->vec<T>(),
    361         wcf_tensor->vec<T>(), wco_tensor->vec<T>(), b_tensor->vec<T>(),
    362         xh_tensor.matrix<T>(), i_tensor->matrix<T>(), cs_tensor->matrix<T>(),
    363         f_tensor->matrix<T>(), o_tensor->matrix<T>(), ci_tensor->matrix<T>(),
    364         co_tensor->matrix<T>(), icfo_tensor.matrix<T>(), h_tensor->matrix<T>());
    365   }
    367  private:
    368   float forget_bias_;
    369   float cell_clip_;
    370   bool use_peephole_;
    371 };
    373 #define REGISTER_KERNEL(T)                                             \
    374   REGISTER_KERNEL_BUILDER(                                             \
    375       Name("LSTMBlockCell").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    376       LSTMBlockCellOp<CPUDevice, T, false>);
    377 REGISTER_KERNEL(float);
    378 REGISTER_KERNEL(Eigen::half);
    379 #undef REGISTER_KERNEL
    381 #if GOOGLE_CUDA
    382 namespace functor {
    383 #define DECLARE_GPU_SPEC(T)                                                \
    384   template <>                                                              \
    385   void LSTMBlockCellFprop<GPUDevice, T, true>::operator()(                 \
    386       OpKernelContext* ctx, const GPUDevice& d, const float forget_bias,   \
    387       const float cell_clip, bool use_peephole,                            \
    388       typename TTypes<T>::ConstMatrix x,                                   \
    389       typename TTypes<T>::ConstMatrix cs_prev,                             \
    390       typename TTypes<T>::ConstMatrix h_prev,                              \
    391       typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
    392       typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,  \
    393       typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh,       \
    394       typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs,         \
    395       typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o,          \
    396       typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co,        \
    397       typename TTypes<T>::Matrix icfo, typename TTypes<T>::Matrix h);      \
    398                                                                            \
    399   extern template struct LSTMBlockCellFprop<GPUDevice, T, true>;
    401 DECLARE_GPU_SPEC(float);
    402 DECLARE_GPU_SPEC(Eigen::half);
    403 #undef DECLARE_GPU_SPEC
    404 }  // end namespace functor
    406 #define REGISTER_GPU_KERNEL(T)                                         \
    407   REGISTER_KERNEL_BUILDER(                                             \
    408       Name("LSTMBlockCell").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
    409       LSTMBlockCellOp<GPUDevice, T, true>);
    411 REGISTER_GPU_KERNEL(float);
    412 REGISTER_GPU_KERNEL(Eigen::half);
    413 // REGISTER_GPU_KERNEL(double);
    414 #undef REGISTER_GPU_KERNEL
    415 #endif  // GOOGLE_CUDA
    417 template <typename Device, typename T, bool USE_CUBLAS>
    418 class LSTMBlockCellGradOp : public OpKernel {
    419  public:
    420   explicit LSTMBlockCellGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    421     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_));
    422   }
    424   void Compute(OpKernelContext* ctx) override {
    425     const Tensor* x_tensor = nullptr;
    426     OP_REQUIRES_OK(ctx, ctx->input("x", &x_tensor));
    428     const Tensor* cs_prev_tensor = nullptr;
    429     OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor));
    431     const Tensor* h_prev_tensor = nullptr;
    432     OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor));
    434     const Tensor* w_tensor = nullptr;
    435     OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor));
    437     const Tensor* wci_tensor = nullptr;
    438     OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor));
    440     const Tensor* wcf_tensor = nullptr;
    441     OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor));
    443     const Tensor* wco_tensor = nullptr;
    444     OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor));
    446     const Tensor* b_tensor = nullptr;
    447     OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor));
    449     const Tensor* i_tensor = nullptr;
    450     OP_REQUIRES_OK(ctx, ctx->input("i", &i_tensor));
    452     const Tensor* cs_tensor = nullptr;
    453     OP_REQUIRES_OK(ctx, ctx->input("cs", &cs_tensor));
    455     const Tensor* f_tensor = nullptr;
    456     OP_REQUIRES_OK(ctx, ctx->input("f", &f_tensor));
    458     const Tensor* o_tensor = nullptr;
    459     OP_REQUIRES_OK(ctx, ctx->input("o", &o_tensor));
    461     const Tensor* ci_tensor = nullptr;
    462     OP_REQUIRES_OK(ctx, ctx->input("ci", &ci_tensor));
    464     const Tensor* co_tensor = nullptr;
    465     OP_REQUIRES_OK(ctx, ctx->input("co", &co_tensor));
    467     const Tensor* cs_grad_tensor = nullptr;
    468     OP_REQUIRES_OK(ctx, ctx->input("cs_grad", &cs_grad_tensor));
    470     const Tensor* h_grad_tensor = nullptr;
    471     OP_REQUIRES_OK(ctx, ctx->input("h_grad", &h_grad_tensor));
    473     const int64 batch_size = x_tensor->dim_size(0);
    474     const int64 input_size = x_tensor->dim_size(1);
    475     const int64 cell_size = cs_prev_tensor->dim_size(1);
    477     // Sanity checks for our input shapes.
    478     OP_REQUIRES(ctx, cs_prev_tensor->dim_size(0) == batch_size,
    479                 errors::InvalidArgument("cs_prev.dims(0) != batch_size: ",
    480                                         cs_prev_tensor->dim_size(0), " vs. ",
    481                                         batch_size));
    482     OP_REQUIRES(ctx, cs_prev_tensor->dim_size(1) == cell_size,
    483                 errors::InvalidArgument("cs_prev.dims(1) != cell_size: ",
    484                                         cs_prev_tensor->dim_size(1), " vs. ",
    485                                         cell_size));
    487     OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size,
    488                 errors::InvalidArgument("h_prev.dims(0) != batch_size: ",
    489                                         h_prev_tensor->dim_size(0), " vs. ",
    490                                         batch_size));
    491     OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size,
    492                 errors::InvalidArgument(
    493                     "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1),
    494                     " vs. ", cell_size));
    496     OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size,
    497                 errors::InvalidArgument(
    498                     "w.dim_size(0) != input_size + cell_size: ",
    499                     w_tensor->dim_size(0), " vs. ", input_size + cell_size));
    500     OP_REQUIRES(ctx, w_tensor->dim_size(1) == cell_size * 4,
    501                 errors::InvalidArgument(
    502                     "w.dim_size(1) != cell_size * 4: ", w_tensor->dim_size(1),
    503                     " vs. ", cell_size * 4));
    505     OP_REQUIRES(ctx, b_tensor->dim_size(0) == cell_size * 4,
    506                 errors::InvalidArgument(
    507                     "b.dim_size(0) != cell_size * 4: ", b_tensor->dim_size(0),
    508                     " vs. ", cell_size * 4));
    510     OP_REQUIRES(ctx, i_tensor->dim_size(0) == batch_size,
    511                 errors::InvalidArgument(
    512                     "i.dim_size(0) != batch_size: ", i_tensor->dim_size(0),
    513                     " vs. ", batch_size));
    514     OP_REQUIRES(ctx, i_tensor->dim_size(1) == cell_size,
    515                 errors::InvalidArgument(
    516                     "i.dim_size(1) != cell_size: ", i_tensor->dim_size(1),
    517                     " vs. ", cell_size));
    519     OP_REQUIRES(ctx, cs_tensor->dim_size(0) == batch_size,
    520                 errors::InvalidArgument(
    521                     "cs.dim_size(0) != batch_size: ", cs_tensor->dim_size(0),
    522                     " vs. ", batch_size));
    523     OP_REQUIRES(ctx, cs_tensor->dim_size(1) == cell_size,
    524                 errors::InvalidArgument(
    525                     "cs.dim_size(1) != cell_size: ", cs_tensor->dim_size(1),
    526                     " vs. ", cell_size));
    528     OP_REQUIRES(ctx, f_tensor->dim_size(0) == batch_size,
    529                 errors::InvalidArgument(
    530                     "f.dim_size(0) != batch_size: ", f_tensor->dim_size(0),
    531                     " vs. ", batch_size));
    532     OP_REQUIRES(ctx, f_tensor->dim_size(1) == cell_size,
    533                 errors::InvalidArgument(
    534                     "i.dim_size(1) != cell_size: ", f_tensor->dim_size(1),
    535                     " vs. ", cell_size));
    537     OP_REQUIRES(ctx, o_tensor->dim_size(0) == batch_size,
    538                 errors::InvalidArgument(
    539                     "o.dim_size(0) != batch_size: ", o_tensor->dim_size(0),
    540                     " vs. ", batch_size));
    541     OP_REQUIRES(ctx, o_tensor->dim_size(1) == cell_size,
    542                 errors::InvalidArgument(
    543                     "o.dim_size(1) != cell_size: ", o_tensor->dim_size(1),
    544                     " vs. ", cell_size));
    546     OP_REQUIRES(ctx, ci_tensor->dim_size(0) == batch_size,
    547                 errors::InvalidArgument(
    548                     "ci.dim_size(0) != batch_size: ", ci_tensor->dim_size(0),
    549                     " vs. ", batch_size));
    550     OP_REQUIRES(ctx, ci_tensor->dim_size(1) == cell_size,
    551                 errors::InvalidArgument(
    552                     "ci.dim_size(1) != cell_size: ", ci_tensor->dim_size(1),
    553                     " vs. ", cell_size));
    555     OP_REQUIRES(ctx, co_tensor->dim_size(0) == batch_size,
    556                 errors::InvalidArgument(
    557                     "co.dim_size(0) != batch_size: ", co_tensor->dim_size(0),
    558                     " vs. ", batch_size));
    559     OP_REQUIRES(ctx, co_tensor->dim_size(1) == cell_size,
    560                 errors::InvalidArgument(
    561                     "co.dim_size(1) != cell_size: ", co_tensor->dim_size(1),
    562                     " vs. ", cell_size));
    564     OP_REQUIRES(ctx, cs_grad_tensor->dim_size(0) == batch_size,
    565                 errors::InvalidArgument(
    566                     "cs_grad_tensor.dims(0) != batch_size: ",
    567                     cs_grad_tensor->dim_size(0), " vs. ", batch_size));
    568     OP_REQUIRES(ctx, cs_grad_tensor->dim_size(1) == cell_size,
    569                 errors::InvalidArgument("cs_grad_tensor.dims(1) != cell_size: ",
    570                                         cs_grad_tensor->dim_size(1), " vs. ",
    571                                         cell_size));
    573     OP_REQUIRES(ctx, h_grad_tensor->dim_size(0) == batch_size,
    574                 errors::InvalidArgument("h_grad_tensor.dims(0) != batch_size: ",
    575                                         h_grad_tensor->dim_size(0), " vs. ",
    576                                         batch_size));
    577     OP_REQUIRES(ctx, h_grad_tensor->dim_size(1) == cell_size,
    578                 errors::InvalidArgument("h_grad_tensor.dims(1) != cell_size: ",
    579                                         h_grad_tensor->dim_size(1), " vs. ",
    580                                         cell_size));
    582     // Allocate our output tensors.
    583     Tensor* cs_prev_grad_tensor = nullptr;
    584     OP_REQUIRES_OK(
    585         ctx, ctx->forward_input_or_allocate_output(
    586                  {"cs_grad"}, "cs_prev_grad",
    587                  TensorShape({batch_size, cell_size}), &cs_prev_grad_tensor));
    589     Tensor* dicfo_tensor = nullptr;
    590     OP_REQUIRES_OK(ctx, ctx->allocate_output(
    591                             "dicfo", TensorShape({batch_size, cell_size * 4}),
    592                             &dicfo_tensor));
    594     Tensor* wci_grad_tensor = nullptr;
    595     OP_REQUIRES_OK(
    596         ctx, ctx->forward_input_or_allocate_output(
    597                  {"wci"}, "wci_grad", wci_tensor->shape(), &wci_grad_tensor));
    599     Tensor* wcf_grad_tensor = nullptr;
    600     OP_REQUIRES_OK(
    601         ctx, ctx->forward_input_or_allocate_output(
    602                  {"wcf"}, "wcf_grad", wcf_tensor->shape(), &wcf_grad_tensor));
    604     Tensor* wco_grad_tensor = nullptr;
    605     OP_REQUIRES_OK(
    606         ctx, ctx->forward_input_or_allocate_output(
    607                  {"wco"}, "wco_grad", wco_tensor->shape(), &wco_grad_tensor));
    609     // Allocate our temp tensors.
    610     Tensor do_tensor;
    611     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
    612                                            TensorShape({batch_size, cell_size}),
    613                                            &do_tensor));
    615     Tensor dcs_tensor;
    616     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
    617                                            TensorShape({batch_size, cell_size}),
    618                                            &dcs_tensor));
    620     Tensor dci_tensor;
    621     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
    622                                            TensorShape({batch_size, cell_size}),
    623                                            &dci_tensor));
    625     Tensor df_tensor;
    626     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
    627                                            TensorShape({batch_size, cell_size}),
    628                                            &df_tensor));
    630     Tensor di_tensor;
    631     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
    632                                            TensorShape({batch_size, cell_size}),
    633                                            &di_tensor));
    635     const Device& device = ctx->eigen_device<Device>();
    637     functor::TensorZero<Device, T>()(device, wci_grad_tensor->flat<T>());
    638     functor::TensorZero<Device, T>()(device, wcf_grad_tensor->flat<T>());
    639     functor::TensorZero<Device, T>()(device, wco_grad_tensor->flat<T>());
    641     functor::LSTMBlockCellBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
    642                                                        cell_size)(
    643         ctx, device, use_peephole_, x_tensor->matrix<T>(),
    644         cs_prev_tensor->matrix<T>(), h_prev_tensor->matrix<T>(),
    645         w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(),
    646         wco_tensor->vec<T>(), b_tensor->vec<T>(), i_tensor->matrix<T>(),
    647         cs_tensor->matrix<T>(), f_tensor->matrix<T>(), o_tensor->matrix<T>(),
    648         ci_tensor->matrix<T>(), co_tensor->matrix<T>(),
    649         cs_grad_tensor->matrix<T>(), h_grad_tensor->matrix<T>(),
    650         do_tensor.matrix<T>(), dcs_tensor.matrix<T>(), dci_tensor.matrix<T>(),
    651         df_tensor.matrix<T>(), di_tensor.matrix<T>(), dicfo_tensor->matrix<T>(),
    652         cs_prev_grad_tensor->matrix<T>(), wci_grad_tensor->vec<T>(),
    653         wcf_grad_tensor->vec<T>(), wco_grad_tensor->vec<T>());
    654   }
    656  protected:
    657   bool use_peephole_;
    658 };
    660 #define REGISTER_KERNEL(T)                                                 \
    661   REGISTER_KERNEL_BUILDER(                                                 \
    662       Name("LSTMBlockCellGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    663       LSTMBlockCellGradOp<CPUDevice, T, false>);
    664 REGISTER_KERNEL(float);
    665 REGISTER_KERNEL(Eigen::half);
    666 #undef REGISTER_KERNEL
    668 #if GOOGLE_CUDA
    669 namespace functor {
    670 #define DECLARE_GPU_SPEC(T)                                                   \
    671   template <>                                                                 \
    672   void LSTMBlockCellBprop<GPUDevice, T, true>::operator()(                    \
    673       OpKernelContext* ctx, const GPUDevice& d, bool use_peephole,            \
    674       typename TTypes<T>::ConstMatrix x,                                      \
    675       typename TTypes<T>::ConstMatrix cs_prev,                                \
    676       typename TTypes<T>::ConstMatrix h_prev,                                 \
    677       typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci,    \
    678       typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,     \
    679       typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i,      \
    680       typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f,  \
    681       typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci,  \
    682       typename TTypes<T>::ConstMatrix co,                                     \
    683       typename TTypes<T>::ConstMatrix cs_grad,                                \
    684       typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \
    685       typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,         \
    686       typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,           \
    687       typename TTypes<T>::Matrix dicfo,                                       \
    688       typename TTypes<T>::Matrix cs_prev_grad,                                \
    689       typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,     \
    690       typename TTypes<T>::Vec wco_grad);                                      \
    691                                                                               \
    692   extern template struct LSTMBlockCellBprop<GPUDevice, T,                     \
    693                                             true /* USE_CUBLAS */>;
    695 DECLARE_GPU_SPEC(float);
    696 DECLARE_GPU_SPEC(Eigen::half);
    697 // DECLARE_GPU_SPEC(double);
    698 #undef DECLARE_GPU_SPEC
    699 }  // namespace functor
    701 #define REGISTER_GPU_KERNEL(T)                                             \
    702   REGISTER_KERNEL_BUILDER(                                                 \
    703       Name("LSTMBlockCellGrad").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
    704       LSTMBlockCellGradOp<GPUDevice, T, true>);
    706 REGISTER_GPU_KERNEL(float);
    707 REGISTER_GPU_KERNEL(Eigen::half);
    708 // REGISTER_GPU_KERNEL(double);
    709 #undef REGISTER_GPU_KERNEL
    710 #endif  // GOOGLE_CUDA
    712 namespace {
    714 // This helper class can be used to access timeslices of a 3D tensor. If a slice
    715 // happens to be unaligned (usually because both batch size and number of cells
    716 // are odd - this isn't common) this involves overhead, since data needs to be
    717 // copied. However, if all slices are aligned, the bits aren't copied. In the
    718 // cases where copying is needed, the outputs have to be recopied back.
    719 // At the end of each time step you should call FinishTimeStep which does this,
    720 // and also allows for reuse of temporary tensors.
    721 template <typename Device, typename T>
    722 class SliceHelper {
    723  public:
    724   explicit SliceHelper(OpKernelContext* ctx)
    725       : ctx_(ctx), device_(ctx_->eigen_device<Device>()) {}
    727   ~SliceHelper() {
    728     CHECK(copy_out_.empty());
    729     for (const auto& entry : pool_) {
    730       CHECK(!entry.second.second);  // nothing is in use
    731     }
    732   }
    734   // Slice through an input tensor. This may copy unaligned slices, but no
    735   // copying back will be done at the end.
    736   const Tensor InputSlice(const Tensor& t, int pos, const string& name) {
    737     Tensor res = UnalignedSlice(t, pos);
    738     if (res.IsAligned()) {
    739       return res;
    740     } else {
    741       return AlignTensor(res, name);
    742     }
    743   }
    745   // Slice through an output tensor. This may copy unaligned slices, and
    746   // schedule copying back on destruction.
    747   Tensor OutputSlice(Tensor* t, int pos, const string& name) {
    748     Tensor res = UnalignedSlice(*t, pos);
    749     if (res.IsAligned()) {
    750       return res;
    751     } else {
    752       Tensor aligned = AlignTensor(res, name);
    753       copy_out_.emplace_back(res, aligned);
    754       return aligned;
    755     }
    756   }
    758   void FinishTimeStep() {
    759     for (const auto& p : copy_out_) {
    760       const Tensor& aligned = p.second;
    761       Tensor original = p.first;
    762       // Copy from aligned back to original.
    763       functor::TensorCopyToUnaligned<Device, T>()(device_, aligned.flat<T>(),
    764                                                   original.unaligned_flat<T>());
    765     }
    766     copy_out_.clear();
    767     // Mark all entries as not in use.
    768     for (auto& entry : pool_) {
    769       entry.second.second = false;
    770     }
    771   }
    773  private:
    774   // Return a slice at position 'pos'. Result may be unaligned. The resulting
    775   // tensor always shares data with the source tensor.
    776   Tensor UnalignedSlice(const Tensor& t, int pos) const {
    777     Tensor res;
    778     // CHECK should never fail here, since the number of elements must match
    779     CHECK(res.CopyFrom(t.Slice(pos, pos + 1), {t.dim_size(1), t.dim_size(2)}));
    780     return res;
    781   }
    783   // Assumes input is not aligned, creates a temporary aligned tensor of the
    784   // same shape and copies the original tensor's content into it.
    785   Tensor AlignTensor(const Tensor& t, const string& name) {
    786     VLOG(1) << "AlignTensor called for " << name << ", shape "
    787             << t.shape().DebugString()
    788             << ". This is unnecessary copying. Consider using shapes with even "
    789             << "sizes";
    790     Tensor aligned;
    791     auto found = pool_.find(name);
    792     if (found != pool_.end()) {  // found in pool
    793       CHECK(!found->second.second) << "Tensor " << name << " is in use";
    794       found->second.second = true;  // mark in use
    795       aligned = found->second.first;
    796       CHECK(aligned.shape().IsSameSize(t.shape()));
    797       CHECK_EQ(aligned.dtype(), t.dtype());
    798     } else {  // allocate a new temporary tensor
    799       TF_CHECK_OK(ctx_->allocate_temp(t.dtype(), t.shape(), &aligned));
    800       pool_.emplace(name, std::make_pair(aligned, true));
    801     }
    802     functor::TensorCopyUnaligned<Device, T>()(device_, t.unaligned_flat<T>(),
    803                                               aligned.flat<T>());
    804     return aligned;
    805   }
    807   // Tensors to be copied.
    808   std::vector<std::pair<Tensor, const Tensor>> copy_out_;
    809   // A pool of pre-allocated temporary tensors, with an indicator for whether
    810   // it's in use.
    811   std::map<string, std::pair<Tensor, bool>> pool_;
    812   // Op context
    813   OpKernelContext* ctx_ = nullptr;
    814   // Device
    815   const Device& device_;
    816 };
    818 }  // namespace
    820 template <typename Device, typename T, bool USE_CUBLAS>
    821 class BlockLSTMOp : public OpKernel {
    822  public:
    823   explicit BlockLSTMOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    824     OP_REQUIRES_OK(ctx, ctx->GetAttr("forget_bias", &forget_bias_));
    825     OP_REQUIRES_OK(ctx, ctx->GetAttr("cell_clip", &cell_clip_));
    826     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_));
    827   }
    829   void Compute(OpKernelContext* ctx) override {
    830     const Tensor* seq_len_max_tensor = nullptr;
    831     OP_REQUIRES_OK(ctx, ctx->input("seq_len_max", &seq_len_max_tensor));
    833     const Tensor* x;
    834     OP_REQUIRES_OK(ctx, ctx->input("x", &x));
    835     OP_REQUIRES(ctx, x->dims() == 3, errors::InvalidArgument("x must be 3D"));
    836     const int64 timelen = x->dim_size(0);
    837     const int64 batch_size = x->dim_size(1);
    838     const int64 input_size = x->dim_size(2);
    840     const Tensor* cs_prev_tensor = nullptr;
    841     OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor));
    842     OP_REQUIRES(ctx, cs_prev_tensor->dims() == 2,
    843                 errors::InvalidArgument("cs_prev must be 2D"));
    844     OP_REQUIRES(ctx, cs_prev_tensor->dim_size(0) == batch_size,
    845                 errors::InvalidArgument("cs_prev.dims(0) != batch_size: ",
    846                                         cs_prev_tensor->dim_size(0), " vs. ",
    847                                         batch_size));
    848     const int64 cell_size = cs_prev_tensor->dim_size(1);
    850     if (batch_size * input_size % 2 == 1) {
    851       LOG(WARNING) << "BlockLSTMOp is inefficient when both batch_size and "
    852                    << "input_size are odd. You are using: batch_size="
    853                    << batch_size << ", input_size=" << input_size;
    854     }
    855     if (batch_size * cell_size % 2 == 1) {
    856       LOG(WARNING) << "BlockLSTMOp is inefficient when both batch_size and "
    857                    << "cell_size are odd. You are using: batch_size="
    858                    << batch_size << ", cell_size=" << cell_size;
    859     }
    861     const Tensor* h_prev_tensor = nullptr;
    862     OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor));
    863     OP_REQUIRES(ctx, h_prev_tensor->dims() == 2,
    864                 errors::InvalidArgument("h_prev must be 2D"));
    865     OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size,
    866                 errors::InvalidArgument("h_prev.dims(0) != batch_size: ",
    867                                         h_prev_tensor->dim_size(0), " vs. ",
    868                                         batch_size));
    869     OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size,
    870                 errors::InvalidArgument(
    871                     "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1),
    872                     " vs. ", cell_size));
    874     const Tensor* w_tensor = nullptr;
    875     OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor));
    876     OP_REQUIRES(ctx, w_tensor->dims() == 2,
    877                 errors::InvalidArgument("w must be 2D"));
    878     OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size,
    879                 errors::InvalidArgument(
    880                     "w.dim_size(0) != input_size + cell_size: ",
    881                     w_tensor->dim_size(0), " vs. ", input_size + cell_size));
    882     OP_REQUIRES(ctx, w_tensor->dim_size(1) == cell_size * 4,
    883                 errors::InvalidArgument(
    884                     "w.dim_size(1) != cell_size * 4: ", w_tensor->dim_size(1),
    885                     " vs. ", cell_size * 4));
    887     const Tensor* wci_tensor = nullptr;
    888     OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor));
    889     OP_REQUIRES(ctx, wci_tensor->dims() == 1,
    890                 errors::InvalidArgument("wci must be 1D"));
    891     OP_REQUIRES(ctx, wci_tensor->dim_size(0) == cell_size,
    892                 errors::InvalidArgument(
    893                     "wci.dim_size(0) != cell_size: ", wci_tensor->dim_size(0),
    894                     " vs. ", cell_size));
    896     const Tensor* wcf_tensor = nullptr;
    897     OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor));
    898     OP_REQUIRES(ctx, wcf_tensor->dims() == 1,
    899                 errors::InvalidArgument("wcf must be 1D"));
    900     OP_REQUIRES(ctx, wcf_tensor->dim_size(0) == cell_size,
    901                 errors::InvalidArgument(
    902                     "wcf.dim_size(0) != cell_size: ", wcf_tensor->dim_size(0),
    903                     " vs. ", cell_size));
    905     const Tensor* wco_tensor = nullptr;
    906     OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor));
    907     OP_REQUIRES(ctx, wco_tensor->dims() == 1,
    908                 errors::InvalidArgument("wco must be 1D"));
    909     OP_REQUIRES(ctx, wco_tensor->dim_size(0) == cell_size,
    910                 errors::InvalidArgument(
    911                     "wco.dim_size(0) != cell_size: ", wco_tensor->dim_size(0),
    912                     " vs. ", cell_size));
    914     const Tensor* b_tensor = nullptr;
    915     OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor));
    916     OP_REQUIRES(ctx, b_tensor->dims() == 1,
    917                 errors::InvalidArgument("b must be 1D"));
    918     OP_REQUIRES(ctx, b_tensor->dim_size(0) == cell_size * 4,
    919                 errors::InvalidArgument(
    920                     "b.dim_size(0) != cell_size * 4: ", b_tensor->dim_size(0),
    921                     " vs. ", cell_size * 4));
    923     TensorShape batch_cell_shape({timelen, batch_size, cell_size});
    924     Tensor* i_out;
    925     OP_REQUIRES_OK(ctx, ctx->allocate_output("i", batch_cell_shape, &i_out));
    927     Tensor* cs_out;
    928     OP_REQUIRES_OK(ctx, ctx->allocate_output("cs", batch_cell_shape, &cs_out));
    930     Tensor* f_out;
    931     OP_REQUIRES_OK(ctx, ctx->allocate_output("f", batch_cell_shape, &f_out));
    933     Tensor* o_out;
    934     OP_REQUIRES_OK(ctx, ctx->allocate_output("o", batch_cell_shape, &o_out));
    936     Tensor* ci_out;
    937     OP_REQUIRES_OK(ctx, ctx->allocate_output("ci", batch_cell_shape, &ci_out));
    939     Tensor* co_out;
    940     OP_REQUIRES_OK(ctx, ctx->allocate_output("co", batch_cell_shape, &co_out));
    942     Tensor* h_out;
    943     OP_REQUIRES_OK(ctx, ctx->allocate_output("h", batch_cell_shape, &h_out));
    945     Tensor xh_tensor;
    946     OP_REQUIRES_OK(ctx, ctx->allocate_temp(
    947                             DataTypeToEnum<T>::v(),
    948                             TensorShape({batch_size, input_size + cell_size}),
    949                             &xh_tensor));
    951     Tensor icfo_tensor;
    952     OP_REQUIRES_OK(ctx,
    953                    ctx->allocate_temp(DataTypeToEnum<T>::v(),
    954                                       TensorShape({batch_size, cell_size * 4}),
    955                                       &icfo_tensor));
    957     const Device& device = ctx->eigen_device<Device>();
    959     const int64 seq_len_max = seq_len_max_tensor->scalar<int64>()();
    960     SliceHelper<Device, T> slicer(ctx);
    961     for (int64 t = 0; t < seq_len_max; ++t) {
    962       const Tensor x_tensor = slicer.InputSlice(*x, t, "x");
    963       const Tensor& cs_prev_tensor2 =
    964           t == 0 ? *cs_prev_tensor
    965                  : slicer.OutputSlice(cs_out, t - 1, "cs_prev");
    966       const Tensor& h_prev_tensor2 =
    967           t == 0 ? *h_prev_tensor : slicer.OutputSlice(h_out, t - 1, "h_prev");
    969       Tensor i_tensor = slicer.OutputSlice(i_out, t, "i_out");
    970       Tensor cs_tensor = slicer.OutputSlice(cs_out, t, "cs_out");
    971       Tensor f_tensor = slicer.OutputSlice(f_out, t, "f_out");
    972       Tensor o_tensor = slicer.OutputSlice(o_out, t, "o_out");
    973       Tensor ci_tensor = slicer.OutputSlice(ci_out, t, "ci_out");
    974       Tensor co_tensor = slicer.OutputSlice(co_out, t, "co_out");
    975       Tensor h_tensor = slicer.OutputSlice(h_out, t, "h_out");
    977       functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
    978                                                          cell_size)(
    979           ctx, device, forget_bias_, cell_clip_, use_peephole_,
    980           x_tensor.matrix<T>(), cs_prev_tensor2.matrix<T>(),
    981           h_prev_tensor2.matrix<T>(), w_tensor->matrix<T>(),
    982           wci_tensor->vec<T>(), wcf_tensor->vec<T>(), wco_tensor->vec<T>(),
    983           b_tensor->vec<T>(), xh_tensor.matrix<T>(), i_tensor.matrix<T>(),
    984           cs_tensor.matrix<T>(), f_tensor.matrix<T>(), o_tensor.matrix<T>(),
    985           ci_tensor.matrix<T>(), co_tensor.matrix<T>(), icfo_tensor.matrix<T>(),
    986           h_tensor.matrix<T>());
    987       slicer.FinishTimeStep();
    988     }
    990     if (seq_len_max < timelen) {
    991       Tensor cs_tensor = cs_out->Slice(seq_len_max, timelen);
    992       Tensor h_tensor = h_out->Slice(seq_len_max, timelen);
    994       functor::TensorUnalignedZero<Device, T>()(device,
    995                                                 cs_tensor.unaligned_flat<T>());
    996       functor::TensorUnalignedZero<Device, T>()(device,
    997                                                 h_tensor.unaligned_flat<T>());
    998     }
    999   }
   1001  private:
   1002   float forget_bias_;
   1003   float cell_clip_;
   1004   bool use_peephole_;
   1005 };
   1007 #define REGISTER_KERNEL(T)                                         \
   1008   REGISTER_KERNEL_BUILDER(                                         \
   1009       Name("BlockLSTM").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
   1010       BlockLSTMOp<CPUDevice, T, false>);
   1011 REGISTER_KERNEL(float);
   1012 REGISTER_KERNEL(Eigen::half);
   1013 #undef REGISTER_KERNEL
   1015 #if GOOGLE_CUDA
   1016 namespace functor {
   1017 #define DECLARE_GPU_SPEC(T)                                              \
   1018   template <>                                                            \
   1019   void TensorZero<GPUDevice, T>::operator()(const GPUDevice& d,          \
   1020                                             typename TTypes<T>::Flat t); \
   1021                                                                          \
   1022   extern template struct TensorZero<GPUDevice, T>;                       \
   1023                                                                          \
   1024   template <>                                                            \
   1025   void TensorUnalignedZero<GPUDevice, T>::operator()(                    \
   1026       const GPUDevice& d, typename TTypes<T>::UnalignedFlat t);          \
   1027                                                                          \
   1028   extern template struct TensorUnalignedZero<GPUDevice, T>;
   1030 DECLARE_GPU_SPEC(float);
   1031 DECLARE_GPU_SPEC(Eigen::half);
   1032 // DECLARE_GPU_SPEC(double);
   1033 #undef DECLARE_GPU_SPEC
   1034 }  // end namespace functor
   1036 #define REGISTER_GPU_KERNEL(T)                           \
   1037   REGISTER_KERNEL_BUILDER(Name("BlockLSTM")              \
   1038                               .Device(DEVICE_GPU)        \
   1039                               .HostMemory("seq_len_max") \
   1040                               .TypeConstraint<T>("T"),   \
   1041                           BlockLSTMOp<GPUDevice, T, true>);
   1043 REGISTER_GPU_KERNEL(float);
   1044 REGISTER_GPU_KERNEL(Eigen::half);
   1045 // REGISTER_GPU_KERNEL(double);
   1046 #undef REGISTER_GPU_KERNEL
   1047 #endif  // GOOGLE_CUDA
   1049 template <typename Device, typename T, bool USE_CUBLAS>
   1050 class BlockLSTMGradOp : public OpKernel {
   1051  public:
   1052   explicit BlockLSTMGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
   1053     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_));
   1054   }
   1056   void Compute(OpKernelContext* ctx) override {
   1057     const Tensor* seq_len_max_tensor = nullptr;
   1058     OP_REQUIRES_OK(ctx, ctx->input("seq_len_max", &seq_len_max_tensor));
   1060     const Tensor* x;
   1061     OP_REQUIRES_OK(ctx, ctx->input("x", &x));
   1062     OP_REQUIRES(ctx, x->dims() == 3, errors::InvalidArgument("x must be 3D"));
   1063     const int64 timelen = x->dim_size(0);
   1064     const int64 batch_size = x->dim_size(1);
   1065     const int64 input_size = x->dim_size(2);
   1067     const Tensor* cs_prev_tensor = nullptr;
   1068     OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor));
   1070     const Tensor* h_prev_tensor = nullptr;
   1071     OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor));
   1073     const Tensor* w_tensor = nullptr;
   1074     OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor));
   1075     const int64 cell_size = w_tensor->dim_size(1) / 4;
   1076     OP_REQUIRES(ctx, input_size + cell_size == w_tensor->dim_size(0),
   1077                 errors::InvalidArgument(
   1078                     "w matrix rows don't match: ", input_size + cell_size,
   1079                     " vs. ", w_tensor->dim_size(0)));
   1081     const Tensor* wci_tensor = nullptr;
   1082     OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor));
   1084     const Tensor* wcf_tensor = nullptr;
   1085     OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor));
   1087     const Tensor* wco_tensor = nullptr;
   1088     OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor));
   1090     const Tensor* b_tensor = nullptr;
   1091     OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor));
   1092     OP_REQUIRES(
   1093         ctx, cell_size == b_tensor->dim_size(0) / 4,
   1094         errors::InvalidArgument("w and b cell_size don't match: ", cell_size,
   1095                                 " vs. ", b_tensor->dim_size(0)));
   1097     const Tensor* i_out = nullptr;
   1098     OP_REQUIRES_OK(ctx, ctx->input("i", &i_out));
   1100     const Tensor* cs_out = nullptr;
   1101     OP_REQUIRES_OK(ctx, ctx->input("cs", &cs_out));
   1103     const Tensor* f_out = nullptr;
   1104     OP_REQUIRES_OK(ctx, ctx->input("f", &f_out));
   1106     const Tensor* o_out = nullptr;
   1107     OP_REQUIRES_OK(ctx, ctx->input("o", &o_out));
   1109     const Tensor* ci_out = nullptr;
   1110     OP_REQUIRES_OK(ctx, ctx->input("ci", &ci_out));
   1112     const Tensor* co_out = nullptr;
   1113     OP_REQUIRES_OK(ctx, ctx->input("co", &co_out));
   1115     const Tensor* h_out = nullptr;
   1116     OP_REQUIRES_OK(ctx, ctx->input("h", &h_out));
   1118     const Tensor* cs_grad = nullptr;
   1119     OP_REQUIRES_OK(ctx, ctx->input("cs_grad", &cs_grad));
   1121     const Tensor* h_grad = nullptr;
   1122     OP_REQUIRES_OK(ctx, ctx->input("h_grad", &h_grad));
   1124     TensorShape batch_input_shape({timelen, batch_size, input_size});
   1125     Tensor* x_grad;
   1126     OP_REQUIRES_OK(ctx,
   1127                    ctx->allocate_output("x_grad", batch_input_shape, &x_grad));
   1129     Tensor* cs_prev_grad_tensor = nullptr;
   1130     OP_REQUIRES_OK(ctx,
   1131                    ctx->allocate_output("cs_prev_grad", cs_prev_tensor->shape(),
   1132                                         &cs_prev_grad_tensor));
   1134     Tensor* h_prev_grad_tensor = nullptr;
   1135     OP_REQUIRES_OK(ctx,
   1136                    ctx->allocate_output("h_prev_grad", h_prev_tensor->shape(),
   1137                                         &h_prev_grad_tensor));
   1139     Tensor* w_grad_tensor = nullptr;
   1140     OP_REQUIRES_OK(
   1141         ctx, ctx->allocate_output("w_grad", w_tensor->shape(), &w_grad_tensor));
   1143     Tensor* wci_grad_tensor = nullptr;
   1144     OP_REQUIRES_OK(ctx, ctx->allocate_output("wci_grad", wci_tensor->shape(),
   1145                                              &wci_grad_tensor));
   1147     Tensor* wcf_grad_tensor = nullptr;
   1148     OP_REQUIRES_OK(ctx, ctx->allocate_output("wcf_grad", wcf_tensor->shape(),
   1149                                              &wcf_grad_tensor));
   1151     Tensor* wco_grad_tensor = nullptr;
   1152     OP_REQUIRES_OK(ctx, ctx->allocate_output("wco_grad", wco_tensor->shape(),
   1153                                              &wco_grad_tensor));
   1155     Tensor* b_grad_tensor = nullptr;
   1156     OP_REQUIRES_OK(
   1157         ctx, ctx->allocate_output("b_grad", b_tensor->shape(), &b_grad_tensor));
   1159     TensorShape batch_cell_shape({batch_size, cell_size});
   1161     Tensor xh_tensor;
   1162     OP_REQUIRES_OK(ctx, ctx->allocate_temp(
   1163                             DataTypeToEnum<T>::v(),
   1164                             TensorShape({batch_size, input_size + cell_size}),
   1165                             &xh_tensor));
   1167     Tensor xh_grad_tensor;
   1168     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
   1169                                            xh_tensor.shape(), &xh_grad_tensor));
   1171     Tensor do_tensor;
   1172     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
   1173                                            batch_cell_shape, &do_tensor));
   1175     Tensor dcs_tensor;
   1176     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
   1177                                            batch_cell_shape, &dcs_tensor));
   1179     Tensor dci_tensor;
   1180     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
   1181                                            batch_cell_shape, &dci_tensor));
   1183     Tensor df_tensor;
   1184     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
   1185                                            batch_cell_shape, &df_tensor));
   1187     Tensor di_tensor;
   1188     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
   1189                                            batch_cell_shape, &di_tensor));
   1191     Tensor dicfo_tensor;
   1192     OP_REQUIRES_OK(ctx,
   1193                    ctx->allocate_temp(DataTypeToEnum<T>::v(),
   1194                                       TensorShape({batch_size, cell_size * 4}),
   1195                                       &dicfo_tensor));
   1197     Tensor cs_grad_tensor;
   1198     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
   1199                                            batch_cell_shape, &cs_grad_tensor));
   1201     Tensor h_grad_tensor;
   1202     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
   1203                                            batch_cell_shape, &h_grad_tensor));
   1205     const Device& device = ctx->eigen_device<Device>();
   1207     functor::TensorZero<Device, T>()(device, cs_grad_tensor.flat<T>());
   1208     functor::TensorZero<Device, T>()(device, cs_prev_grad_tensor->flat<T>());
   1209     functor::TensorZero<Device, T>()(device, h_grad_tensor.flat<T>());
   1210     functor::TensorZero<Device, T>()(device, h_prev_grad_tensor->flat<T>());
   1211     functor::TensorZero<Device, T>()(device, w_grad_tensor->flat<T>());
   1212     functor::TensorZero<Device, T>()(device, wci_grad_tensor->flat<T>());
   1213     functor::TensorZero<Device, T>()(device, wcf_grad_tensor->flat<T>());
   1214     functor::TensorZero<Device, T>()(device, wco_grad_tensor->flat<T>());
   1215     functor::TensorZero<Device, T>()(device, b_grad_tensor->flat<T>());
   1217     const int64 seq_len_max = seq_len_max_tensor->scalar<int64>()();
   1218     SliceHelper<Device, T> slicer(ctx);
   1219     for (int64 t = seq_len_max - 1; t >= 0; --t) {
   1220       const Tensor& x_tensor = slicer.InputSlice(*x, t, "x");
   1221       const Tensor& cs_prev_tensor2 =
   1222           t == 0 ? *cs_prev_tensor
   1223                  : slicer.InputSlice(*cs_out, t - 1, "cs_prev");
   1224       const Tensor& h_prev_tensor2 =
   1225           t == 0 ? *h_prev_tensor : slicer.InputSlice(*h_out, t - 1, "h_prev");
   1226       const Tensor& i_tensor = slicer.InputSlice(*i_out, t, "i_out");
   1227       const Tensor& cs_tensor = slicer.InputSlice(*cs_out, t, "cs_out");
   1228       const Tensor& f_tensor = slicer.InputSlice(*f_out, t, "f_out");
   1229       const Tensor& o_tensor = slicer.InputSlice(*o_out, t, "o_out");
   1230       const Tensor& ci_tensor = slicer.InputSlice(*ci_out, t, "ci_out");
   1231       const Tensor& co_tensor = slicer.InputSlice(*co_out, t, "co_out");
   1233       // Grab previous CS grad.
   1234       const Tensor& const_cs_prev_grad_tensor = *cs_prev_grad_tensor;
   1235       const Tensor const_cs_grad_slice =
   1236           slicer.InputSlice(*cs_grad, t, "cs_grad");
   1237       functor::TensorAdd<Device, T>()(
   1238           device, const_cs_prev_grad_tensor.flat<T>(),
   1239           const_cs_grad_slice.flat<T>(), cs_grad_tensor.flat<T>());
   1241       // Combine previous h grad and h grad coming on top.
   1242       const Tensor& const_h_prev_grad_tensor = *h_prev_grad_tensor;
   1243       const Tensor const_h_grad_slice = slicer.InputSlice(*h_grad, t, "h_grad");
   1244       functor::TensorAdd<Device, T>()(
   1245           device, const_h_prev_grad_tensor.flat<T>(),
   1246           const_h_grad_slice.flat<T>(), h_grad_tensor.flat<T>());
   1248       const Tensor& const_cs_grad_tensor = cs_grad_tensor;
   1249       const Tensor& const_h_grad_tensor = h_grad_tensor;
   1251       Tensor x_grad_tensor = slicer.OutputSlice(x_grad, t, "x_grad");
   1252       functor::BlockLSTMBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
   1253                                                      cell_size)(
   1254           ctx, device, use_peephole_, x_tensor.matrix<T>(),
   1255           cs_prev_tensor2.matrix<T>(), h_prev_tensor2.matrix<T>(),
   1256           w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(),
   1257           wco_tensor->vec<T>(), b_tensor->vec<T>(), xh_tensor.matrix<T>(),
   1258           i_tensor.matrix<T>(), cs_tensor.matrix<T>(), f_tensor.matrix<T>(),
   1259           o_tensor.matrix<T>(), ci_tensor.matrix<T>(), co_tensor.matrix<T>(),
   1260           const_cs_grad_tensor.matrix<T>(), const_h_grad_tensor.matrix<T>(),
   1261           do_tensor.matrix<T>(), dcs_tensor.matrix<T>(), dci_tensor.matrix<T>(),
   1262           df_tensor.matrix<T>(), di_tensor.matrix<T>(),
   1263           dicfo_tensor.matrix<T>(), cs_prev_grad_tensor->matrix<T>(),
   1264           h_prev_grad_tensor->matrix<T>(), xh_grad_tensor.matrix<T>(),
   1265           x_grad_tensor.matrix<T>(), w_grad_tensor->matrix<T>(),
   1266           wci_grad_tensor->vec<T>(), wcf_grad_tensor->vec<T>(),
   1267           wco_grad_tensor->vec<T>(), b_grad_tensor->vec<T>());
   1268       slicer.FinishTimeStep();
   1269     }
   1271     if (seq_len_max < timelen) {
   1272       Tensor x_grad_tensor = x_grad->Slice(seq_len_max, timelen);
   1273       functor::TensorUnalignedZero<Device, T>()(
   1274           device, x_grad_tensor.unaligned_flat<T>());
   1275     }
   1276   }
   1278  private:
   1279   bool use_peephole_;
   1280 };
   1282 #define REGISTER_KERNEL(T)                                             \
   1283   REGISTER_KERNEL_BUILDER(                                             \
   1284       Name("BlockLSTMGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
   1285       BlockLSTMGradOp<CPUDevice, T, false>);
   1286 REGISTER_KERNEL(float);
   1287 REGISTER_KERNEL(Eigen::half);
   1288 #undef REGISTER_KERNEL
   1290 #if GOOGLE_CUDA
   1291 namespace functor {
   1292 #define DECLARE_GPU_SPEC(T)                                                    \
   1293   template <>                                                                  \
   1294   void TensorCopy<GPUDevice, T>::operator()(const GPUDevice& d,                \
   1295                                             typename TTypes<T>::ConstFlat src, \
   1296                                             typename TTypes<T>::Flat dst);     \
   1297                                                                                \
   1298   template <>                                                                  \
   1299   void TensorCopyUnaligned<GPUDevice, T>::operator()(                          \
   1300       const GPUDevice& d, typename TTypes<T>::UnalignedConstFlat src,          \
   1301       typename TTypes<T>::Flat dst);                                           \
   1302                                                                                \
   1303   template <>                                                                  \
   1304   void TensorCopyToUnaligned<GPUDevice, T>::operator()(                        \
   1305       const GPUDevice& d, typename TTypes<T>::ConstFlat src,                   \
   1306       typename TTypes<T>::UnalignedFlat dst);                                  \
   1307                                                                                \
   1308   template <>                                                                  \
   1309   void TensorAdd<GPUDevice, T>::operator()(                                    \
   1310       const GPUDevice& d, typename TTypes<T>::ConstFlat a,                     \
   1311       typename TTypes<T>::ConstFlat b, typename TTypes<T>::Flat c);            \
   1312                                                                                \
   1313   template <>                                                                  \
   1314   void BlockLSTMBprop<GPUDevice, T, true>::operator()(                         \
   1315       OpKernelContext* ctx, const GPUDevice& d, bool use_peephole,             \
   1316       typename TTypes<T>::ConstMatrix x,                                       \
   1317       typename TTypes<T>::ConstMatrix cs_prev,                                 \
   1318       typename TTypes<T>::ConstMatrix h_prev,                                  \
   1319       typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci,     \
   1320       typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,      \
   1321       typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh,           \
   1322       typename TTypes<T>::ConstMatrix i, typename TTypes<T>::ConstMatrix cs,   \
   1323       typename TTypes<T>::ConstMatrix f, typename TTypes<T>::ConstMatrix o,    \
   1324       typename TTypes<T>::ConstMatrix ci, typename TTypes<T>::ConstMatrix co,  \
   1325       typename TTypes<T>::ConstMatrix cs_grad,                                 \
   1326       typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,  \
   1327       typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,          \
   1328       typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,            \
   1329       typename TTypes<T>::Matrix dicfo,                                        \
   1330       typename TTypes<T>::Matrix cs_prev_grad,                                 \
   1331       typename TTypes<T>::Matrix h_prev_grad,                                  \
   1332       typename TTypes<T>::Matrix xh_grad, typename TTypes<T>::Matrix x_grad,   \
   1333       typename TTypes<T>::Matrix w_grad, typename TTypes<T>::Vec wci_grad,     \
   1334       typename TTypes<T>::Vec wcf_grad, typename TTypes<T>::Vec wco_grad,      \
   1335       typename TTypes<T>::Vec b_grad);                                         \
   1336                                                                                \
   1337   extern template struct TensorCopy<GPUDevice, T>;                             \
   1338   extern template struct TensorAdd<GPUDevice, T>;                              \
   1339   extern template struct BlockLSTMBprop<GPUDevice, T, true>;
   1341 DECLARE_GPU_SPEC(float);
   1342 DECLARE_GPU_SPEC(Eigen::half);
   1343 // DECLARE_GPU_SPEC(double);
   1344 #undef DECLARE_GPU_SPEC
   1345 }  // end namespace functor
   1347 #define REGISTER_GPU_KERNEL(T)                           \
   1348   REGISTER_KERNEL_BUILDER(Name("BlockLSTMGrad")          \
   1349                               .Device(DEVICE_GPU)        \
   1350                               .HostMemory("seq_len_max") \
   1351                               .TypeConstraint<T>("T"),   \
   1352                           BlockLSTMGradOp<GPUDevice, T, true>);
   1354 REGISTER_GPU_KERNEL(float);
   1355 REGISTER_GPU_KERNEL(Eigen::half);
   1356 // REGISTER_GPU_KERNEL(double);
   1357 #undef REGISTER_GPU_KERNEL
   1358 #endif  // GOOGLE_CUDA
   1360 }  // end namespace tensorflow