Home | History | Annotate | Download | only in cuda
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/stream_executor/cuda/cuda_dnn.h"
     17 
     18 #include <functional>
     19 #include <memory>
     20 
     21 #include "third_party/eigen3/Eigen/Core"
     22 #include "tensorflow/core/util/env_var.h"
     23 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
     24 #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
     25 #include "tensorflow/stream_executor/cuda/cuda_driver.h"
     26 #include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
     27 #include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
     28 #include "tensorflow/stream_executor/cuda/cuda_stream.h"
     29 #include "tensorflow/stream_executor/cuda/cuda_timer.h"
     30 #include "tensorflow/stream_executor/dnn.h"
     31 #include "tensorflow/stream_executor/lib/env.h"
     32 #include "tensorflow/stream_executor/lib/error.h"
     33 #include "tensorflow/stream_executor/lib/initialize.h"
     34 #include "tensorflow/stream_executor/lib/strcat.h"
     35 #include "tensorflow/stream_executor/lib/stringpiece.h"
     36 #include "tensorflow/stream_executor/lib/threadpool.h"
     37 #include "tensorflow/stream_executor/platform/logging.h"
     38 #include "tensorflow/stream_executor/plugin_registry.h"
     39 #include "tensorflow/stream_executor/scratch_allocator.h"
     40 #include "tensorflow/stream_executor/stream.h"
     41 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
     42 // clang-format off
     43 #include "cuda/include/cudnn.h"
     44 // clang-format on
     45 
     46 namespace {
     47 
     48 // Converts (via narrowing) a type T value to a type U, and checks that the
     49 // value has no value change due to the conversion.
     50 template <typename WideT, typename NarrowT>
     51 NarrowT CheckedNarrowing(const WideT& wide) {
     52   NarrowT narrow = wide;
     53   CHECK_EQ(narrow, wide)
     54       << "checked narrowing failed; values not equal post-conversion";
     55   return narrow;
     56 }
     57 
     58 // Returns the "Compatibility" version number from the CuDNN version number.
     59 // This is the number that tries to indicate ABI compatibility.
     60 //
     61 // For example, if cudnn_version is 5107, the compatibility version
     62 // number will be 5100.
     63 size_t cudnnCompatibilityVersion(size_t cudnn_version) {
     64   return (cudnn_version / 100) * 100;
     65 }
     66 
     67 }  // namespace
     68 
     69 namespace perftools {
     70 namespace gputools {
     71 
     72 using dnn::BatchDescriptor;
     73 using dnn::FilterDescriptor;
     74 using dnn::ConvolutionDescriptor;
     75 using dnn::PoolingDescriptor;
     76 using dnn::NormalizeDescriptor;
     77 
     78 namespace cuda {
     79 
     80 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuDnnPlugin);
     81 
     82 string ToString(cudnnStatus_t status) {
     83   switch (status) {
     84     case CUDNN_STATUS_SUCCESS:
     85       return "CUDNN_STATUS_SUCCESS";
     86     case CUDNN_STATUS_NOT_INITIALIZED:
     87       return "CUDNN_STATUS_NOT_INITIALIZED";
     88     case CUDNN_STATUS_ALLOC_FAILED:
     89       return "CUDNN_STATUS_ALLOC_FAILED";
     90     case CUDNN_STATUS_BAD_PARAM:
     91       return "CUDNN_STATUS_BAD_PARAM";
     92     case CUDNN_STATUS_INTERNAL_ERROR:
     93       return "CUDNN_STATUS_INTERNAL_ERROR";
     94     case CUDNN_STATUS_INVALID_VALUE:
     95       return "CUDNN_STATUS_INVALID_VALUE";
     96     case CUDNN_STATUS_ARCH_MISMATCH:
     97       return "CUDNN_STATUS_ARCH_MISMATCH";
     98     case CUDNN_STATUS_MAPPING_ERROR:
     99       return "CUDNN_STATUS_MAPPING_ERROR";
    100     case CUDNN_STATUS_EXECUTION_FAILED:
    101       return "CUDNN_STATUS_EXECUTION_FAILED";
    102     case CUDNN_STATUS_NOT_SUPPORTED:
    103       return "CUDNN_STATUS_NOT_SUPPORTED";
    104     case CUDNN_STATUS_LICENSE_ERROR:
    105       return "CUDNN_STATUS_LICENSE_ERROR";
    106     default:
    107       return port::StrCat("<unknown cudnn status: ", static_cast<int>(status),
    108                           ">");
    109   }
    110 }
    111 
    112 template <typename T>
    113 cudnnDataType_t GetCudnnDataType();
    114 
    115 template <>
    116 cudnnDataType_t GetCudnnDataType<double>() {
    117   return CUDNN_DATA_DOUBLE;
    118 }
    119 
    120 template <>
    121 cudnnDataType_t GetCudnnDataType<float>() {
    122   return CUDNN_DATA_FLOAT;
    123 }
    124 
    125 template <>
    126 cudnnDataType_t GetCudnnDataType<Eigen::half>() {
    127   return CUDNN_DATA_HALF;
    128 }
    129 
    130 namespace wrap {
    131 
    132 static port::ThreadPool* InitCudnnThreadpool() {
    133   port::ThreadPool* cudnn_threadpool_;
    134   port::ThreadOptions options;
    135   // TBD(keveman): Conservatively setting the stack size and guard size to 2MB,
    136   // until we can get some guarantees from NVIDIA on the minimum stack space
    137   // they will work with.
    138   options.stack_size = 2 * 1024 * 1024;
    139   options.guard_size = 2 * 1024 * 1024;
    140   cudnn_threadpool_ = new port::ThreadPool(port::Env::Default(), options,
    141                                            "cudnn_threadpool", 1);
    142   CHECK(cudnn_threadpool_);
    143   return cudnn_threadpool_;
    144 }
    145 
    146 static mutex cudnn_threadpool_mu(LINKER_INITIALIZED);
    147 static port::ThreadPool* GetCudaThreadpool() {
    148   mutex_lock lock(cudnn_threadpool_mu);
    149   static port::ThreadPool* cudnn_threadpool = InitCudnnThreadpool();
    150   return cudnn_threadpool;
    151 }
    152 
    153 #define PERFTOOLS_GPUTOOLS_CUDNN_WRAP(__name)                      \
    154   struct WrapperShim__##__name {                                   \
    155     template <typename... Args>                                    \
    156     cudnnStatus_t operator()(CUDAExecutor* parent, Args... args) { \
    157       cuda::ScopedActivateExecutorContext sac{parent};             \
    158       cudnnStatus_t retval = ::__name(args...);                    \
    159       return retval;                                               \
    160     }                                                              \
    161   } __name;
    162 
    163 // clang-format off
    164 #define CUDNN_DNN_ROUTINE_EACH(__macro)                   \
    165   __macro(cudnnBatchNormalizationBackward)                \
    166   __macro(cudnnBatchNormalizationForwardInference)        \
    167   __macro(cudnnBatchNormalizationForwardTraining)         \
    168   __macro(cudnnGetConvolutionNdForwardOutputDim)          \
    169   __macro(cudnnGetConvolutionForwardAlgorithm)            \
    170   __macro(cudnnCreateTensorDescriptor)                    \
    171   __macro(cudnnDestroyTensorDescriptor)                   \
    172   __macro(cudnnCreateFilterDescriptor)                    \
    173   __macro(cudnnSetPoolingNdDescriptor)                    \
    174   __macro(cudnnSetLRNDescriptor)                          \
    175   __macro(cudnnDestroyFilterDescriptor)                   \
    176   __macro(cudnnCreateConvolutionDescriptor)               \
    177   __macro(cudnnCreatePoolingDescriptor)                   \
    178   __macro(cudnnDestroyPoolingDescriptor)                  \
    179   __macro(cudnnCreateLRNDescriptor)                       \
    180   __macro(cudnnDestroyLRNDescriptor)                      \
    181   __macro(cudnnDestroyConvolutionDescriptor)              \
    182   __macro(cudnnCreate)                                    \
    183   __macro(cudnnDestroy)                                   \
    184   __macro(cudnnSetStream)                                 \
    185   __macro(cudnnActivationForward)                         \
    186   __macro(cudnnConvolutionForward)                        \
    187   __macro(cudnnConvolutionBackwardBias)                   \
    188   __macro(cudnnGetConvolutionForwardWorkspaceSize)        \
    189   __macro(cudnnTransformTensor)                           \
    190   __macro(cudnnSetConvolutionNdDescriptor)                \
    191   __macro(cudnnSetTensor4dDescriptor)                     \
    192   __macro(cudnnSetTensorNdDescriptor)                     \
    193   __macro(cudnnSetFilterNdDescriptor)                     \
    194   __macro(cudnnPoolingForward)                            \
    195   __macro(cudnnPoolingBackward)                           \
    196   __macro(cudnnLRNCrossChannelForward)                    \
    197   __macro(cudnnLRNCrossChannelBackward)                   \
    198   __macro(cudnnAddTensor)                                 \
    199   __macro(cudnnConvolutionBackwardData)                   \
    200   __macro(cudnnConvolutionBackwardFilter)
    201 // clang-format on
    202 
    203 CUDNN_DNN_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
    204 
    205 // APIs available after R3:
    206 #if CUDNN_VERSION >= 3000
    207 #define CUDNN_DNN_ROUTINE_EACH_AFTER_R3(__macro)              \
    208   __macro(cudnnGetConvolutionBackwardFilterWorkspaceSize)     \
    209   __macro(cudnnGetConvolutionBackwardDataAlgorithm)           \
    210   __macro(cudnnGetConvolutionBackwardFilterAlgorithm)         \
    211   __macro(cudnnGetConvolutionBackwardDataWorkspaceSize)
    212 CUDNN_DNN_ROUTINE_EACH_AFTER_R3(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
    213 #undef CUDNN_DNN_ROUTINE_EACH_AFTER_R3
    214 #endif
    215 
    216 // APIs in R3 but not in R5
    217 // clang-format off
    218 #if CUDNN_VERSION >= 3000 && CUDNN_VERSION < 5000
    219 #define CUDNN_DNN_ROUTINE_EACH_R3(__macro)                    \
    220   __macro(cudnnAddTensor_v3)                                  \
    221   __macro(cudnnConvolutionBackwardData_v3)                    \
    222   __macro(cudnnConvolutionBackwardFilter_v3)
    223 // clang-format on
    224 
    225 CUDNN_DNN_ROUTINE_EACH_R3(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
    226 #undef CUDNN_DNN_ROUTINE_EACH_R3
    227 #endif
    228 
    229 // APIs in R5
    230 // clang-format off
    231 #if CUDNN_VERSION >= 5000
    232 #define CUDNN_DNN_ROUTINE_EACH_R5(__macro)                    \
    233   __macro(cudnnCreateActivationDescriptor)                    \
    234   __macro(cudnnSetActivationDescriptor)                       \
    235   __macro(cudnnGetActivationDescriptor)                       \
    236   __macro(cudnnDestroyActivationDescriptor)                   \
    237   __macro(cudnnCreateDropoutDescriptor)                       \
    238   __macro(cudnnDestroyDropoutDescriptor)                      \
    239   __macro(cudnnSetDropoutDescriptor)                          \
    240   __macro(cudnnDropoutGetStatesSize)                          \
    241   __macro(cudnnCreateRNNDescriptor)                           \
    242   __macro(cudnnDestroyRNNDescriptor)                          \
    243   __macro(cudnnGetRNNParamsSize)                              \
    244   __macro(cudnnGetRNNWorkspaceSize)                           \
    245   __macro(cudnnGetRNNTrainingReserveSize)                     \
    246   __macro(cudnnGetRNNLinLayerMatrixParams)                    \
    247   __macro(cudnnGetRNNLinLayerBiasParams)                      \
    248   __macro(cudnnRNNForwardInference)                           \
    249   __macro(cudnnRNNForwardTraining)                            \
    250   __macro(cudnnRNNBackwardData)                               \
    251   __macro(cudnnRNNBackwardWeights)                            \
    252   __macro(cudnnSetRNNDescriptor)                              \
    253   __macro(cudnnGetFilterNdDescriptor)
    254 
    255 // clang-format on
    256 
    257 CUDNN_DNN_ROUTINE_EACH_R5(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
    258 #undef CUDNN_DNN_ROUTINE_EACH_R5
    259 #endif
    260 
    261 // APIs in R6
    262 // clang-format off
    263 #if CUDNN_VERSION >= 6000
    264 #define CUDNN_DNN_ROUTINE_EACH_R6(__macro)                    \
    265   __macro(cudnnConvolutionBiasActivationForward)              \
    266   __macro(cudnnSetRNNDescriptor_v6)
    267 
    268 // clang-format on
    269 CUDNN_DNN_ROUTINE_EACH_R6(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
    270 #undef CUDNN_DNN_ROUTINE_EACH_R6
    271 #endif
    272 
    273 // APIs in R7
    274 // clang-format off
    275 #if CUDNN_VERSION >= 7000
    276 #define CUDNN_DNN_ROUTINE_EACH_R7(__macro)                    \
    277   __macro(cudnnSetConvolutionMathType)
    278 
    279 // clang-format on
    280 CUDNN_DNN_ROUTINE_EACH_R7(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
    281 #undef CUDNN_DNN_ROUTINE_EACH_R7
    282 #endif
    283 
    284 #undef CUDNN_DNN_ROUTINE_EACH
    285 
    286 }  // namespace wrap
    287 
    288 namespace {
    289 
    290 cudnnHandle_t ToHandle(void* opaque_handle) {
    291   return static_cast<cudnnHandle_t>(opaque_handle);
    292 }
    293 
    294 cudnnConvolutionFwdAlgo_t ToConvForwardAlgo(dnn::AlgorithmDesc algorithm) {
    295   cudnnConvolutionFwdAlgo_t algo =
    296       cudnnConvolutionFwdAlgo_t(algorithm.algo_id());
    297   switch (algo) {
    298     case CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM:
    299     case CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM:
    300     case CUDNN_CONVOLUTION_FWD_ALGO_GEMM:
    301     case CUDNN_CONVOLUTION_FWD_ALGO_DIRECT:
    302     case CUDNN_CONVOLUTION_FWD_ALGO_FFT:
    303     case CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING:
    304 #if CUDNN_VERSION >= 5000
    305     case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD:
    306 #endif
    307 #if CUDNN_VERSION >= 5100
    308     case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED:
    309 #endif
    310       return algo;
    311     default:
    312       LOG(FATAL) << "Unsupported Cudnn convolution forward algorithm: "
    313                  << algorithm.algo_id();
    314   }
    315 }
    316 
    317 cudnnConvolutionBwdDataAlgo_t ToConvBackwardDataAlgo(
    318     dnn::AlgorithmDesc algorithm) {
    319   cudnnConvolutionBwdDataAlgo_t algo =
    320       cudnnConvolutionBwdDataAlgo_t(algorithm.algo_id());
    321   switch (algo) {
    322     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_0:
    323     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_1:
    324     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT:
    325     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING:
    326 #if CUDNN_VERSION >= 5000
    327     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD:
    328 #endif
    329 #if CUDNN_VERSION >= 5100
    330     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED:
    331 #endif
    332       return algo;
    333     default:
    334       LOG(FATAL)
    335           << "Unsupported Cudnn convolution backward algorithm for data: "
    336           << algorithm.algo_id();
    337   }
    338 }
    339 
    340 cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo(
    341     dnn::AlgorithmDesc algorithm) {
    342   cudnnConvolutionBwdFilterAlgo_t algo =
    343       cudnnConvolutionBwdFilterAlgo_t(algorithm.algo_id());
    344   switch (algo) {
    345     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0:
    346     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1:
    347     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT:
    348     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3:
    349 #if CUDNN_VERSION >= 5100
    350     // Based on cudnn.h, the following is not implemented.
    351     // case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD:
    352     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED:
    353 #endif
    354       return algo;
    355     default:
    356       LOG(FATAL)
    357           << "Unsupported Cudnn convolution backward algorithm for filter: "
    358           << algorithm.algo_id();
    359   }
    360 }
    361 
    362 }  // namespace
    363 
    364 CudnnSupport::CudnnSupport(CUDAExecutor* parent)
    365     : parent_(parent), dnn_handle_(nullptr) {}
    366 
    367 CudnnSupport::~CudnnSupport() {
    368   auto status = wrap::cudnnDestroy(parent_, ToHandle(dnn_handle_));
    369   if (status != CUDNN_STATUS_SUCCESS) {
    370     LOG(ERROR) << "could not destroy cudnn handle: " << ToString(status);
    371   }
    372 }
    373 
    374 port::Status CudnnSupport::Init() {
    375   auto status = wrap::cudnnCreate(
    376       parent_, reinterpret_cast<cudnnHandle_t*>(&dnn_handle_));
    377   if (status == CUDNN_STATUS_SUCCESS) {
    378     // Check whether loaded version of CuDNN matches what the source
    379     // was built with.
    380     size_t loaded_version = ::cudnnGetVersion();
    381     size_t loaded_compat_version = cudnnCompatibilityVersion(loaded_version);
    382     size_t compiled_compat_version = cudnnCompatibilityVersion(CUDNN_VERSION);
    383     bool library_loaded_matches_source =
    384         (loaded_compat_version == compiled_compat_version);
    385     if (!library_loaded_matches_source) {
    386       const string error =
    387           port::StrCat("Loaded runtime CuDNN library: ", loaded_version,
    388                        " (compatibility version ", loaded_compat_version,
    389                        ") but source was compiled with ", CUDNN_VERSION,
    390                        " (compatibility version ", compiled_compat_version,
    391                        ").  If using a binary install, upgrade your CuDNN "
    392                        "library to match.  If building from sources, "
    393                        "make sure the library loaded at runtime matches a "
    394                        "compatible version specified during compile "
    395                        "configuration.");
    396       LOG(ERROR) << error;
    397       return port::Status{port::error::INTERNAL, error};
    398     }
    399 
    400     return port::Status::OK();
    401   }
    402 
    403   LOG(ERROR) << "could not create cudnn handle: " << ToString(status);
    404   if (status == CUDNN_STATUS_NOT_INITIALIZED) {
    405     auto result = cuda::Diagnostician::FindKernelDriverVersion();
    406     if (!result.ok()) {
    407       LOG(ERROR) << "error retrieving driver version: "
    408                  << DriverVersionStatusToString(result);
    409     } else {
    410       const auto& version = result.ValueOrDie();
    411       LOG(ERROR) << "possibly insufficient driver version: "
    412                  << DriverVersionToString(version);
    413       // OS X kernel driver does not report version accurately
    414 #if !defined(__APPLE__)
    415       if (std::get<0>(version) < 340) {
    416         LOG(ERROR)
    417             << "cudnn library is only supported on 340.XX+ driver versions";
    418       }
    419 #endif
    420     }
    421   }
    422 
    423   return port::Status{port::error::INTERNAL,
    424                       port::StrCat("cudnn library could not create a handle: ",
    425                                    ToString(status))};
    426 }
    427 
    428 // Turns a BatchDescriptor structure into a cudnn tensor handle within a scope.
    429 class ScopedTensorDescriptor {
    430  public:
    431   ScopedTensorDescriptor(CUDAExecutor* parent,
    432                          const BatchDescriptor& batch_descriptor,
    433                          cudnnDataType_t elem_type)
    434       : parent_(parent), handle_(nullptr) {
    435     cudnnStatus_t status = wrap::cudnnCreateTensorDescriptor(parent_, &handle_);
    436     if (status != CUDNN_STATUS_SUCCESS) {
    437       LOG(FATAL) << "could not create cudnn tensor descriptor: "
    438                  << ToString(status);
    439     }
    440 
    441     switch (batch_descriptor.layout()) {
    442       case dnn::DataLayout::kBatchYXDepth:
    443       case dnn::DataLayout::kBatchDepthYX: {
    444         const int nd = batch_descriptor.ndims() + 2;
    445         // cuDNN requires the strides and dims to be ordered as BDYX.
    446         std::vector<int64> strides64 =
    447             batch_descriptor.full_strides(dnn::DataLayout::kBatchDepthYX);
    448         std::vector<int64> dims64 =
    449             batch_descriptor.full_dims(dnn::DataLayout::kBatchDepthYX);
    450 
    451         // cuDNN requires arrays of ints.
    452         std::vector<int> strides(nd);
    453         std::vector<int> dims(nd);
    454         std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
    455                        &CheckedNarrowing<int64, int>);
    456         std::transform(dims64.cbegin(), dims64.cend(), dims.begin(),
    457                        &CheckedNarrowing<int64, int>);
    458         status = wrap::cudnnSetTensorNdDescriptor(
    459             parent_, handle_, elem_type, nd, dims.data(), strides.data());
    460 
    461         if (status != CUDNN_STATUS_SUCCESS) {
    462           LOG(FATAL) << "could not convert BatchDescriptor "
    463                      << batch_descriptor.ToString()
    464                      << " to cudnn tensor descriptor: " << ToString(status);
    465         }
    466       } break;
    467 #if CUDNN_VERSION >= 6000
    468       case dnn::DataLayout::kBatchDepthYX4: {
    469         status = wrap::cudnnSetTensor4dDescriptor(
    470             parent_, handle_, CUDNN_TENSOR_NCHW_VECT_C, elem_type,
    471             batch_descriptor.count(), batch_descriptor.feature_map_count(),
    472             batch_descriptor.height(), batch_descriptor.width());
    473         if (status != CUDNN_STATUS_SUCCESS) {
    474           LOG(FATAL) << "could not convert BatchDescriptor "
    475                      << batch_descriptor.ToString()
    476                      << " to cudnn tensor descriptor: " << ToString(status);
    477         }
    478       } break;
    479 #endif
    480       default:
    481         LOG(FATAL) << "Unsupported tensor format "
    482                    << DataLayoutString(batch_descriptor.layout());
    483         break;
    484     }
    485   }
    486 
    487   ~ScopedTensorDescriptor() {
    488     cudnnStatus_t status = wrap::cudnnDestroyTensorDescriptor(parent_, handle_);
    489     if (status != CUDNN_STATUS_SUCCESS) {
    490       LOG(ERROR) << "could not destroy cudnn tensor descriptor: "
    491                  << ToString(status);
    492     }
    493   }
    494 
    495   cudnnTensorDescriptor_t handle() const { return handle_; }
    496 
    497  private:
    498   CUDAExecutor* parent_;            // Parent executor. Not owned.
    499   cudnnTensorDescriptor_t handle_;  // Owned.
    500 
    501   SE_DISALLOW_COPY_AND_ASSIGN(ScopedTensorDescriptor);
    502 };
    503 
    504 // Turns a FilterDescriptor structure into a cudnn filter handle within a scope.
    505 class ScopedFilterDescriptor {
    506  public:
    507   ScopedFilterDescriptor(CUDAExecutor* parent,
    508                          const FilterDescriptor& filter_descriptor,
    509                          const BatchDescriptor& batch_descriptor,
    510                          cudnnDataType_t elem_type)
    511       : parent_(parent), handle_(nullptr) {
    512     cudnnStatus_t status = wrap::cudnnCreateFilterDescriptor(parent_, &handle_);
    513     if (status != CUDNN_STATUS_SUCCESS) {
    514       LOG(FATAL) << "could not create cudnn filter descriptor: "
    515                  << ToString(status);
    516     }
    517 
    518 #if CUDNN_VERSION >= 5000
    519     // TODO(b/23032134): Even if the filter layout is not supported,
    520     // cudnnSetFilter4DDescriptor_v4 will return CUDNN_STATUS_SUCCESS because it
    521     // does not take layout as an input. Maybe force cuDNN by giving wrong
    522     // inputs intentionally?
    523     cudnnTensorFormat_t format;
    524     switch (filter_descriptor.layout()) {
    525       case dnn::FilterLayout::kOutputInputYX:
    526         format = CUDNN_TENSOR_NCHW;
    527         break;
    528 #if CUDNN_VERSION >= 6000
    529       case dnn::FilterLayout::kOutputInputYX4:
    530         format = CUDNN_TENSOR_NCHW_VECT_C;
    531         break;
    532 #endif
    533       default:
    534         LOG(FATAL) << "Unsupported filter format "
    535                    << FilterLayoutString(filter_descriptor.layout());
    536         break;
    537     }
    538 #endif
    539 
    540     std::vector<int> dims(2 + filter_descriptor.ndims());
    541     dims[0] = filter_descriptor.output_feature_map_count();
    542     dims[1] = filter_descriptor.input_feature_map_count();
    543     const auto& spatial_dims = filter_descriptor.input_filter_dims();
    544     std::copy(spatial_dims.begin(), spatial_dims.end(), dims.begin() + 2);
    545 
    546     status = wrap::cudnnSetFilterNdDescriptor(parent_, handle_, elem_type,
    547 #if CUDNN_VERSION >= 5000
    548                                               format,
    549 #endif
    550                                               dims.size(), dims.data());
    551     if (status != CUDNN_STATUS_SUCCESS) {
    552       LOG(FATAL) << "could not set cudnn filter descriptor: "
    553                  << ToString(status);
    554     }
    555   }
    556 
    557   ~ScopedFilterDescriptor() {
    558     cudnnStatus_t status = wrap::cudnnDestroyFilterDescriptor(parent_, handle_);
    559     if (status != CUDNN_STATUS_SUCCESS) {
    560       LOG(ERROR) << "could not destroy cudnn filter descriptor: "
    561                  << ToString(status);
    562     }
    563   }
    564 
    565   cudnnFilterDescriptor_t handle() const { return handle_; }
    566 
    567  private:
    568   // Parent executor object. Not owned.
    569   CUDAExecutor* parent_;
    570 
    571   // cudnn filter descriptor this object creates. Owned.
    572   cudnnFilterDescriptor_t handle_;
    573 
    574   SE_DISALLOW_COPY_AND_ASSIGN(ScopedFilterDescriptor);
    575 };
    576 
    577 // A helper function to decide whether to enable the TENSOR_OP_MATH math type
    578 static bool TensorOpMathEnabled() {
    579   static bool is_enabled = [] {
    580     bool is_disabled = false;
    581     TF_CHECK_OK(
    582         tensorflow::ReadBoolFromEnvVar("TF_DISABLE_CUDNN_TENSOR_OP_MATH",
    583                                        /*default_val=*/false, &is_disabled));
    584     return !is_disabled;
    585   }();
    586   return is_enabled;
    587 }
    588 
    589 // A helper function to decide whether to use CUDNN_BATCHNORM_SPATIAL_PERSISTENT
    590 // in batchnorm. This mode can be faster in some tasks because an optimized path
    591 // may be selected for CUDNN_DATA_FLOAT and CUDNN_DATA_HALF data types, compute
    592 // capability 6.0 or higher. The reason we set it to false by default is that
    593 // this mode may use scaled atomic integer reduction that may cause a numerical
    594 // overflow for certain input data range.
    595 // TODO(yangzihao): Use autotune to choose between this mode and
    596 // CUDNN_BATCHNORM_SPATIAL mode.
    597 static bool BatchnormSpatialPersistentEnabled() {
    598   static bool is_enabled = [] {
    599     bool is_enabled = false;
    600     TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
    601         "TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT",
    602         /*default_val=*/false, &is_enabled));
    603     return is_enabled;
    604   }();
    605   return is_enabled;
    606 }
    607 
    608 // Turns a ConvolutionDescriptor structure into a cudnn convolution handle
    609 // within a scope.
    610 class ScopedConvolutionDescriptor {
    611  public:
    612   ScopedConvolutionDescriptor(
    613       CUDAExecutor* parent, const ConvolutionDescriptor& convolution_descriptor,
    614       cudnnDataType_t data_type)
    615       : parent_(parent), handle_(nullptr) {
    616     cudnnStatus_t status =
    617         wrap::cudnnCreateConvolutionDescriptor(parent_, &handle_);
    618     if (status != CUDNN_STATUS_SUCCESS) {
    619       LOG(FATAL) << "could not create cudnn convolution descriptor: "
    620                  << ToString(status);
    621     }
    622     const auto& strides64 = convolution_descriptor.strides();
    623     const auto& padding64 = convolution_descriptor.padding();
    624     const auto& dilations64 = convolution_descriptor.dilations();
    625     if (convolution_descriptor.pad_alignment() ==
    626         dnn::PadAlignment::kTensorFlowPadding) {
    627       LOG(ERROR) << "TensorFlow padding alignment is not supported.";
    628     }
    629 
    630     // cuDNN requires arrays of ints.
    631     std::vector<int> strides(convolution_descriptor.ndims());
    632     std::vector<int> padding(convolution_descriptor.ndims());
    633     std::vector<int> dilations(convolution_descriptor.ndims());
    634     std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
    635                    &CheckedNarrowing<int64, int>);
    636     std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
    637                    &CheckedNarrowing<int64, int>);
    638     // TODO(yangzihao): Test with negative dilation to make sure that cudnn
    639     // doesn't crash.
    640     std::transform(dilations64.cbegin(), dilations64.cend(), dilations.begin(),
    641                    &CheckedNarrowing<int64, int>);
    642 
    643     status = wrap::cudnnSetConvolutionNdDescriptor(
    644         parent_, handle_, convolution_descriptor.ndims(), padding.data(),
    645         strides.data(), dilations.data(),
    646         // NOTE(keveman): cuDNN supports convolution and cross correlation.
    647         // However, almost all the use cases do cross correlation, so just
    648         // hard coding it here.
    649         CUDNN_CROSS_CORRELATION, data_type);
    650 
    651     if (status != CUDNN_STATUS_SUCCESS) {
    652       LOG(FATAL) << "could not set cudnn convolution descriptor: "
    653                  << ToString(status);
    654     }
    655     // NOTE(benbarsdell): This only applies if tensor op math is enabled
    656     //                      and algo selection is set to Default.
    657     this->set_use_tensor_op_math(true);
    658   }
    659 
    660   void set_use_tensor_op_math(bool use_tensor_op_math) {
    661 #if CUDNN_VERSION >= 7000
    662     cudnnMathType_t math_type =
    663         (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH);
    664     if (TensorOpMathEnabled()) {
    665       cudnnStatus_t status =
    666           wrap::cudnnSetConvolutionMathType(parent_, handle_, math_type);
    667       if (status != CUDNN_STATUS_SUCCESS) {
    668         LOG(FATAL) << "could not set cudnn convolution math type: "
    669                    << ToString(status);
    670       }
    671     }
    672 #endif
    673   }
    674 
    675   ~ScopedConvolutionDescriptor() {
    676     cudnnStatus_t status =
    677         wrap::cudnnDestroyConvolutionDescriptor(parent_, handle_);
    678     if (status != CUDNN_STATUS_SUCCESS) {
    679       LOG(ERROR) << "could not destroy cudnn convolution descriptor: "
    680                  << ToString(status);
    681     }
    682   }
    683 
    684   cudnnConvolutionDescriptor_t handle() const { return handle_; }
    685 
    686  private:
    687   CUDAExecutor* parent_;                 // Parent executor. Not owned.
    688   cudnnConvolutionDescriptor_t handle_;  // Owned.
    689 
    690   SE_DISALLOW_COPY_AND_ASSIGN(ScopedConvolutionDescriptor);
    691 };
    692 
    693 // Turns a PoolingDescriptor structure into a cudnn pooling descriptor handle
    694 // within a scope.
    695 class ScopedPoolingDescriptor {
    696  public:
    697   ScopedPoolingDescriptor(CUDAExecutor* parent,
    698                           const PoolingDescriptor& pooling_descriptor)
    699       : parent_(parent), handle_(nullptr) {
    700     cudnnStatus_t status =
    701         wrap::cudnnCreatePoolingDescriptor(parent_, &handle_);
    702     if (status != CUDNN_STATUS_SUCCESS) {
    703       LOG(FATAL) << "could not create cudnn pooling descriptor: "
    704                  << ToString(status);
    705     }
    706     const std::vector<int64> strides64 = pooling_descriptor.strides();
    707     const std::vector<int64> padding64 = pooling_descriptor.padding();
    708     const std::vector<int64> shape64 = pooling_descriptor.window();
    709 
    710     const int nd = pooling_descriptor.ndims();
    711     std::vector<int> shape(nd);
    712     std::vector<int> padding(nd);
    713     std::vector<int> strides(nd);
    714     std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
    715                    &CheckedNarrowing<int64, int>);
    716     std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
    717                    &CheckedNarrowing<int64, int>);
    718     std::transform(shape64.cbegin(), shape64.cend(), shape.begin(),
    719                    &CheckedNarrowing<int64, int>);
    720     bool propagate_nans = pooling_descriptor.propagate_nans();
    721     status = wrap::cudnnSetPoolingNdDescriptor(
    722         parent_, handle_,
    723         (pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
    724              ? CUDNN_POOLING_MAX
    725              : CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING),
    726 #if CUDNN_VERSION >= 5000
    727         propagate_nans ? CUDNN_PROPAGATE_NAN : CUDNN_NOT_PROPAGATE_NAN,
    728 #endif
    729         nd, shape.data(), padding.data(), strides.data());
    730     if (status != CUDNN_STATUS_SUCCESS) {
    731       LOG(FATAL) << "could not set cudnn pooling descriptor: "
    732                  << ToString(status);
    733     }
    734   }
    735   ~ScopedPoolingDescriptor() {
    736     cudnnStatus_t status =
    737         wrap::cudnnDestroyPoolingDescriptor(parent_, handle_);
    738     if (status != CUDNN_STATUS_SUCCESS) {
    739       LOG(ERROR) << "could not destroy cudnn pooling descriptor: "
    740                  << ToString(status);
    741     }
    742   }
    743 
    744   cudnnPoolingDescriptor_t handle() const { return handle_; }
    745 
    746  private:
    747   CUDAExecutor* parent_;             // Parent executor. Not owned.
    748   cudnnPoolingDescriptor_t handle_;  // Owned.
    749 
    750   SE_DISALLOW_COPY_AND_ASSIGN(ScopedPoolingDescriptor);
    751 };
    752 
    753 // Turns a NormalizeDescriptor structure into a cudnn LRN descriptor handle.
    754 class ScopedNormalizeDescriptor {
    755  public:
    756   ScopedNormalizeDescriptor(CUDAExecutor* parent,
    757                             const NormalizeDescriptor& normalize_descriptor)
    758       : parent_(parent), handle_(nullptr) {
    759     cudnnStatus_t status = wrap::cudnnCreateLRNDescriptor(parent_, &handle_);
    760     if (status != CUDNN_STATUS_SUCCESS) {
    761       LOG(FATAL) << "could not create cudnn LRN descriptor: "
    762                  << ToString(status);
    763     }
    764 
    765     // The range specifies that the indices in the closed range
    766     // [i - range, i + range] should be included in the normalization for index
    767     // i. The lrnN value is the total number of elements in the range, so
    768     // lrnN = 2*range + 1.
    769     unsigned lrnN = 2 * normalize_descriptor.range() + 1;
    770 
    771     // Note that SE defines the normalization operation as
    772     //
    773     //  U_i = V_i / ((bias +  alpha      * (sum_j V_j^2)) ^ beta)
    774     //
    775     // but cuDNN defines it as
    776     //
    777     //  U_i = V_i / ((bias + (alpha / n) * (sum_j V_j^2)) ^ beta)
    778     //
    779     // i.e. there is a factor of n difference between the meaning of the alphas
    780     // in the two contexts. The cuDNN alpha is n times the SE alpha.
    781     double lrnAlpha = lrnN * normalize_descriptor.alpha();
    782 
    783     double lrnBeta = normalize_descriptor.beta();
    784     double lrnK = normalize_descriptor.bias();
    785     status = wrap::cudnnSetLRNDescriptor(parent_, handle_, lrnN, lrnAlpha,
    786                                          lrnBeta, lrnK);
    787     if (status != CUDNN_STATUS_SUCCESS) {
    788       LOG(FATAL) << "could not set cudnn LRN descriptor: " << ToString(status);
    789     }
    790   }
    791 
    792   ~ScopedNormalizeDescriptor() {
    793     cudnnStatus_t status = wrap::cudnnDestroyLRNDescriptor(parent_, handle_);
    794     if (status != CUDNN_STATUS_SUCCESS) {
    795       LOG(ERROR) << "could not destroy cudnn LRN descriptor: "
    796                  << ToString(status);
    797     }
    798   }
    799 
    800   cudnnLRNDescriptor_t handle() const { return handle_; }
    801 
    802  private:
    803   CUDAExecutor* parent_;         // Parent executor. Not owned.
    804   cudnnLRNDescriptor_t handle_;  // Owned.
    805 
    806   SE_DISALLOW_COPY_AND_ASSIGN(ScopedNormalizeDescriptor);
    807 };
    808 
    809 #if CUDNN_VERSION >= 5000
    810 // Turns a ActivationDescriptor structure into a cudnn activation
    811 // descriptor handle within a scope.
    812 class ScopedActivationDescriptor {
    813  public:
    814   ScopedActivationDescriptor(CUDAExecutor* parent,
    815                              dnn::ActivationMode activation_mode,
    816                              cudnnNanPropagation_t nan_propagation,
    817                              double value_max)
    818       : parent_(parent), handle_(nullptr) {
    819     cudnnStatus_t status =
    820         wrap::cudnnCreateActivationDescriptor(parent_, &handle_);
    821     if (status != CUDNN_STATUS_SUCCESS) {
    822       LOG(FATAL) << "could not create cudnn activation descriptor: "
    823                  << ToString(status);
    824     }
    825 
    826     double relu_ceiling = 0.0;
    827     cudnnActivationMode_t mode;
    828     switch (activation_mode) {
    829       case dnn::ActivationMode::kRelu6:
    830         relu_ceiling = 6.0;
    831         mode = CUDNN_ACTIVATION_CLIPPED_RELU;
    832         break;
    833       case dnn::ActivationMode::kReluX:
    834         relu_ceiling = value_max;
    835         mode = CUDNN_ACTIVATION_CLIPPED_RELU;
    836         break;
    837       case dnn::ActivationMode::kRelu:
    838         mode = CUDNN_ACTIVATION_RELU;
    839         break;
    840       case dnn::ActivationMode::kSigmoid:
    841         mode = CUDNN_ACTIVATION_SIGMOID;
    842         break;
    843       case dnn::ActivationMode::kTanh:
    844         mode = CUDNN_ACTIVATION_TANH;
    845         break;
    846       default:
    847         LOG(FATAL) << "unrecognized activation mode: "
    848                    << static_cast<int>(activation_mode);
    849     }
    850 
    851     status = wrap::cudnnSetActivationDescriptor(parent_, handle_, mode,
    852                                                 nan_propagation, relu_ceiling);
    853     if (status != CUDNN_STATUS_SUCCESS) {
    854       LOG(FATAL) << "could not set cudnn activation descriptor: "
    855                  << ToString(status);
    856     }
    857   }
    858 
    859   ~ScopedActivationDescriptor() {
    860     cudnnStatus_t status =
    861         wrap::cudnnDestroyActivationDescriptor(parent_, handle_);
    862     if (status != CUDNN_STATUS_SUCCESS) {
    863       LOG(ERROR) << "could not destroy cudnn activation descriptor: "
    864                  << ToString(status);
    865     }
    866   }
    867 
    868   cudnnActivationDescriptor_t handle() const { return handle_; }
    869 
    870  private:
    871   CUDAExecutor* parent_;                // Parent executor. Not owned.
    872   cudnnActivationDescriptor_t handle_;  // Owned.
    873 
    874   SE_DISALLOW_COPY_AND_ASSIGN(ScopedActivationDescriptor);
    875 };
    876 #endif
    877 
    878 namespace {
    879 cudnnDataType_t ToCudnnDataType(
    880     dnn::DataType data_type,
    881     dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) {
    882   switch (data_type) {
    883     case dnn::DataType::kFloat:
    884     case dnn::DataType::kDouble:
    885     case dnn::DataType::kHalf:
    886       return static_cast<cudnnDataType_t>(data_type);
    887 #if CUDNN_VERSION >= 6000
    888     case dnn::DataType::kInt8:
    889       return data_layout == dnn::DataLayout::kBatchDepthYX4 ? CUDNN_DATA_INT8x4
    890                                                             : CUDNN_DATA_INT8;
    891 #endif
    892     default:
    893       LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
    894   }
    895 }
    896 
    897 #if CUDNN_VERSION >= 5000
    898 
    899 cudnnRNNInputMode_t ToCudnnRnnInputMode(dnn::RnnInputMode input_mode) {
    900   switch (input_mode) {
    901     case dnn::RnnInputMode::kRnnLinearSkip:
    902     case dnn::RnnInputMode::kRnnSkipInput:
    903       return static_cast<cudnnRNNInputMode_t>(input_mode);
    904     default:
    905       LOG(FATAL) << "Invalid RNN input mode: " << static_cast<int>(input_mode);
    906   }
    907 }
    908 
    909 cudnnDirectionMode_t ToCudnnRnnDirectionMode(
    910     dnn::RnnDirectionMode direction_mode) {
    911   switch (direction_mode) {
    912     case dnn::RnnDirectionMode::kRnnUnidirectional:
    913     case dnn::RnnDirectionMode::kRnnBidirectional:
    914       return static_cast<cudnnDirectionMode_t>(direction_mode);
    915     default:
    916       LOG(FATAL) << "Invalid RNN direction mode: "
    917                  << static_cast<int>(direction_mode);
    918   }
    919 }
    920 
    921 cudnnRNNMode_t ToCudnnRnnMode(dnn::RnnMode rnn_mode) {
    922   switch (rnn_mode) {
    923     case dnn::RnnMode::kRnnRelu:
    924     case dnn::RnnMode::kRnnTanh:
    925     case dnn::RnnMode::kRnnLstm:
    926     case dnn::RnnMode::kRnnGru:
    927       return static_cast<cudnnRNNMode_t>(rnn_mode);
    928     default:
    929       LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
    930   }
    931 }
    932 
    933 int CudnnDataTypeToByteSize(cudnnDataType_t data_type) {
    934   switch (data_type) {
    935     case CUDNN_DATA_FLOAT:
    936       return sizeof(float);
    937     case CUDNN_DATA_DOUBLE:
    938       return sizeof(double);
    939     case CUDNN_DATA_HALF:
    940       return sizeof(Eigen::half);
    941     default:
    942       LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
    943   }
    944 }
    945 
    946 #endif  // CUDNN_VERSION
    947 
    948 template <typename Base>
    949 class MixinBase : public Base {};
    950 template <>
    951 class MixinBase<void> {};
    952 
    953 }  // namespace
    954 
    955 #if CUDNN_VERSION >= 5000
    956 
    957 #define CUDNN_RETURN_IF_FAIL(STATUS, ...)                                \
    958   if (!SE_PREDICT_TRUE((STATUS) == CUDNN_STATUS_SUCCESS)) {              \
    959     string error_msg = port::StrCat(ToString(STATUS), " ", __VA_ARGS__); \
    960     SetFailure(port::Status(port::error::UNKNOWN, error_msg));           \
    961     LOG(ERROR) << error_msg;                                             \
    962     return;                                                              \
    963   }
    964 
    965 template <typename Base>
    966 class CudnnDescriptorCommon : public MixinBase<Base> {
    967  public:
    968   bool ok() const { return status_.ok(); }
    969   port::Status Status() const { return status_; }
    970 
    971  protected:
    972   void SetFailure(const port::Status& status) { status_.Update(status); }
    973   port::Status status_;
    974 };
    975 
    976 class CudnnDropoutDescriptor : public CudnnDescriptorCommon<void> {
    977  public:
    978   CudnnDropoutDescriptor(CUDAExecutor* parent, cudnnHandle_t cudnn_handle,
    979                          float dropout, uint64 seed,
    980                          ScratchAllocator* state_allocator)
    981       : parent_(parent), handle_(nullptr) {
    982     cudnnStatus_t status;
    983     status = wrap::cudnnCreateDropoutDescriptor(parent_, &handle_);
    984     CUDNN_RETURN_IF_FAIL(status, "Failed to create dropout descriptor");
    985 
    986     if (dropout == 0.f) {
    987       return;
    988     }
    989 
    990     DeviceMemory<uint8> state_memory;
    991     if (state_allocator) {
    992       size_t state_sizes_in_bytes = 0;
    993       status = wrap::cudnnDropoutGetStatesSize(parent_, cudnn_handle,
    994                                                &state_sizes_in_bytes);
    995       CUDNN_RETURN_IF_FAIL(status, "Failed to query dropout state sizes");
    996 
    997       auto allocated =
    998           state_allocator->AllocateBytes(nullptr, state_sizes_in_bytes);
    999       if (!allocated.ok() ||
   1000           (state_memory = allocated.ValueOrDie()) == nullptr) {
   1001         string error_msg =
   1002             port::StrCat("Failed to allocate Cudnn dropout state memory of ",
   1003                          state_sizes_in_bytes, " bytes.");
   1004         status_ = port::Status(port::error::UNKNOWN, error_msg);
   1005         LOG(ERROR) << error_msg;
   1006         return;
   1007       }
   1008     }
   1009     status = wrap::cudnnSetDropoutDescriptor(parent_, handle_, cudnn_handle,
   1010                                              dropout, state_memory.opaque(),
   1011                                              state_memory.size(), seed);
   1012     CUDNN_RETURN_IF_FAIL(
   1013         status, port::StrCat(
   1014                     "Failed to set dropout descriptor with state memory size: ",
   1015                     state_memory.size(), " bytes."));
   1016   }
   1017 
   1018   ~CudnnDropoutDescriptor() {
   1019     if (handle_) {
   1020       cudnnStatus_t status =
   1021           wrap::cudnnDestroyDropoutDescriptor(parent_, handle_);
   1022       CUDNN_RETURN_IF_FAIL(status, "Failed to destroy Cudnn dropout handle: ");
   1023     }
   1024   }
   1025 
   1026   cudnnDropoutDescriptor_t handle() const {
   1027     if (!ok()) return nullptr;
   1028     return handle_;
   1029   }
   1030 
   1031  private:
   1032   CUDAExecutor* parent_;
   1033   cudnnDropoutDescriptor_t handle_;
   1034   float dropout_;
   1035   uint64 seed_;
   1036   SE_DISALLOW_COPY_AND_ASSIGN(CudnnDropoutDescriptor);
   1037 };
   1038 
   1039 class CudnnRnnParamsDescriptor : public CudnnDescriptorCommon<void> {
   1040  public:
   1041   typedef dnn::RnnDescriptor::ParamsRegion ParamsRegion;
   1042   typedef dnn::RnnDescriptor::ParamsRegions ParamsRegions;
   1043   CudnnRnnParamsDescriptor(CUDAExecutor* parent, cudnnHandle_t cudnn_handle,
   1044                            const CudnnRnnDescriptor& rnn_desc);
   1045   ~CudnnRnnParamsDescriptor() {
   1046     cudnnStatus_t status = wrap::cudnnDestroyFilterDescriptor(parent_, handle_);
   1047     CUDNN_RETURN_IF_FAIL(status, "Failed to destroy RNN filter descriptor");
   1048   }
   1049   cudnnFilterDescriptor_t handle() const {
   1050     if (!ok()) return nullptr;
   1051     return handle_;
   1052   }
   1053   int64 params_size_in_bytes() const { return params_size_in_bytes_; }
   1054   ParamsRegions params_weights() const {
   1055     if (!ok()) return ParamsRegions();
   1056     return weights_;
   1057   }
   1058   ParamsRegions params_biases() const {
   1059     if (!ok()) return ParamsRegions();
   1060     return biases_;
   1061   }
   1062 
   1063  private:
   1064   int GetRegionCountPerLayer() const;
   1065   CUDAExecutor* parent_;
   1066   cudnnFilterDescriptor_t handle_;
   1067   const CudnnRnnDescriptor* rnn_desc_;
   1068   int64 params_size_in_bytes_;
   1069   ParamsRegions weights_;
   1070   ParamsRegions biases_;
   1071   SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnParamsDescriptor);
   1072 };
   1073 
   1074 class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
   1075  public:
   1076   CudnnRnnDescriptor(CUDAExecutor* parent, cudnnHandle_t cudnn_handle,
   1077                      int num_layers, int hidden_size, int input_size,
   1078                      cudnnRNNInputMode_t input_mode,
   1079                      cudnnDirectionMode_t direction_mode,
   1080                      cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type,
   1081                      float dropout, uint64 seed,
   1082                      ScratchAllocator* state_allocator)
   1083       : parent_(parent),
   1084         rnn_desc_(nullptr),
   1085         num_layers_(num_layers),
   1086         hidden_size_(hidden_size),
   1087         input_size_(input_size),
   1088         input_mode_(input_mode),
   1089         direction_mode_(direction_mode),
   1090         rnn_mode_(rnn_mode),
   1091         data_type_(data_type) {
   1092     // Create the dropout handle.
   1093     cudnn_dropout_desc_.reset(new CudnnDropoutDescriptor(
   1094         parent, cudnn_handle, dropout, seed, state_allocator));
   1095     if (!cudnn_dropout_desc_->ok()) {
   1096       SetFailure(cudnn_dropout_desc_->Status());
   1097       return;
   1098     }
   1099 
   1100     // Create the RNN handle
   1101     cudnnStatus_t status = wrap::cudnnCreateRNNDescriptor(parent_, &rnn_desc_);
   1102     CUDNN_RETURN_IF_FAIL(status, "Unable to create RNN descriptor");
   1103 #if CUDNN_VERSION >= 6000
   1104     // TODO: allow the user to choose an algorithm.
   1105     cudnnRNNAlgo_t rnn_algo = CUDNN_RNN_ALGO_STANDARD;
   1106     status = wrap::cudnnSetRNNDescriptor_v6(
   1107         parent, cudnn_handle, rnn_desc_ /*rnnDesc*/, hidden_size /*hiddenSize*/,
   1108         num_layers /*numLayers*/, dropout_handle() /*dropoutDesc*/,
   1109         input_mode /*inputMode*/, direction_mode /*direction*/,
   1110         rnn_mode /*mode*/, rnn_algo /*algo*/, data_type /*dataType*/);
   1111 #else
   1112     status = wrap::cudnnSetRNNDescriptor(
   1113         parent, rnn_desc_ /*rnnDesc*/, hidden_size /*hiddenSize*/,
   1114         num_layers /*numLayers*/, dropout_handle() /*dropoutDesc*/,
   1115         input_mode /*inputMode*/, direction_mode /*direction*/,
   1116         rnn_mode /*mode*/, data_type /*dataType*/);
   1117 #endif
   1118     CUDNN_RETURN_IF_FAIL(status, "Unable to update RNN descriptor");
   1119 
   1120     // Create the params handle.
   1121     cudnn_params_desc_.reset(
   1122         new CudnnRnnParamsDescriptor(parent, cudnn_handle, *this));
   1123     if (!cudnn_params_desc_->ok()) {
   1124       SetFailure(cudnn_params_desc_->Status());
   1125       return;
   1126     }
   1127   }
   1128   ~CudnnRnnDescriptor() override {
   1129     if (rnn_desc_) {
   1130       cudnnStatus_t status =
   1131           wrap::cudnnDestroyRNNDescriptor(parent_, rnn_desc_);
   1132       CUDNN_RETURN_IF_FAIL(status, "Unable to destroy RNN descriptor");
   1133     }
   1134   }
   1135   cudnnRNNDescriptor_t handle() const {
   1136     if (!ok()) return nullptr;
   1137     return rnn_desc_;
   1138   }
   1139   int num_layers() const { return num_layers_; }
   1140   int hidden_size() const { return hidden_size_; }
   1141   int input_size() const { return input_size_; }
   1142   cudnnRNNInputMode_t input_mode() const { return input_mode_; }
   1143   cudnnDirectionMode_t direction_mode() const { return direction_mode_; }
   1144   cudnnRNNMode_t rnn_mode() const { return rnn_mode_; }
   1145   cudnnDataType_t data_type() const { return data_type_; }
   1146   int64 ParamsSizeInBytes() const override {
   1147     return cudnn_params_desc_->params_size_in_bytes();
   1148   }
   1149   cudnnDropoutDescriptor_t dropout_handle() const {
   1150     if (!cudnn_dropout_desc_) return nullptr;
   1151     return cudnn_dropout_desc_->handle();
   1152   }
   1153   cudnnFilterDescriptor_t params_handle() const {
   1154     if (!cudnn_params_desc_) return nullptr;
   1155     return cudnn_params_desc_->handle();
   1156   }
   1157   ParamsRegions ParamsWeightRegions() const override {
   1158     if (!ok()) return ParamsRegions();
   1159     return cudnn_params_desc_->params_weights();
   1160   }
   1161   ParamsRegions ParamsBiasRegions() const override {
   1162     if (!ok()) return ParamsRegions();
   1163     return cudnn_params_desc_->params_biases();
   1164   }
   1165 
   1166  private:
   1167   CUDAExecutor* parent_;
   1168   cudnnRNNDescriptor_t rnn_desc_;
   1169   int num_layers_;
   1170   int hidden_size_;
   1171   int input_size_;
   1172   cudnnRNNInputMode_t input_mode_;
   1173   cudnnDirectionMode_t direction_mode_;
   1174   cudnnRNNMode_t rnn_mode_;
   1175   cudnnDataType_t data_type_;
   1176   std::unique_ptr<CudnnDropoutDescriptor> cudnn_dropout_desc_;
   1177   std::unique_ptr<CudnnRnnParamsDescriptor> cudnn_params_desc_;
   1178   SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnDescriptor);
   1179 };
   1180 
   1181 CudnnRnnParamsDescriptor::CudnnRnnParamsDescriptor(
   1182     CUDAExecutor* parent, cudnnHandle_t cudnn_handle,
   1183     const CudnnRnnDescriptor& rnn_desc)
   1184     : parent_(parent),
   1185       handle_(nullptr),
   1186       rnn_desc_(&rnn_desc),
   1187       params_size_in_bytes_(0) {
   1188   cudnnTensorDescriptor_t input_desc = nullptr;
   1189   {
   1190     // Query the params size.
   1191     auto status = wrap::cudnnCreateTensorDescriptor(parent, &input_desc);
   1192     CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create tensor descriptor");
   1193     int dims[] = {1, rnn_desc.input_size(), 1};
   1194     int strides[] = {dims[1] * dims[2], dims[2], 1};
   1195     status = wrap::cudnnSetTensorNdDescriptor(
   1196         parent, input_desc /*tensorDesc*/, rnn_desc.data_type() /*dataType*/,
   1197         sizeof(dims) / sizeof(dims[0]) /*nbDims*/, dims /*dimA*/,
   1198         strides /*strideA*/);
   1199     CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to set tensor descriptor");
   1200 
   1201     size_t params_size = 0;
   1202     status = wrap::cudnnGetRNNParamsSize(
   1203         parent, cudnn_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/,
   1204         input_desc /*xDesc*/, &params_size /*sizeInBytes*/,
   1205         rnn_desc.data_type() /*dataType*/);
   1206     CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to get RNN parameter size");
   1207     params_size_in_bytes_ = static_cast<int64>(params_size);
   1208   }
   1209 
   1210   {
   1211     // Create the params descriptor.
   1212     auto status = wrap::cudnnCreateFilterDescriptor(parent, &handle_);
   1213     CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create RNN filter descriptor");
   1214     int dims[] = {static_cast<int>(params_size_in_bytes_), 1, 1};
   1215     status = wrap::cudnnSetFilterNdDescriptor(
   1216         parent, handle_ /*filterDesc*/, rnn_desc.data_type() /*dataType*/,
   1217         CUDNN_TENSOR_NCHW /*format*/, sizeof(dims) / sizeof(dims[0]) /*nbDims*/,
   1218         dims /*filterDimA*/);
   1219     CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to update RNN filter descriptor");
   1220   }
   1221 
   1222   {
   1223     // Create the weights and biases into the params buffer
   1224     int region_count_per_layer = GetRegionCountPerLayer();
   1225     cudnnFilterDescriptor_t region_desc_handle = nullptr;
   1226     auto status =
   1227         wrap::cudnnCreateFilterDescriptor(parent, &region_desc_handle);
   1228     CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create filter descriptor");
   1229     const int layer_count = rnn_desc.direction_mode() == CUDNN_UNIDIRECTIONAL
   1230                                 ? rnn_desc.num_layers()
   1231                                 : 2 * rnn_desc.num_layers();
   1232     for (int layer = 0; layer < layer_count; layer++) {
   1233       for (int region = 0; region < region_count_per_layer; region++) {
   1234         for (int type = 0; type < 2; type++) {
   1235           void* offset = nullptr;
   1236           if (type == 0) {
   1237             status = wrap::cudnnGetRNNLinLayerMatrixParams(
   1238                 parent, cudnn_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/,
   1239                 layer /*layer*/, input_desc /*xDesc*/, handle_ /*wDesc*/,
   1240                 nullptr /*w*/, region /*linLayerID*/,
   1241                 region_desc_handle /*linLayerMatDesc*/,
   1242                 &offset /*linLayerMat*/);
   1243             CUDNN_RETURN_IF_FAIL(
   1244                 status, "Cudnn fails to call cudnnGetRNNLinLayerMatrixParams");
   1245           } else {
   1246             status = wrap::cudnnGetRNNLinLayerBiasParams(
   1247                 parent, cudnn_handle /*rnnDesc*/, rnn_desc.handle() /*rnnDesc*/,
   1248                 layer /*layer*/, input_desc /*xDesc*/, handle_ /*wDesc*/,
   1249                 nullptr /*w*/, region /*linLayerID*/,
   1250                 region_desc_handle /*linLayerBiasDesc*/,
   1251                 &offset /*linLayerBias*/);
   1252             CUDNN_RETURN_IF_FAIL(
   1253                 status, "Cudnn fails to call cudnnGetRNNLinLayerBiasParams");
   1254           }
   1255           int dims[] = {1, 1, 1};
   1256           cudnnDataType_t data_type;
   1257           cudnnTensorFormat_t tensor_format;
   1258           int n_dims;
   1259           status = wrap::cudnnGetFilterNdDescriptor(
   1260               parent, region_desc_handle /*filterDesc*/,
   1261               sizeof(dims) / sizeof(dims[0]) /*nbDimsRequested*/,
   1262               &data_type /*dataType*/, &tensor_format /*format*/,
   1263               &n_dims /*nbDims*/, dims /*filterDimA*/);
   1264           CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to get filter description");
   1265           int64 size = dims[0] * dims[1] * dims[2] *
   1266                        CudnnDataTypeToByteSize(rnn_desc.data_type());
   1267           auto region = ParamsRegion{reinterpret_cast<int64>(offset), size};
   1268           if (type == 0) {
   1269             weights_.push_back(region);
   1270           } else {
   1271             biases_.push_back(region);
   1272           }
   1273         }
   1274       }
   1275     }
   1276     status = wrap::cudnnDestroyFilterDescriptor(parent, region_desc_handle);
   1277     CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to destroy filter descriptor");
   1278   }
   1279 
   1280   {
   1281     // Release the dummy input tensor descriptor.
   1282     auto status = wrap::cudnnDestroyTensorDescriptor(parent, input_desc);
   1283     CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to destroy tensor descriptor");
   1284   }
   1285 }
   1286 
   1287 int CudnnRnnParamsDescriptor::GetRegionCountPerLayer() const {
   1288   auto rnn_mode = rnn_desc_->rnn_mode();
   1289   switch (rnn_mode) {
   1290     case CUDNN_RNN_RELU:
   1291     case CUDNN_RNN_TANH:
   1292       return 2;
   1293     case CUDNN_LSTM:
   1294       return 8;
   1295     case CUDNN_GRU:
   1296       return 6;
   1297     default:
   1298       LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
   1299   }
   1300 }
   1301 
   1302 class CudnnRnnSequenceTensorDescriptor
   1303     : public CudnnDescriptorCommon<dnn::RnnSequenceTensorDescriptor> {
   1304  public:
   1305   CudnnRnnSequenceTensorDescriptor(CUDAExecutor* parent, int seq_length,
   1306                                    int batch_size, int data_size,
   1307                                    cudnnDataType_t data_type)
   1308       : parent_(parent),
   1309         seq_length_(seq_length),
   1310         batch_size_(batch_size),
   1311         data_size_(data_size),
   1312         data_type_(data_type) {
   1313     cudnnTensorDescriptor_t handle = nullptr;
   1314     if (seq_length <= 0) {
   1315       string error_msg =
   1316           port::StrCat("sequence length must be positive: ", seq_length);
   1317       LOG(ERROR) << error_msg;
   1318       SetFailure(port::Status(port::error::UNKNOWN, error_msg));
   1319       return;
   1320     }
   1321     cudnnStatus_t status = wrap::cudnnCreateTensorDescriptor(parent, &handle);
   1322     CUDNN_RETURN_IF_FAIL(status, "Failed to create tensor descriptor");
   1323     int dims[] = {batch_size, data_size, 1};
   1324     int strides[] = {dims[1] * dims[2], dims[2], 1};
   1325     status = wrap::cudnnSetTensorNdDescriptor(
   1326         parent, handle /*tensorDesc*/, data_type /*dataType*/,
   1327         sizeof(dims) / sizeof(dims[0]) /*nbDims*/, dims /*dimA*/,
   1328         strides /*strideA*/);
   1329     CUDNN_RETURN_IF_FAIL(status, "Failed to update tensor descriptor");
   1330     // Replicate handle across the number of steps.
   1331     handles_.assign(seq_length, handle);
   1332   }
   1333 
   1334   ~CudnnRnnSequenceTensorDescriptor() override {
   1335     // Only the first one needs to be destroyed. All others are the same.
   1336     cudnnStatus_t status =
   1337         wrap::cudnnDestroyTensorDescriptor(parent_, handles_[0]);
   1338     CUDNN_RETURN_IF_FAIL(status,
   1339                          "Failed to destroy sequence tensor descriptor");
   1340   }
   1341 
   1342   const cudnnTensorDescriptor_t* handles() const {
   1343     if (!ok()) return nullptr;
   1344     CHECK(!handles_.empty()) << "handles cannot be empty";
   1345     return handles_.data();
   1346   }
   1347 
   1348   int seq_length() const { return seq_length_; }
   1349   int batch_size() const { return batch_size_; }
   1350   int data_size() const { return data_size_; }
   1351 
   1352  private:
   1353   CUDAExecutor* parent_;
   1354   int seq_length_;
   1355   int batch_size_;
   1356   int data_size_;
   1357   cudnnDataType_t data_type_;
   1358   std::vector<cudnnTensorDescriptor_t> handles_;
   1359   SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnSequenceTensorDescriptor);
   1360 };
   1361 
   1362 class CudnnRnnStateTensorDescriptor
   1363     : public CudnnDescriptorCommon<dnn::RnnStateTensorDescriptor> {
   1364  public:
   1365   CudnnRnnStateTensorDescriptor(CUDAExecutor* parent, int num_layers,
   1366                                 int batch_size, int data_size,
   1367                                 cudnnDataType_t data_type)
   1368       : parent_(parent),
   1369         handle_(nullptr),
   1370         num_layers_(num_layers),
   1371         batch_size_(batch_size),
   1372         data_size_(data_size),
   1373         data_type_(data_type) {
   1374     cudnnStatus_t status = wrap::cudnnCreateTensorDescriptor(parent, &handle_);
   1375     CUDNN_RETURN_IF_FAIL(status, "Failed to create tensor descriptor");
   1376     int dims[] = {num_layers, batch_size, data_size};
   1377     int strides[] = {dims[1] * dims[2], dims[2], 1};
   1378     status = wrap::cudnnSetTensorNdDescriptor(
   1379         parent, handle_ /*tensorDesc*/, data_type /*dataType*/,
   1380         sizeof(dims) / sizeof(dims[0]) /*nbDims*/, dims /*dimA*/,
   1381         strides /*strideA*/);
   1382     CUDNN_RETURN_IF_FAIL(status, "Failed to update tensor descriptor");
   1383   }
   1384 
   1385   ~CudnnRnnStateTensorDescriptor() override {
   1386     if (!handle_) {
   1387       cudnnStatus_t status =
   1388           wrap::cudnnDestroyTensorDescriptor(parent_, handle_);
   1389       CUDNN_RETURN_IF_FAIL(status, "Unable to destroy RNN state tensor");
   1390     }
   1391   }
   1392 
   1393   cudnnTensorDescriptor_t handle() const {
   1394     if (!ok()) return nullptr;
   1395     return handle_;
   1396   }
   1397   int num_layers() const { return num_layers_; }
   1398   int batch_size() const { return batch_size_; }
   1399   int data_size() const { return data_size_; }
   1400 
   1401  private:
   1402   CUDAExecutor* parent_;
   1403   cudnnTensorDescriptor_t handle_;
   1404   int num_layers_;
   1405   int batch_size_;
   1406   int data_size_;
   1407   cudnnDataType_t data_type_;
   1408   SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnStateTensorDescriptor);
   1409 };
   1410 
   1411 namespace {
   1412 
   1413 struct RnnModelDims {
   1414   int num_layers = 0;
   1415   int batch_size = 0;
   1416   int seq_length = 0;
   1417   int hidden_size = 0;
   1418   int input_size = 0;
   1419   int dir_count = 0;
   1420 };
   1421 
   1422 template <class T>
   1423 bool ExtractAndCheckRnnForward(
   1424     const CudnnRnnDescriptor& rnn_desc,
   1425     const CudnnRnnSequenceTensorDescriptor& input_desc,
   1426     const DeviceMemory<T>& input_data,
   1427     const CudnnRnnStateTensorDescriptor& input_h_desc,
   1428     const DeviceMemory<T>& input_h_data,
   1429     const CudnnRnnStateTensorDescriptor& input_c_desc,
   1430     const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
   1431     const CudnnRnnSequenceTensorDescriptor& output_desc,
   1432     const DeviceMemory<T>& output_data,
   1433     const CudnnRnnStateTensorDescriptor& output_h_desc,
   1434     const DeviceMemory<T>& output_h_data,
   1435     const CudnnRnnStateTensorDescriptor& output_c_desc,
   1436     const DeviceMemory<T>& output_c_data, RnnModelDims* model_dims) {
   1437   // extract model parameters
   1438   model_dims->num_layers = rnn_desc.num_layers();
   1439   model_dims->batch_size = input_desc.batch_size();
   1440   model_dims->seq_length = input_desc.seq_length();
   1441   model_dims->hidden_size = rnn_desc.hidden_size();
   1442   model_dims->input_size = input_desc.data_size();
   1443   model_dims->dir_count =
   1444       (rnn_desc.direction_mode() == CUDNN_BIDIRECTIONAL) ? 2 : 1;
   1445 
   1446   // check parameters
   1447   if (!(input_h_desc.num_layers() ==
   1448             model_dims->num_layers * model_dims->dir_count &&
   1449         input_h_desc.batch_size() == model_dims->batch_size &&
   1450         input_h_desc.data_size() == model_dims->hidden_size)) {
   1451     LOG(ERROR) << "Invalid input_h shape";
   1452     return false;
   1453   }
   1454   if (!(input_h_desc.num_layers() == input_c_desc.num_layers() &&
   1455         input_h_desc.batch_size() == input_c_desc.batch_size() &&
   1456         input_h_desc.data_size() == input_c_desc.data_size())) {
   1457     LOG(ERROR) << "Invalid input_c shape";
   1458     return false;
   1459   }
   1460   if (!(output_desc.seq_length() == model_dims->seq_length &&
   1461         output_desc.batch_size() == model_dims->batch_size &&
   1462         output_desc.data_size() ==
   1463             model_dims->hidden_size * model_dims->dir_count)) {
   1464     LOG(ERROR) << "Invalid output shape";
   1465     return false;
   1466   }
   1467   if (!(input_h_desc.num_layers() == output_h_desc.num_layers() &&
   1468         input_h_desc.batch_size() == output_h_desc.batch_size() &&
   1469         input_h_desc.data_size() == output_h_desc.data_size())) {
   1470     LOG(ERROR) << "Invalid output_h shape";
   1471     return false;
   1472   }
   1473   if (!(input_h_desc.num_layers() == output_c_desc.num_layers() &&
   1474         input_h_desc.batch_size() == output_c_desc.batch_size() &&
   1475         input_h_desc.data_size() == output_c_desc.data_size())) {
   1476     LOG(ERROR) << "Invalid output_h shape";
   1477     return false;
   1478   }
   1479 
   1480   return true;
   1481 }
   1482 
   1483 bool CheckRNNParameterSize(CUDAExecutor* parent, cudnnHandle_t cudnn_handle,
   1484                            const CudnnRnnDescriptor& rnn_desc,
   1485                            const CudnnRnnSequenceTensorDescriptor& input_desc) {
   1486   size_t params_size_in_bytes = 0;
   1487   cudnnStatus_t status = wrap::cudnnGetRNNParamsSize(
   1488       parent, cudnn_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/,
   1489       input_desc.handles()[0] /*xDesc*/, &params_size_in_bytes /*sizeInBytes*/,
   1490       rnn_desc.data_type() /*dataType*/);
   1491   if (status != CUDNN_STATUS_SUCCESS) {
   1492     LOG(ERROR) << "Unable to check RNN param size: " << ToString(status);
   1493     return false;
   1494   }
   1495   return static_cast<int64>(params_size_in_bytes) ==
   1496          rnn_desc.ParamsSizeInBytes();
   1497 }
   1498 
   1499 bool CreateRnnWorkspace(Stream* stream, CUDAExecutor* parent,
   1500                         cudnnHandle_t cudnn_handle,
   1501                         const CudnnRnnDescriptor& rnn_desc,
   1502                         const CudnnRnnSequenceTensorDescriptor& input_desc,
   1503                         ScratchAllocator* workspace_allocator,
   1504                         DeviceMemory<uint8>* workspace) {
   1505   // Query the workspace size.
   1506   size_t workspace_size_in_bytes = 0;
   1507   cudnnStatus_t status = wrap::cudnnGetRNNWorkspaceSize(
   1508       parent, cudnn_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/,
   1509       input_desc.seq_length() /*seqLength*/, input_desc.handles() /*xDesc*/,
   1510       &workspace_size_in_bytes /*sizeInBytes*/);
   1511   if (status != CUDNN_STATUS_SUCCESS) {
   1512     LOG(ERROR) << "Unable to query workspace size: " << ToString(status);
   1513     return false;
   1514   }
   1515   // Allocate the workspace.
   1516   if (workspace_size_in_bytes > 0) {
   1517     auto allocated =
   1518         workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes);
   1519     if (!allocated.ok() || (*workspace = allocated.ValueOrDie()) == nullptr) {
   1520       LOG(ERROR) << port::StrCat("Failed to allocate RNN workspace of ",
   1521                                  workspace_size_in_bytes, " bytes.");
   1522       return false;
   1523     }
   1524   } else {
   1525     *workspace = DeviceMemory<uint8>();
   1526   }
   1527   return true;
   1528 }
   1529 
   1530 }  // namespace
   1531 
   1532 template <class T>
   1533 bool CudnnSupport::DoRnnForwardImpl(
   1534     Stream* stream, const CudnnRnnDescriptor& rnn_desc,
   1535     const CudnnRnnSequenceTensorDescriptor& input_desc,
   1536     const DeviceMemory<T>& input_data,
   1537     const CudnnRnnStateTensorDescriptor& input_h_desc,
   1538     const DeviceMemory<T>& input_h_data,
   1539     const CudnnRnnStateTensorDescriptor& input_c_desc,
   1540     const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
   1541     const CudnnRnnSequenceTensorDescriptor& output_desc,
   1542     DeviceMemory<T>* output_data,
   1543     const CudnnRnnStateTensorDescriptor& output_h_desc,
   1544     DeviceMemory<T>* output_h_data,
   1545     const CudnnRnnStateTensorDescriptor& output_c_desc,
   1546     DeviceMemory<T>* output_c_data, bool is_training,
   1547     ScratchAllocator* reserve_space_allocator,
   1548     ScratchAllocator* workspace_allocator) {
   1549   // extract model parameters
   1550   RnnModelDims model_dims;
   1551   bool res = ExtractAndCheckRnnForward(
   1552       rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
   1553       input_c_desc, input_c_data, params, output_desc, *output_data,
   1554       output_h_desc, *output_h_data, output_c_desc, *output_c_data,
   1555       &model_dims);
   1556   if (!res) {
   1557     LOG(ERROR) << "Invalid parameters for RNN Model";
   1558     return false;
   1559   }
   1560 
   1561   // check params size
   1562   mutex_lock lock{dnn_handle_mutex_};
   1563 
   1564   if (!CheckRNNParameterSize(parent_, ToHandle(dnn_handle_), rnn_desc,
   1565                              input_desc)) {
   1566     LOG(ERROR) << "Invalid parameters";
   1567     return false;
   1568   }
   1569 
   1570   // create the workspace
   1571   DeviceMemory<uint8> workspace;
   1572   if (!CreateRnnWorkspace(stream, parent_, ToHandle(dnn_handle_), rnn_desc,
   1573                           input_desc, workspace_allocator, &workspace)) {
   1574     LOG(ERROR) << "Unable to create rnn workspace";
   1575     return false;
   1576   }
   1577 
   1578   // query the reserve space size
   1579   // allocate the reserve space
   1580   DeviceMemory<uint8> reserve_space;
   1581   if (is_training) {
   1582     size_t reserve_space_size_in_bytes = 0;
   1583     cudnnStatus_t status = wrap::cudnnGetRNNTrainingReserveSize(
   1584         parent_, ToHandle(dnn_handle_) /*handle*/,
   1585         rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/,
   1586         input_desc.handles() /*xDesc*/,
   1587         &reserve_space_size_in_bytes /*sizeInBytes*/);
   1588     if (status != CUDNN_STATUS_SUCCESS) {
   1589       LOG(ERROR) << "Unable to query reserve space size: " << ToString(status);
   1590       return false;
   1591     }
   1592 
   1593     if (reserve_space_size_in_bytes > 0) {
   1594       auto allocated = reserve_space_allocator->AllocateBytes(
   1595           stream, reserve_space_size_in_bytes);
   1596       if (!allocated.ok() ||
   1597           (reserve_space = allocated.ValueOrDie()) == nullptr) {
   1598         LOG(ERROR) << "Failed to allocate RNN reserve space of "
   1599                    << reserve_space_size_in_bytes << " bytes.";
   1600         return false;
   1601       }
   1602     }
   1603   }
   1604 
   1605   // make the forward call
   1606   if (!is_training) {
   1607     cudnnStatus_t status = wrap::cudnnRNNForwardInference(
   1608         parent_, ToHandle(dnn_handle_) /*handle*/,
   1609         rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/,
   1610         input_desc.handles() /*xDesc*/, input_data.opaque() /*x*/,
   1611         input_h_desc.handle() /*hxDesc*/, input_h_data.opaque() /*hx*/,
   1612         input_c_desc.handle() /*cxDesc*/, input_c_data.opaque() /*cx*/,
   1613         rnn_desc.params_handle() /*wDesc*/, params.opaque() /*w*/,
   1614         output_desc.handles() /*yDesc*/, output_data->opaque() /*y*/,
   1615         output_h_desc.handle() /*hyDesc*/, output_h_data->opaque() /*hy*/,
   1616         output_c_desc.handle() /*cyDesc*/, output_c_data->opaque() /*cy*/,
   1617         workspace.opaque() /*workspace*/,
   1618         workspace.size() /*workSpaceSizeInBytes*/);
   1619     if (status != CUDNN_STATUS_SUCCESS) {
   1620       LOG(ERROR) << "Failed to call cudnnRNNForwardInference: "
   1621                  << ToString(status);
   1622       return false;
   1623     }
   1624   } else {
   1625     cudnnStatus_t status = wrap::cudnnRNNForwardTraining(
   1626         parent_, ToHandle(dnn_handle_) /*handle*/,
   1627         rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/,
   1628         input_desc.handles() /*xDesc*/, input_data.opaque() /*x*/,
   1629         input_h_desc.handle() /*hxDesc*/, input_h_data.opaque() /*hx*/,
   1630         input_c_desc.handle() /*cxDesc*/, input_c_data.opaque() /*cx*/,
   1631         rnn_desc.params_handle() /*wDesc*/, params.opaque() /*w*/,
   1632         output_desc.handles() /*yDesc*/, output_data->opaque() /*y*/,
   1633         output_h_desc.handle() /*hyDesc*/, output_h_data->opaque() /*hy*/,
   1634         output_c_desc.handle() /*cyDesc*/, output_c_data->opaque() /*cy*/,
   1635         workspace.opaque() /*workspace*/,
   1636         workspace.size() /*workSpaceSizeInBytes*/,
   1637         reserve_space.opaque() /*reserveSpace*/,
   1638         reserve_space.size() /*reserveSpaceSizeInBytes*/);
   1639     if (status != CUDNN_STATUS_SUCCESS) {
   1640       LOG(ERROR) << "Failed to call cudnnRNNForwardTraining"
   1641                  << ToString(status);
   1642       return false;
   1643     }
   1644   }
   1645 
   1646   return true;
   1647 }
   1648 
   1649 template <class T>
   1650 bool CudnnSupport::DoRnnBackwardImpl(
   1651     Stream* stream, const CudnnRnnDescriptor& rnn_desc,
   1652     const CudnnRnnSequenceTensorDescriptor& input_desc,
   1653     const DeviceMemory<T>& input_data,
   1654     const CudnnRnnStateTensorDescriptor& input_h_desc,
   1655     const DeviceMemory<T>& input_h_data,
   1656     const CudnnRnnStateTensorDescriptor& input_c_desc,
   1657     const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
   1658     const CudnnRnnSequenceTensorDescriptor& output_desc,
   1659     const DeviceMemory<T>& output_data,
   1660     const CudnnRnnStateTensorDescriptor& output_h_desc,
   1661     const DeviceMemory<T>& output_h_data,
   1662     const CudnnRnnStateTensorDescriptor& output_c_desc,
   1663     const DeviceMemory<T>& output_c_data,
   1664     const DeviceMemory<T>& output_backprop_data,
   1665     const DeviceMemory<T>& output_h_backprop_data,
   1666     const DeviceMemory<T>& output_c_backprop_data,
   1667     DeviceMemory<T>* input_backprop_data,
   1668     DeviceMemory<T>* input_h_backprop_data,
   1669     DeviceMemory<T>* input_c_backprop_data,
   1670     DeviceMemory<T>* params_backprop_data,
   1671     DeviceMemory<uint8>* reserve_space_data,
   1672     ScratchAllocator* workspace_allocator) {
   1673   // extract model parameters
   1674   RnnModelDims model_dims;
   1675   bool res = ExtractAndCheckRnnForward(
   1676       rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
   1677       input_c_desc, input_c_data, params, output_desc, output_data,
   1678       output_h_desc, output_h_data, output_c_desc, output_c_data, &model_dims);
   1679   if (!res) {
   1680     LOG(ERROR) << "Invalid parameters for RNN Model";
   1681     return false;
   1682   }
   1683 
   1684   // check params size
   1685   mutex_lock lock{dnn_handle_mutex_};
   1686 
   1687   if (!CheckRNNParameterSize(parent_, ToHandle(dnn_handle_), rnn_desc,
   1688                              input_desc)) {
   1689     LOG(ERROR) << "Invalid parameters";
   1690     return false;
   1691   }
   1692 
   1693   // create the workspace
   1694   DeviceMemory<uint8> workspace;
   1695   if (!CreateRnnWorkspace(stream, parent_, ToHandle(dnn_handle_), rnn_desc,
   1696                           input_desc, workspace_allocator, &workspace)) {
   1697     LOG(ERROR) << "Unable to create rnn workspace";
   1698     return false;
   1699   }
   1700 
   1701   // make the backward data call
   1702   cudnnStatus_t status = wrap::cudnnRNNBackwardData(
   1703       parent_, ToHandle(dnn_handle_) /*handle*/, rnn_desc.handle() /*rnnDesc*/,
   1704       model_dims.seq_length /*seqLength*/, output_desc.handles() /*yDesc*/,
   1705       output_data.opaque() /*y*/, output_desc.handles() /*dyDesc*/,
   1706       output_backprop_data.opaque() /*dy*/, output_h_desc.handle() /*dhyDesc*/,
   1707       output_h_backprop_data.opaque() /*dhy*/,
   1708       output_c_desc.handle() /*dcyDesc*/,
   1709       output_c_backprop_data.opaque() /*dcy*/,
   1710       rnn_desc.params_handle() /*wDesc*/, params.opaque() /*w*/,
   1711       input_h_desc.handle() /*hxDesc*/, input_h_data.opaque() /*hx*/,
   1712       input_c_desc.handle() /*cxDesc*/, input_c_data.opaque() /*cx*/,
   1713       input_desc.handles() /*dxDesc*/, input_backprop_data->opaque() /*dx*/,
   1714       input_h_desc.handle() /*dhxDesc*/,
   1715       input_h_backprop_data->opaque() /*dhx*/,
   1716       input_c_desc.handle() /*dcxDesc*/,
   1717       input_c_backprop_data->opaque() /*dcx*/, workspace.opaque() /*workspace*/,
   1718       workspace.size() /*workSpaceSizeInBytes*/,
   1719       reserve_space_data->opaque() /*reserveSpace*/,
   1720       reserve_space_data->size() /*reserveSpaceSizeInBytes*/);
   1721   if (status != CUDNN_STATUS_SUCCESS) {
   1722     LOG(ERROR) << "Failed to call cudnnRNNBackwardData: " << ToString(status);
   1723     return false;
   1724   }
   1725 
   1726   if (params_backprop_data != nullptr) {
   1727     // Clear the dw to zeros.
   1728     stream->ThenMemZero(params_backprop_data, params_backprop_data->size());
   1729     // make the backward weight call
   1730     status = wrap::cudnnRNNBackwardWeights(
   1731         parent_, ToHandle(dnn_handle_) /*handle*/,
   1732         rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/,
   1733         input_desc.handles() /*xDesc*/, input_data.opaque() /*x*/,
   1734         input_h_desc.handle() /*hxDesc*/, input_h_data.opaque() /*hx*/,
   1735         output_desc.handles() /*yDesc*/, output_data.opaque() /*y*/,
   1736         workspace.opaque() /*workspace*/,
   1737         workspace.size() /*workSpaceSizeInBytes*/,
   1738         rnn_desc.params_handle() /*dwDesc*/,
   1739         params_backprop_data->opaque() /*dw*/,
   1740         reserve_space_data->opaque() /*reserveSpace*/,
   1741         reserve_space_data->size() /*reserveSpaceSizeInBytes*/);
   1742     if (status != CUDNN_STATUS_SUCCESS) {
   1743       LOG(ERROR) << "Failed to call cudnnRNNBackwardWeights: "
   1744                  << ToString(status);
   1745       return false;
   1746     }
   1747   }
   1748 
   1749   return true;
   1750 }
   1751 
   1752 #endif  // CUDNN_VERSION
   1753 
   1754 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
   1755 CudnnSupport::createRnnDescriptor(int num_layers, int hidden_size,
   1756                                   int input_size, dnn::RnnInputMode input_mode,
   1757                                   dnn::RnnDirectionMode direction_mode,
   1758                                   dnn::RnnMode rnn_mode,
   1759                                   dnn::DataType data_type, float dropout,
   1760                                   uint64 seed,
   1761                                   ScratchAllocator* state_allocator) {
   1762 #if CUDNN_VERSION >= 5000
   1763   mutex_lock lock{dnn_handle_mutex_};
   1764   std::unique_ptr<CudnnRnnDescriptor> rnn_desc(new CudnnRnnDescriptor(
   1765       parent_, ToHandle(dnn_handle_), num_layers, hidden_size, input_size,
   1766       ToCudnnRnnInputMode(input_mode), ToCudnnRnnDirectionMode(direction_mode),
   1767       ToCudnnRnnMode(rnn_mode), ToCudnnDataType(data_type), dropout, seed,
   1768       state_allocator));
   1769   if (!rnn_desc->ok()) {
   1770     return rnn_desc->Status();
   1771   }
   1772   return port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>(
   1773       std::move(rnn_desc));
   1774 #else
   1775   string error_msg =
   1776       port::StrCat("createRnnDescriptor needs at least Cudnn 5.0 to work. ",
   1777                    "Current Cudnn version: ", CUDNN_VERSION, ". ");
   1778   LOG(ERROR) << error_msg;
   1779   return port::Status{port::error::UNIMPLEMENTED, error_msg};
   1780 #endif  // CUDNN_VERSION
   1781 }
   1782 
   1783 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
   1784 CudnnSupport::createRnnSequenceTensorDescriptor(int seq_length, int batch_size,
   1785                                                 int data_size,
   1786                                                 dnn::DataType data_type) {
   1787 #if CUDNN_VERSION >= 5000
   1788   std::unique_ptr<CudnnRnnSequenceTensorDescriptor> seq_desc(
   1789       new CudnnRnnSequenceTensorDescriptor(parent_, seq_length, batch_size,
   1790                                            data_size,
   1791                                            ToCudnnDataType(data_type)));
   1792   if (!seq_desc->ok()) {
   1793     return seq_desc->Status();
   1794   }
   1795   return port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>(
   1796       std::move(seq_desc));
   1797 #else
   1798   string error_msg = port::StrCat(
   1799       "createRnnSequenceTensorDescriptor needs at least Cudnn 5.0 to work. ",
   1800       "Current Cudnn version: ", CUDNN_VERSION, ". ");
   1801   LOG(ERROR) << error_msg;
   1802   return port::Status{port::error::UNIMPLEMENTED, error_msg};
   1803 #endif  // CUDNN_VERSION
   1804 }
   1805 
   1806 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
   1807 CudnnSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size,
   1808                                              int data_size,
   1809                                              dnn::DataType data_type) {
   1810 #if CUDNN_VERSION >= 5000
   1811   std::unique_ptr<CudnnRnnStateTensorDescriptor> state_desc(
   1812       new CudnnRnnStateTensorDescriptor(parent_, num_layer, batch_size,
   1813                                         data_size, ToCudnnDataType(data_type)));
   1814   if (!state_desc->ok()) {
   1815     return state_desc->Status();
   1816   }
   1817   return port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>(
   1818       std::move(state_desc));
   1819 #else
   1820   string error_msg = port::StrCat(
   1821       "createRnnStateTensorDescriptor needs at least Cudnn 5.0 to work. ",
   1822       "Current Cudnn version: ", CUDNN_VERSION, ". ");
   1823   LOG(ERROR) << error_msg;
   1824   return port::Status{port::error::UNIMPLEMENTED, error_msg};
   1825 #endif  // CUDNN_VERSION
   1826 }
   1827 
   1828 bool CudnnSupport::DoRnnForward(
   1829     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
   1830     const dnn::RnnSequenceTensorDescriptor& input_desc,
   1831     const DeviceMemory<Eigen::half>& input_data,
   1832     const dnn::RnnStateTensorDescriptor& input_h_desc,
   1833     const DeviceMemory<Eigen::half>& input_h_data,
   1834     const dnn::RnnStateTensorDescriptor& input_c_desc,
   1835     const DeviceMemory<Eigen::half>& input_c_data,
   1836     const DeviceMemory<Eigen::half>& params,
   1837     const dnn::RnnSequenceTensorDescriptor& output_desc,
   1838     DeviceMemory<Eigen::half>* output_data,
   1839     const dnn::RnnStateTensorDescriptor& output_h_desc,
   1840     DeviceMemory<Eigen::half>* output_h_data,
   1841     const dnn::RnnStateTensorDescriptor& output_c_desc,
   1842     DeviceMemory<Eigen::half>* output_c_data, bool is_training,
   1843     ScratchAllocator* reserve_space_allocator,
   1844     ScratchAllocator* workspace_allocator) {
   1845 #if CUDNN_VERSION >= 5000
   1846   const CudnnRnnDescriptor& cudnn_rnn_desc =
   1847       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
   1848   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
   1849       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
   1850   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
   1851       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
   1852   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
   1853       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
   1854   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
   1855       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
   1856   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
   1857       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
   1858   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
   1859       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
   1860 
   1861   return DoRnnForwardImpl<Eigen::half>(
   1862       stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc,
   1863       input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
   1864       output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
   1865       output_c_data, is_training, reserve_space_allocator, workspace_allocator);
   1866 #else
   1867   return false;
   1868 #endif  // CUDNN_VERSION
   1869 }
   1870 
   1871 bool CudnnSupport::DoRnnForward(
   1872     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
   1873     const dnn::RnnSequenceTensorDescriptor& input_desc,
   1874     const DeviceMemory<float>& input_data,
   1875     const dnn::RnnStateTensorDescriptor& input_h_desc,
   1876     const DeviceMemory<float>& input_h_data,
   1877     const dnn::RnnStateTensorDescriptor& input_c_desc,
   1878     const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
   1879     const dnn::RnnSequenceTensorDescriptor& output_desc,
   1880     DeviceMemory<float>* output_data,
   1881     const dnn::RnnStateTensorDescriptor& output_h_desc,
   1882     DeviceMemory<float>* output_h_data,
   1883     const dnn::RnnStateTensorDescriptor& output_c_desc,
   1884     DeviceMemory<float>* output_c_data, bool is_training,
   1885     ScratchAllocator* reserve_space_allocator,
   1886     ScratchAllocator* workspace_allocator) {
   1887 #if CUDNN_VERSION >= 5000
   1888   const CudnnRnnDescriptor& cudnn_rnn_desc =
   1889       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
   1890   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
   1891       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
   1892   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
   1893       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
   1894   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
   1895       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
   1896   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
   1897       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
   1898   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
   1899       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
   1900   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
   1901       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
   1902 
   1903   return DoRnnForwardImpl<float>(
   1904       stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc,
   1905       input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
   1906       output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
   1907       output_c_data, is_training, reserve_space_allocator, workspace_allocator);
   1908 #else
   1909   return false;
   1910 #endif  // CUDNN_VERSION
   1911 }
   1912 
   1913 bool CudnnSupport::DoRnnForward(
   1914     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
   1915     const dnn::RnnSequenceTensorDescriptor& input_desc,
   1916     const DeviceMemory<double>& input_data,
   1917     const dnn::RnnStateTensorDescriptor& input_h_desc,
   1918     const DeviceMemory<double>& input_h_data,
   1919     const dnn::RnnStateTensorDescriptor& input_c_desc,
   1920     const DeviceMemory<double>& input_c_data,
   1921     const DeviceMemory<double>& params,
   1922     const dnn::RnnSequenceTensorDescriptor& output_desc,
   1923     DeviceMemory<double>* output_data,
   1924     const dnn::RnnStateTensorDescriptor& output_h_desc,
   1925     DeviceMemory<double>* output_h_data,
   1926     const dnn::RnnStateTensorDescriptor& output_c_desc,
   1927     DeviceMemory<double>* output_c_data, bool is_training,
   1928     ScratchAllocator* reserve_space_allocator,
   1929     ScratchAllocator* workspace_allocator) {
   1930 #if CUDNN_VERSION >= 5000
   1931   const CudnnRnnDescriptor& cudnn_rnn_desc =
   1932       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
   1933   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
   1934       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
   1935   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
   1936       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
   1937   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
   1938       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
   1939   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
   1940       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
   1941   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
   1942       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
   1943   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
   1944       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
   1945 
   1946   return DoRnnForwardImpl<double>(
   1947       stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc,
   1948       input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
   1949       output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
   1950       output_c_data, is_training, reserve_space_allocator, workspace_allocator);
   1951 #else
   1952   return false;
   1953 #endif  // CUDNN_VERSION
   1954 }
   1955 
   1956 bool CudnnSupport::DoRnnBackward(
   1957     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
   1958     const dnn::RnnSequenceTensorDescriptor& input_desc,
   1959     const DeviceMemory<Eigen::half>& input_data,
   1960     const dnn::RnnStateTensorDescriptor& input_h_desc,
   1961     const DeviceMemory<Eigen::half>& input_h_data,
   1962     const dnn::RnnStateTensorDescriptor& input_c_desc,
   1963     const DeviceMemory<Eigen::half>& input_c_data,
   1964     const DeviceMemory<Eigen::half>& params,
   1965     const dnn::RnnSequenceTensorDescriptor& output_desc,
   1966     const DeviceMemory<Eigen::half>& output_data,
   1967     const dnn::RnnStateTensorDescriptor& output_h_desc,
   1968     const DeviceMemory<Eigen::half>& output_h_data,
   1969     const dnn::RnnStateTensorDescriptor& output_c_desc,
   1970     const DeviceMemory<Eigen::half>& output_c_data,
   1971     const DeviceMemory<Eigen::half>& output_backprop_data,
   1972     const DeviceMemory<Eigen::half>& output_h_backprop_data,
   1973     const DeviceMemory<Eigen::half>& output_c_backprop_data,
   1974     DeviceMemory<Eigen::half>* input_backprop_data,
   1975     DeviceMemory<Eigen::half>* input_h_backprop_data,
   1976     DeviceMemory<Eigen::half>* input_c_backprop_data,
   1977     DeviceMemory<Eigen::half>* params_backprop_data,
   1978     DeviceMemory<uint8>* reserve_space_data,
   1979     ScratchAllocator* workspace_allocator) {
   1980 #if CUDNN_VERSION >= 5000
   1981   const CudnnRnnDescriptor& cudnn_rnn_desc =
   1982       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
   1983   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
   1984       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
   1985   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
   1986       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
   1987   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
   1988       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
   1989   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
   1990       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
   1991   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
   1992       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
   1993   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
   1994       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
   1995 
   1996   return DoRnnBackwardImpl<Eigen::half>(
   1997       stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc,
   1998       input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
   1999       output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
   2000       output_c_data, output_backprop_data, output_h_backprop_data,
   2001       output_c_backprop_data, input_backprop_data, input_h_backprop_data,
   2002       input_c_backprop_data, params_backprop_data, reserve_space_data,
   2003       workspace_allocator);
   2004 #else
   2005   return false;
   2006 #endif  // CUDNN_VERSION
   2007 }
   2008 
   2009 bool CudnnSupport::DoRnnBackward(
   2010     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
   2011     const dnn::RnnSequenceTensorDescriptor& input_desc,
   2012     const DeviceMemory<float>& input_data,
   2013     const dnn::RnnStateTensorDescriptor& input_h_desc,
   2014     const DeviceMemory<float>& input_h_data,
   2015     const dnn::RnnStateTensorDescriptor& input_c_desc,
   2016     const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
   2017     const dnn::RnnSequenceTensorDescriptor& output_desc,
   2018     const DeviceMemory<float>& output_data,
   2019     const dnn::RnnStateTensorDescriptor& output_h_desc,
   2020     const DeviceMemory<float>& output_h_data,
   2021     const dnn::RnnStateTensorDescriptor& output_c_desc,
   2022     const DeviceMemory<float>& output_c_data,
   2023     const DeviceMemory<float>& output_backprop_data,
   2024     const DeviceMemory<float>& output_h_backprop_data,
   2025     const DeviceMemory<float>& output_c_backprop_data,
   2026     DeviceMemory<float>* input_backprop_data,
   2027     DeviceMemory<float>* input_h_backprop_data,
   2028     DeviceMemory<float>* input_c_backprop_data,
   2029     DeviceMemory<float>* params_backprop_data,
   2030     DeviceMemory<uint8>* reserve_space_data,
   2031     ScratchAllocator* workspace_allocator) {
   2032 #if CUDNN_VERSION >= 5000
   2033   const CudnnRnnDescriptor& cudnn_rnn_desc =
   2034       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
   2035   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
   2036       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
   2037   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
   2038       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
   2039   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
   2040       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
   2041   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
   2042       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
   2043   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
   2044       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
   2045   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
   2046       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
   2047 
   2048   return DoRnnBackwardImpl<float>(
   2049       stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc,
   2050       input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
   2051       output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
   2052       output_c_data, output_backprop_data, output_h_backprop_data,
   2053       output_c_backprop_data, input_backprop_data, input_h_backprop_data,
   2054       input_c_backprop_data, params_backprop_data, reserve_space_data,
   2055       workspace_allocator);
   2056 #else
   2057   return false;
   2058 #endif  // CUDNN_VERSION
   2059 }
   2060 
   2061 bool CudnnSupport::DoRnnBackward(
   2062     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
   2063     const dnn::RnnSequenceTensorDescriptor& input_desc,
   2064     const DeviceMemory<double>& input_data,
   2065     const dnn::RnnStateTensorDescriptor& input_h_desc,
   2066     const DeviceMemory<double>& input_h_data,
   2067     const dnn::RnnStateTensorDescriptor& input_c_desc,
   2068     const DeviceMemory<double>& input_c_data,
   2069     const DeviceMemory<double>& params,
   2070     const dnn::RnnSequenceTensorDescriptor& output_desc,
   2071     const DeviceMemory<double>& output_data,
   2072     const dnn::RnnStateTensorDescriptor& output_h_desc,
   2073     const DeviceMemory<double>& output_h_data,
   2074     const dnn::RnnStateTensorDescriptor& output_c_desc,
   2075     const DeviceMemory<double>& output_c_data,
   2076     const DeviceMemory<double>& output_backprop_data,
   2077     const DeviceMemory<double>& output_h_backprop_data,
   2078     const DeviceMemory<double>& output_c_backprop_data,
   2079     DeviceMemory<double>* input_backprop_data,
   2080     DeviceMemory<double>* input_h_backprop_data,
   2081     DeviceMemory<double>* input_c_backprop_data,
   2082     DeviceMemory<double>* params_backprop_data,
   2083     DeviceMemory<uint8>* reserve_space_data,
   2084     ScratchAllocator* workspace_allocator) {
   2085 #if CUDNN_VERSION >= 5000
   2086   const CudnnRnnDescriptor& cudnn_rnn_desc =
   2087       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
   2088   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
   2089       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
   2090   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
   2091       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
   2092   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
   2093       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
   2094   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
   2095       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
   2096   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
   2097       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
   2098   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
   2099       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
   2100 
   2101   return DoRnnBackwardImpl<double>(
   2102       stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc,
   2103       input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
   2104       output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
   2105       output_c_data, output_backprop_data, output_h_backprop_data,
   2106       output_c_backprop_data, input_backprop_data, input_h_backprop_data,
   2107       input_c_backprop_data, params_backprop_data, reserve_space_data,
   2108       workspace_allocator);
   2109 #else
   2110   return false;
   2111 #endif  // CUDNN_VERSION
   2112 }
   2113 
   2114 namespace {
   2115 
   2116 inline cudnnConvolutionFwdAlgo_t GetCudnnConvolutionForwardAlgo(
   2117     Stream* stream, CUDAExecutor* parent, void* dnn_handle,
   2118     const ScopedTensorDescriptor& input_nd,
   2119     const ScopedFilterDescriptor& filter,
   2120     const ScopedConvolutionDescriptor& conv,
   2121     const ScopedTensorDescriptor& output_nd, bool specify_workspace_limit,
   2122     ScratchAllocator* scratch_allocator) {
   2123   cudnnConvolutionFwdPreference_t preference =
   2124       specify_workspace_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
   2125                               : CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
   2126   auto memory_limit_bytes =
   2127       scratch_allocator == nullptr
   2128           ? 0
   2129           : scratch_allocator->GetMemoryLimitInBytes(stream);
   2130   if (memory_limit_bytes < 0) {
   2131     memory_limit_bytes = 0;
   2132   }
   2133 
   2134   cudnnConvolutionFwdAlgo_t algo_to_use;
   2135   auto status = wrap::cudnnGetConvolutionForwardAlgorithm(
   2136       parent, ToHandle(dnn_handle), input_nd.handle(), filter.handle(),
   2137       conv.handle(), output_nd.handle(), preference, memory_limit_bytes,
   2138       &algo_to_use);
   2139   CHECK_EQ(status, CUDNN_STATUS_SUCCESS)
   2140       << "Unable to find a suitable algorithm for doing forward convolution";
   2141   return algo_to_use;
   2142 }
   2143 
   2144 dnn::AlgorithmDesc GetCudnnConvolutionForwardAlgorithm(
   2145     Stream* stream, CUDAExecutor* parent, void* dnn_handle,
   2146     const dnn::AlgorithmConfig& algorithm_config, bool is_profiling,
   2147     const ScopedTensorDescriptor& input_nd,
   2148     const ScopedFilterDescriptor& filter,
   2149     const ScopedConvolutionDescriptor& conv,
   2150     const ScopedTensorDescriptor& output_nd,
   2151     ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch) {
   2152   cudnnConvolutionFwdAlgo_t algo;
   2153   bool use_tensor_ops;
   2154   if (algorithm_config.algorithm().is_default()) {
   2155     use_tensor_ops = true;
   2156     algo = GetCudnnConvolutionForwardAlgo(
   2157         stream, parent, dnn_handle, input_nd, filter, conv, output_nd,
   2158         /*specify_workspace_limit=*/scratch_allocator != nullptr,
   2159         scratch_allocator);
   2160   } else {
   2161     use_tensor_ops = algorithm_config.algorithm().tensor_ops_enabled();
   2162     algo = ToConvForwardAlgo(algorithm_config.algorithm());
   2163   }
   2164   size_t size_in_bytes;
   2165   auto status = wrap::cudnnGetConvolutionForwardWorkspaceSize(
   2166       parent, ToHandle(dnn_handle), /*srcDesc=*/input_nd.handle(),
   2167       /*filterDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
   2168       /*destDesc=*/output_nd.handle(), /*algo=*/algo,
   2169       /*sizeInBytes=*/&size_in_bytes);
   2170   int64 size_in_bytes_int64 = size_in_bytes;
   2171   if (TF_PREDICT_FALSE(status != CUDNN_STATUS_SUCCESS)) {
   2172     CHECK(is_profiling) << "Cannot query the size of workspace needed "
   2173                            "for the specified algorithm: "
   2174                         << algorithm_config.algorithm().algo_id() << " "
   2175                         << ToString(status);
   2176     // Silently return when we are profiling.
   2177     return dnn::AlgorithmDesc();
   2178   }
   2179   if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
   2180     LOG(WARNING) << "cudnnGetConvolutionForwardWorkspaceSize() returned "
   2181                     "negative sizeInBytes value. This could be a cudnn bug.";
   2182     if (TF_PREDICT_TRUE(is_profiling)) {
   2183       return dnn::AlgorithmDesc();
   2184     }
   2185   } else if (size_in_bytes_int64 > 0) {
   2186     port::StatusOr<DeviceMemory<uint8>> allocated;
   2187     if (TF_PREDICT_TRUE(scratch_allocator)) {
   2188       allocated = scratch_allocator->AllocateBytes(stream, size_in_bytes);
   2189       if (TF_PREDICT_TRUE(allocated.ok())) {
   2190         *scratch = allocated.ValueOrDie();
   2191       } else {
   2192         if (TF_PREDICT_TRUE(is_profiling)) {
   2193           // Silently return when we are profiling.
   2194           return dnn::AlgorithmDesc();
   2195         }
   2196         LOG(WARNING) << allocated.status().error_message();
   2197         // For the int8 case, we fail at this point since the no_scratch
   2198         // algorithm should be set to dnn::kDefaultAlgorithm.
   2199         CHECK(!algorithm_config.algorithm_no_scratch().is_default())
   2200             << "The primary convolution algorithm failed memory allocation, "
   2201                "while a secondary algorithm is not provided.";
   2202       }
   2203     }
   2204     if (TF_PREDICT_FALSE(!allocated.ok())) {
   2205       if (algorithm_config.algorithm_no_scratch().is_default()) {
   2206         use_tensor_ops = true;
   2207         algo = GetCudnnConvolutionForwardAlgo(
   2208             stream, parent, dnn_handle, input_nd, filter, conv, output_nd,
   2209             /*specify_workspace_limit=*/false, nullptr);
   2210       } else {
   2211         use_tensor_ops = algorithm_config.algorithm().tensor_ops_enabled();
   2212         algo = ToConvForwardAlgo(algorithm_config.algorithm_no_scratch());
   2213       }
   2214     }
   2215   }
   2216 
   2217   return dnn::AlgorithmDesc(algo, use_tensor_ops);
   2218 }
   2219 
   2220 // A helper class to set env-vars and choose options for cudnn-related
   2221 // algorithms.
   2222 template <typename EnvVar>
   2223 class CudnnEnvVar {
   2224  public:
   2225   static bool IsEnabled() {
   2226     static bool is_enabled = IsEnabledImpl();
   2227     return is_enabled;
   2228   }
   2229 
   2230  private:
   2231   static bool IsEnabledImpl() {
   2232     const char* tf_env_var_val = getenv(EnvVar::kName);
   2233     if (tf_env_var_val != nullptr) {
   2234       port::StringPiece tf_env_var_val_str(tf_env_var_val);
   2235       if (tf_env_var_val_str == "0") {
   2236         return false;
   2237       }
   2238       return true;
   2239     }
   2240     return EnvVar::kDefaultFlag;
   2241   }
   2242 };
   2243 
   2244 // A helper struct to decide whether to enable the FFT_TILING algorithms for
   2245 // forward convolution. Before cudnn v5.1 it works fine but since cudnn v5.1
   2246 // it is turned off due to memory corruption caused by some shapes with this
   2247 // algorithm.
   2248 // Before NVIDIA fixes the memory corruption bug, users can explicitly
   2249 // enable the algorithm through an env-var "TF_ENABLE_FFT_TILING_FORWARD=1".
   2250 struct FftTilingForward {
   2251   static constexpr const char* kName = "TF_ENABLE_FFT_TILING_FORWARD";
   2252   // TODO(yangzihao): turn the default to True when the memory corruption bug
   2253   // is fixed.
   2254   static constexpr bool kDefaultFlag = CUDNN_VERSION < 5100;
   2255 };
   2256 
   2257 // A helper struct to decide whether to enable the WINOGRAD_NONFUSED algorithms.
   2258 // By default it is turned on, users can explicitly disable them through an
   2259 // env-var "TF_ENABLE_WINOGRAD_NONFUSED=0".
   2260 // https://github.com/tensorflow/tensorflow/pull/4901
   2261 struct WinogradNonfused {
   2262   static constexpr const char* kName = "TF_ENABLE_WINOGRAD_NONFUSED";
   2263   // NVIDIA has fixed winograd nonfused bug for cudnn v>=7.
   2264   // For cudnn v>=5.1, we have a workaround and for any lower version, we
   2265   // disable it by default.
   2266   static constexpr bool kDefaultFlag = CUDNN_VERSION >= 5100;
   2267 };
   2268 
   2269 // A helper struct to decide whether to use FP32 as the internal compute type
   2270 // for convolution when the input data type is FP16. By default it is turned on,
   2271 // users can explicitly disable them (choose to use FP16 as the internal compute
   2272 // type) through an env-var "TF_FP16_CONV_USE_FP32_COMPUTE=0".
   2273 struct ConvDoFP32ComputationFP16Input {
   2274   static constexpr const char* kName = "TF_FP16_CONV_USE_FP32_COMPUTE";
   2275   // Using FP16 as the internal compute type for convolution when the input data
   2276   // type is FP16 is only supported on architectures with true fp16 support
   2277   // (compute capability 5.3 and 6.0). Setting this to false in an unsupported
   2278   // architecture will cause internal errors.
   2279   static constexpr bool kDefaultFlag = true;
   2280 };
   2281 
   2282 // A group of helper functions to return the internal compute type for
   2283 // convolutions in cudnn.
   2284 // TODO(yangzihao): Add support for float64.
   2285 template <typename T>
   2286 cudnnDataType_t GetConvComputeType() {
   2287   return CUDNN_DATA_FLOAT;
   2288 }
   2289 
   2290 template <>
   2291 cudnnDataType_t GetConvComputeType<Eigen::half>() {
   2292   if (CudnnEnvVar<ConvDoFP32ComputationFP16Input>::IsEnabled()) {
   2293     return CUDNN_DATA_FLOAT;
   2294   } else {
   2295     return CUDNN_DATA_HALF;
   2296   }
   2297 }
   2298 
   2299 }  // namespace
   2300 
   2301 template <class T>
   2302 bool CudnnSupport::DoConvolveImpl(
   2303     Stream* stream, const BatchDescriptor& batch_descriptor,
   2304     const DeviceMemory<T>& input_data,
   2305     const FilterDescriptor& filter_descriptor,
   2306     const DeviceMemory<T>& filter_data,
   2307     const ConvolutionDescriptor& convolution_descriptor,
   2308     const BatchDescriptor& output_descriptor, DeviceMemory<T>* output_data,
   2309     ScratchAllocator* scratch_allocator,
   2310     const dnn::AlgorithmConfig& algorithm_config,
   2311     dnn::ProfileResult* output_profile_result) {
   2312   cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
   2313   ScopedTensorDescriptor input_nd{parent_, batch_descriptor, cudnn_type};
   2314   ScopedTensorDescriptor output_nd{parent_, output_descriptor, cudnn_type};
   2315   ScopedFilterDescriptor filter{parent_, filter_descriptor, batch_descriptor,
   2316                                 cudnn_type};
   2317   ScopedConvolutionDescriptor conv{parent_, convolution_descriptor,
   2318                                    GetConvComputeType<T>()};
   2319 
   2320   mutex_lock lock{dnn_handle_mutex_};
   2321   auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
   2322                                      AsCUDAStreamValue(stream));
   2323   if (status != CUDNN_STATUS_SUCCESS) {
   2324     LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status);
   2325   }
   2326   // Alpha is the scaling factor for input.
   2327   float alpha = 1.0;
   2328   // Beta is the scaling factor for output.
   2329   float beta = 0.0;
   2330 
   2331   const bool is_profiling = output_profile_result != nullptr;
   2332   cudnnConvolutionFwdAlgo_t algo;
   2333   bool use_tensor_ops;
   2334   DeviceMemory<uint8> scratch;
   2335 
   2336   // TODO(pauldonnelly): Replace the following code with a call to
   2337   //   GetCudnnConvolutionForwardAlgorithm().
   2338   if (algorithm_config.algorithm().is_default()) {
   2339     // With the default algorithm, use Cudnn's heuristics.
   2340     auto get_algorithm =
   2341         [&](bool specify_limit) SHARED_LOCKS_REQUIRED(dnn_handle_mutex_) {
   2342           cudnnConvolutionFwdPreference_t preference =
   2343               specify_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
   2344                             : CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
   2345 
   2346           auto memory_limit_bytes =
   2347               scratch_allocator == nullptr
   2348                   ? 0
   2349                   : scratch_allocator->GetMemoryLimitInBytes(stream);
   2350           if (memory_limit_bytes < 0) {
   2351             memory_limit_bytes = 0;
   2352           }
   2353 
   2354           cudnnConvolutionFwdAlgo_t algo_to_use;
   2355           status = wrap::cudnnGetConvolutionForwardAlgorithm(
   2356               parent_, ToHandle(dnn_handle_), input_nd.handle(),
   2357               filter.handle(), conv.handle(), output_nd.handle(),
   2358               /*preference=*/preference,
   2359               /*memoryLimitInBytes=*/memory_limit_bytes,
   2360               /*algo=*/&algo_to_use);
   2361           CHECK_EQ(status, CUDNN_STATUS_SUCCESS)
   2362               << "Unable to find a suitable "
   2363                  "algorithm for doing forward "
   2364                  "convolution";
   2365           return algo_to_use;
   2366         };
   2367 
   2368     algo = get_algorithm(/*specify_limit=*/scratch_allocator != nullptr);
   2369     use_tensor_ops = true;
   2370     if (scratch_allocator != nullptr) {
   2371       size_t size_in_bytes;
   2372       status = wrap::cudnnGetConvolutionForwardWorkspaceSize(
   2373           parent_, ToHandle(dnn_handle_), /*srcDesc=*/input_nd.handle(),
   2374           /*filterDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
   2375           /*destDesc=*/output_nd.handle(), /*algo=*/algo,
   2376           /*sizeInBytes=*/&size_in_bytes);
   2377       int64 size_in_bytes_int64 = size_in_bytes;
   2378       if (status == CUDNN_STATUS_SUCCESS && size_in_bytes_int64 != 0) {
   2379         if (size_in_bytes_int64 > 0) {
   2380           auto allocated =
   2381               scratch_allocator->AllocateBytes(stream, size_in_bytes);
   2382           if (allocated.ok()) {
   2383             scratch = allocated.ValueOrDie();
   2384           } else {
   2385             LOG(WARNING) << allocated.status().error_message();
   2386           }
   2387         } else {
   2388           LOG(WARNING)
   2389               << "cudnnGetConvolutionForwardWorkspaceSize() returned "
   2390                  "negative sizeInBytes value. This could be a cudnn bug.";
   2391         }
   2392       }
   2393     }
   2394 
   2395     // If we didn't allocate any scratch space (perhaps because of failed
   2396     // allocation), we force a switch back to the "no workspace" algorithm.
   2397     if (scratch == nullptr) {
   2398       algo = get_algorithm(/*specify_limit=*/false);
   2399     }
   2400   } else {
   2401     // An algorithm has been specified.
   2402     dnn::AlgorithmDesc algotype = algorithm_config.algorithm();
   2403     algo = ToConvForwardAlgo(algotype);
   2404     use_tensor_ops = algotype.tensor_ops_enabled();
   2405     conv.set_use_tensor_op_math(use_tensor_ops);
   2406     size_t size_in_bytes;
   2407     status = wrap::cudnnGetConvolutionForwardWorkspaceSize(
   2408         parent_, ToHandle(dnn_handle_), /*srcDesc=*/input_nd.handle(),
   2409         /*filterDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
   2410         /*destDesc=*/output_nd.handle(), /*algo=*/algo,
   2411         /*sizeInBytes=*/&size_in_bytes);
   2412     if (status != CUDNN_STATUS_SUCCESS) {
   2413       if (is_profiling) {
   2414         // Silently return when we are profiling.
   2415         return false;
   2416       }
   2417       LOG(FATAL) << "Cannot query the size of workspace needed for the given "
   2418                     "algorithm: "
   2419                  << algorithm_config.algorithm().algo_id();
   2420     }
   2421     int64 size_in_bytes_int64 = size_in_bytes;
   2422     if (size_in_bytes_int64 > 0) {
   2423       if (scratch_allocator == nullptr) {
   2424         LOG(FATAL) << "An allocator must be specified when scratch memory is "
   2425                       "needed";
   2426       }
   2427       auto allocated = scratch_allocator->AllocateBytes(stream, size_in_bytes);
   2428       if (is_profiling && !allocated.ok()) {
   2429         // Silently return when we are profiling.
   2430         return false;
   2431       }
   2432       if (allocated.ok()) {
   2433         scratch = allocated.ValueOrDie();
   2434       } else {
   2435         LOG(WARNING) << allocated.status().error_message();
   2436       }
   2437       if (scratch == nullptr) {
   2438         CHECK(!algorithm_config.algorithm_no_scratch().is_default())
   2439             << "The primary convolution algorithm failed memory allocation, "
   2440                "while a secondary algorithm is not provided.";
   2441         dnn::AlgorithmDesc algotype = algorithm_config.algorithm_no_scratch();
   2442         algo = ToConvForwardAlgo(algotype);
   2443         use_tensor_ops = algotype.tensor_ops_enabled();
   2444         conv.set_use_tensor_op_math(use_tensor_ops);
   2445       }
   2446     } else if (size_in_bytes_int64 < 0) {
   2447       LOG(WARNING) << "cudnnGetConvolutionForwardWorkspaceSize() returned "
   2448                       "negative sizeInBytes value. This could be a cudnn bug.";
   2449     }
   2450   }
   2451   std::unique_ptr<CUDATimer> timer;
   2452   if (is_profiling) {
   2453     timer.reset(new CUDATimer(parent_));  // NOLINT
   2454     if (!timer->Init()) {
   2455       return false;
   2456     }
   2457     // The start and stop of the timer should be as close to the Cudnn call as
   2458     // possible. It is still possible for other threads to issue workload on
   2459     // to this stream. So it could take multiple profiling measurements.
   2460     if (!timer->Start(AsCUDAStream(stream))) {
   2461       timer->Destroy();
   2462       return false;
   2463     }
   2464   }
   2465   status = wrap::cudnnConvolutionForward(
   2466       parent_, ToHandle(dnn_handle_),
   2467       /*alpha=*/&alpha, /*srcDesc=*/input_nd.handle(),
   2468       /*srcData=*/input_data.opaque(), /*filterDesc=*/filter.handle(),
   2469       /*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(),
   2470       /*algo=*/algo, /*workSpace=*/scratch.opaque(),
   2471       /*workSpaceSizeInBytes=*/scratch.size(), /*beta=*/&beta,
   2472       /*destDesc=*/output_nd.handle(), /*destData=*/output_data->opaque());
   2473 
   2474   if (is_profiling) {
   2475     if (!timer->Stop(AsCUDAStream(stream))) {
   2476       timer->Destroy();
   2477       return false;
   2478     }
   2479     if (status == CUDNN_STATUS_SUCCESS) {
   2480       dnn::AlgorithmDesc algotype(algo, use_tensor_ops);
   2481       output_profile_result->set_algorithm(algotype);
   2482       output_profile_result->set_elapsed_time_in_ms(
   2483           timer->GetElapsedMilliseconds());
   2484     }
   2485     timer->Destroy();
   2486   }
   2487 
   2488   if (status != CUDNN_STATUS_SUCCESS) {
   2489     // Silently return when we are profiling.
   2490     if (!is_profiling) {
   2491       LOG(ERROR) << "failed to enqueue convolution on stream: "
   2492                  << ToString(status);
   2493     }
   2494     return false;
   2495   }
   2496 
   2497   return true;
   2498 }
   2499 
   2500 template <typename Type, typename BiasType, typename ScaleType,
   2501           int cudnn_data_type, int cudnn_compute_type>
   2502 bool CudnnSupport::DoFusedConvolveImpl(
   2503     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
   2504     const DeviceMemory<Type>& conv_input_data, ScaleType conv_input_scale,
   2505     const dnn::FilterDescriptor& filter_descriptor,
   2506     const DeviceMemory<Type>& filter_data,
   2507     const dnn::ConvolutionDescriptor& convolution_descriptor,
   2508     const DeviceMemory<Type>& side_input_data, ScaleType side_input_scale,
   2509     const dnn::BatchDescriptor& bias_descriptor,
   2510     const DeviceMemory<BiasType>& biases, dnn::ActivationMode activation_mode,
   2511     const dnn::BatchDescriptor& output_descriptor,
   2512     DeviceMemory<Type>* output_data, ScratchAllocator* scratch_allocator,
   2513     const dnn::AlgorithmConfig& algorithm_config,
   2514     dnn::ProfileResult* output_profile_result) {
   2515 #if CUDNN_VERSION < 6000
   2516   LOG(ERROR) << "cudnnConvolutionBiasActivationForward() is only "
   2517                 "supported for cuDNN version >= 6";
   2518   return false;
   2519 #else
   2520   ScopedTensorDescriptor conv_input_nd{
   2521       parent_, conv_input_descriptor,
   2522       static_cast<cudnnDataType_t>(cudnn_data_type)};
   2523   ScopedTensorDescriptor output_nd{
   2524       parent_, output_descriptor,
   2525       static_cast<cudnnDataType_t>(cudnn_data_type)};
   2526   ScopedFilterDescriptor filter{parent_, filter_descriptor,
   2527                                 conv_input_descriptor,
   2528                                 static_cast<cudnnDataType_t>(cudnn_data_type)};
   2529   ScopedTensorDescriptor bias_nd{parent_, bias_descriptor, CUDNN_DATA_FLOAT};
   2530   ScopedConvolutionDescriptor conv{
   2531       parent_, convolution_descriptor,
   2532       static_cast<cudnnDataType_t>(cudnn_compute_type)};
   2533 
   2534   mutex_lock lock{dnn_handle_mutex_};
   2535   auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
   2536                                      AsCUDAStreamValue(stream));
   2537   CHECK(status == CUDNN_STATUS_SUCCESS)
   2538       << "failed to set stream for cudnn handle: " << ToString(status);
   2539 
   2540   const bool is_profiling = output_profile_result != nullptr;
   2541   DeviceMemory<uint8> scratch;
   2542   dnn::AlgorithmDesc algotype = GetCudnnConvolutionForwardAlgorithm(
   2543       stream, parent_, dnn_handle_, algorithm_config, is_profiling,
   2544       conv_input_nd, filter, conv, output_nd, scratch_allocator, &scratch);
   2545   if (algotype.is_default()) {
   2546     if (!is_profiling) {
   2547       LOG(ERROR) << "No suitable algorithm found";
   2548     }
   2549     return false;
   2550   }
   2551   auto algo = static_cast<cudnnConvolutionFwdAlgo_t>(algotype.algo_id());
   2552   conv.set_use_tensor_op_math(algotype.tensor_ops_enabled());
   2553 
   2554   if (activation_mode != dnn::ActivationMode::kRelu) {
   2555     LOG(ERROR) << "cudnnConvolutionBiasActivationForward() only supports Relu "
   2556                   "activation.";
   2557     return false;
   2558   }
   2559 
   2560   std::unique_ptr<CUDATimer> timer;
   2561   if (is_profiling) {
   2562     timer.reset(new CUDATimer(parent_));  // NOLINT
   2563     if (!timer->Init()) {
   2564       return false;
   2565     }
   2566     // The start and stop of the timer should be as close to the Cudnn call as
   2567     // possible. It is still possible for other threads to issue workload on
   2568     // to this stream. So it could take multiple profiling measurements.
   2569     if (!timer->Start(AsCUDAStream(stream))) {
   2570       timer->Destroy();
   2571       return false;
   2572     }
   2573   }
   2574   // CUDNN v6 only supports CUDNN_NOT_PROPAGATE_NAN as the reluNanOpt for
   2575   // activation descriptor. Note that this will change the nan propagation
   2576   // behavior from separate conv, bias, and relu (which by default is
   2577   // CUDNN_PROPAGATE_NAN.
   2578   ScopedActivationDescriptor activation_desc{parent_, activation_mode,
   2579                                              CUDNN_NOT_PROPAGATE_NAN,
   2580                                              output_descriptor.value_max()};
   2581   auto side_input_data_ptr = (side_input_scale == 0) ? output_data->opaque()
   2582                                                      : side_input_data.opaque();
   2583 
   2584   VLOG(2) << "\nconv_input_scale = " << conv_input_scale
   2585           << "\nconv_input_nd.handle() = " << conv_input_nd.handle()
   2586           << "\nconv_input_data.opaque() = " << conv_input_data.opaque()
   2587           << "\nfilter.handle() = " << filter.handle()
   2588           << "\nfilter_data.opaque() = " << filter_data.opaque()
   2589           << "\nconv.handle() = " << conv.handle() << "\nalgo = " << algo
   2590           << "\nscratch.opaque() = " << scratch.opaque()
   2591           << "\nscratch.size() = " << scratch.size()
   2592           << "\nside_input_scale = " << side_input_scale
   2593           << "\noutput_nd.handle() = " << output_nd.handle()
   2594           << "\nside_input_data_ptr = " << side_input_data_ptr
   2595           << "\nbias_nd.handle() = " << bias_nd.handle()
   2596           << "\nbiases.opaque() = " << biases.opaque()
   2597           << "\nactivation_desc.handle() = " << activation_desc.handle()
   2598           << "\noutput_nd.handle() = " << output_nd.handle()
   2599           << "\noutput_data->opaque() = " << output_data->opaque();
   2600 
   2601   status = wrap::cudnnConvolutionBiasActivationForward(
   2602       parent_, ToHandle(dnn_handle_), /*alpha1=*/&conv_input_scale,
   2603       /*srcDesc=*/conv_input_nd.handle(), /*srcData=*/conv_input_data.opaque(),
   2604       /*filterDesc=*/filter.handle(), /*filterData=*/filter_data.opaque(),
   2605       /*convDesc=*/conv.handle(), algo, /*workSpace=*/scratch.opaque(),
   2606       /*workSpaceSizeInBytes=*/scratch.size(), /*alpha2=*/&side_input_scale,
   2607       /*zDesc=*/output_nd.handle(), /*z=*/side_input_data_ptr,
   2608       /*biasDesc=*/bias_nd.handle(), /*bias=*/biases.opaque(),
   2609       /*activationDesc=*/activation_desc.handle(),
   2610       /*destDesc=*/output_nd.handle(), /*destData=*/output_data->opaque());
   2611 
   2612   if (is_profiling) {
   2613     if (!timer->Stop(AsCUDAStream(stream))) {
   2614       timer->Destroy();
   2615       return false;
   2616     }
   2617     if (status == CUDNN_STATUS_SUCCESS) {
   2618       output_profile_result->set_algorithm(algotype);
   2619       output_profile_result->set_elapsed_time_in_ms(
   2620           timer->GetElapsedMilliseconds());
   2621     }
   2622     timer->Destroy();
   2623   }
   2624 
   2625   if (status != CUDNN_STATUS_SUCCESS) {
   2626     // Silently return when we are profiling.
   2627     if (!is_profiling) {
   2628       LOG(ERROR) << "failed to enqueue convolution on stream: "
   2629                  << ToString(status);
   2630     }
   2631     return false;
   2632   }
   2633 
   2634   return true;
   2635 #endif  // CUDNN_VERSION < 6000
   2636 }
   2637 
   2638 bool CudnnSupport::GetConvolveAlgorithms(
   2639     bool with_winograd_nonfused, int cc_major, int cc_minor,
   2640     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
   2641   std::vector<dnn::AlgorithmDesc::Index> algo_types = {
   2642     // clang-format off
   2643     CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
   2644     CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
   2645     CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
   2646     CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
   2647     CUDNN_CONVOLUTION_FWD_ALGO_FFT,
   2648 #if CUDNN_VERSION >= 5000
   2649     CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
   2650 #endif
   2651     // clang-format on
   2652   };
   2653   if (CudnnEnvVar<FftTilingForward>::IsEnabled()) {
   2654     algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING);
   2655   }
   2656 #if CUDNN_VERSION >= 5100
   2657   if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
   2658     algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED);
   2659   }
   2660 #endif
   2661 
   2662   out_algorithms->clear();
   2663   for (auto i : algo_types) {
   2664     out_algorithms->push_back({i, /*use_tensor_ops=*/false});
   2665     if (cc_major >= 7 && CUDNN_VERSION >= 7000 && TensorOpMathEnabled()) {
   2666       out_algorithms->push_back({i, /*use_tensor_ops=*/true});
   2667     }
   2668   }
   2669   return true;
   2670 }
   2671 
   2672 bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
   2673     bool with_winograd_nonfused, int cc_major, int cc_minor,
   2674     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
   2675   std::vector<dnn::AlgorithmDesc::Index> algo_types = {
   2676     // clang-format off
   2677     CUDNN_CONVOLUTION_BWD_DATA_ALGO_0,
   2678     CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
   2679     CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
   2680     CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
   2681 #if CUDNN_VERSION >= 5000
   2682     CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
   2683 #endif
   2684     // clang-format on
   2685   };
   2686 #if CUDNN_VERSION >= 5100
   2687   if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
   2688     algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED);
   2689   }
   2690 #endif
   2691 
   2692   out_algorithms->clear();
   2693   for (auto i : algo_types) {
   2694     out_algorithms->push_back({i, /*use_tensor_ops=*/false});
   2695     if (cc_major >= 7 && CUDNN_VERSION >= 7000 && TensorOpMathEnabled()) {
   2696       out_algorithms->push_back({i, /*use_tensor_ops=*/true});
   2697     }
   2698   }
   2699   return true;
   2700 }
   2701 
   2702 bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
   2703     bool with_winograd_nonfused, int cc_major, int cc_minor,
   2704     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
   2705   std::vector<dnn::AlgorithmDesc::Index> algo_types = {
   2706       // clang-format off
   2707       CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0,
   2708       CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
   2709       CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
   2710       CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3,
   2711       // Based on cudnn.h, the following is not implemented.
   2712       // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD,
   2713       // clang-format on
   2714   };
   2715 #if CUDNN_VERSION >= 5100
   2716   if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
   2717     algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED);
   2718   }
   2719 #endif
   2720 
   2721   out_algorithms->clear();
   2722   for (auto i : algo_types) {
   2723     out_algorithms->push_back({i, /*use_tensor_ops=*/false});
   2724     if (cc_major >= 7 && CUDNN_VERSION >= 7000 && TensorOpMathEnabled()) {
   2725       out_algorithms->push_back({i, /*use_tensor_ops=*/true});
   2726     }
   2727   }
   2728   return true;
   2729 }
   2730 
   2731 bool CudnnSupport::DoBatchNormalizationForward(
   2732     Stream* stream, const DeviceMemory<float>& x,
   2733     const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
   2734     const DeviceMemory<float>& estimated_mean,
   2735     const DeviceMemory<float>& estimated_variance,
   2736     const dnn::BatchDescriptor& x_desc,
   2737     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
   2738     DeviceMemory<float>* y, DeviceMemory<float>* batch_mean,
   2739     DeviceMemory<float>* batch_var, DeviceMemory<float>* saved_mean,
   2740     DeviceMemory<float>* saved_inv_var, bool is_training,
   2741     std::function<const DeviceMemory<float>&()> var_to_inv_var,
   2742     std::function<void()> inv_var_to_var) {
   2743   return DoBatchNormalizationForwardImpl<float, float>(
   2744       stream, dnn::DataType::kFloat, dnn::DataType::kFloat, x, scale, offset,
   2745       estimated_mean, estimated_variance, x_desc, scale_offset_desc, epsilon, y,
   2746       batch_mean, batch_var, saved_mean, saved_inv_var, is_training,
   2747       std::move(var_to_inv_var), std::move(inv_var_to_var));
   2748 }
   2749 
   2750 bool CudnnSupport::DoBatchNormalizationForward(
   2751     Stream* stream, const DeviceMemory<Eigen::half>& x,
   2752     const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
   2753     const DeviceMemory<float>& estimated_mean,
   2754     const DeviceMemory<float>& estimated_variance,
   2755     const dnn::BatchDescriptor& x_desc,
   2756     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
   2757     DeviceMemory<Eigen::half>* y, DeviceMemory<float>* batch_mean,
   2758     DeviceMemory<float>* batch_var, DeviceMemory<float>* saved_mean,
   2759     DeviceMemory<float>* saved_inv_var, bool is_training,
   2760     std::function<const DeviceMemory<float>&()> var_to_inv_var,
   2761     std::function<void()> inv_var_to_var) {
   2762   return DoBatchNormalizationForwardImpl<Eigen::half, float>(
   2763       stream, dnn::DataType::kHalf, dnn::DataType::kFloat, x, scale, offset,
   2764       estimated_mean, estimated_variance, x_desc, scale_offset_desc, epsilon, y,
   2765       batch_mean, batch_var, saved_mean, saved_inv_var, is_training,
   2766       std::move(var_to_inv_var), std::move(inv_var_to_var));
   2767 }
   2768 
   2769 template <class T, class U>
   2770 bool CudnnSupport::DoBatchNormalizationForwardImpl(
   2771     Stream* stream, dnn::DataType input_data_type,
   2772     dnn::DataType scale_data_type, const DeviceMemory<T>& x,
   2773     const DeviceMemory<U>& scale, const DeviceMemory<U>& offset,
   2774     const DeviceMemory<U>& estimated_mean,
   2775     const DeviceMemory<U>& estimated_variance,
   2776     const dnn::BatchDescriptor& x_desc,
   2777     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
   2778     DeviceMemory<T>* y, DeviceMemory<U>* batch_mean, DeviceMemory<U>* batch_var,
   2779     DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
   2780     bool is_training, std::function<const DeviceMemory<U>&()> var_to_inv_var,
   2781     std::function<void()> inv_var_to_var) {
   2782   mutex_lock lock{dnn_handle_mutex_};
   2783   auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
   2784                                      AsCUDAStreamValue(stream));
   2785   if (status != CUDNN_STATUS_SUCCESS) {
   2786     LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
   2787     return false;
   2788   }
   2789 
   2790   ScopedTensorDescriptor x_descriptor{parent_, x_desc,
   2791                                       ToCudnnDataType(input_data_type)};
   2792   ScopedTensorDescriptor scale_offset_descriptor{
   2793       parent_, scale_offset_desc, ToCudnnDataType(scale_data_type)};
   2794   cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
   2795 #if CUDNN_VERSION >= 7000
   2796   if (BatchnormSpatialPersistentEnabled()) {
   2797     mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
   2798   }
   2799 #endif
   2800   float one = 1.0;
   2801   float zero = 0.0;
   2802 
   2803   if (is_training) {
   2804     CHECK_EQ(batch_mean->is_null(), batch_var->is_null())
   2805         << "batch_mean and batch_var must both be null or both be non-null";
   2806 
   2807     void* batch_mean_opaque;
   2808     void* batch_var_opaque;
   2809     if (!batch_mean->is_null() && !batch_var->is_null()) {
   2810       stream->ThenMemZero(batch_mean, batch_mean->size());
   2811       stream->ThenMemZero(batch_var, batch_var->size());
   2812       batch_mean_opaque = batch_mean->opaque();
   2813       batch_var_opaque = batch_var->opaque();
   2814     } else {
   2815       batch_mean_opaque = nullptr;
   2816       batch_var_opaque = nullptr;
   2817     }
   2818 
   2819     status = wrap::cudnnBatchNormalizationForwardTraining(
   2820         parent_, ToHandle(dnn_handle_), mode, &one, &zero,
   2821         x_descriptor.handle(), x.opaque(), x_descriptor.handle(), y->opaque(),
   2822         scale_offset_descriptor.handle(), scale.opaque(), offset.opaque(), 1.0,
   2823         batch_mean_opaque, batch_var_opaque, epsilon, saved_mean->opaque(),
   2824         saved_inv_var->opaque());
   2825 #if CUDNN_VERSION < 5000
   2826     CHECK(inv_var_to_var);
   2827     inv_var_to_var();
   2828 #endif
   2829   } else {
   2830 #if CUDNN_VERSION < 5000
   2831     CHECK(var_to_inv_var);
   2832     const void* maybe_inv_var = var_to_inv_var().opaque();
   2833 #else
   2834     const void* maybe_inv_var = estimated_variance.opaque();
   2835 #endif
   2836     status = wrap::cudnnBatchNormalizationForwardInference(
   2837         parent_, ToHandle(dnn_handle_), mode, &one, &zero,
   2838         x_descriptor.handle(), x.opaque(), x_descriptor.handle(), y->opaque(),
   2839         scale_offset_descriptor.handle(), scale.opaque(), offset.opaque(),
   2840         estimated_mean.opaque(), maybe_inv_var, epsilon);
   2841   }
   2842   if (status != CUDNN_STATUS_SUCCESS) {
   2843     LOG(ERROR) << "failed to enqueue forward batch normalization on stream: "
   2844                << ToString(status);
   2845     return false;
   2846   }
   2847   return true;
   2848 }
   2849 
   2850 bool CudnnSupport::DoBatchNormalizationBackward(
   2851     Stream* stream, const DeviceMemory<float>& y_backprop,
   2852     const DeviceMemory<float>& x, const DeviceMemory<float>& scale,
   2853     const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
   2854     const dnn::BatchDescriptor& x_desc,
   2855     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
   2856     DeviceMemory<float>* x_backprop, DeviceMemory<float>* scale_backprop,
   2857     DeviceMemory<float>* offset_backprop) {
   2858   return DoBatchNormalizationBackwardImpl(
   2859       stream, CUDNN_DATA_FLOAT, CUDNN_DATA_FLOAT, y_backprop, x, scale, mean,
   2860       inv_var, x_desc, scale_offset_desc, epsilon, x_backprop, scale_backprop,
   2861       offset_backprop);
   2862 }
   2863 
   2864 bool CudnnSupport::DoBatchNormalizationBackward(
   2865     Stream* stream, const DeviceMemory<Eigen::half>& y_backprop,
   2866     const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale,
   2867     const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
   2868     const dnn::BatchDescriptor& x_desc,
   2869     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
   2870     DeviceMemory<Eigen::half>* x_backprop, DeviceMemory<float>* scale_backprop,
   2871     DeviceMemory<float>* offset_backprop) {
   2872   return DoBatchNormalizationBackwardImpl(
   2873       stream, CUDNN_DATA_HALF, CUDNN_DATA_FLOAT, y_backprop, x, scale, mean,
   2874       inv_var, x_desc, scale_offset_desc, epsilon, x_backprop, scale_backprop,
   2875       offset_backprop);
   2876 }
   2877 
   2878 template <class T, class U>
   2879 bool CudnnSupport::DoBatchNormalizationBackwardImpl(
   2880     Stream* stream, int cudnn_input_type, int cudnn_scale_type,
   2881     const DeviceMemory<T>& y_backprop, const DeviceMemory<T>& x,
   2882     const DeviceMemory<U>& scale, const DeviceMemory<U>& mean,
   2883     const DeviceMemory<U>& inv_var, const dnn::BatchDescriptor& x_desc,
   2884     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
   2885     DeviceMemory<T>* x_backprop, DeviceMemory<U>* scale_backprop,
   2886     DeviceMemory<U>* offset_backprop) {
   2887   mutex_lock lock{dnn_handle_mutex_};
   2888   auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
   2889                                      AsCUDAStreamValue(stream));
   2890   if (status != CUDNN_STATUS_SUCCESS) {
   2891     LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
   2892     return false;
   2893   }
   2894 
   2895   ScopedTensorDescriptor x_descriptor{
   2896       parent_, x_desc, static_cast<cudnnDataType_t>(cudnn_input_type)};
   2897   ScopedTensorDescriptor scale_offset_descriptor{
   2898       parent_, scale_offset_desc,
   2899       static_cast<cudnnDataType_t>(cudnn_scale_type)};
   2900   cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
   2901 #if CUDNN_VERSION >= 7000
   2902   if (BatchnormSpatialPersistentEnabled()) {
   2903     mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
   2904   }
   2905 #endif
   2906   float one = 1.0;
   2907   float zero = 0.0;
   2908 
   2909   status = wrap::cudnnBatchNormalizationBackward(
   2910       parent_, ToHandle(dnn_handle_), mode, &one, &zero, &one, &zero,
   2911       x_descriptor.handle(), x.opaque(), x_descriptor.handle(),
   2912       y_backprop.opaque(), x_descriptor.handle(), x_backprop->opaque(),
   2913       scale_offset_descriptor.handle(), scale.opaque(),
   2914       scale_backprop->opaque(), offset_backprop->opaque(), epsilon,
   2915       mean.opaque(), inv_var.opaque());
   2916   if (status != CUDNN_STATUS_SUCCESS) {
   2917     LOG(ERROR) << "failed to enqueue backward batch normalization on stream: "
   2918                << ToString(status);
   2919     return false;
   2920   }
   2921   return true;
   2922 }
   2923 
   2924 bool CudnnSupport::DoConvolve(
   2925     Stream* stream, const BatchDescriptor& batch_descriptor,
   2926     const DeviceMemory<float>& input_data,
   2927     const FilterDescriptor& filter_descriptor,
   2928     const DeviceMemory<float>& filter_data,
   2929     const ConvolutionDescriptor& convolution_descriptor,
   2930     const BatchDescriptor& output_descriptor, DeviceMemory<float>* output_data,
   2931     ScratchAllocator* scratch_allocator,
   2932     const dnn::AlgorithmConfig& algorithm_config,
   2933     dnn::ProfileResult* output_profile_result) {
   2934   return DoConvolveImpl<float>(
   2935       stream, batch_descriptor, input_data, filter_descriptor, filter_data,
   2936       convolution_descriptor, output_descriptor, output_data, scratch_allocator,
   2937       algorithm_config, output_profile_result);
   2938 }
   2939 
   2940 bool CudnnSupport::DoConvolve(
   2941     Stream* stream, const BatchDescriptor& batch_descriptor,
   2942     const DeviceMemory<double>& input_data,
   2943     const FilterDescriptor& filter_descriptor,
   2944     const DeviceMemory<double>& filter_data,
   2945     const ConvolutionDescriptor& convolution_descriptor,
   2946     const BatchDescriptor& output_descriptor,
   2947     DeviceMemory<double>* output_data) {
   2948   LOG(ERROR) << "double-based DNN not yet implemented";
   2949   return false;
   2950 }
   2951 
   2952 bool CudnnSupport::DoConvolve(
   2953     Stream* stream, const BatchDescriptor& batch_descriptor,
   2954     const DeviceMemory<Eigen::half>& input_data,
   2955     const FilterDescriptor& filter_descriptor,
   2956     const DeviceMemory<Eigen::half>& filter_data,
   2957     const ConvolutionDescriptor& convolution_descriptor,
   2958     const BatchDescriptor& output_descriptor,
   2959     DeviceMemory<Eigen::half>* output_data, ScratchAllocator* scratch_allocator,
   2960     const dnn::AlgorithmConfig& algorithm_config,
   2961     dnn::ProfileResult* output_profile_result) {
   2962   return DoConvolveImpl<Eigen::half>(
   2963       stream, batch_descriptor, input_data, filter_descriptor, filter_data,
   2964       convolution_descriptor, output_descriptor, output_data, scratch_allocator,
   2965       algorithm_config, output_profile_result);
   2966 }
   2967 
   2968 bool CudnnSupport::DoFusedConvolve(
   2969     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
   2970     const DeviceMemory<double>& conv_input_data, double conv_input_scale,
   2971     const dnn::FilterDescriptor& filter_descriptor,
   2972     const DeviceMemory<double>& filter_data,
   2973     const dnn::ConvolutionDescriptor& convolution_descriptor,
   2974     const DeviceMemory<double>& side_input_data, double side_input_scale,
   2975     const dnn::BatchDescriptor& bias_descriptor,
   2976     const DeviceMemory<double>& biases, dnn::ActivationMode activation_mode,
   2977     const dnn::BatchDescriptor& output_descriptor,
   2978     DeviceMemory<double>* output_data, ScratchAllocator* scratch_allocator,
   2979     const dnn::AlgorithmConfig& algorithm_config,
   2980     dnn::ProfileResult* output_profile_result) {
   2981   return DoFusedConvolveImpl<double, double, double, CUDNN_DATA_DOUBLE,
   2982                              CUDNN_DATA_DOUBLE>(
   2983       stream, conv_input_descriptor, conv_input_data, conv_input_scale,
   2984       filter_descriptor, filter_data, convolution_descriptor, side_input_data,
   2985       side_input_scale, bias_descriptor, biases, activation_mode,
   2986       output_descriptor, output_data, scratch_allocator, algorithm_config,
   2987       output_profile_result);
   2988 }
   2989 
   2990 bool CudnnSupport::DoFusedConvolve(
   2991     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
   2992     const DeviceMemory<float>& conv_input_data, float conv_input_scale,
   2993     const dnn::FilterDescriptor& filter_descriptor,
   2994     const DeviceMemory<float>& filter_data,
   2995     const dnn::ConvolutionDescriptor& convolution_descriptor,
   2996     const DeviceMemory<float>& side_input_data, float side_input_scale,
   2997     const dnn::BatchDescriptor& bias_descriptor,
   2998     const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
   2999     const dnn::BatchDescriptor& output_descriptor,
   3000     DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
   3001     const dnn::AlgorithmConfig& algorithm_config,
   3002     dnn::ProfileResult* output_profile_result) {
   3003   return DoFusedConvolveImpl<float, float, float, CUDNN_DATA_FLOAT,
   3004                              CUDNN_DATA_FLOAT>(
   3005       stream, conv_input_descriptor, conv_input_data, conv_input_scale,
   3006       filter_descriptor, filter_data, convolution_descriptor, side_input_data,
   3007       side_input_scale, bias_descriptor, biases, activation_mode,
   3008       output_descriptor, output_data, scratch_allocator, algorithm_config,
   3009       output_profile_result);
   3010 }
   3011 
   3012 bool CudnnSupport::DoFusedConvolve(
   3013     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
   3014     const DeviceMemory<Eigen::half>& conv_input_data, float conv_input_scale,
   3015     const dnn::FilterDescriptor& filter_descriptor,
   3016     const DeviceMemory<Eigen::half>& filter_data,
   3017     const dnn::ConvolutionDescriptor& convolution_descriptor,
   3018     const DeviceMemory<Eigen::half>& side_input_data, float side_input_scale,
   3019     const dnn::BatchDescriptor& bias_descriptor,
   3020     const DeviceMemory<Eigen::half>& biases,
   3021     dnn::ActivationMode activation_mode,
   3022     const dnn::BatchDescriptor& output_descriptor,
   3023     DeviceMemory<Eigen::half>* output_data, ScratchAllocator* scratch_allocator,
   3024     const dnn::AlgorithmConfig& algorithm_config,
   3025     dnn::ProfileResult* output_profile_result) {
   3026   return DoFusedConvolveImpl<Eigen::half, Eigen::half, float, CUDNN_DATA_HALF,
   3027                              CUDNN_DATA_FLOAT>(
   3028       stream, conv_input_descriptor, conv_input_data, conv_input_scale,
   3029       filter_descriptor, filter_data, convolution_descriptor, side_input_data,
   3030       side_input_scale, bias_descriptor, biases, activation_mode,
   3031       output_descriptor, output_data, scratch_allocator, algorithm_config,
   3032       output_profile_result);
   3033 }
   3034 
   3035 bool CudnnSupport::DoFusedConvolve(
   3036     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
   3037     const DeviceMemory<int8>& conv_input_data, float conv_input_scale,
   3038     const dnn::FilterDescriptor& filter_descriptor,
   3039     const DeviceMemory<int8>& filter_data,
   3040     const dnn::ConvolutionDescriptor& convolution_descriptor,
   3041     const DeviceMemory<int8>& side_input_data, float side_input_scale,
   3042     const dnn::BatchDescriptor& bias_descriptor,
   3043     const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
   3044     const dnn::BatchDescriptor& output_descriptor,
   3045     DeviceMemory<int8>* output_data, ScratchAllocator* scratch_allocator,
   3046     const dnn::AlgorithmConfig& algorithm_config,
   3047     dnn::ProfileResult* output_profile_result) {
   3048 #if CUDNN_VERSION < 6000
   3049   LOG(WARNING) << "cudnnConvolutionBiasActivationForward() is only "
   3050                   "supported for cuDNN version >= 6";
   3051   return false;
   3052 #else
   3053   int cc_major, cc_minor;
   3054   stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major,
   3055                                                                    &cc_minor);
   3056   if (cc_major < 6 || (cc_major == 6 && cc_minor < 1)) {
   3057     LOG(WARNING) << "cudnnConvolutionBiasActivationForward() for int8 is only "
   3058                     "supported on GPUs with compute capability 6.1 or later.";
   3059     return false;
   3060   }
   3061   return DoFusedConvolveImpl<int8, float, float, CUDNN_DATA_INT8x4,
   3062                              CUDNN_DATA_INT32>(
   3063       stream, conv_input_descriptor, conv_input_data, conv_input_scale,
   3064       filter_descriptor, filter_data, convolution_descriptor, side_input_data,
   3065       side_input_scale, bias_descriptor, biases, activation_mode,
   3066       output_descriptor, output_data, scratch_allocator, algorithm_config,
   3067       output_profile_result);
   3068 #endif
   3069 }
   3070 
   3071 template<class T>
   3072 DeviceMemory<T> CudnnSupport::MaybeTransformLayout(
   3073     Stream* stream,
   3074     BatchDescriptor* output_descriptor,
   3075     DeviceMemory<T> backward_output_data,
   3076     std::unique_ptr<TemporaryDeviceMemory<T>>* transform_scratch) {
   3077   if (output_descriptor->layout() == dnn::DataLayout::kBatchDepthYX) {
   3078     return backward_output_data;
   3079   }
   3080   CHECK(output_descriptor->layout() == dnn::DataLayout::kBatchYXDepth);
   3081   *transform_scratch =
   3082       stream->AllocateTemporaryArray<T>(backward_output_data.ElementCount())
   3083           .ConsumeValueOrDie();
   3084   BatchDescriptor transformed_output_descriptor;
   3085   transformed_output_descriptor.CloneFrom(*output_descriptor);
   3086   transformed_output_descriptor.set_layout(dnn::DataLayout::kBatchDepthYX);
   3087   cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
   3088   ScopedTensorDescriptor orig_out_back_nd{parent_, *output_descriptor,
   3089                                           cudnn_type};
   3090   ScopedTensorDescriptor transformed_out_back_nd{
   3091       parent_, transformed_output_descriptor, cudnn_type};
   3092 
   3093   float alpha = 1.0f;
   3094   float beta = 0.0f;
   3095   auto status = wrap::cudnnTransformTensor(
   3096       parent_, ToHandle(dnn_handle_), &alpha, orig_out_back_nd.handle(),
   3097       backward_output_data.opaque(), &beta, transformed_out_back_nd.handle(),
   3098       (*transform_scratch)->mutable_device_memory()->opaque());
   3099 
   3100   if (status != CUDNN_STATUS_SUCCESS) {
   3101     LOG(FATAL) << "Failed to transform the data layout.";
   3102   }
   3103   output_descriptor->set_layout(dnn::DataLayout::kBatchDepthYX);
   3104   return (*transform_scratch)->device_memory();
   3105 }
   3106 
   3107 bool CudnnSupport::DoTransformTensor(Stream* stream,
   3108                                      const dnn::BatchDescriptor& input_desc,
   3109                                      dnn::DataType input_type,
   3110                                      const DeviceMemoryBase& input_data,
   3111                                      const dnn::BatchDescriptor& output_desc,
   3112                                      dnn::DataType output_type, float scale,
   3113                                      DeviceMemoryBase* output_data) {
   3114   mutex_lock lock{dnn_handle_mutex_};
   3115   float beta = 0.0f;
   3116   ScopedTensorDescriptor input_tensor_desc(
   3117       parent_, input_desc, ToCudnnDataType(input_type, input_desc.layout()));
   3118   ScopedTensorDescriptor output_tensor_desc(
   3119       parent_, output_desc, ToCudnnDataType(output_type, output_desc.layout()));
   3120   cudnnStatus_t status = wrap::cudnnTransformTensor(
   3121       parent_, ToHandle(dnn_handle_), &scale, input_tensor_desc.handle(),
   3122       input_data.opaque(), &beta, output_tensor_desc.handle(),
   3123       output_data->opaque());
   3124   if (status != CUDNN_STATUS_SUCCESS) {
   3125     LOG(ERROR) << "Could not transform a tensor with layout "
   3126                << input_desc.ToString() << " and data type "
   3127                << static_cast<int>(input_type) << " to another with layout "
   3128                << output_desc.ToString() << " and data type "
   3129                << static_cast<int>(output_type) << ": " << ToString(status);
   3130     return false;
   3131   }
   3132   return true;
   3133 }
   3134 
   3135 template <class T>
   3136 bool CudnnSupport::DoConvolveBackwardDataImpl(
   3137     Stream* stream,
   3138     const FilterDescriptor& filter_descriptor,
   3139     const DeviceMemory<T>& filter_data,
   3140     const BatchDescriptor& output_descriptor_in,
   3141     DeviceMemory<T> backward_output_data,
   3142     const ConvolutionDescriptor& convolution_descriptor,
   3143     const BatchDescriptor& input_descriptor,
   3144     DeviceMemory<T>* backward_input_data, ScratchAllocator* scratch_allocator,
   3145     const dnn::AlgorithmConfig& algorithm_config,
   3146     dnn::ProfileResult* output_profile_result) {
   3147   mutex_lock lock{dnn_handle_mutex_};
   3148   auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
   3149                                      AsCUDAStreamValue(stream));
   3150   if (status != CUDNN_STATUS_SUCCESS) {
   3151     LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status);
   3152   }
   3153 
   3154   // Alpha is the scaling factor for input.
   3155   float alpha = 1.0;
   3156   // Beta is the scaling factor for output.
   3157   float beta = 0.0;
   3158 
   3159   // TBD(keveman): remove once cuDNN supports kBatchYXDepth for backward pass.
   3160   BatchDescriptor output_descriptor;
   3161   output_descriptor.CloneFrom(output_descriptor_in);
   3162   std::unique_ptr<TemporaryDeviceMemory<T>> transform_scratch;
   3163   backward_output_data = MaybeTransformLayout(
   3164       stream, &output_descriptor, backward_output_data, &transform_scratch);
   3165 
   3166   cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
   3167   ScopedTensorDescriptor out_back_nd{parent_, output_descriptor, cudnn_type};
   3168   ScopedTensorDescriptor in_back_nd{parent_, input_descriptor, cudnn_type};
   3169   ScopedFilterDescriptor filter{parent_, filter_descriptor, input_descriptor,
   3170                                 cudnn_type};
   3171   ScopedConvolutionDescriptor conv{parent_, convolution_descriptor,
   3172                                    GetConvComputeType<T>()};
   3173 
   3174   const bool is_profiling = output_profile_result != nullptr;
   3175   cudnnConvolutionBwdDataAlgo_t algo;
   3176   DeviceMemory<uint8> scratch;
   3177 
   3178   if (algorithm_config.algorithm().is_default()) {
   3179     // With the default algorithm, use Cudnn's heuristics.
   3180     auto get_algorithm = [&](bool specify_limit) SHARED_LOCKS_REQUIRED(
   3181         dnn_handle_mutex_) -> cudnnConvolutionBwdDataAlgo_t {
   3182       cudnnConvolutionBwdDataPreference_t preference =
   3183           specify_limit ? CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT
   3184                         : CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE;
   3185 
   3186       auto memory_limit_bytes =
   3187           scratch_allocator == nullptr
   3188               ? 0
   3189               : scratch_allocator->GetMemoryLimitInBytes(stream);
   3190       if (memory_limit_bytes < 0) {
   3191         memory_limit_bytes = 0;
   3192       }
   3193       cudnnConvolutionBwdDataAlgo_t algo_to_use;
   3194       cudnnStatus_t status = wrap::cudnnGetConvolutionBackwardDataAlgorithm(
   3195           parent_, ToHandle(dnn_handle_),
   3196           /*filterDesc=*/filter.handle(),
   3197           /*diffDesc=*/out_back_nd.handle(),
   3198           /*convDesc=*/conv.handle(),
   3199           /*gradDesc=*/in_back_nd.handle(),
   3200           /*preference=*/preference,
   3201           /*memoryLimitInBytes=*/memory_limit_bytes,
   3202           /*algo=*/&algo_to_use);
   3203       CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Unable to find a suitable "
   3204                                                 "algorithm for doing backward "
   3205                                                 "data convolution";
   3206       return algo_to_use;
   3207     };
   3208 
   3209     algo = get_algorithm(/*specify_limit=*/scratch_allocator != nullptr);
   3210 
   3211     if (scratch_allocator != nullptr) {
   3212       size_t size_in_bytes;
   3213       status = wrap::cudnnGetConvolutionBackwardDataWorkspaceSize(
   3214           parent_, ToHandle(dnn_handle_),
   3215           /*filterDesc=*/filter.handle(),
   3216           /*diffDesc=*/out_back_nd.handle(),
   3217           /*convDesc=*/conv.handle(),
   3218           /*gradDesc=*/in_back_nd.handle(),
   3219           /*algo=*/algo,
   3220           /*sizeInBytes=*/&size_in_bytes);
   3221       int64 size_in_bytes_int64 = size_in_bytes;
   3222       if (status == CUDNN_STATUS_SUCCESS && size_in_bytes_int64 != 0) {
   3223         if (size_in_bytes_int64 > 0) {
   3224           auto allocated =
   3225               scratch_allocator->AllocateBytes(stream, size_in_bytes);
   3226           if (allocated.ok()) {
   3227             scratch = allocated.ValueOrDie();
   3228           } else {
   3229             LOG(WARNING) << allocated.status().error_message();
   3230           }
   3231         } else {
   3232           LOG(WARNING)
   3233               << "cudnnGetConvolutionBackwardDataWorkspaceSize() returned "
   3234                  "negative sizeInBytes value. This could be a cudnn bug.";
   3235         }
   3236       }
   3237     }
   3238 
   3239     // If we didn't allocate any scratch space (perhaps because of failed
   3240     // allocation), we force a switch back to the "no workspace" algorithm.
   3241     if (scratch == nullptr) {
   3242       algo = get_algorithm(/*specify_limit=*/false);
   3243     }
   3244   } else {
   3245     // An algorithm has been specified.
   3246     dnn::AlgorithmDesc algotype = algorithm_config.algorithm();
   3247     algo = ToConvBackwardDataAlgo(algotype);
   3248     conv.set_use_tensor_op_math(algotype.tensor_ops_enabled());
   3249     size_t size_in_bytes;
   3250     status = wrap::cudnnGetConvolutionBackwardDataWorkspaceSize(
   3251         parent_, ToHandle(dnn_handle_),
   3252         /*filterDesc=*/filter.handle(),
   3253         /*diffDesc=*/out_back_nd.handle(),
   3254         /*convDesc=*/conv.handle(),
   3255         /*gradDesc=*/in_back_nd.handle(),
   3256         /*algo=*/algo,
   3257         /*sizeInBytes=*/&size_in_bytes);
   3258     if (status != CUDNN_STATUS_SUCCESS) {
   3259       if (is_profiling) {
   3260         // Silently return when we are profiling.
   3261         return false;
   3262       }
   3263       LOG(FATAL) << "Cannot query the size of workspace needed for the given "
   3264                     "algorithm: "
   3265                  << algorithm_config.algorithm().algo_id();
   3266     }
   3267     int64 size_in_bytes_int64 = size_in_bytes;
   3268     if (size_in_bytes_int64 > 0) {
   3269       if (scratch_allocator == nullptr) {
   3270         LOG(FATAL) << "An allocator must be specified when scratch memory is "
   3271                       "needed";
   3272       }
   3273       auto allocated = scratch_allocator->AllocateBytes(stream, size_in_bytes);
   3274       if (is_profiling && !allocated.ok()) {
   3275         // Silently return when we are profiling.
   3276         return false;
   3277       }
   3278       if (allocated.ok()) {
   3279         scratch = allocated.ValueOrDie();
   3280       } else {
   3281         LOG(WARNING) << allocated.status().error_message();
   3282       }
   3283       if (scratch == nullptr) {
   3284         CHECK(!algorithm_config.algorithm_no_scratch().is_default())
   3285             << "The primary convolution algorithm failed memory allocation, "
   3286                "while a secondary algorithm is not provided.";
   3287         dnn::AlgorithmDesc algotype = algorithm_config.algorithm_no_scratch();
   3288         algo = ToConvBackwardDataAlgo(algotype);
   3289         conv.set_use_tensor_op_math(algotype.tensor_ops_enabled());
   3290       }
   3291     } else if (size_in_bytes_int64 < 0) {
   3292       LOG(WARNING) << "cudnnGetConvolutionBackwardDataWorkspaceSize() returned "
   3293                       "negative sizeInBytes value. This could be a cudnn bug.";
   3294     }
   3295   }
   3296 
   3297   std::unique_ptr<CUDATimer> timer;
   3298   if (is_profiling) {
   3299     timer.reset(new CUDATimer(parent_));  // NOLINT
   3300     timer->Init();
   3301     // The start and stop of the timer should be as close to the Cudnn call as
   3302     // possible. It is still possible for other threads to issue workload on
   3303     // to this stream. So it could take multiple profiling measurements.
   3304     timer->Start(AsCUDAStream(stream));
   3305   }
   3306 
   3307 #if CUDNN_VERSION >= 5000
   3308   status = wrap::cudnnConvolutionBackwardData(
   3309 #else
   3310   status = wrap::cudnnConvolutionBackwardData_v3(
   3311 #endif
   3312       parent_, ToHandle(dnn_handle_),
   3313       /*alpha=*/&alpha,
   3314       /*filterDesc=*/filter.handle(),
   3315       /*filterData=*/filter_data.opaque(),
   3316       /*diffDesc=*/out_back_nd.handle(),
   3317       /*diffData=*/backward_output_data.opaque(),
   3318       /*convDesc=*/conv.handle(),
   3319       /*algo=*/algo,
   3320       /*workSpace=*/scratch.opaque(),
   3321       /*workSpaceSizeInBytes=*/scratch.size(),
   3322       /*beta=*/&beta,
   3323       /*gradDesc=*/in_back_nd.handle(),
   3324       /*gradData=*/backward_input_data->opaque());
   3325   if (is_profiling) {
   3326     timer->Stop(AsCUDAStream(stream));
   3327     if (status == CUDNN_STATUS_SUCCESS) {
   3328       bool use_tensor_ops = algorithm_config.algorithm().tensor_ops_enabled();
   3329       dnn::AlgorithmDesc algotype(algo, use_tensor_ops);
   3330       output_profile_result->set_algorithm(algotype);
   3331       output_profile_result->set_elapsed_time_in_ms(
   3332           timer->GetElapsedMilliseconds());
   3333     }
   3334     timer->Destroy();
   3335   }
   3336   if (status != CUDNN_STATUS_SUCCESS) {
   3337     // Silently return when we are profiling.
   3338     if (!is_profiling) {
   3339       LOG(ERROR) << "failed to enqueue convolution on stream: "
   3340                  << ToString(status);
   3341     }
   3342     return false;
   3343   }
   3344   return true;
   3345 }
   3346 
   3347 bool CudnnSupport::DoConvolveBackwardData(
   3348     Stream* stream, const FilterDescriptor& filter_descriptor,
   3349     const DeviceMemory<float>& filter_data,
   3350     const BatchDescriptor& output_descriptor_in,
   3351     DeviceMemory<float> backward_output_data,
   3352     const ConvolutionDescriptor& convolution_descriptor,
   3353     const BatchDescriptor& input_descriptor,
   3354     DeviceMemory<float>* backward_input_data,
   3355     ScratchAllocator* scratch_allocator,
   3356     const dnn::AlgorithmConfig& algorithm_config,
   3357     dnn::ProfileResult* output_profile_result) {
   3358   return DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data,
   3359                                     output_descriptor_in, backward_output_data,
   3360                                     convolution_descriptor, input_descriptor,
   3361                                     backward_input_data, scratch_allocator,
   3362                                     algorithm_config, output_profile_result);
   3363 }
   3364 
   3365 bool CudnnSupport::DoConvolveBackwardData(
   3366     Stream* stream, const FilterDescriptor& filter_descriptor,
   3367     const DeviceMemory<Eigen::half>& filter_data,
   3368     const BatchDescriptor& output_descriptor_in,
   3369     DeviceMemory<Eigen::half> backward_output_data,
   3370     const ConvolutionDescriptor& convolution_descriptor,
   3371     const BatchDescriptor& input_descriptor,
   3372     DeviceMemory<Eigen::half>* backward_input_data,
   3373     ScratchAllocator* scratch_allocator,
   3374     const dnn::AlgorithmConfig& algorithm_config,
   3375     dnn::ProfileResult* output_profile_result) {
   3376   return DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data,
   3377                                     output_descriptor_in, backward_output_data,
   3378                                     convolution_descriptor, input_descriptor,
   3379                                     backward_input_data, scratch_allocator,
   3380                                     algorithm_config, output_profile_result);
   3381 }
   3382 
   3383 template <class T>
   3384 bool CudnnSupport::DoConvolveBackwardFilterImpl(
   3385     Stream* stream, const dnn::BatchDescriptor& input_descriptor,
   3386     const DeviceMemory<T>& input_data,
   3387     const dnn::BatchDescriptor& output_descriptor_in,
   3388     DeviceMemory<T> backward_output_data,
   3389     const dnn::ConvolutionDescriptor& convolution_descriptor,
   3390     const dnn::FilterDescriptor& filter_descriptor,
   3391     DeviceMemory<T>* backward_filter_data, ScratchAllocator* scratch_allocator,
   3392     const dnn::AlgorithmConfig& algorithm_config,
   3393     dnn::ProfileResult* output_profile_result) {
   3394   mutex_lock lock{dnn_handle_mutex_};
   3395   auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
   3396                                      AsCUDAStreamValue(stream));
   3397   if (status != CUDNN_STATUS_SUCCESS) {
   3398     LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status);
   3399   }
   3400 
   3401   // Alpha is the scaling factor for input.
   3402   float alpha = 1.0;
   3403   // Beta is the scaling factor for output.
   3404   float beta = 0.0;
   3405 
   3406   // TBD(keveman): remove once cuDNN supports kBatchYXDepth for backward pass.
   3407   BatchDescriptor output_descriptor;
   3408   output_descriptor.CloneFrom(output_descriptor_in);
   3409   std::unique_ptr<TemporaryDeviceMemory<T>> transform_scratch;
   3410   backward_output_data = MaybeTransformLayout(
   3411       stream, &output_descriptor, backward_output_data, &transform_scratch);
   3412 
   3413   cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
   3414   ScopedTensorDescriptor out_back_nd{parent_, output_descriptor, cudnn_type};
   3415   ScopedTensorDescriptor input_nd{parent_, input_descriptor, cudnn_type};
   3416   ScopedFilterDescriptor filter{parent_, filter_descriptor, input_descriptor,
   3417                                 cudnn_type};
   3418   ScopedConvolutionDescriptor conv{parent_, convolution_descriptor,
   3419                                    GetConvComputeType<T>()};
   3420 
   3421   const bool is_profiling = output_profile_result != nullptr;
   3422   cudnnConvolutionBwdFilterAlgo_t algo;
   3423   DeviceMemory<uint8> scratch;
   3424 
   3425   if (algorithm_config.algorithm().is_default()) {
   3426     // With the default algorithm, use Cudnn's heuristics.
   3427 
   3428     // Lambda that retrieves the algorithm.
   3429     // specify_limit will occur when we have a scratch allocator and it succeeds
   3430     // in allocating; otherwise, we'll fall back to the "no workspace" version.
   3431     auto get_algorithm = [&](bool specify_limit) SHARED_LOCKS_REQUIRED(
   3432         dnn_handle_mutex_) {
   3433       cudnnConvolutionBwdFilterPreference_t preference =
   3434           specify_limit ? CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
   3435                         : CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE;
   3436 
   3437       auto memory_limit_bytes =
   3438           scratch_allocator == nullptr
   3439               ? 0
   3440               : scratch_allocator->GetMemoryLimitInBytes(stream);
   3441       if (memory_limit_bytes < 0) {
   3442         memory_limit_bytes = 0;
   3443       }
   3444 
   3445       cudnnConvolutionBwdFilterAlgo_t algo_to_use;
   3446       cudnnStatus_t status = wrap::cudnnGetConvolutionBackwardFilterAlgorithm(
   3447           parent_, ToHandle(dnn_handle_),
   3448           /*srcDesc=*/input_nd.handle(),
   3449           /*diffDesc=*/out_back_nd.handle(),
   3450           /*convDesc=*/conv.handle(),
   3451           /*gradDesc=*/filter.handle(),
   3452           /*preference=*/preference,
   3453           /*memoryLimitInBytes=*/memory_limit_bytes,
   3454           /*algo=*/&algo_to_use);
   3455       CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Unable to find a suitable "
   3456                                                 "algorithm for doing backward "
   3457                                                 "filter convolution";
   3458       return algo_to_use;
   3459     };
   3460 
   3461     algo = get_algorithm(/*specify_limit=*/scratch_allocator != nullptr);
   3462 
   3463     if (scratch_allocator != nullptr) {
   3464       size_t size_in_bytes;
   3465       status = wrap::cudnnGetConvolutionBackwardFilterWorkspaceSize(
   3466           parent_, ToHandle(dnn_handle_), /*srcDesc=*/input_nd.handle(),
   3467           /*diffDesc=*/out_back_nd.handle(), /*convDesc=*/conv.handle(),
   3468           /*gradDesc=*/filter.handle(), /*algo=*/algo,
   3469           /*sizeInBytes=*/&size_in_bytes);
   3470       int64 size_in_bytes_int64 = size_in_bytes;
   3471       if (status == CUDNN_STATUS_SUCCESS && size_in_bytes_int64 != 0) {
   3472         if (size_in_bytes_int64 > 0) {
   3473           auto allocated =
   3474               scratch_allocator->AllocateBytes(stream, size_in_bytes);
   3475           if (allocated.ok()) {
   3476             scratch = allocated.ValueOrDie();
   3477           } else {
   3478             LOG(WARNING) << allocated.status().error_message();
   3479           }
   3480         } else {
   3481           LOG(WARNING)
   3482               << "cudnnGetConvolutionBackwardFilterWorkspaceSize() returned "
   3483                  "negative sizeInBytes value. This could be a cudnn bug.";
   3484         }
   3485       }
   3486     }
   3487 
   3488     // If we didn't allocate any scratch space (perhaps because of failed
   3489     // allocation), we force a switch back to the "no workspace" algorithm.
   3490     if (scratch == nullptr) {
   3491       algo = get_algorithm(/*specify_limit=*/false);
   3492     }
   3493   } else {
   3494     // An algorithm has been specified.
   3495     dnn::AlgorithmDesc algotype = algorithm_config.algorithm();
   3496     algo = ToConvBackwardFilterAlgo(algotype);
   3497     conv.set_use_tensor_op_math(algotype.tensor_ops_enabled());
   3498 
   3499     size_t size_in_bytes;
   3500     status = wrap::cudnnGetConvolutionBackwardFilterWorkspaceSize(
   3501         parent_, ToHandle(dnn_handle_), /*srcDesc=*/input_nd.handle(),
   3502         /*diffDesc=*/out_back_nd.handle(), /*convDesc=*/conv.handle(),
   3503         /*gradDesc=*/filter.handle(), /*algo=*/algo,
   3504         /*sizeInBytes=*/&size_in_bytes);
   3505     if (status != CUDNN_STATUS_SUCCESS) {
   3506       if (is_profiling) {
   3507         // Silently return when we are profiling.
   3508         return false;
   3509       }
   3510       LOG(FATAL) << "Cannot query the size of workspace needed for the given "
   3511                     "algorithm: "
   3512                  << algorithm_config.algorithm().algo_id();
   3513     }
   3514     int64 size_in_bytes_int64 = size_in_bytes;
   3515     if (size_in_bytes_int64 > 0) {
   3516       if (scratch_allocator == nullptr) {
   3517         LOG(FATAL) << "An allocator must be specified when scratch memory is "
   3518                       "needed";
   3519       }
   3520       auto allocated = scratch_allocator->AllocateBytes(stream, size_in_bytes);
   3521       if (is_profiling && !allocated.ok()) {
   3522         // Silently return when we are profiling.
   3523         return false;
   3524       }
   3525       if (allocated.ok()) {
   3526         scratch = allocated.ValueOrDie();
   3527       } else {
   3528         LOG(WARNING) << allocated.status().error_message();
   3529       }
   3530       if (scratch == nullptr) {
   3531         CHECK(!algorithm_config.algorithm_no_scratch().is_default())
   3532             << "The primary convolution algorithm failed memory allocation, "
   3533                "while a secondary algorithm is not provided.";
   3534         dnn::AlgorithmDesc algotype = algorithm_config.algorithm_no_scratch();
   3535         algo = ToConvBackwardFilterAlgo(algotype);
   3536         conv.set_use_tensor_op_math(algotype.tensor_ops_enabled());
   3537       }
   3538     } else if (size_in_bytes_int64 < 0) {
   3539       LOG(WARNING)
   3540           << "cudnnGetConvolutionBackwardFilterWorkspaceSize() returned "
   3541              "negative sizeInBytes value. This could be a cudnn bug.";
   3542     }
   3543   }
   3544 
   3545   std::unique_ptr<CUDATimer> timer;
   3546   if (is_profiling) {
   3547     timer.reset(new CUDATimer(parent_));  // NOLINT
   3548     timer->Init();
   3549     // The start and stop of the timer should be as close to the Cudnn call as
   3550     // possible. It is still possible for other threads to issue workload on
   3551     // to this stream. So it could take multiple profiling measurements.
   3552     timer->Start(AsCUDAStream(stream));
   3553   }
   3554 
   3555 #if CUDNN_VERSION >= 5000
   3556   status = wrap::cudnnConvolutionBackwardFilter(
   3557 #else
   3558   status = wrap::cudnnConvolutionBackwardFilter_v3(
   3559 #endif
   3560       parent_, ToHandle(dnn_handle_), /*alpha=*/&alpha,
   3561       /*srcDesc=*/input_nd.handle(),
   3562       /*srcData=*/input_data.opaque(),
   3563       /*diffDesc=*/out_back_nd.handle(),
   3564       /*diffData=*/backward_output_data.opaque(),
   3565       /*convDesc=*/conv.handle(),
   3566       /*algo=*/algo,
   3567       /*workSpace=*/scratch.opaque(),
   3568       /*workSpaceSizeInBytes=*/scratch.size(),
   3569       /*beta=*/&beta,
   3570       /*gradDesc=*/filter.handle(),
   3571       /*gradData=*/backward_filter_data->opaque());
   3572 
   3573   if (is_profiling) {
   3574     timer->Stop(AsCUDAStream(stream));
   3575     if (status == CUDNN_STATUS_SUCCESS) {
   3576       bool use_tensor_ops = algorithm_config.algorithm().tensor_ops_enabled();
   3577       dnn::AlgorithmDesc algotype(algo, use_tensor_ops);
   3578       output_profile_result->set_algorithm(algotype);
   3579       output_profile_result->set_elapsed_time_in_ms(
   3580           timer->GetElapsedMilliseconds());
   3581     }
   3582     timer->Destroy();
   3583   }
   3584   if (status != CUDNN_STATUS_SUCCESS) {
   3585     // Silently return when we are profiling.
   3586     if (!is_profiling) {
   3587       LOG(ERROR) << "failed to enqueue convolution on stream: "
   3588                  << ToString(status);
   3589     }
   3590     return false;
   3591   }
   3592   return true;
   3593 }
   3594 
   3595 bool CudnnSupport::DoConvolveBackwardFilter(
   3596     Stream* stream, const dnn::BatchDescriptor& input_descriptor,
   3597     const DeviceMemory<float>& input_data,
   3598     const dnn::BatchDescriptor& output_descriptor_in,
   3599     DeviceMemory<float> backward_output_data,
   3600     const dnn::ConvolutionDescriptor& convolution_descriptor,
   3601     const dnn::FilterDescriptor& filter_descriptor,
   3602     DeviceMemory<float>* backward_filter_data,
   3603     ScratchAllocator* scratch_allocator,
   3604     const dnn::AlgorithmConfig& algorithm_config,
   3605     dnn::ProfileResult* output_profile_result) {
   3606   return DoConvolveBackwardFilterImpl(
   3607       stream, input_descriptor, input_data, output_descriptor_in,
   3608       backward_output_data, convolution_descriptor, filter_descriptor,
   3609       backward_filter_data, scratch_allocator, algorithm_config,
   3610       output_profile_result);
   3611 }
   3612 
   3613 bool CudnnSupport::DoConvolveBackwardFilter(
   3614     Stream* stream, const dnn::BatchDescriptor& input_descriptor,
   3615     const DeviceMemory<Eigen::half>& input_data,
   3616     const dnn::BatchDescriptor& output_descriptor_in,
   3617     DeviceMemory<Eigen::half> backward_output_data,
   3618     const dnn::ConvolutionDescriptor& convolution_descriptor,
   3619     const dnn::FilterDescriptor& filter_descriptor,
   3620     DeviceMemory<Eigen::half>* backward_filter_data,
   3621     ScratchAllocator* scratch_allocator,
   3622     const dnn::AlgorithmConfig& algorithm_config,
   3623     dnn::ProfileResult* output_profile_result) {
   3624   return DoConvolveBackwardFilterImpl(
   3625       stream, input_descriptor, input_data, output_descriptor_in,
   3626       backward_output_data, convolution_descriptor, filter_descriptor,
   3627       backward_filter_data, scratch_allocator, algorithm_config,
   3628       output_profile_result);
   3629 }
   3630 
   3631 template <class T>
   3632 bool CudnnSupport::DoConvolveBackwardBiasImpl(
   3633     Stream* stream, const dnn::BatchDescriptor& input_descriptor,
   3634     const DeviceMemory<T>& input_data,
   3635     const dnn::BatchDescriptor& bias_descriptor,
   3636     DeviceMemory<T>* backward_bias_data) {
   3637   mutex_lock lock{dnn_handle_mutex_};
   3638   auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
   3639                                      AsCUDAStreamValue(stream));
   3640   if (status != CUDNN_STATUS_SUCCESS) {
   3641     LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status);
   3642   }
   3643 
   3644   cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
   3645   ScopedTensorDescriptor input_nd{parent_, input_descriptor, cudnn_type};
   3646   ScopedTensorDescriptor bias_nd{parent_, bias_descriptor, cudnn_type};
   3647 
   3648   // Alpha is the scaling factor for input.
   3649   float alpha = 1.0;
   3650   // Beta is the scaling factor for output.
   3651   float beta = 0.0;
   3652 
   3653   status = wrap::cudnnConvolutionBackwardBias(
   3654       parent_, ToHandle(dnn_handle_), &alpha, input_nd.handle(),
   3655       input_data.opaque(), &beta, bias_nd.handle(),
   3656       backward_bias_data->opaque());
   3657   if (status != CUDNN_STATUS_SUCCESS) {
   3658     LOG(ERROR) << "failed to enqueue backward convolution on stream: "
   3659                << ToString(status);
   3660     return false;
   3661   }
   3662   return true;
   3663 }
   3664 
   3665 bool CudnnSupport::DoConvolveBackwardBias(
   3666     Stream* stream, const BatchDescriptor& input_descriptor,
   3667     const DeviceMemory<double>& input_data,
   3668     const BatchDescriptor& bias_descriptor,
   3669     DeviceMemory<double>* backward_bias_data) {
   3670   return DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
   3671                                     bias_descriptor, backward_bias_data);
   3672 }
   3673 
   3674 bool CudnnSupport::DoConvolveBackwardBias(
   3675     Stream* stream, const BatchDescriptor& input_descriptor,
   3676     const DeviceMemory<float>& input_data,
   3677     const BatchDescriptor& bias_descriptor,
   3678     DeviceMemory<float>* backward_bias_data) {
   3679   return DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
   3680                                     bias_descriptor, backward_bias_data);
   3681 }
   3682 
   3683 bool CudnnSupport::DoConvolveBackwardBias(
   3684     Stream* stream, const BatchDescriptor& input_descriptor,
   3685     const DeviceMemory<Eigen::half>& input_data,
   3686     const BatchDescriptor& bias_descriptor,
   3687     DeviceMemory<Eigen::half>* backward_bias_data) {
   3688   return DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
   3689                                     bias_descriptor, backward_bias_data);
   3690 }
   3691 
   3692 bool CudnnSupport::DoMatMul(Stream* stream,
   3693                             const DeviceMemory<float>& input_data,
   3694                             const DeviceMemory<float>& weights,
   3695                             const dnn::BatchDescriptor& input_dimensions,
   3696                             const dnn::BatchDescriptor& output_dimensions,
   3697                             DeviceMemory<float>* output_data) {
   3698   if (input_dimensions.count() != output_dimensions.count()) {
   3699     LOG(ERROR) << "MatMul input and output dimensions are not compatible.";
   3700     return false;
   3701   }
   3702 
   3703   // We do not permute the input or output, instead we just
   3704   // reinterpret the layout. We are working with row-major matrices
   3705   // and the rows of the input and output correspond to batch, so
   3706   // batch has to be outermost in both the input and output.
   3707   //
   3708   // By adding transposes to the BLAS gemm call we could perhaps make
   3709   // the kYXDepthBatch layout work as well, but there has been no need
   3710   // for that so far.
   3711   if (input_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
   3712       input_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
   3713     LOG(ERROR) << "Unsupported MatMul input layout.";
   3714     return false;
   3715   }
   3716   if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
   3717       output_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
   3718     LOG(ERROR) << "Unsupported MatMul output layout.";
   3719     return false;
   3720   }
   3721 
   3722   if (output_dimensions.width() == 1 && output_dimensions.height() == 1) {
   3723     // This is a fast path that also supports the kBatchYXDepth layout.
   3724 
   3725     // The matrices here are in row-major format while BLAS expects
   3726     // column-major, i.e. our matrices are transposed as far as BLAS
   3727     // is concerned. So we need to compute output^T =
   3728     // input^T*weights^T. There is no parameter for transposing the
   3729     // output in BLAS gemm, but instead we can transpose both sides of
   3730     // the equality to see that this is equivalent to
   3731     // output=weights*input. So we only need to swap the order of
   3732     // weights and input in the matrix product to correct for the
   3733     // row-major versus column-major difference.
   3734     const float alpha = 1.0f;  // Take the matrix product without scaling it.
   3735     const float beta = 0.0f;   // Ignore the original values in output_data.
   3736     const int64 m = output_dimensions.NodesAcrossFeatureMaps();
   3737     const int64 n = input_dimensions.count();
   3738     const int64 k = input_dimensions.NodesAcrossFeatureMaps();
   3739     stream->ThenBlasGemm(blas::Transpose::kNoTranspose,
   3740                          blas::Transpose::kNoTranspose, m, n, k, alpha, weights,
   3741                          m, input_data, k, beta, output_data, m);
   3742   } else {
   3743     // This is a slower and more complex path that supports output
   3744     // width() * height() > 1, though it only supports the
   3745     // kBatchYXDepth layout. Does support kBatchDepthYX if output
   3746     // feature_map_count() == 1, as then there is no difference
   3747     // between the two layouts.
   3748     //
   3749     // The operation here is the same as above, except that we have to
   3750     // do the matrix multiplication for each (y,x) output coordinate
   3751     // separately. We then interpret weights as containing K = width()
   3752     // * height() different matrices, which we all multiply onto the
   3753     // matrix from input_data, yielding K matrix products. We then
   3754     // combine these together into one matrix by concatenating all the
   3755     // first rows of these matrices, then all the seconds rows and so
   3756     // on. We can do this with a batched matrix multiplication, where
   3757     // the result is written to a different submatrix of the output
   3758     // for each matrix multiplication.
   3759     //
   3760     // The reason that we only support the kBatchYXDepth output layout
   3761     // is that we have to do something in the depth for each (y,x)
   3762     // coordinate. The kBatchYXDepth layout has the depth information
   3763     // for each point (y,x) in contiguous memory while the
   3764     // kBatchDepthYX layout does not.
   3765     //
   3766     // TODO(broune): Consider a special case for when output depth ==
   3767     // 1, as then possibly this could all be done as one matrix
   3768     // multiplication instead of a batched one, which should be
   3769     // faster. Another possibility would be to add a weights layout
   3770     // parameter and then support kBatchDepthYX for a different
   3771     // weights layout.
   3772     if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
   3773         !(output_dimensions.layout() == dnn::DataLayout::kBatchDepthYX &&
   3774           output_dimensions.feature_map_count() == 1)) {
   3775       LOG(ERROR) << "Unsupported MatMul output layout.";
   3776       return false;
   3777     }
   3778 
   3779     const float alpha = 1.0f;  // Take the matrix product without scaling it.
   3780     const float beta = 0.0f;   // Ignore the original values in output_data.
   3781     const uint64 m = output_dimensions.feature_map_count();
   3782     const uint64 n = input_dimensions.count();
   3783     const uint64 k = input_dimensions.NodesAcrossFeatureMaps();
   3784     const int lda = m;
   3785     const int ldb = k;
   3786     const int ldc = output_dimensions.NodesAcrossFeatureMaps();
   3787     const int batch_count = output_dimensions.NodesPerFeatureMap();
   3788 
   3789     std::vector<DeviceMemory<float>> a(batch_count);
   3790     std::vector<DeviceMemory<float>> b(batch_count);
   3791     std::vector<DeviceMemory<float>> c(batch_count);
   3792     for (int i = 0; i < batch_count; ++i) {
   3793       const int weights_offset = i * input_dimensions.NodesAcrossFeatureMaps() *
   3794                                  output_dimensions.feature_map_count();
   3795       a[i] = DeviceMemory<float>::MakeFromByteSize(
   3796           const_cast<float*>(reinterpret_cast<const float*>(weights.opaque())) +
   3797               weights_offset,
   3798           weights.ElementCount() - weights_offset);
   3799 
   3800       b[i] = input_data;
   3801 
   3802       const int output_offset = i * output_dimensions.feature_map_count();
   3803       c[i] = DeviceMemory<float>::MakeFromByteSize(
   3804           const_cast<float*>(
   3805               reinterpret_cast<const float*>(output_data->opaque())) +
   3806               output_offset,
   3807           output_data->ElementCount() - output_offset);
   3808     }
   3809     const auto toPtrs = [](std::vector<DeviceMemory<float>>& v) {
   3810       std::vector<DeviceMemory<float>*> ptrs;
   3811       ptrs.reserve(v.size());
   3812       for (auto& mem : v) {
   3813         ptrs.push_back(&mem);
   3814       }
   3815       return ptrs;
   3816     };
   3817 
   3818     stream->ThenBlasGemmBatched(blas::Transpose::kNoTranspose,
   3819                                 blas::Transpose::kNoTranspose, m, n, k, alpha,
   3820                                 toPtrs(a), lda, toPtrs(b), ldb, beta, toPtrs(c),
   3821                                 ldc, batch_count);
   3822   }
   3823 
   3824   return stream->ok();
   3825 }
   3826 
   3827 bool CudnnSupport::DoBiasAdd(Stream* stream,
   3828                              const DeviceMemory<float>& input_data,
   3829                              const DeviceMemory<float>& biases,
   3830                              const dnn::BatchDescriptor& dimensions,
   3831                              DeviceMemory<float>* output_data) {
   3832   ScopedTensorDescriptor input_descriptor{parent_, dimensions,
   3833                                           CUDNN_DATA_FLOAT};
   3834 
   3835   BatchDescriptor bias_dimensions;
   3836   bias_dimensions.set_count(1)
   3837       .set_feature_map_count(dimensions.feature_map_count())
   3838       .set_height(1)
   3839       .set_width(1)
   3840       .set_layout(dnn::DataLayout::kBatchYXDepth);
   3841   ScopedTensorDescriptor bias_descriptor{parent_, bias_dimensions,
   3842                                          CUDNN_DATA_FLOAT};
   3843 
   3844   // cudnnAddTensor after R3 is in-place, so we need to copy input_data to
   3845   // output_data before doing the addition, unless the input and
   3846   // output are at the same address.
   3847   if (input_data.opaque() != output_data->opaque()) {
   3848     stream->ThenMemcpy(output_data, input_data,
   3849                        dimensions.ElementCount() * sizeof(float));
   3850     if (!stream->ok()) {
   3851       LOG(ERROR)
   3852           << "stream " << stream
   3853           << " could not enqueue a tensor copy as part of bias addition.";
   3854       return false;
   3855     }
   3856   }
   3857 
   3858   mutex_lock lock{dnn_handle_mutex_};
   3859   auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
   3860                                      AsCUDAStreamValue(stream));
   3861   if (status != CUDNN_STATUS_SUCCESS) {
   3862     LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
   3863     return false;
   3864   }
   3865 
   3866   const float alpha = 1.0f;
   3867   const float beta = 1.0f;
   3868 
   3869 #if CUDNN_VERSION >= 5000
   3870   status = wrap::cudnnAddTensor(
   3871 #else
   3872   status = wrap::cudnnAddTensor_v3(
   3873 #endif
   3874       parent_, ToHandle(dnn_handle_), &alpha, bias_descriptor.handle(),
   3875       biases.opaque(), &beta, input_descriptor.handle(), output_data->opaque());
   3876 
   3877   if (status != CUDNN_STATUS_SUCCESS) {
   3878     LOG(ERROR) << "stream " << stream << " could not enqueue bias addition.";
   3879     return false;
   3880   }
   3881 
   3882   return true;
   3883 }
   3884 
   3885 bool CudnnSupport::DoActivate(Stream* stream,
   3886                               dnn::ActivationMode activation_mode,
   3887                               const dnn::BatchDescriptor& dimensions,
   3888                               const DeviceMemory<float>& input_data,
   3889                               DeviceMemory<float>* output_data,
   3890                               uint64 options) {
   3891   mutex_lock lock{dnn_handle_mutex_};
   3892   auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
   3893                                      AsCUDAStreamValue(stream));
   3894   if (status != CUDNN_STATUS_SUCCESS) {
   3895     LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
   3896     return false;
   3897   }
   3898 
   3899 #if CUDNN_VERSION >= 5000
   3900   ScopedActivationDescriptor activation_desc{
   3901       parent_, activation_mode, CUDNN_PROPAGATE_NAN, dimensions.value_max()};
   3902 #else
   3903   cudnnActivationMode_t mode;
   3904   switch (activation_mode) {
   3905     case dnn::ActivationMode::kRelu6:
   3906       // TODO(leary) should probably do a post-pass to clip at 6?
   3907       LOG(WARNING) << "user requested Relu6, but providing Relu instead";
   3908       mode = CUDNN_ACTIVATION_RELU;
   3909       break;
   3910     case dnn::ActivationMode::kReluX:
   3911       // TODO(broune) should probably do a post-pass to clip at X?
   3912       LOG(WARNING) << "user requested ReluX, but providing Relu instead";
   3913       mode = CUDNN_ACTIVATION_RELU;
   3914       break;
   3915     case dnn::ActivationMode::kRelu:
   3916       mode = CUDNN_ACTIVATION_RELU;
   3917       break;
   3918     case dnn::ActivationMode::kSigmoid:
   3919       mode = CUDNN_ACTIVATION_SIGMOID;
   3920       break;
   3921     case dnn::ActivationMode::kTanh:
   3922       mode = CUDNN_ACTIVATION_TANH;
   3923       break;
   3924     default:
   3925       LOG(ERROR) << "unrecognized activation mode: "
   3926                  << static_cast<int>(activation_mode);
   3927       return false;
   3928   }
   3929 #endif
   3930 
   3931   ScopedTensorDescriptor input_nd{parent_, dimensions, CUDNN_DATA_FLOAT};
   3932   // Alpha is the input scaling factor.
   3933   float alpha = 1.0;
   3934   // Beta is the output scaling factor.
   3935   float beta = 0.0;
   3936   status = wrap::cudnnActivationForward(
   3937       parent_, ToHandle(dnn_handle_),
   3938 #if CUDNN_VERSION >= 5000
   3939       activation_desc.handle(),
   3940 #else
   3941       mode,
   3942 #endif
   3943       &alpha, input_nd.handle(), input_data.opaque(), &beta, input_nd.handle(),
   3944       output_data->opaque());
   3945   if (status != CUDNN_STATUS_SUCCESS) {
   3946     LOG(ERROR) << "stream " << stream
   3947                << " could not enqueue activation: " << ToString(status);
   3948     return false;
   3949   }
   3950 
   3951   return true;
   3952 }
   3953 
   3954 bool CudnnSupport::DoPoolForward(
   3955     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
   3956     const dnn::BatchDescriptor& input_dimensions,
   3957     const DeviceMemory<double>& input_data,
   3958     const dnn::BatchDescriptor& output_dimensions,
   3959     DeviceMemory<double>* output_data) {
   3960   mutex_lock lock{dnn_handle_mutex_};
   3961   auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
   3962                                      AsCUDAStreamValue(stream));
   3963   if (status != CUDNN_STATUS_SUCCESS) {
   3964     LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
   3965     return false;
   3966   }
   3967 
   3968   // Alpha is the scaling factor for input.
   3969   double alpha = 1.0;
   3970   // Beta is the scaling factor for output.
   3971   double beta = 0.0;
   3972 
   3973   ScopedTensorDescriptor src_desc{parent_, input_dimensions, CUDNN_DATA_DOUBLE};
   3974   ScopedTensorDescriptor dest_desc{parent_, output_dimensions,
   3975                                    CUDNN_DATA_DOUBLE};
   3976   ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions};
   3977   status = wrap::cudnnPoolingForward(
   3978       parent_, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
   3979       src_desc.handle(), input_data.opaque(), &beta, dest_desc.handle(),
   3980       output_data->opaque());
   3981   if (status != CUDNN_STATUS_SUCCESS) {
   3982     LOG(ERROR) << "failed to enqueue forward pooling on stream: "
   3983                << ToString(status);
   3984     return false;
   3985   }
   3986   return true;
   3987 }
   3988 
   3989 bool CudnnSupport::DoPoolForward(
   3990     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
   3991     const dnn::BatchDescriptor& input_dimensions,
   3992     const DeviceMemory<float>& input_data,
   3993     const dnn::BatchDescriptor& output_dimensions,
   3994     DeviceMemory<float>* output_data) {
   3995   mutex_lock lock{dnn_handle_mutex_};
   3996   auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
   3997                                      AsCUDAStreamValue(stream));
   3998   if (status != CUDNN_STATUS_SUCCESS) {
   3999     LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
   4000     return false;
   4001   }
   4002 
   4003   // Alpha is the scaling factor for input.
   4004   float alpha = 1.0;
   4005   // Beta is the scaling factor for output.
   4006   float beta = 0.0;
   4007 
   4008   ScopedTensorDescriptor src_desc{parent_, input_dimensions, CUDNN_DATA_FLOAT};
   4009   ScopedTensorDescriptor dest_desc{parent_, output_dimensions,
   4010                                    CUDNN_DATA_FLOAT};
   4011   ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions};
   4012   status = wrap::cudnnPoolingForward(
   4013       parent_, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
   4014       src_desc.handle(), input_data.opaque(), &beta, dest_desc.handle(),
   4015       output_data->opaque());
   4016   if (status != CUDNN_STATUS_SUCCESS) {
   4017     LOG(ERROR) << "failed to enqueue forward pooling on stream: "
   4018                << ToString(status);
   4019     return false;
   4020   }
   4021   return true;
   4022 }
   4023 
   4024 bool CudnnSupport::DoPoolForward(
   4025     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
   4026     const dnn::BatchDescriptor& input_dimensions,
   4027     const DeviceMemory<Eigen::half>& input_data,
   4028     const dnn::BatchDescriptor& output_dimensions,
   4029     DeviceMemory<Eigen::half>* output_data) {
   4030   mutex_lock lock{dnn_handle_mutex_};
   4031   auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
   4032                                      AsCUDAStreamValue(stream));
   4033   if (status != CUDNN_STATUS_SUCCESS) {
   4034     LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
   4035     return false;
   4036   }
   4037 
   4038   // Alpha is the scaling factor for input.
   4039   float alpha = 1.0;
   4040   // Beta is the scaling factor for output.
   4041   float beta = 0.0;
   4042 
   4043   ScopedTensorDescriptor src_desc{parent_, input_dimensions, CUDNN_DATA_HALF};
   4044   ScopedTensorDescriptor dest_desc{parent_, output_dimensions, CUDNN_DATA_HALF};
   4045   ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions};
   4046   status = wrap::cudnnPoolingForward(
   4047       parent_, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
   4048       src_desc.handle(), input_data.opaque(), &beta, dest_desc.handle(),
   4049       output_data->opaque());
   4050   if (status != CUDNN_STATUS_SUCCESS) {
   4051     LOG(ERROR) << "failed to enqueue forward pooling on stream: "
   4052                << ToString(status);
   4053     return false;
   4054   }
   4055   return true;
   4056 }
   4057 
   4058 bool CudnnSupport::DoPoolBackward(
   4059     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
   4060     const dnn::BatchDescriptor& input_dimensions,
   4061     const DeviceMemory<double>& input_data,
   4062     const dnn::BatchDescriptor& output_dimensions,
   4063     const DeviceMemory<double>& output_data,
   4064     const DeviceMemory<double>& input_diff_data,
   4065     DeviceMemory<double>* output_diff_data) {
   4066   mutex_lock lock{dnn_handle_mutex_};
   4067   auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
   4068                                      AsCUDAStreamValue(stream));
   4069   if (status != CUDNN_STATUS_SUCCESS) {
   4070     LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
   4071     return false;
   4072   }
   4073 
   4074   // Alpha is the scaling factor for input.
   4075   double alpha = 1.0;
   4076   // Beta is the scaling factor for output.
   4077   double beta = 0.0;
   4078 
   4079   ScopedTensorDescriptor src_desc{parent_, input_dimensions, CUDNN_DATA_DOUBLE};
   4080   ScopedTensorDescriptor dest_desc{parent_, output_dimensions,
   4081                                    CUDNN_DATA_DOUBLE};
   4082   ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions};
   4083   status = wrap::cudnnPoolingBackward(
   4084       parent_, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
   4085       dest_desc.handle(), output_data.opaque(), dest_desc.handle(),
   4086       input_diff_data.opaque(), src_desc.handle(), input_data.opaque(), &beta,
   4087       src_desc.handle(), output_diff_data->opaque());
   4088   if (status != CUDNN_STATUS_SUCCESS) {
   4089     LOG(ERROR) << "failed to enqueue backward pooling on stream: "
   4090                << ToString(status);
   4091     return false;
   4092   }
   4093   return true;
   4094 }
   4095 
   4096 bool CudnnSupport::DoPoolBackward(
   4097     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
   4098     const dnn::BatchDescriptor& input_dimensions,
   4099     const DeviceMemory<float>& input_data,
   4100     const dnn::BatchDescriptor& output_dimensions,
   4101     const DeviceMemory<float>& output_data,
   4102     const DeviceMemory<float>& input_diff_data,
   4103     DeviceMemory<float>* output_diff_data) {
   4104   mutex_lock lock{dnn_handle_mutex_};
   4105   auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
   4106                                      AsCUDAStreamValue(stream));
   4107   if (status != CUDNN_STATUS_SUCCESS) {
   4108     LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
   4109     return false;
   4110   }
   4111 
   4112   // Alpha is the scaling factor for input.
   4113   float alpha = 1.0;
   4114   // Beta is the scaling factor for output.
   4115   float beta = 0.0;
   4116 
   4117   ScopedTensorDescriptor src_desc{parent_, input_dimensions, CUDNN_DATA_FLOAT};
   4118   ScopedTensorDescriptor dest_desc{parent_, output_dimensions,
   4119                                    CUDNN_DATA_FLOAT};
   4120   ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions};
   4121   status = wrap::cudnnPoolingBackward(
   4122       parent_, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
   4123       dest_desc.handle(), output_data.opaque(), dest_desc.handle(),
   4124       input_diff_data.opaque(), src_desc.handle(), input_data.opaque(), &beta,
   4125       src_desc.handle(), output_diff_data->opaque());
   4126   if (status != CUDNN_STATUS_SUCCESS) {
   4127     LOG(ERROR) << "failed to enqueue backward pooling on stream: "
   4128                << ToString(status);
   4129     return false;
   4130   }
   4131   return true;
   4132 }
   4133 
   4134 bool CudnnSupport::DoPoolBackward(
   4135     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
   4136     const dnn::BatchDescriptor& input_dimensions,
   4137     const DeviceMemory<Eigen::half>& input_data,
   4138     const dnn::BatchDescriptor& output_dimensions,
   4139     const DeviceMemory<Eigen::half>& output_data,
   4140     const DeviceMemory<Eigen::half>& input_diff_data,
   4141     DeviceMemory<Eigen::half>* output_diff_data) {
   4142   mutex_lock lock{dnn_handle_mutex_};
   4143   auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
   4144                                      AsCUDAStreamValue(stream));
   4145   if (status != CUDNN_STATUS_SUCCESS) {
   4146     LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
   4147     return false;
   4148   }
   4149 
   4150   // Alpha is the scaling factor for input.
   4151   float alpha = 1.0;
   4152   // Beta is the scaling factor for output.
   4153   float beta = 0.0;
   4154 
   4155   ScopedTensorDescriptor src_desc{parent_, input_dimensions, CUDNN_DATA_HALF};
   4156   ScopedTensorDescriptor dest_desc{parent_, output_dimensions, CUDNN_DATA_HALF};
   4157   ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions};
   4158   status = wrap::cudnnPoolingBackward(
   4159       parent_, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
   4160       dest_desc.handle(), output_data.opaque(), dest_desc.handle(),
   4161       input_diff_data.opaque(), src_desc.handle(), input_data.opaque(), &beta,
   4162       src_desc.handle(), output_diff_data->opaque());
   4163   if (status != CUDNN_STATUS_SUCCESS) {
   4164     LOG(ERROR) << "failed to enqueue backward pooling on stream: "
   4165                << ToString(status);
   4166     return false;
   4167   }
   4168   return true;
   4169 }
   4170 
   4171 bool CudnnSupport::DoNormalize(
   4172     Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
   4173     const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) {
   4174   LOG(FATAL) << "not yet implemented";  // TODO(leary)
   4175   return false;
   4176 }
   4177 
   4178 bool CudnnSupport::DoNormalizeWithDimensions(
   4179     Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
   4180     const dnn::BatchDescriptor& dimensions,
   4181     const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) {
   4182   // Check for unsupported modes.
   4183   if (normalize_descriptor.wrap_around()) {
   4184     LOG(ERROR) << "CUDA LRN does not support wrap-around mode";
   4185     return false;
   4186   }
   4187   if (normalize_descriptor.segment_size()) {
   4188     LOG(ERROR) << "CUDA LRN does not support segmentation";
   4189     return false;
   4190   }
   4191 
   4192   // Launch the normalization.
   4193   mutex_lock lock{dnn_handle_mutex_};
   4194   auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
   4195                                      AsCUDAStreamValue(stream));
   4196   if (status != CUDNN_STATUS_SUCCESS) {
   4197     LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
   4198     return false;
   4199   }
   4200 
   4201   ScopedTensorDescriptor dims{parent_, dimensions, CUDNN_DATA_FLOAT};
   4202   ScopedNormalizeDescriptor normalize{parent_, normalize_descriptor};
   4203 
   4204   // Alpha is the scaling factor for input.
   4205   float alpha = 1.0f;
   4206   // Beta is the scaling factor for output.
   4207   float beta = 0.0f;
   4208 
   4209   status = wrap::cudnnLRNCrossChannelForward(
   4210       parent_, ToHandle(dnn_handle_), normalize.handle(),
   4211       CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha, dims.handle(), input_data.opaque(),
   4212       &beta, dims.handle(), output_data->opaque());
   4213   if (status != CUDNN_STATUS_SUCCESS) {
   4214     LOG(ERROR) << "failed to run cudnnLRNCrossChannelForward";
   4215     return false;
   4216   }
   4217   return true;
   4218 }
   4219 
   4220 bool CudnnSupport::DoNormalizeBackwardWithDimensions(
   4221     Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
   4222     const dnn::BatchDescriptor& dimensions, const DeviceMemory<float>& raw_data,
   4223     const DeviceMemory<float>& normalized_data,
   4224     const DeviceMemory<float>& normalized_variable_gradient,
   4225     DeviceMemory<float>* raw_variable_gradient) {
   4226   // Check for unsupported modes.
   4227   if (normalize_descriptor.wrap_around()) {
   4228     LOG(ERROR) << "CUDA LRN does not support wrap-around mode";
   4229     return false;
   4230   }
   4231   if (normalize_descriptor.segment_size()) {
   4232     LOG(ERROR) << "CUDA LRN does not support segmentation";
   4233     return false;
   4234   }
   4235 
   4236   mutex_lock lock{dnn_handle_mutex_};
   4237   auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
   4238                                      AsCUDAStreamValue(stream));
   4239   if (status != CUDNN_STATUS_SUCCESS) {
   4240     LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
   4241     return false;
   4242   }
   4243 
   4244   ScopedTensorDescriptor dims{parent_, dimensions, CUDNN_DATA_FLOAT};
   4245   ScopedNormalizeDescriptor normalize{parent_, normalize_descriptor};
   4246 
   4247   float alpha = 1.0f;
   4248   float beta = 0.0f;
   4249 
   4250   status = wrap::cudnnLRNCrossChannelBackward(
   4251       parent_, ToHandle(dnn_handle_), normalize.handle(),
   4252       CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha, dims.handle(),
   4253       normalized_data.opaque(), dims.handle(),
   4254       normalized_variable_gradient.opaque(), dims.handle(), raw_data.opaque(),
   4255       &beta, dims.handle(), raw_variable_gradient->opaque());
   4256   if (status != CUDNN_STATUS_SUCCESS) {
   4257     LOG(ERROR) << "failed to run cudnnLRNCrossChannelBackward";
   4258     return false;
   4259   }
   4260   return true;
   4261 }
   4262 
   4263 bool CudnnSupport::DoDepthConcatenate(
   4264     Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
   4265     port::ArraySlice<const DeviceMemory<float>*> input_data,
   4266     DeviceMemory<float>* output_data) {
   4267   CHECK_EQ(input_dimensions.size(), input_data.size());
   4268 
   4269   for (const auto& dimensions : input_dimensions) {
   4270     if (dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
   4271       LOG(ERROR) << "CudnnSupport::DoDepthConcatenate currently only "
   4272                     "supports the kBatchDepthYX layout.";
   4273       return false;
   4274     }
   4275   }
   4276 
   4277   if (input_dimensions.empty()) {
   4278     return true;  // Nothing to do.
   4279   }
   4280 
   4281   dnn::BatchDescriptor output_dimensions =
   4282       dnn::BatchDescriptor::DepthConcatenateOutputDescriptor(input_dimensions);
   4283 
   4284   const int64 area = output_dimensions.width() * output_dimensions.height();
   4285   const auto index = [area](int64 batch, int64 depth, int64 yx,
   4286                             int64 max_depth) {
   4287     return (batch * max_depth + depth) * area + yx;
   4288   };
   4289 
   4290   std::vector<float> output_host(output_dimensions.ElementCount());
   4291   std::vector<float> tmp;
   4292   int64 depth_sum = 0;
   4293   for (size_t i = 0; i < input_data.size(); ++i) {
   4294     const auto& dimensions = input_dimensions[i];
   4295     tmp.resize(dimensions.ElementCount());
   4296     stream->ThenMemcpyD2H<float>(*input_data[i], &tmp);
   4297     port::Status block_status = stream->BlockHostUntilDone();
   4298     if (!block_status.ok()) {
   4299       LOG(ERROR) << "BlockHostUntilDone failed: " << block_status;
   4300       return false;
   4301     }
   4302 
   4303     for (int64 batch = 0; batch < output_dimensions.count(); ++batch) {
   4304       for (int64 yx = 0; yx < area; ++yx) {
   4305         for (int64 depth = 0; depth < dimensions.feature_map_count(); ++depth) {
   4306           LOG(INFO) << output_dimensions.ElementCount() << ' ' << batch << ' '
   4307                     << yx << ' ' << depth;
   4308           output_host[index(batch, depth + depth_sum, yx,
   4309                             output_dimensions.feature_map_count())] =
   4310               tmp[index(batch, depth, yx, dimensions.feature_map_count())];
   4311         }
   4312       }
   4313     }
   4314     depth_sum += dimensions.feature_map_count();
   4315   }
   4316   stream->ThenMemcpyH2D<float>(output_host, output_data);
   4317   return true;
   4318 }
   4319 
   4320 bool CudnnSupport::DoElementwiseOperate(
   4321     Stream* stream, dnn::ElementwiseOperation operation,
   4322     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
   4323     port::ArraySlice<const DeviceMemory<float>*> input_data,
   4324     const dnn::BatchDescriptor& output_dimensions,
   4325     DeviceMemory<float>* output_data) {
   4326   LOG(FATAL) << "not yet implemented";  // TODO(leary)
   4327   return false;
   4328 }
   4329 
   4330 bool CudnnSupport::DoXYPad(Stream* stream,
   4331                            const dnn::BatchDescriptor& dimensions,
   4332                            const DeviceMemory<float>& input_data,
   4333                            int64 left_pad, int64 right_pad, int64 top_pad,
   4334                            int64 bottom_pad, DeviceMemory<float>* output_data) {
   4335   LOG(FATAL) << "not yet implemented";  // TODO(leary)
   4336   return false;
   4337 }
   4338 
   4339 bool CudnnSupport::DoXYSlice(Stream* stream,
   4340                              const dnn::BatchDescriptor& dimensions,
   4341                              const DeviceMemory<float>& input_data,
   4342                              int64 left_trim, int64 right_trim, int64 top_trim,
   4343                              int64 bottom_trim,
   4344                              DeviceMemory<float>* output_data) {
   4345   LOG(FATAL) << "not yet implemented";  // TODO(leary)
   4346   return false;
   4347 }
   4348 
   4349 bool CudnnSupport::DoMemcpyD2HQuantized(
   4350     Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
   4351     dnn::QuantizedActivationMode mode, void* host_dst, int64 size) {
   4352   LOG(ERROR) << "quantized memcpy not supported by cuDNN";
   4353   return false;
   4354 }
   4355 
   4356 bool CudnnSupport::DoMemcpyH2DQuantized(
   4357     Stream* stream, const void* host_src, int64 size,
   4358     dnn::QuantizedActivationMode mode,
   4359     DeviceMemory<float>* gpu_unquantized_dst) {
   4360   LOG(ERROR) << "quantized memcpy not supported by cuDNN";
   4361   return false;
   4362 }
   4363 
   4364 bool CudnnSupport::DeriveOutputBatchDescriptor(
   4365     const BatchDescriptor& batch_descriptor,
   4366     const FilterDescriptor& filter_descriptor,
   4367     const dnn::ConvolutionDescriptor& convolution_descriptor,
   4368     dnn::BatchDescriptor* output_batch_descriptor) {
   4369   ScopedTensorDescriptor input_nd{parent_, batch_descriptor, CUDNN_DATA_FLOAT};
   4370   ScopedFilterDescriptor filter{parent_, filter_descriptor, batch_descriptor,
   4371                                 CUDNN_DATA_FLOAT};
   4372   ScopedConvolutionDescriptor conv{parent_, convolution_descriptor,
   4373                                    CUDNN_DATA_FLOAT};
   4374 
   4375   int dn = batch_descriptor.ndims() + 2;
   4376   std::vector<int> dims(dn);  // in BDYX
   4377   auto status = wrap::cudnnGetConvolutionNdForwardOutputDim(
   4378       parent_, conv.handle(), input_nd.handle(), filter.handle(), dn,
   4379       dims.data());
   4380   if (status != CUDNN_STATUS_SUCCESS) {
   4381     LOG(ERROR) << "could not get output tensor for convolution: "
   4382                << ToString(status);
   4383     return false;
   4384   }
   4385 
   4386   output_batch_descriptor->set_count(dims[0])
   4387       .set_feature_map_count(dims[1])
   4388       .set_layout(batch_descriptor.layout());
   4389 
   4390   for (int i = 0; i < batch_descriptor.ndims(); i++) {
   4391     output_batch_descriptor->set_spatial_dim(static_cast<dnn::DimIndex>(i),
   4392                                              dims.rbegin()[i]);
   4393   }
   4394 
   4395   return true;
   4396 }
   4397 
   4398 }  // namespace cuda
   4399 
   4400 namespace gpu = ::perftools::gputools;
   4401 
   4402 void initialize_cudnn() {
   4403   gpu::port::Status status =
   4404       gpu::PluginRegistry::Instance()
   4405           ->RegisterFactory<gpu::PluginRegistry::DnnFactory>(
   4406               gpu::cuda::kCudaPlatformId, gpu::cuda::kCuDnnPlugin, "cuDNN",
   4407               [](gpu::internal::StreamExecutorInterface*
   4408                      parent) -> gpu::dnn::DnnSupport* {
   4409                 gpu::cuda::CUDAExecutor* cuda_executor =
   4410                     dynamic_cast<gpu::cuda::CUDAExecutor*>(parent);
   4411                 if (cuda_executor == nullptr) {
   4412                   LOG(ERROR)
   4413                       << "Attempting to initialize an instance of the cuBLAS "
   4414                       << "support library with a non-CUDA StreamExecutor";
   4415                   return nullptr;
   4416                 }
   4417 
   4418                 gpu::cuda::CudnnSupport* dnn =
   4419                     new gpu::cuda::CudnnSupport(cuda_executor);
   4420                 if (!dnn->Init().ok()) {
   4421                   // Note: Init() will log a more specific error.
   4422                   delete dnn;
   4423                   return nullptr;
   4424                 }
   4425                 return dnn;
   4426               });
   4427 
   4428   if (!status.ok()) {
   4429     LOG(ERROR) << "Unable to register cuDNN factory: "
   4430                << status.error_message();
   4431   }
   4432 
   4433   gpu::PluginRegistry::Instance()->SetDefaultFactory(gpu::cuda::kCudaPlatformId,
   4434                                                      gpu::PluginKind::kDnn,
   4435                                                      gpu::cuda::kCuDnnPlugin);
   4436 }
   4437 
   4438 }  // namespace gputools
   4439 }  // namespace perftools
   4440 
   4441 REGISTER_MODULE_INITIALIZER(register_cudnn,
   4442                             { perftools::gputools::initialize_cudnn(); });
   4443