1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Implements the StreamExecutor interface by passing through to its 17 // implementation_ value (in pointer-to-implementation style), which 18 // implements StreamExecutorInterface. 19 20 #include "tensorflow/stream_executor/stream_executor_pimpl.h" 21 22 #include <atomic> 23 #include <utility> 24 25 #include "absl/strings/str_cat.h" 26 #include "tensorflow/core/util/env_var.h" 27 #include "tensorflow/stream_executor/blas.h" 28 #include "tensorflow/stream_executor/fft.h" 29 #include "tensorflow/stream_executor/lib/env.h" 30 #include "tensorflow/stream_executor/lib/error.h" 31 #include "tensorflow/stream_executor/lib/notification.h" 32 #include "tensorflow/stream_executor/lib/stacktrace.h" 33 #include "tensorflow/stream_executor/lib/str_util.h" 34 #include "tensorflow/stream_executor/lib/stringprintf.h" 35 #include "tensorflow/stream_executor/lib/threadpool.h" 36 #include "tensorflow/stream_executor/platform/port.h" 37 #include "tensorflow/stream_executor/rng.h" 38 #include "tensorflow/stream_executor/stream_executor_internal.h" 39 40 namespace { 41 bool FLAGS_check_device_leaks = false; 42 } // namespace 43 44 namespace stream_executor { 45 namespace { 46 47 string StackTraceIfVLOG10() { 48 if (VLOG_IS_ON(10)) { 49 return absl::StrCat(" ", port::CurrentStackTrace(), "\n"); 50 } else { 51 return ""; 52 } 53 } 54 55 // Make sure the executor is done with its work; we know (because this isn't 56 // publicly visible) that all enqueued work is quick. 57 void BlockOnThreadExecutor(port::ThreadPool *executor) { 58 port::Notification n; 59 executor->Schedule([&n]() { n.Notify(); }); 60 n.WaitForNotification(); 61 } 62 63 internal::StreamExecutorInterface *StreamExecutorImplementationFromPlatformKind( 64 PlatformKind platform_kind, const PluginConfig &plugin_config) { 65 // Note: we use this factory-assignment-in-switch pattern instead of just 66 // invoking the callable in case linkage is messed up -- instead of invoking a 67 // nullptr std::function (due to failed registration) we give a nice 68 // LOG(FATAL) message. 69 internal::StreamExecutorFactory factory; 70 switch (platform_kind) { 71 case PlatformKind::kCuda: 72 factory = *internal::MakeCUDAExecutorImplementation(); 73 break; 74 case PlatformKind::kROCm: 75 factory = *internal::MakeROCMExecutorImplementation(); 76 break; 77 case PlatformKind::kOpenCL: 78 factory = *internal::MakeOpenCLExecutorImplementation(); 79 break; 80 case PlatformKind::kHost: 81 factory = internal::MakeHostExecutorImplementation; 82 break; 83 default: 84 factory = nullptr; 85 } 86 if (factory == nullptr) { 87 LOG(FATAL) 88 << "cannot create StreamExecutor implementation for platform kind: " 89 << PlatformKindString(platform_kind); 90 } 91 return factory(plugin_config); 92 } 93 94 std::atomic_int_fast64_t correlation_id_generator(0); 95 96 } // namespace 97 98 template <typename BeginCallT, typename CompleteCallT, 99 typename ReturnT, typename... BeginArgsT> 100 class ScopedTracer { 101 public: 102 ScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call, 103 CompleteCallT complete_call, const ReturnT *result, 104 BeginArgsT... begin_args) 105 : stream_exec_(stream_exec), 106 complete_call_(complete_call), 107 result_(result) { 108 if (stream_exec_->tracing_enabled_) { 109 correlation_id_ = 110 correlation_id_generator.fetch_add(1, std::memory_order_relaxed) - 1; 111 Trace(begin_call, begin_args...); 112 } 113 } 114 115 ~ScopedTracer() { 116 if (stream_exec_->tracing_enabled_) { 117 Trace(complete_call_, result_); 118 } 119 } 120 121 private: 122 template <typename CallbackT, typename... TraceArgsT> 123 void Trace(CallbackT callback, TraceArgsT... args) { 124 { 125 // Instance tracers held in a block to limit the lock lifetime. 126 tf_shared_lock lock{stream_exec_->mu_}; 127 for (TraceListener *listener : stream_exec_->listeners_) { 128 (listener->*callback)(correlation_id_, 129 std::forward<TraceArgsT>(args)...); 130 } 131 } 132 } 133 134 StreamExecutor *stream_exec_; 135 CompleteCallT complete_call_; 136 const ReturnT* result_; 137 int64 correlation_id_; 138 }; 139 140 template <typename BeginCallT, typename CompleteCallT, typename ReturnT, 141 typename... BeginArgsT> 142 ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...> 143 MakeScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call, 144 CompleteCallT complete_call, ReturnT *result, 145 BeginArgsT... begin_args) { 146 return ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>( 147 stream_exec, begin_call, complete_call, result, 148 std::forward<BeginArgsT>(begin_args)...); 149 } 150 151 #define SCOPED_TRACE(LOC, ...) \ 152 auto tracer = MakeScopedTracer(this, &LOC ## Begin, \ 153 &LOC ## Complete, ## __VA_ARGS__); 154 155 /* static */ mutex StreamExecutor::static_mu_{LINKER_INITIALIZED}; 156 157 StreamExecutor::StreamExecutor(PlatformKind platform_kind, 158 const PluginConfig &plugin_config) 159 : platform_(nullptr), 160 implementation_(StreamExecutorImplementationFromPlatformKind( 161 platform_kind, plugin_config)), 162 platform_kind_(platform_kind), 163 device_ordinal_(-1), 164 background_threads_(new port::ThreadPool( 165 port::Env::Default(), "stream_executor", kNumBackgroundThreads)), 166 live_stream_count_(0), 167 tracing_enabled_(false) { 168 CheckPlatformKindIsValid(platform_kind); 169 } 170 171 // Get per-device memory limit in bytes. Returns 0 if 172 // TF_PER_DEVICE_MEMORY_LIMIT_MB environment variable is not set. 173 static int64 GetMemoryLimitBytes() { 174 int64 value; 175 SE_CHECK_OK(tensorflow::ReadInt64FromEnvVar("TF_PER_DEVICE_MEMORY_LIMIT_MB", 176 0, &value)); 177 return value * (1ll << 20); 178 } 179 180 StreamExecutor::StreamExecutor( 181 const Platform *platform, 182 std::unique_ptr<internal::StreamExecutorInterface> implementation) 183 : platform_(platform), 184 implementation_(std::move(implementation)), 185 device_ordinal_(-1), 186 background_threads_(new port::ThreadPool( 187 port::Env::Default(), "stream_executor", kNumBackgroundThreads)), 188 live_stream_count_(0), 189 tracing_enabled_(false), 190 mem_alloc_bytes_(0), 191 memory_limit_bytes_(GetMemoryLimitBytes()) { 192 if (port::Lowercase(platform_->Name()) == "cuda") { 193 platform_kind_ = PlatformKind::kCuda; 194 } else if (port::Lowercase(platform_->Name()) == "rocm") { 195 platform_kind_ = PlatformKind::kROCm; 196 } else if (port::Lowercase(platform_->Name()) == "opencl") { 197 platform_kind_ = PlatformKind::kOpenCL; 198 } else if (port::Lowercase(platform_->Name()) == "host") { 199 platform_kind_ = PlatformKind::kHost; 200 } else { 201 platform_kind_ = PlatformKind::kInvalid; 202 } 203 } 204 205 StreamExecutor::~StreamExecutor() { 206 BlockOnThreadExecutor(background_threads_.get()); 207 208 if (live_stream_count_.load() != 0) { 209 LOG(WARNING) << "Not all streams were deallocated at executor destruction " 210 << "time. This may lead to unexpected/bad behavior - " 211 << "especially if any stream is still active!"; 212 } 213 214 if (FLAGS_check_device_leaks) { 215 for (auto it : mem_allocs_) { 216 LOG(INFO) << "Memory alloced at executor exit: addr: " 217 << port::Printf("%p", it.first) 218 << ", bytes: " << it.second.bytes << ", trace: \n" 219 << it.second.stack_trace; 220 } 221 } 222 } 223 224 port::Status StreamExecutor::Init(int device_ordinal, 225 DeviceOptions device_options) { 226 device_ordinal_ = device_ordinal; 227 return implementation_->Init(device_ordinal, std::move(device_options)); 228 } 229 230 port::Status StreamExecutor::Init() { 231 return Init(0, DeviceOptions::Default()); 232 } 233 234 bool StreamExecutor::GetKernel(const MultiKernelLoaderSpec &spec, 235 KernelBase *kernel) { 236 return implementation_->GetKernel(spec, kernel); 237 } 238 239 void StreamExecutor::UnloadKernel(const KernelBase *kernel) { 240 implementation_->UnloadKernel(kernel); 241 } 242 243 bool StreamExecutor::LoadModule(const MultiModuleLoaderSpec &spec, 244 ModuleHandle *module_handle) { 245 return implementation_->LoadModule(spec, module_handle); 246 } 247 248 bool StreamExecutor::UnloadModule(ModuleHandle module_handle) { 249 return implementation_->UnloadModule(module_handle); 250 } 251 252 void StreamExecutor::Deallocate(DeviceMemoryBase *mem) { 253 VLOG(1) << "Called StreamExecutor::Deallocate(mem=" << mem->opaque() 254 << ") mem->size()=" << mem->size() << StackTraceIfVLOG10(); 255 256 if (mem->opaque() != nullptr) { 257 EraseAllocRecord(mem->opaque()); 258 } 259 implementation_->Deallocate(mem); 260 mem->Reset(nullptr, 0); 261 } 262 263 void StreamExecutor::GetMemAllocs(std::map<void *, AllocRecord> *records_out) { 264 tf_shared_lock lock(mu_); 265 *records_out = mem_allocs_; 266 } 267 268 bool StreamExecutor::CanEnablePeerAccessTo(StreamExecutor *other) { 269 return implementation_->CanEnablePeerAccessTo(other->implementation_.get()); 270 } 271 272 port::Status StreamExecutor::EnablePeerAccessTo(StreamExecutor *other) { 273 return implementation_->EnablePeerAccessTo(other->implementation_.get()); 274 } 275 276 SharedMemoryConfig StreamExecutor::GetDeviceSharedMemoryConfig() { 277 return implementation_->GetDeviceSharedMemoryConfig(); 278 } 279 280 port::Status StreamExecutor::SetDeviceSharedMemoryConfig( 281 SharedMemoryConfig config) { 282 if (config != SharedMemoryConfig::kDefault && 283 config != SharedMemoryConfig::kFourByte && 284 config != SharedMemoryConfig::kEightByte) { 285 string error_msg = port::Printf( 286 "Invalid shared memory config specified: %d", static_cast<int>(config)); 287 LOG(ERROR) << error_msg; 288 return port::Status(port::error::INVALID_ARGUMENT, error_msg); 289 } 290 return implementation_->SetDeviceSharedMemoryConfig(config); 291 } 292 293 const DeviceDescription &StreamExecutor::GetDeviceDescription() const { 294 mutex_lock lock(mu_); 295 if (device_description_ != nullptr) { 296 return *device_description_; 297 } 298 299 device_description_.reset(PopulateDeviceDescription()); 300 return *device_description_; 301 } 302 303 int64 StreamExecutor::GetDeviceLoad() const { 304 return implementation_->GetDeviceLoad(); 305 } 306 307 int StreamExecutor::PlatformDeviceCount() const { 308 return implementation_->PlatformDeviceCount(); 309 } 310 311 bool StreamExecutor::SupportsBlas() const { 312 return implementation_->SupportsBlas(); 313 } 314 315 bool StreamExecutor::SupportsRng() const { 316 return implementation_->SupportsRng(); 317 } 318 319 bool StreamExecutor::SupportsDnn() const { 320 return implementation_->SupportsDnn(); 321 } 322 323 bool StreamExecutor::GetConvolveAlgorithms( 324 bool with_winograd_nonfused, 325 std::vector<dnn::AlgorithmDesc> *out_algorithms) { 326 dnn::DnnSupport *dnn_support = AsDnn(); 327 if (!dnn_support) { 328 return false; 329 } 330 int cc_major, cc_minor; 331 GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor); 332 return dnn_support->GetConvolveAlgorithms(with_winograd_nonfused, cc_major, 333 cc_minor, out_algorithms); 334 } 335 336 bool StreamExecutor::GetRnnAlgorithms( 337 std::vector<dnn::AlgorithmDesc> *out_algorithms) { 338 dnn::DnnSupport *dnn_support = AsDnn(); 339 if (!dnn_support) { 340 return false; 341 } 342 return dnn_support->GetRnnAlgorithms(out_algorithms); 343 } 344 345 bool StreamExecutor::GetConvolveBackwardDataAlgorithms( 346 bool with_winograd_nonfused, 347 std::vector<dnn::AlgorithmDesc> *out_algorithms) { 348 dnn::DnnSupport *dnn_support = AsDnn(); 349 if (!dnn_support) { 350 return false; 351 } 352 int cc_major, cc_minor; 353 GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor); 354 return dnn_support->GetConvolveBackwardDataAlgorithms( 355 with_winograd_nonfused, cc_major, cc_minor, out_algorithms); 356 } 357 358 bool StreamExecutor::GetConvolveBackwardFilterAlgorithms( 359 bool with_winograd_nonfused, 360 std::vector<dnn::AlgorithmDesc> *out_algorithms) { 361 dnn::DnnSupport *dnn_support = AsDnn(); 362 if (!dnn_support) { 363 return false; 364 } 365 int cc_major, cc_minor; 366 GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor); 367 return dnn_support->GetConvolveBackwardFilterAlgorithms( 368 with_winograd_nonfused, cc_major, cc_minor, out_algorithms); 369 } 370 371 bool StreamExecutor::GetBlasGemmAlgorithms( 372 std::vector<blas::AlgorithmType> *out_algorithms) { 373 blas::BlasSupport *blas_support = AsBlas(); 374 if (!blas_support) { 375 return false; 376 } 377 return blas_support->GetBlasGemmAlgorithms(out_algorithms); 378 } 379 380 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> 381 StreamExecutor::createRnnDescriptor( 382 int num_layers, int hidden_size, int input_size, int batch_size, 383 dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, 384 dnn::RnnMode rnn_mode, dnn::DataType data_type, 385 const dnn::AlgorithmConfig &algorithm_config, float dropout, uint64 seed, 386 ScratchAllocator *state_allocator) { 387 dnn::DnnSupport *dnn_support = AsDnn(); 388 if (!dnn_support) { 389 return port::Status(port::error::UNKNOWN, 390 "Fail to find the dnn implementation."); 391 } 392 return dnn_support->createRnnDescriptor( 393 num_layers, hidden_size, input_size, batch_size, input_mode, 394 direction_mode, rnn_mode, data_type, algorithm_config, dropout, seed, 395 state_allocator); 396 } 397 398 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>> 399 StreamExecutor::createRnnSequenceTensorDescriptor(int max_seq_length, 400 int batch_size, int data_size, 401 dnn::DataType data_type) { 402 dnn::DnnSupport *dnn_support = AsDnn(); 403 if (!dnn_support) { 404 return port::Status(port::error::UNKNOWN, 405 "Fail to find the dnn implementation."); 406 } 407 return dnn_support->createRnnSequenceTensorDescriptor( 408 max_seq_length, batch_size, data_size, data_type); 409 } 410 411 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>> 412 StreamExecutor::createRnnSequenceTensorDescriptor( 413 int max_seq_length, int batch_size, int data_size, 414 const absl::Span<const int> &seq_lengths, bool time_major, 415 dnn::DataType data_type) { 416 dnn::DnnSupport *dnn_support = AsDnn(); 417 if (!dnn_support) { 418 return port::Status(port::error::UNKNOWN, 419 "Fail to find the dnn implementation."); 420 } 421 return dnn_support->createRnnSequenceTensorDescriptor( 422 max_seq_length, batch_size, data_size, seq_lengths, time_major, 423 data_type); 424 } 425 426 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>> 427 StreamExecutor::createRnnStateTensorDescriptor(int num_layer, int batch_size, 428 int data_size, 429 dnn::DataType data_type) { 430 dnn::DnnSupport *dnn_support = AsDnn(); 431 if (!dnn_support) { 432 return port::Status(port::error::UNKNOWN, 433 "Fail to find the dnn implementation."); 434 } 435 return dnn_support->createRnnStateTensorDescriptor(num_layer, batch_size, 436 data_size, data_type); 437 } 438 439 dnn::DnnSupport *StreamExecutor::AsDnn() { 440 mutex_lock lock(mu_); 441 if (dnn_ != nullptr) { 442 return dnn_.get(); 443 } 444 445 dnn_.reset(implementation_->CreateDnn()); 446 return dnn_.get(); 447 } 448 449 blas::BlasSupport *StreamExecutor::AsBlas() { 450 mutex_lock lock(mu_); 451 if (blas_ != nullptr) { 452 return blas_.get(); 453 } 454 455 blas_.reset(implementation_->CreateBlas()); 456 return blas_.get(); 457 } 458 459 fft::FftSupport *StreamExecutor::AsFft() { 460 mutex_lock lock(mu_); 461 if (fft_ != nullptr) { 462 return fft_.get(); 463 } 464 465 fft_.reset(implementation_->CreateFft()); 466 return fft_.get(); 467 } 468 469 rng::RngSupport *StreamExecutor::AsRng() { 470 mutex_lock lock(mu_); 471 if (rng_ != nullptr) { 472 return rng_.get(); 473 } 474 475 rng_.reset(implementation_->CreateRng()); 476 return rng_.get(); 477 } 478 479 bool StreamExecutor::Launch(Stream *stream, const ThreadDim &thread_dims, 480 const BlockDim &block_dims, 481 const KernelBase &kernel, 482 const KernelArgsArrayBase &args) { 483 SubmitTrace(&TraceListener::LaunchSubmit, stream, thread_dims, block_dims, 484 kernel, args); 485 486 return implementation_->Launch(stream, thread_dims, block_dims, kernel, args); 487 } 488 489 port::Status StreamExecutor::BlockHostUntilDone(Stream *stream) { 490 port::Status result; 491 SCOPED_TRACE(TraceListener::BlockHostUntilDone, &result, stream); 492 493 result = implementation_->BlockHostUntilDone(stream); 494 return result; 495 } 496 497 port::Status StreamExecutor::GetStatus(Stream *stream) { 498 return implementation_->GetStatus(stream); 499 } 500 501 void *StreamExecutor::Allocate(uint64 size) { 502 if (memory_limit_bytes_ > 0 && 503 mem_alloc_bytes_ + size > memory_limit_bytes_) { 504 LOG(WARNING) << "Not enough memory to allocate " << size << " on device " 505 << device_ordinal_ 506 << " within provided limit. [used=" << mem_alloc_bytes_ 507 << ", limit=" << memory_limit_bytes_ << "]"; 508 return nullptr; 509 } 510 void *buf = implementation_->Allocate(size); 511 VLOG(1) << "Called StreamExecutor::Allocate(size=" << size << ") returns " 512 << buf << StackTraceIfVLOG10(); 513 CreateAllocRecord(buf, size); 514 515 return buf; 516 } 517 518 port::StatusOr<DeviceMemoryBase> StreamExecutor::GetUntypedSymbol( 519 const string &symbol_name, ModuleHandle module_handle) { 520 // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to 521 // be nullptr/0 for consistency with DeviceMemory semantics. 522 void *opaque = nullptr; 523 size_t bytes = 0; 524 if (GetSymbol(symbol_name, module_handle, &opaque, &bytes)) { 525 return DeviceMemoryBase(opaque, bytes); 526 } 527 528 if (static_cast<bool>(module_handle)) { 529 return port::Status( 530 port::error::NOT_FOUND, 531 absl::StrCat("Check if module containing symbol ", symbol_name, 532 " is loaded (module_handle = ", 533 reinterpret_cast<uintptr_t>(module_handle.id()), ")")); 534 } else { 535 return port::Status( 536 port::error::NOT_FOUND, 537 absl::StrCat("Check if kernel using the symbol is loaded: ", 538 symbol_name)); 539 } 540 } 541 542 bool StreamExecutor::GetSymbol(const string &symbol_name, 543 ModuleHandle module_handle, void **mem, 544 size_t *bytes) { 545 return implementation_->GetSymbol(symbol_name, module_handle, mem, bytes); 546 } 547 548 void *StreamExecutor::UnifiedMemoryAllocate(uint64 bytes) { 549 void *buffer = implementation_->UnifiedMemoryAllocate(bytes); 550 VLOG(1) << "Called StreamExecutor::UnifiedMemoryAllocate(size=" << bytes 551 << ") returns " << buffer << StackTraceIfVLOG10(); 552 return buffer; 553 } 554 555 void StreamExecutor::UnifiedMemoryDeallocate(void *location) { 556 VLOG(1) << "Called StreamExecutor::UnifiedMemoryDeallocate(location=" 557 << location << ")" << StackTraceIfVLOG10(); 558 559 return implementation_->UnifiedMemoryDeallocate(location); 560 } 561 562 void *StreamExecutor::HostMemoryAllocate(uint64 size) { 563 void *buffer = implementation_->HostMemoryAllocate(size); 564 VLOG(1) << "Called StreamExecutor::HostMemoryAllocate(size=" << size 565 << ") returns " << buffer << StackTraceIfVLOG10(); 566 return buffer; 567 } 568 569 void StreamExecutor::HostMemoryDeallocate(void *location) { 570 VLOG(1) << "Called StreamExecutor::HostMemoryDeallocate(location=" << location 571 << ")" << StackTraceIfVLOG10(); 572 573 return implementation_->HostMemoryDeallocate(location); 574 } 575 576 bool StreamExecutor::HostMemoryRegister(void *location, uint64 size) { 577 VLOG(1) << "Called StreamExecutor::HostMemoryRegister(location=" << location 578 << ", size=" << size << ")" << StackTraceIfVLOG10(); 579 if (location == nullptr || size == 0) { 580 LOG(WARNING) << "attempting to register null or zero-sized memory: " 581 << location << "; size " << size; 582 } 583 return implementation_->HostMemoryRegister(location, size); 584 } 585 586 bool StreamExecutor::HostMemoryUnregister(void *location) { 587 VLOG(1) << "Called StreamExecutor::HostMemoryUnregister(location=" << location 588 << ")" << StackTraceIfVLOG10(); 589 return implementation_->HostMemoryUnregister(location); 590 } 591 592 bool StreamExecutor::SynchronizeAllActivity() { 593 VLOG(1) << "Called StreamExecutor::SynchronizeAllActivity()" 594 << StackTraceIfVLOG10(); 595 bool ok = implementation_->SynchronizeAllActivity(); 596 597 // This should all be quick and infallible work, so we can perform the 598 // synchronization even in the case of failure. 599 BlockOnThreadExecutor(background_threads_.get()); 600 601 return ok; 602 } 603 604 bool StreamExecutor::SynchronousMemZero(DeviceMemoryBase *location, 605 uint64 size) { 606 VLOG(1) << "Called StreamExecutor::SynchronousMemZero(location=" << location 607 << ", size=" << size << ")" << StackTraceIfVLOG10(); 608 609 return implementation_->SynchronousMemZero(location, size); 610 } 611 612 bool StreamExecutor::SynchronousMemSet(DeviceMemoryBase *location, int value, 613 uint64 size) { 614 VLOG(1) << "Called StreamExecutor::SynchronousMemSet(location=" << location 615 << ", value=" << value << ", size=" << size << ")" 616 << StackTraceIfVLOG10(); 617 618 return implementation_->SynchronousMemSet(location, value, size); 619 } 620 621 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst, 622 const void *host_src, uint64 size) { 623 VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst=" 624 << device_dst->opaque() << ", host_src=" << host_src 625 << ", size=" << size << ") H2D" << StackTraceIfVLOG10(); 626 627 // Tracing overloaded methods is very difficult due to issues with type 628 // inference on template args. Since use of these overloaded methods is 629 // discouraged anyway, this isn't a huge deal. 630 port::Status status = 631 implementation_->SynchronousMemcpy(device_dst, host_src, size); 632 if (!status.ok()) { 633 LOG(ERROR) << "synchronous memcpy: " << status; 634 } 635 return status.ok(); 636 } 637 638 bool StreamExecutor::SynchronousMemcpy(void *host_dst, 639 const DeviceMemoryBase &device_src, 640 uint64 size) { 641 VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(host_dst=" << host_dst 642 << ", device_src=" << device_src.opaque() << ", size=" << size 643 << ") D2H" << StackTraceIfVLOG10(); 644 645 port::Status status = 646 implementation_->SynchronousMemcpy(host_dst, device_src, size); 647 if (!status.ok()) { 648 LOG(ERROR) << "synchronous memcpy: " << status; 649 } 650 return status.ok(); 651 } 652 653 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst, 654 const DeviceMemoryBase &device_src, 655 uint64 size) { 656 VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst=" 657 << device_dst->opaque() << ", device_src=" << device_src.opaque() 658 << ", size=" << size << ") D2D" << StackTraceIfVLOG10(); 659 660 port::Status status = implementation_->SynchronousMemcpyDeviceToDevice( 661 device_dst, device_src, size); 662 if (!status.ok()) { 663 LOG(ERROR) << "synchronous memcpy: " << status; 664 } 665 return status.ok(); 666 } 667 668 port::Status StreamExecutor::SynchronousMemcpyD2H( 669 const DeviceMemoryBase &device_src, int64 size, void *host_dst) { 670 VLOG(1) << "Called StreamExecutor::SynchronousMemcpyD2H(device_src=" 671 << device_src.opaque() << ", size=" << size 672 << ", host_dst=" << host_dst << ")" << StackTraceIfVLOG10(); 673 674 port::Status result; 675 SCOPED_TRACE(TraceListener::SynchronousMemcpyD2H, &result, device_src, size, 676 host_dst); 677 678 result = implementation_->SynchronousMemcpy(host_dst, device_src, size); 679 if (!result.ok()) { 680 result = port::Status(port::error::INTERNAL, 681 port::Printf("failed to synchronously memcpy " 682 "device-to-host: device %p to host %p " 683 "size %lld: %s", 684 device_src.opaque(), host_dst, size, 685 result.ToString().c_str())); 686 } 687 688 return result; 689 } 690 691 port::Status StreamExecutor::SynchronousMemcpyH2D( 692 const void *host_src, int64 size, DeviceMemoryBase *device_dst) { 693 VLOG(1) << "Called StreamExecutor::SynchronousMemcpyH2D(host_src=" << host_src 694 << ", size=" << size << ", device_dst=" << device_dst->opaque() << ")" 695 << StackTraceIfVLOG10(); 696 697 port::Status result; 698 SCOPED_TRACE(TraceListener::SynchronousMemcpyH2D, &result, host_src, size, 699 device_dst); 700 701 result = implementation_->SynchronousMemcpy(device_dst, host_src, size); 702 if (!result.ok()) { 703 result = port::Status( 704 port::error::INTERNAL, 705 port::Printf("failed to synchronously memcpy host-to-device: host " 706 "%p to device %p size %lld: %s", 707 host_src, device_dst->opaque(), size, 708 result.ToString().c_str())); 709 } 710 711 return result; 712 } 713 714 bool StreamExecutor::Memcpy(Stream *stream, void *host_dst, 715 const DeviceMemoryBase &device_src, uint64 size) { 716 return implementation_->Memcpy(stream, host_dst, device_src, size); 717 } 718 719 bool StreamExecutor::Memcpy(Stream *stream, DeviceMemoryBase *device_dst, 720 const void *host_src, uint64 size) { 721 return implementation_->Memcpy(stream, device_dst, host_src, size); 722 } 723 724 bool StreamExecutor::MemcpyDeviceToDevice(Stream *stream, 725 DeviceMemoryBase *device_dst, 726 const DeviceMemoryBase &device_src, 727 uint64 size) { 728 return implementation_->MemcpyDeviceToDevice(stream, device_dst, device_src, 729 size); 730 } 731 732 bool StreamExecutor::MemZero(Stream *stream, DeviceMemoryBase *location, 733 uint64 size) { 734 return implementation_->MemZero(stream, location, size); 735 } 736 737 bool StreamExecutor::Memset32(Stream *stream, DeviceMemoryBase *location, 738 uint32 pattern, uint64 size) { 739 CHECK_EQ(0, size % 4) 740 << "need 32-bit multiple size to fill with 32-bit pattern"; 741 return implementation_->Memset32(stream, location, pattern, size); 742 } 743 744 bool StreamExecutor::HostCallback(Stream *stream, 745 std::function<void()> callback) { 746 return implementation_->HostCallback(stream, std::move(callback)); 747 } 748 749 bool StreamExecutor::HostCallback(Stream *stream, 750 std::function<port::Status()> callback) { 751 return implementation_->HostCallback(stream, std::move(callback)); 752 } 753 754 port::Status StreamExecutor::AllocateEvent(Event *event) { 755 return implementation_->AllocateEvent(event); 756 } 757 758 port::Status StreamExecutor::DeallocateEvent(Event *event) { 759 return implementation_->DeallocateEvent(event); 760 } 761 762 port::Status StreamExecutor::RecordEvent(Stream *stream, Event *event) { 763 return implementation_->RecordEvent(stream, event); 764 } 765 766 port::Status StreamExecutor::WaitForEvent(Stream *stream, Event *event) { 767 return implementation_->WaitForEvent(stream, event); 768 } 769 770 Event::Status StreamExecutor::PollForEventStatus(Event *event) { 771 return implementation_->PollForEventStatus(event); 772 } 773 774 bool StreamExecutor::AllocateStream(Stream *stream) { 775 live_stream_count_.fetch_add(1, std::memory_order_relaxed); 776 if (!implementation_->AllocateStream(stream)) { 777 auto count = live_stream_count_.fetch_sub(1); 778 CHECK_GE(count, 0) << "live stream count should not dip below zero"; 779 LOG(INFO) << "failed to allocate stream; live stream count: " << count; 780 return false; 781 } 782 783 return true; 784 } 785 786 void StreamExecutor::DeallocateStream(Stream *stream) { 787 implementation_->DeallocateStream(stream); 788 CHECK_GE(live_stream_count_.fetch_sub(1), 0) 789 << "live stream count should not dip below zero"; 790 } 791 792 bool StreamExecutor::CreateStreamDependency(Stream *dependent, Stream *other) { 793 return implementation_->CreateStreamDependency(dependent, other); 794 } 795 796 bool StreamExecutor::AllocateTimer(Timer *timer) { 797 return implementation_->AllocateTimer(timer); 798 } 799 800 void StreamExecutor::DeallocateTimer(Timer *timer) { 801 return implementation_->DeallocateTimer(timer); 802 } 803 804 bool StreamExecutor::StartTimer(Stream *stream, Timer *timer) { 805 return implementation_->StartTimer(stream, timer); 806 } 807 808 bool StreamExecutor::StopTimer(Stream *stream, Timer *timer) { 809 return implementation_->StopTimer(stream, timer); 810 } 811 812 DeviceDescription *StreamExecutor::PopulateDeviceDescription() const { 813 return implementation_->PopulateDeviceDescription(); 814 } 815 816 bool StreamExecutor::DeviceMemoryUsage(int64 *free, int64 *total) const { 817 return implementation_->DeviceMemoryUsage(free, total); 818 } 819 820 void StreamExecutor::EnqueueOnBackgroundThread(std::function<void()> task) { 821 background_threads_->Schedule(std::move(task)); 822 } 823 824 void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) { 825 if (FLAGS_check_device_leaks && opaque != nullptr && bytes != 0) { 826 mutex_lock lock(mu_); 827 mem_allocs_[opaque] = AllocRecord{ 828 bytes, ""}; 829 mem_alloc_bytes_ += bytes; 830 } 831 } 832 833 void StreamExecutor::EraseAllocRecord(void *opaque) { 834 if (FLAGS_check_device_leaks && opaque != nullptr) { 835 mutex_lock lock(mu_); 836 if (mem_allocs_.find(opaque) == mem_allocs_.end()) { 837 LOG(ERROR) << "Deallocating unknown pointer: " 838 << port::Printf("0x%p", opaque); 839 } else { 840 mem_alloc_bytes_ -= mem_allocs_[opaque].bytes; 841 mem_allocs_.erase(opaque); 842 } 843 } 844 } 845 846 void StreamExecutor::EnableTracing(bool enabled) { tracing_enabled_ = enabled; } 847 848 void StreamExecutor::RegisterTraceListener(TraceListener *listener) { 849 { 850 mutex_lock lock(mu_); 851 if (listeners_.find(listener) != listeners_.end()) { 852 LOG(INFO) << "Attempt to register already-registered listener, " 853 << listener; 854 } else { 855 listeners_.insert(listener); 856 } 857 } 858 859 implementation_->RegisterTraceListener(listener); 860 } 861 862 bool StreamExecutor::UnregisterTraceListener(TraceListener *listener) { 863 { 864 mutex_lock lock(mu_); 865 if (listeners_.find(listener) == listeners_.end()) { 866 LOG(INFO) << "Attempt to unregister unknown listener, " << listener; 867 return false; 868 } 869 listeners_.erase(listener); 870 } 871 872 implementation_->UnregisterTraceListener(listener); 873 return true; 874 } 875 876 absl::optional<AllocatorStats> StreamExecutor::GetAllocatorStats() { 877 return implementation_->GetAllocatorStats(); 878 } 879 880 template <typename TraceCallT, typename... ArgsT> 881 void StreamExecutor::SubmitTrace(TraceCallT trace_call, ArgsT &&... args) { 882 if (tracing_enabled_) { 883 { 884 // instance tracers held in a block to limit the lock lifetime. 885 tf_shared_lock lock(mu_); 886 for (TraceListener *listener : listeners_) { 887 (listener->*trace_call)(std::forward<ArgsT>(args)...); 888 } 889 } 890 } 891 } 892 893 internal::StreamExecutorInterface *StreamExecutor::implementation() { 894 return implementation_->GetUnderlyingExecutor(); 895 } 896 897 } // namespace stream_executor 898