Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #define EIGEN_USE_THREADS
     16 
     17 #include <stddef.h>
     18 #include <atomic>
     19 #include <cmath>
     20 #include <functional>
     21 #include <limits>
     22 #include <string>
     23 #include <unordered_set>
     24 
     25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     26 #include "tensorflow/core/framework/device_base.h"
     27 #include "tensorflow/core/framework/kernel_def_builder.h"
     28 #include "tensorflow/core/framework/op.h"
     29 #include "tensorflow/core/framework/op_def_builder.h"
     30 #include "tensorflow/core/framework/op_kernel.h"
     31 #include "tensorflow/core/framework/register_types.h"
     32 #include "tensorflow/core/framework/tensor.h"
     33 #include "tensorflow/core/framework/tensor_shape.h"
     34 #include "tensorflow/core/framework/tensor_types.h"
     35 #include "tensorflow/core/framework/types.h"
     36 #include "tensorflow/core/kernels/gpu_utils.h"
     37 #include "tensorflow/core/lib/core/errors.h"
     38 #include "tensorflow/core/lib/core/status.h"
     39 #include "tensorflow/core/lib/core/stringpiece.h"
     40 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     41 #include "tensorflow/core/lib/hash/hash.h"
     42 #include "tensorflow/core/lib/strings/stringprintf.h"
     43 #include "tensorflow/core/platform/fingerprint.h"
     44 #include "tensorflow/core/platform/mutex.h"
     45 #include "tensorflow/core/platform/types.h"
     46 #include "tensorflow/core/util/env_var.h"
     47 #include "tensorflow/core/util/use_cudnn.h"
     48 
     49 #if GOOGLE_CUDA
     50 #include "tensorflow/core/platform/stream_executor.h"
     51 #include "tensorflow/core/util/stream_executor_util.h"
     52 #endif  // GOOGLE_CUDA
     53 
     54 /*
     55  * This module implements ops that fuse a multi-layer multi-step RNN/LSTM model
     56  * using the underlying Cudnn library.
     57  *
     58  * Cudnn RNN library exposes an opaque parameter buffer with unknown layout and
     59  * format. And it is very likely that if saved, they cannot be used across
     60  * different GPUs. So users need to first query the size of the opaque
     61  * parameter buffer, and convert it to and from its canonical forms. But each
     62  * actual training step is carried out with the parameter buffer.
     63  *
     64  * Similar to many other ops, the forward op has two flavors: training and
     65  * inference. When training is specified, additional data in reserve_space will
     66  * be produced for the backward pass. So there is a performance penalty.
     67  *
     68  * In addition to the actual data and reserve_space, Cudnn also needs more
     69  * memory as temporary workspace. The memory management to and from
     70  * stream-executor is done through ScratchAllocator. In general,
     71  * stream-executor is responsible for creating the memory of proper size. And
     72  * TensorFlow is responsible for making sure the memory is alive long enough
     73  * and recycles afterwards.
     74  *
     75  */
     76 namespace tensorflow {
     77 
     78 using CPUDevice = Eigen::ThreadPoolDevice;
     79 
     80 #if GOOGLE_CUDA
     81 
     82 using GPUDevice = Eigen::GpuDevice;
     83 using se::Stream;
     84 using se::StreamExecutor;
     85 using se::dnn::RnnDescriptor;
     86 
     87 template <typename Device, typename T, typename Index>
     88 class CudnnRNNParamsSizeOp;
     89 
     90 template <typename Device, typename T>
     91 class CudnnRNNParamsToCanonical;
     92 
     93 template <typename Device, typename T>
     94 class CudnnRNNCanonicalToParams;
     95 
     96 template <typename Device, typename T>
     97 class CudnnRNNForwardOp;
     98 
     99 template <typename Device, typename T>
    100 class CudnnRNNBackwardOp;
    101 
    102 template <typename Device, typename T>
    103 class CudnnRNNForwardOpV2;
    104 
    105 template <typename Device, typename T>
    106 class CudnnRNNBackwardOpV2;
    107 
    108 template <typename Device, typename T>
    109 class CudnnRNNForwardOpV3;
    110 
    111 template <typename Device, typename T>
    112 class CudnnRNNBackwardOpV3;
    113 
    114 enum class TFRNNInputMode {
    115   kRNNLinearInput = 0,
    116   kRNNSkipInput = 1,
    117   kAutoSelect = 9999999
    118 };
    119 
    120 namespace {
    121 using se::DeviceMemory;
    122 using se::DeviceMemoryBase;
    123 using se::ScratchAllocator;
    124 using se::dnn::AlgorithmConfig;
    125 using se::dnn::AlgorithmDesc;
    126 using se::dnn::ProfileResult;
    127 using se::dnn::RnnDirectionMode;
    128 using se::dnn::RnnInputMode;
    129 using se::dnn::RnnMode;
    130 using se::dnn::RnnSequenceTensorDescriptor;
    131 using se::dnn::RnnStateTensorDescriptor;
    132 using se::dnn::ToDataType;
    133 using se::port::StatusOr;
    134 
    135 uint64 HashList(const std::vector<int>& list) {
    136   if (list.empty()) {
    137     return 0;
    138   }
    139   uint64 hash_code = list[0];
    140   for (int i = 1; i < list.size(); i++) {
    141     hash_code = Hash64Combine(hash_code, list[i]);
    142   }
    143   return hash_code;
    144 }
    145 
    146 // Encapsulate all the shape information that is used in both forward and
    147 // backward rnn operations.
    148 class CudnnRnnParameters {
    149  public:
    150   CudnnRnnParameters(int num_layers, int input_size, int num_units,
    151                      int max_seq_length, int batch_size, int dir_count,
    152                      bool has_dropout, bool is_training, RnnMode rnn_mode,
    153                      TFRNNInputMode rnn_input_mode, DataType dtype)
    154       : num_layers_(num_layers),
    155         input_size_(input_size),
    156         num_units_(num_units),
    157         seq_length_(max_seq_length),
    158         batch_size_(batch_size),
    159         dir_count_(dir_count),
    160         has_dropout_(has_dropout),
    161         is_training_(is_training),
    162         rnn_mode_(rnn_mode),
    163         rnn_input_mode_(rnn_input_mode),
    164         dtype_(dtype) {
    165     hash_code_ =
    166         HashList({num_layers, input_size, num_units, max_seq_length, batch_size,
    167                   dir_count, static_cast<int>(has_dropout),
    168                   static_cast<int>(is_training), static_cast<int>(rnn_mode),
    169                   static_cast<int>(rnn_input_mode), dtype});
    170   }
    171 
    172   bool operator==(const CudnnRnnParameters& other) const {
    173     return this->get_data_as_tuple() == other.get_data_as_tuple();
    174   }
    175 
    176   bool operator!=(const CudnnRnnParameters& other) const {
    177     return !(*this == other);
    178   }
    179   uint64 hash() const { return hash_code_; }
    180 
    181   string ToString() const {
    182     std::vector<string> fields = {
    183         std::to_string(num_layers_),
    184         std::to_string(input_size_),
    185         std::to_string(num_units_),
    186         std::to_string(seq_length_),
    187         std::to_string(batch_size_),
    188         std::to_string(dir_count_),
    189         std::to_string(has_dropout_),
    190         std::to_string(is_training_),
    191         std::to_string(static_cast<int>(rnn_mode_)),
    192         std::to_string(static_cast<int>(rnn_input_mode_)),
    193         std::to_string(static_cast<int>(dtype_))};
    194     return str_util::Join(fields, ", ");
    195   }
    196 
    197  private:
    198   using ParameterDataType = std::tuple<int, int, int, int, int, int, bool, bool,
    199                                        RnnMode, TFRNNInputMode, DataType>;
    200 
    201   ParameterDataType get_data_as_tuple() const {
    202     return std::make_tuple(num_layers_, input_size_, num_units_, seq_length_,
    203                            batch_size_, dir_count_, has_dropout_, is_training_,
    204                            rnn_mode_, rnn_input_mode_, dtype_);
    205   }
    206 
    207   const int num_layers_;
    208   const int input_size_;
    209   const int num_units_;
    210   const int seq_length_;
    211   const int batch_size_;
    212   const int dir_count_;
    213   const bool has_dropout_;
    214   const bool is_training_;
    215   const RnnMode rnn_mode_;
    216   const TFRNNInputMode rnn_input_mode_;
    217   const DataType dtype_;
    218   uint64 hash_code_;
    219 };
    220 
    221 struct RnnAutoTuneGroup {
    222   static string name() { return "Rnn"; }
    223 };
    224 
    225 using AutoTuneRnnConfigMap =
    226     AutoTuneSingleton<RnnAutoTuneGroup, CudnnRnnParameters, AlgorithmConfig>;
    227 
    228 Status ParseRNNMode(const string& str, RnnMode* rnn_mode) {
    229   if (str == "rnn_relu") {
    230     *rnn_mode = RnnMode::kRnnRelu;
    231     return Status::OK();
    232   } else if (str == "rnn_tanh") {
    233     *rnn_mode = RnnMode::kRnnTanh;
    234     return Status::OK();
    235   } else if (str == "lstm") {
    236     *rnn_mode = RnnMode::kRnnLstm;
    237     return Status::OK();
    238   } else if (str == "gru") {
    239     *rnn_mode = RnnMode::kRnnGru;
    240     return Status::OK();
    241   }
    242   return errors::InvalidArgument("Invalid RNN mode: ", str);
    243 }
    244 
    245 Status ParseTFRNNInputMode(const string& str, TFRNNInputMode* rnn_input_mode) {
    246   if (str == "linear_input") {
    247     *rnn_input_mode = TFRNNInputMode::kRNNLinearInput;
    248     return Status::OK();
    249   } else if (str == "skip_input") {
    250     *rnn_input_mode = TFRNNInputMode::kRNNSkipInput;
    251     return Status::OK();
    252   } else if (str == "auto_select") {
    253     *rnn_input_mode = TFRNNInputMode::kAutoSelect;
    254     return Status::OK();
    255   }
    256   return errors::InvalidArgument("Invalid RNN input mode: ", str);
    257 }
    258 
    259 Status ParseRNNDirectionMode(const string& str,
    260                              RnnDirectionMode* rnn_dir_mode) {
    261   if (str == "unidirectional") {
    262     *rnn_dir_mode = RnnDirectionMode::kRnnUnidirectional;
    263     return Status::OK();
    264   } else if (str == "bidirectional") {
    265     *rnn_dir_mode = RnnDirectionMode::kRnnBidirectional;
    266     return Status::OK();
    267   }
    268   return errors::InvalidArgument("Invalid RNN direction mode: ", str);
    269 }
    270 
    271 Status ToRNNInputMode(TFRNNInputMode tf_input_mode, int num_units,
    272                       int input_size, RnnInputMode* input_mode) {
    273   switch (tf_input_mode) {
    274     case TFRNNInputMode::kRNNLinearInput:
    275       *input_mode = RnnInputMode::kRnnLinearSkip;
    276       break;
    277     case TFRNNInputMode::kRNNSkipInput:
    278       *input_mode = RnnInputMode::kRnnSkipInput;
    279       break;
    280     case TFRNNInputMode::kAutoSelect:
    281       *input_mode = (input_size == num_units) ? RnnInputMode::kRnnSkipInput
    282                                               : RnnInputMode::kRnnLinearSkip;
    283       break;
    284     default:
    285       return errors::InvalidArgument("Invalid TF input mode: ",
    286                                      static_cast<int>(tf_input_mode));
    287   }
    288   return Status::OK();
    289 }
    290 
    291 // TODO(zhengxq): Merge those into stream_executor_util.h.
    292 template <typename T>
    293 const DeviceMemory<T> AsDeviceMemory(const Tensor* tensor) {
    294   return DeviceMemory<T>::MakeFromByteSize(
    295       const_cast<T*>(tensor->template flat<T>().data()),
    296       tensor->template flat<T>().size() * sizeof(T));
    297 }
    298 
    299 template <typename T>
    300 DeviceMemory<T> AsDeviceMemory(Tensor* tensor) {
    301   return DeviceMemory<T>::MakeFromByteSize(
    302       tensor->template flat<T>().data(),
    303       tensor->template flat<T>().size() * sizeof(T));
    304 }
    305 
    306 template <typename U, typename T>
    307 DeviceMemory<U> CastDeviceMemory(Tensor* tensor) {
    308   return DeviceMemory<U>::MakeFromByteSize(
    309       tensor->template flat<T>().data(),
    310       tensor->template flat<T>().size() * sizeof(T));
    311 }
    312 
    313 DeviceMemoryBase SliceDeviceMemory(const DeviceMemoryBase& device_memory,
    314                                    int64 offset, int64 size) {
    315   const void* base_ptr = device_memory.opaque();
    316   void* offset_ptr =
    317       const_cast<char*>(reinterpret_cast<const char*>(base_ptr) + offset);
    318   CHECK(offset + size <= device_memory.size())
    319       << "The slice is not within the region of DeviceMemory.";
    320   return DeviceMemoryBase(offset_ptr, size);
    321 }
    322 
    323 inline Status FromExecutorStatus(const se::port::Status& s) {
    324   return s.ok() ? Status::OK()
    325                 : Status(static_cast<error::Code>(static_cast<int>(s.code())),
    326                          s.error_message());
    327 }
    328 
    329 template <typename T>
    330 inline Status FromExecutorStatus(const se::port::StatusOr<T>& s) {
    331   return FromExecutorStatus(s.status());
    332 }
    333 
    334 inline se::port::Status ToExecutorStatus(const Status& s) {
    335   return s.ok() ? se::port::Status::OK()
    336                 : se::port::Status(static_cast<se::port::error::Code>(
    337                                        static_cast<int>(s.code())),
    338                                    s.error_message());
    339 }
    340 
    341 template <typename>
    342 struct ToTFDataType;
    343 
    344 template <>
    345 struct ToTFDataType<Eigen::half> : std::integral_constant<DataType, DT_HALF> {};
    346 
    347 template <>
    348 struct ToTFDataType<float> : std::integral_constant<DataType, DT_FLOAT> {};
    349 
    350 template <>
    351 struct ToTFDataType<double> : std::integral_constant<DataType, DT_DOUBLE> {};
    352 
    353 template <>
    354 struct ToTFDataType<uint8> : std::integral_constant<DataType, DT_UINT8> {};
    355 
    356 // A helper to allocate temporary scratch memory for Cudnn RNN models. It
    357 // takes the ownership of the underlying memory. The expectation is that the
    358 // memory should be alive for the span of the Cudnn RNN itself.
    359 template <typename T>
    360 class CudnnRnnAllocatorInTemp : public ScratchAllocator {
    361  public:
    362   ~CudnnRnnAllocatorInTemp() override = default;
    363 
    364   explicit CudnnRnnAllocatorInTemp(OpKernelContext* context)
    365       : context_(context) {}
    366   int64 GetMemoryLimitInBytes(Stream* stream) override {
    367     return std::numeric_limits<int64>::max();
    368   }
    369 
    370   StatusOr<DeviceMemory<uint8>> AllocateBytes(Stream* stream,
    371                                               int64 byte_size) override {
    372     Tensor temporary_memory;
    373     const DataType tf_data_type = ToTFDataType<T>::value;
    374     int64 allocate_count =
    375         Eigen::divup(byte_size, static_cast<int64>(sizeof(T)));
    376     Status allocation_status(context_->allocate_temp(
    377         tf_data_type, TensorShape({allocate_count}), &temporary_memory));
    378     if (!allocation_status.ok()) {
    379       return ToExecutorStatus(allocation_status);
    380     }
    381     // Hold the reference of the allocated tensors until the end of the
    382     // allocator.
    383     allocated_tensors_.push_back(temporary_memory);
    384     total_byte_size_ += byte_size;
    385     return DeviceMemory<uint8>::MakeFromByteSize(
    386         temporary_memory.template flat<T>().data(),
    387         temporary_memory.template flat<T>().size() * sizeof(T));
    388   }
    389 
    390   int64 TotalByteSize() const { return total_byte_size_; }
    391 
    392   Tensor get_allocated_tensor(int index) const {
    393     return allocated_tensors_[index];
    394   }
    395 
    396  private:
    397   int64 total_byte_size_ = 0;
    398   OpKernelContext* context_;  // not owned
    399   std::vector<Tensor> allocated_tensors_;
    400 };
    401 
    402 // A helper to allocate memory for Cudnn RNN models as a kernel output. It is
    403 // used by forward pass kernel to feed the output to the backward pass.
    404 // The memory is expected to live long enough after the backward pass is
    405 // finished.
    406 template <typename T>
    407 class CudnnRnnAllocatorInOutput : public ScratchAllocator {
    408  public:
    409   ~CudnnRnnAllocatorInOutput() override {}
    410   CudnnRnnAllocatorInOutput(OpKernelContext* context, int output_index)
    411       : context_(context), output_index_(output_index) {}
    412   int64 GetMemoryLimitInBytes(Stream* stream) override {
    413     return std::numeric_limits<int64>::max();
    414   }
    415   StatusOr<DeviceMemory<uint8>> AllocateBytes(Stream* stream,
    416                                               int64 byte_size) override {
    417     CHECK(total_byte_size_ == 0)
    418         << "Reserve space allocator can only be called once";
    419     int64 allocate_count =
    420         Eigen::divup(byte_size, static_cast<int64>(sizeof(T)));
    421 
    422     Tensor* temporary_memory = nullptr;
    423     Status allocation_status(context_->allocate_output(
    424         output_index_, TensorShape({allocate_count}), &temporary_memory));
    425     if (!allocation_status.ok()) {
    426       return ToExecutorStatus(allocation_status);
    427     }
    428     total_byte_size_ += byte_size;
    429     auto memory_uint8 = DeviceMemory<uint8>::MakeFromByteSize(
    430         temporary_memory->template flat<T>().data(),
    431         temporary_memory->template flat<T>().size() * sizeof(T));
    432     return StatusOr<DeviceMemory<uint8>>(memory_uint8);
    433   }
    434   int64 TotalByteSize() { return total_byte_size_; }
    435 
    436  private:
    437   int64 total_byte_size_ = 0;
    438   OpKernelContext* context_;  // not owned
    439   int output_index_;
    440 };
    441 
    442 // A helper to allocate persistent memory for Cudnn RNN models, which is
    443 // expected to live between kernel invocations.
    444 // This class is not thread-safe.
    445 class CudnnRNNPersistentSpaceAllocator : public ScratchAllocator {
    446  public:
    447   explicit CudnnRNNPersistentSpaceAllocator(OpKernelContext* context)
    448       : context_(context) {}
    449 
    450   ~CudnnRNNPersistentSpaceAllocator() override {}
    451 
    452   int64 GetMemoryLimitInBytes(Stream* stream) override {
    453     return std::numeric_limits<int64>::max();
    454   }
    455 
    456   StatusOr<DeviceMemory<uint8>> AllocateBytes(Stream* stream,
    457                                               int64 byte_size) override {
    458     if (total_byte_size_ != 0) {
    459       return Status(error::FAILED_PRECONDITION,
    460                     "Persistent space allocator can only be called once");
    461     }
    462 
    463     Status allocation_status = context_->allocate_persistent(
    464         DT_UINT8, TensorShape({byte_size}), &handle_, nullptr);
    465     if (!allocation_status.ok()) {
    466       return ToExecutorStatus(allocation_status);
    467     }
    468     total_byte_size_ += byte_size;
    469     return AsDeviceMemory<uint8>(handle_.AccessTensor(context_));
    470   }
    471   int64 TotalByteSize() { return total_byte_size_; }
    472 
    473  private:
    474   int64 total_byte_size_ = 0;
    475   PersistentTensor handle_;
    476   OpKernelContext* context_;  // not owned
    477 };
    478 
    479 struct CudnnModelTypes {
    480   RnnMode rnn_mode;
    481   TFRNNInputMode rnn_input_mode;
    482   RnnDirectionMode rnn_direction_mode;
    483   bool HasInputC() const {
    484     // For Cudnn 5.0, only LSTM has input-c. All other models use only
    485     // input-h.
    486     return rnn_mode == RnnMode::kRnnLstm;
    487   }
    488 
    489   string DebugString() const {
    490     return strings::Printf(
    491         "[rnn_mode, rnn_input_mode, rnn_direction_mode]: %d, %d, %d ",
    492         static_cast<int>(rnn_mode), static_cast<int>(rnn_input_mode),
    493         static_cast<int>(rnn_direction_mode));
    494   }
    495 };
    496 
    497 // A helper class that collects the shapes to describe a RNN model.
    498 struct CudnnRnnModelShapes {
    499   int num_layers;
    500   int input_size;
    501   int num_units;
    502   int dir_count;
    503   int max_seq_length;
    504   int batch_size;
    505   TensorShape input_shape;
    506   TensorShape output_shape;
    507   TensorShape hidden_state_shape;
    508   // At present only fields related to cached RnnDescriptor are concerned.
    509   bool IsCompatibleWith(const CudnnRnnModelShapes& rhs) const {
    510     return num_layers == rhs.num_layers && input_size == rhs.input_size &&
    511            num_units == rhs.num_units && dir_count == rhs.dir_count;
    512   }
    513   string DebugString() const {
    514     return strings::Printf(
    515         "[num_layers, input_size, num_units, dir_count, max_seq_length, "
    516         "batch_size]: [%d, %d, %d, %d, %d, %d] ",
    517         num_layers, input_size, num_units, dir_count, max_seq_length,
    518         batch_size);
    519   }
    520 };
    521 
    522 // Utility class for using CudnnRnnConfig and AlgorithmDesc pair a hash table
    523 // key.
    524 struct CudnnRnnConfigHasher {
    525   uint64 operator()(
    526       const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>&
    527           to_hash) const {
    528     auto& shapes = to_hash.first;
    529     auto& algo_desc = to_hash.second;
    530 
    531     uint64 hash =
    532         HashList({shapes.num_layers, shapes.input_size, shapes.num_units,
    533                   shapes.dir_count, shapes.batch_size});
    534     if (algo_desc.has_value()) {
    535       hash = Hash64Combine(hash, algo_desc->hash());
    536     }
    537     return hash;
    538   }
    539 };
    540 
    541 // Utility class for using CudnnRnnModelShapes and AlgorithmDesc pair as a hash
    542 // table key.
    543 struct CudnnRnnConfigComparator {
    544   bool operator()(
    545       const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>& lhs,
    546       const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>& rhs)
    547       const {
    548     return lhs.first.IsCompatibleWith(rhs.first) && lhs.second == rhs.second;
    549   }
    550 };
    551 
    552 // Pointers to RNN scratch space for a specific set of shape parameters (used as
    553 // a hash table value in CudnnRNNForwardOp and CudnnRNNBackwardOp).
    554 struct RnnScratchSpace {
    555   std::unique_ptr<RnnDescriptor> rnn_desc;
    556   std::unique_ptr<CudnnRNNPersistentSpaceAllocator> dropout_state_allocator;
    557 };
    558 
    559 // Extract and checks the forward input tensors, parameters, and shapes from the
    560 // OpKernelContext.
    561 Status ExtractForwardInput(OpKernelContext* context,
    562                            const CudnnModelTypes& model_types, bool time_major,
    563                            const Tensor** input, const Tensor** input_h,
    564                            const Tensor** input_c, const Tensor** params,
    565                            CudnnRnnModelShapes* model_shapes) {
    566   TF_RETURN_IF_ERROR(context->input("input", input));
    567   TF_RETURN_IF_ERROR(context->input("input_h", input_h));
    568   if (model_types.HasInputC()) {
    569     TF_RETURN_IF_ERROR(context->input("input_c", input_c));
    570   }
    571   TF_RETURN_IF_ERROR(context->input("params", params));
    572 
    573   if ((*input)->dims() != 3) {
    574     return errors::InvalidArgument("RNN input must be a 3-D vector.");
    575   }
    576   if (time_major) {
    577     model_shapes->max_seq_length = (*input)->dim_size(0);
    578     model_shapes->batch_size = (*input)->dim_size(1);
    579   } else {
    580     model_shapes->max_seq_length = (*input)->dim_size(1);
    581     model_shapes->batch_size = (*input)->dim_size(0);
    582   }
    583   model_shapes->input_size = (*input)->dim_size(2);
    584   model_shapes->input_shape = (*input)->shape();
    585   model_shapes->dir_count =
    586       (model_types.rnn_direction_mode == RnnDirectionMode::kRnnBidirectional)
    587           ? 2
    588           : 1;
    589 
    590   if ((*input_h)->dims() != 3) {
    591     return errors::InvalidArgument("RNN input_h must be a 3-D vector.");
    592   }
    593   if (time_major) {
    594     model_shapes->num_layers =
    595         (*input_h)->dim_size(0) / model_shapes->dir_count;
    596   } else {
    597     model_shapes->num_layers =
    598         (*input_h)->dim_size(1) / model_shapes->dir_count;
    599   }
    600   model_shapes->num_units = (*input_h)->dim_size(2);
    601 
    602   if (time_major) {
    603     model_shapes->hidden_state_shape =
    604         TensorShape({model_shapes->dir_count * model_shapes->num_layers,
    605                      model_shapes->batch_size, model_shapes->num_units});
    606   } else {
    607     model_shapes->hidden_state_shape =
    608         TensorShape({model_shapes->batch_size,
    609                      model_shapes->dir_count * model_shapes->num_layers,
    610                      model_shapes->num_units});
    611   }
    612   if ((*input_h)->shape() != model_shapes->hidden_state_shape) {
    613     return errors::InvalidArgument(
    614         "Invalid input_h shape: ", (*input_h)->shape().DebugString(), " ",
    615         model_shapes->hidden_state_shape.DebugString());
    616   }
    617   if (model_types.HasInputC()) {
    618     if ((*input_h)->shape() != (*input_c)->shape()) {
    619       return errors::InvalidArgument(
    620           "input_h and input_c must have the same shape: ",
    621           (*input_h)->shape().DebugString(), " ",
    622           (*input_c)->shape().DebugString());
    623     }
    624   }
    625   if (time_major) {
    626     model_shapes->output_shape =
    627         TensorShape({model_shapes->max_seq_length, model_shapes->batch_size,
    628                      model_shapes->dir_count * model_shapes->num_units});
    629   } else {
    630     model_shapes->output_shape =
    631         TensorShape({model_shapes->batch_size, model_shapes->max_seq_length,
    632                      model_shapes->dir_count * model_shapes->num_units});
    633   }
    634   return Status::OK();
    635 }
    636 
    637 // Overloaded function to process the sequence_lengths
    638 Status ExtractForwardInput(OpKernelContext* context,
    639                            const CudnnModelTypes& model_types, bool time_major,
    640                            const Tensor** input, const Tensor** input_h,
    641                            const Tensor** input_c, const Tensor** params,
    642                            const Tensor** sequence_lengths,
    643                            CudnnRnnModelShapes* model_shapes) {
    644   TF_RETURN_IF_ERROR(context->input("sequence_lengths", sequence_lengths));
    645   return ExtractForwardInput(context, model_types, time_major, input, input_h,
    646                              input_c, params, model_shapes);
    647 }
    648 
    649 template <typename T>
    650 Status CreateForwardAndBackwardIODescriptors(
    651     OpKernelContext* context, const CudnnRnnModelShapes& model_shapes,
    652     std::unique_ptr<RnnSequenceTensorDescriptor>* input_desc,
    653     std::unique_ptr<RnnStateTensorDescriptor>* state_desc,
    654     std::unique_ptr<RnnSequenceTensorDescriptor>* output_desc,
    655     const absl::Span<const int>& seq_lengths, bool time_major) {
    656   StreamExecutor* executor = context->op_device_context()->stream()->parent();
    657   se::dnn::DataType data_type = ToDataType<T>::value;
    658 
    659   const TensorShape& input_shape = model_shapes.input_shape;
    660   const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
    661   const TensorShape& output_shape = model_shapes.output_shape;
    662 
    663   DCHECK_EQ(input_shape.dims(), 3);
    664   if (seq_lengths.data() != nullptr) {
    665     if (time_major) {
    666       auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
    667           input_shape.dim_size(0), input_shape.dim_size(1),
    668           input_shape.dim_size(2), seq_lengths, time_major, data_type);
    669       TF_RETURN_IF_ERROR(input_desc_s.status());
    670       *input_desc = input_desc_s.ConsumeValueOrDie();
    671     } else {
    672       auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
    673           input_shape.dim_size(1), input_shape.dim_size(0),
    674           input_shape.dim_size(2), seq_lengths, time_major, data_type);
    675       TF_RETURN_IF_ERROR(input_desc_s.status());
    676       *input_desc = input_desc_s.ConsumeValueOrDie();
    677     }
    678   } else {
    679     auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
    680         input_shape.dim_size(0), input_shape.dim_size(1),
    681         input_shape.dim_size(2), data_type);
    682     TF_RETURN_IF_ERROR(input_desc_s.status());
    683     *input_desc = input_desc_s.ConsumeValueOrDie();
    684   }
    685 
    686   DCHECK_EQ(hidden_state_shape.dims(), 3);
    687   if (time_major) {
    688     auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
    689         hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1),
    690         hidden_state_shape.dim_size(2), data_type);
    691     TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
    692     *state_desc = hidden_state_desc_s.ConsumeValueOrDie();
    693   } else {
    694     auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
    695         hidden_state_shape.dim_size(1), hidden_state_shape.dim_size(0),
    696         hidden_state_shape.dim_size(2), data_type);
    697     TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
    698     *state_desc = hidden_state_desc_s.ConsumeValueOrDie();
    699   }
    700 
    701   DCHECK_EQ(output_shape.dims(), 3);
    702   if (seq_lengths.data() != nullptr) {
    703     if (time_major) {
    704       auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
    705           output_shape.dim_size(0), output_shape.dim_size(1),
    706           output_shape.dim_size(2), seq_lengths, time_major, data_type);
    707       TF_RETURN_IF_ERROR(output_desc_s.status());
    708       *output_desc = output_desc_s.ConsumeValueOrDie();
    709     } else {
    710       auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
    711           output_shape.dim_size(1), output_shape.dim_size(0),
    712           output_shape.dim_size(2), seq_lengths, time_major, data_type);
    713       TF_RETURN_IF_ERROR(output_desc_s.status());
    714       *output_desc = output_desc_s.ConsumeValueOrDie();
    715     }
    716   } else {
    717     auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
    718         output_shape.dim_size(0), output_shape.dim_size(1),
    719         output_shape.dim_size(2), data_type);
    720     TF_RETURN_IF_ERROR(output_desc_s.status());
    721     *output_desc = output_desc_s.ConsumeValueOrDie();
    722   }
    723 
    724   return Status::OK();
    725 }
    726 
    727 template <typename T>
    728 Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc,
    729                  const CudnnModelTypes& model_types,
    730                  const CudnnRnnModelShapes& model_shapes,
    731                  /* forward inputs */
    732                  const Tensor* input, const Tensor* input_h,
    733                  const Tensor* input_c, const Tensor* params,
    734                  const bool is_training,
    735                  /* forward outputs, outputs of the function */
    736                  Tensor* output, Tensor* output_h, Tensor* output_c,
    737                  const Tensor* sequence_lengths, bool time_major,
    738                  ScratchAllocator* reserve_space_allocator,
    739                  ScratchAllocator* workspace_allocator,
    740                  ProfileResult* output_profile_result) {
    741   std::unique_ptr<RnnSequenceTensorDescriptor> input_desc;
    742   std::unique_ptr<RnnStateTensorDescriptor> state_desc;
    743   std::unique_ptr<RnnSequenceTensorDescriptor> output_desc;
    744 
    745   absl::Span<const int> seq_lengths;
    746   if (sequence_lengths != nullptr) {
    747     seq_lengths = absl::Span<const int>(
    748         sequence_lengths->template flat<int>().data(), model_shapes.batch_size);
    749   }
    750   TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
    751       context, model_shapes, &input_desc, &state_desc, &output_desc,
    752       seq_lengths, time_major));
    753 
    754   auto input_data = AsDeviceMemory<T>(input);
    755   auto input_h_data = AsDeviceMemory<T>(input_h);
    756   DeviceMemory<T> input_c_data;
    757   if (model_types.HasInputC()) {
    758     input_c_data = AsDeviceMemory<T>(input_c);
    759   }
    760 
    761   auto params_data = AsDeviceMemory<T>(params);
    762   auto output_data = AsDeviceMemory<T>(output);
    763   auto output_h_data = AsDeviceMemory<T>(output_h);
    764   DeviceMemory<T> output_c_data;
    765   if (model_types.HasInputC()) {
    766     output_c_data = AsDeviceMemory<T>(output_c);
    767   }
    768 
    769   Stream* stream = context->op_device_context()->stream();
    770   bool launch_success =
    771       stream
    772           ->ThenRnnForward(rnn_desc, *input_desc, input_data, *state_desc,
    773                            input_h_data, *state_desc, input_c_data, params_data,
    774                            *output_desc, &output_data, *state_desc,
    775                            &output_h_data, *state_desc, &output_c_data,
    776                            is_training, reserve_space_allocator,
    777                            workspace_allocator, output_profile_result)
    778           .ok();
    779   return launch_success
    780              ? Status::OK()
    781              : errors::Internal(
    782                    "Failed to call ThenRnnForward with model config: ",
    783                    model_types.DebugString(), ", ", model_shapes.DebugString());
    784 }
    785 
    786 template <typename T>
    787 Status DoBackward(
    788     OpKernelContext* context, const RnnDescriptor& rnn_desc,
    789     const CudnnModelTypes& model_types, const CudnnRnnModelShapes& model_shapes,
    790     /* forward inputs */
    791     const Tensor* input, const Tensor* input_h, const Tensor* input_c,
    792     const Tensor* params,
    793     /* forward outputs */
    794     const Tensor* output, const Tensor* output_h, const Tensor* output_c,
    795     /* backprop inputs */
    796     const Tensor* output_backprop, const Tensor* output_h_backprop,
    797     const Tensor* output_c_backprop, const Tensor* reserve_space,
    798     /* backprop outputs, output of the function */
    799     Tensor* input_backprop, Tensor* input_h_backprop, Tensor* input_c_backprop,
    800     Tensor* params_backprop, const Tensor* sequence_lengths, bool time_major,
    801     ScratchAllocator* workspace_allocator,
    802     ProfileResult* output_profile_result) {
    803   std::unique_ptr<RnnSequenceTensorDescriptor> input_desc;
    804   std::unique_ptr<RnnStateTensorDescriptor> state_desc;
    805   std::unique_ptr<RnnSequenceTensorDescriptor> output_desc;
    806 
    807   absl::Span<const int> seq_lengths;
    808   if (sequence_lengths != nullptr) {
    809     seq_lengths = absl::Span<const int>(
    810         sequence_lengths->template flat<int>().data(), model_shapes.batch_size);
    811   }
    812   TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
    813       context, model_shapes, &input_desc, &state_desc, &output_desc,
    814       seq_lengths, time_major));
    815 
    816   auto input_data = AsDeviceMemory<T>(input);
    817   auto input_h_data = AsDeviceMemory<T>(input_h);
    818   DeviceMemory<T> input_c_data;
    819   if (model_types.HasInputC()) {
    820     input_c_data = AsDeviceMemory<T>(input_c);
    821   }
    822   auto params_data = AsDeviceMemory<T>(params);
    823   auto output_data = AsDeviceMemory<T>(output);
    824   auto output_h_data = AsDeviceMemory<T>(output_h);
    825   DeviceMemory<T> output_c_data;
    826   if (model_types.HasInputC()) {
    827     output_c_data = AsDeviceMemory<T>(output_c);
    828   }
    829   auto output_backprop_data = AsDeviceMemory<T>(output_backprop);
    830   auto output_h_backprop_data = AsDeviceMemory<T>(output_h_backprop);
    831   DeviceMemory<T> output_c_backprop_data;
    832   if (model_types.HasInputC()) {
    833     output_c_backprop_data = AsDeviceMemory<T>(output_c_backprop);
    834   }
    835   auto input_backprop_data = AsDeviceMemory<T>(input_backprop);
    836   auto input_h_backprop_data = AsDeviceMemory<T>(input_h_backprop);
    837   DeviceMemory<T> input_c_backprop_data;
    838   if (model_types.HasInputC()) {
    839     input_c_backprop_data = AsDeviceMemory<T>(input_c_backprop);
    840   }
    841   auto params_backprop_data = AsDeviceMemory<T>(params_backprop);
    842   auto reserve_space_uint8 =
    843       CastDeviceMemory<uint8, T>(const_cast<Tensor*>(reserve_space));
    844 
    845   // Creates a memory callback for the workspace. The memory lives to the end
    846   // of this kernel calls.
    847   Stream* stream = context->op_device_context()->stream();
    848   bool launch_success =
    849       stream
    850           ->ThenRnnBackward(rnn_desc, *input_desc, input_data, *state_desc,
    851                             input_h_data, *state_desc, input_c_data,
    852                             params_data, *output_desc, output_data, *state_desc,
    853                             output_h_data, *state_desc, output_c_data,
    854                             output_backprop_data, output_h_backprop_data,
    855                             output_c_backprop_data, &input_backprop_data,
    856                             &input_h_backprop_data, &input_c_backprop_data,
    857                             &params_backprop_data, &reserve_space_uint8,
    858                             workspace_allocator, output_profile_result)
    859           .ok();
    860   return launch_success
    861              ? Status::OK()
    862              : errors::Internal(
    863                    "Failed to call ThenRnnBackward with model config: ",
    864                    model_types.DebugString(), ", ", model_shapes.DebugString());
    865 }
    866 
    867 template <typename T>
    868 void RestoreParams(const OpInputList params_input,
    869                    const std::vector<RnnDescriptor::ParamsRegion>& params,
    870                    DeviceMemoryBase* data_dst, Stream* stream) {
    871   int num_params = params.size();
    872   CHECK(params_input.size() == num_params)
    873       << "Number of params mismatch. Expected " << params_input.size()
    874       << ", got " << num_params;
    875   for (int i = 0; i < params.size(); i++) {
    876     int64 size_in_bytes = params[i].size;
    877     int64 size = size_in_bytes / sizeof(T);
    878     CHECK(size == params_input[i].NumElements())
    879         << "Params size mismatch. Expected " << size << ", got "
    880         << params_input[i].NumElements();
    881     auto data_src_ptr = StreamExecutorUtil::AsDeviceMemory<T>(params_input[i]);
    882     DeviceMemoryBase data_dst_ptr =
    883         SliceDeviceMemory(*data_dst, params[i].offset, size_in_bytes);
    884     stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
    885   }
    886 }
    887 
    888 }  // namespace
    889 
    890 // Note: all following kernels depend on a RnnDescriptor instance, which
    891 // according to Cudnn official doc should be kept around and reused across all
    892 // Cudnn kernels in the same model.
    893 // In Tensorflow, we don't pass the reference across different OpKernels,
    894 // rather, recreate it separately in each OpKernel, which does no cause issue:
    895 // CudnnDropoutDescriptor keeps a reference to a memory for
    896 // random number generator state. During recreation, this state is lost.
    897 // However, only forward-pass Cudnn APIs make use of the state.
    898 
    899 // A common base class for RNN kernels. It extracts common attributes and
    900 // shape validations.
    901 class CudnnRNNKernelCommon : public OpKernel {
    902  protected:
    903   explicit CudnnRNNKernelCommon(OpKernelConstruction* context)
    904       : OpKernel(context) {
    905     OP_REQUIRES_OK(context, context->GetAttr("dropout", &dropout_));
    906     OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_));
    907     OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_));
    908     string str;
    909     OP_REQUIRES_OK(context, context->GetAttr("rnn_mode", &str));
    910     OP_REQUIRES_OK(context, ParseRNNMode(str, &model_types_.rnn_mode));
    911     OP_REQUIRES_OK(context, context->GetAttr("input_mode", &str));
    912     OP_REQUIRES_OK(context,
    913                    ParseTFRNNInputMode(str, &model_types_.rnn_input_mode));
    914     OP_REQUIRES_OK(context, context->GetAttr("direction", &str));
    915     OP_REQUIRES_OK(
    916         context, ParseRNNDirectionMode(str, &model_types_.rnn_direction_mode));
    917     // Reset CudnnRnnDescriptor and related random number generate states in
    918     // every Compute() call.
    919     OP_REQUIRES_OK(context, ReadBoolFromEnvVar("TF_CUDNN_RESET_RND_GEN_STATE",
    920                                                false, &reset_rnd_gen_state_));
    921   }
    922 
    923   bool HasInputC() const { return model_types_.HasInputC(); }
    924   RnnMode rnn_mode() const { return model_types_.rnn_mode; }
    925   TFRNNInputMode rnn_input_mode() const { return model_types_.rnn_input_mode; }
    926   RnnDirectionMode rnn_direction_mode() const {
    927     return model_types_.rnn_direction_mode;
    928   }
    929   const CudnnModelTypes& model_types() const { return model_types_; }
    930   float dropout() const { return dropout_; }
    931   uint64 seed() { return (static_cast<uint64>(seed_) << 32) | seed2_; }
    932   bool ResetRndGenState() { return reset_rnd_gen_state_; }
    933 
    934   template <typename T>
    935   Status ExtractCudnnRNNParamsInfo(OpKernelContext* context,
    936                                    std::unique_ptr<RnnDescriptor>* rnn_desc) {
    937     const Tensor* num_layers_t = nullptr;
    938     TF_RETURN_IF_ERROR(context->input("num_layers", &num_layers_t));
    939     if (!TensorShapeUtils::IsScalar(num_layers_t->shape())) {
    940       return errors::InvalidArgument("num_layers is not a scalar");
    941     }
    942     int num_layers = num_layers_t->scalar<int>()();
    943     const Tensor* num_units_t = nullptr;
    944     TF_RETURN_IF_ERROR(context->input("num_units", &num_units_t));
    945     if (!TensorShapeUtils::IsScalar(num_units_t->shape())) {
    946       return errors::InvalidArgument("num_units is not a scalar");
    947     }
    948     int num_units = num_units_t->scalar<int>()();
    949     const Tensor* input_size_t = nullptr;
    950     TF_RETURN_IF_ERROR(context->input("input_size", &input_size_t));
    951     if (!TensorShapeUtils::IsScalar(input_size_t->shape())) {
    952       return errors::InvalidArgument("input_size is not a scalar");
    953     }
    954     int input_size = input_size_t->scalar<int>()();
    955 
    956     RnnInputMode input_mode;
    957     TF_RETURN_IF_ERROR(
    958         ToRNNInputMode(rnn_input_mode(), num_units, input_size, &input_mode));
    959 
    960     Stream* stream = context->op_device_context()->stream();
    961     // ExtracCudnnRNNParamsInfo is only called by op_kernels that do not require
    962     // random number generator, therefore set state_allocator to nullptr.
    963     const AlgorithmConfig algo_config;
    964     auto rnn_desc_s = stream->parent()->createRnnDescriptor(
    965         num_layers, num_units, input_size, /*batch_size=*/0, input_mode,
    966         rnn_direction_mode(), rnn_mode(), ToDataType<T>::value, algo_config,
    967         dropout(), seed(), /* state_allocator=*/nullptr);
    968     if (!rnn_desc_s.ok()) {
    969       return FromExecutorStatus(rnn_desc_s);
    970     }
    971     *rnn_desc = rnn_desc_s.ConsumeValueOrDie();
    972     return Status::OK();
    973   }
    974 
    975   template <typename T>
    976   Status CreateRnnDescriptor(OpKernelContext* context,
    977                              const CudnnRnnModelShapes& model_shapes,
    978                              const RnnInputMode& input_mode,
    979                              const AlgorithmConfig& algo_config,
    980                              ScratchAllocator* dropout_state_allocator,
    981                              std::unique_ptr<RnnDescriptor>* rnn_desc) {
    982     StreamExecutor* executor = context->op_device_context()->stream()->parent();
    983     se::dnn::DataType data_type = ToDataType<T>::value;
    984     auto rnn_desc_s = executor->createRnnDescriptor(
    985         model_shapes.num_layers, model_shapes.num_units,
    986         model_shapes.input_size, model_shapes.batch_size, input_mode,
    987         rnn_direction_mode(), rnn_mode(), data_type, algo_config, dropout(),
    988         seed(), dropout_state_allocator);
    989     TF_RETURN_IF_ERROR(rnn_desc_s.status());
    990 
    991     *rnn_desc = rnn_desc_s.ConsumeValueOrDie();
    992     return Status::OK();
    993   }
    994 
    995   using RnnStateCache = gtl::FlatMap<
    996       std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>,
    997       RnnScratchSpace, CudnnRnnConfigHasher, CudnnRnnConfigComparator>;
    998   // Returns a raw rnn descriptor pointer. The cache owns the rnn descriptor and
    999   // should outlive the returned pointer.
   1000   template <typename T>
   1001   Status GetCachedRnnDescriptor(OpKernelContext* context,
   1002                                 const CudnnRnnModelShapes& model_shapes,
   1003                                 const RnnInputMode& input_mode,
   1004                                 const AlgorithmConfig& algo_config,
   1005                                 RnnStateCache* cache,
   1006                                 RnnDescriptor** rnn_desc) {
   1007     auto key = std::make_pair(model_shapes, algo_config.algorithm());
   1008     RnnScratchSpace& rnn_state = (*cache)[key];
   1009     if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) {
   1010       CudnnRNNPersistentSpaceAllocator* dropout_state_allocator =
   1011           new CudnnRNNPersistentSpaceAllocator(context);
   1012       rnn_state.dropout_state_allocator.reset(dropout_state_allocator);
   1013       Status status =
   1014           CreateRnnDescriptor<T>(context, model_shapes, input_mode, algo_config,
   1015                                  dropout_state_allocator, &rnn_state.rnn_desc);
   1016       TF_RETURN_IF_ERROR(status);
   1017     }
   1018     *rnn_desc = rnn_state.rnn_desc.get();
   1019     return Status::OK();
   1020   }
   1021 
   1022  private:
   1023   int seed_;
   1024   int seed2_;
   1025   float dropout_;
   1026   bool reset_rnd_gen_state_;
   1027 
   1028   CudnnModelTypes model_types_;
   1029 };
   1030 
   1031 // A class that returns the size of the opaque parameter buffer. The user should
   1032 // use that to create the actual parameter buffer for training. However, it
   1033 // should not be used for saving and restoring.
   1034 template <typename T, typename Index>
   1035 class CudnnRNNParamsSizeOp<GPUDevice, T, Index> : public CudnnRNNKernelCommon {
   1036  public:
   1037   explicit CudnnRNNParamsSizeOp(OpKernelConstruction* context)
   1038       : CudnnRNNKernelCommon(context) {}
   1039 
   1040   void Compute(OpKernelContext* context) override {
   1041     std::unique_ptr<RnnDescriptor> rnn_desc;
   1042     OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc));
   1043     int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
   1044     CHECK(params_size_in_bytes % sizeof(T) == 0)
   1045         << "params_size_in_bytes must be multiple of element size";
   1046     int64 params_size = params_size_in_bytes / sizeof(T);
   1047 
   1048     Tensor* output_t = nullptr;
   1049     OP_REQUIRES_OK(context, context->allocate_output(0, {1}, &output_t));
   1050     *output_t->template flat<Index>().data() = params_size;
   1051   }
   1052 };
   1053 
   1054 #define REGISTER_GPU(T)                                    \
   1055   REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsSize")       \
   1056                               .Device(DEVICE_GPU)          \
   1057                               .HostMemory("num_layers")    \
   1058                               .HostMemory("num_units")     \
   1059                               .HostMemory("input_size")    \
   1060                               .HostMemory("params_size")   \
   1061                               .TypeConstraint<T>("T")      \
   1062                               .TypeConstraint<int32>("S"), \
   1063                           CudnnRNNParamsSizeOp<GPUDevice, T, int32>);
   1064 
   1065 TF_CALL_half(REGISTER_GPU);
   1066 TF_CALL_float(REGISTER_GPU);
   1067 TF_CALL_double(REGISTER_GPU);
   1068 #undef REGISTER_GPU
   1069 
   1070 // Convert weight and bias params from a platform-specific layout to the
   1071 // canonical form.
   1072 template <typename T>
   1073 class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
   1074  public:
   1075   explicit CudnnRNNParamsToCanonical(OpKernelConstruction* context)
   1076       : CudnnRNNKernelCommon(context) {
   1077     OP_REQUIRES_OK(context, context->GetAttr("num_params", &num_params_));
   1078   }
   1079 
   1080   void Compute(OpKernelContext* context) override {
   1081     const Tensor& input = context->input(3);
   1082     auto input_ptr = StreamExecutorUtil::AsDeviceMemory<T>(input);
   1083     Stream* stream = context->op_device_context()->stream();
   1084 
   1085     std::unique_ptr<RnnDescriptor> rnn_desc;
   1086     OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc));
   1087     int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
   1088     CHECK(params_size_in_bytes % sizeof(T) == 0)
   1089         << "params_size_in_bytes must be multiple of element size";
   1090 
   1091     const Tensor* num_units_t = nullptr;
   1092     OP_REQUIRES_OK(context, context->input("num_units", &num_units_t));
   1093     CHECK(TensorShapeUtils::IsScalar(num_units_t->shape()))
   1094         << "num_units is not a scalar";
   1095     int num_units = num_units_t->scalar<int>()();
   1096 
   1097     const Tensor* input_size_t = nullptr;
   1098     OP_REQUIRES_OK(context, context->input("input_size", &input_size_t));
   1099     CHECK(TensorShapeUtils::IsScalar(input_size_t->shape()))
   1100         << "input_size is not a scalar";
   1101     int input_size = input_size_t->scalar<int>()();
   1102 
   1103     const Tensor* num_layers_t = nullptr;
   1104     OP_REQUIRES_OK(context, context->input("num_layers", &num_layers_t));
   1105     CHECK(TensorShapeUtils::IsScalar(num_layers_t->shape()))
   1106         << "num_layers is not a scalar";
   1107     int num_layers = num_layers_t->scalar<int>()();
   1108     int num_dirs = 1;
   1109     if (rnn_direction_mode() == RnnDirectionMode::kRnnBidirectional) {
   1110       num_dirs = 2;
   1111     }
   1112     const int num_params_per_layer = num_params_ / num_layers / num_dirs;
   1113     // Number of params applied on inputs. The rest are applied on recurrent
   1114     // hidden states.
   1115     const int num_params_input_state = num_params_per_layer / 2;
   1116     CHECK(num_params_ % (num_layers * num_dirs) == 0)
   1117         << "Number of params is not a multiple of num_layers * num_dirs.";
   1118     CHECK(num_params_per_layer % 2 == 0)
   1119         << "Number of params per layer is not a even number.";
   1120 
   1121     CHECK(num_params_ == rnn_desc->ParamsWeightRegions().size())
   1122         << "Number of params mismatch. Expected " << num_params_ << ", got "
   1123         << rnn_desc->ParamsWeightRegions().size();
   1124     for (int i = 0; i < rnn_desc->ParamsWeightRegions().size(); i++) {
   1125       int64 size_in_bytes = rnn_desc->ParamsWeightRegions()[i].size;
   1126       int64 size = size_in_bytes / sizeof(T);
   1127       const int layer_idx = i / num_params_per_layer;
   1128       const int index_within_layer = i % num_params_per_layer;
   1129       int width = 0, height = num_units;
   1130       // In CuDNN layout, each layer has num_params_per_layer params, with the
   1131       // first half a.k.a num_params_input_state params applied on the inputs,
   1132       // and the second half on the recurrent hidden states.
   1133       bool apply_on_input_state = index_within_layer < num_params_input_state;
   1134       if (rnn_direction_mode() == RnnDirectionMode::kRnnUnidirectional) {
   1135         if (layer_idx == 0 && apply_on_input_state) {
   1136           width = input_size;
   1137         } else {
   1138           width = num_units;
   1139         }
   1140       } else {
   1141         if (apply_on_input_state) {
   1142           if (layer_idx <= 1) {
   1143             // First fwd or bak layer.
   1144             width = input_size;
   1145           } else {
   1146             // Following layers, cell inputs are concatenated outputs of
   1147             // its prior layer.
   1148             width = 2 * num_units;
   1149           }
   1150         } else {
   1151           width = num_units;
   1152         }
   1153       }
   1154       CHECK(size == width * height) << "Params size mismatch. Expected "
   1155                                     << width * height << ", got " << size;
   1156       Tensor* output = nullptr;
   1157       OP_REQUIRES_OK(context, context->allocate_output(
   1158                                   i, TensorShape({height, width}), &output));
   1159       DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
   1160           input_ptr, rnn_desc->ParamsWeightRegions()[i].offset, size_in_bytes);
   1161       auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
   1162       stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
   1163     }
   1164 
   1165     OP_REQUIRES(context, num_params_ == rnn_desc->ParamsBiasRegions().size(),
   1166                 errors::InvalidArgument("Number of params mismatch. Expected ",
   1167                                         num_params_, ", got ",
   1168                                         rnn_desc->ParamsBiasRegions().size()));
   1169     for (int i = 0; i < rnn_desc->ParamsBiasRegions().size(); i++) {
   1170       int64 size_in_bytes = rnn_desc->ParamsBiasRegions()[i].size;
   1171       int64 size = size_in_bytes / sizeof(T);
   1172       OP_REQUIRES(context, size == num_units,
   1173                   errors::InvalidArgument("Params size mismatch. Expected ",
   1174                                           num_units, ", got ", size));
   1175 
   1176       Tensor* output = nullptr;
   1177       OP_REQUIRES_OK(context,
   1178                      context->allocate_output(num_params_ + i,
   1179                                               TensorShape({size}), &output));
   1180       DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
   1181           input_ptr, rnn_desc->ParamsBiasRegions()[i].offset, size_in_bytes);
   1182       auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
   1183       stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
   1184     }
   1185   }
   1186 
   1187  private:
   1188   int num_params_;
   1189 };
   1190 
   1191 #define REGISTER_GPU(T)                                     \
   1192   REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsToCanonical") \
   1193                               .Device(DEVICE_GPU)           \
   1194                               .HostMemory("num_layers")     \
   1195                               .HostMemory("num_units")      \
   1196                               .HostMemory("input_size")     \
   1197                               .TypeConstraint<T>("T"),      \
   1198                           CudnnRNNParamsToCanonical<GPUDevice, T>);
   1199 TF_CALL_half(REGISTER_GPU);
   1200 TF_CALL_float(REGISTER_GPU);
   1201 TF_CALL_double(REGISTER_GPU);
   1202 #undef REGISTER_GPU
   1203 
   1204 // Convert weight and bias params from the canonical form to a
   1205 // platform-specific layout.
   1206 template <typename T>
   1207 class CudnnRNNCanonicalToParams<GPUDevice, T> : public CudnnRNNKernelCommon {
   1208  public:
   1209   explicit CudnnRNNCanonicalToParams(OpKernelConstruction* context)
   1210       : CudnnRNNKernelCommon(context) {}
   1211 
   1212   void Compute(OpKernelContext* context) override {
   1213     std::unique_ptr<RnnDescriptor> rnn_desc;
   1214     OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc));
   1215     int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
   1216     CHECK(params_size_in_bytes % sizeof(T) == 0)
   1217         << "params_size_in_bytes must be multiple of element size";
   1218     Tensor* output = nullptr;
   1219     int params_size = params_size_in_bytes / sizeof(T);
   1220     OP_REQUIRES_OK(context,
   1221                    context->allocate_output(0, {params_size}, &output));
   1222     auto output_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
   1223     Stream* stream = context->op_device_context()->stream();
   1224 
   1225     OpInputList weights;
   1226     OP_REQUIRES_OK(context, context->input_list("weights", &weights));
   1227     RestoreParams<T>(weights, rnn_desc->ParamsWeightRegions(), &output_ptr,
   1228                      stream);
   1229 
   1230     OpInputList biases;
   1231     OP_REQUIRES_OK(context, context->input_list("biases", &biases));
   1232     RestoreParams<T>(biases, rnn_desc->ParamsBiasRegions(), &output_ptr,
   1233                      stream);
   1234   }
   1235 };
   1236 
   1237 #define REGISTER_GPU(T)                                     \
   1238   REGISTER_KERNEL_BUILDER(Name("CudnnRNNCanonicalToParams") \
   1239                               .Device(DEVICE_GPU)           \
   1240                               .HostMemory("num_layers")     \
   1241                               .HostMemory("num_units")      \
   1242                               .HostMemory("input_size")     \
   1243                               .TypeConstraint<T>("T"),      \
   1244                           CudnnRNNCanonicalToParams<GPUDevice, T>);
   1245 TF_CALL_half(REGISTER_GPU);
   1246 TF_CALL_float(REGISTER_GPU);
   1247 TF_CALL_double(REGISTER_GPU);
   1248 #undef REGISTER_GPU
   1249 
   1250 // Run the forward operation of the RNN model.
   1251 template <typename T>
   1252 class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
   1253  public:
   1254   explicit CudnnRNNForwardOp(OpKernelConstruction* context)
   1255       : CudnnRNNKernelCommon(context) {
   1256     OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
   1257 
   1258     // Read debug env variables.
   1259     is_debug_mode_ = DebugCudnnRnn();
   1260     debug_cudnn_rnn_algo_ = DebugCudnnRnnAlgo();
   1261     debug_use_tensor_ops_ = DebugCudnnRnnUseTensorOps();
   1262   }
   1263 
   1264   void Compute(OpKernelContext* context) override {
   1265     AlgorithmConfig algo_config;
   1266     ComputeAndReturnAlgorithm(context, &algo_config, /*var_seq_lengths=*/false,
   1267                               /*time_major=*/true);
   1268   }
   1269 
   1270  protected:
   1271   virtual void ComputeAndReturnAlgorithm(OpKernelContext* context,
   1272                                          AlgorithmConfig* output_algo_config,
   1273                                          bool var_seq_lengths,
   1274                                          bool time_major) {
   1275     CHECK_NE(output_algo_config, nullptr);
   1276 
   1277     const Tensor* input = nullptr;
   1278     const Tensor* input_h = nullptr;
   1279     const Tensor* input_c = nullptr;
   1280     const Tensor* params = nullptr;
   1281     const Tensor* sequence_lengths = nullptr;
   1282     CudnnRnnModelShapes model_shapes;
   1283     if (var_seq_lengths) {
   1284       OP_REQUIRES_OK(context,
   1285                      ExtractForwardInput(context, model_types(), time_major,
   1286                                          &input, &input_h, &input_c, &params,
   1287                                          &sequence_lengths, &model_shapes));
   1288     } else {
   1289       OP_REQUIRES_OK(context, ExtractForwardInput(
   1290                                   context, model_types(), time_major, &input,
   1291                                   &input_h, &input_c, &params, &model_shapes));
   1292     }
   1293     RnnInputMode input_mode;
   1294     OP_REQUIRES_OK(context,
   1295                    ToRNNInputMode(rnn_input_mode(), model_shapes.num_units,
   1296                                   model_shapes.input_size, &input_mode));
   1297 
   1298     Tensor* output = nullptr;
   1299     Tensor* output_h = nullptr;
   1300     Tensor* output_c = nullptr;
   1301     OP_REQUIRES_OK(context, AllocateOutputs(context, model_shapes, &output,
   1302                                             &output_h, &output_c));
   1303 
   1304     // Creates a memory callback for the reserve_space. The memory lives in the
   1305     // output of this kernel. And it will be fed into the backward pass when
   1306     // needed.
   1307     CudnnRnnAllocatorInOutput<T> reserve_space_allocator(context, 3);
   1308     // Creates a memory callback for the workspace. The memory lives to the end
   1309     // of this kernel calls.
   1310     CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
   1311 
   1312     if (is_debug_mode_) {
   1313       AlgorithmDesc algo_desc(debug_cudnn_rnn_algo_, debug_use_tensor_ops_);
   1314       output_algo_config->set_algorithm(algo_desc);
   1315     } else {
   1316       OP_REQUIRES_OK(context,
   1317                      MaybeAutoTune(context, model_shapes, input_mode, input,
   1318                                    input_h, input_c, params, output, output_h,
   1319                                    output_c, output_algo_config));
   1320     }
   1321 
   1322     Status launch_status;
   1323     {
   1324       mutex_lock l(mu_);
   1325       RnnDescriptor* rnn_desc_ptr = nullptr;
   1326       OP_REQUIRES_OK(
   1327           context, GetCachedRnnDescriptor<T>(context, model_shapes, input_mode,
   1328                                              *output_algo_config,
   1329                                              &rnn_state_cache_, &rnn_desc_ptr));
   1330       launch_status = DoForward<T>(
   1331           context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
   1332           input_c, params, is_training_, output, output_h, output_c,
   1333           sequence_lengths, time_major, &reserve_space_allocator,
   1334           &workspace_allocator, /*output_profile_result=*/nullptr);
   1335     }
   1336     OP_REQUIRES_OK(context, launch_status);
   1337   }
   1338 
   1339  protected:
   1340   virtual Status MaybeAutoTune(OpKernelContext* context,
   1341                                const CudnnRnnModelShapes& model_shapes,
   1342                                const RnnInputMode& input_mode,
   1343                                const Tensor* input, const Tensor* input_h,
   1344                                const Tensor* input_c, const Tensor* params,
   1345                                Tensor* output, Tensor* output_h,
   1346                                Tensor* output_c,
   1347                                AlgorithmConfig* best_algo_config) {
   1348     CHECK_NE(best_algo_config, nullptr);
   1349     *best_algo_config = AlgorithmConfig();
   1350     return Status::OK();
   1351   }
   1352 
   1353   bool is_training() const { return is_training_; }
   1354   bool is_debug_mode_;
   1355   bool debug_use_tensor_ops_;
   1356   int64 debug_cudnn_rnn_algo_;
   1357 
   1358  private:
   1359   Status AllocateOutputs(OpKernelContext* context,
   1360                          const CudnnRnnModelShapes& model_shapes,
   1361                          Tensor** output, Tensor** output_h,
   1362                          Tensor** output_c) {
   1363     const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
   1364     const TensorShape& output_shape = model_shapes.output_shape;
   1365 
   1366     TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, output));
   1367     TF_RETURN_IF_ERROR(
   1368         context->allocate_output(1, hidden_state_shape, output_h));
   1369     if (HasInputC()) {
   1370       TF_RETURN_IF_ERROR(
   1371           context->allocate_output(2, hidden_state_shape, output_c));
   1372     } else {
   1373       // Only LSTM uses input_c and output_c. So for all other models, we only
   1374       // need to create dummy outputs.
   1375       TF_RETURN_IF_ERROR(context->allocate_output(2, {}, output_c));
   1376     }
   1377     if (!is_training_) {
   1378       Tensor* dummy_reserve_space = nullptr;
   1379       TF_RETURN_IF_ERROR(context->allocate_output(3, {}, &dummy_reserve_space));
   1380     }
   1381     return Status::OK();
   1382   }
   1383 
   1384   mutex mu_;
   1385   bool is_training_;
   1386   RnnStateCache rnn_state_cache_ GUARDED_BY(mu_);
   1387 };
   1388 
   1389 #define REGISTER_GPU(T)                                           \
   1390   REGISTER_KERNEL_BUILDER(                                        \
   1391       Name("CudnnRNN").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
   1392       CudnnRNNForwardOp<GPUDevice, T>);
   1393 
   1394 TF_CALL_half(REGISTER_GPU);
   1395 TF_CALL_float(REGISTER_GPU);
   1396 TF_CALL_double(REGISTER_GPU);
   1397 #undef REGISTER_GPU
   1398 
   1399 template <typename T>
   1400 class CudnnRNNForwardOpV2<GPUDevice, T>
   1401     : public CudnnRNNForwardOp<GPUDevice, T> {
   1402  private:
   1403   using CudnnRNNForwardOp<GPUDevice, T>::is_training;
   1404   using CudnnRNNKernelCommon::CreateRnnDescriptor;
   1405   using CudnnRNNKernelCommon::dropout;
   1406   using CudnnRNNKernelCommon::HasInputC;
   1407   using CudnnRNNKernelCommon::model_types;
   1408 
   1409  public:
   1410   explicit CudnnRNNForwardOpV2(OpKernelConstruction* context)
   1411       : CudnnRNNForwardOp<GPUDevice, T>(context) {}
   1412 
   1413   void Compute(OpKernelContext* context) override {
   1414     AlgorithmConfig best_algo_config;
   1415     CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
   1416         context, &best_algo_config, /*var_seq_lengths=*/false,
   1417         /*time_major=*/true);
   1418     if (!context->status().ok()) {
   1419       return;
   1420     }
   1421 
   1422     Tensor* output_host_reserved = nullptr;
   1423     // output_host_reserved stores opaque info used for backprop when running
   1424     // in training mode. At present, it includes a serialization of the best
   1425     // AlgorithmDesc picked during rnn forward pass autotune.
   1426     // int8 algorithm_id
   1427     // int8 use_tensor_op
   1428     // If autotune is not enabled, the algorithm_id is
   1429     // stream_executor::dnn::kDefaultAlgorithm and use_tensor_op is false. If
   1430     // running in inference mode, the output_host_reserved is currently not
   1431     // populated.
   1432     if (is_training()) {
   1433       OP_REQUIRES_OK(context, context->allocate_output(4, TensorShape({2}),
   1434                                                        &output_host_reserved));
   1435       auto output_host_reserved_int8 = output_host_reserved->vec<int8>();
   1436       output_host_reserved_int8(0) = best_algo_config.algorithm()->algo_id();
   1437       output_host_reserved_int8(1) =
   1438           best_algo_config.algorithm()->tensor_ops_enabled();
   1439     } else {
   1440       OP_REQUIRES_OK(context,
   1441                      context->allocate_output(4, {}, &output_host_reserved));
   1442     }
   1443   }
   1444 
   1445  protected:
   1446   Status MaybeAutoTune(OpKernelContext* context,
   1447                        const CudnnRnnModelShapes& model_shapes,
   1448                        const RnnInputMode& input_mode, const Tensor* input,
   1449                        const Tensor* input_h, const Tensor* input_c,
   1450                        const Tensor* params, Tensor* output, Tensor* output_h,
   1451                        Tensor* output_c,
   1452                        AlgorithmConfig* algo_config) override {
   1453     CHECK_NE(algo_config, nullptr);
   1454     if (!CudnnRnnUseAutotune() || this->is_debug_mode_) {
   1455       *algo_config = AlgorithmConfig();
   1456       return Status::OK();
   1457     }
   1458 
   1459     std::vector<AlgorithmDesc> algorithms;
   1460     auto* stream = context->op_device_context()->stream();
   1461     CHECK(stream->parent()->GetRnnAlgorithms(&algorithms));
   1462     if (algorithms.empty()) {
   1463       LOG(WARNING) << "No Rnn algorithm found";
   1464       return Status::OK();
   1465     }
   1466 
   1467     const auto& modeltypes = model_types();
   1468     CudnnRnnParameters rnn_params(
   1469         model_shapes.num_layers, model_shapes.input_size,
   1470         model_shapes.num_units, model_shapes.max_seq_length,
   1471         model_shapes.batch_size, model_shapes.dir_count,
   1472         /*has_dropout=*/std::abs(dropout()) > 1e-8, is_training(),
   1473         modeltypes.rnn_mode, modeltypes.rnn_input_mode, input->dtype());
   1474 
   1475     if (AutoTuneRnnConfigMap::GetInstance()->Find(rnn_params, algo_config)) {
   1476       VLOG(1) << "Using existing best Cudnn RNN algorithm "
   1477               << "(algo, tensor_op_enabled) = ("
   1478               << algo_config->algorithm()->algo_id() << ", "
   1479               << algo_config->algorithm()->tensor_ops_enabled() << ").";
   1480       return Status::OK();
   1481     }
   1482 
   1483     // Create temp tensors when profiling backprop pass.
   1484     auto data_type = input->dtype();
   1485     Tensor output_backprop;
   1486     Tensor output_h_backprop;
   1487     Tensor output_c_backprop;
   1488     Tensor input_backprop;
   1489     Tensor input_h_backprop;
   1490     Tensor input_c_backprop;
   1491     Tensor params_backprop;
   1492     if (is_training()) {
   1493       TF_RETURN_IF_ERROR(context->allocate_temp(
   1494           data_type, model_shapes.output_shape, &output_backprop));
   1495       TF_RETURN_IF_ERROR(context->allocate_temp(
   1496           data_type, model_shapes.hidden_state_shape, &output_h_backprop));
   1497 
   1498       TF_RETURN_IF_ERROR(
   1499           context->allocate_temp(data_type, params->shape(), &params_backprop));
   1500       TF_RETURN_IF_ERROR(context->allocate_temp(
   1501           data_type, model_shapes.input_shape, &input_backprop));
   1502       TF_RETURN_IF_ERROR(context->allocate_temp(
   1503           data_type, model_shapes.hidden_state_shape, &input_h_backprop));
   1504       if (HasInputC()) {
   1505         TF_RETURN_IF_ERROR(context->allocate_temp(
   1506             data_type, model_shapes.hidden_state_shape, &output_c_backprop));
   1507         TF_RETURN_IF_ERROR(context->allocate_temp(
   1508             data_type, model_shapes.hidden_state_shape, &input_c_backprop));
   1509       }
   1510     }
   1511     ProfileResult best_result;
   1512     for (auto& algo : algorithms) {
   1513       VLOG(1) << "Profile Cudnn RNN algorithm (algo, tensor_op_enabled) =  ("
   1514               << algo.algo_id() << ", " << algo.tensor_ops_enabled() << ").";
   1515       Status status;
   1516       ProfileResult final_profile_result;
   1517 
   1518       ProfileResult fwd_profile_result;
   1519       ProfileResult bak_profile_result;
   1520 
   1521       // RnnDescriptor is algorithm-dependent, thus not reusable.
   1522       std::unique_ptr<RnnDescriptor> rnn_desc;
   1523       // Use a temp scratch allocator for the random num generator.
   1524       CudnnRnnAllocatorInTemp<uint8> dropout_state_allocator(context);
   1525       if (!this->template CreateRnnDescriptor<T>(
   1526                    context, model_shapes, input_mode, AlgorithmConfig(algo),
   1527                    &dropout_state_allocator, &rnn_desc)
   1528                .ok()) {
   1529         continue;
   1530       }
   1531 
   1532       // Again use temp scratch allocator during profiling.
   1533       CudnnRnnAllocatorInTemp<T> reserve_space_allocator(context);
   1534       CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
   1535       status = DoForward<T>(context, *rnn_desc, model_types(), model_shapes,
   1536                             input, input_h, input_c, params, is_training(),
   1537                             output, output_h, output_c, nullptr, true,
   1538                             &reserve_space_allocator, &workspace_allocator,
   1539                             &fwd_profile_result);
   1540       if (!status.ok()) {
   1541         continue;
   1542       }
   1543 
   1544       if (is_training()) {
   1545         // Get reserve space from the forward pass.
   1546         Tensor reserve_space = reserve_space_allocator.get_allocated_tensor(0);
   1547         status = DoBackward<T>(
   1548             context, *rnn_desc, model_types(), model_shapes, input, input_h,
   1549             input_c, params, output, output_h, output_c, &output_backprop,
   1550             &output_h_backprop, &output_c_backprop, &reserve_space,
   1551             &input_backprop, &input_h_backprop, &input_c_backprop,
   1552             &params_backprop, nullptr, true, &workspace_allocator,
   1553             &bak_profile_result);
   1554         if (!status.ok()) {
   1555           continue;
   1556         }
   1557         final_profile_result.set_elapsed_time_in_ms(
   1558             fwd_profile_result.elapsed_time_in_ms() +
   1559             bak_profile_result.elapsed_time_in_ms());
   1560       } else {
   1561         final_profile_result = fwd_profile_result;
   1562       }
   1563 
   1564       auto total_time = final_profile_result.elapsed_time_in_ms();
   1565       VLOG(1) << "Cudnn RNN algorithm (algo, tensor_op_enabled) =  ("
   1566               << algo.algo_id() << ", " << algo.tensor_ops_enabled() << ")"
   1567               << " run time: " << total_time << " ms.";
   1568       if (total_time < best_result.elapsed_time_in_ms()) {
   1569         best_result.set_elapsed_time_in_ms(total_time);
   1570         best_result.set_algorithm(algo);
   1571       }
   1572     }
   1573 
   1574     if (!best_result.is_valid()) {
   1575       return Status(error::Code::INTERNAL, "No algorithm worked!");
   1576     }
   1577     algo_config->set_algorithm(best_result.algorithm());
   1578     VLOG(1) << "Best Cudnn RNN algorithm (algo, tensor_op_enabled) =  ("
   1579             << best_result.algorithm().algo_id() << ", "
   1580             << best_result.algorithm().tensor_ops_enabled() << ").";
   1581     AutoTuneRnnConfigMap::GetInstance()->Insert(rnn_params, *algo_config);
   1582     return Status::OK();
   1583   }
   1584 };
   1585 
   1586 #define REGISTER_GPU(T)                                    \
   1587   REGISTER_KERNEL_BUILDER(Name("CudnnRNNV2")               \
   1588                               .Device(DEVICE_GPU)          \
   1589                               .HostMemory("host_reserved") \
   1590                               .TypeConstraint<T>("T"),     \
   1591                           CudnnRNNForwardOpV2<GPUDevice, T>);
   1592 
   1593 TF_CALL_half(REGISTER_GPU);
   1594 TF_CALL_float(REGISTER_GPU);
   1595 TF_CALL_double(REGISTER_GPU);
   1596 #undef REGISTER_GPU
   1597 
   1598 template <typename T>
   1599 class CudnnRNNForwardOpV3<GPUDevice, T>
   1600     : public CudnnRNNForwardOp<GPUDevice, T> {
   1601  private:
   1602   using CudnnRNNForwardOp<GPUDevice, T>::is_training;
   1603   using CudnnRNNKernelCommon::CreateRnnDescriptor;
   1604   using CudnnRNNKernelCommon::dropout;
   1605   using CudnnRNNKernelCommon::HasInputC;
   1606   using CudnnRNNKernelCommon::model_types;
   1607   bool time_major_;
   1608 
   1609  protected:
   1610   bool time_major() { return time_major_; }
   1611 
   1612  public:
   1613   explicit CudnnRNNForwardOpV3(OpKernelConstruction* context)
   1614       : CudnnRNNForwardOp<GPUDevice, T>(context) {
   1615     OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_));
   1616   }
   1617 
   1618   void Compute(OpKernelContext* context) override {
   1619     AlgorithmConfig best_algo_config;
   1620     CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
   1621         context, &best_algo_config, /*var_seq_lengths=*/true,
   1622         /*time_major=*/time_major());
   1623     if (!context->status().ok()) {
   1624       return;
   1625     }
   1626 
   1627     Tensor* output_host_reserved = nullptr;
   1628     // TODO: Current V3 only uses the default standard algorithm to process
   1629     // batches with variable sequences and the inputs should be padded.
   1630     // Autotune is not supported yet.
   1631     OP_REQUIRES_OK(context,
   1632                    context->allocate_output(4, {}, &output_host_reserved));
   1633   }
   1634 };
   1635 
   1636 #define REGISTER_GPU(T)                                       \
   1637   REGISTER_KERNEL_BUILDER(Name("CudnnRNNV3")                  \
   1638                               .Device(DEVICE_GPU)             \
   1639                               .HostMemory("sequence_lengths") \
   1640                               .HostMemory("host_reserved")    \
   1641                               .TypeConstraint<T>("T"),        \
   1642                           CudnnRNNForwardOpV3<GPUDevice, T>);
   1643 
   1644 TF_CALL_half(REGISTER_GPU);
   1645 TF_CALL_float(REGISTER_GPU);
   1646 TF_CALL_double(REGISTER_GPU);
   1647 #undef REGISTER_GPU
   1648 
   1649 // Run the backward operation of the RNN model.
   1650 template <typename T>
   1651 class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
   1652  public:
   1653   explicit CudnnRNNBackwardOp(OpKernelConstruction* context)
   1654       : CudnnRNNKernelCommon(context) {}
   1655 
   1656   void Compute(OpKernelContext* context) override {
   1657     ComputeImpl(context, false, true);
   1658   }
   1659 
   1660  protected:
   1661   virtual void ComputeImpl(OpKernelContext* context, bool var_seq_lengths,
   1662                            bool time_major) {
   1663     const Tensor* input = nullptr;
   1664     const Tensor* input_h = nullptr;
   1665     const Tensor* input_c = nullptr;
   1666     const Tensor* params = nullptr;
   1667     const Tensor* sequence_lengths = nullptr;
   1668     CudnnRnnModelShapes model_shapes;
   1669     if (var_seq_lengths) {
   1670       OP_REQUIRES_OK(context,
   1671                      ExtractForwardInput(context, model_types(), time_major,
   1672                                          &input, &input_h, &input_c, &params,
   1673                                          &sequence_lengths, &model_shapes));
   1674     } else {
   1675       OP_REQUIRES_OK(context, ExtractForwardInput(
   1676                                   context, model_types(), time_major, &input,
   1677                                   &input_h, &input_c, &params, &model_shapes));
   1678     }
   1679     RnnInputMode input_mode;
   1680     OP_REQUIRES_OK(context,
   1681                    ToRNNInputMode(rnn_input_mode(), model_shapes.num_units,
   1682                                   model_shapes.input_size, &input_mode));
   1683 
   1684     const Tensor* output = nullptr;
   1685     const Tensor* output_h = nullptr;
   1686     const Tensor* output_c = nullptr;
   1687     const Tensor* output_backprop = nullptr;
   1688     const Tensor* output_h_backprop = nullptr;
   1689     const Tensor* output_c_backprop = nullptr;
   1690     const Tensor* reserve_space = nullptr;
   1691     OP_REQUIRES_OK(context,
   1692                    ExtractBackwardInputs(context, model_shapes, model_types(),
   1693                                          &output, &output_h, &output_c,
   1694                                          &output_backprop, &output_h_backprop,
   1695                                          &output_c_backprop, &reserve_space));
   1696 
   1697     Tensor* input_backprop = nullptr;
   1698     Tensor* input_h_backprop = nullptr;
   1699     Tensor* input_c_backprop = nullptr;
   1700     Tensor* params_backprop = nullptr;
   1701     OP_REQUIRES_OK(context,
   1702                    AllocateOutputs(context, model_shapes, params->shape(),
   1703                                    &input_backprop, &input_h_backprop,
   1704                                    &input_c_backprop, &params_backprop));
   1705 
   1706     // Creates a memory callback for the workspace. The memory lives to the end
   1707     // of this kernel calls.
   1708     CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
   1709     AlgorithmConfig algo_config;
   1710     OP_REQUIRES_OK(context, GetAlgorithm(context, &algo_config));
   1711     Status launch_status;
   1712     {
   1713       mutex_lock l(mu_);
   1714       RnnDescriptor* rnn_desc_ptr = nullptr;
   1715       OP_REQUIRES_OK(
   1716           context, GetCachedRnnDescriptor<T>(context, model_shapes, input_mode,
   1717                                              algo_config, &rnn_state_cache_,
   1718                                              &rnn_desc_ptr));
   1719       launch_status = DoBackward<T>(
   1720           context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
   1721           input_c, params, output, output_h, output_c, output_backprop,
   1722           output_h_backprop, output_c_backprop, reserve_space, input_backprop,
   1723           input_h_backprop, input_c_backprop, params_backprop, sequence_lengths,
   1724           time_major, &workspace_allocator,
   1725           /*output_profile_result=*/nullptr);
   1726     }
   1727     OP_REQUIRES_OK(context, launch_status);
   1728   }
   1729 
   1730  protected:
   1731   virtual Status GetAlgorithm(OpKernelContext* context,
   1732                               AlgorithmConfig* algo_config) {
   1733     CHECK_NE(algo_config, nullptr);
   1734     *algo_config = AlgorithmConfig();
   1735     return Status::OK();
   1736   }
   1737 
   1738  private:
   1739   mutex mu_;
   1740   RnnStateCache rnn_state_cache_ GUARDED_BY(mu_);
   1741 
   1742   Status ExtractBackwardInputs(
   1743       OpKernelContext* context, const CudnnRnnModelShapes& model_shapes,
   1744       const CudnnModelTypes& model_types, const Tensor** output,
   1745       const Tensor** output_h, const Tensor** output_c,
   1746       const Tensor** output_backprop, const Tensor** output_h_backprop,
   1747       const Tensor** output_c_backprop, const Tensor** reserve_space) {
   1748     TF_RETURN_IF_ERROR(context->input("output", output));
   1749     TF_RETURN_IF_ERROR(context->input("output_backprop", output_backprop));
   1750     TF_RETURN_IF_ERROR(context->input("output_h", output_h));
   1751     TF_RETURN_IF_ERROR(context->input("output_h_backprop", output_h_backprop));
   1752     if (model_types.HasInputC()) {
   1753       TF_RETURN_IF_ERROR(context->input("output_c", output_c));
   1754       TF_RETURN_IF_ERROR(
   1755           context->input("output_c_backprop", output_c_backprop));
   1756     }
   1757     TF_RETURN_IF_ERROR(context->input("reserve_space", reserve_space));
   1758     const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
   1759     const TensorShape& output_shape = model_shapes.output_shape;
   1760 
   1761     if (output_shape != (*output)->shape()) {
   1762       return errors::InvalidArgument(
   1763           "Invalid output shape: ", (*output)->shape().DebugString(), " ",
   1764           output_shape.DebugString());
   1765     }
   1766     if (hidden_state_shape != (*output_h)->shape()) {
   1767       return errors::InvalidArgument(
   1768           "Invalid output_h shape: ", (*output_h)->shape().DebugString(), " ",
   1769           hidden_state_shape.DebugString());
   1770     }
   1771 
   1772     if (output_shape != (*output_backprop)->shape()) {
   1773       return errors::InvalidArgument("Invalid output_backprop shape: ",
   1774                                      (*output_backprop)->shape().DebugString(),
   1775                                      " ", output_shape.DebugString());
   1776     }
   1777     if (hidden_state_shape != (*output_h_backprop)->shape()) {
   1778       return errors::InvalidArgument(
   1779           "Invalid output_h_backprop shape: ",
   1780           (*output_h_backprop)->shape().DebugString(), " ",
   1781           hidden_state_shape.DebugString());
   1782     }
   1783 
   1784     if (model_types.HasInputC()) {
   1785       if (hidden_state_shape != (*output_c)->shape()) {
   1786         return errors::InvalidArgument(
   1787             "Invalid output_c shape: ", (*output_c)->shape().DebugString(), " ",
   1788             hidden_state_shape.DebugString());
   1789       }
   1790       if (hidden_state_shape != (*output_c_backprop)->shape()) {
   1791         return errors::InvalidArgument(
   1792             "Invalid output_c_backprop shape: ",
   1793             (*output_c_backprop)->shape().DebugString(), " ",
   1794             hidden_state_shape.DebugString());
   1795       }
   1796     }
   1797     return Status::OK();
   1798   }
   1799 
   1800   Status AllocateOutputs(OpKernelContext* context,
   1801                          const CudnnRnnModelShapes& model_shapes,
   1802                          const TensorShape& params_shape,
   1803                          Tensor** input_backprop, Tensor** input_h_backprop,
   1804                          Tensor** input_c_backprop, Tensor** params_backprop) {
   1805     const TensorShape& input_shape = model_shapes.input_shape;
   1806     const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
   1807 
   1808     TF_RETURN_IF_ERROR(
   1809         context->allocate_output(0, input_shape, input_backprop));
   1810     TF_RETURN_IF_ERROR(
   1811         context->allocate_output(1, hidden_state_shape, input_h_backprop));
   1812     if (HasInputC()) {
   1813       TF_RETURN_IF_ERROR(
   1814           context->allocate_output(2, hidden_state_shape, input_c_backprop));
   1815     } else {
   1816       // Only LSTM uses input_c and output_c. So for all other models, we only
   1817       // need to create dummy outputs.
   1818       TF_RETURN_IF_ERROR(context->allocate_output(2, {}, input_c_backprop));
   1819     }
   1820     TF_RETURN_IF_ERROR(
   1821         context->allocate_output(3, params_shape, params_backprop));
   1822     return Status::OK();
   1823   }
   1824 };
   1825 
   1826 #define REGISTER_GPU(T)                                                   \
   1827   REGISTER_KERNEL_BUILDER(                                                \
   1828       Name("CudnnRNNBackprop").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
   1829       CudnnRNNBackwardOp<GPUDevice, T>);
   1830 
   1831 TF_CALL_half(REGISTER_GPU);
   1832 TF_CALL_float(REGISTER_GPU);
   1833 TF_CALL_double(REGISTER_GPU);
   1834 #undef REGISTER_GPU
   1835 
   1836 template <typename T>
   1837 class CudnnRNNBackwardOpV2<GPUDevice, T>
   1838     : public CudnnRNNBackwardOp<GPUDevice, T> {
   1839  public:
   1840   explicit CudnnRNNBackwardOpV2(OpKernelConstruction* context)
   1841       : CudnnRNNBackwardOp<GPUDevice, T>(context) {}
   1842 
   1843  protected:
   1844   Status GetAlgorithm(OpKernelContext* context,
   1845                       AlgorithmConfig* algo_config) override {
   1846     CHECK_NE(algo_config, nullptr);
   1847     const Tensor* host_reserved = nullptr;
   1848     TF_RETURN_IF_ERROR(context->input("host_reserved", &host_reserved));
   1849 
   1850     auto host_reserved_int8 = host_reserved->vec<int8>();
   1851     const AlgorithmDesc algo_desc(host_reserved_int8(0), host_reserved_int8(1));
   1852     algo_config->set_algorithm(algo_desc);
   1853     return Status::OK();
   1854   }
   1855 };
   1856 
   1857 #define REGISTER_GPU(T)                                    \
   1858   REGISTER_KERNEL_BUILDER(Name("CudnnRNNBackpropV2")       \
   1859                               .Device(DEVICE_GPU)          \
   1860                               .HostMemory("host_reserved") \
   1861                               .TypeConstraint<T>("T"),     \
   1862                           CudnnRNNBackwardOpV2<GPUDevice, T>);
   1863 
   1864 TF_CALL_half(REGISTER_GPU);
   1865 TF_CALL_float(REGISTER_GPU);
   1866 TF_CALL_double(REGISTER_GPU);
   1867 #undef REGISTER_GPU
   1868 
   1869 template <typename T>
   1870 class CudnnRNNBackwardOpV3<GPUDevice, T>
   1871     : public CudnnRNNBackwardOp<GPUDevice, T> {
   1872  private:
   1873   bool time_major_;
   1874 
   1875  protected:
   1876   bool time_major() { return time_major_; }
   1877 
   1878  public:
   1879   explicit CudnnRNNBackwardOpV3(OpKernelConstruction* context)
   1880       : CudnnRNNBackwardOp<GPUDevice, T>(context) {
   1881     OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_));
   1882   }
   1883 
   1884   void Compute(OpKernelContext* context) override {
   1885     CudnnRNNBackwardOp<GPUDevice, T>::ComputeImpl(context, true, time_major());
   1886   }
   1887 };
   1888 
   1889 #define REGISTER_GPU(T)                                       \
   1890   REGISTER_KERNEL_BUILDER(Name("CudnnRNNBackpropV3")          \
   1891                               .Device(DEVICE_GPU)             \
   1892                               .HostMemory("sequence_lengths") \
   1893                               .HostMemory("host_reserved")    \
   1894                               .TypeConstraint<T>("T"),        \
   1895                           CudnnRNNBackwardOpV3<GPUDevice, T>);
   1896 
   1897 TF_CALL_half(REGISTER_GPU);
   1898 TF_CALL_float(REGISTER_GPU);
   1899 TF_CALL_double(REGISTER_GPU);
   1900 #undef REGISTER_GPU
   1901 
   1902 // TODO(zhengxq): Add the conversion of Cudnn RNN Params from and to
   1903 // its canonical form.
   1904 
   1905 #endif  // GOOGLE_CUDA
   1906 
   1907 }  // namespace tensorflow
   1908