Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #include "tensorflow/contrib/rnn/kernels/gru_ops.h"
     19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     20 #include "tensorflow/core/framework/op_kernel.h"
     21 
     22 namespace tensorflow {
     23 
     24 typedef Eigen::ThreadPoolDevice CPUDevice;
     25 typedef Eigen::GpuDevice GPUDevice;
     26 
     27 template <typename Device, typename T, bool USE_CUBLAS>
     28 class GRUCellBlockOp : public OpKernel {
     29  public:
     30   explicit GRUCellBlockOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
     31   // TODO(gitegaurav) Replace the input checks with some smarter function.
     32   void Compute(OpKernelContext* ctx) override {
     33     // Grab the input tensors.
     34     const Tensor* x_tensor = nullptr;
     35     OP_REQUIRES_OK(ctx, ctx->input("x", &x_tensor));
     36 
     37     const Tensor* h_prev_tensor = nullptr;
     38     OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor));
     39 
     40     const Tensor* w_ru_tensor = nullptr;
     41     OP_REQUIRES_OK(ctx, ctx->input("w_ru", &w_ru_tensor));
     42 
     43     const Tensor* w_c_tensor = nullptr;
     44     OP_REQUIRES_OK(ctx, ctx->input("w_c", &w_c_tensor));
     45 
     46     const Tensor* b_ru_tensor = nullptr;
     47     OP_REQUIRES_OK(ctx, ctx->input("b_ru", &b_ru_tensor));
     48 
     49     const Tensor* b_c_tensor = nullptr;
     50     OP_REQUIRES_OK(ctx, ctx->input("b_c", &b_c_tensor));
     51 
     52     const int64 batch_size = x_tensor->dim_size(0);
     53     const int64 input_size = x_tensor->dim_size(1);
     54     const int64 cell_size = h_prev_tensor->dim_size(1);
     55 
     56     // Sanity checks for input shapes.
     57 
     58     // Shape of 'h' must be [batch_size, cell_size]
     59     OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size,
     60                 errors::InvalidArgument("h_prev.dims(0) != batch_size: ",
     61                                         h_prev_tensor->dim_size(0), " vs. ",
     62                                         batch_size));
     63     OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size,
     64                 errors::InvalidArgument(
     65                     "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1),
     66                     " vs. ", cell_size));
     67 
     68     // Shape of 'w_ru' must be [input_size+cell_size, 2*cell_size]
     69     OP_REQUIRES(ctx, w_ru_tensor->dim_size(0) == input_size + cell_size,
     70                 errors::InvalidArgument(
     71                     "w_ru.dim_size(0) != input_size + cell_size: ",
     72                     w_ru_tensor->dim_size(0), " vs. ", input_size + cell_size));
     73 
     74     OP_REQUIRES(ctx, w_ru_tensor->dim_size(1) == cell_size * 2,
     75                 errors::InvalidArgument("w_ru.dim_size(1) != cell_size * 2: ",
     76                                         w_ru_tensor->dim_size(1), " vs. ",
     77                                         cell_size * 2));
     78 
     79     // Shape of 'w_c' must be [input_size+cell_size, cell_size]
     80     OP_REQUIRES(ctx, w_c_tensor->dim_size(0) == input_size + cell_size,
     81                 errors::InvalidArgument(
     82                     "w_c.dim_size(0) != input_size + cell_size: ",
     83                     w_c_tensor->dim_size(0), " vs. ", input_size + cell_size));
     84 
     85     OP_REQUIRES(ctx, w_c_tensor->dim_size(1) == cell_size,
     86                 errors::InvalidArgument(
     87                     "w_c.dim_size(1) != cell_size: ", w_c_tensor->dim_size(1),
     88                     " vs. ", cell_size));
     89 
     90     // Shape of 'b_ru' must be [2*cell_size]
     91     OP_REQUIRES(ctx, b_ru_tensor->dim_size(0) == cell_size * 2,
     92                 errors::InvalidArgument("b_ru.dim_size(0) != cell_size * 2: ",
     93                                         b_ru_tensor->dim_size(0), " vs. ",
     94                                         cell_size * 2));
     95 
     96     OP_REQUIRES(ctx, b_ru_tensor->dims() == 1,
     97                 errors::InvalidArgument("Rank of b_ru must be 1",
     98                                         b_ru_tensor->dims(), " vs. 1", 1));
     99     // Shape of 'b_c' must be [cell_size]
    100     OP_REQUIRES(ctx, b_c_tensor->dim_size(0) == cell_size,
    101                 errors::InvalidArgument(
    102                     "b_c.dim_size(0) != cell_size: ", b_c_tensor->dim_size(0),
    103                     " vs. ", cell_size));
    104     OP_REQUIRES(ctx, b_c_tensor->dims() == 1,
    105                 errors::InvalidArgument("Rank of b_c must be 1",
    106                                         b_c_tensor->dims(), " vs. 1"));
    107 
    108     // Create output tensors.
    109     Tensor* r_tensor = nullptr;
    110     OP_REQUIRES_OK(
    111         ctx, ctx->allocate_output("r", TensorShape({batch_size, cell_size}),
    112                                   &r_tensor));
    113 
    114     Tensor* u_tensor = nullptr;
    115     OP_REQUIRES_OK(
    116         ctx, ctx->allocate_output("u", TensorShape({batch_size, cell_size}),
    117                                   &u_tensor));
    118 
    119     Tensor* c_tensor = nullptr;
    120     OP_REQUIRES_OK(
    121         ctx, ctx->allocate_output("c", TensorShape({batch_size, cell_size}),
    122                                   &c_tensor));
    123 
    124     Tensor* h_tensor = nullptr;
    125     OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
    126                             {"h_prev"}, "h",
    127                             TensorShape({batch_size, cell_size}), &h_tensor));
    128 
    129     // Allocate temp tensors.
    130     Tensor x_h_prev_tensor;
    131     OP_REQUIRES_OK(ctx, ctx->allocate_temp(
    132                             DataTypeToEnum<T>::v(),
    133                             TensorShape({batch_size, input_size + cell_size}),
    134                             &x_h_prev_tensor));
    135 
    136     Tensor x_h_prevr_tensor;
    137     OP_REQUIRES_OK(ctx, ctx->allocate_temp(
    138                             DataTypeToEnum<T>::v(),
    139                             TensorShape({batch_size, input_size + cell_size}),
    140                             &x_h_prevr_tensor));
    141 
    142     Tensor r_u_bar_tensor;
    143     OP_REQUIRES_OK(ctx,
    144                    ctx->allocate_temp(DataTypeToEnum<T>::v(),
    145                                       TensorShape({batch_size, 2 * cell_size}),
    146                                       &r_u_bar_tensor));
    147 
    148     const Device& device = ctx->eigen_device<Device>();
    149 
    150     functor::GRUBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
    151                                                       cell_size)(
    152         ctx, device, x_tensor->matrix<T>(), h_prev_tensor->matrix<T>(),
    153         w_ru_tensor->matrix<T>(), w_c_tensor->matrix<T>(),
    154         b_ru_tensor->vec<T>(), b_c_tensor->vec<T>(), r_u_bar_tensor.matrix<T>(),
    155         r_tensor->matrix<T>(), u_tensor->matrix<T>(), c_tensor->matrix<T>(),
    156         h_tensor->matrix<T>(), x_h_prev_tensor.matrix<T>(),
    157         x_h_prevr_tensor.matrix<T>());
    158   }
    159 };
    160 
    161 // Register the Block GRU cell kernel for CPU.
    162 #define REGISTER_KERNEL(T)                                            \
    163   REGISTER_KERNEL_BUILDER(                                            \
    164       Name("GRUBlockCell").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    165       GRUCellBlockOp<CPUDevice, T, false>);
    166 
    167 REGISTER_KERNEL(float);
    168 #undef REGISTER_KERNEL
    169 
    170 template <typename Device, typename T, bool USE_CUBLAS>
    171 class GRUBlockCellGradOp : public OpKernel {
    172  public:
    173   explicit GRUBlockCellGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    174 
    175   void Compute(OpKernelContext* ctx) override {
    176     // Grab the input tensors.
    177     const Tensor* x_tensor = nullptr;
    178     OP_REQUIRES_OK(ctx, ctx->input("x", &x_tensor));
    179 
    180     const Tensor* h_prev_tensor = nullptr;
    181     OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor));
    182 
    183     const Tensor* w_ru_tensor = nullptr;
    184     OP_REQUIRES_OK(ctx, ctx->input("w_ru", &w_ru_tensor));
    185 
    186     const Tensor* w_c_tensor = nullptr;
    187     OP_REQUIRES_OK(ctx, ctx->input("w_c", &w_c_tensor));
    188 
    189     const Tensor* b_ru_tensor = nullptr;
    190     OP_REQUIRES_OK(ctx, ctx->input("b_ru", &b_ru_tensor));
    191 
    192     const Tensor* b_c_tensor = nullptr;
    193     OP_REQUIRES_OK(ctx, ctx->input("b_c", &b_c_tensor));
    194 
    195     const Tensor* r_tensor = nullptr;
    196     OP_REQUIRES_OK(ctx, ctx->input("r", &r_tensor));
    197 
    198     const Tensor* u_tensor = nullptr;
    199     OP_REQUIRES_OK(ctx, ctx->input("u", &u_tensor));
    200 
    201     const Tensor* c_tensor = nullptr;
    202     OP_REQUIRES_OK(ctx, ctx->input("c", &c_tensor));
    203 
    204     const Tensor* d_h_tensor = nullptr;
    205     OP_REQUIRES_OK(ctx, ctx->input("d_h", &d_h_tensor));
    206 
    207     const int64 batch_size = x_tensor->dim_size(0);
    208     const int64 input_size = x_tensor->dim_size(1);
    209     const int64 cell_size = h_prev_tensor->dim_size(1);
    210 
    211     // Sanity checks for input shapes.
    212 
    213     // Shape of 'h_prev' must be [batch_size, cell_size]
    214     OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size,
    215                 errors::InvalidArgument("h_prev.dims(0) != batch_size: ",
    216                                         h_prev_tensor->dim_size(0), " vs. ",
    217                                         batch_size));
    218     OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size,
    219                 errors::InvalidArgument(
    220                     "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1),
    221                     " vs. ", cell_size));
    222 
    223     // Shape of 'w_ru' must be [input_size+cell_size, 2*cell_size]
    224     OP_REQUIRES(ctx, w_ru_tensor->dim_size(0) == input_size + cell_size,
    225                 errors::InvalidArgument(
    226                     "w_ru.dim_size(0) != input_size + cell_size: ",
    227                     w_ru_tensor->dim_size(0), " vs. ", input_size + cell_size));
    228 
    229     OP_REQUIRES(ctx, w_ru_tensor->dim_size(1) == cell_size * 2,
    230                 errors::InvalidArgument("w_ru.dim_size(1) != cell_size * 2: ",
    231                                         w_ru_tensor->dim_size(1), " vs. ",
    232                                         cell_size * 2));
    233 
    234     // Shape of 'w_c' must be [input_size+cell_size, cell_size]
    235     OP_REQUIRES(ctx, w_c_tensor->dim_size(0) == input_size + cell_size,
    236                 errors::InvalidArgument(
    237                     "w_c.dim_size(0) != input_size + cell_size: ",
    238                     w_c_tensor->dim_size(0), " vs. ", input_size + cell_size));
    239 
    240     OP_REQUIRES(ctx, w_c_tensor->dim_size(1) == cell_size,
    241                 errors::InvalidArgument(
    242                     "w_c.dim_size(1) != cell_size: ", w_c_tensor->dim_size(1),
    243                     " vs. ", cell_size));
    244 
    245     // Shape of 'b_ru' must be [2*cell_size]
    246     OP_REQUIRES(ctx, b_ru_tensor->dim_size(0) == cell_size * 2,
    247                 errors::InvalidArgument("b_ru.dim_size(0) != cell_size * 2: ",
    248                                         b_ru_tensor->dim_size(0), " vs. ",
    249                                         cell_size * 2));
    250 
    251     OP_REQUIRES(ctx, b_ru_tensor->dims() == 1,
    252                 errors::InvalidArgument("Rank of b_ru must be 1",
    253                                         b_ru_tensor->dims(), " vs. 1"));
    254 
    255     // Shape of 'b_c' must be [cell_size]
    256     OP_REQUIRES(ctx, b_c_tensor->dim_size(0) == cell_size,
    257                 errors::InvalidArgument(
    258                     "b_c.dim_size(0) != cell_size: ", b_c_tensor->dim_size(0),
    259                     " vs. ", cell_size));
    260 
    261     OP_REQUIRES(ctx, b_c_tensor->dims() == 1,
    262                 errors::InvalidArgument("Rank of b_c must be 1 ",
    263                                         b_c_tensor->dims(), " vs. 1"));
    264 
    265     // Shape of 'r' must be [batch_size, cell_size]
    266     OP_REQUIRES(ctx, r_tensor->dim_size(0) == batch_size,
    267                 errors::InvalidArgument(
    268                     "r.dims(0) != batch_size: ", r_tensor->dim_size(0), " vs. ",
    269                     batch_size));
    270     OP_REQUIRES(ctx, r_tensor->dim_size(1) == cell_size,
    271                 errors::InvalidArgument(
    272                     "r.dims(1) != cell_size: ", r_tensor->dim_size(1), " vs. ",
    273                     cell_size));
    274 
    275     // Shape of 'u' must be [batch_size, cell_size]
    276     OP_REQUIRES(ctx, u_tensor->dim_size(0) == batch_size,
    277                 errors::InvalidArgument(
    278                     "u.dims(0) != batch_size: ", u_tensor->dim_size(0), " vs. ",
    279                     batch_size));
    280     OP_REQUIRES(ctx, u_tensor->dim_size(1) == cell_size,
    281                 errors::InvalidArgument(
    282                     "u.dims(1) != cell_size: ", u_tensor->dim_size(1), " vs. ",
    283                     cell_size));
    284 
    285     // Shape of 'c' must be [batch_size, cell_size]
    286     OP_REQUIRES(ctx, c_tensor->dim_size(0) == batch_size,
    287                 errors::InvalidArgument(
    288                     "c.dims(0) != batch_size: ", c_tensor->dim_size(0), " vs. ",
    289                     batch_size));
    290     OP_REQUIRES(ctx, c_tensor->dim_size(1) == cell_size,
    291                 errors::InvalidArgument(
    292                     "c.dims(1) != cell_size: ", c_tensor->dim_size(1), " vs. ",
    293                     cell_size));
    294 
    295     // Shape of 'd_h' must be [batch_size, cell_size]
    296     OP_REQUIRES(ctx, d_h_tensor->dim_size(0) == batch_size,
    297                 errors::InvalidArgument(
    298                     "d_h.dims(0) != batch_size: ", d_h_tensor->dim_size(0),
    299                     " vs. ", batch_size));
    300     OP_REQUIRES(ctx, d_h_tensor->dim_size(1) == cell_size,
    301                 errors::InvalidArgument(
    302                     "d_h.dims(1) != cell_size: ", d_h_tensor->dim_size(1),
    303                     " vs. ", cell_size));
    304 
    305     // Create output tensors.
    306     Tensor* d_x_tensor = nullptr;
    307     OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
    308                             {"x"}, "d_x", TensorShape({batch_size, input_size}),
    309                             &d_x_tensor));
    310 
    311     Tensor* d_h_prev_tensor = nullptr;
    312     OP_REQUIRES_OK(
    313         ctx, ctx->forward_input_or_allocate_output(
    314                  {"h_prev"}, "d_h_prev", TensorShape({batch_size, cell_size}),
    315                  &d_h_prev_tensor));
    316 
    317     Tensor* d_c_bar_tensor;
    318     OP_REQUIRES_OK(ctx, ctx->allocate_output(
    319                             "d_c_bar", TensorShape({batch_size, cell_size}),
    320                             &d_c_bar_tensor));
    321 
    322     Tensor* d_r_bar_u_bar_tensor;
    323     OP_REQUIRES_OK(
    324         ctx, ctx->allocate_output("d_r_bar_u_bar",
    325                                   TensorShape({batch_size, 2 * cell_size}),
    326                                   &d_r_bar_u_bar_tensor));
    327 
    328     // Allocate temp tensors.
    329     Tensor d_r_bar_tensor;
    330     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
    331                                            TensorShape({batch_size, cell_size}),
    332                                            &d_r_bar_tensor));
    333 
    334     Tensor d_u_bar_tensor;
    335     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
    336                                            TensorShape({batch_size, cell_size}),
    337                                            &d_u_bar_tensor));
    338 
    339     Tensor d_h_prevr_tensor;
    340     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
    341                                            TensorShape({batch_size, cell_size}),
    342                                            &d_h_prevr_tensor));
    343 
    344     Tensor d_x_component_1_h_prev_compenent_1;
    345     OP_REQUIRES_OK(ctx, ctx->allocate_temp(
    346                             DataTypeToEnum<T>::v(),
    347                             TensorShape({batch_size, input_size + cell_size}),
    348                             &d_x_component_1_h_prev_compenent_1));
    349 
    350     Tensor d_x_component_2_h_prevr;
    351     OP_REQUIRES_OK(ctx, ctx->allocate_temp(
    352                             DataTypeToEnum<T>::v(),
    353                             TensorShape({batch_size, input_size + cell_size}),
    354                             &d_x_component_2_h_prevr));
    355 
    356     const Device& device = ctx->eigen_device<Device>();
    357 
    358     functor::GRUBlockCellBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
    359                                                       cell_size)(
    360         ctx, device, x_tensor->matrix<T>(), h_prev_tensor->matrix<T>(),
    361         w_ru_tensor->matrix<T>(), w_c_tensor->matrix<T>(),
    362         b_ru_tensor->vec<T>(), b_c_tensor->vec<T>(), r_tensor->matrix<T>(),
    363         u_tensor->matrix<T>(), c_tensor->matrix<T>(), d_h_tensor->matrix<T>(),
    364         d_x_tensor->matrix<T>(), d_h_prev_tensor->matrix<T>(),
    365         d_c_bar_tensor->matrix<T>(), d_r_bar_u_bar_tensor->matrix<T>(),
    366         d_r_bar_tensor.matrix<T>(), d_u_bar_tensor.matrix<T>(),
    367         d_h_prevr_tensor.matrix<T>(),
    368         d_x_component_1_h_prev_compenent_1.matrix<T>(),
    369         d_x_component_2_h_prevr.matrix<T>());
    370   }
    371 };
    372 
    373 // Register the gradient kernel for CPU.
    374 #define REGISTER_KERNEL(T)                                                \
    375   REGISTER_KERNEL_BUILDER(                                                \
    376       Name("GRUBlockCellGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    377       GRUBlockCellGradOp<CPUDevice, T, false>);
    378 
    379 REGISTER_KERNEL(float);
    380 #undef REGISTER_KERNEL
    381 
    382 // GPU support.
    383 #if GOOGLE_CUDA
    384 #define EIGEN_USE_GPU
    385 
    386 // Forward declare the GPU Fprop functor.
    387 namespace functor {
    388 #define DECLARE_GPU_SPEC(T)                                                   \
    389   template <>                                                                 \
    390   void GRUBlockCellFprop<GPUDevice, T, true>::operator()(                     \
    391       OpKernelContext* ctx, const GPUDevice& d,                               \
    392       typename TTypes<T>::ConstMatrix x,                                      \
    393       typename TTypes<T>::ConstMatrix h_prev,                                 \
    394       typename TTypes<T>::ConstMatrix w_ru,                                   \
    395       typename TTypes<T>::ConstMatrix w_c, typename TTypes<T>::ConstVec b_ru, \
    396       typename TTypes<T>::ConstVec b_c, typename TTypes<T>::Matrix r_u_bar,   \
    397       typename TTypes<T>::Matrix r, typename TTypes<T>::Matrix u,             \
    398       typename TTypes<T>::Matrix c, typename TTypes<T>::Matrix h,             \
    399       typename TTypes<T>::Matrix x_h_prev,                                    \
    400       typename TTypes<T>::Matrix x_h_prevr);                                  \
    401   extern template struct GRUBlockCellFprop<GPUDevice, T, true>;
    402 
    403 DECLARE_GPU_SPEC(float);
    404 #undef DECLARE_GPU_SPEC
    405 }  // end namespace functor
    406 
    407 // Register the Block GRU cell kernel for GPU.
    408 #define REGISTER_GPU_KERNEL(T)                                        \
    409   REGISTER_KERNEL_BUILDER(                                            \
    410       Name("GRUBlockCell").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
    411       GRUCellBlockOp<GPUDevice, T, true>);
    412 
    413 REGISTER_GPU_KERNEL(float);
    414 #undef REGISTER_GPU_KERNEL
    415 
    416 // Forward declare the GPU Bprop functor.
    417 namespace functor {
    418 #define DECLARE_GPU_SPEC(T)                                                    \
    419   template <>                                                                  \
    420   void GRUBlockCellBprop<GPUDevice, T, true>::operator()(                      \
    421       OpKernelContext* ctx, const GPUDevice& d,                                \
    422       typename TTypes<T>::ConstMatrix x, typename TTypes<T>::ConstMatrix h,    \
    423       typename TTypes<T>::ConstMatrix w_ru,                                    \
    424       typename TTypes<T>::ConstMatrix w_c, typename TTypes<T>::ConstVec b_ru,  \
    425       typename TTypes<T>::ConstVec b_c, typename TTypes<T>::ConstMatrix r,     \
    426       typename TTypes<T>::ConstMatrix u, typename TTypes<T>::ConstMatrix c,    \
    427       typename TTypes<T>::ConstMatrix d_h, typename TTypes<T>::Matrix d_x,     \
    428       typename TTypes<T>::Matrix d_h_prev, typename TTypes<T>::Matrix d_c_bar, \
    429       typename TTypes<T>::Matrix d_r_bar_u_bar,                                \
    430       typename TTypes<T>::Matrix d_r_bar, typename TTypes<T>::Matrix d_u_bar,  \
    431       typename TTypes<T>::Matrix d_h_prevr,                                    \
    432       typename TTypes<T>::Matrix d_x_comp1_h_prev_comp1,                       \
    433       typename TTypes<T>::Matrix d_x_comp2_and_h_prevr);                       \
    434   extern template struct GRUBlockCellBprop<GPUDevice, T, true>;
    435 
    436 DECLARE_GPU_SPEC(float);
    437 #undef DECLARE_GPU_SPEC
    438 }  // end namespace functor
    439 
    440 // Register the gradient kernel for GPU.
    441 #define REGISTER_GPU_KERNEL(T)                                            \
    442   REGISTER_KERNEL_BUILDER(                                                \
    443       Name("GRUBlockCellGrad").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
    444       GRUBlockCellGradOp<GPUDevice, T, true>);
    445 
    446 REGISTER_GPU_KERNEL(float);
    447 #undef REGISTER_GPU_KERNEL
    448 #endif  // GOOGLE_CUDA
    449 
    450 }  // end namespace tensorflow
    451