1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #define EIGEN_USE_THREADS 16 17 #include <stddef.h> 18 #include <atomic> 19 #include <cmath> 20 #include <functional> 21 #include <limits> 22 #include <string> 23 #include <unordered_set> 24 25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 26 #include "tensorflow/core/framework/device_base.h" 27 #include "tensorflow/core/framework/kernel_def_builder.h" 28 #include "tensorflow/core/framework/op.h" 29 #include "tensorflow/core/framework/op_def_builder.h" 30 #include "tensorflow/core/framework/op_kernel.h" 31 #include "tensorflow/core/framework/register_types.h" 32 #include "tensorflow/core/framework/tensor.h" 33 #include "tensorflow/core/framework/tensor_shape.h" 34 #include "tensorflow/core/framework/tensor_types.h" 35 #include "tensorflow/core/framework/types.h" 36 #include "tensorflow/core/lib/core/errors.h" 37 #include "tensorflow/core/lib/core/status.h" 38 #include "tensorflow/core/lib/core/stringpiece.h" 39 #include "tensorflow/core/lib/gtl/inlined_vector.h" 40 #include "tensorflow/core/lib/hash/hash.h" 41 #include "tensorflow/core/lib/strings/stringprintf.h" 42 #include "tensorflow/core/platform/fingerprint.h" 43 #include "tensorflow/core/platform/mutex.h" 44 #include "tensorflow/core/platform/types.h" 45 #include "tensorflow/core/util/env_var.h" 46 47 #if GOOGLE_CUDA 48 #include "tensorflow/core/platform/stream_executor.h" 49 #include "tensorflow/core/util/stream_executor_util.h" 50 #endif // GOOGLE_CUDA 51 52 /* 53 * This module implements ops that fuse a multi-layer multi-step RNN/LSTM model 54 * using the underlying Cudnn library. 55 * 56 * Cudnn RNN library exposes an opaque parameter buffer with unknown layout and 57 * format. And it is very likely that if saved, they cannot be used across 58 * different GPUs. So users need to first query the size of the opaque 59 * parameter buffer, and convert it to and from its canonical forms. But each 60 * actual training step is carried out with the parameter buffer. 61 * 62 * Similar to many other ops, the forward op has two flavors: training and 63 * inference. When training is specified, additional data in reserve_space will 64 * be produced for the backward pass. So there is a performance penalty. 65 * 66 * In addition to the actual data and reserve_space, Cudnn also needs more 67 * memory as temporary workspace. The memory management to and from 68 * stream-executor is done through ScratchAllocator. In general, 69 * stream-executor is responsible for creating the memory of proper size. And 70 * TensorFlow is responsible for making sure the memory is alive long enough 71 * and recycles afterwards. 72 * 73 */ 74 namespace tensorflow { 75 76 using CPUDevice = Eigen::ThreadPoolDevice; 77 78 #if GOOGLE_CUDA 79 80 using GPUDevice = Eigen::GpuDevice; 81 82 template <typename Device, typename T, typename Index> 83 class CudnnRNNParamsSizeOp; 84 85 template <typename Device, typename T> 86 class CudnnRNNParamsToCanonical; 87 88 template <typename Device, typename T> 89 class CudnnRNNCanonicalToParams; 90 91 template <typename Device, typename T> 92 class CudnnRNNForwardOp; 93 94 template <typename Device, typename T> 95 class CudnnRNNBackwardOp; 96 97 enum class TFRNNInputMode { 98 kRNNLinearInput = 0, 99 kRNNSkipInput = 1, 100 kAutoSelect = 9999999 101 }; 102 103 namespace { 104 using perftools::gputools::DeviceMemory; 105 using perftools::gputools::DeviceMemoryBase; 106 using perftools::gputools::ScratchAllocator; 107 using perftools::gputools::dnn::RnnDirectionMode; 108 using perftools::gputools::dnn::RnnInputMode; 109 using perftools::gputools::dnn::RnnMode; 110 using perftools::gputools::dnn::ToDataType; 111 using perftools::gputools::port::StatusOr; 112 113 Status ParseRNNMode(const string& str, RnnMode* rnn_mode) { 114 if (str == "rnn_relu") { 115 *rnn_mode = RnnMode::kRnnRelu; 116 return Status::OK(); 117 } else if (str == "rnn_tanh") { 118 *rnn_mode = RnnMode::kRnnTanh; 119 return Status::OK(); 120 } else if (str == "lstm") { 121 *rnn_mode = RnnMode::kRnnLstm; 122 return Status::OK(); 123 } else if (str == "gru") { 124 *rnn_mode = RnnMode::kRnnGru; 125 return Status::OK(); 126 } 127 return errors::InvalidArgument("Invalid RNN mode: ", str); 128 } 129 130 Status ParseTFRNNInputMode(const string& str, TFRNNInputMode* rnn_input_mode) { 131 if (str == "linear_input") { 132 *rnn_input_mode = TFRNNInputMode::kRNNLinearInput; 133 return Status::OK(); 134 } else if (str == "skip_input") { 135 *rnn_input_mode = TFRNNInputMode::kRNNSkipInput; 136 return Status::OK(); 137 } else if (str == "auto_select") { 138 *rnn_input_mode = TFRNNInputMode::kAutoSelect; 139 return Status::OK(); 140 } 141 return errors::InvalidArgument("Invalid RNN input mode: ", str); 142 } 143 144 Status ParseRNNDirectionMode(const string& str, 145 RnnDirectionMode* rnn_dir_mode) { 146 if (str == "unidirectional") { 147 *rnn_dir_mode = RnnDirectionMode::kRnnUnidirectional; 148 return Status::OK(); 149 } else if (str == "bidirectional") { 150 *rnn_dir_mode = RnnDirectionMode::kRnnBidirectional; 151 return Status::OK(); 152 } 153 return errors::InvalidArgument("Invalid RNN direction mode: ", str); 154 } 155 156 Status ToRNNInputMode(TFRNNInputMode tf_input_mode, int num_units, 157 int input_size, RnnInputMode* input_mode) { 158 switch (tf_input_mode) { 159 case TFRNNInputMode::kRNNLinearInput: 160 *input_mode = RnnInputMode::kRnnLinearSkip; 161 break; 162 case TFRNNInputMode::kRNNSkipInput: 163 *input_mode = RnnInputMode::kRnnSkipInput; 164 break; 165 case TFRNNInputMode::kAutoSelect: 166 *input_mode = (input_size == num_units) ? RnnInputMode::kRnnSkipInput 167 : RnnInputMode::kRnnLinearSkip; 168 break; 169 default: 170 return errors::InvalidArgument("Invalid TF input mode: ", 171 static_cast<int>(tf_input_mode)); 172 } 173 return Status::OK(); 174 } 175 176 // TODO(zhengxq): Merge those into stream_executor_util.h. 177 template <typename T> 178 const DeviceMemory<T> AsDeviceMemory(const Tensor* tensor) { 179 return DeviceMemory<T>::MakeFromByteSize( 180 const_cast<T*>(tensor->template flat<T>().data()), 181 tensor->template flat<T>().size() * sizeof(T)); 182 } 183 184 template <typename T> 185 DeviceMemory<T> AsDeviceMemory(Tensor* tensor) { 186 return DeviceMemory<T>::MakeFromByteSize( 187 tensor->template flat<T>().data(), 188 tensor->template flat<T>().size() * sizeof(T)); 189 } 190 191 template <typename U, typename T> 192 DeviceMemory<U> CastDeviceMemory(Tensor* tensor) { 193 return DeviceMemory<U>::MakeFromByteSize( 194 tensor->template flat<T>().data(), 195 tensor->template flat<T>().size() * sizeof(T)); 196 } 197 198 DeviceMemoryBase SliceDeviceMemory(const DeviceMemoryBase& device_memory, 199 int64 offset, int64 size) { 200 const void* base_ptr = device_memory.opaque(); 201 void* offset_ptr = 202 const_cast<char*>(reinterpret_cast<const char*>(base_ptr) + offset); 203 CHECK(offset + size <= device_memory.size()) 204 << "The slice is not within the region of DeviceMemory."; 205 return DeviceMemoryBase(offset_ptr, size); 206 } 207 208 inline Status FromExecutorStatus(const perftools::gputools::port::Status& s) { 209 return s.ok() ? Status::OK() 210 : Status(static_cast<tensorflow::error::Code>( 211 static_cast<int>(s.code())), 212 s.error_message()); 213 } 214 215 template <typename T> 216 inline Status FromExecutorStatus( 217 const perftools::gputools::port::StatusOr<T>& s) { 218 return FromExecutorStatus(s.status()); 219 } 220 221 inline perftools::gputools::port::Status ToExecutorStatus(const Status& s) { 222 return s.ok() ? perftools::gputools::port::Status::OK() 223 : perftools::gputools::port::Status( 224 static_cast<perftools::gputools::port::error::Code>( 225 static_cast<int>(s.code())), 226 s.error_message()); 227 } 228 229 // A helper to allocate temporary scratch memory for Cudnn RNN models. It takes 230 // the ownership of the underlying memory. The expectation is that the memory 231 // should be alive for the span of the Cudnn RNN itself. 232 class CudnnRNNWorkspaceAllocator : public ScratchAllocator { 233 public: 234 ~CudnnRNNWorkspaceAllocator() override {} 235 explicit CudnnRNNWorkspaceAllocator(OpKernelContext* context) 236 : context_(context) {} 237 int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override { 238 return std::numeric_limits<int64>::max(); 239 } 240 StatusOr<DeviceMemory<uint8>> AllocateBytes( 241 perftools::gputools::Stream* stream, int64 byte_size) override { 242 Tensor temporary_memory; 243 Status allocation_status(context_->allocate_temp( 244 DT_UINT8, TensorShape({byte_size}), &temporary_memory)); 245 if (!allocation_status.ok()) { 246 return ToExecutorStatus(allocation_status); 247 } 248 // Hold the reference of the allocated tensors until the end of the 249 // allocator. 250 allocated_tensors_.push_back(temporary_memory); 251 total_byte_size_ += byte_size; 252 return StatusOr<DeviceMemory<uint8>>( 253 AsDeviceMemory<uint8>(&temporary_memory)); 254 } 255 int64 TotalByteSize() { return total_byte_size_; } 256 257 private: 258 int64 total_byte_size_ = 0; 259 OpKernelContext* context_; // not owned 260 std::vector<Tensor> allocated_tensors_; 261 }; 262 263 // A helper to allocate reserve-space memory for Cudnn RNN models. The tensors 264 // are allocated as a kernel output, and will be fed into the backward pass. 265 // The memory is expected to live long enough after the backward pass is 266 // finished. 267 template <typename T> 268 class CudnnRNNReserveSpaceAllocator : public ScratchAllocator { 269 public: 270 ~CudnnRNNReserveSpaceAllocator() override {} 271 CudnnRNNReserveSpaceAllocator(OpKernelContext* context, int output_index) 272 : context_(context), output_index_(output_index) {} 273 int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override { 274 return std::numeric_limits<int64>::max(); 275 } 276 StatusOr<DeviceMemory<uint8>> AllocateBytes( 277 perftools::gputools::Stream* stream, int64 byte_size) override { 278 CHECK(total_byte_size_ == 0) 279 << "Reserve space allocator can only be called once"; 280 int64 allocate_count = 281 Eigen::divup(byte_size, static_cast<int64>(sizeof(T))); 282 283 Tensor* temporary_memory = nullptr; 284 Status allocation_status(context_->allocate_output( 285 output_index_, TensorShape({allocate_count}), &temporary_memory)); 286 if (!allocation_status.ok()) { 287 return ToExecutorStatus(allocation_status); 288 } 289 total_byte_size_ += byte_size; 290 auto memory_uint8 = DeviceMemory<uint8>::MakeFromByteSize( 291 temporary_memory->template flat<T>().data(), 292 temporary_memory->template flat<T>().size() * sizeof(T)); 293 return StatusOr<DeviceMemory<uint8>>(memory_uint8); 294 } 295 int64 TotalByteSize() { return total_byte_size_; } 296 297 private: 298 int64 total_byte_size_ = 0; 299 OpKernelContext* context_; // not owned 300 int output_index_; 301 }; 302 303 // A helper to allocate persistent memory for Cudnn RNN models, which is 304 // expected to live between kernel invocations. 305 // This class is not thread-safe. 306 class CudnnRNNPersistentSpaceAllocator : public ScratchAllocator { 307 public: 308 explicit CudnnRNNPersistentSpaceAllocator(OpKernelContext* context) 309 : context_(context) {} 310 311 ~CudnnRNNPersistentSpaceAllocator() override {} 312 313 int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override { 314 return std::numeric_limits<int64>::max(); 315 } 316 317 StatusOr<DeviceMemory<uint8>> AllocateBytes( 318 perftools::gputools::Stream* stream, int64 byte_size) override { 319 if (total_byte_size_ != 0) { 320 return Status(error::FAILED_PRECONDITION, 321 "Persistent space allocator can only be called once"); 322 } 323 324 Status allocation_status = context_->allocate_persistent( 325 DT_UINT8, TensorShape({byte_size}), &handle_, nullptr); 326 if (!allocation_status.ok()) { 327 return ToExecutorStatus(allocation_status); 328 } 329 total_byte_size_ += byte_size; 330 return AsDeviceMemory<uint8>(handle_.AccessTensor(context_)); 331 } 332 int64 TotalByteSize() { return total_byte_size_; } 333 334 private: 335 int64 total_byte_size_ = 0; 336 PersistentTensor handle_; 337 OpKernelContext* context_; // not owned 338 }; 339 340 struct CudnnModelTypes { 341 RnnMode rnn_mode; 342 TFRNNInputMode rnn_input_mode; 343 RnnDirectionMode rnn_direction_mode; 344 bool HasInputC() const { 345 // For Cudnn 5.0, only LSTM has input-c. All other models use only input-h. 346 return rnn_mode == RnnMode::kRnnLstm; 347 } 348 }; 349 350 // A helper class that collects the shapes to describe a RNN model. 351 struct CudnnModelShapes { 352 int num_layers; 353 int input_size; 354 int num_units; 355 int seq_length; 356 int batch_size; 357 int dir_count; 358 TensorShape input_shape; 359 TensorShape output_shape; 360 TensorShape hidden_state_shape; 361 // At present only fields related to cached RnnDescriptor are concerned. 362 bool IsCompatibleWith(const CudnnModelShapes& rhs) const { 363 return num_layers == rhs.num_layers && input_size == rhs.input_size && 364 num_units == rhs.num_units && dir_count == rhs.dir_count; 365 } 366 string RnnDescDebugString() { 367 return strings::Printf( 368 "[num_layers, input_size, num_units, dir_count]: [%d, %d, %d, %d]", 369 num_layers, input_size, num_units, dir_count); 370 } 371 }; 372 373 // Utility class for using CudnnModelShapes as a hash table key. 374 struct CudnnModelShapesHasher { 375 uint64 operator()(const CudnnModelShapes& to_hash) const { 376 uint64 hash = static_cast<uint64>(to_hash.num_layers); 377 hash = tensorflow::FingerprintCat64( 378 hash, static_cast<uint64>(to_hash.input_size)); 379 hash = tensorflow::FingerprintCat64(hash, 380 static_cast<uint64>(to_hash.num_units)); 381 return tensorflow::FingerprintCat64(hash, 382 static_cast<uint64>(to_hash.dir_count)); 383 } 384 }; 385 386 // Utility class for using CudnnModelShapes as a hash table key. 387 struct CudnnModelShapesComparator { 388 bool operator()(const CudnnModelShapes& first, 389 const CudnnModelShapes& second) const { 390 return first.IsCompatibleWith(second); 391 } 392 }; 393 394 // Extract and checks the forward input tensors, parameters, and shapes from the 395 // OpKernelContext. 396 Status ExtractForwardInput(OpKernelContext* context, 397 const CudnnModelTypes& model_types, 398 const Tensor** input, const Tensor** input_h, 399 const Tensor** input_c, const Tensor** params, 400 CudnnModelShapes* model_shapes) { 401 TF_RETURN_IF_ERROR(context->input("input", input)); 402 TF_RETURN_IF_ERROR(context->input("input_h", input_h)); 403 if (model_types.HasInputC()) { 404 TF_RETURN_IF_ERROR(context->input("input_c", input_c)); 405 } 406 TF_RETURN_IF_ERROR(context->input("params", params)); 407 408 if ((*input)->dims() != 3) { 409 return errors::InvalidArgument("RNN input must be a 3-D vector."); 410 } 411 model_shapes->seq_length = (*input)->dim_size(0); 412 model_shapes->batch_size = (*input)->dim_size(1); 413 model_shapes->input_size = (*input)->dim_size(2); 414 model_shapes->input_shape = (*input)->shape(); 415 model_shapes->dir_count = 416 (model_types.rnn_direction_mode == RnnDirectionMode::kRnnBidirectional) 417 ? 2 418 : 1; 419 420 if ((*input_h)->dims() != 3) { 421 return errors::InvalidArgument("RNN input must be a 3-D vector."); 422 } 423 model_shapes->num_layers = (*input_h)->dim_size(0) / model_shapes->dir_count; 424 model_shapes->num_units = (*input_h)->dim_size(2); 425 426 model_shapes->hidden_state_shape = 427 TensorShape({model_shapes->dir_count * model_shapes->num_layers, 428 model_shapes->batch_size, model_shapes->num_units}); 429 if ((*input_h)->shape() != model_shapes->hidden_state_shape) { 430 return errors::InvalidArgument( 431 "Invalid input_h shape: ", (*input_h)->shape().DebugString(), " ", 432 model_shapes->hidden_state_shape.DebugString()); 433 } 434 if (model_types.HasInputC()) { 435 if ((*input_h)->shape() != (*input_c)->shape()) { 436 return errors::InvalidArgument( 437 "input_h and input_c must have the same shape: ", 438 (*input_h)->shape().DebugString(), " ", 439 (*input_c)->shape().DebugString()); 440 } 441 } 442 model_shapes->output_shape = 443 TensorShape({model_shapes->seq_length, model_shapes->batch_size, 444 model_shapes->dir_count * model_shapes->num_units}); 445 return Status::OK(); 446 } 447 448 using perftools::gputools::dnn::RnnDescriptor; 449 450 template <typename T> 451 void RestoreParams(const OpInputList params_input, 452 const std::vector<RnnDescriptor::ParamsRegion>& params, 453 DeviceMemoryBase* data_dst, 454 perftools::gputools::Stream* stream) { 455 int num_params = params.size(); 456 CHECK(params_input.size() == num_params) 457 << "Number of params mismatch. Expected " << params_input.size() 458 << ", got " << num_params; 459 for (int i = 0; i < params.size(); i++) { 460 int64 size_in_bytes = params[i].size; 461 int64 size = size_in_bytes / sizeof(T); 462 CHECK(size == params_input[i].NumElements()) 463 << "Params size mismatch. Expected " << size << ", got " 464 << params_input[i].NumElements(); 465 auto data_src_ptr = StreamExecutorUtil::AsDeviceMemory<T>(params_input[i]); 466 DeviceMemoryBase data_dst_ptr = 467 SliceDeviceMemory(*data_dst, params[i].offset, size_in_bytes); 468 stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes); 469 } 470 } 471 472 } // namespace 473 474 // Note: all following kernels depend on a RnnDescriptor instance, which 475 // according to Cudnn official doc should be kept around and reused across all 476 // Cudnn kernels in the same model. 477 // In Tensorflow, we don't pass the reference across different OpKernels, 478 // rather, recreate it separately in each OpKernel, which does no cause issue: 479 // CudnnDropoutDescriptor keeps a reference to a memory for 480 // random number generator state. During recreation, this state is lost. 481 // However, only forward-pass Cudnn APIs make use of the state. 482 483 // A common base class for RNN kernels. It extracts common attributes and 484 // shape validations. 485 class CudnnRNNKernelCommon : public OpKernel { 486 protected: 487 explicit CudnnRNNKernelCommon(OpKernelConstruction* context) 488 : OpKernel(context) { 489 OP_REQUIRES_OK(context, context->GetAttr("dropout", &dropout_)); 490 OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_)); 491 OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_)); 492 string str; 493 OP_REQUIRES_OK(context, context->GetAttr("rnn_mode", &str)); 494 OP_REQUIRES_OK(context, ParseRNNMode(str, &model_types_.rnn_mode)); 495 OP_REQUIRES_OK(context, context->GetAttr("input_mode", &str)); 496 OP_REQUIRES_OK(context, 497 ParseTFRNNInputMode(str, &model_types_.rnn_input_mode)); 498 OP_REQUIRES_OK(context, context->GetAttr("direction", &str)); 499 OP_REQUIRES_OK( 500 context, ParseRNNDirectionMode(str, &model_types_.rnn_direction_mode)); 501 // Reset CudnnRnnDescriptor and related random number generate states in 502 // every Compute() call. 503 OP_REQUIRES_OK(context, ReadBoolFromEnvVar("TF_CUDNN_RESET_RND_GEN_STATE", 504 false, &reset_rnd_gen_state_)); 505 } 506 507 bool HasInputC() const { return model_types_.HasInputC(); } 508 RnnMode rnn_mode() const { return model_types_.rnn_mode; } 509 TFRNNInputMode rnn_input_mode() const { return model_types_.rnn_input_mode; } 510 RnnDirectionMode rnn_direction_mode() const { 511 return model_types_.rnn_direction_mode; 512 } 513 CudnnModelTypes model_types() const { return model_types_; } 514 float dropout() const { return dropout_; } 515 uint64 seed() { return (static_cast<uint64>(seed_) << 32) | seed2_; } 516 bool ResetRndGenState() { return reset_rnd_gen_state_; } 517 518 template <typename T> 519 Status ExtractCudnnRNNParamsInfo(OpKernelContext* context, 520 std::unique_ptr<RnnDescriptor>* rnn_desc) { 521 const Tensor* num_layers_t = nullptr; 522 TF_RETURN_IF_ERROR(context->input("num_layers", &num_layers_t)); 523 if (!TensorShapeUtils::IsScalar(num_layers_t->shape())) { 524 return errors::InvalidArgument("num_layers is not a scalar"); 525 } 526 int num_layers = num_layers_t->scalar<int>()(); 527 const Tensor* num_units_t = nullptr; 528 TF_RETURN_IF_ERROR(context->input("num_units", &num_units_t)); 529 if (!TensorShapeUtils::IsScalar(num_units_t->shape())) { 530 return errors::InvalidArgument("num_units is not a scalar"); 531 } 532 int num_units = num_units_t->scalar<int>()(); 533 const Tensor* input_size_t = nullptr; 534 TF_RETURN_IF_ERROR(context->input("input_size", &input_size_t)); 535 if (!TensorShapeUtils::IsScalar(input_size_t->shape())) { 536 return errors::InvalidArgument("input_size is not a scalar"); 537 } 538 int input_size = input_size_t->scalar<int>()(); 539 540 RnnInputMode input_mode; 541 TF_RETURN_IF_ERROR( 542 ToRNNInputMode(rnn_input_mode(), num_units, input_size, &input_mode)); 543 544 auto* stream = context->op_device_context()->stream(); 545 // ExtracCudnnRNNParamsInfo is only called by op_kernels that do not require 546 // random number generator, therefore set state_allocator to nullptr. 547 auto rnn_desc_s = stream->parent()->createRnnDescriptor( 548 num_layers, num_units, input_size, input_mode, rnn_direction_mode(), 549 rnn_mode(), ToDataType<T>::value, dropout(), seed(), 550 nullptr /* state_allocator */); 551 if (!rnn_desc_s.ok()) { 552 return FromExecutorStatus(rnn_desc_s); 553 } 554 *rnn_desc = rnn_desc_s.ConsumeValueOrDie(); 555 return Status::OK(); 556 } 557 558 private: 559 int seed_; 560 int seed2_; 561 float dropout_; 562 bool reset_rnd_gen_state_; 563 564 CudnnModelTypes model_types_; 565 }; 566 567 // A class that returns the size of the opaque parameter buffer. The user should 568 // use that to create the actual parameter buffer for training. However, it 569 // should not be used for saving and restoring. 570 template <typename T, typename Index> 571 class CudnnRNNParamsSizeOp<GPUDevice, T, Index> : public CudnnRNNKernelCommon { 572 public: 573 typedef GPUDevice Device; 574 explicit CudnnRNNParamsSizeOp(OpKernelConstruction* context) 575 : CudnnRNNKernelCommon(context) {} 576 577 void Compute(OpKernelContext* context) override { 578 std::unique_ptr<RnnDescriptor> rnn_desc; 579 OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc)); 580 int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes(); 581 CHECK(params_size_in_bytes % sizeof(T) == 0) 582 << "params_size_in_bytes must be multiple of element size"; 583 int64 params_size = params_size_in_bytes / sizeof(T); 584 585 Tensor* output_t = nullptr; 586 OP_REQUIRES_OK(context, context->allocate_output(0, {1}, &output_t)); 587 *output_t->template flat<Index>().data() = params_size; 588 } 589 }; 590 591 #define REGISTER_GPU(T) \ 592 REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsSize") \ 593 .Device(DEVICE_GPU) \ 594 .HostMemory("num_layers") \ 595 .HostMemory("num_units") \ 596 .HostMemory("input_size") \ 597 .HostMemory("params_size") \ 598 .TypeConstraint<T>("T") \ 599 .TypeConstraint<int32>("S"), \ 600 CudnnRNNParamsSizeOp<GPUDevice, T, int32>); 601 602 TF_CALL_half(REGISTER_GPU); 603 TF_CALL_float(REGISTER_GPU); 604 TF_CALL_double(REGISTER_GPU); 605 #undef REGISTER_GPU 606 607 // Convert weight and bias params from a platform-specific layout to the 608 // canonical form. 609 template <typename T> 610 class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon { 611 public: 612 typedef GPUDevice Device; 613 explicit CudnnRNNParamsToCanonical(OpKernelConstruction* context) 614 : CudnnRNNKernelCommon(context) { 615 OP_REQUIRES_OK(context, context->GetAttr("num_params", &num_params_)); 616 } 617 618 void Compute(OpKernelContext* context) override { 619 const Tensor& input = context->input(3); 620 auto input_ptr = StreamExecutorUtil::AsDeviceMemory<T>(input); 621 auto* stream = context->op_device_context()->stream(); 622 623 std::unique_ptr<RnnDescriptor> rnn_desc; 624 OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc)); 625 int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes(); 626 CHECK(params_size_in_bytes % sizeof(T) == 0) 627 << "params_size_in_bytes must be multiple of element size"; 628 629 const Tensor* num_units_t = nullptr; 630 OP_REQUIRES_OK(context, context->input("num_units", &num_units_t)); 631 CHECK(TensorShapeUtils::IsScalar(num_units_t->shape())) 632 << "num_units is not a scalar"; 633 int num_units = num_units_t->scalar<int>()(); 634 635 const Tensor* input_size_t = nullptr; 636 OP_REQUIRES_OK(context, context->input("input_size", &input_size_t)); 637 CHECK(TensorShapeUtils::IsScalar(input_size_t->shape())) 638 << "input_size is not a scalar"; 639 int input_size = input_size_t->scalar<int>()(); 640 641 const Tensor* num_layers_t = nullptr; 642 OP_REQUIRES_OK(context, context->input("num_layers", &num_layers_t)); 643 CHECK(TensorShapeUtils::IsScalar(num_layers_t->shape())) 644 << "num_layers is not a scalar"; 645 int num_layers = num_layers_t->scalar<int>()(); 646 int num_dirs = 1; 647 if (rnn_direction_mode() == RnnDirectionMode::kRnnBidirectional) { 648 num_dirs = 2; 649 } 650 const int num_params_per_layer = num_params_ / num_layers / num_dirs; 651 // Number of params applied on inputs. The rest are applied on recurrent 652 // hidden states. 653 const int num_params_input_state = num_params_per_layer / 2; 654 CHECK(num_params_ % (num_layers * num_dirs) == 0) 655 << "Number of params is not a multiple of num_layers * num_dirs."; 656 CHECK(num_params_per_layer % 2 == 0) 657 << "Number of params per layer is not a even number."; 658 659 CHECK(num_params_ == rnn_desc->ParamsWeightRegions().size()) 660 << "Number of params mismatch. Expected " << num_params_ << ", got " 661 << rnn_desc->ParamsWeightRegions().size(); 662 for (int i = 0; i < rnn_desc->ParamsWeightRegions().size(); i++) { 663 int64 size_in_bytes = rnn_desc->ParamsWeightRegions()[i].size; 664 int64 size = size_in_bytes / sizeof(T); 665 const int layer_idx = i / num_params_per_layer; 666 const int index_within_layer = i % num_params_per_layer; 667 int width = 0, height = num_units; 668 // In CuDNN layout, each layer has num_params_per_layer params, with the 669 // first half a.k.a num_params_input_state params applied on the inputs, 670 // and the second half on the recurrent hidden states. 671 bool apply_on_input_state = index_within_layer < num_params_input_state; 672 if (rnn_direction_mode() == RnnDirectionMode::kRnnUnidirectional) { 673 if (layer_idx == 0 && apply_on_input_state) { 674 width = input_size; 675 } else { 676 width = num_units; 677 } 678 } else { 679 if (apply_on_input_state) { 680 if (layer_idx <= 1) { 681 // First fwd or bak layer. 682 width = input_size; 683 } else { 684 // Following layers, cell inputs are concatenated outputs of 685 // its prior layer. 686 width = 2 * num_units; 687 } 688 } else { 689 width = num_units; 690 } 691 } 692 CHECK(size == width * height) << "Params size mismatch. Expected " 693 << width * height << ", got " << size; 694 Tensor* output = nullptr; 695 OP_REQUIRES_OK(context, context->allocate_output( 696 i, TensorShape({height, width}), &output)); 697 DeviceMemoryBase data_src_ptr = SliceDeviceMemory( 698 input_ptr, rnn_desc->ParamsWeightRegions()[i].offset, size_in_bytes); 699 auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output); 700 stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes); 701 } 702 703 OP_REQUIRES(context, num_params_ == rnn_desc->ParamsBiasRegions().size(), 704 errors::InvalidArgument("Number of params mismatch. Expected ", 705 num_params_, ", got ", 706 rnn_desc->ParamsBiasRegions().size())); 707 for (int i = 0; i < rnn_desc->ParamsBiasRegions().size(); i++) { 708 int64 size_in_bytes = rnn_desc->ParamsBiasRegions()[i].size; 709 int64 size = size_in_bytes / sizeof(T); 710 OP_REQUIRES(context, size == num_units, 711 errors::InvalidArgument("Params size mismatch. Expected ", 712 num_units, ", got ", size)); 713 714 Tensor* output = nullptr; 715 OP_REQUIRES_OK(context, 716 context->allocate_output(num_params_ + i, 717 TensorShape({size}), &output)); 718 DeviceMemoryBase data_src_ptr = SliceDeviceMemory( 719 input_ptr, rnn_desc->ParamsBiasRegions()[i].offset, size_in_bytes); 720 auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output); 721 stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes); 722 } 723 } 724 725 private: 726 int num_params_; 727 }; 728 729 #define REGISTER_GPU(T) \ 730 REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsToCanonical") \ 731 .Device(DEVICE_GPU) \ 732 .HostMemory("num_layers") \ 733 .HostMemory("num_units") \ 734 .HostMemory("input_size") \ 735 .TypeConstraint<T>("T"), \ 736 CudnnRNNParamsToCanonical<GPUDevice, T>); 737 TF_CALL_half(REGISTER_GPU); 738 TF_CALL_float(REGISTER_GPU); 739 TF_CALL_double(REGISTER_GPU); 740 #undef REGISTER_GPU 741 742 // Convert weight and bias params from the canonical form to a 743 // platform-specific layout. 744 template <typename T> 745 class CudnnRNNCanonicalToParams<GPUDevice, T> : public CudnnRNNKernelCommon { 746 public: 747 typedef GPUDevice Device; 748 explicit CudnnRNNCanonicalToParams(OpKernelConstruction* context) 749 : CudnnRNNKernelCommon(context) {} 750 751 void Compute(OpKernelContext* context) override { 752 std::unique_ptr<RnnDescriptor> rnn_desc; 753 OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc)); 754 int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes(); 755 CHECK(params_size_in_bytes % sizeof(T) == 0) 756 << "params_size_in_bytes must be multiple of element size"; 757 Tensor* output = nullptr; 758 int params_size = params_size_in_bytes / sizeof(T); 759 OP_REQUIRES_OK(context, 760 context->allocate_output(0, {params_size}, &output)); 761 auto output_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output); 762 auto* stream = context->op_device_context()->stream(); 763 764 OpInputList weights; 765 OP_REQUIRES_OK(context, context->input_list("weights", &weights)); 766 RestoreParams<T>(weights, rnn_desc->ParamsWeightRegions(), &output_ptr, 767 stream); 768 769 OpInputList biases; 770 OP_REQUIRES_OK(context, context->input_list("biases", &biases)); 771 RestoreParams<T>(biases, rnn_desc->ParamsBiasRegions(), &output_ptr, 772 stream); 773 } 774 }; 775 776 #define REGISTER_GPU(T) \ 777 REGISTER_KERNEL_BUILDER(Name("CudnnRNNCanonicalToParams") \ 778 .Device(DEVICE_GPU) \ 779 .HostMemory("num_layers") \ 780 .HostMemory("num_units") \ 781 .HostMemory("input_size") \ 782 .TypeConstraint<T>("T"), \ 783 CudnnRNNCanonicalToParams<GPUDevice, T>); 784 TF_CALL_half(REGISTER_GPU); 785 TF_CALL_float(REGISTER_GPU); 786 TF_CALL_double(REGISTER_GPU); 787 #undef REGISTER_GPU 788 789 // Pointers to RNN scratch space for a specific set of shape parameters (used as 790 // a hash table value in CudnnRNNForwardOp and CudnnRNNBackwardOp). 791 struct RnnScratchSpace { 792 std::unique_ptr<RnnDescriptor> rnn_desc; 793 std::unique_ptr<CudnnRNNPersistentSpaceAllocator> dropout_state_allocator; 794 }; 795 796 // Run the forward operation of the RNN model. 797 template <typename T> 798 class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon { 799 public: 800 typedef GPUDevice Device; 801 explicit CudnnRNNForwardOp(OpKernelConstruction* context) 802 : CudnnRNNKernelCommon(context) { 803 OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_)); 804 } 805 806 void Compute(OpKernelContext* context) override { 807 const Tensor* input = nullptr; 808 const Tensor* input_h = nullptr; 809 const Tensor* input_c = nullptr; 810 const Tensor* params = nullptr; 811 CudnnModelShapes model_shapes; 812 OP_REQUIRES_OK(context, 813 ExtractForwardInput(context, model_types(), &input, &input_h, 814 &input_c, ¶ms, &model_shapes)); 815 const auto& input_shape = model_shapes.input_shape; 816 const auto& hidden_state_shape = model_shapes.hidden_state_shape; 817 const auto& output_shape = model_shapes.output_shape; 818 819 Tensor* output = nullptr; 820 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 821 Tensor* output_h = nullptr; 822 OP_REQUIRES_OK(context, 823 context->allocate_output(1, hidden_state_shape, &output_h)); 824 Tensor* output_c = nullptr; 825 if (HasInputC()) { 826 // Only LSTM uses input_c and output_c. So for all other models, we only 827 // need to create dummy outputs. 828 OP_REQUIRES_OK( 829 context, context->allocate_output(2, hidden_state_shape, &output_c)); 830 } else { 831 OP_REQUIRES_OK(context, context->allocate_output(2, {}, &output_c)); 832 } 833 834 auto* stream = context->op_device_context()->stream(); 835 auto* executor = stream->parent(); 836 RnnInputMode input_mode; 837 OP_REQUIRES_OK(context, 838 ToRNNInputMode(rnn_input_mode(), model_shapes.num_units, 839 model_shapes.input_size, &input_mode)); 840 auto data_type = ToDataType<T>::value; 841 842 auto input_desc_s = executor->createRnnSequenceTensorDescriptor( 843 input_shape.dim_size(0), input_shape.dim_size(1), 844 input_shape.dim_size(2), data_type); 845 OP_REQUIRES_OK(context, FromExecutorStatus(input_desc_s)); 846 auto input_desc = input_desc_s.ConsumeValueOrDie(); 847 848 auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor( 849 hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1), 850 hidden_state_shape.dim_size(2), data_type); 851 OP_REQUIRES_OK(context, FromExecutorStatus(hidden_state_desc_s)); 852 auto hidden_state_desc = hidden_state_desc_s.ConsumeValueOrDie(); 853 854 auto output_desc_s = executor->createRnnSequenceTensorDescriptor( 855 output_shape.dim_size(0), output_shape.dim_size(1), 856 output_shape.dim_size(2), data_type); 857 OP_REQUIRES_OK(context, FromExecutorStatus(output_desc_s)); 858 auto output_desc = output_desc_s.ConsumeValueOrDie(); 859 860 auto input_data = AsDeviceMemory<T>(input); 861 auto input_h_data = AsDeviceMemory<T>(input_h); 862 DeviceMemory<T> input_c_data; 863 if (HasInputC()) { 864 input_c_data = AsDeviceMemory<T>(input_c); 865 } 866 auto params_data = AsDeviceMemory<T>(params); 867 auto output_data = AsDeviceMemory<T>(output); 868 auto output_h_data = AsDeviceMemory<T>(output_h); 869 DeviceMemory<T> output_c_data; 870 if (HasInputC()) { 871 output_c_data = AsDeviceMemory<T>(output_c); 872 } 873 874 // Creates a memory callback for the reserve_space. The memory lives in the 875 // output of this kernel. And it will be fed into the backward pass when 876 // needed. 877 CudnnRNNReserveSpaceAllocator<T> reserve_space_allocator(context, 3); 878 if (!is_training_) { 879 Tensor* dummy_reserve_space = nullptr; 880 OP_REQUIRES_OK(context, 881 context->allocate_output(3, {}, &dummy_reserve_space)); 882 } 883 // Creates a memory callback for the workspace. The memory lives to the end 884 // of this kernel calls. 885 CudnnRNNWorkspaceAllocator workspace_allocator(context); 886 bool launch_status = false; 887 { 888 mutex_lock l(mu_); 889 RnnScratchSpace& rnn_state = rnn_state_cache_[model_shapes]; 890 if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) { 891 CudnnRNNPersistentSpaceAllocator* dropout_state_allocator = 892 new CudnnRNNPersistentSpaceAllocator(context); 893 rnn_state.dropout_state_allocator.reset(dropout_state_allocator); 894 auto rnn_desc_s = executor->createRnnDescriptor( 895 model_shapes.num_layers, model_shapes.num_units, 896 model_shapes.input_size, input_mode, rnn_direction_mode(), 897 rnn_mode(), data_type, dropout(), seed(), dropout_state_allocator); 898 OP_REQUIRES_OK(context, FromExecutorStatus(rnn_desc_s)); 899 rnn_state.rnn_desc = std::move(rnn_desc_s.ConsumeValueOrDie()); 900 } 901 launch_status = 902 stream 903 ->ThenRnnForward(*rnn_state.rnn_desc, *input_desc, input_data, 904 *hidden_state_desc, input_h_data, 905 *hidden_state_desc, input_c_data, params_data, 906 *output_desc, &output_data, *hidden_state_desc, 907 &output_h_data, *hidden_state_desc, 908 &output_c_data, is_training_, 909 &reserve_space_allocator, &workspace_allocator) 910 .ok(); 911 } 912 OP_REQUIRES(context, launch_status, 913 errors::Internal("Failed to call ThenRnnForward")); 914 } 915 916 private: 917 mutex mu_; 918 bool is_training_; 919 std::unordered_map<CudnnModelShapes, RnnScratchSpace, CudnnModelShapesHasher, 920 CudnnModelShapesComparator> 921 rnn_state_cache_ GUARDED_BY(mu_); 922 }; 923 924 #define REGISTER_GPU(T) \ 925 REGISTER_KERNEL_BUILDER( \ 926 Name("CudnnRNN").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 927 CudnnRNNForwardOp<GPUDevice, T>); 928 929 TF_CALL_half(REGISTER_GPU); 930 TF_CALL_float(REGISTER_GPU); 931 TF_CALL_double(REGISTER_GPU); 932 #undef REGISTER_GPU 933 934 // Run the backward operation of the RNN model. 935 template <typename T> 936 class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon { 937 public: 938 typedef GPUDevice Device; 939 940 explicit CudnnRNNBackwardOp(OpKernelConstruction* context) 941 : CudnnRNNKernelCommon(context) {} 942 943 void Compute(OpKernelContext* context) override { 944 const Tensor* input = nullptr; 945 const Tensor* input_h = nullptr; 946 const Tensor* input_c = nullptr; 947 const Tensor* params = nullptr; 948 CudnnModelShapes model_shapes; 949 OP_REQUIRES_OK(context, 950 ExtractForwardInput(context, model_types(), &input, &input_h, 951 &input_c, ¶ms, &model_shapes)); 952 953 const auto& input_shape = model_shapes.input_shape; 954 const auto& hidden_state_shape = model_shapes.hidden_state_shape; 955 const auto& output_shape = model_shapes.output_shape; 956 957 auto data_type = ToDataType<T>::value; 958 const Tensor* output = nullptr; 959 OP_REQUIRES_OK(context, context->input("output", &output)); 960 OP_REQUIRES(context, output_shape == output->shape(), 961 errors::InvalidArgument( 962 "input_h and input_c must have the same shape: ", 963 input_h->shape().DebugString(), " ", 964 input_c->shape().DebugString())); 965 const Tensor* output_h = nullptr; 966 OP_REQUIRES_OK(context, context->input("output_h", &output_h)); 967 OP_REQUIRES(context, output_h->shape() == hidden_state_shape, 968 errors::InvalidArgument( 969 "Invalid output_h shape: ", output_h->shape().DebugString(), 970 " ", hidden_state_shape.DebugString())); 971 const Tensor* output_c = nullptr; 972 if (HasInputC()) { 973 // Only LSTM uses input_c and output_c. So for all other models, we only 974 // need to create dummy outputs. 975 OP_REQUIRES_OK(context, context->input("output_c", &output_c)); 976 OP_REQUIRES(context, output_c->shape() == hidden_state_shape, 977 errors::InvalidArgument("Invalid output_c shape: ", 978 output_c->shape().DebugString(), " ", 979 hidden_state_shape.DebugString())); 980 } 981 982 const Tensor* output_backprop = nullptr; 983 OP_REQUIRES_OK(context, 984 context->input("output_backprop", &output_backprop)); 985 OP_REQUIRES(context, output_backprop->shape() == output_shape, 986 errors::InvalidArgument("Invalid output_backprop shapes: ", 987 output_backprop->shape().DebugString(), 988 " ", output_shape.DebugString())); 989 990 const Tensor* output_h_backprop = nullptr; 991 OP_REQUIRES_OK(context, 992 context->input("output_h_backprop", &output_h_backprop)); 993 OP_REQUIRES( 994 context, output_h_backprop->shape() == hidden_state_shape, 995 errors::InvalidArgument("Invalid output_h_backprop shapes: ", 996 output_h_backprop->shape().DebugString(), " ", 997 hidden_state_shape.DebugString())); 998 const Tensor* output_c_backprop = nullptr; 999 if (HasInputC()) { 1000 OP_REQUIRES_OK(context, 1001 context->input("output_c_backprop", &output_c_backprop)); 1002 OP_REQUIRES( 1003 context, output_c_backprop->shape() == hidden_state_shape, 1004 errors::InvalidArgument("Invalid output_c_backprop shapes: ", 1005 output_c_backprop->shape().DebugString(), " ", 1006 hidden_state_shape.DebugString())); 1007 } 1008 const Tensor* reserve_space_const = nullptr; 1009 // This is the same "reserve_space" created by the forward op. 1010 // It can also be modified by this backward operation. 1011 OP_REQUIRES_OK(context, 1012 context->input("reserve_space", &reserve_space_const)); 1013 // Cudnn needs the reserve space to be writeable. This is fine because they 1014 // are opaque. 1015 Tensor* reserve_space = const_cast<Tensor*>(reserve_space_const); 1016 1017 Tensor* input_backprop = nullptr; 1018 OP_REQUIRES_OK( 1019 context, context->allocate_output(0, input->shape(), &input_backprop)); 1020 Tensor* input_h_backprop = nullptr; 1021 OP_REQUIRES_OK(context, context->allocate_output(1, input_h->shape(), 1022 &input_h_backprop)); 1023 Tensor* input_c_backprop = nullptr; 1024 if (HasInputC()) { 1025 OP_REQUIRES_OK(context, context->allocate_output(2, input_c->shape(), 1026 &input_c_backprop)); 1027 } else { 1028 OP_REQUIRES_OK(context, 1029 context->allocate_output(2, {}, &input_c_backprop)); 1030 } 1031 Tensor* params_backprop = nullptr; 1032 OP_REQUIRES_OK(context, context->allocate_output(3, params->shape(), 1033 ¶ms_backprop)); 1034 1035 auto* stream = context->op_device_context()->stream(); 1036 auto* executor = stream->parent(); 1037 RnnInputMode input_mode; 1038 OP_REQUIRES_OK(context, 1039 ToRNNInputMode(rnn_input_mode(), model_shapes.num_units, 1040 model_shapes.input_size, &input_mode)); 1041 1042 auto input_desc_s = executor->createRnnSequenceTensorDescriptor( 1043 input_shape.dim_size(0), input_shape.dim_size(1), 1044 input_shape.dim_size(2), data_type); 1045 OP_REQUIRES_OK(context, FromExecutorStatus(input_desc_s)); 1046 auto input_desc = input_desc_s.ConsumeValueOrDie(); 1047 1048 auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor( 1049 hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1), 1050 hidden_state_shape.dim_size(2), data_type); 1051 OP_REQUIRES_OK(context, FromExecutorStatus(hidden_state_desc_s)); 1052 auto hidden_state_desc = hidden_state_desc_s.ConsumeValueOrDie(); 1053 1054 auto output_desc_s = executor->createRnnSequenceTensorDescriptor( 1055 output_shape.dim_size(0), output_shape.dim_size(1), 1056 output_shape.dim_size(2), data_type); 1057 OP_REQUIRES_OK(context, FromExecutorStatus(output_desc_s)); 1058 auto output_desc = output_desc_s.ConsumeValueOrDie(); 1059 1060 auto input_data = AsDeviceMemory<T>(input); 1061 auto input_h_data = AsDeviceMemory<T>(input_h); 1062 DeviceMemory<T> input_c_data; 1063 if (HasInputC()) { 1064 input_c_data = AsDeviceMemory<T>(input_c); 1065 } 1066 auto params_data = AsDeviceMemory<T>(params); 1067 auto output_data = AsDeviceMemory<T>(output); 1068 auto output_h_data = AsDeviceMemory<T>(output_h); 1069 DeviceMemory<T> output_c_data; 1070 if (HasInputC()) { 1071 output_c_data = AsDeviceMemory<T>(output_c); 1072 } 1073 auto output_backprop_data = AsDeviceMemory<T>(output_backprop); 1074 auto output_h_backprop_data = AsDeviceMemory<T>(output_h_backprop); 1075 DeviceMemory<T> output_c_backprop_data; 1076 if (HasInputC()) { 1077 output_c_backprop_data = AsDeviceMemory<T>(output_c_backprop); 1078 } 1079 auto input_backprop_data = AsDeviceMemory<T>(input_backprop); 1080 auto input_h_backprop_data = AsDeviceMemory<T>(input_h_backprop); 1081 DeviceMemory<T> input_c_backprop_data; 1082 if (HasInputC()) { 1083 input_c_backprop_data = AsDeviceMemory<T>(input_c_backprop); 1084 } 1085 auto params_backprop_data = AsDeviceMemory<T>(params_backprop); 1086 auto reserve_space_uint8 = CastDeviceMemory<uint8, T>(reserve_space); 1087 // Creates a memory callback for the workspace. The memory lives to the end 1088 // of this kernel calls. 1089 CudnnRNNWorkspaceAllocator workspace_allocator(context); 1090 bool launch_status = false; 1091 { 1092 mutex_lock l(mu_); 1093 RnnScratchSpace& rnn_state = rnn_state_cache_[model_shapes]; 1094 if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) { 1095 CudnnRNNPersistentSpaceAllocator* dropout_state_allocator = 1096 new CudnnRNNPersistentSpaceAllocator(context); 1097 rnn_state.dropout_state_allocator.reset(dropout_state_allocator); 1098 auto rnn_desc_s = executor->createRnnDescriptor( 1099 model_shapes.num_layers, model_shapes.num_units, 1100 model_shapes.input_size, input_mode, rnn_direction_mode(), 1101 rnn_mode(), data_type, dropout(), seed(), dropout_state_allocator); 1102 OP_REQUIRES_OK(context, FromExecutorStatus(rnn_desc_s)); 1103 rnn_state.rnn_desc = std::move(rnn_desc_s.ConsumeValueOrDie()); 1104 } 1105 launch_status = 1106 stream 1107 ->ThenRnnBackward(*rnn_state.rnn_desc, *input_desc, input_data, 1108 *hidden_state_desc, input_h_data, 1109 *hidden_state_desc, input_c_data, params_data, 1110 *output_desc, output_data, *hidden_state_desc, 1111 output_h_data, *hidden_state_desc, 1112 output_c_data, output_backprop_data, 1113 output_h_backprop_data, output_c_backprop_data, 1114 &input_backprop_data, &input_h_backprop_data, 1115 &input_c_backprop_data, ¶ms_backprop_data, 1116 &reserve_space_uint8, &workspace_allocator) 1117 .ok(); 1118 } 1119 OP_REQUIRES(context, launch_status, 1120 errors::Internal("Failed to call ThenRnnBackward")); 1121 } 1122 1123 private: 1124 mutex mu_; 1125 std::unordered_map<CudnnModelShapes, RnnScratchSpace, CudnnModelShapesHasher, 1126 CudnnModelShapesComparator> 1127 rnn_state_cache_ GUARDED_BY(mu_); 1128 }; 1129 1130 #define REGISTER_GPU(T) \ 1131 REGISTER_KERNEL_BUILDER( \ 1132 Name("CudnnRNNBackprop").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 1133 CudnnRNNBackwardOp<GPUDevice, T>); 1134 1135 TF_CALL_half(REGISTER_GPU); 1136 TF_CALL_float(REGISTER_GPU); 1137 TF_CALL_double(REGISTER_GPU); 1138 #undef REGISTER_GPU 1139 1140 // TODO(zhengxq): Add the conversion of Cudnn RNN Params from and to 1141 // its canonical form. 1142 1143 #endif // GOOGLE_CUDA 1144 1145 } // namespace tensorflow 1146