Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #define EIGEN_USE_THREADS
     16 
     17 #include <stddef.h>
     18 #include <atomic>
     19 #include <cmath>
     20 #include <functional>
     21 #include <limits>
     22 #include <string>
     23 #include <unordered_set>
     24 
     25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     26 #include "tensorflow/core/framework/device_base.h"
     27 #include "tensorflow/core/framework/kernel_def_builder.h"
     28 #include "tensorflow/core/framework/op.h"
     29 #include "tensorflow/core/framework/op_def_builder.h"
     30 #include "tensorflow/core/framework/op_kernel.h"
     31 #include "tensorflow/core/framework/register_types.h"
     32 #include "tensorflow/core/framework/tensor.h"
     33 #include "tensorflow/core/framework/tensor_shape.h"
     34 #include "tensorflow/core/framework/tensor_types.h"
     35 #include "tensorflow/core/framework/types.h"
     36 #include "tensorflow/core/lib/core/errors.h"
     37 #include "tensorflow/core/lib/core/status.h"
     38 #include "tensorflow/core/lib/core/stringpiece.h"
     39 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     40 #include "tensorflow/core/lib/hash/hash.h"
     41 #include "tensorflow/core/lib/strings/stringprintf.h"
     42 #include "tensorflow/core/platform/fingerprint.h"
     43 #include "tensorflow/core/platform/mutex.h"
     44 #include "tensorflow/core/platform/types.h"
     45 #include "tensorflow/core/util/env_var.h"
     46 
     47 #if GOOGLE_CUDA
     48 #include "tensorflow/core/platform/stream_executor.h"
     49 #include "tensorflow/core/util/stream_executor_util.h"
     50 #endif  // GOOGLE_CUDA
     51 
     52 /*
     53  * This module implements ops that fuse a multi-layer multi-step RNN/LSTM model
     54  * using the underlying Cudnn library.
     55  *
     56  * Cudnn RNN library exposes an opaque parameter buffer with unknown layout and
     57  * format. And it is very likely that if saved, they cannot be used across
     58  * different GPUs. So users need to first query the size of the opaque
     59  * parameter buffer, and convert it to and from its canonical forms. But each
     60  * actual training step is carried out with the parameter buffer.
     61  *
     62  * Similar to many other ops, the forward op has two flavors: training and
     63  * inference. When training is specified, additional data in reserve_space will
     64  * be produced for the backward pass. So there is a performance penalty.
     65  *
     66  * In addition to the actual data and reserve_space, Cudnn also needs more
     67  * memory as temporary workspace. The memory management to and from
     68  * stream-executor is done through ScratchAllocator. In general,
     69  * stream-executor is responsible for creating the memory of proper size. And
     70  * TensorFlow is responsible for making sure the memory is alive long enough
     71  * and recycles afterwards.
     72  *
     73  */
     74 namespace tensorflow {
     75 
     76 using CPUDevice = Eigen::ThreadPoolDevice;
     77 
     78 #if GOOGLE_CUDA
     79 
     80 using GPUDevice = Eigen::GpuDevice;
     81 
     82 template <typename Device, typename T, typename Index>
     83 class CudnnRNNParamsSizeOp;
     84 
     85 template <typename Device, typename T>
     86 class CudnnRNNParamsToCanonical;
     87 
     88 template <typename Device, typename T>
     89 class CudnnRNNCanonicalToParams;
     90 
     91 template <typename Device, typename T>
     92 class CudnnRNNForwardOp;
     93 
     94 template <typename Device, typename T>
     95 class CudnnRNNBackwardOp;
     96 
     97 enum class TFRNNInputMode {
     98   kRNNLinearInput = 0,
     99   kRNNSkipInput = 1,
    100   kAutoSelect = 9999999
    101 };
    102 
    103 namespace {
    104 using perftools::gputools::DeviceMemory;
    105 using perftools::gputools::DeviceMemoryBase;
    106 using perftools::gputools::ScratchAllocator;
    107 using perftools::gputools::dnn::RnnDirectionMode;
    108 using perftools::gputools::dnn::RnnInputMode;
    109 using perftools::gputools::dnn::RnnMode;
    110 using perftools::gputools::dnn::ToDataType;
    111 using perftools::gputools::port::StatusOr;
    112 
    113 Status ParseRNNMode(const string& str, RnnMode* rnn_mode) {
    114   if (str == "rnn_relu") {
    115     *rnn_mode = RnnMode::kRnnRelu;
    116     return Status::OK();
    117   } else if (str == "rnn_tanh") {
    118     *rnn_mode = RnnMode::kRnnTanh;
    119     return Status::OK();
    120   } else if (str == "lstm") {
    121     *rnn_mode = RnnMode::kRnnLstm;
    122     return Status::OK();
    123   } else if (str == "gru") {
    124     *rnn_mode = RnnMode::kRnnGru;
    125     return Status::OK();
    126   }
    127   return errors::InvalidArgument("Invalid RNN mode: ", str);
    128 }
    129 
    130 Status ParseTFRNNInputMode(const string& str, TFRNNInputMode* rnn_input_mode) {
    131   if (str == "linear_input") {
    132     *rnn_input_mode = TFRNNInputMode::kRNNLinearInput;
    133     return Status::OK();
    134   } else if (str == "skip_input") {
    135     *rnn_input_mode = TFRNNInputMode::kRNNSkipInput;
    136     return Status::OK();
    137   } else if (str == "auto_select") {
    138     *rnn_input_mode = TFRNNInputMode::kAutoSelect;
    139     return Status::OK();
    140   }
    141   return errors::InvalidArgument("Invalid RNN input mode: ", str);
    142 }
    143 
    144 Status ParseRNNDirectionMode(const string& str,
    145                              RnnDirectionMode* rnn_dir_mode) {
    146   if (str == "unidirectional") {
    147     *rnn_dir_mode = RnnDirectionMode::kRnnUnidirectional;
    148     return Status::OK();
    149   } else if (str == "bidirectional") {
    150     *rnn_dir_mode = RnnDirectionMode::kRnnBidirectional;
    151     return Status::OK();
    152   }
    153   return errors::InvalidArgument("Invalid RNN direction mode: ", str);
    154 }
    155 
    156 Status ToRNNInputMode(TFRNNInputMode tf_input_mode, int num_units,
    157                       int input_size, RnnInputMode* input_mode) {
    158   switch (tf_input_mode) {
    159     case TFRNNInputMode::kRNNLinearInput:
    160       *input_mode = RnnInputMode::kRnnLinearSkip;
    161       break;
    162     case TFRNNInputMode::kRNNSkipInput:
    163       *input_mode = RnnInputMode::kRnnSkipInput;
    164       break;
    165     case TFRNNInputMode::kAutoSelect:
    166       *input_mode = (input_size == num_units) ? RnnInputMode::kRnnSkipInput
    167                                               : RnnInputMode::kRnnLinearSkip;
    168       break;
    169     default:
    170       return errors::InvalidArgument("Invalid TF input mode: ",
    171                                      static_cast<int>(tf_input_mode));
    172   }
    173   return Status::OK();
    174 }
    175 
    176 // TODO(zhengxq): Merge those into stream_executor_util.h.
    177 template <typename T>
    178 const DeviceMemory<T> AsDeviceMemory(const Tensor* tensor) {
    179   return DeviceMemory<T>::MakeFromByteSize(
    180       const_cast<T*>(tensor->template flat<T>().data()),
    181       tensor->template flat<T>().size() * sizeof(T));
    182 }
    183 
    184 template <typename T>
    185 DeviceMemory<T> AsDeviceMemory(Tensor* tensor) {
    186   return DeviceMemory<T>::MakeFromByteSize(
    187       tensor->template flat<T>().data(),
    188       tensor->template flat<T>().size() * sizeof(T));
    189 }
    190 
    191 template <typename U, typename T>
    192 DeviceMemory<U> CastDeviceMemory(Tensor* tensor) {
    193   return DeviceMemory<U>::MakeFromByteSize(
    194       tensor->template flat<T>().data(),
    195       tensor->template flat<T>().size() * sizeof(T));
    196 }
    197 
    198 DeviceMemoryBase SliceDeviceMemory(const DeviceMemoryBase& device_memory,
    199                                    int64 offset, int64 size) {
    200   const void* base_ptr = device_memory.opaque();
    201   void* offset_ptr =
    202       const_cast<char*>(reinterpret_cast<const char*>(base_ptr) + offset);
    203   CHECK(offset + size <= device_memory.size())
    204       << "The slice is not within the region of DeviceMemory.";
    205   return DeviceMemoryBase(offset_ptr, size);
    206 }
    207 
    208 inline Status FromExecutorStatus(const perftools::gputools::port::Status& s) {
    209   return s.ok() ? Status::OK()
    210                 : Status(static_cast<tensorflow::error::Code>(
    211                              static_cast<int>(s.code())),
    212                          s.error_message());
    213 }
    214 
    215 template <typename T>
    216 inline Status FromExecutorStatus(
    217     const perftools::gputools::port::StatusOr<T>& s) {
    218   return FromExecutorStatus(s.status());
    219 }
    220 
    221 inline perftools::gputools::port::Status ToExecutorStatus(const Status& s) {
    222   return s.ok() ? perftools::gputools::port::Status::OK()
    223                 : perftools::gputools::port::Status(
    224                       static_cast<perftools::gputools::port::error::Code>(
    225                           static_cast<int>(s.code())),
    226                       s.error_message());
    227 }
    228 
    229 // A helper to allocate temporary scratch memory for Cudnn RNN models. It takes
    230 // the ownership of the underlying memory. The expectation is that the memory
    231 // should be alive for the span of the Cudnn RNN itself.
    232 class CudnnRNNWorkspaceAllocator : public ScratchAllocator {
    233  public:
    234   ~CudnnRNNWorkspaceAllocator() override {}
    235   explicit CudnnRNNWorkspaceAllocator(OpKernelContext* context)
    236       : context_(context) {}
    237   int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override {
    238     return std::numeric_limits<int64>::max();
    239   }
    240   StatusOr<DeviceMemory<uint8>> AllocateBytes(
    241       perftools::gputools::Stream* stream, int64 byte_size) override {
    242     Tensor temporary_memory;
    243     Status allocation_status(context_->allocate_temp(
    244         DT_UINT8, TensorShape({byte_size}), &temporary_memory));
    245     if (!allocation_status.ok()) {
    246       return ToExecutorStatus(allocation_status);
    247     }
    248     // Hold the reference of the allocated tensors until the end of the
    249     // allocator.
    250     allocated_tensors_.push_back(temporary_memory);
    251     total_byte_size_ += byte_size;
    252     return StatusOr<DeviceMemory<uint8>>(
    253         AsDeviceMemory<uint8>(&temporary_memory));
    254   }
    255   int64 TotalByteSize() { return total_byte_size_; }
    256 
    257  private:
    258   int64 total_byte_size_ = 0;
    259   OpKernelContext* context_;  // not owned
    260   std::vector<Tensor> allocated_tensors_;
    261 };
    262 
    263 // A helper to allocate reserve-space memory for Cudnn RNN models. The tensors
    264 // are allocated as a kernel output, and will be fed into the backward pass.
    265 // The memory is expected to live long enough after the backward pass is
    266 // finished.
    267 template <typename T>
    268 class CudnnRNNReserveSpaceAllocator : public ScratchAllocator {
    269  public:
    270   ~CudnnRNNReserveSpaceAllocator() override {}
    271   CudnnRNNReserveSpaceAllocator(OpKernelContext* context, int output_index)
    272       : context_(context), output_index_(output_index) {}
    273   int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override {
    274     return std::numeric_limits<int64>::max();
    275   }
    276   StatusOr<DeviceMemory<uint8>> AllocateBytes(
    277       perftools::gputools::Stream* stream, int64 byte_size) override {
    278     CHECK(total_byte_size_ == 0)
    279         << "Reserve space allocator can only be called once";
    280     int64 allocate_count =
    281         Eigen::divup(byte_size, static_cast<int64>(sizeof(T)));
    282 
    283     Tensor* temporary_memory = nullptr;
    284     Status allocation_status(context_->allocate_output(
    285         output_index_, TensorShape({allocate_count}), &temporary_memory));
    286     if (!allocation_status.ok()) {
    287       return ToExecutorStatus(allocation_status);
    288     }
    289     total_byte_size_ += byte_size;
    290     auto memory_uint8 = DeviceMemory<uint8>::MakeFromByteSize(
    291         temporary_memory->template flat<T>().data(),
    292         temporary_memory->template flat<T>().size() * sizeof(T));
    293     return StatusOr<DeviceMemory<uint8>>(memory_uint8);
    294   }
    295   int64 TotalByteSize() { return total_byte_size_; }
    296 
    297  private:
    298   int64 total_byte_size_ = 0;
    299   OpKernelContext* context_;  // not owned
    300   int output_index_;
    301 };
    302 
    303 // A helper to allocate persistent memory for Cudnn RNN models, which is
    304 // expected to live between kernel invocations.
    305 // This class is not thread-safe.
    306 class CudnnRNNPersistentSpaceAllocator : public ScratchAllocator {
    307  public:
    308   explicit CudnnRNNPersistentSpaceAllocator(OpKernelContext* context)
    309       : context_(context) {}
    310 
    311   ~CudnnRNNPersistentSpaceAllocator() override {}
    312 
    313   int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override {
    314     return std::numeric_limits<int64>::max();
    315   }
    316 
    317   StatusOr<DeviceMemory<uint8>> AllocateBytes(
    318       perftools::gputools::Stream* stream, int64 byte_size) override {
    319     if (total_byte_size_ != 0) {
    320       return Status(error::FAILED_PRECONDITION,
    321                     "Persistent space allocator can only be called once");
    322     }
    323 
    324     Status allocation_status = context_->allocate_persistent(
    325         DT_UINT8, TensorShape({byte_size}), &handle_, nullptr);
    326     if (!allocation_status.ok()) {
    327       return ToExecutorStatus(allocation_status);
    328     }
    329     total_byte_size_ += byte_size;
    330     return AsDeviceMemory<uint8>(handle_.AccessTensor(context_));
    331   }
    332   int64 TotalByteSize() { return total_byte_size_; }
    333 
    334  private:
    335   int64 total_byte_size_ = 0;
    336   PersistentTensor handle_;
    337   OpKernelContext* context_;  // not owned
    338 };
    339 
    340 struct CudnnModelTypes {
    341   RnnMode rnn_mode;
    342   TFRNNInputMode rnn_input_mode;
    343   RnnDirectionMode rnn_direction_mode;
    344   bool HasInputC() const {
    345     // For Cudnn 5.0, only LSTM has input-c. All other models use only input-h.
    346     return rnn_mode == RnnMode::kRnnLstm;
    347   }
    348 };
    349 
    350 // A helper class that collects the shapes to describe a RNN model.
    351 struct CudnnModelShapes {
    352   int num_layers;
    353   int input_size;
    354   int num_units;
    355   int seq_length;
    356   int batch_size;
    357   int dir_count;
    358   TensorShape input_shape;
    359   TensorShape output_shape;
    360   TensorShape hidden_state_shape;
    361   // At present only fields related to cached RnnDescriptor are concerned.
    362   bool IsCompatibleWith(const CudnnModelShapes& rhs) const {
    363     return num_layers == rhs.num_layers && input_size == rhs.input_size &&
    364            num_units == rhs.num_units && dir_count == rhs.dir_count;
    365   }
    366   string RnnDescDebugString() {
    367     return strings::Printf(
    368         "[num_layers, input_size, num_units, dir_count]: [%d, %d, %d, %d]",
    369         num_layers, input_size, num_units, dir_count);
    370   }
    371 };
    372 
    373 // Utility class for using CudnnModelShapes as a hash table key.
    374 struct CudnnModelShapesHasher {
    375   uint64 operator()(const CudnnModelShapes& to_hash) const {
    376     uint64 hash = static_cast<uint64>(to_hash.num_layers);
    377     hash = tensorflow::FingerprintCat64(
    378         hash, static_cast<uint64>(to_hash.input_size));
    379     hash = tensorflow::FingerprintCat64(hash,
    380                                         static_cast<uint64>(to_hash.num_units));
    381     return tensorflow::FingerprintCat64(hash,
    382                                         static_cast<uint64>(to_hash.dir_count));
    383   }
    384 };
    385 
    386 // Utility class for using CudnnModelShapes as a hash table key.
    387 struct CudnnModelShapesComparator {
    388   bool operator()(const CudnnModelShapes& first,
    389                   const CudnnModelShapes& second) const {
    390     return first.IsCompatibleWith(second);
    391   }
    392 };
    393 
    394 // Extract and checks the forward input tensors, parameters, and shapes from the
    395 // OpKernelContext.
    396 Status ExtractForwardInput(OpKernelContext* context,
    397                            const CudnnModelTypes& model_types,
    398                            const Tensor** input, const Tensor** input_h,
    399                            const Tensor** input_c, const Tensor** params,
    400                            CudnnModelShapes* model_shapes) {
    401   TF_RETURN_IF_ERROR(context->input("input", input));
    402   TF_RETURN_IF_ERROR(context->input("input_h", input_h));
    403   if (model_types.HasInputC()) {
    404     TF_RETURN_IF_ERROR(context->input("input_c", input_c));
    405   }
    406   TF_RETURN_IF_ERROR(context->input("params", params));
    407 
    408   if ((*input)->dims() != 3) {
    409     return errors::InvalidArgument("RNN input must be a 3-D vector.");
    410   }
    411   model_shapes->seq_length = (*input)->dim_size(0);
    412   model_shapes->batch_size = (*input)->dim_size(1);
    413   model_shapes->input_size = (*input)->dim_size(2);
    414   model_shapes->input_shape = (*input)->shape();
    415   model_shapes->dir_count =
    416       (model_types.rnn_direction_mode == RnnDirectionMode::kRnnBidirectional)
    417           ? 2
    418           : 1;
    419 
    420   if ((*input_h)->dims() != 3) {
    421     return errors::InvalidArgument("RNN input must be a 3-D vector.");
    422   }
    423   model_shapes->num_layers = (*input_h)->dim_size(0) / model_shapes->dir_count;
    424   model_shapes->num_units = (*input_h)->dim_size(2);
    425 
    426   model_shapes->hidden_state_shape =
    427       TensorShape({model_shapes->dir_count * model_shapes->num_layers,
    428                    model_shapes->batch_size, model_shapes->num_units});
    429   if ((*input_h)->shape() != model_shapes->hidden_state_shape) {
    430     return errors::InvalidArgument(
    431         "Invalid input_h shape: ", (*input_h)->shape().DebugString(), " ",
    432         model_shapes->hidden_state_shape.DebugString());
    433   }
    434   if (model_types.HasInputC()) {
    435     if ((*input_h)->shape() != (*input_c)->shape()) {
    436       return errors::InvalidArgument(
    437           "input_h and input_c must have the same shape: ",
    438           (*input_h)->shape().DebugString(), " ",
    439           (*input_c)->shape().DebugString());
    440     }
    441   }
    442   model_shapes->output_shape =
    443       TensorShape({model_shapes->seq_length, model_shapes->batch_size,
    444                    model_shapes->dir_count * model_shapes->num_units});
    445   return Status::OK();
    446 }
    447 
    448 using perftools::gputools::dnn::RnnDescriptor;
    449 
    450 template <typename T>
    451 void RestoreParams(const OpInputList params_input,
    452                    const std::vector<RnnDescriptor::ParamsRegion>& params,
    453                    DeviceMemoryBase* data_dst,
    454                    perftools::gputools::Stream* stream) {
    455   int num_params = params.size();
    456   CHECK(params_input.size() == num_params)
    457       << "Number of params mismatch. Expected " << params_input.size()
    458       << ", got " << num_params;
    459   for (int i = 0; i < params.size(); i++) {
    460     int64 size_in_bytes = params[i].size;
    461     int64 size = size_in_bytes / sizeof(T);
    462     CHECK(size == params_input[i].NumElements())
    463         << "Params size mismatch. Expected " << size << ", got "
    464         << params_input[i].NumElements();
    465     auto data_src_ptr = StreamExecutorUtil::AsDeviceMemory<T>(params_input[i]);
    466     DeviceMemoryBase data_dst_ptr =
    467         SliceDeviceMemory(*data_dst, params[i].offset, size_in_bytes);
    468     stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
    469   }
    470 }
    471 
    472 }  // namespace
    473 
    474 // Note: all following kernels depend on a RnnDescriptor instance, which
    475 // according to Cudnn official doc should be kept around and reused across all
    476 // Cudnn kernels in the same model.
    477 // In Tensorflow, we don't pass the reference across different OpKernels,
    478 // rather, recreate it separately in each OpKernel, which does no cause issue:
    479 // CudnnDropoutDescriptor keeps a reference to a memory for
    480 // random number generator state. During recreation, this state is lost.
    481 // However, only forward-pass Cudnn APIs make use of the state.
    482 
    483 // A common base class for RNN kernels. It extracts common attributes and
    484 // shape validations.
    485 class CudnnRNNKernelCommon : public OpKernel {
    486  protected:
    487   explicit CudnnRNNKernelCommon(OpKernelConstruction* context)
    488       : OpKernel(context) {
    489     OP_REQUIRES_OK(context, context->GetAttr("dropout", &dropout_));
    490     OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_));
    491     OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_));
    492     string str;
    493     OP_REQUIRES_OK(context, context->GetAttr("rnn_mode", &str));
    494     OP_REQUIRES_OK(context, ParseRNNMode(str, &model_types_.rnn_mode));
    495     OP_REQUIRES_OK(context, context->GetAttr("input_mode", &str));
    496     OP_REQUIRES_OK(context,
    497                    ParseTFRNNInputMode(str, &model_types_.rnn_input_mode));
    498     OP_REQUIRES_OK(context, context->GetAttr("direction", &str));
    499     OP_REQUIRES_OK(
    500         context, ParseRNNDirectionMode(str, &model_types_.rnn_direction_mode));
    501     // Reset CudnnRnnDescriptor and related random number generate states in
    502     // every Compute() call.
    503     OP_REQUIRES_OK(context, ReadBoolFromEnvVar("TF_CUDNN_RESET_RND_GEN_STATE",
    504                                                false, &reset_rnd_gen_state_));
    505   }
    506 
    507   bool HasInputC() const { return model_types_.HasInputC(); }
    508   RnnMode rnn_mode() const { return model_types_.rnn_mode; }
    509   TFRNNInputMode rnn_input_mode() const { return model_types_.rnn_input_mode; }
    510   RnnDirectionMode rnn_direction_mode() const {
    511     return model_types_.rnn_direction_mode;
    512   }
    513   CudnnModelTypes model_types() const { return model_types_; }
    514   float dropout() const { return dropout_; }
    515   uint64 seed() { return (static_cast<uint64>(seed_) << 32) | seed2_; }
    516   bool ResetRndGenState() { return reset_rnd_gen_state_; }
    517 
    518   template <typename T>
    519   Status ExtractCudnnRNNParamsInfo(OpKernelContext* context,
    520                                    std::unique_ptr<RnnDescriptor>* rnn_desc) {
    521     const Tensor* num_layers_t = nullptr;
    522     TF_RETURN_IF_ERROR(context->input("num_layers", &num_layers_t));
    523     if (!TensorShapeUtils::IsScalar(num_layers_t->shape())) {
    524       return errors::InvalidArgument("num_layers is not a scalar");
    525     }
    526     int num_layers = num_layers_t->scalar<int>()();
    527     const Tensor* num_units_t = nullptr;
    528     TF_RETURN_IF_ERROR(context->input("num_units", &num_units_t));
    529     if (!TensorShapeUtils::IsScalar(num_units_t->shape())) {
    530       return errors::InvalidArgument("num_units is not a scalar");
    531     }
    532     int num_units = num_units_t->scalar<int>()();
    533     const Tensor* input_size_t = nullptr;
    534     TF_RETURN_IF_ERROR(context->input("input_size", &input_size_t));
    535     if (!TensorShapeUtils::IsScalar(input_size_t->shape())) {
    536       return errors::InvalidArgument("input_size is not a scalar");
    537     }
    538     int input_size = input_size_t->scalar<int>()();
    539 
    540     RnnInputMode input_mode;
    541     TF_RETURN_IF_ERROR(
    542         ToRNNInputMode(rnn_input_mode(), num_units, input_size, &input_mode));
    543 
    544     auto* stream = context->op_device_context()->stream();
    545     // ExtracCudnnRNNParamsInfo is only called by op_kernels that do not require
    546     // random number generator, therefore set state_allocator to nullptr.
    547     auto rnn_desc_s = stream->parent()->createRnnDescriptor(
    548         num_layers, num_units, input_size, input_mode, rnn_direction_mode(),
    549         rnn_mode(), ToDataType<T>::value, dropout(), seed(),
    550         nullptr /* state_allocator */);
    551     if (!rnn_desc_s.ok()) {
    552       return FromExecutorStatus(rnn_desc_s);
    553     }
    554     *rnn_desc = rnn_desc_s.ConsumeValueOrDie();
    555     return Status::OK();
    556   }
    557 
    558  private:
    559   int seed_;
    560   int seed2_;
    561   float dropout_;
    562   bool reset_rnd_gen_state_;
    563 
    564   CudnnModelTypes model_types_;
    565 };
    566 
    567 // A class that returns the size of the opaque parameter buffer. The user should
    568 // use that to create the actual parameter buffer for training. However, it
    569 // should not be used for saving and restoring.
    570 template <typename T, typename Index>
    571 class CudnnRNNParamsSizeOp<GPUDevice, T, Index> : public CudnnRNNKernelCommon {
    572  public:
    573   typedef GPUDevice Device;
    574   explicit CudnnRNNParamsSizeOp(OpKernelConstruction* context)
    575       : CudnnRNNKernelCommon(context) {}
    576 
    577   void Compute(OpKernelContext* context) override {
    578     std::unique_ptr<RnnDescriptor> rnn_desc;
    579     OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc));
    580     int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
    581     CHECK(params_size_in_bytes % sizeof(T) == 0)
    582         << "params_size_in_bytes must be multiple of element size";
    583     int64 params_size = params_size_in_bytes / sizeof(T);
    584 
    585     Tensor* output_t = nullptr;
    586     OP_REQUIRES_OK(context, context->allocate_output(0, {1}, &output_t));
    587     *output_t->template flat<Index>().data() = params_size;
    588   }
    589 };
    590 
    591 #define REGISTER_GPU(T)                                    \
    592   REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsSize")       \
    593                               .Device(DEVICE_GPU)          \
    594                               .HostMemory("num_layers")    \
    595                               .HostMemory("num_units")     \
    596                               .HostMemory("input_size")    \
    597                               .HostMemory("params_size")   \
    598                               .TypeConstraint<T>("T")      \
    599                               .TypeConstraint<int32>("S"), \
    600                           CudnnRNNParamsSizeOp<GPUDevice, T, int32>);
    601 
    602 TF_CALL_half(REGISTER_GPU);
    603 TF_CALL_float(REGISTER_GPU);
    604 TF_CALL_double(REGISTER_GPU);
    605 #undef REGISTER_GPU
    606 
    607 // Convert weight and bias params from a platform-specific layout to the
    608 // canonical form.
    609 template <typename T>
    610 class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
    611  public:
    612   typedef GPUDevice Device;
    613   explicit CudnnRNNParamsToCanonical(OpKernelConstruction* context)
    614       : CudnnRNNKernelCommon(context) {
    615     OP_REQUIRES_OK(context, context->GetAttr("num_params", &num_params_));
    616   }
    617 
    618   void Compute(OpKernelContext* context) override {
    619     const Tensor& input = context->input(3);
    620     auto input_ptr = StreamExecutorUtil::AsDeviceMemory<T>(input);
    621     auto* stream = context->op_device_context()->stream();
    622 
    623     std::unique_ptr<RnnDescriptor> rnn_desc;
    624     OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc));
    625     int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
    626     CHECK(params_size_in_bytes % sizeof(T) == 0)
    627         << "params_size_in_bytes must be multiple of element size";
    628 
    629     const Tensor* num_units_t = nullptr;
    630     OP_REQUIRES_OK(context, context->input("num_units", &num_units_t));
    631     CHECK(TensorShapeUtils::IsScalar(num_units_t->shape()))
    632         << "num_units is not a scalar";
    633     int num_units = num_units_t->scalar<int>()();
    634 
    635     const Tensor* input_size_t = nullptr;
    636     OP_REQUIRES_OK(context, context->input("input_size", &input_size_t));
    637     CHECK(TensorShapeUtils::IsScalar(input_size_t->shape()))
    638         << "input_size is not a scalar";
    639     int input_size = input_size_t->scalar<int>()();
    640 
    641     const Tensor* num_layers_t = nullptr;
    642     OP_REQUIRES_OK(context, context->input("num_layers", &num_layers_t));
    643     CHECK(TensorShapeUtils::IsScalar(num_layers_t->shape()))
    644         << "num_layers is not a scalar";
    645     int num_layers = num_layers_t->scalar<int>()();
    646     int num_dirs = 1;
    647     if (rnn_direction_mode() == RnnDirectionMode::kRnnBidirectional) {
    648       num_dirs = 2;
    649     }
    650     const int num_params_per_layer = num_params_ / num_layers / num_dirs;
    651     // Number of params applied on inputs. The rest are applied on recurrent
    652     // hidden states.
    653     const int num_params_input_state = num_params_per_layer / 2;
    654     CHECK(num_params_ % (num_layers * num_dirs) == 0)
    655         << "Number of params is not a multiple of num_layers * num_dirs.";
    656     CHECK(num_params_per_layer % 2 == 0)
    657         << "Number of params per layer is not a even number.";
    658 
    659     CHECK(num_params_ == rnn_desc->ParamsWeightRegions().size())
    660         << "Number of params mismatch. Expected " << num_params_ << ", got "
    661         << rnn_desc->ParamsWeightRegions().size();
    662     for (int i = 0; i < rnn_desc->ParamsWeightRegions().size(); i++) {
    663       int64 size_in_bytes = rnn_desc->ParamsWeightRegions()[i].size;
    664       int64 size = size_in_bytes / sizeof(T);
    665       const int layer_idx = i / num_params_per_layer;
    666       const int index_within_layer = i % num_params_per_layer;
    667       int width = 0, height = num_units;
    668       // In CuDNN layout, each layer has num_params_per_layer params, with the
    669       // first half a.k.a num_params_input_state params applied on the inputs,
    670       // and the second half on the recurrent hidden states.
    671       bool apply_on_input_state = index_within_layer < num_params_input_state;
    672       if (rnn_direction_mode() == RnnDirectionMode::kRnnUnidirectional) {
    673         if (layer_idx == 0 && apply_on_input_state) {
    674           width = input_size;
    675         } else {
    676           width = num_units;
    677         }
    678       } else {
    679         if (apply_on_input_state) {
    680           if (layer_idx <= 1) {
    681             // First fwd or bak layer.
    682             width = input_size;
    683           } else {
    684             // Following layers, cell inputs are concatenated outputs of
    685             // its prior layer.
    686             width = 2 * num_units;
    687           }
    688         } else {
    689           width = num_units;
    690         }
    691       }
    692       CHECK(size == width * height) << "Params size mismatch. Expected "
    693                                     << width * height << ", got " << size;
    694       Tensor* output = nullptr;
    695       OP_REQUIRES_OK(context, context->allocate_output(
    696                                   i, TensorShape({height, width}), &output));
    697       DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
    698           input_ptr, rnn_desc->ParamsWeightRegions()[i].offset, size_in_bytes);
    699       auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
    700       stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
    701     }
    702 
    703     OP_REQUIRES(context, num_params_ == rnn_desc->ParamsBiasRegions().size(),
    704                 errors::InvalidArgument("Number of params mismatch. Expected ",
    705                                         num_params_, ", got ",
    706                                         rnn_desc->ParamsBiasRegions().size()));
    707     for (int i = 0; i < rnn_desc->ParamsBiasRegions().size(); i++) {
    708       int64 size_in_bytes = rnn_desc->ParamsBiasRegions()[i].size;
    709       int64 size = size_in_bytes / sizeof(T);
    710       OP_REQUIRES(context, size == num_units,
    711                   errors::InvalidArgument("Params size mismatch. Expected ",
    712                                           num_units, ", got ", size));
    713 
    714       Tensor* output = nullptr;
    715       OP_REQUIRES_OK(context,
    716                      context->allocate_output(num_params_ + i,
    717                                               TensorShape({size}), &output));
    718       DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
    719           input_ptr, rnn_desc->ParamsBiasRegions()[i].offset, size_in_bytes);
    720       auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
    721       stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
    722     }
    723   }
    724 
    725  private:
    726   int num_params_;
    727 };
    728 
    729 #define REGISTER_GPU(T)                                     \
    730   REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsToCanonical") \
    731                               .Device(DEVICE_GPU)           \
    732                               .HostMemory("num_layers")     \
    733                               .HostMemory("num_units")      \
    734                               .HostMemory("input_size")     \
    735                               .TypeConstraint<T>("T"),      \
    736                           CudnnRNNParamsToCanonical<GPUDevice, T>);
    737 TF_CALL_half(REGISTER_GPU);
    738 TF_CALL_float(REGISTER_GPU);
    739 TF_CALL_double(REGISTER_GPU);
    740 #undef REGISTER_GPU
    741 
    742 // Convert weight and bias params from the canonical form to a
    743 // platform-specific layout.
    744 template <typename T>
    745 class CudnnRNNCanonicalToParams<GPUDevice, T> : public CudnnRNNKernelCommon {
    746  public:
    747   typedef GPUDevice Device;
    748   explicit CudnnRNNCanonicalToParams(OpKernelConstruction* context)
    749       : CudnnRNNKernelCommon(context) {}
    750 
    751   void Compute(OpKernelContext* context) override {
    752     std::unique_ptr<RnnDescriptor> rnn_desc;
    753     OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc));
    754     int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
    755     CHECK(params_size_in_bytes % sizeof(T) == 0)
    756         << "params_size_in_bytes must be multiple of element size";
    757     Tensor* output = nullptr;
    758     int params_size = params_size_in_bytes / sizeof(T);
    759     OP_REQUIRES_OK(context,
    760                    context->allocate_output(0, {params_size}, &output));
    761     auto output_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
    762     auto* stream = context->op_device_context()->stream();
    763 
    764     OpInputList weights;
    765     OP_REQUIRES_OK(context, context->input_list("weights", &weights));
    766     RestoreParams<T>(weights, rnn_desc->ParamsWeightRegions(), &output_ptr,
    767                      stream);
    768 
    769     OpInputList biases;
    770     OP_REQUIRES_OK(context, context->input_list("biases", &biases));
    771     RestoreParams<T>(biases, rnn_desc->ParamsBiasRegions(), &output_ptr,
    772                      stream);
    773   }
    774 };
    775 
    776 #define REGISTER_GPU(T)                                     \
    777   REGISTER_KERNEL_BUILDER(Name("CudnnRNNCanonicalToParams") \
    778                               .Device(DEVICE_GPU)           \
    779                               .HostMemory("num_layers")     \
    780                               .HostMemory("num_units")      \
    781                               .HostMemory("input_size")     \
    782                               .TypeConstraint<T>("T"),      \
    783                           CudnnRNNCanonicalToParams<GPUDevice, T>);
    784 TF_CALL_half(REGISTER_GPU);
    785 TF_CALL_float(REGISTER_GPU);
    786 TF_CALL_double(REGISTER_GPU);
    787 #undef REGISTER_GPU
    788 
    789 // Pointers to RNN scratch space for a specific set of shape parameters (used as
    790 // a hash table value in CudnnRNNForwardOp and CudnnRNNBackwardOp).
    791 struct RnnScratchSpace {
    792   std::unique_ptr<RnnDescriptor> rnn_desc;
    793   std::unique_ptr<CudnnRNNPersistentSpaceAllocator> dropout_state_allocator;
    794 };
    795 
    796 // Run the forward operation of the RNN model.
    797 template <typename T>
    798 class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
    799  public:
    800   typedef GPUDevice Device;
    801   explicit CudnnRNNForwardOp(OpKernelConstruction* context)
    802       : CudnnRNNKernelCommon(context) {
    803     OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
    804   }
    805 
    806   void Compute(OpKernelContext* context) override {
    807     const Tensor* input = nullptr;
    808     const Tensor* input_h = nullptr;
    809     const Tensor* input_c = nullptr;
    810     const Tensor* params = nullptr;
    811     CudnnModelShapes model_shapes;
    812     OP_REQUIRES_OK(context,
    813                    ExtractForwardInput(context, model_types(), &input, &input_h,
    814                                        &input_c, &params, &model_shapes));
    815     const auto& input_shape = model_shapes.input_shape;
    816     const auto& hidden_state_shape = model_shapes.hidden_state_shape;
    817     const auto& output_shape = model_shapes.output_shape;
    818 
    819     Tensor* output = nullptr;
    820     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
    821     Tensor* output_h = nullptr;
    822     OP_REQUIRES_OK(context,
    823                    context->allocate_output(1, hidden_state_shape, &output_h));
    824     Tensor* output_c = nullptr;
    825     if (HasInputC()) {
    826       // Only LSTM uses input_c and output_c. So for all other models, we only
    827       // need to create dummy outputs.
    828       OP_REQUIRES_OK(
    829           context, context->allocate_output(2, hidden_state_shape, &output_c));
    830     } else {
    831       OP_REQUIRES_OK(context, context->allocate_output(2, {}, &output_c));
    832     }
    833 
    834     auto* stream = context->op_device_context()->stream();
    835     auto* executor = stream->parent();
    836     RnnInputMode input_mode;
    837     OP_REQUIRES_OK(context,
    838                    ToRNNInputMode(rnn_input_mode(), model_shapes.num_units,
    839                                   model_shapes.input_size, &input_mode));
    840     auto data_type = ToDataType<T>::value;
    841 
    842     auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
    843         input_shape.dim_size(0), input_shape.dim_size(1),
    844         input_shape.dim_size(2), data_type);
    845     OP_REQUIRES_OK(context, FromExecutorStatus(input_desc_s));
    846     auto input_desc = input_desc_s.ConsumeValueOrDie();
    847 
    848     auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
    849         hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1),
    850         hidden_state_shape.dim_size(2), data_type);
    851     OP_REQUIRES_OK(context, FromExecutorStatus(hidden_state_desc_s));
    852     auto hidden_state_desc = hidden_state_desc_s.ConsumeValueOrDie();
    853 
    854     auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
    855         output_shape.dim_size(0), output_shape.dim_size(1),
    856         output_shape.dim_size(2), data_type);
    857     OP_REQUIRES_OK(context, FromExecutorStatus(output_desc_s));
    858     auto output_desc = output_desc_s.ConsumeValueOrDie();
    859 
    860     auto input_data = AsDeviceMemory<T>(input);
    861     auto input_h_data = AsDeviceMemory<T>(input_h);
    862     DeviceMemory<T> input_c_data;
    863     if (HasInputC()) {
    864       input_c_data = AsDeviceMemory<T>(input_c);
    865     }
    866     auto params_data = AsDeviceMemory<T>(params);
    867     auto output_data = AsDeviceMemory<T>(output);
    868     auto output_h_data = AsDeviceMemory<T>(output_h);
    869     DeviceMemory<T> output_c_data;
    870     if (HasInputC()) {
    871       output_c_data = AsDeviceMemory<T>(output_c);
    872     }
    873 
    874     // Creates a memory callback for the reserve_space. The memory lives in the
    875     // output of this kernel. And it will be fed into the backward pass when
    876     // needed.
    877     CudnnRNNReserveSpaceAllocator<T> reserve_space_allocator(context, 3);
    878     if (!is_training_) {
    879       Tensor* dummy_reserve_space = nullptr;
    880       OP_REQUIRES_OK(context,
    881                      context->allocate_output(3, {}, &dummy_reserve_space));
    882     }
    883     // Creates a memory callback for the workspace. The memory lives to the end
    884     // of this kernel calls.
    885     CudnnRNNWorkspaceAllocator workspace_allocator(context);
    886     bool launch_status = false;
    887     {
    888       mutex_lock l(mu_);
    889       RnnScratchSpace& rnn_state = rnn_state_cache_[model_shapes];
    890       if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) {
    891         CudnnRNNPersistentSpaceAllocator* dropout_state_allocator =
    892             new CudnnRNNPersistentSpaceAllocator(context);
    893         rnn_state.dropout_state_allocator.reset(dropout_state_allocator);
    894         auto rnn_desc_s = executor->createRnnDescriptor(
    895             model_shapes.num_layers, model_shapes.num_units,
    896             model_shapes.input_size, input_mode, rnn_direction_mode(),
    897             rnn_mode(), data_type, dropout(), seed(), dropout_state_allocator);
    898         OP_REQUIRES_OK(context, FromExecutorStatus(rnn_desc_s));
    899         rnn_state.rnn_desc = std::move(rnn_desc_s.ConsumeValueOrDie());
    900       }
    901       launch_status =
    902           stream
    903               ->ThenRnnForward(*rnn_state.rnn_desc, *input_desc, input_data,
    904                                *hidden_state_desc, input_h_data,
    905                                *hidden_state_desc, input_c_data, params_data,
    906                                *output_desc, &output_data, *hidden_state_desc,
    907                                &output_h_data, *hidden_state_desc,
    908                                &output_c_data, is_training_,
    909                                &reserve_space_allocator, &workspace_allocator)
    910               .ok();
    911     }
    912     OP_REQUIRES(context, launch_status,
    913                 errors::Internal("Failed to call ThenRnnForward"));
    914   }
    915 
    916  private:
    917   mutex mu_;
    918   bool is_training_;
    919   std::unordered_map<CudnnModelShapes, RnnScratchSpace, CudnnModelShapesHasher,
    920                      CudnnModelShapesComparator>
    921       rnn_state_cache_ GUARDED_BY(mu_);
    922 };
    923 
    924 #define REGISTER_GPU(T)                                           \
    925   REGISTER_KERNEL_BUILDER(                                        \
    926       Name("CudnnRNN").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
    927       CudnnRNNForwardOp<GPUDevice, T>);
    928 
    929 TF_CALL_half(REGISTER_GPU);
    930 TF_CALL_float(REGISTER_GPU);
    931 TF_CALL_double(REGISTER_GPU);
    932 #undef REGISTER_GPU
    933 
    934 // Run the backward operation of the RNN model.
    935 template <typename T>
    936 class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
    937  public:
    938   typedef GPUDevice Device;
    939 
    940   explicit CudnnRNNBackwardOp(OpKernelConstruction* context)
    941       : CudnnRNNKernelCommon(context) {}
    942 
    943   void Compute(OpKernelContext* context) override {
    944     const Tensor* input = nullptr;
    945     const Tensor* input_h = nullptr;
    946     const Tensor* input_c = nullptr;
    947     const Tensor* params = nullptr;
    948     CudnnModelShapes model_shapes;
    949     OP_REQUIRES_OK(context,
    950                    ExtractForwardInput(context, model_types(), &input, &input_h,
    951                                        &input_c, &params, &model_shapes));
    952 
    953     const auto& input_shape = model_shapes.input_shape;
    954     const auto& hidden_state_shape = model_shapes.hidden_state_shape;
    955     const auto& output_shape = model_shapes.output_shape;
    956 
    957     auto data_type = ToDataType<T>::value;
    958     const Tensor* output = nullptr;
    959     OP_REQUIRES_OK(context, context->input("output", &output));
    960     OP_REQUIRES(context, output_shape == output->shape(),
    961                 errors::InvalidArgument(
    962                     "input_h and input_c must have the same shape: ",
    963                     input_h->shape().DebugString(), " ",
    964                     input_c->shape().DebugString()));
    965     const Tensor* output_h = nullptr;
    966     OP_REQUIRES_OK(context, context->input("output_h", &output_h));
    967     OP_REQUIRES(context, output_h->shape() == hidden_state_shape,
    968                 errors::InvalidArgument(
    969                     "Invalid output_h shape: ", output_h->shape().DebugString(),
    970                     " ", hidden_state_shape.DebugString()));
    971     const Tensor* output_c = nullptr;
    972     if (HasInputC()) {
    973       // Only LSTM uses input_c and output_c. So for all other models, we only
    974       // need to create dummy outputs.
    975       OP_REQUIRES_OK(context, context->input("output_c", &output_c));
    976       OP_REQUIRES(context, output_c->shape() == hidden_state_shape,
    977                   errors::InvalidArgument("Invalid output_c shape: ",
    978                                           output_c->shape().DebugString(), " ",
    979                                           hidden_state_shape.DebugString()));
    980     }
    981 
    982     const Tensor* output_backprop = nullptr;
    983     OP_REQUIRES_OK(context,
    984                    context->input("output_backprop", &output_backprop));
    985     OP_REQUIRES(context, output_backprop->shape() == output_shape,
    986                 errors::InvalidArgument("Invalid output_backprop shapes: ",
    987                                         output_backprop->shape().DebugString(),
    988                                         " ", output_shape.DebugString()));
    989 
    990     const Tensor* output_h_backprop = nullptr;
    991     OP_REQUIRES_OK(context,
    992                    context->input("output_h_backprop", &output_h_backprop));
    993     OP_REQUIRES(
    994         context, output_h_backprop->shape() == hidden_state_shape,
    995         errors::InvalidArgument("Invalid output_h_backprop shapes: ",
    996                                 output_h_backprop->shape().DebugString(), " ",
    997                                 hidden_state_shape.DebugString()));
    998     const Tensor* output_c_backprop = nullptr;
    999     if (HasInputC()) {
   1000       OP_REQUIRES_OK(context,
   1001                      context->input("output_c_backprop", &output_c_backprop));
   1002       OP_REQUIRES(
   1003           context, output_c_backprop->shape() == hidden_state_shape,
   1004           errors::InvalidArgument("Invalid output_c_backprop shapes: ",
   1005                                   output_c_backprop->shape().DebugString(), " ",
   1006                                   hidden_state_shape.DebugString()));
   1007     }
   1008     const Tensor* reserve_space_const = nullptr;
   1009     // This is the same "reserve_space" created by the forward op.
   1010     // It can also be modified by this backward operation.
   1011     OP_REQUIRES_OK(context,
   1012                    context->input("reserve_space", &reserve_space_const));
   1013     // Cudnn needs the reserve space to be writeable. This is fine because they
   1014     // are opaque.
   1015     Tensor* reserve_space = const_cast<Tensor*>(reserve_space_const);
   1016 
   1017     Tensor* input_backprop = nullptr;
   1018     OP_REQUIRES_OK(
   1019         context, context->allocate_output(0, input->shape(), &input_backprop));
   1020     Tensor* input_h_backprop = nullptr;
   1021     OP_REQUIRES_OK(context, context->allocate_output(1, input_h->shape(),
   1022                                                      &input_h_backprop));
   1023     Tensor* input_c_backprop = nullptr;
   1024     if (HasInputC()) {
   1025       OP_REQUIRES_OK(context, context->allocate_output(2, input_c->shape(),
   1026                                                        &input_c_backprop));
   1027     } else {
   1028       OP_REQUIRES_OK(context,
   1029                      context->allocate_output(2, {}, &input_c_backprop));
   1030     }
   1031     Tensor* params_backprop = nullptr;
   1032     OP_REQUIRES_OK(context, context->allocate_output(3, params->shape(),
   1033                                                      &params_backprop));
   1034 
   1035     auto* stream = context->op_device_context()->stream();
   1036     auto* executor = stream->parent();
   1037     RnnInputMode input_mode;
   1038     OP_REQUIRES_OK(context,
   1039                    ToRNNInputMode(rnn_input_mode(), model_shapes.num_units,
   1040                                   model_shapes.input_size, &input_mode));
   1041 
   1042     auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
   1043         input_shape.dim_size(0), input_shape.dim_size(1),
   1044         input_shape.dim_size(2), data_type);
   1045     OP_REQUIRES_OK(context, FromExecutorStatus(input_desc_s));
   1046     auto input_desc = input_desc_s.ConsumeValueOrDie();
   1047 
   1048     auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
   1049         hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1),
   1050         hidden_state_shape.dim_size(2), data_type);
   1051     OP_REQUIRES_OK(context, FromExecutorStatus(hidden_state_desc_s));
   1052     auto hidden_state_desc = hidden_state_desc_s.ConsumeValueOrDie();
   1053 
   1054     auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
   1055         output_shape.dim_size(0), output_shape.dim_size(1),
   1056         output_shape.dim_size(2), data_type);
   1057     OP_REQUIRES_OK(context, FromExecutorStatus(output_desc_s));
   1058     auto output_desc = output_desc_s.ConsumeValueOrDie();
   1059 
   1060     auto input_data = AsDeviceMemory<T>(input);
   1061     auto input_h_data = AsDeviceMemory<T>(input_h);
   1062     DeviceMemory<T> input_c_data;
   1063     if (HasInputC()) {
   1064       input_c_data = AsDeviceMemory<T>(input_c);
   1065     }
   1066     auto params_data = AsDeviceMemory<T>(params);
   1067     auto output_data = AsDeviceMemory<T>(output);
   1068     auto output_h_data = AsDeviceMemory<T>(output_h);
   1069     DeviceMemory<T> output_c_data;
   1070     if (HasInputC()) {
   1071       output_c_data = AsDeviceMemory<T>(output_c);
   1072     }
   1073     auto output_backprop_data = AsDeviceMemory<T>(output_backprop);
   1074     auto output_h_backprop_data = AsDeviceMemory<T>(output_h_backprop);
   1075     DeviceMemory<T> output_c_backprop_data;
   1076     if (HasInputC()) {
   1077       output_c_backprop_data = AsDeviceMemory<T>(output_c_backprop);
   1078     }
   1079     auto input_backprop_data = AsDeviceMemory<T>(input_backprop);
   1080     auto input_h_backprop_data = AsDeviceMemory<T>(input_h_backprop);
   1081     DeviceMemory<T> input_c_backprop_data;
   1082     if (HasInputC()) {
   1083       input_c_backprop_data = AsDeviceMemory<T>(input_c_backprop);
   1084     }
   1085     auto params_backprop_data = AsDeviceMemory<T>(params_backprop);
   1086     auto reserve_space_uint8 = CastDeviceMemory<uint8, T>(reserve_space);
   1087     // Creates a memory callback for the workspace. The memory lives to the end
   1088     // of this kernel calls.
   1089     CudnnRNNWorkspaceAllocator workspace_allocator(context);
   1090     bool launch_status = false;
   1091     {
   1092       mutex_lock l(mu_);
   1093       RnnScratchSpace& rnn_state = rnn_state_cache_[model_shapes];
   1094       if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) {
   1095         CudnnRNNPersistentSpaceAllocator* dropout_state_allocator =
   1096             new CudnnRNNPersistentSpaceAllocator(context);
   1097         rnn_state.dropout_state_allocator.reset(dropout_state_allocator);
   1098         auto rnn_desc_s = executor->createRnnDescriptor(
   1099             model_shapes.num_layers, model_shapes.num_units,
   1100             model_shapes.input_size, input_mode, rnn_direction_mode(),
   1101             rnn_mode(), data_type, dropout(), seed(), dropout_state_allocator);
   1102         OP_REQUIRES_OK(context, FromExecutorStatus(rnn_desc_s));
   1103         rnn_state.rnn_desc = std::move(rnn_desc_s.ConsumeValueOrDie());
   1104       }
   1105       launch_status =
   1106           stream
   1107               ->ThenRnnBackward(*rnn_state.rnn_desc, *input_desc, input_data,
   1108                                 *hidden_state_desc, input_h_data,
   1109                                 *hidden_state_desc, input_c_data, params_data,
   1110                                 *output_desc, output_data, *hidden_state_desc,
   1111                                 output_h_data, *hidden_state_desc,
   1112                                 output_c_data, output_backprop_data,
   1113                                 output_h_backprop_data, output_c_backprop_data,
   1114                                 &input_backprop_data, &input_h_backprop_data,
   1115                                 &input_c_backprop_data, &params_backprop_data,
   1116                                 &reserve_space_uint8, &workspace_allocator)
   1117               .ok();
   1118     }
   1119     OP_REQUIRES(context, launch_status,
   1120                 errors::Internal("Failed to call ThenRnnBackward"));
   1121   }
   1122 
   1123  private:
   1124   mutex mu_;
   1125   std::unordered_map<CudnnModelShapes, RnnScratchSpace, CudnnModelShapesHasher,
   1126                      CudnnModelShapesComparator>
   1127       rnn_state_cache_ GUARDED_BY(mu_);
   1128 };
   1129 
   1130 #define REGISTER_GPU(T)                                                   \
   1131   REGISTER_KERNEL_BUILDER(                                                \
   1132       Name("CudnnRNNBackprop").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
   1133       CudnnRNNBackwardOp<GPUDevice, T>);
   1134 
   1135 TF_CALL_half(REGISTER_GPU);
   1136 TF_CALL_float(REGISTER_GPU);
   1137 TF_CALL_double(REGISTER_GPU);
   1138 #undef REGISTER_GPU
   1139 
   1140 // TODO(zhengxq): Add the conversion of Cudnn RNN Params from and to
   1141 // its canonical form.
   1142 
   1143 #endif  // GOOGLE_CUDA
   1144 
   1145 }  // namespace tensorflow
   1146