Home | History | Annotate | Download | only in stream_executor
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Implements the StreamExecutor interface by passing through to its
     17 // implementation_ value (in pointer-to-implementation style), which
     18 // implements StreamExecutorInterface.
     19 
     20 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
     21 
     22 #include <atomic>
     23 #include <utility>
     24 
     25 #include "absl/strings/str_cat.h"
     26 #include "tensorflow/core/util/env_var.h"
     27 #include "tensorflow/stream_executor/blas.h"
     28 #include "tensorflow/stream_executor/fft.h"
     29 #include "tensorflow/stream_executor/lib/env.h"
     30 #include "tensorflow/stream_executor/lib/error.h"
     31 #include "tensorflow/stream_executor/lib/notification.h"
     32 #include "tensorflow/stream_executor/lib/stacktrace.h"
     33 #include "tensorflow/stream_executor/lib/str_util.h"
     34 #include "tensorflow/stream_executor/lib/stringprintf.h"
     35 #include "tensorflow/stream_executor/lib/threadpool.h"
     36 #include "tensorflow/stream_executor/platform/port.h"
     37 #include "tensorflow/stream_executor/rng.h"
     38 #include "tensorflow/stream_executor/stream_executor_internal.h"
     39 
     40 namespace {
     41 bool FLAGS_check_device_leaks = false;
     42 }  // namespace
     43 
     44 namespace stream_executor {
     45 namespace {
     46 
     47 string StackTraceIfVLOG10() {
     48   if (VLOG_IS_ON(10)) {
     49     return absl::StrCat(" ", port::CurrentStackTrace(), "\n");
     50   } else {
     51     return "";
     52   }
     53 }
     54 
     55 // Make sure the executor is done with its work; we know (because this isn't
     56 // publicly visible) that all enqueued work is quick.
     57 void BlockOnThreadExecutor(port::ThreadPool *executor) {
     58   port::Notification n;
     59   executor->Schedule([&n]() { n.Notify(); });
     60   n.WaitForNotification();
     61 }
     62 
     63 internal::StreamExecutorInterface *StreamExecutorImplementationFromPlatformKind(
     64     PlatformKind platform_kind, const PluginConfig &plugin_config) {
     65   // Note: we use this factory-assignment-in-switch pattern instead of just
     66   // invoking the callable in case linkage is messed up -- instead of invoking a
     67   // nullptr std::function (due to failed registration) we give a nice
     68   // LOG(FATAL) message.
     69   internal::StreamExecutorFactory factory;
     70   switch (platform_kind) {
     71     case PlatformKind::kCuda:
     72       factory = *internal::MakeCUDAExecutorImplementation();
     73       break;
     74     case PlatformKind::kROCm:
     75       factory = *internal::MakeROCMExecutorImplementation();
     76       break;
     77     case PlatformKind::kOpenCL:
     78       factory = *internal::MakeOpenCLExecutorImplementation();
     79       break;
     80     case PlatformKind::kHost:
     81       factory = internal::MakeHostExecutorImplementation;
     82       break;
     83     default:
     84       factory = nullptr;
     85   }
     86   if (factory == nullptr) {
     87     LOG(FATAL)
     88         << "cannot create StreamExecutor implementation for platform kind: "
     89         << PlatformKindString(platform_kind);
     90   }
     91   return factory(plugin_config);
     92 }
     93 
     94 std::atomic_int_fast64_t correlation_id_generator(0);
     95 
     96 }  // namespace
     97 
     98 template <typename BeginCallT, typename CompleteCallT,
     99           typename ReturnT, typename... BeginArgsT>
    100 class ScopedTracer {
    101  public:
    102   ScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
    103                CompleteCallT complete_call, const ReturnT *result,
    104                BeginArgsT... begin_args)
    105       : stream_exec_(stream_exec),
    106         complete_call_(complete_call),
    107         result_(result) {
    108     if (stream_exec_->tracing_enabled_) {
    109       correlation_id_ =
    110           correlation_id_generator.fetch_add(1, std::memory_order_relaxed) - 1;
    111       Trace(begin_call, begin_args...);
    112     }
    113   }
    114 
    115   ~ScopedTracer() {
    116     if (stream_exec_->tracing_enabled_) {
    117       Trace(complete_call_, result_);
    118     }
    119   }
    120 
    121  private:
    122   template <typename CallbackT, typename... TraceArgsT>
    123   void Trace(CallbackT callback, TraceArgsT... args) {
    124     {
    125       // Instance tracers held in a block to limit the lock lifetime.
    126       tf_shared_lock lock{stream_exec_->mu_};
    127       for (TraceListener *listener : stream_exec_->listeners_) {
    128         (listener->*callback)(correlation_id_,
    129                               std::forward<TraceArgsT>(args)...);
    130       }
    131     }
    132   }
    133 
    134   StreamExecutor *stream_exec_;
    135   CompleteCallT complete_call_;
    136   const ReturnT* result_;
    137   int64 correlation_id_;
    138 };
    139 
    140 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
    141           typename... BeginArgsT>
    142 ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>
    143 MakeScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
    144                  CompleteCallT complete_call, ReturnT *result,
    145                  BeginArgsT... begin_args) {
    146   return ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>(
    147       stream_exec, begin_call, complete_call, result,
    148       std::forward<BeginArgsT>(begin_args)...);
    149 }
    150 
    151 #define SCOPED_TRACE(LOC, ...)                                      \
    152   auto tracer = MakeScopedTracer(this, &LOC ## Begin,               \
    153                                  &LOC ## Complete, ## __VA_ARGS__);
    154 
    155 /* static */ mutex StreamExecutor::static_mu_{LINKER_INITIALIZED};
    156 
    157 StreamExecutor::StreamExecutor(PlatformKind platform_kind,
    158                                const PluginConfig &plugin_config)
    159     : platform_(nullptr),
    160       implementation_(StreamExecutorImplementationFromPlatformKind(
    161           platform_kind, plugin_config)),
    162       platform_kind_(platform_kind),
    163       device_ordinal_(-1),
    164       background_threads_(new port::ThreadPool(
    165           port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
    166       live_stream_count_(0),
    167       tracing_enabled_(false) {
    168   CheckPlatformKindIsValid(platform_kind);
    169 }
    170 
    171 // Get per-device memory limit in bytes. Returns 0 if
    172 // TF_PER_DEVICE_MEMORY_LIMIT_MB environment variable is not set.
    173 static int64 GetMemoryLimitBytes() {
    174   int64 value;
    175   SE_CHECK_OK(tensorflow::ReadInt64FromEnvVar("TF_PER_DEVICE_MEMORY_LIMIT_MB",
    176                                               0, &value));
    177   return value * (1ll << 20);
    178 }
    179 
    180 StreamExecutor::StreamExecutor(
    181     const Platform *platform,
    182     std::unique_ptr<internal::StreamExecutorInterface> implementation)
    183     : platform_(platform),
    184       implementation_(std::move(implementation)),
    185       device_ordinal_(-1),
    186       background_threads_(new port::ThreadPool(
    187           port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
    188       live_stream_count_(0),
    189       tracing_enabled_(false),
    190       mem_alloc_bytes_(0),
    191       memory_limit_bytes_(GetMemoryLimitBytes()) {
    192   if (port::Lowercase(platform_->Name()) == "cuda") {
    193     platform_kind_ = PlatformKind::kCuda;
    194   } else if (port::Lowercase(platform_->Name()) == "rocm") {
    195     platform_kind_ = PlatformKind::kROCm;
    196   } else if (port::Lowercase(platform_->Name()) == "opencl") {
    197     platform_kind_ = PlatformKind::kOpenCL;
    198   } else if (port::Lowercase(platform_->Name()) == "host") {
    199     platform_kind_ = PlatformKind::kHost;
    200   } else {
    201     platform_kind_ = PlatformKind::kInvalid;
    202   }
    203 }
    204 
    205 StreamExecutor::~StreamExecutor() {
    206   BlockOnThreadExecutor(background_threads_.get());
    207 
    208   if (live_stream_count_.load() != 0) {
    209     LOG(WARNING) << "Not all streams were deallocated at executor destruction "
    210                  << "time. This may lead to unexpected/bad behavior - "
    211                  << "especially if any stream is still active!";
    212   }
    213 
    214   if (FLAGS_check_device_leaks) {
    215     for (auto it : mem_allocs_) {
    216       LOG(INFO) << "Memory alloced at executor exit: addr: "
    217                 << port::Printf("%p", it.first)
    218                 << ", bytes: " << it.second.bytes << ", trace: \n"
    219                 << it.second.stack_trace;
    220     }
    221   }
    222 }
    223 
    224 port::Status StreamExecutor::Init(int device_ordinal,
    225                                   DeviceOptions device_options) {
    226   device_ordinal_ = device_ordinal;
    227   return implementation_->Init(device_ordinal, std::move(device_options));
    228 }
    229 
    230 port::Status StreamExecutor::Init() {
    231   return Init(0, DeviceOptions::Default());
    232 }
    233 
    234 bool StreamExecutor::GetKernel(const MultiKernelLoaderSpec &spec,
    235                                KernelBase *kernel) {
    236   return implementation_->GetKernel(spec, kernel);
    237 }
    238 
    239 void StreamExecutor::UnloadKernel(const KernelBase *kernel) {
    240   implementation_->UnloadKernel(kernel);
    241 }
    242 
    243 bool StreamExecutor::LoadModule(const MultiModuleLoaderSpec &spec,
    244                                 ModuleHandle *module_handle) {
    245   return implementation_->LoadModule(spec, module_handle);
    246 }
    247 
    248 bool StreamExecutor::UnloadModule(ModuleHandle module_handle) {
    249   return implementation_->UnloadModule(module_handle);
    250 }
    251 
    252 void StreamExecutor::Deallocate(DeviceMemoryBase *mem) {
    253   VLOG(1) << "Called StreamExecutor::Deallocate(mem=" << mem->opaque()
    254           << ") mem->size()=" << mem->size() << StackTraceIfVLOG10();
    255 
    256   if (mem->opaque() != nullptr) {
    257     EraseAllocRecord(mem->opaque());
    258   }
    259   implementation_->Deallocate(mem);
    260   mem->Reset(nullptr, 0);
    261 }
    262 
    263 void StreamExecutor::GetMemAllocs(std::map<void *, AllocRecord> *records_out) {
    264   tf_shared_lock lock(mu_);
    265   *records_out = mem_allocs_;
    266 }
    267 
    268 bool StreamExecutor::CanEnablePeerAccessTo(StreamExecutor *other) {
    269   return implementation_->CanEnablePeerAccessTo(other->implementation_.get());
    270 }
    271 
    272 port::Status StreamExecutor::EnablePeerAccessTo(StreamExecutor *other) {
    273   return implementation_->EnablePeerAccessTo(other->implementation_.get());
    274 }
    275 
    276 SharedMemoryConfig StreamExecutor::GetDeviceSharedMemoryConfig() {
    277   return implementation_->GetDeviceSharedMemoryConfig();
    278 }
    279 
    280 port::Status StreamExecutor::SetDeviceSharedMemoryConfig(
    281     SharedMemoryConfig config) {
    282   if (config != SharedMemoryConfig::kDefault &&
    283       config != SharedMemoryConfig::kFourByte &&
    284       config != SharedMemoryConfig::kEightByte) {
    285     string error_msg = port::Printf(
    286         "Invalid shared memory config specified: %d", static_cast<int>(config));
    287     LOG(ERROR) << error_msg;
    288     return port::Status(port::error::INVALID_ARGUMENT, error_msg);
    289   }
    290   return implementation_->SetDeviceSharedMemoryConfig(config);
    291 }
    292 
    293 const DeviceDescription &StreamExecutor::GetDeviceDescription() const {
    294   mutex_lock lock(mu_);
    295   if (device_description_ != nullptr) {
    296     return *device_description_;
    297   }
    298 
    299   device_description_.reset(PopulateDeviceDescription());
    300   return *device_description_;
    301 }
    302 
    303 int64 StreamExecutor::GetDeviceLoad() const {
    304   return implementation_->GetDeviceLoad();
    305 }
    306 
    307 int StreamExecutor::PlatformDeviceCount() const {
    308   return implementation_->PlatformDeviceCount();
    309 }
    310 
    311 bool StreamExecutor::SupportsBlas() const {
    312   return implementation_->SupportsBlas();
    313 }
    314 
    315 bool StreamExecutor::SupportsRng() const {
    316   return implementation_->SupportsRng();
    317 }
    318 
    319 bool StreamExecutor::SupportsDnn() const {
    320   return implementation_->SupportsDnn();
    321 }
    322 
    323 bool StreamExecutor::GetConvolveAlgorithms(
    324     bool with_winograd_nonfused,
    325     std::vector<dnn::AlgorithmDesc> *out_algorithms) {
    326   dnn::DnnSupport *dnn_support = AsDnn();
    327   if (!dnn_support) {
    328     return false;
    329   }
    330   int cc_major, cc_minor;
    331   GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
    332   return dnn_support->GetConvolveAlgorithms(with_winograd_nonfused, cc_major,
    333                                             cc_minor, out_algorithms);
    334 }
    335 
    336 bool StreamExecutor::GetRnnAlgorithms(
    337     std::vector<dnn::AlgorithmDesc> *out_algorithms) {
    338   dnn::DnnSupport *dnn_support = AsDnn();
    339   if (!dnn_support) {
    340     return false;
    341   }
    342   return dnn_support->GetRnnAlgorithms(out_algorithms);
    343 }
    344 
    345 bool StreamExecutor::GetConvolveBackwardDataAlgorithms(
    346     bool with_winograd_nonfused,
    347     std::vector<dnn::AlgorithmDesc> *out_algorithms) {
    348   dnn::DnnSupport *dnn_support = AsDnn();
    349   if (!dnn_support) {
    350     return false;
    351   }
    352   int cc_major, cc_minor;
    353   GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
    354   return dnn_support->GetConvolveBackwardDataAlgorithms(
    355       with_winograd_nonfused, cc_major, cc_minor, out_algorithms);
    356 }
    357 
    358 bool StreamExecutor::GetConvolveBackwardFilterAlgorithms(
    359     bool with_winograd_nonfused,
    360     std::vector<dnn::AlgorithmDesc> *out_algorithms) {
    361   dnn::DnnSupport *dnn_support = AsDnn();
    362   if (!dnn_support) {
    363     return false;
    364   }
    365   int cc_major, cc_minor;
    366   GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
    367   return dnn_support->GetConvolveBackwardFilterAlgorithms(
    368       with_winograd_nonfused, cc_major, cc_minor, out_algorithms);
    369 }
    370 
    371 bool StreamExecutor::GetBlasGemmAlgorithms(
    372     std::vector<blas::AlgorithmType> *out_algorithms) {
    373   blas::BlasSupport *blas_support = AsBlas();
    374   if (!blas_support) {
    375     return false;
    376   }
    377   return blas_support->GetBlasGemmAlgorithms(out_algorithms);
    378 }
    379 
    380 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
    381 StreamExecutor::createRnnDescriptor(
    382     int num_layers, int hidden_size, int input_size, int batch_size,
    383     dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
    384     dnn::RnnMode rnn_mode, dnn::DataType data_type,
    385     const dnn::AlgorithmConfig &algorithm_config, float dropout, uint64 seed,
    386     ScratchAllocator *state_allocator) {
    387   dnn::DnnSupport *dnn_support = AsDnn();
    388   if (!dnn_support) {
    389     return port::Status(port::error::UNKNOWN,
    390                         "Fail to find the dnn implementation.");
    391   }
    392   return dnn_support->createRnnDescriptor(
    393       num_layers, hidden_size, input_size, batch_size, input_mode,
    394       direction_mode, rnn_mode, data_type, algorithm_config, dropout, seed,
    395       state_allocator);
    396 }
    397 
    398 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
    399 StreamExecutor::createRnnSequenceTensorDescriptor(int max_seq_length,
    400                                                   int batch_size, int data_size,
    401                                                   dnn::DataType data_type) {
    402   dnn::DnnSupport *dnn_support = AsDnn();
    403   if (!dnn_support) {
    404     return port::Status(port::error::UNKNOWN,
    405                         "Fail to find the dnn implementation.");
    406   }
    407   return dnn_support->createRnnSequenceTensorDescriptor(
    408       max_seq_length, batch_size, data_size, data_type);
    409 }
    410 
    411 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
    412 StreamExecutor::createRnnSequenceTensorDescriptor(
    413     int max_seq_length, int batch_size, int data_size,
    414     const absl::Span<const int> &seq_lengths, bool time_major,
    415     dnn::DataType data_type) {
    416   dnn::DnnSupport *dnn_support = AsDnn();
    417   if (!dnn_support) {
    418     return port::Status(port::error::UNKNOWN,
    419                         "Fail to find the dnn implementation.");
    420   }
    421   return dnn_support->createRnnSequenceTensorDescriptor(
    422       max_seq_length, batch_size, data_size, seq_lengths, time_major,
    423       data_type);
    424 }
    425 
    426 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
    427 StreamExecutor::createRnnStateTensorDescriptor(int num_layer, int batch_size,
    428                                                int data_size,
    429                                                dnn::DataType data_type) {
    430   dnn::DnnSupport *dnn_support = AsDnn();
    431   if (!dnn_support) {
    432     return port::Status(port::error::UNKNOWN,
    433                         "Fail to find the dnn implementation.");
    434   }
    435   return dnn_support->createRnnStateTensorDescriptor(num_layer, batch_size,
    436                                                      data_size, data_type);
    437 }
    438 
    439 dnn::DnnSupport *StreamExecutor::AsDnn() {
    440   mutex_lock lock(mu_);
    441   if (dnn_ != nullptr) {
    442     return dnn_.get();
    443   }
    444 
    445   dnn_.reset(implementation_->CreateDnn());
    446   return dnn_.get();
    447 }
    448 
    449 blas::BlasSupport *StreamExecutor::AsBlas() {
    450   mutex_lock lock(mu_);
    451   if (blas_ != nullptr) {
    452     return blas_.get();
    453   }
    454 
    455   blas_.reset(implementation_->CreateBlas());
    456   return blas_.get();
    457 }
    458 
    459 fft::FftSupport *StreamExecutor::AsFft() {
    460   mutex_lock lock(mu_);
    461   if (fft_ != nullptr) {
    462     return fft_.get();
    463   }
    464 
    465   fft_.reset(implementation_->CreateFft());
    466   return fft_.get();
    467 }
    468 
    469 rng::RngSupport *StreamExecutor::AsRng() {
    470   mutex_lock lock(mu_);
    471   if (rng_ != nullptr) {
    472     return rng_.get();
    473   }
    474 
    475   rng_.reset(implementation_->CreateRng());
    476   return rng_.get();
    477 }
    478 
    479 bool StreamExecutor::Launch(Stream *stream, const ThreadDim &thread_dims,
    480                             const BlockDim &block_dims,
    481                             const KernelBase &kernel,
    482                             const KernelArgsArrayBase &args) {
    483   SubmitTrace(&TraceListener::LaunchSubmit, stream, thread_dims, block_dims,
    484               kernel, args);
    485 
    486   return implementation_->Launch(stream, thread_dims, block_dims, kernel, args);
    487 }
    488 
    489 port::Status StreamExecutor::BlockHostUntilDone(Stream *stream) {
    490   port::Status result;
    491   SCOPED_TRACE(TraceListener::BlockHostUntilDone, &result, stream);
    492 
    493   result = implementation_->BlockHostUntilDone(stream);
    494   return result;
    495 }
    496 
    497 port::Status StreamExecutor::GetStatus(Stream *stream) {
    498   return implementation_->GetStatus(stream);
    499 }
    500 
    501 void *StreamExecutor::Allocate(uint64 size) {
    502   if (memory_limit_bytes_ > 0 &&
    503       mem_alloc_bytes_ + size > memory_limit_bytes_) {
    504     LOG(WARNING) << "Not enough memory to allocate " << size << " on device "
    505                  << device_ordinal_
    506                  << " within provided limit. [used=" << mem_alloc_bytes_
    507                  << ", limit=" << memory_limit_bytes_ << "]";
    508     return nullptr;
    509   }
    510   void *buf = implementation_->Allocate(size);
    511   VLOG(1) << "Called StreamExecutor::Allocate(size=" << size << ") returns "
    512           << buf << StackTraceIfVLOG10();
    513   CreateAllocRecord(buf, size);
    514 
    515   return buf;
    516 }
    517 
    518 port::StatusOr<DeviceMemoryBase> StreamExecutor::GetUntypedSymbol(
    519     const string &symbol_name, ModuleHandle module_handle) {
    520   // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to
    521   // be nullptr/0 for consistency with DeviceMemory semantics.
    522   void *opaque = nullptr;
    523   size_t bytes = 0;
    524   if (GetSymbol(symbol_name, module_handle, &opaque, &bytes)) {
    525     return DeviceMemoryBase(opaque, bytes);
    526   }
    527 
    528   if (static_cast<bool>(module_handle)) {
    529     return port::Status(
    530         port::error::NOT_FOUND,
    531         absl::StrCat("Check if module containing symbol ", symbol_name,
    532                      " is loaded (module_handle = ",
    533                      reinterpret_cast<uintptr_t>(module_handle.id()), ")"));
    534   } else {
    535     return port::Status(
    536         port::error::NOT_FOUND,
    537         absl::StrCat("Check if kernel using the symbol is loaded: ",
    538                      symbol_name));
    539   }
    540 }
    541 
    542 bool StreamExecutor::GetSymbol(const string &symbol_name,
    543                                ModuleHandle module_handle, void **mem,
    544                                size_t *bytes) {
    545   return implementation_->GetSymbol(symbol_name, module_handle, mem, bytes);
    546 }
    547 
    548 void *StreamExecutor::UnifiedMemoryAllocate(uint64 bytes) {
    549   void *buffer = implementation_->UnifiedMemoryAllocate(bytes);
    550   VLOG(1) << "Called StreamExecutor::UnifiedMemoryAllocate(size=" << bytes
    551           << ") returns " << buffer << StackTraceIfVLOG10();
    552   return buffer;
    553 }
    554 
    555 void StreamExecutor::UnifiedMemoryDeallocate(void *location) {
    556   VLOG(1) << "Called StreamExecutor::UnifiedMemoryDeallocate(location="
    557           << location << ")" << StackTraceIfVLOG10();
    558 
    559   return implementation_->UnifiedMemoryDeallocate(location);
    560 }
    561 
    562 void *StreamExecutor::HostMemoryAllocate(uint64 size) {
    563   void *buffer = implementation_->HostMemoryAllocate(size);
    564   VLOG(1) << "Called StreamExecutor::HostMemoryAllocate(size=" << size
    565           << ") returns " << buffer << StackTraceIfVLOG10();
    566   return buffer;
    567 }
    568 
    569 void StreamExecutor::HostMemoryDeallocate(void *location) {
    570   VLOG(1) << "Called StreamExecutor::HostMemoryDeallocate(location=" << location
    571           << ")" << StackTraceIfVLOG10();
    572 
    573   return implementation_->HostMemoryDeallocate(location);
    574 }
    575 
    576 bool StreamExecutor::HostMemoryRegister(void *location, uint64 size) {
    577   VLOG(1) << "Called StreamExecutor::HostMemoryRegister(location=" << location
    578           << ", size=" << size << ")" << StackTraceIfVLOG10();
    579   if (location == nullptr || size == 0) {
    580     LOG(WARNING) << "attempting to register null or zero-sized memory: "
    581                  << location << "; size " << size;
    582   }
    583   return implementation_->HostMemoryRegister(location, size);
    584 }
    585 
    586 bool StreamExecutor::HostMemoryUnregister(void *location) {
    587   VLOG(1) << "Called StreamExecutor::HostMemoryUnregister(location=" << location
    588           << ")" << StackTraceIfVLOG10();
    589   return implementation_->HostMemoryUnregister(location);
    590 }
    591 
    592 bool StreamExecutor::SynchronizeAllActivity() {
    593   VLOG(1) << "Called StreamExecutor::SynchronizeAllActivity()"
    594           << StackTraceIfVLOG10();
    595   bool ok = implementation_->SynchronizeAllActivity();
    596 
    597   // This should all be quick and infallible work, so we can perform the
    598   // synchronization even in the case of failure.
    599   BlockOnThreadExecutor(background_threads_.get());
    600 
    601   return ok;
    602 }
    603 
    604 bool StreamExecutor::SynchronousMemZero(DeviceMemoryBase *location,
    605                                         uint64 size) {
    606   VLOG(1) << "Called StreamExecutor::SynchronousMemZero(location=" << location
    607           << ", size=" << size << ")" << StackTraceIfVLOG10();
    608 
    609   return implementation_->SynchronousMemZero(location, size);
    610 }
    611 
    612 bool StreamExecutor::SynchronousMemSet(DeviceMemoryBase *location, int value,
    613                                        uint64 size) {
    614   VLOG(1) << "Called StreamExecutor::SynchronousMemSet(location=" << location
    615           << ", value=" << value << ", size=" << size << ")"
    616           << StackTraceIfVLOG10();
    617 
    618   return implementation_->SynchronousMemSet(location, value, size);
    619 }
    620 
    621 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst,
    622                                        const void *host_src, uint64 size) {
    623   VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst="
    624           << device_dst->opaque() << ", host_src=" << host_src
    625           << ", size=" << size << ") H2D" << StackTraceIfVLOG10();
    626 
    627   // Tracing overloaded methods is very difficult due to issues with type
    628   // inference on template args. Since use of these overloaded methods is
    629   // discouraged anyway, this isn't a huge deal.
    630   port::Status status =
    631       implementation_->SynchronousMemcpy(device_dst, host_src, size);
    632   if (!status.ok()) {
    633     LOG(ERROR) << "synchronous memcpy: " << status;
    634   }
    635   return status.ok();
    636 }
    637 
    638 bool StreamExecutor::SynchronousMemcpy(void *host_dst,
    639                                        const DeviceMemoryBase &device_src,
    640                                        uint64 size) {
    641   VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(host_dst=" << host_dst
    642           << ", device_src=" << device_src.opaque() << ", size=" << size
    643           << ") D2H" << StackTraceIfVLOG10();
    644 
    645   port::Status status =
    646       implementation_->SynchronousMemcpy(host_dst, device_src, size);
    647   if (!status.ok()) {
    648     LOG(ERROR) << "synchronous memcpy: " << status;
    649   }
    650   return status.ok();
    651 }
    652 
    653 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst,
    654                                        const DeviceMemoryBase &device_src,
    655                                        uint64 size) {
    656   VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst="
    657           << device_dst->opaque() << ", device_src=" << device_src.opaque()
    658           << ", size=" << size << ") D2D" << StackTraceIfVLOG10();
    659 
    660   port::Status status = implementation_->SynchronousMemcpyDeviceToDevice(
    661       device_dst, device_src, size);
    662   if (!status.ok()) {
    663     LOG(ERROR) << "synchronous memcpy: " << status;
    664   }
    665   return status.ok();
    666 }
    667 
    668 port::Status StreamExecutor::SynchronousMemcpyD2H(
    669     const DeviceMemoryBase &device_src, int64 size, void *host_dst) {
    670   VLOG(1) << "Called StreamExecutor::SynchronousMemcpyD2H(device_src="
    671           << device_src.opaque() << ", size=" << size
    672           << ", host_dst=" << host_dst << ")" << StackTraceIfVLOG10();
    673 
    674   port::Status result;
    675   SCOPED_TRACE(TraceListener::SynchronousMemcpyD2H, &result, device_src, size,
    676                host_dst);
    677 
    678   result = implementation_->SynchronousMemcpy(host_dst, device_src, size);
    679   if (!result.ok()) {
    680     result = port::Status(port::error::INTERNAL,
    681                           port::Printf("failed to synchronously memcpy "
    682                                        "device-to-host: device %p to host %p "
    683                                        "size %lld: %s",
    684                                        device_src.opaque(), host_dst, size,
    685                                        result.ToString().c_str()));
    686   }
    687 
    688   return result;
    689 }
    690 
    691 port::Status StreamExecutor::SynchronousMemcpyH2D(
    692     const void *host_src, int64 size, DeviceMemoryBase *device_dst) {
    693   VLOG(1) << "Called StreamExecutor::SynchronousMemcpyH2D(host_src=" << host_src
    694           << ", size=" << size << ", device_dst=" << device_dst->opaque() << ")"
    695           << StackTraceIfVLOG10();
    696 
    697   port::Status result;
    698   SCOPED_TRACE(TraceListener::SynchronousMemcpyH2D, &result, host_src, size,
    699                device_dst);
    700 
    701   result = implementation_->SynchronousMemcpy(device_dst, host_src, size);
    702   if (!result.ok()) {
    703     result = port::Status(
    704         port::error::INTERNAL,
    705         port::Printf("failed to synchronously memcpy host-to-device: host "
    706                      "%p to device %p size %lld: %s",
    707                      host_src, device_dst->opaque(), size,
    708                      result.ToString().c_str()));
    709   }
    710 
    711   return result;
    712 }
    713 
    714 bool StreamExecutor::Memcpy(Stream *stream, void *host_dst,
    715                             const DeviceMemoryBase &device_src, uint64 size) {
    716   return implementation_->Memcpy(stream, host_dst, device_src, size);
    717 }
    718 
    719 bool StreamExecutor::Memcpy(Stream *stream, DeviceMemoryBase *device_dst,
    720                             const void *host_src, uint64 size) {
    721   return implementation_->Memcpy(stream, device_dst, host_src, size);
    722 }
    723 
    724 bool StreamExecutor::MemcpyDeviceToDevice(Stream *stream,
    725                                           DeviceMemoryBase *device_dst,
    726                                           const DeviceMemoryBase &device_src,
    727                                           uint64 size) {
    728   return implementation_->MemcpyDeviceToDevice(stream, device_dst, device_src,
    729                                                size);
    730 }
    731 
    732 bool StreamExecutor::MemZero(Stream *stream, DeviceMemoryBase *location,
    733                              uint64 size) {
    734   return implementation_->MemZero(stream, location, size);
    735 }
    736 
    737 bool StreamExecutor::Memset32(Stream *stream, DeviceMemoryBase *location,
    738                               uint32 pattern, uint64 size) {
    739   CHECK_EQ(0, size % 4)
    740       << "need 32-bit multiple size to fill with 32-bit pattern";
    741   return implementation_->Memset32(stream, location, pattern, size);
    742 }
    743 
    744 bool StreamExecutor::HostCallback(Stream *stream,
    745                                   std::function<void()> callback) {
    746   return implementation_->HostCallback(stream, std::move(callback));
    747 }
    748 
    749 bool StreamExecutor::HostCallback(Stream *stream,
    750                                   std::function<port::Status()> callback) {
    751   return implementation_->HostCallback(stream, std::move(callback));
    752 }
    753 
    754 port::Status StreamExecutor::AllocateEvent(Event *event) {
    755   return implementation_->AllocateEvent(event);
    756 }
    757 
    758 port::Status StreamExecutor::DeallocateEvent(Event *event) {
    759   return implementation_->DeallocateEvent(event);
    760 }
    761 
    762 port::Status StreamExecutor::RecordEvent(Stream *stream, Event *event) {
    763   return implementation_->RecordEvent(stream, event);
    764 }
    765 
    766 port::Status StreamExecutor::WaitForEvent(Stream *stream, Event *event) {
    767   return implementation_->WaitForEvent(stream, event);
    768 }
    769 
    770 Event::Status StreamExecutor::PollForEventStatus(Event *event) {
    771   return implementation_->PollForEventStatus(event);
    772 }
    773 
    774 bool StreamExecutor::AllocateStream(Stream *stream) {
    775   live_stream_count_.fetch_add(1, std::memory_order_relaxed);
    776   if (!implementation_->AllocateStream(stream)) {
    777     auto count = live_stream_count_.fetch_sub(1);
    778     CHECK_GE(count, 0) << "live stream count should not dip below zero";
    779     LOG(INFO) << "failed to allocate stream; live stream count: " << count;
    780     return false;
    781   }
    782 
    783   return true;
    784 }
    785 
    786 void StreamExecutor::DeallocateStream(Stream *stream) {
    787   implementation_->DeallocateStream(stream);
    788   CHECK_GE(live_stream_count_.fetch_sub(1), 0)
    789       << "live stream count should not dip below zero";
    790 }
    791 
    792 bool StreamExecutor::CreateStreamDependency(Stream *dependent, Stream *other) {
    793   return implementation_->CreateStreamDependency(dependent, other);
    794 }
    795 
    796 bool StreamExecutor::AllocateTimer(Timer *timer) {
    797   return implementation_->AllocateTimer(timer);
    798 }
    799 
    800 void StreamExecutor::DeallocateTimer(Timer *timer) {
    801   return implementation_->DeallocateTimer(timer);
    802 }
    803 
    804 bool StreamExecutor::StartTimer(Stream *stream, Timer *timer) {
    805   return implementation_->StartTimer(stream, timer);
    806 }
    807 
    808 bool StreamExecutor::StopTimer(Stream *stream, Timer *timer) {
    809   return implementation_->StopTimer(stream, timer);
    810 }
    811 
    812 DeviceDescription *StreamExecutor::PopulateDeviceDescription() const {
    813   return implementation_->PopulateDeviceDescription();
    814 }
    815 
    816 bool StreamExecutor::DeviceMemoryUsage(int64 *free, int64 *total) const {
    817   return implementation_->DeviceMemoryUsage(free, total);
    818 }
    819 
    820 void StreamExecutor::EnqueueOnBackgroundThread(std::function<void()> task) {
    821   background_threads_->Schedule(std::move(task));
    822 }
    823 
    824 void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) {
    825   if (FLAGS_check_device_leaks && opaque != nullptr && bytes != 0) {
    826     mutex_lock lock(mu_);
    827     mem_allocs_[opaque] = AllocRecord{
    828         bytes, ""};
    829     mem_alloc_bytes_ += bytes;
    830   }
    831 }
    832 
    833 void StreamExecutor::EraseAllocRecord(void *opaque) {
    834   if (FLAGS_check_device_leaks && opaque != nullptr) {
    835     mutex_lock lock(mu_);
    836     if (mem_allocs_.find(opaque) == mem_allocs_.end()) {
    837       LOG(ERROR) << "Deallocating unknown pointer: "
    838                  << port::Printf("0x%p", opaque);
    839     } else {
    840       mem_alloc_bytes_ -= mem_allocs_[opaque].bytes;
    841       mem_allocs_.erase(opaque);
    842     }
    843   }
    844 }
    845 
    846 void StreamExecutor::EnableTracing(bool enabled) { tracing_enabled_ = enabled; }
    847 
    848 void StreamExecutor::RegisterTraceListener(TraceListener *listener) {
    849   {
    850     mutex_lock lock(mu_);
    851     if (listeners_.find(listener) != listeners_.end()) {
    852       LOG(INFO) << "Attempt to register already-registered listener, "
    853                 << listener;
    854     } else {
    855       listeners_.insert(listener);
    856     }
    857   }
    858 
    859   implementation_->RegisterTraceListener(listener);
    860 }
    861 
    862 bool StreamExecutor::UnregisterTraceListener(TraceListener *listener) {
    863   {
    864     mutex_lock lock(mu_);
    865     if (listeners_.find(listener) == listeners_.end()) {
    866       LOG(INFO) << "Attempt to unregister unknown listener, " << listener;
    867       return false;
    868     }
    869     listeners_.erase(listener);
    870   }
    871 
    872   implementation_->UnregisterTraceListener(listener);
    873   return true;
    874 }
    875 
    876 absl::optional<AllocatorStats> StreamExecutor::GetAllocatorStats() {
    877   return implementation_->GetAllocatorStats();
    878 }
    879 
    880 template <typename TraceCallT, typename... ArgsT>
    881 void StreamExecutor::SubmitTrace(TraceCallT trace_call, ArgsT &&... args) {
    882   if (tracing_enabled_) {
    883     {
    884       // instance tracers held in a block to limit the lock lifetime.
    885       tf_shared_lock lock(mu_);
    886       for (TraceListener *listener : listeners_) {
    887         (listener->*trace_call)(std::forward<ArgsT>(args)...);
    888       }
    889     }
    890   }
    891 }
    892 
    893 internal::StreamExecutorInterface *StreamExecutor::implementation() {
    894   return implementation_->GetUnderlyingExecutor();
    895 }
    896 
    897 }  // namespace stream_executor
    898