1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Implements the StreamExecutor interface by passing through to its 17 // implementation_ value (in pointer-to-implementation style), which 18 // implements StreamExecutorInterface. 19 20 #include "tensorflow/stream_executor/stream_executor_pimpl.h" 21 22 #include <atomic> 23 #include <utility> 24 25 #include "tensorflow/stream_executor/blas.h" 26 #include "tensorflow/stream_executor/fft.h" 27 #include "tensorflow/stream_executor/lib/env.h" 28 #include "tensorflow/stream_executor/lib/error.h" 29 #include "tensorflow/stream_executor/lib/notification.h" 30 #include "tensorflow/stream_executor/lib/stacktrace.h" 31 #include "tensorflow/stream_executor/lib/str_util.h" 32 #include "tensorflow/stream_executor/lib/stringprintf.h" 33 #include "tensorflow/stream_executor/lib/threadpool.h" 34 #include "tensorflow/stream_executor/platform/port.h" 35 #include "tensorflow/stream_executor/rng.h" 36 #include "tensorflow/stream_executor/stream_executor_internal.h" 37 38 namespace { 39 bool FLAGS_check_device_leaks = false; 40 } // namespace 41 42 namespace perftools { 43 namespace gputools { 44 namespace { 45 46 string StackTraceIfVLOG10() { 47 if (VLOG_IS_ON(10)) { 48 return port::StrCat(" ", port::CurrentStackTrace(), "\n"); 49 } else { 50 return ""; 51 } 52 } 53 54 // Make sure the executor is done with its work; we know (because this isn't 55 // publicly visible) that all enqueued work is quick. 56 void BlockOnThreadExecutor(port::ThreadPool *executor) { 57 port::Notification n; 58 executor->Schedule([&n]() { n.Notify(); }); 59 n.WaitForNotification(); 60 } 61 62 internal::StreamExecutorInterface *StreamExecutorImplementationFromPlatformKind( 63 PlatformKind platform_kind, const PluginConfig &plugin_config) { 64 // Note: we use this factory-assignment-in-switch pattern instead of just 65 // invoking the callable in case linkage is messed up -- instead of invoking a 66 // nullptr std::function (due to failed registration) we give a nice 67 // LOG(FATAL) message. 68 internal::StreamExecutorFactory factory; 69 switch (platform_kind) { 70 case PlatformKind::kCuda: 71 factory = *internal::MakeCUDAExecutorImplementation(); 72 break; 73 case PlatformKind::kOpenCL: 74 factory = *internal::MakeOpenCLExecutorImplementation(); 75 break; 76 case PlatformKind::kHost: 77 factory = internal::MakeHostExecutorImplementation; 78 break; 79 default: 80 factory = nullptr; 81 } 82 if (factory == nullptr) { 83 LOG(FATAL) 84 << "cannot create StreamExecutor implementation for platform kind: " 85 << PlatformKindString(platform_kind); 86 } 87 return factory(plugin_config); 88 } 89 90 std::atomic_int_fast64_t correlation_id_generator(0); 91 92 } // namespace 93 94 template <typename BeginCallT, typename CompleteCallT, 95 typename ReturnT, typename... BeginArgsT> 96 class ScopedTracer { 97 public: 98 ScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call, 99 CompleteCallT complete_call, const ReturnT *result, 100 BeginArgsT... begin_args) 101 : stream_exec_(stream_exec), 102 complete_call_(complete_call), 103 result_(result) { 104 if (stream_exec_->tracing_enabled_) { 105 correlation_id_ = 106 correlation_id_generator.fetch_add(1, std::memory_order_relaxed) - 1; 107 Trace(begin_call, begin_args...); 108 } 109 } 110 111 ~ScopedTracer() { 112 if (stream_exec_->tracing_enabled_) { 113 Trace(complete_call_, result_); 114 } 115 } 116 117 private: 118 template <typename CallbackT, typename... TraceArgsT> 119 void Trace(CallbackT callback, TraceArgsT... args) { 120 { 121 // Instance tracers held in a block to limit the lock lifetime. 122 tf_shared_lock lock{stream_exec_->mu_}; 123 for (TraceListener *listener : stream_exec_->listeners_) { 124 (listener->*callback)(correlation_id_, 125 std::forward<TraceArgsT>(args)...); 126 } 127 } 128 } 129 130 StreamExecutor *stream_exec_; 131 CompleteCallT complete_call_; 132 const ReturnT* result_; 133 int64 correlation_id_; 134 }; 135 136 template <typename BeginCallT, typename CompleteCallT, typename ReturnT, 137 typename... BeginArgsT> 138 ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...> 139 MakeScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call, 140 CompleteCallT complete_call, ReturnT *result, 141 BeginArgsT... begin_args) { 142 return ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>( 143 stream_exec, begin_call, complete_call, result, 144 std::forward<BeginArgsT>(begin_args)...); 145 } 146 147 #define SCOPED_TRACE(LOC, ...) \ 148 auto tracer = MakeScopedTracer(this, &LOC ## Begin, \ 149 &LOC ## Complete, ## __VA_ARGS__); 150 151 /* static */ mutex StreamExecutor::static_mu_{LINKER_INITIALIZED}; 152 153 StreamExecutor::StreamExecutor(PlatformKind platform_kind, 154 const PluginConfig &plugin_config) 155 : platform_(nullptr), 156 implementation_(StreamExecutorImplementationFromPlatformKind( 157 platform_kind, plugin_config)), 158 platform_kind_(platform_kind), 159 device_ordinal_(-1), 160 background_threads_(new port::ThreadPool( 161 port::Env::Default(), "stream_executor", kNumBackgroundThreads)), 162 live_stream_count_(0), 163 tracing_enabled_(false) { 164 CheckPlatformKindIsValid(platform_kind); 165 } 166 167 StreamExecutor::StreamExecutor( 168 const Platform *platform, 169 std::unique_ptr<internal::StreamExecutorInterface> implementation) 170 : platform_(platform), 171 implementation_(std::move(implementation)), 172 device_ordinal_(-1), 173 background_threads_(new port::ThreadPool( 174 port::Env::Default(), "stream_executor", kNumBackgroundThreads)), 175 live_stream_count_(0), 176 tracing_enabled_(false) { 177 if (port::Lowercase(platform_->Name()) == "cuda") { 178 platform_kind_ = PlatformKind::kCuda; 179 } else if (port::Lowercase(platform_->Name()) == "opencl") { 180 platform_kind_ = PlatformKind::kOpenCL; 181 } else if (port::Lowercase(platform_->Name()) == "host") { 182 platform_kind_ = PlatformKind::kHost; 183 } 184 } 185 186 StreamExecutor::~StreamExecutor() { 187 BlockOnThreadExecutor(background_threads_.get()); 188 189 if (live_stream_count_.load() != 0) { 190 LOG(WARNING) << "Not all streams were deallocated at executor destruction " 191 << "time. This may lead to unexpected/bad behavior - " 192 << "especially if any stream is still active!"; 193 } 194 195 if (FLAGS_check_device_leaks) { 196 for (auto it : mem_allocs_) { 197 LOG(INFO) << "Memory alloced at executor exit: addr: " 198 << port::Printf("%p", it.first) 199 << ", bytes: " << it.second.bytes << ", trace: \n" 200 << it.second.stack_trace; 201 } 202 } 203 } 204 205 port::Status StreamExecutor::Init(int device_ordinal, 206 DeviceOptions device_options) { 207 device_ordinal_ = device_ordinal; 208 return implementation_->Init(device_ordinal, std::move(device_options)); 209 } 210 211 port::Status StreamExecutor::Init() { 212 return Init(0, DeviceOptions::Default()); 213 } 214 215 bool StreamExecutor::GetKernel(const MultiKernelLoaderSpec &spec, 216 KernelBase *kernel) { 217 return implementation_->GetKernel(spec, kernel); 218 } 219 220 void StreamExecutor::UnloadKernel(const KernelBase *kernel) { 221 implementation_->UnloadKernel(kernel); 222 } 223 224 void StreamExecutor::Deallocate(DeviceMemoryBase *mem) { 225 VLOG(1) << "Called StreamExecutor::Deallocate(mem=" << mem->opaque() 226 << ") mem->size()=" << mem->size() << StackTraceIfVLOG10(); 227 228 if (mem->opaque() != nullptr) { 229 EraseAllocRecord(mem->opaque()); 230 } 231 implementation_->Deallocate(mem); 232 mem->Reset(nullptr, 0); 233 } 234 235 void StreamExecutor::GetMemAllocs(std::map<void *, AllocRecord> *records_out) { 236 tf_shared_lock lock{mu_}; 237 *records_out = mem_allocs_; 238 } 239 240 bool StreamExecutor::CanEnablePeerAccessTo(StreamExecutor *other) { 241 return implementation_->CanEnablePeerAccessTo(other->implementation_.get()); 242 } 243 244 port::Status StreamExecutor::EnablePeerAccessTo(StreamExecutor *other) { 245 return implementation_->EnablePeerAccessTo(other->implementation_.get()); 246 } 247 248 SharedMemoryConfig StreamExecutor::GetDeviceSharedMemoryConfig() { 249 return implementation_->GetDeviceSharedMemoryConfig(); 250 } 251 252 port::Status StreamExecutor::SetDeviceSharedMemoryConfig( 253 SharedMemoryConfig config) { 254 if (config != SharedMemoryConfig::kDefault && 255 config != SharedMemoryConfig::kFourByte && 256 config != SharedMemoryConfig::kEightByte) { 257 string error_msg = port::Printf( 258 "Invalid shared memory config specified: %d", static_cast<int>(config)); 259 LOG(ERROR) << error_msg; 260 return port::Status{port::error::INVALID_ARGUMENT, error_msg}; 261 } 262 return implementation_->SetDeviceSharedMemoryConfig(config); 263 } 264 265 const DeviceDescription &StreamExecutor::GetDeviceDescription() const { 266 mutex_lock lock{mu_}; 267 if (device_description_ != nullptr) { 268 return *device_description_; 269 } 270 271 device_description_.reset(PopulateDeviceDescription()); 272 return *device_description_; 273 } 274 275 int64 StreamExecutor::GetDeviceLoad() const { 276 return implementation_->GetDeviceLoad(); 277 } 278 279 int StreamExecutor::PlatformDeviceCount() const { 280 return implementation_->PlatformDeviceCount(); 281 } 282 283 bool StreamExecutor::SupportsBlas() const { 284 return implementation_->SupportsBlas(); 285 } 286 287 bool StreamExecutor::SupportsRng() const { 288 return implementation_->SupportsRng(); 289 } 290 291 bool StreamExecutor::SupportsDnn() const { 292 return implementation_->SupportsDnn(); 293 } 294 295 bool StreamExecutor::GetConvolveAlgorithms( 296 bool with_winograd_nonfused, 297 std::vector<dnn::AlgorithmDesc> *out_algorithms) { 298 dnn::DnnSupport *dnn_support = AsDnn(); 299 if (!dnn_support) { 300 return false; 301 } 302 int cc_major, cc_minor; 303 GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor); 304 return dnn_support->GetConvolveAlgorithms(with_winograd_nonfused, cc_major, 305 cc_minor, out_algorithms); 306 } 307 308 bool StreamExecutor::GetConvolveBackwardDataAlgorithms( 309 bool with_winograd_nonfused, 310 std::vector<dnn::AlgorithmDesc> *out_algorithms) { 311 dnn::DnnSupport *dnn_support = AsDnn(); 312 if (!dnn_support) { 313 return false; 314 } 315 int cc_major, cc_minor; 316 GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor); 317 return dnn_support->GetConvolveBackwardDataAlgorithms( 318 with_winograd_nonfused, cc_major, cc_minor, out_algorithms); 319 } 320 321 bool StreamExecutor::GetConvolveBackwardFilterAlgorithms( 322 bool with_winograd_nonfused, 323 std::vector<dnn::AlgorithmDesc> *out_algorithms) { 324 dnn::DnnSupport *dnn_support = AsDnn(); 325 if (!dnn_support) { 326 return false; 327 } 328 int cc_major, cc_minor; 329 GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor); 330 return dnn_support->GetConvolveBackwardFilterAlgorithms( 331 with_winograd_nonfused, cc_major, cc_minor, out_algorithms); 332 } 333 334 bool StreamExecutor::GetBlasGemmAlgorithms( 335 std::vector<blas::AlgorithmType> *out_algorithms) { 336 blas::BlasSupport *blas_support = AsBlas(); 337 if (!blas_support) { 338 return false; 339 } 340 return blas_support->GetBlasGemmAlgorithms(out_algorithms); 341 } 342 343 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> 344 StreamExecutor::createRnnDescriptor( 345 int num_layers, int hidden_size, int input_size, 346 dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, 347 dnn::RnnMode rnn_mode, dnn::DataType data_type, float dropout, uint64 seed, 348 ScratchAllocator *state_allocator) { 349 dnn::DnnSupport *dnn_support = AsDnn(); 350 if (!dnn_support) { 351 return port::Status(port::error::UNKNOWN, 352 "Fail to find the dnn implementation."); 353 } 354 return dnn_support->createRnnDescriptor( 355 num_layers, hidden_size, input_size, input_mode, direction_mode, rnn_mode, 356 data_type, dropout, seed, state_allocator); 357 } 358 359 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>> 360 StreamExecutor::createRnnSequenceTensorDescriptor(int seq_length, 361 int batch_size, int data_size, 362 dnn::DataType data_type) { 363 dnn::DnnSupport *dnn_support = AsDnn(); 364 if (!dnn_support) { 365 return port::Status(port::error::UNKNOWN, 366 "Fail to find the dnn implementation."); 367 } 368 return dnn_support->createRnnSequenceTensorDescriptor(seq_length, batch_size, 369 data_size, data_type); 370 } 371 372 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>> 373 StreamExecutor::createRnnStateTensorDescriptor(int num_layer, int batch_size, 374 int data_size, 375 dnn::DataType data_type) { 376 dnn::DnnSupport *dnn_support = AsDnn(); 377 if (!dnn_support) { 378 return port::Status(port::error::UNKNOWN, 379 "Fail to find the dnn implementation."); 380 } 381 return dnn_support->createRnnStateTensorDescriptor(num_layer, batch_size, 382 data_size, data_type); 383 } 384 385 dnn::DnnSupport *StreamExecutor::AsDnn() { 386 mutex_lock lock{mu_}; 387 if (dnn_ != nullptr) { 388 return dnn_.get(); 389 } 390 391 dnn_.reset(implementation_->CreateDnn()); 392 return dnn_.get(); 393 } 394 395 blas::BlasSupport *StreamExecutor::AsBlas() { 396 mutex_lock lock{mu_}; 397 if (blas_ != nullptr) { 398 return blas_.get(); 399 } 400 401 blas_.reset(implementation_->CreateBlas()); 402 return blas_.get(); 403 } 404 405 fft::FftSupport *StreamExecutor::AsFft() { 406 mutex_lock lock{mu_}; 407 if (fft_ != nullptr) { 408 return fft_.get(); 409 } 410 411 fft_.reset(implementation_->CreateFft()); 412 return fft_.get(); 413 } 414 415 rng::RngSupport *StreamExecutor::AsRng() { 416 mutex_lock lock{mu_}; 417 if (rng_ != nullptr) { 418 return rng_.get(); 419 } 420 421 rng_.reset(implementation_->CreateRng()); 422 return rng_.get(); 423 } 424 425 bool StreamExecutor::Launch(Stream *stream, const ThreadDim &thread_dims, 426 const BlockDim &block_dims, 427 const KernelBase &kernel, 428 const KernelArgsArrayBase &args) { 429 SubmitTrace(&TraceListener::LaunchSubmit, stream, thread_dims, block_dims, 430 kernel, args); 431 432 return implementation_->Launch(stream, thread_dims, block_dims, kernel, args); 433 } 434 435 port::Status StreamExecutor::BlockHostUntilDone(Stream *stream) { 436 port::Status result; 437 SCOPED_TRACE(TraceListener::BlockHostUntilDone, &result, stream); 438 439 result = implementation_->BlockHostUntilDone(stream); 440 return result; 441 } 442 443 void *StreamExecutor::Allocate(uint64 size) { 444 void *buf = implementation_->Allocate(size); 445 VLOG(1) << "Called StreamExecutor::Allocate(size=" << size << ") returns " 446 << buf << StackTraceIfVLOG10(); 447 CreateAllocRecord(buf, size); 448 449 return buf; 450 } 451 452 bool StreamExecutor::GetSymbol(const string &symbol_name, void **mem, 453 size_t *bytes) { 454 return implementation_->GetSymbol(symbol_name, mem, bytes); 455 } 456 457 void *StreamExecutor::HostMemoryAllocate(uint64 size) { 458 void *buffer = implementation_->HostMemoryAllocate(size); 459 VLOG(1) << "Called StreamExecutor::HostMemoryAllocate(size=" << size 460 << ") returns " << buffer << StackTraceIfVLOG10(); 461 return buffer; 462 } 463 464 void StreamExecutor::HostMemoryDeallocate(void *location) { 465 VLOG(1) << "Called StreamExecutor::HostMemoryDeallocate(location=" << location 466 << ")" << StackTraceIfVLOG10(); 467 468 return implementation_->HostMemoryDeallocate(location); 469 } 470 471 bool StreamExecutor::HostMemoryRegister(void *location, uint64 size) { 472 VLOG(1) << "Called StreamExecutor::HostMemoryRegister(location=" << location 473 << ", size=" << size << ")" << StackTraceIfVLOG10(); 474 if (location == nullptr || size == 0) { 475 LOG(WARNING) << "attempting to register null or zero-sized memory: " 476 << location << "; size " << size; 477 } 478 return implementation_->HostMemoryRegister(location, size); 479 } 480 481 bool StreamExecutor::HostMemoryUnregister(void *location) { 482 VLOG(1) << "Called StreamExecutor::HostMemoryUnregister(location=" << location 483 << ")" << StackTraceIfVLOG10(); 484 return implementation_->HostMemoryUnregister(location); 485 } 486 487 bool StreamExecutor::SynchronizeAllActivity() { 488 VLOG(1) << "Called StreamExecutor::SynchronizeAllActivity()" 489 << StackTraceIfVLOG10(); 490 bool ok = implementation_->SynchronizeAllActivity(); 491 492 // This should all be quick and infallible work, so we can perform the 493 // synchronization even in the case of failure. 494 BlockOnThreadExecutor(background_threads_.get()); 495 496 return ok; 497 } 498 499 bool StreamExecutor::SynchronousMemZero(DeviceMemoryBase *location, 500 uint64 size) { 501 VLOG(1) << "Called StreamExecutor::SynchronousMemZero(location=" << location 502 << ", size=" << size << ")" << StackTraceIfVLOG10(); 503 504 return implementation_->SynchronousMemZero(location, size); 505 } 506 507 bool StreamExecutor::SynchronousMemSet(DeviceMemoryBase *location, int value, 508 uint64 size) { 509 VLOG(1) << "Called StreamExecutor::SynchronousMemSet(location=" << location 510 << ", value=" << value << ", size=" << size << ")" 511 << StackTraceIfVLOG10(); 512 513 return implementation_->SynchronousMemSet(location, value, size); 514 } 515 516 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst, 517 const void *host_src, uint64 size) { 518 VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst=" 519 << device_dst->opaque() << ", host_src=" << host_src 520 << ", size=" << size << ") H2D" << StackTraceIfVLOG10(); 521 522 // Tracing overloaded methods is very difficult due to issues with type 523 // inference on template args. Since use of these overloaded methods is 524 // discouraged anyway, this isn't a huge deal. 525 port::Status status = 526 implementation_->SynchronousMemcpy(device_dst, host_src, size); 527 if (!status.ok()) { 528 LOG(ERROR) << "synchronous memcpy: " << status; 529 } 530 return status.ok(); 531 } 532 533 bool StreamExecutor::SynchronousMemcpy(void *host_dst, 534 const DeviceMemoryBase &device_src, 535 uint64 size) { 536 VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(host_dst=" << host_dst 537 << ", device_src=" << device_src.opaque() << ", size=" << size 538 << ") D2H" << StackTraceIfVLOG10(); 539 540 port::Status status = 541 implementation_->SynchronousMemcpy(host_dst, device_src, size); 542 if (!status.ok()) { 543 LOG(ERROR) << "synchronous memcpy: " << status; 544 } 545 return status.ok(); 546 } 547 548 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst, 549 const DeviceMemoryBase &device_src, 550 uint64 size) { 551 VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst=" 552 << device_dst->opaque() << ", device_src=" << device_src.opaque() 553 << ", size=" << size << ") D2D" << StackTraceIfVLOG10(); 554 555 port::Status status = implementation_->SynchronousMemcpyDeviceToDevice( 556 device_dst, device_src, size); 557 if (!status.ok()) { 558 LOG(ERROR) << "synchronous memcpy: " << status; 559 } 560 return status.ok(); 561 } 562 563 port::Status StreamExecutor::SynchronousMemcpyD2H( 564 const DeviceMemoryBase &device_src, int64 size, void *host_dst) { 565 VLOG(1) << "Called StreamExecutor::SynchronousMemcpyD2H(device_src=" 566 << device_src.opaque() << ", size=" << size 567 << ", host_dst=" << host_dst << ")" << StackTraceIfVLOG10(); 568 569 port::Status result; 570 SCOPED_TRACE(TraceListener::SynchronousMemcpyD2H, &result, device_src, size, 571 host_dst); 572 573 result = implementation_->SynchronousMemcpy(host_dst, device_src, size); 574 if (!result.ok()) { 575 result = port::Status{port::error::INTERNAL, 576 port::Printf("failed to synchronously memcpy " 577 "device-to-host: device %p to host %p " 578 "size %lld: %s", 579 device_src.opaque(), host_dst, size, 580 result.ToString().c_str())}; 581 } 582 583 return result; 584 } 585 586 port::Status StreamExecutor::SynchronousMemcpyH2D( 587 const void *host_src, int64 size, DeviceMemoryBase *device_dst) { 588 VLOG(1) << "Called StreamExecutor::SynchronousMemcpyH2D(host_src=" << host_src 589 << ", size=" << size << ", device_dst" << device_dst->opaque() << ")" 590 << StackTraceIfVLOG10(); 591 592 port::Status result; 593 SCOPED_TRACE(TraceListener::SynchronousMemcpyH2D, &result, host_src, size, 594 device_dst); 595 596 result = implementation_->SynchronousMemcpy(device_dst, host_src, size); 597 if (!result.ok()) { 598 result = port::Status{ 599 port::error::INTERNAL, 600 port::Printf("failed to synchronously memcpy host-to-device: host " 601 "%p to device %p size %lld: %s", 602 host_src, device_dst->opaque(), size, 603 result.ToString().c_str())}; 604 } 605 606 return result; 607 } 608 609 bool StreamExecutor::Memcpy(Stream *stream, void *host_dst, 610 const DeviceMemoryBase &device_src, uint64 size) { 611 return implementation_->Memcpy(stream, host_dst, device_src, size); 612 } 613 614 bool StreamExecutor::Memcpy(Stream *stream, DeviceMemoryBase *device_dst, 615 const void *host_src, uint64 size) { 616 return implementation_->Memcpy(stream, device_dst, host_src, size); 617 } 618 619 bool StreamExecutor::MemcpyDeviceToDevice(Stream *stream, 620 DeviceMemoryBase *device_dst, 621 const DeviceMemoryBase &device_src, 622 uint64 size) { 623 return implementation_->MemcpyDeviceToDevice(stream, device_dst, device_src, 624 size); 625 } 626 627 bool StreamExecutor::MemZero(Stream *stream, DeviceMemoryBase *location, 628 uint64 size) { 629 return implementation_->MemZero(stream, location, size); 630 } 631 632 bool StreamExecutor::Memset32(Stream *stream, DeviceMemoryBase *location, 633 uint32 pattern, uint64 size) { 634 CHECK_EQ(0, size % 4) 635 << "need 32-bit multiple size to fill with 32-bit pattern"; 636 return implementation_->Memset32(stream, location, pattern, size); 637 } 638 639 bool StreamExecutor::HostCallback(Stream *stream, 640 std::function<void()> callback) { 641 return implementation_->HostCallback(stream, std::move(callback)); 642 } 643 644 port::Status StreamExecutor::AllocateEvent(Event *event) { 645 return implementation_->AllocateEvent(event); 646 } 647 648 port::Status StreamExecutor::DeallocateEvent(Event *event) { 649 return implementation_->DeallocateEvent(event); 650 } 651 652 port::Status StreamExecutor::RecordEvent(Stream *stream, Event *event) { 653 return implementation_->RecordEvent(stream, event); 654 } 655 656 port::Status StreamExecutor::WaitForEvent(Stream *stream, Event *event) { 657 return implementation_->WaitForEvent(stream, event); 658 } 659 660 Event::Status StreamExecutor::PollForEventStatus(Event *event) { 661 return implementation_->PollForEventStatus(event); 662 } 663 664 bool StreamExecutor::AllocateStream(Stream *stream) { 665 live_stream_count_.fetch_add(1, std::memory_order_relaxed); 666 if (!implementation_->AllocateStream(stream)) { 667 auto count = live_stream_count_.fetch_sub(1); 668 CHECK_GE(count, 0) << "live stream count should not dip below zero"; 669 LOG(INFO) << "failed to allocate stream; live stream count: " << count; 670 return false; 671 } 672 673 return true; 674 } 675 676 void StreamExecutor::DeallocateStream(Stream *stream) { 677 implementation_->DeallocateStream(stream); 678 CHECK_GE(live_stream_count_.fetch_sub(1), 0) 679 << "live stream count should not dip below zero"; 680 } 681 682 bool StreamExecutor::CreateStreamDependency(Stream *dependent, Stream *other) { 683 return implementation_->CreateStreamDependency(dependent, other); 684 } 685 686 bool StreamExecutor::AllocateTimer(Timer *timer) { 687 return implementation_->AllocateTimer(timer); 688 } 689 690 void StreamExecutor::DeallocateTimer(Timer *timer) { 691 return implementation_->DeallocateTimer(timer); 692 } 693 694 bool StreamExecutor::StartTimer(Stream *stream, Timer *timer) { 695 return implementation_->StartTimer(stream, timer); 696 } 697 698 bool StreamExecutor::StopTimer(Stream *stream, Timer *timer) { 699 return implementation_->StopTimer(stream, timer); 700 } 701 702 DeviceDescription *StreamExecutor::PopulateDeviceDescription() const { 703 return implementation_->PopulateDeviceDescription(); 704 } 705 706 bool StreamExecutor::DeviceMemoryUsage(int64 *free, int64 *total) const { 707 return implementation_->DeviceMemoryUsage(free, total); 708 } 709 710 void StreamExecutor::EnqueueOnBackgroundThread(std::function<void()> task) { 711 background_threads_->Schedule(std::move(task)); 712 } 713 714 void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) { 715 if (FLAGS_check_device_leaks && opaque != nullptr && bytes != 0) { 716 mutex_lock lock{mu_}; 717 mem_allocs_[opaque] = AllocRecord{ 718 bytes, ""}; 719 } 720 } 721 722 void StreamExecutor::EraseAllocRecord(void *opaque) { 723 if (FLAGS_check_device_leaks && opaque != nullptr) { 724 mutex_lock lock{mu_}; 725 if (mem_allocs_.find(opaque) == mem_allocs_.end()) { 726 LOG(ERROR) << "Deallocating unknown pointer: " 727 << port::Printf("0x%p", opaque); 728 } else { 729 mem_allocs_.erase(opaque); 730 } 731 } 732 } 733 734 void StreamExecutor::EnableTracing(bool enabled) { tracing_enabled_ = enabled; } 735 736 void StreamExecutor::RegisterTraceListener(TraceListener *listener) { 737 { 738 mutex_lock lock{mu_}; 739 if (listeners_.find(listener) != listeners_.end()) { 740 LOG(INFO) << "Attempt to register already-registered listener, " 741 << listener; 742 } else { 743 listeners_.insert(listener); 744 } 745 } 746 747 implementation_->RegisterTraceListener(listener); 748 } 749 750 bool StreamExecutor::UnregisterTraceListener(TraceListener *listener) { 751 { 752 mutex_lock lock{mu_}; 753 if (listeners_.find(listener) == listeners_.end()) { 754 LOG(INFO) << "Attempt to unregister unknown listener, " << listener; 755 return false; 756 } 757 listeners_.erase(listener); 758 } 759 760 implementation_->UnregisterTraceListener(listener); 761 return true; 762 } 763 764 template <typename TraceCallT, typename... ArgsT> 765 void StreamExecutor::SubmitTrace(TraceCallT trace_call, ArgsT &&... args) { 766 if (tracing_enabled_) { 767 { 768 // instance tracers held in a block to limit the lock lifetime. 769 tf_shared_lock lock{mu_}; 770 for (TraceListener *listener : listeners_) { 771 (listener->*trace_call)(std::forward<ArgsT>(args)...); 772 } 773 } 774 } 775 } 776 777 internal::StreamExecutorInterface *StreamExecutor::implementation() { 778 return implementation_->GetUnderlyingExecutor(); 779 } 780 781 } // namespace gputools 782 } // namespace perftools 783