1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #define EIGEN_USE_THREADS 16 17 #include <stddef.h> 18 #include <atomic> 19 #include <cmath> 20 #include <functional> 21 #include <limits> 22 #include <string> 23 #include <unordered_set> 24 25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 26 #include "tensorflow/core/framework/device_base.h" 27 #include "tensorflow/core/framework/kernel_def_builder.h" 28 #include "tensorflow/core/framework/op.h" 29 #include "tensorflow/core/framework/op_def_builder.h" 30 #include "tensorflow/core/framework/op_kernel.h" 31 #include "tensorflow/core/framework/register_types.h" 32 #include "tensorflow/core/framework/tensor.h" 33 #include "tensorflow/core/framework/tensor_shape.h" 34 #include "tensorflow/core/framework/tensor_types.h" 35 #include "tensorflow/core/framework/types.h" 36 #include "tensorflow/core/kernels/gpu_utils.h" 37 #include "tensorflow/core/lib/core/errors.h" 38 #include "tensorflow/core/lib/core/status.h" 39 #include "tensorflow/core/lib/core/stringpiece.h" 40 #include "tensorflow/core/lib/gtl/inlined_vector.h" 41 #include "tensorflow/core/lib/hash/hash.h" 42 #include "tensorflow/core/lib/strings/stringprintf.h" 43 #include "tensorflow/core/platform/fingerprint.h" 44 #include "tensorflow/core/platform/mutex.h" 45 #include "tensorflow/core/platform/types.h" 46 #include "tensorflow/core/util/env_var.h" 47 #include "tensorflow/core/util/use_cudnn.h" 48 49 #if GOOGLE_CUDA 50 #include "tensorflow/core/platform/stream_executor.h" 51 #include "tensorflow/core/util/stream_executor_util.h" 52 #endif // GOOGLE_CUDA 53 54 /* 55 * This module implements ops that fuse a multi-layer multi-step RNN/LSTM model 56 * using the underlying Cudnn library. 57 * 58 * Cudnn RNN library exposes an opaque parameter buffer with unknown layout and 59 * format. And it is very likely that if saved, they cannot be used across 60 * different GPUs. So users need to first query the size of the opaque 61 * parameter buffer, and convert it to and from its canonical forms. But each 62 * actual training step is carried out with the parameter buffer. 63 * 64 * Similar to many other ops, the forward op has two flavors: training and 65 * inference. When training is specified, additional data in reserve_space will 66 * be produced for the backward pass. So there is a performance penalty. 67 * 68 * In addition to the actual data and reserve_space, Cudnn also needs more 69 * memory as temporary workspace. The memory management to and from 70 * stream-executor is done through ScratchAllocator. In general, 71 * stream-executor is responsible for creating the memory of proper size. And 72 * TensorFlow is responsible for making sure the memory is alive long enough 73 * and recycles afterwards. 74 * 75 */ 76 namespace tensorflow { 77 78 using CPUDevice = Eigen::ThreadPoolDevice; 79 80 #if GOOGLE_CUDA 81 82 using GPUDevice = Eigen::GpuDevice; 83 using se::Stream; 84 using se::StreamExecutor; 85 using se::dnn::RnnDescriptor; 86 87 template <typename Device, typename T, typename Index> 88 class CudnnRNNParamsSizeOp; 89 90 template <typename Device, typename T> 91 class CudnnRNNParamsToCanonical; 92 93 template <typename Device, typename T> 94 class CudnnRNNCanonicalToParams; 95 96 template <typename Device, typename T> 97 class CudnnRNNForwardOp; 98 99 template <typename Device, typename T> 100 class CudnnRNNBackwardOp; 101 102 template <typename Device, typename T> 103 class CudnnRNNForwardOpV2; 104 105 template <typename Device, typename T> 106 class CudnnRNNBackwardOpV2; 107 108 template <typename Device, typename T> 109 class CudnnRNNForwardOpV3; 110 111 template <typename Device, typename T> 112 class CudnnRNNBackwardOpV3; 113 114 enum class TFRNNInputMode { 115 kRNNLinearInput = 0, 116 kRNNSkipInput = 1, 117 kAutoSelect = 9999999 118 }; 119 120 namespace { 121 using se::DeviceMemory; 122 using se::DeviceMemoryBase; 123 using se::ScratchAllocator; 124 using se::dnn::AlgorithmConfig; 125 using se::dnn::AlgorithmDesc; 126 using se::dnn::ProfileResult; 127 using se::dnn::RnnDirectionMode; 128 using se::dnn::RnnInputMode; 129 using se::dnn::RnnMode; 130 using se::dnn::RnnSequenceTensorDescriptor; 131 using se::dnn::RnnStateTensorDescriptor; 132 using se::dnn::ToDataType; 133 using se::port::StatusOr; 134 135 uint64 HashList(const std::vector<int>& list) { 136 if (list.empty()) { 137 return 0; 138 } 139 uint64 hash_code = list[0]; 140 for (int i = 1; i < list.size(); i++) { 141 hash_code = Hash64Combine(hash_code, list[i]); 142 } 143 return hash_code; 144 } 145 146 // Encapsulate all the shape information that is used in both forward and 147 // backward rnn operations. 148 class CudnnRnnParameters { 149 public: 150 CudnnRnnParameters(int num_layers, int input_size, int num_units, 151 int max_seq_length, int batch_size, int dir_count, 152 bool has_dropout, bool is_training, RnnMode rnn_mode, 153 TFRNNInputMode rnn_input_mode, DataType dtype) 154 : num_layers_(num_layers), 155 input_size_(input_size), 156 num_units_(num_units), 157 seq_length_(max_seq_length), 158 batch_size_(batch_size), 159 dir_count_(dir_count), 160 has_dropout_(has_dropout), 161 is_training_(is_training), 162 rnn_mode_(rnn_mode), 163 rnn_input_mode_(rnn_input_mode), 164 dtype_(dtype) { 165 hash_code_ = 166 HashList({num_layers, input_size, num_units, max_seq_length, batch_size, 167 dir_count, static_cast<int>(has_dropout), 168 static_cast<int>(is_training), static_cast<int>(rnn_mode), 169 static_cast<int>(rnn_input_mode), dtype}); 170 } 171 172 bool operator==(const CudnnRnnParameters& other) const { 173 return this->get_data_as_tuple() == other.get_data_as_tuple(); 174 } 175 176 bool operator!=(const CudnnRnnParameters& other) const { 177 return !(*this == other); 178 } 179 uint64 hash() const { return hash_code_; } 180 181 string ToString() const { 182 std::vector<string> fields = { 183 std::to_string(num_layers_), 184 std::to_string(input_size_), 185 std::to_string(num_units_), 186 std::to_string(seq_length_), 187 std::to_string(batch_size_), 188 std::to_string(dir_count_), 189 std::to_string(has_dropout_), 190 std::to_string(is_training_), 191 std::to_string(static_cast<int>(rnn_mode_)), 192 std::to_string(static_cast<int>(rnn_input_mode_)), 193 std::to_string(static_cast<int>(dtype_))}; 194 return str_util::Join(fields, ", "); 195 } 196 197 private: 198 using ParameterDataType = std::tuple<int, int, int, int, int, int, bool, bool, 199 RnnMode, TFRNNInputMode, DataType>; 200 201 ParameterDataType get_data_as_tuple() const { 202 return std::make_tuple(num_layers_, input_size_, num_units_, seq_length_, 203 batch_size_, dir_count_, has_dropout_, is_training_, 204 rnn_mode_, rnn_input_mode_, dtype_); 205 } 206 207 const int num_layers_; 208 const int input_size_; 209 const int num_units_; 210 const int seq_length_; 211 const int batch_size_; 212 const int dir_count_; 213 const bool has_dropout_; 214 const bool is_training_; 215 const RnnMode rnn_mode_; 216 const TFRNNInputMode rnn_input_mode_; 217 const DataType dtype_; 218 uint64 hash_code_; 219 }; 220 221 struct RnnAutoTuneGroup { 222 static string name() { return "Rnn"; } 223 }; 224 225 using AutoTuneRnnConfigMap = 226 AutoTuneSingleton<RnnAutoTuneGroup, CudnnRnnParameters, AlgorithmConfig>; 227 228 Status ParseRNNMode(const string& str, RnnMode* rnn_mode) { 229 if (str == "rnn_relu") { 230 *rnn_mode = RnnMode::kRnnRelu; 231 return Status::OK(); 232 } else if (str == "rnn_tanh") { 233 *rnn_mode = RnnMode::kRnnTanh; 234 return Status::OK(); 235 } else if (str == "lstm") { 236 *rnn_mode = RnnMode::kRnnLstm; 237 return Status::OK(); 238 } else if (str == "gru") { 239 *rnn_mode = RnnMode::kRnnGru; 240 return Status::OK(); 241 } 242 return errors::InvalidArgument("Invalid RNN mode: ", str); 243 } 244 245 Status ParseTFRNNInputMode(const string& str, TFRNNInputMode* rnn_input_mode) { 246 if (str == "linear_input") { 247 *rnn_input_mode = TFRNNInputMode::kRNNLinearInput; 248 return Status::OK(); 249 } else if (str == "skip_input") { 250 *rnn_input_mode = TFRNNInputMode::kRNNSkipInput; 251 return Status::OK(); 252 } else if (str == "auto_select") { 253 *rnn_input_mode = TFRNNInputMode::kAutoSelect; 254 return Status::OK(); 255 } 256 return errors::InvalidArgument("Invalid RNN input mode: ", str); 257 } 258 259 Status ParseRNNDirectionMode(const string& str, 260 RnnDirectionMode* rnn_dir_mode) { 261 if (str == "unidirectional") { 262 *rnn_dir_mode = RnnDirectionMode::kRnnUnidirectional; 263 return Status::OK(); 264 } else if (str == "bidirectional") { 265 *rnn_dir_mode = RnnDirectionMode::kRnnBidirectional; 266 return Status::OK(); 267 } 268 return errors::InvalidArgument("Invalid RNN direction mode: ", str); 269 } 270 271 Status ToRNNInputMode(TFRNNInputMode tf_input_mode, int num_units, 272 int input_size, RnnInputMode* input_mode) { 273 switch (tf_input_mode) { 274 case TFRNNInputMode::kRNNLinearInput: 275 *input_mode = RnnInputMode::kRnnLinearSkip; 276 break; 277 case TFRNNInputMode::kRNNSkipInput: 278 *input_mode = RnnInputMode::kRnnSkipInput; 279 break; 280 case TFRNNInputMode::kAutoSelect: 281 *input_mode = (input_size == num_units) ? RnnInputMode::kRnnSkipInput 282 : RnnInputMode::kRnnLinearSkip; 283 break; 284 default: 285 return errors::InvalidArgument("Invalid TF input mode: ", 286 static_cast<int>(tf_input_mode)); 287 } 288 return Status::OK(); 289 } 290 291 // TODO(zhengxq): Merge those into stream_executor_util.h. 292 template <typename T> 293 const DeviceMemory<T> AsDeviceMemory(const Tensor* tensor) { 294 return DeviceMemory<T>::MakeFromByteSize( 295 const_cast<T*>(tensor->template flat<T>().data()), 296 tensor->template flat<T>().size() * sizeof(T)); 297 } 298 299 template <typename T> 300 DeviceMemory<T> AsDeviceMemory(Tensor* tensor) { 301 return DeviceMemory<T>::MakeFromByteSize( 302 tensor->template flat<T>().data(), 303 tensor->template flat<T>().size() * sizeof(T)); 304 } 305 306 template <typename U, typename T> 307 DeviceMemory<U> CastDeviceMemory(Tensor* tensor) { 308 return DeviceMemory<U>::MakeFromByteSize( 309 tensor->template flat<T>().data(), 310 tensor->template flat<T>().size() * sizeof(T)); 311 } 312 313 DeviceMemoryBase SliceDeviceMemory(const DeviceMemoryBase& device_memory, 314 int64 offset, int64 size) { 315 const void* base_ptr = device_memory.opaque(); 316 void* offset_ptr = 317 const_cast<char*>(reinterpret_cast<const char*>(base_ptr) + offset); 318 CHECK(offset + size <= device_memory.size()) 319 << "The slice is not within the region of DeviceMemory."; 320 return DeviceMemoryBase(offset_ptr, size); 321 } 322 323 inline Status FromExecutorStatus(const se::port::Status& s) { 324 return s.ok() ? Status::OK() 325 : Status(static_cast<error::Code>(static_cast<int>(s.code())), 326 s.error_message()); 327 } 328 329 template <typename T> 330 inline Status FromExecutorStatus(const se::port::StatusOr<T>& s) { 331 return FromExecutorStatus(s.status()); 332 } 333 334 inline se::port::Status ToExecutorStatus(const Status& s) { 335 return s.ok() ? se::port::Status::OK() 336 : se::port::Status(static_cast<se::port::error::Code>( 337 static_cast<int>(s.code())), 338 s.error_message()); 339 } 340 341 template <typename> 342 struct ToTFDataType; 343 344 template <> 345 struct ToTFDataType<Eigen::half> : std::integral_constant<DataType, DT_HALF> {}; 346 347 template <> 348 struct ToTFDataType<float> : std::integral_constant<DataType, DT_FLOAT> {}; 349 350 template <> 351 struct ToTFDataType<double> : std::integral_constant<DataType, DT_DOUBLE> {}; 352 353 template <> 354 struct ToTFDataType<uint8> : std::integral_constant<DataType, DT_UINT8> {}; 355 356 // A helper to allocate temporary scratch memory for Cudnn RNN models. It 357 // takes the ownership of the underlying memory. The expectation is that the 358 // memory should be alive for the span of the Cudnn RNN itself. 359 template <typename T> 360 class CudnnRnnAllocatorInTemp : public ScratchAllocator { 361 public: 362 ~CudnnRnnAllocatorInTemp() override = default; 363 364 explicit CudnnRnnAllocatorInTemp(OpKernelContext* context) 365 : context_(context) {} 366 int64 GetMemoryLimitInBytes(Stream* stream) override { 367 return std::numeric_limits<int64>::max(); 368 } 369 370 StatusOr<DeviceMemory<uint8>> AllocateBytes(Stream* stream, 371 int64 byte_size) override { 372 Tensor temporary_memory; 373 const DataType tf_data_type = ToTFDataType<T>::value; 374 int64 allocate_count = 375 Eigen::divup(byte_size, static_cast<int64>(sizeof(T))); 376 Status allocation_status(context_->allocate_temp( 377 tf_data_type, TensorShape({allocate_count}), &temporary_memory)); 378 if (!allocation_status.ok()) { 379 return ToExecutorStatus(allocation_status); 380 } 381 // Hold the reference of the allocated tensors until the end of the 382 // allocator. 383 allocated_tensors_.push_back(temporary_memory); 384 total_byte_size_ += byte_size; 385 return DeviceMemory<uint8>::MakeFromByteSize( 386 temporary_memory.template flat<T>().data(), 387 temporary_memory.template flat<T>().size() * sizeof(T)); 388 } 389 390 int64 TotalByteSize() const { return total_byte_size_; } 391 392 Tensor get_allocated_tensor(int index) const { 393 return allocated_tensors_[index]; 394 } 395 396 private: 397 int64 total_byte_size_ = 0; 398 OpKernelContext* context_; // not owned 399 std::vector<Tensor> allocated_tensors_; 400 }; 401 402 // A helper to allocate memory for Cudnn RNN models as a kernel output. It is 403 // used by forward pass kernel to feed the output to the backward pass. 404 // The memory is expected to live long enough after the backward pass is 405 // finished. 406 template <typename T> 407 class CudnnRnnAllocatorInOutput : public ScratchAllocator { 408 public: 409 ~CudnnRnnAllocatorInOutput() override {} 410 CudnnRnnAllocatorInOutput(OpKernelContext* context, int output_index) 411 : context_(context), output_index_(output_index) {} 412 int64 GetMemoryLimitInBytes(Stream* stream) override { 413 return std::numeric_limits<int64>::max(); 414 } 415 StatusOr<DeviceMemory<uint8>> AllocateBytes(Stream* stream, 416 int64 byte_size) override { 417 CHECK(total_byte_size_ == 0) 418 << "Reserve space allocator can only be called once"; 419 int64 allocate_count = 420 Eigen::divup(byte_size, static_cast<int64>(sizeof(T))); 421 422 Tensor* temporary_memory = nullptr; 423 Status allocation_status(context_->allocate_output( 424 output_index_, TensorShape({allocate_count}), &temporary_memory)); 425 if (!allocation_status.ok()) { 426 return ToExecutorStatus(allocation_status); 427 } 428 total_byte_size_ += byte_size; 429 auto memory_uint8 = DeviceMemory<uint8>::MakeFromByteSize( 430 temporary_memory->template flat<T>().data(), 431 temporary_memory->template flat<T>().size() * sizeof(T)); 432 return StatusOr<DeviceMemory<uint8>>(memory_uint8); 433 } 434 int64 TotalByteSize() { return total_byte_size_; } 435 436 private: 437 int64 total_byte_size_ = 0; 438 OpKernelContext* context_; // not owned 439 int output_index_; 440 }; 441 442 // A helper to allocate persistent memory for Cudnn RNN models, which is 443 // expected to live between kernel invocations. 444 // This class is not thread-safe. 445 class CudnnRNNPersistentSpaceAllocator : public ScratchAllocator { 446 public: 447 explicit CudnnRNNPersistentSpaceAllocator(OpKernelContext* context) 448 : context_(context) {} 449 450 ~CudnnRNNPersistentSpaceAllocator() override {} 451 452 int64 GetMemoryLimitInBytes(Stream* stream) override { 453 return std::numeric_limits<int64>::max(); 454 } 455 456 StatusOr<DeviceMemory<uint8>> AllocateBytes(Stream* stream, 457 int64 byte_size) override { 458 if (total_byte_size_ != 0) { 459 return Status(error::FAILED_PRECONDITION, 460 "Persistent space allocator can only be called once"); 461 } 462 463 Status allocation_status = context_->allocate_persistent( 464 DT_UINT8, TensorShape({byte_size}), &handle_, nullptr); 465 if (!allocation_status.ok()) { 466 return ToExecutorStatus(allocation_status); 467 } 468 total_byte_size_ += byte_size; 469 return AsDeviceMemory<uint8>(handle_.AccessTensor(context_)); 470 } 471 int64 TotalByteSize() { return total_byte_size_; } 472 473 private: 474 int64 total_byte_size_ = 0; 475 PersistentTensor handle_; 476 OpKernelContext* context_; // not owned 477 }; 478 479 struct CudnnModelTypes { 480 RnnMode rnn_mode; 481 TFRNNInputMode rnn_input_mode; 482 RnnDirectionMode rnn_direction_mode; 483 bool HasInputC() const { 484 // For Cudnn 5.0, only LSTM has input-c. All other models use only 485 // input-h. 486 return rnn_mode == RnnMode::kRnnLstm; 487 } 488 489 string DebugString() const { 490 return strings::Printf( 491 "[rnn_mode, rnn_input_mode, rnn_direction_mode]: %d, %d, %d ", 492 static_cast<int>(rnn_mode), static_cast<int>(rnn_input_mode), 493 static_cast<int>(rnn_direction_mode)); 494 } 495 }; 496 497 // A helper class that collects the shapes to describe a RNN model. 498 struct CudnnRnnModelShapes { 499 int num_layers; 500 int input_size; 501 int num_units; 502 int dir_count; 503 int max_seq_length; 504 int batch_size; 505 TensorShape input_shape; 506 TensorShape output_shape; 507 TensorShape hidden_state_shape; 508 // At present only fields related to cached RnnDescriptor are concerned. 509 bool IsCompatibleWith(const CudnnRnnModelShapes& rhs) const { 510 return num_layers == rhs.num_layers && input_size == rhs.input_size && 511 num_units == rhs.num_units && dir_count == rhs.dir_count; 512 } 513 string DebugString() const { 514 return strings::Printf( 515 "[num_layers, input_size, num_units, dir_count, max_seq_length, " 516 "batch_size]: [%d, %d, %d, %d, %d, %d] ", 517 num_layers, input_size, num_units, dir_count, max_seq_length, 518 batch_size); 519 } 520 }; 521 522 // Utility class for using CudnnRnnConfig and AlgorithmDesc pair a hash table 523 // key. 524 struct CudnnRnnConfigHasher { 525 uint64 operator()( 526 const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>& 527 to_hash) const { 528 auto& shapes = to_hash.first; 529 auto& algo_desc = to_hash.second; 530 531 uint64 hash = 532 HashList({shapes.num_layers, shapes.input_size, shapes.num_units, 533 shapes.dir_count, shapes.batch_size}); 534 if (algo_desc.has_value()) { 535 hash = Hash64Combine(hash, algo_desc->hash()); 536 } 537 return hash; 538 } 539 }; 540 541 // Utility class for using CudnnRnnModelShapes and AlgorithmDesc pair as a hash 542 // table key. 543 struct CudnnRnnConfigComparator { 544 bool operator()( 545 const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>& lhs, 546 const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>& rhs) 547 const { 548 return lhs.first.IsCompatibleWith(rhs.first) && lhs.second == rhs.second; 549 } 550 }; 551 552 // Pointers to RNN scratch space for a specific set of shape parameters (used as 553 // a hash table value in CudnnRNNForwardOp and CudnnRNNBackwardOp). 554 struct RnnScratchSpace { 555 std::unique_ptr<RnnDescriptor> rnn_desc; 556 std::unique_ptr<CudnnRNNPersistentSpaceAllocator> dropout_state_allocator; 557 }; 558 559 // Extract and checks the forward input tensors, parameters, and shapes from the 560 // OpKernelContext. 561 Status ExtractForwardInput(OpKernelContext* context, 562 const CudnnModelTypes& model_types, bool time_major, 563 const Tensor** input, const Tensor** input_h, 564 const Tensor** input_c, const Tensor** params, 565 CudnnRnnModelShapes* model_shapes) { 566 TF_RETURN_IF_ERROR(context->input("input", input)); 567 TF_RETURN_IF_ERROR(context->input("input_h", input_h)); 568 if (model_types.HasInputC()) { 569 TF_RETURN_IF_ERROR(context->input("input_c", input_c)); 570 } 571 TF_RETURN_IF_ERROR(context->input("params", params)); 572 573 if ((*input)->dims() != 3) { 574 return errors::InvalidArgument("RNN input must be a 3-D vector."); 575 } 576 if (time_major) { 577 model_shapes->max_seq_length = (*input)->dim_size(0); 578 model_shapes->batch_size = (*input)->dim_size(1); 579 } else { 580 model_shapes->max_seq_length = (*input)->dim_size(1); 581 model_shapes->batch_size = (*input)->dim_size(0); 582 } 583 model_shapes->input_size = (*input)->dim_size(2); 584 model_shapes->input_shape = (*input)->shape(); 585 model_shapes->dir_count = 586 (model_types.rnn_direction_mode == RnnDirectionMode::kRnnBidirectional) 587 ? 2 588 : 1; 589 590 if ((*input_h)->dims() != 3) { 591 return errors::InvalidArgument("RNN input_h must be a 3-D vector."); 592 } 593 if (time_major) { 594 model_shapes->num_layers = 595 (*input_h)->dim_size(0) / model_shapes->dir_count; 596 } else { 597 model_shapes->num_layers = 598 (*input_h)->dim_size(1) / model_shapes->dir_count; 599 } 600 model_shapes->num_units = (*input_h)->dim_size(2); 601 602 if (time_major) { 603 model_shapes->hidden_state_shape = 604 TensorShape({model_shapes->dir_count * model_shapes->num_layers, 605 model_shapes->batch_size, model_shapes->num_units}); 606 } else { 607 model_shapes->hidden_state_shape = 608 TensorShape({model_shapes->batch_size, 609 model_shapes->dir_count * model_shapes->num_layers, 610 model_shapes->num_units}); 611 } 612 if ((*input_h)->shape() != model_shapes->hidden_state_shape) { 613 return errors::InvalidArgument( 614 "Invalid input_h shape: ", (*input_h)->shape().DebugString(), " ", 615 model_shapes->hidden_state_shape.DebugString()); 616 } 617 if (model_types.HasInputC()) { 618 if ((*input_h)->shape() != (*input_c)->shape()) { 619 return errors::InvalidArgument( 620 "input_h and input_c must have the same shape: ", 621 (*input_h)->shape().DebugString(), " ", 622 (*input_c)->shape().DebugString()); 623 } 624 } 625 if (time_major) { 626 model_shapes->output_shape = 627 TensorShape({model_shapes->max_seq_length, model_shapes->batch_size, 628 model_shapes->dir_count * model_shapes->num_units}); 629 } else { 630 model_shapes->output_shape = 631 TensorShape({model_shapes->batch_size, model_shapes->max_seq_length, 632 model_shapes->dir_count * model_shapes->num_units}); 633 } 634 return Status::OK(); 635 } 636 637 // Overloaded function to process the sequence_lengths 638 Status ExtractForwardInput(OpKernelContext* context, 639 const CudnnModelTypes& model_types, bool time_major, 640 const Tensor** input, const Tensor** input_h, 641 const Tensor** input_c, const Tensor** params, 642 const Tensor** sequence_lengths, 643 CudnnRnnModelShapes* model_shapes) { 644 TF_RETURN_IF_ERROR(context->input("sequence_lengths", sequence_lengths)); 645 return ExtractForwardInput(context, model_types, time_major, input, input_h, 646 input_c, params, model_shapes); 647 } 648 649 template <typename T> 650 Status CreateForwardAndBackwardIODescriptors( 651 OpKernelContext* context, const CudnnRnnModelShapes& model_shapes, 652 std::unique_ptr<RnnSequenceTensorDescriptor>* input_desc, 653 std::unique_ptr<RnnStateTensorDescriptor>* state_desc, 654 std::unique_ptr<RnnSequenceTensorDescriptor>* output_desc, 655 const absl::Span<const int>& seq_lengths, bool time_major) { 656 StreamExecutor* executor = context->op_device_context()->stream()->parent(); 657 se::dnn::DataType data_type = ToDataType<T>::value; 658 659 const TensorShape& input_shape = model_shapes.input_shape; 660 const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape; 661 const TensorShape& output_shape = model_shapes.output_shape; 662 663 DCHECK_EQ(input_shape.dims(), 3); 664 if (seq_lengths.data() != nullptr) { 665 if (time_major) { 666 auto input_desc_s = executor->createRnnSequenceTensorDescriptor( 667 input_shape.dim_size(0), input_shape.dim_size(1), 668 input_shape.dim_size(2), seq_lengths, time_major, data_type); 669 TF_RETURN_IF_ERROR(input_desc_s.status()); 670 *input_desc = input_desc_s.ConsumeValueOrDie(); 671 } else { 672 auto input_desc_s = executor->createRnnSequenceTensorDescriptor( 673 input_shape.dim_size(1), input_shape.dim_size(0), 674 input_shape.dim_size(2), seq_lengths, time_major, data_type); 675 TF_RETURN_IF_ERROR(input_desc_s.status()); 676 *input_desc = input_desc_s.ConsumeValueOrDie(); 677 } 678 } else { 679 auto input_desc_s = executor->createRnnSequenceTensorDescriptor( 680 input_shape.dim_size(0), input_shape.dim_size(1), 681 input_shape.dim_size(2), data_type); 682 TF_RETURN_IF_ERROR(input_desc_s.status()); 683 *input_desc = input_desc_s.ConsumeValueOrDie(); 684 } 685 686 DCHECK_EQ(hidden_state_shape.dims(), 3); 687 if (time_major) { 688 auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor( 689 hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1), 690 hidden_state_shape.dim_size(2), data_type); 691 TF_RETURN_IF_ERROR(hidden_state_desc_s.status()); 692 *state_desc = hidden_state_desc_s.ConsumeValueOrDie(); 693 } else { 694 auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor( 695 hidden_state_shape.dim_size(1), hidden_state_shape.dim_size(0), 696 hidden_state_shape.dim_size(2), data_type); 697 TF_RETURN_IF_ERROR(hidden_state_desc_s.status()); 698 *state_desc = hidden_state_desc_s.ConsumeValueOrDie(); 699 } 700 701 DCHECK_EQ(output_shape.dims(), 3); 702 if (seq_lengths.data() != nullptr) { 703 if (time_major) { 704 auto output_desc_s = executor->createRnnSequenceTensorDescriptor( 705 output_shape.dim_size(0), output_shape.dim_size(1), 706 output_shape.dim_size(2), seq_lengths, time_major, data_type); 707 TF_RETURN_IF_ERROR(output_desc_s.status()); 708 *output_desc = output_desc_s.ConsumeValueOrDie(); 709 } else { 710 auto output_desc_s = executor->createRnnSequenceTensorDescriptor( 711 output_shape.dim_size(1), output_shape.dim_size(0), 712 output_shape.dim_size(2), seq_lengths, time_major, data_type); 713 TF_RETURN_IF_ERROR(output_desc_s.status()); 714 *output_desc = output_desc_s.ConsumeValueOrDie(); 715 } 716 } else { 717 auto output_desc_s = executor->createRnnSequenceTensorDescriptor( 718 output_shape.dim_size(0), output_shape.dim_size(1), 719 output_shape.dim_size(2), data_type); 720 TF_RETURN_IF_ERROR(output_desc_s.status()); 721 *output_desc = output_desc_s.ConsumeValueOrDie(); 722 } 723 724 return Status::OK(); 725 } 726 727 template <typename T> 728 Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc, 729 const CudnnModelTypes& model_types, 730 const CudnnRnnModelShapes& model_shapes, 731 /* forward inputs */ 732 const Tensor* input, const Tensor* input_h, 733 const Tensor* input_c, const Tensor* params, 734 const bool is_training, 735 /* forward outputs, outputs of the function */ 736 Tensor* output, Tensor* output_h, Tensor* output_c, 737 const Tensor* sequence_lengths, bool time_major, 738 ScratchAllocator* reserve_space_allocator, 739 ScratchAllocator* workspace_allocator, 740 ProfileResult* output_profile_result) { 741 std::unique_ptr<RnnSequenceTensorDescriptor> input_desc; 742 std::unique_ptr<RnnStateTensorDescriptor> state_desc; 743 std::unique_ptr<RnnSequenceTensorDescriptor> output_desc; 744 745 absl::Span<const int> seq_lengths; 746 if (sequence_lengths != nullptr) { 747 seq_lengths = absl::Span<const int>( 748 sequence_lengths->template flat<int>().data(), model_shapes.batch_size); 749 } 750 TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>( 751 context, model_shapes, &input_desc, &state_desc, &output_desc, 752 seq_lengths, time_major)); 753 754 auto input_data = AsDeviceMemory<T>(input); 755 auto input_h_data = AsDeviceMemory<T>(input_h); 756 DeviceMemory<T> input_c_data; 757 if (model_types.HasInputC()) { 758 input_c_data = AsDeviceMemory<T>(input_c); 759 } 760 761 auto params_data = AsDeviceMemory<T>(params); 762 auto output_data = AsDeviceMemory<T>(output); 763 auto output_h_data = AsDeviceMemory<T>(output_h); 764 DeviceMemory<T> output_c_data; 765 if (model_types.HasInputC()) { 766 output_c_data = AsDeviceMemory<T>(output_c); 767 } 768 769 Stream* stream = context->op_device_context()->stream(); 770 bool launch_success = 771 stream 772 ->ThenRnnForward(rnn_desc, *input_desc, input_data, *state_desc, 773 input_h_data, *state_desc, input_c_data, params_data, 774 *output_desc, &output_data, *state_desc, 775 &output_h_data, *state_desc, &output_c_data, 776 is_training, reserve_space_allocator, 777 workspace_allocator, output_profile_result) 778 .ok(); 779 return launch_success 780 ? Status::OK() 781 : errors::Internal( 782 "Failed to call ThenRnnForward with model config: ", 783 model_types.DebugString(), ", ", model_shapes.DebugString()); 784 } 785 786 template <typename T> 787 Status DoBackward( 788 OpKernelContext* context, const RnnDescriptor& rnn_desc, 789 const CudnnModelTypes& model_types, const CudnnRnnModelShapes& model_shapes, 790 /* forward inputs */ 791 const Tensor* input, const Tensor* input_h, const Tensor* input_c, 792 const Tensor* params, 793 /* forward outputs */ 794 const Tensor* output, const Tensor* output_h, const Tensor* output_c, 795 /* backprop inputs */ 796 const Tensor* output_backprop, const Tensor* output_h_backprop, 797 const Tensor* output_c_backprop, const Tensor* reserve_space, 798 /* backprop outputs, output of the function */ 799 Tensor* input_backprop, Tensor* input_h_backprop, Tensor* input_c_backprop, 800 Tensor* params_backprop, const Tensor* sequence_lengths, bool time_major, 801 ScratchAllocator* workspace_allocator, 802 ProfileResult* output_profile_result) { 803 std::unique_ptr<RnnSequenceTensorDescriptor> input_desc; 804 std::unique_ptr<RnnStateTensorDescriptor> state_desc; 805 std::unique_ptr<RnnSequenceTensorDescriptor> output_desc; 806 807 absl::Span<const int> seq_lengths; 808 if (sequence_lengths != nullptr) { 809 seq_lengths = absl::Span<const int>( 810 sequence_lengths->template flat<int>().data(), model_shapes.batch_size); 811 } 812 TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>( 813 context, model_shapes, &input_desc, &state_desc, &output_desc, 814 seq_lengths, time_major)); 815 816 auto input_data = AsDeviceMemory<T>(input); 817 auto input_h_data = AsDeviceMemory<T>(input_h); 818 DeviceMemory<T> input_c_data; 819 if (model_types.HasInputC()) { 820 input_c_data = AsDeviceMemory<T>(input_c); 821 } 822 auto params_data = AsDeviceMemory<T>(params); 823 auto output_data = AsDeviceMemory<T>(output); 824 auto output_h_data = AsDeviceMemory<T>(output_h); 825 DeviceMemory<T> output_c_data; 826 if (model_types.HasInputC()) { 827 output_c_data = AsDeviceMemory<T>(output_c); 828 } 829 auto output_backprop_data = AsDeviceMemory<T>(output_backprop); 830 auto output_h_backprop_data = AsDeviceMemory<T>(output_h_backprop); 831 DeviceMemory<T> output_c_backprop_data; 832 if (model_types.HasInputC()) { 833 output_c_backprop_data = AsDeviceMemory<T>(output_c_backprop); 834 } 835 auto input_backprop_data = AsDeviceMemory<T>(input_backprop); 836 auto input_h_backprop_data = AsDeviceMemory<T>(input_h_backprop); 837 DeviceMemory<T> input_c_backprop_data; 838 if (model_types.HasInputC()) { 839 input_c_backprop_data = AsDeviceMemory<T>(input_c_backprop); 840 } 841 auto params_backprop_data = AsDeviceMemory<T>(params_backprop); 842 auto reserve_space_uint8 = 843 CastDeviceMemory<uint8, T>(const_cast<Tensor*>(reserve_space)); 844 845 // Creates a memory callback for the workspace. The memory lives to the end 846 // of this kernel calls. 847 Stream* stream = context->op_device_context()->stream(); 848 bool launch_success = 849 stream 850 ->ThenRnnBackward(rnn_desc, *input_desc, input_data, *state_desc, 851 input_h_data, *state_desc, input_c_data, 852 params_data, *output_desc, output_data, *state_desc, 853 output_h_data, *state_desc, output_c_data, 854 output_backprop_data, output_h_backprop_data, 855 output_c_backprop_data, &input_backprop_data, 856 &input_h_backprop_data, &input_c_backprop_data, 857 ¶ms_backprop_data, &reserve_space_uint8, 858 workspace_allocator, output_profile_result) 859 .ok(); 860 return launch_success 861 ? Status::OK() 862 : errors::Internal( 863 "Failed to call ThenRnnBackward with model config: ", 864 model_types.DebugString(), ", ", model_shapes.DebugString()); 865 } 866 867 template <typename T> 868 void RestoreParams(const OpInputList params_input, 869 const std::vector<RnnDescriptor::ParamsRegion>& params, 870 DeviceMemoryBase* data_dst, Stream* stream) { 871 int num_params = params.size(); 872 CHECK(params_input.size() == num_params) 873 << "Number of params mismatch. Expected " << params_input.size() 874 << ", got " << num_params; 875 for (int i = 0; i < params.size(); i++) { 876 int64 size_in_bytes = params[i].size; 877 int64 size = size_in_bytes / sizeof(T); 878 CHECK(size == params_input[i].NumElements()) 879 << "Params size mismatch. Expected " << size << ", got " 880 << params_input[i].NumElements(); 881 auto data_src_ptr = StreamExecutorUtil::AsDeviceMemory<T>(params_input[i]); 882 DeviceMemoryBase data_dst_ptr = 883 SliceDeviceMemory(*data_dst, params[i].offset, size_in_bytes); 884 stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes); 885 } 886 } 887 888 } // namespace 889 890 // Note: all following kernels depend on a RnnDescriptor instance, which 891 // according to Cudnn official doc should be kept around and reused across all 892 // Cudnn kernels in the same model. 893 // In Tensorflow, we don't pass the reference across different OpKernels, 894 // rather, recreate it separately in each OpKernel, which does no cause issue: 895 // CudnnDropoutDescriptor keeps a reference to a memory for 896 // random number generator state. During recreation, this state is lost. 897 // However, only forward-pass Cudnn APIs make use of the state. 898 899 // A common base class for RNN kernels. It extracts common attributes and 900 // shape validations. 901 class CudnnRNNKernelCommon : public OpKernel { 902 protected: 903 explicit CudnnRNNKernelCommon(OpKernelConstruction* context) 904 : OpKernel(context) { 905 OP_REQUIRES_OK(context, context->GetAttr("dropout", &dropout_)); 906 OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_)); 907 OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_)); 908 string str; 909 OP_REQUIRES_OK(context, context->GetAttr("rnn_mode", &str)); 910 OP_REQUIRES_OK(context, ParseRNNMode(str, &model_types_.rnn_mode)); 911 OP_REQUIRES_OK(context, context->GetAttr("input_mode", &str)); 912 OP_REQUIRES_OK(context, 913 ParseTFRNNInputMode(str, &model_types_.rnn_input_mode)); 914 OP_REQUIRES_OK(context, context->GetAttr("direction", &str)); 915 OP_REQUIRES_OK( 916 context, ParseRNNDirectionMode(str, &model_types_.rnn_direction_mode)); 917 // Reset CudnnRnnDescriptor and related random number generate states in 918 // every Compute() call. 919 OP_REQUIRES_OK(context, ReadBoolFromEnvVar("TF_CUDNN_RESET_RND_GEN_STATE", 920 false, &reset_rnd_gen_state_)); 921 } 922 923 bool HasInputC() const { return model_types_.HasInputC(); } 924 RnnMode rnn_mode() const { return model_types_.rnn_mode; } 925 TFRNNInputMode rnn_input_mode() const { return model_types_.rnn_input_mode; } 926 RnnDirectionMode rnn_direction_mode() const { 927 return model_types_.rnn_direction_mode; 928 } 929 const CudnnModelTypes& model_types() const { return model_types_; } 930 float dropout() const { return dropout_; } 931 uint64 seed() { return (static_cast<uint64>(seed_) << 32) | seed2_; } 932 bool ResetRndGenState() { return reset_rnd_gen_state_; } 933 934 template <typename T> 935 Status ExtractCudnnRNNParamsInfo(OpKernelContext* context, 936 std::unique_ptr<RnnDescriptor>* rnn_desc) { 937 const Tensor* num_layers_t = nullptr; 938 TF_RETURN_IF_ERROR(context->input("num_layers", &num_layers_t)); 939 if (!TensorShapeUtils::IsScalar(num_layers_t->shape())) { 940 return errors::InvalidArgument("num_layers is not a scalar"); 941 } 942 int num_layers = num_layers_t->scalar<int>()(); 943 const Tensor* num_units_t = nullptr; 944 TF_RETURN_IF_ERROR(context->input("num_units", &num_units_t)); 945 if (!TensorShapeUtils::IsScalar(num_units_t->shape())) { 946 return errors::InvalidArgument("num_units is not a scalar"); 947 } 948 int num_units = num_units_t->scalar<int>()(); 949 const Tensor* input_size_t = nullptr; 950 TF_RETURN_IF_ERROR(context->input("input_size", &input_size_t)); 951 if (!TensorShapeUtils::IsScalar(input_size_t->shape())) { 952 return errors::InvalidArgument("input_size is not a scalar"); 953 } 954 int input_size = input_size_t->scalar<int>()(); 955 956 RnnInputMode input_mode; 957 TF_RETURN_IF_ERROR( 958 ToRNNInputMode(rnn_input_mode(), num_units, input_size, &input_mode)); 959 960 Stream* stream = context->op_device_context()->stream(); 961 // ExtracCudnnRNNParamsInfo is only called by op_kernels that do not require 962 // random number generator, therefore set state_allocator to nullptr. 963 const AlgorithmConfig algo_config; 964 auto rnn_desc_s = stream->parent()->createRnnDescriptor( 965 num_layers, num_units, input_size, /*batch_size=*/0, input_mode, 966 rnn_direction_mode(), rnn_mode(), ToDataType<T>::value, algo_config, 967 dropout(), seed(), /* state_allocator=*/nullptr); 968 if (!rnn_desc_s.ok()) { 969 return FromExecutorStatus(rnn_desc_s); 970 } 971 *rnn_desc = rnn_desc_s.ConsumeValueOrDie(); 972 return Status::OK(); 973 } 974 975 template <typename T> 976 Status CreateRnnDescriptor(OpKernelContext* context, 977 const CudnnRnnModelShapes& model_shapes, 978 const RnnInputMode& input_mode, 979 const AlgorithmConfig& algo_config, 980 ScratchAllocator* dropout_state_allocator, 981 std::unique_ptr<RnnDescriptor>* rnn_desc) { 982 StreamExecutor* executor = context->op_device_context()->stream()->parent(); 983 se::dnn::DataType data_type = ToDataType<T>::value; 984 auto rnn_desc_s = executor->createRnnDescriptor( 985 model_shapes.num_layers, model_shapes.num_units, 986 model_shapes.input_size, model_shapes.batch_size, input_mode, 987 rnn_direction_mode(), rnn_mode(), data_type, algo_config, dropout(), 988 seed(), dropout_state_allocator); 989 TF_RETURN_IF_ERROR(rnn_desc_s.status()); 990 991 *rnn_desc = rnn_desc_s.ConsumeValueOrDie(); 992 return Status::OK(); 993 } 994 995 using RnnStateCache = gtl::FlatMap< 996 std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>, 997 RnnScratchSpace, CudnnRnnConfigHasher, CudnnRnnConfigComparator>; 998 // Returns a raw rnn descriptor pointer. The cache owns the rnn descriptor and 999 // should outlive the returned pointer. 1000 template <typename T> 1001 Status GetCachedRnnDescriptor(OpKernelContext* context, 1002 const CudnnRnnModelShapes& model_shapes, 1003 const RnnInputMode& input_mode, 1004 const AlgorithmConfig& algo_config, 1005 RnnStateCache* cache, 1006 RnnDescriptor** rnn_desc) { 1007 auto key = std::make_pair(model_shapes, algo_config.algorithm()); 1008 RnnScratchSpace& rnn_state = (*cache)[key]; 1009 if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) { 1010 CudnnRNNPersistentSpaceAllocator* dropout_state_allocator = 1011 new CudnnRNNPersistentSpaceAllocator(context); 1012 rnn_state.dropout_state_allocator.reset(dropout_state_allocator); 1013 Status status = 1014 CreateRnnDescriptor<T>(context, model_shapes, input_mode, algo_config, 1015 dropout_state_allocator, &rnn_state.rnn_desc); 1016 TF_RETURN_IF_ERROR(status); 1017 } 1018 *rnn_desc = rnn_state.rnn_desc.get(); 1019 return Status::OK(); 1020 } 1021 1022 private: 1023 int seed_; 1024 int seed2_; 1025 float dropout_; 1026 bool reset_rnd_gen_state_; 1027 1028 CudnnModelTypes model_types_; 1029 }; 1030 1031 // A class that returns the size of the opaque parameter buffer. The user should 1032 // use that to create the actual parameter buffer for training. However, it 1033 // should not be used for saving and restoring. 1034 template <typename T, typename Index> 1035 class CudnnRNNParamsSizeOp<GPUDevice, T, Index> : public CudnnRNNKernelCommon { 1036 public: 1037 explicit CudnnRNNParamsSizeOp(OpKernelConstruction* context) 1038 : CudnnRNNKernelCommon(context) {} 1039 1040 void Compute(OpKernelContext* context) override { 1041 std::unique_ptr<RnnDescriptor> rnn_desc; 1042 OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc)); 1043 int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes(); 1044 CHECK(params_size_in_bytes % sizeof(T) == 0) 1045 << "params_size_in_bytes must be multiple of element size"; 1046 int64 params_size = params_size_in_bytes / sizeof(T); 1047 1048 Tensor* output_t = nullptr; 1049 OP_REQUIRES_OK(context, context->allocate_output(0, {1}, &output_t)); 1050 *output_t->template flat<Index>().data() = params_size; 1051 } 1052 }; 1053 1054 #define REGISTER_GPU(T) \ 1055 REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsSize") \ 1056 .Device(DEVICE_GPU) \ 1057 .HostMemory("num_layers") \ 1058 .HostMemory("num_units") \ 1059 .HostMemory("input_size") \ 1060 .HostMemory("params_size") \ 1061 .TypeConstraint<T>("T") \ 1062 .TypeConstraint<int32>("S"), \ 1063 CudnnRNNParamsSizeOp<GPUDevice, T, int32>); 1064 1065 TF_CALL_half(REGISTER_GPU); 1066 TF_CALL_float(REGISTER_GPU); 1067 TF_CALL_double(REGISTER_GPU); 1068 #undef REGISTER_GPU 1069 1070 // Convert weight and bias params from a platform-specific layout to the 1071 // canonical form. 1072 template <typename T> 1073 class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon { 1074 public: 1075 explicit CudnnRNNParamsToCanonical(OpKernelConstruction* context) 1076 : CudnnRNNKernelCommon(context) { 1077 OP_REQUIRES_OK(context, context->GetAttr("num_params", &num_params_)); 1078 } 1079 1080 void Compute(OpKernelContext* context) override { 1081 const Tensor& input = context->input(3); 1082 auto input_ptr = StreamExecutorUtil::AsDeviceMemory<T>(input); 1083 Stream* stream = context->op_device_context()->stream(); 1084 1085 std::unique_ptr<RnnDescriptor> rnn_desc; 1086 OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc)); 1087 int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes(); 1088 CHECK(params_size_in_bytes % sizeof(T) == 0) 1089 << "params_size_in_bytes must be multiple of element size"; 1090 1091 const Tensor* num_units_t = nullptr; 1092 OP_REQUIRES_OK(context, context->input("num_units", &num_units_t)); 1093 CHECK(TensorShapeUtils::IsScalar(num_units_t->shape())) 1094 << "num_units is not a scalar"; 1095 int num_units = num_units_t->scalar<int>()(); 1096 1097 const Tensor* input_size_t = nullptr; 1098 OP_REQUIRES_OK(context, context->input("input_size", &input_size_t)); 1099 CHECK(TensorShapeUtils::IsScalar(input_size_t->shape())) 1100 << "input_size is not a scalar"; 1101 int input_size = input_size_t->scalar<int>()(); 1102 1103 const Tensor* num_layers_t = nullptr; 1104 OP_REQUIRES_OK(context, context->input("num_layers", &num_layers_t)); 1105 CHECK(TensorShapeUtils::IsScalar(num_layers_t->shape())) 1106 << "num_layers is not a scalar"; 1107 int num_layers = num_layers_t->scalar<int>()(); 1108 int num_dirs = 1; 1109 if (rnn_direction_mode() == RnnDirectionMode::kRnnBidirectional) { 1110 num_dirs = 2; 1111 } 1112 const int num_params_per_layer = num_params_ / num_layers / num_dirs; 1113 // Number of params applied on inputs. The rest are applied on recurrent 1114 // hidden states. 1115 const int num_params_input_state = num_params_per_layer / 2; 1116 CHECK(num_params_ % (num_layers * num_dirs) == 0) 1117 << "Number of params is not a multiple of num_layers * num_dirs."; 1118 CHECK(num_params_per_layer % 2 == 0) 1119 << "Number of params per layer is not a even number."; 1120 1121 CHECK(num_params_ == rnn_desc->ParamsWeightRegions().size()) 1122 << "Number of params mismatch. Expected " << num_params_ << ", got " 1123 << rnn_desc->ParamsWeightRegions().size(); 1124 for (int i = 0; i < rnn_desc->ParamsWeightRegions().size(); i++) { 1125 int64 size_in_bytes = rnn_desc->ParamsWeightRegions()[i].size; 1126 int64 size = size_in_bytes / sizeof(T); 1127 const int layer_idx = i / num_params_per_layer; 1128 const int index_within_layer = i % num_params_per_layer; 1129 int width = 0, height = num_units; 1130 // In CuDNN layout, each layer has num_params_per_layer params, with the 1131 // first half a.k.a num_params_input_state params applied on the inputs, 1132 // and the second half on the recurrent hidden states. 1133 bool apply_on_input_state = index_within_layer < num_params_input_state; 1134 if (rnn_direction_mode() == RnnDirectionMode::kRnnUnidirectional) { 1135 if (layer_idx == 0 && apply_on_input_state) { 1136 width = input_size; 1137 } else { 1138 width = num_units; 1139 } 1140 } else { 1141 if (apply_on_input_state) { 1142 if (layer_idx <= 1) { 1143 // First fwd or bak layer. 1144 width = input_size; 1145 } else { 1146 // Following layers, cell inputs are concatenated outputs of 1147 // its prior layer. 1148 width = 2 * num_units; 1149 } 1150 } else { 1151 width = num_units; 1152 } 1153 } 1154 CHECK(size == width * height) << "Params size mismatch. Expected " 1155 << width * height << ", got " << size; 1156 Tensor* output = nullptr; 1157 OP_REQUIRES_OK(context, context->allocate_output( 1158 i, TensorShape({height, width}), &output)); 1159 DeviceMemoryBase data_src_ptr = SliceDeviceMemory( 1160 input_ptr, rnn_desc->ParamsWeightRegions()[i].offset, size_in_bytes); 1161 auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output); 1162 stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes); 1163 } 1164 1165 OP_REQUIRES(context, num_params_ == rnn_desc->ParamsBiasRegions().size(), 1166 errors::InvalidArgument("Number of params mismatch. Expected ", 1167 num_params_, ", got ", 1168 rnn_desc->ParamsBiasRegions().size())); 1169 for (int i = 0; i < rnn_desc->ParamsBiasRegions().size(); i++) { 1170 int64 size_in_bytes = rnn_desc->ParamsBiasRegions()[i].size; 1171 int64 size = size_in_bytes / sizeof(T); 1172 OP_REQUIRES(context, size == num_units, 1173 errors::InvalidArgument("Params size mismatch. Expected ", 1174 num_units, ", got ", size)); 1175 1176 Tensor* output = nullptr; 1177 OP_REQUIRES_OK(context, 1178 context->allocate_output(num_params_ + i, 1179 TensorShape({size}), &output)); 1180 DeviceMemoryBase data_src_ptr = SliceDeviceMemory( 1181 input_ptr, rnn_desc->ParamsBiasRegions()[i].offset, size_in_bytes); 1182 auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output); 1183 stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes); 1184 } 1185 } 1186 1187 private: 1188 int num_params_; 1189 }; 1190 1191 #define REGISTER_GPU(T) \ 1192 REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsToCanonical") \ 1193 .Device(DEVICE_GPU) \ 1194 .HostMemory("num_layers") \ 1195 .HostMemory("num_units") \ 1196 .HostMemory("input_size") \ 1197 .TypeConstraint<T>("T"), \ 1198 CudnnRNNParamsToCanonical<GPUDevice, T>); 1199 TF_CALL_half(REGISTER_GPU); 1200 TF_CALL_float(REGISTER_GPU); 1201 TF_CALL_double(REGISTER_GPU); 1202 #undef REGISTER_GPU 1203 1204 // Convert weight and bias params from the canonical form to a 1205 // platform-specific layout. 1206 template <typename T> 1207 class CudnnRNNCanonicalToParams<GPUDevice, T> : public CudnnRNNKernelCommon { 1208 public: 1209 explicit CudnnRNNCanonicalToParams(OpKernelConstruction* context) 1210 : CudnnRNNKernelCommon(context) {} 1211 1212 void Compute(OpKernelContext* context) override { 1213 std::unique_ptr<RnnDescriptor> rnn_desc; 1214 OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc)); 1215 int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes(); 1216 CHECK(params_size_in_bytes % sizeof(T) == 0) 1217 << "params_size_in_bytes must be multiple of element size"; 1218 Tensor* output = nullptr; 1219 int params_size = params_size_in_bytes / sizeof(T); 1220 OP_REQUIRES_OK(context, 1221 context->allocate_output(0, {params_size}, &output)); 1222 auto output_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output); 1223 Stream* stream = context->op_device_context()->stream(); 1224 1225 OpInputList weights; 1226 OP_REQUIRES_OK(context, context->input_list("weights", &weights)); 1227 RestoreParams<T>(weights, rnn_desc->ParamsWeightRegions(), &output_ptr, 1228 stream); 1229 1230 OpInputList biases; 1231 OP_REQUIRES_OK(context, context->input_list("biases", &biases)); 1232 RestoreParams<T>(biases, rnn_desc->ParamsBiasRegions(), &output_ptr, 1233 stream); 1234 } 1235 }; 1236 1237 #define REGISTER_GPU(T) \ 1238 REGISTER_KERNEL_BUILDER(Name("CudnnRNNCanonicalToParams") \ 1239 .Device(DEVICE_GPU) \ 1240 .HostMemory("num_layers") \ 1241 .HostMemory("num_units") \ 1242 .HostMemory("input_size") \ 1243 .TypeConstraint<T>("T"), \ 1244 CudnnRNNCanonicalToParams<GPUDevice, T>); 1245 TF_CALL_half(REGISTER_GPU); 1246 TF_CALL_float(REGISTER_GPU); 1247 TF_CALL_double(REGISTER_GPU); 1248 #undef REGISTER_GPU 1249 1250 // Run the forward operation of the RNN model. 1251 template <typename T> 1252 class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon { 1253 public: 1254 explicit CudnnRNNForwardOp(OpKernelConstruction* context) 1255 : CudnnRNNKernelCommon(context) { 1256 OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_)); 1257 1258 // Read debug env variables. 1259 is_debug_mode_ = DebugCudnnRnn(); 1260 debug_cudnn_rnn_algo_ = DebugCudnnRnnAlgo(); 1261 debug_use_tensor_ops_ = DebugCudnnRnnUseTensorOps(); 1262 } 1263 1264 void Compute(OpKernelContext* context) override { 1265 AlgorithmConfig algo_config; 1266 ComputeAndReturnAlgorithm(context, &algo_config, /*var_seq_lengths=*/false, 1267 /*time_major=*/true); 1268 } 1269 1270 protected: 1271 virtual void ComputeAndReturnAlgorithm(OpKernelContext* context, 1272 AlgorithmConfig* output_algo_config, 1273 bool var_seq_lengths, 1274 bool time_major) { 1275 CHECK_NE(output_algo_config, nullptr); 1276 1277 const Tensor* input = nullptr; 1278 const Tensor* input_h = nullptr; 1279 const Tensor* input_c = nullptr; 1280 const Tensor* params = nullptr; 1281 const Tensor* sequence_lengths = nullptr; 1282 CudnnRnnModelShapes model_shapes; 1283 if (var_seq_lengths) { 1284 OP_REQUIRES_OK(context, 1285 ExtractForwardInput(context, model_types(), time_major, 1286 &input, &input_h, &input_c, ¶ms, 1287 &sequence_lengths, &model_shapes)); 1288 } else { 1289 OP_REQUIRES_OK(context, ExtractForwardInput( 1290 context, model_types(), time_major, &input, 1291 &input_h, &input_c, ¶ms, &model_shapes)); 1292 } 1293 RnnInputMode input_mode; 1294 OP_REQUIRES_OK(context, 1295 ToRNNInputMode(rnn_input_mode(), model_shapes.num_units, 1296 model_shapes.input_size, &input_mode)); 1297 1298 Tensor* output = nullptr; 1299 Tensor* output_h = nullptr; 1300 Tensor* output_c = nullptr; 1301 OP_REQUIRES_OK(context, AllocateOutputs(context, model_shapes, &output, 1302 &output_h, &output_c)); 1303 1304 // Creates a memory callback for the reserve_space. The memory lives in the 1305 // output of this kernel. And it will be fed into the backward pass when 1306 // needed. 1307 CudnnRnnAllocatorInOutput<T> reserve_space_allocator(context, 3); 1308 // Creates a memory callback for the workspace. The memory lives to the end 1309 // of this kernel calls. 1310 CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context); 1311 1312 if (is_debug_mode_) { 1313 AlgorithmDesc algo_desc(debug_cudnn_rnn_algo_, debug_use_tensor_ops_); 1314 output_algo_config->set_algorithm(algo_desc); 1315 } else { 1316 OP_REQUIRES_OK(context, 1317 MaybeAutoTune(context, model_shapes, input_mode, input, 1318 input_h, input_c, params, output, output_h, 1319 output_c, output_algo_config)); 1320 } 1321 1322 Status launch_status; 1323 { 1324 mutex_lock l(mu_); 1325 RnnDescriptor* rnn_desc_ptr = nullptr; 1326 OP_REQUIRES_OK( 1327 context, GetCachedRnnDescriptor<T>(context, model_shapes, input_mode, 1328 *output_algo_config, 1329 &rnn_state_cache_, &rnn_desc_ptr)); 1330 launch_status = DoForward<T>( 1331 context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h, 1332 input_c, params, is_training_, output, output_h, output_c, 1333 sequence_lengths, time_major, &reserve_space_allocator, 1334 &workspace_allocator, /*output_profile_result=*/nullptr); 1335 } 1336 OP_REQUIRES_OK(context, launch_status); 1337 } 1338 1339 protected: 1340 virtual Status MaybeAutoTune(OpKernelContext* context, 1341 const CudnnRnnModelShapes& model_shapes, 1342 const RnnInputMode& input_mode, 1343 const Tensor* input, const Tensor* input_h, 1344 const Tensor* input_c, const Tensor* params, 1345 Tensor* output, Tensor* output_h, 1346 Tensor* output_c, 1347 AlgorithmConfig* best_algo_config) { 1348 CHECK_NE(best_algo_config, nullptr); 1349 *best_algo_config = AlgorithmConfig(); 1350 return Status::OK(); 1351 } 1352 1353 bool is_training() const { return is_training_; } 1354 bool is_debug_mode_; 1355 bool debug_use_tensor_ops_; 1356 int64 debug_cudnn_rnn_algo_; 1357 1358 private: 1359 Status AllocateOutputs(OpKernelContext* context, 1360 const CudnnRnnModelShapes& model_shapes, 1361 Tensor** output, Tensor** output_h, 1362 Tensor** output_c) { 1363 const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape; 1364 const TensorShape& output_shape = model_shapes.output_shape; 1365 1366 TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, output)); 1367 TF_RETURN_IF_ERROR( 1368 context->allocate_output(1, hidden_state_shape, output_h)); 1369 if (HasInputC()) { 1370 TF_RETURN_IF_ERROR( 1371 context->allocate_output(2, hidden_state_shape, output_c)); 1372 } else { 1373 // Only LSTM uses input_c and output_c. So for all other models, we only 1374 // need to create dummy outputs. 1375 TF_RETURN_IF_ERROR(context->allocate_output(2, {}, output_c)); 1376 } 1377 if (!is_training_) { 1378 Tensor* dummy_reserve_space = nullptr; 1379 TF_RETURN_IF_ERROR(context->allocate_output(3, {}, &dummy_reserve_space)); 1380 } 1381 return Status::OK(); 1382 } 1383 1384 mutex mu_; 1385 bool is_training_; 1386 RnnStateCache rnn_state_cache_ GUARDED_BY(mu_); 1387 }; 1388 1389 #define REGISTER_GPU(T) \ 1390 REGISTER_KERNEL_BUILDER( \ 1391 Name("CudnnRNN").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 1392 CudnnRNNForwardOp<GPUDevice, T>); 1393 1394 TF_CALL_half(REGISTER_GPU); 1395 TF_CALL_float(REGISTER_GPU); 1396 TF_CALL_double(REGISTER_GPU); 1397 #undef REGISTER_GPU 1398 1399 template <typename T> 1400 class CudnnRNNForwardOpV2<GPUDevice, T> 1401 : public CudnnRNNForwardOp<GPUDevice, T> { 1402 private: 1403 using CudnnRNNForwardOp<GPUDevice, T>::is_training; 1404 using CudnnRNNKernelCommon::CreateRnnDescriptor; 1405 using CudnnRNNKernelCommon::dropout; 1406 using CudnnRNNKernelCommon::HasInputC; 1407 using CudnnRNNKernelCommon::model_types; 1408 1409 public: 1410 explicit CudnnRNNForwardOpV2(OpKernelConstruction* context) 1411 : CudnnRNNForwardOp<GPUDevice, T>(context) {} 1412 1413 void Compute(OpKernelContext* context) override { 1414 AlgorithmConfig best_algo_config; 1415 CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm( 1416 context, &best_algo_config, /*var_seq_lengths=*/false, 1417 /*time_major=*/true); 1418 if (!context->status().ok()) { 1419 return; 1420 } 1421 1422 Tensor* output_host_reserved = nullptr; 1423 // output_host_reserved stores opaque info used for backprop when running 1424 // in training mode. At present, it includes a serialization of the best 1425 // AlgorithmDesc picked during rnn forward pass autotune. 1426 // int8 algorithm_id 1427 // int8 use_tensor_op 1428 // If autotune is not enabled, the algorithm_id is 1429 // stream_executor::dnn::kDefaultAlgorithm and use_tensor_op is false. If 1430 // running in inference mode, the output_host_reserved is currently not 1431 // populated. 1432 if (is_training()) { 1433 OP_REQUIRES_OK(context, context->allocate_output(4, TensorShape({2}), 1434 &output_host_reserved)); 1435 auto output_host_reserved_int8 = output_host_reserved->vec<int8>(); 1436 output_host_reserved_int8(0) = best_algo_config.algorithm()->algo_id(); 1437 output_host_reserved_int8(1) = 1438 best_algo_config.algorithm()->tensor_ops_enabled(); 1439 } else { 1440 OP_REQUIRES_OK(context, 1441 context->allocate_output(4, {}, &output_host_reserved)); 1442 } 1443 } 1444 1445 protected: 1446 Status MaybeAutoTune(OpKernelContext* context, 1447 const CudnnRnnModelShapes& model_shapes, 1448 const RnnInputMode& input_mode, const Tensor* input, 1449 const Tensor* input_h, const Tensor* input_c, 1450 const Tensor* params, Tensor* output, Tensor* output_h, 1451 Tensor* output_c, 1452 AlgorithmConfig* algo_config) override { 1453 CHECK_NE(algo_config, nullptr); 1454 if (!CudnnRnnUseAutotune() || this->is_debug_mode_) { 1455 *algo_config = AlgorithmConfig(); 1456 return Status::OK(); 1457 } 1458 1459 std::vector<AlgorithmDesc> algorithms; 1460 auto* stream = context->op_device_context()->stream(); 1461 CHECK(stream->parent()->GetRnnAlgorithms(&algorithms)); 1462 if (algorithms.empty()) { 1463 LOG(WARNING) << "No Rnn algorithm found"; 1464 return Status::OK(); 1465 } 1466 1467 const auto& modeltypes = model_types(); 1468 CudnnRnnParameters rnn_params( 1469 model_shapes.num_layers, model_shapes.input_size, 1470 model_shapes.num_units, model_shapes.max_seq_length, 1471 model_shapes.batch_size, model_shapes.dir_count, 1472 /*has_dropout=*/std::abs(dropout()) > 1e-8, is_training(), 1473 modeltypes.rnn_mode, modeltypes.rnn_input_mode, input->dtype()); 1474 1475 if (AutoTuneRnnConfigMap::GetInstance()->Find(rnn_params, algo_config)) { 1476 VLOG(1) << "Using existing best Cudnn RNN algorithm " 1477 << "(algo, tensor_op_enabled) = (" 1478 << algo_config->algorithm()->algo_id() << ", " 1479 << algo_config->algorithm()->tensor_ops_enabled() << ")."; 1480 return Status::OK(); 1481 } 1482 1483 // Create temp tensors when profiling backprop pass. 1484 auto data_type = input->dtype(); 1485 Tensor output_backprop; 1486 Tensor output_h_backprop; 1487 Tensor output_c_backprop; 1488 Tensor input_backprop; 1489 Tensor input_h_backprop; 1490 Tensor input_c_backprop; 1491 Tensor params_backprop; 1492 if (is_training()) { 1493 TF_RETURN_IF_ERROR(context->allocate_temp( 1494 data_type, model_shapes.output_shape, &output_backprop)); 1495 TF_RETURN_IF_ERROR(context->allocate_temp( 1496 data_type, model_shapes.hidden_state_shape, &output_h_backprop)); 1497 1498 TF_RETURN_IF_ERROR( 1499 context->allocate_temp(data_type, params->shape(), ¶ms_backprop)); 1500 TF_RETURN_IF_ERROR(context->allocate_temp( 1501 data_type, model_shapes.input_shape, &input_backprop)); 1502 TF_RETURN_IF_ERROR(context->allocate_temp( 1503 data_type, model_shapes.hidden_state_shape, &input_h_backprop)); 1504 if (HasInputC()) { 1505 TF_RETURN_IF_ERROR(context->allocate_temp( 1506 data_type, model_shapes.hidden_state_shape, &output_c_backprop)); 1507 TF_RETURN_IF_ERROR(context->allocate_temp( 1508 data_type, model_shapes.hidden_state_shape, &input_c_backprop)); 1509 } 1510 } 1511 ProfileResult best_result; 1512 for (auto& algo : algorithms) { 1513 VLOG(1) << "Profile Cudnn RNN algorithm (algo, tensor_op_enabled) = (" 1514 << algo.algo_id() << ", " << algo.tensor_ops_enabled() << ")."; 1515 Status status; 1516 ProfileResult final_profile_result; 1517 1518 ProfileResult fwd_profile_result; 1519 ProfileResult bak_profile_result; 1520 1521 // RnnDescriptor is algorithm-dependent, thus not reusable. 1522 std::unique_ptr<RnnDescriptor> rnn_desc; 1523 // Use a temp scratch allocator for the random num generator. 1524 CudnnRnnAllocatorInTemp<uint8> dropout_state_allocator(context); 1525 if (!this->template CreateRnnDescriptor<T>( 1526 context, model_shapes, input_mode, AlgorithmConfig(algo), 1527 &dropout_state_allocator, &rnn_desc) 1528 .ok()) { 1529 continue; 1530 } 1531 1532 // Again use temp scratch allocator during profiling. 1533 CudnnRnnAllocatorInTemp<T> reserve_space_allocator(context); 1534 CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context); 1535 status = DoForward<T>(context, *rnn_desc, model_types(), model_shapes, 1536 input, input_h, input_c, params, is_training(), 1537 output, output_h, output_c, nullptr, true, 1538 &reserve_space_allocator, &workspace_allocator, 1539 &fwd_profile_result); 1540 if (!status.ok()) { 1541 continue; 1542 } 1543 1544 if (is_training()) { 1545 // Get reserve space from the forward pass. 1546 Tensor reserve_space = reserve_space_allocator.get_allocated_tensor(0); 1547 status = DoBackward<T>( 1548 context, *rnn_desc, model_types(), model_shapes, input, input_h, 1549 input_c, params, output, output_h, output_c, &output_backprop, 1550 &output_h_backprop, &output_c_backprop, &reserve_space, 1551 &input_backprop, &input_h_backprop, &input_c_backprop, 1552 ¶ms_backprop, nullptr, true, &workspace_allocator, 1553 &bak_profile_result); 1554 if (!status.ok()) { 1555 continue; 1556 } 1557 final_profile_result.set_elapsed_time_in_ms( 1558 fwd_profile_result.elapsed_time_in_ms() + 1559 bak_profile_result.elapsed_time_in_ms()); 1560 } else { 1561 final_profile_result = fwd_profile_result; 1562 } 1563 1564 auto total_time = final_profile_result.elapsed_time_in_ms(); 1565 VLOG(1) << "Cudnn RNN algorithm (algo, tensor_op_enabled) = (" 1566 << algo.algo_id() << ", " << algo.tensor_ops_enabled() << ")" 1567 << " run time: " << total_time << " ms."; 1568 if (total_time < best_result.elapsed_time_in_ms()) { 1569 best_result.set_elapsed_time_in_ms(total_time); 1570 best_result.set_algorithm(algo); 1571 } 1572 } 1573 1574 if (!best_result.is_valid()) { 1575 return Status(error::Code::INTERNAL, "No algorithm worked!"); 1576 } 1577 algo_config->set_algorithm(best_result.algorithm()); 1578 VLOG(1) << "Best Cudnn RNN algorithm (algo, tensor_op_enabled) = (" 1579 << best_result.algorithm().algo_id() << ", " 1580 << best_result.algorithm().tensor_ops_enabled() << ")."; 1581 AutoTuneRnnConfigMap::GetInstance()->Insert(rnn_params, *algo_config); 1582 return Status::OK(); 1583 } 1584 }; 1585 1586 #define REGISTER_GPU(T) \ 1587 REGISTER_KERNEL_BUILDER(Name("CudnnRNNV2") \ 1588 .Device(DEVICE_GPU) \ 1589 .HostMemory("host_reserved") \ 1590 .TypeConstraint<T>("T"), \ 1591 CudnnRNNForwardOpV2<GPUDevice, T>); 1592 1593 TF_CALL_half(REGISTER_GPU); 1594 TF_CALL_float(REGISTER_GPU); 1595 TF_CALL_double(REGISTER_GPU); 1596 #undef REGISTER_GPU 1597 1598 template <typename T> 1599 class CudnnRNNForwardOpV3<GPUDevice, T> 1600 : public CudnnRNNForwardOp<GPUDevice, T> { 1601 private: 1602 using CudnnRNNForwardOp<GPUDevice, T>::is_training; 1603 using CudnnRNNKernelCommon::CreateRnnDescriptor; 1604 using CudnnRNNKernelCommon::dropout; 1605 using CudnnRNNKernelCommon::HasInputC; 1606 using CudnnRNNKernelCommon::model_types; 1607 bool time_major_; 1608 1609 protected: 1610 bool time_major() { return time_major_; } 1611 1612 public: 1613 explicit CudnnRNNForwardOpV3(OpKernelConstruction* context) 1614 : CudnnRNNForwardOp<GPUDevice, T>(context) { 1615 OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_)); 1616 } 1617 1618 void Compute(OpKernelContext* context) override { 1619 AlgorithmConfig best_algo_config; 1620 CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm( 1621 context, &best_algo_config, /*var_seq_lengths=*/true, 1622 /*time_major=*/time_major()); 1623 if (!context->status().ok()) { 1624 return; 1625 } 1626 1627 Tensor* output_host_reserved = nullptr; 1628 // TODO: Current V3 only uses the default standard algorithm to process 1629 // batches with variable sequences and the inputs should be padded. 1630 // Autotune is not supported yet. 1631 OP_REQUIRES_OK(context, 1632 context->allocate_output(4, {}, &output_host_reserved)); 1633 } 1634 }; 1635 1636 #define REGISTER_GPU(T) \ 1637 REGISTER_KERNEL_BUILDER(Name("CudnnRNNV3") \ 1638 .Device(DEVICE_GPU) \ 1639 .HostMemory("sequence_lengths") \ 1640 .HostMemory("host_reserved") \ 1641 .TypeConstraint<T>("T"), \ 1642 CudnnRNNForwardOpV3<GPUDevice, T>); 1643 1644 TF_CALL_half(REGISTER_GPU); 1645 TF_CALL_float(REGISTER_GPU); 1646 TF_CALL_double(REGISTER_GPU); 1647 #undef REGISTER_GPU 1648 1649 // Run the backward operation of the RNN model. 1650 template <typename T> 1651 class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon { 1652 public: 1653 explicit CudnnRNNBackwardOp(OpKernelConstruction* context) 1654 : CudnnRNNKernelCommon(context) {} 1655 1656 void Compute(OpKernelContext* context) override { 1657 ComputeImpl(context, false, true); 1658 } 1659 1660 protected: 1661 virtual void ComputeImpl(OpKernelContext* context, bool var_seq_lengths, 1662 bool time_major) { 1663 const Tensor* input = nullptr; 1664 const Tensor* input_h = nullptr; 1665 const Tensor* input_c = nullptr; 1666 const Tensor* params = nullptr; 1667 const Tensor* sequence_lengths = nullptr; 1668 CudnnRnnModelShapes model_shapes; 1669 if (var_seq_lengths) { 1670 OP_REQUIRES_OK(context, 1671 ExtractForwardInput(context, model_types(), time_major, 1672 &input, &input_h, &input_c, ¶ms, 1673 &sequence_lengths, &model_shapes)); 1674 } else { 1675 OP_REQUIRES_OK(context, ExtractForwardInput( 1676 context, model_types(), time_major, &input, 1677 &input_h, &input_c, ¶ms, &model_shapes)); 1678 } 1679 RnnInputMode input_mode; 1680 OP_REQUIRES_OK(context, 1681 ToRNNInputMode(rnn_input_mode(), model_shapes.num_units, 1682 model_shapes.input_size, &input_mode)); 1683 1684 const Tensor* output = nullptr; 1685 const Tensor* output_h = nullptr; 1686 const Tensor* output_c = nullptr; 1687 const Tensor* output_backprop = nullptr; 1688 const Tensor* output_h_backprop = nullptr; 1689 const Tensor* output_c_backprop = nullptr; 1690 const Tensor* reserve_space = nullptr; 1691 OP_REQUIRES_OK(context, 1692 ExtractBackwardInputs(context, model_shapes, model_types(), 1693 &output, &output_h, &output_c, 1694 &output_backprop, &output_h_backprop, 1695 &output_c_backprop, &reserve_space)); 1696 1697 Tensor* input_backprop = nullptr; 1698 Tensor* input_h_backprop = nullptr; 1699 Tensor* input_c_backprop = nullptr; 1700 Tensor* params_backprop = nullptr; 1701 OP_REQUIRES_OK(context, 1702 AllocateOutputs(context, model_shapes, params->shape(), 1703 &input_backprop, &input_h_backprop, 1704 &input_c_backprop, ¶ms_backprop)); 1705 1706 // Creates a memory callback for the workspace. The memory lives to the end 1707 // of this kernel calls. 1708 CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context); 1709 AlgorithmConfig algo_config; 1710 OP_REQUIRES_OK(context, GetAlgorithm(context, &algo_config)); 1711 Status launch_status; 1712 { 1713 mutex_lock l(mu_); 1714 RnnDescriptor* rnn_desc_ptr = nullptr; 1715 OP_REQUIRES_OK( 1716 context, GetCachedRnnDescriptor<T>(context, model_shapes, input_mode, 1717 algo_config, &rnn_state_cache_, 1718 &rnn_desc_ptr)); 1719 launch_status = DoBackward<T>( 1720 context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h, 1721 input_c, params, output, output_h, output_c, output_backprop, 1722 output_h_backprop, output_c_backprop, reserve_space, input_backprop, 1723 input_h_backprop, input_c_backprop, params_backprop, sequence_lengths, 1724 time_major, &workspace_allocator, 1725 /*output_profile_result=*/nullptr); 1726 } 1727 OP_REQUIRES_OK(context, launch_status); 1728 } 1729 1730 protected: 1731 virtual Status GetAlgorithm(OpKernelContext* context, 1732 AlgorithmConfig* algo_config) { 1733 CHECK_NE(algo_config, nullptr); 1734 *algo_config = AlgorithmConfig(); 1735 return Status::OK(); 1736 } 1737 1738 private: 1739 mutex mu_; 1740 RnnStateCache rnn_state_cache_ GUARDED_BY(mu_); 1741 1742 Status ExtractBackwardInputs( 1743 OpKernelContext* context, const CudnnRnnModelShapes& model_shapes, 1744 const CudnnModelTypes& model_types, const Tensor** output, 1745 const Tensor** output_h, const Tensor** output_c, 1746 const Tensor** output_backprop, const Tensor** output_h_backprop, 1747 const Tensor** output_c_backprop, const Tensor** reserve_space) { 1748 TF_RETURN_IF_ERROR(context->input("output", output)); 1749 TF_RETURN_IF_ERROR(context->input("output_backprop", output_backprop)); 1750 TF_RETURN_IF_ERROR(context->input("output_h", output_h)); 1751 TF_RETURN_IF_ERROR(context->input("output_h_backprop", output_h_backprop)); 1752 if (model_types.HasInputC()) { 1753 TF_RETURN_IF_ERROR(context->input("output_c", output_c)); 1754 TF_RETURN_IF_ERROR( 1755 context->input("output_c_backprop", output_c_backprop)); 1756 } 1757 TF_RETURN_IF_ERROR(context->input("reserve_space", reserve_space)); 1758 const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape; 1759 const TensorShape& output_shape = model_shapes.output_shape; 1760 1761 if (output_shape != (*output)->shape()) { 1762 return errors::InvalidArgument( 1763 "Invalid output shape: ", (*output)->shape().DebugString(), " ", 1764 output_shape.DebugString()); 1765 } 1766 if (hidden_state_shape != (*output_h)->shape()) { 1767 return errors::InvalidArgument( 1768 "Invalid output_h shape: ", (*output_h)->shape().DebugString(), " ", 1769 hidden_state_shape.DebugString()); 1770 } 1771 1772 if (output_shape != (*output_backprop)->shape()) { 1773 return errors::InvalidArgument("Invalid output_backprop shape: ", 1774 (*output_backprop)->shape().DebugString(), 1775 " ", output_shape.DebugString()); 1776 } 1777 if (hidden_state_shape != (*output_h_backprop)->shape()) { 1778 return errors::InvalidArgument( 1779 "Invalid output_h_backprop shape: ", 1780 (*output_h_backprop)->shape().DebugString(), " ", 1781 hidden_state_shape.DebugString()); 1782 } 1783 1784 if (model_types.HasInputC()) { 1785 if (hidden_state_shape != (*output_c)->shape()) { 1786 return errors::InvalidArgument( 1787 "Invalid output_c shape: ", (*output_c)->shape().DebugString(), " ", 1788 hidden_state_shape.DebugString()); 1789 } 1790 if (hidden_state_shape != (*output_c_backprop)->shape()) { 1791 return errors::InvalidArgument( 1792 "Invalid output_c_backprop shape: ", 1793 (*output_c_backprop)->shape().DebugString(), " ", 1794 hidden_state_shape.DebugString()); 1795 } 1796 } 1797 return Status::OK(); 1798 } 1799 1800 Status AllocateOutputs(OpKernelContext* context, 1801 const CudnnRnnModelShapes& model_shapes, 1802 const TensorShape& params_shape, 1803 Tensor** input_backprop, Tensor** input_h_backprop, 1804 Tensor** input_c_backprop, Tensor** params_backprop) { 1805 const TensorShape& input_shape = model_shapes.input_shape; 1806 const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape; 1807 1808 TF_RETURN_IF_ERROR( 1809 context->allocate_output(0, input_shape, input_backprop)); 1810 TF_RETURN_IF_ERROR( 1811 context->allocate_output(1, hidden_state_shape, input_h_backprop)); 1812 if (HasInputC()) { 1813 TF_RETURN_IF_ERROR( 1814 context->allocate_output(2, hidden_state_shape, input_c_backprop)); 1815 } else { 1816 // Only LSTM uses input_c and output_c. So for all other models, we only 1817 // need to create dummy outputs. 1818 TF_RETURN_IF_ERROR(context->allocate_output(2, {}, input_c_backprop)); 1819 } 1820 TF_RETURN_IF_ERROR( 1821 context->allocate_output(3, params_shape, params_backprop)); 1822 return Status::OK(); 1823 } 1824 }; 1825 1826 #define REGISTER_GPU(T) \ 1827 REGISTER_KERNEL_BUILDER( \ 1828 Name("CudnnRNNBackprop").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 1829 CudnnRNNBackwardOp<GPUDevice, T>); 1830 1831 TF_CALL_half(REGISTER_GPU); 1832 TF_CALL_float(REGISTER_GPU); 1833 TF_CALL_double(REGISTER_GPU); 1834 #undef REGISTER_GPU 1835 1836 template <typename T> 1837 class CudnnRNNBackwardOpV2<GPUDevice, T> 1838 : public CudnnRNNBackwardOp<GPUDevice, T> { 1839 public: 1840 explicit CudnnRNNBackwardOpV2(OpKernelConstruction* context) 1841 : CudnnRNNBackwardOp<GPUDevice, T>(context) {} 1842 1843 protected: 1844 Status GetAlgorithm(OpKernelContext* context, 1845 AlgorithmConfig* algo_config) override { 1846 CHECK_NE(algo_config, nullptr); 1847 const Tensor* host_reserved = nullptr; 1848 TF_RETURN_IF_ERROR(context->input("host_reserved", &host_reserved)); 1849 1850 auto host_reserved_int8 = host_reserved->vec<int8>(); 1851 const AlgorithmDesc algo_desc(host_reserved_int8(0), host_reserved_int8(1)); 1852 algo_config->set_algorithm(algo_desc); 1853 return Status::OK(); 1854 } 1855 }; 1856 1857 #define REGISTER_GPU(T) \ 1858 REGISTER_KERNEL_BUILDER(Name("CudnnRNNBackpropV2") \ 1859 .Device(DEVICE_GPU) \ 1860 .HostMemory("host_reserved") \ 1861 .TypeConstraint<T>("T"), \ 1862 CudnnRNNBackwardOpV2<GPUDevice, T>); 1863 1864 TF_CALL_half(REGISTER_GPU); 1865 TF_CALL_float(REGISTER_GPU); 1866 TF_CALL_double(REGISTER_GPU); 1867 #undef REGISTER_GPU 1868 1869 template <typename T> 1870 class CudnnRNNBackwardOpV3<GPUDevice, T> 1871 : public CudnnRNNBackwardOp<GPUDevice, T> { 1872 private: 1873 bool time_major_; 1874 1875 protected: 1876 bool time_major() { return time_major_; } 1877 1878 public: 1879 explicit CudnnRNNBackwardOpV3(OpKernelConstruction* context) 1880 : CudnnRNNBackwardOp<GPUDevice, T>(context) { 1881 OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_)); 1882 } 1883 1884 void Compute(OpKernelContext* context) override { 1885 CudnnRNNBackwardOp<GPUDevice, T>::ComputeImpl(context, true, time_major()); 1886 } 1887 }; 1888 1889 #define REGISTER_GPU(T) \ 1890 REGISTER_KERNEL_BUILDER(Name("CudnnRNNBackpropV3") \ 1891 .Device(DEVICE_GPU) \ 1892 .HostMemory("sequence_lengths") \ 1893 .HostMemory("host_reserved") \ 1894 .TypeConstraint<T>("T"), \ 1895 CudnnRNNBackwardOpV3<GPUDevice, T>); 1896 1897 TF_CALL_half(REGISTER_GPU); 1898 TF_CALL_float(REGISTER_GPU); 1899 TF_CALL_double(REGISTER_GPU); 1900 #undef REGISTER_GPU 1901 1902 // TODO(zhengxq): Add the conversion of Cudnn RNN Params from and to 1903 // its canonical form. 1904 1905 #endif // GOOGLE_CUDA 1906 1907 } // namespace tensorflow 1908