Home | History | Annotate | Download | only in stream_executor
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Implements the StreamExecutor interface by passing through to its
     17 // implementation_ value (in pointer-to-implementation style), which
     18 // implements StreamExecutorInterface.
     19 
     20 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
     21 
     22 #include <atomic>
     23 #include <utility>
     24 
     25 #include "tensorflow/stream_executor/blas.h"
     26 #include "tensorflow/stream_executor/fft.h"
     27 #include "tensorflow/stream_executor/lib/env.h"
     28 #include "tensorflow/stream_executor/lib/error.h"
     29 #include "tensorflow/stream_executor/lib/notification.h"
     30 #include "tensorflow/stream_executor/lib/stacktrace.h"
     31 #include "tensorflow/stream_executor/lib/str_util.h"
     32 #include "tensorflow/stream_executor/lib/stringprintf.h"
     33 #include "tensorflow/stream_executor/lib/threadpool.h"
     34 #include "tensorflow/stream_executor/platform/port.h"
     35 #include "tensorflow/stream_executor/rng.h"
     36 #include "tensorflow/stream_executor/stream_executor_internal.h"
     37 
     38 namespace {
     39 bool FLAGS_check_device_leaks = false;
     40 }  // namespace
     41 
     42 namespace perftools {
     43 namespace gputools {
     44 namespace {
     45 
     46 string StackTraceIfVLOG10() {
     47   if (VLOG_IS_ON(10)) {
     48     return port::StrCat(" ", port::CurrentStackTrace(), "\n");
     49   } else {
     50     return "";
     51   }
     52 }
     53 
     54 // Make sure the executor is done with its work; we know (because this isn't
     55 // publicly visible) that all enqueued work is quick.
     56 void BlockOnThreadExecutor(port::ThreadPool *executor) {
     57   port::Notification n;
     58   executor->Schedule([&n]() { n.Notify(); });
     59   n.WaitForNotification();
     60 }
     61 
     62 internal::StreamExecutorInterface *StreamExecutorImplementationFromPlatformKind(
     63     PlatformKind platform_kind, const PluginConfig &plugin_config) {
     64   // Note: we use this factory-assignment-in-switch pattern instead of just
     65   // invoking the callable in case linkage is messed up -- instead of invoking a
     66   // nullptr std::function (due to failed registration) we give a nice
     67   // LOG(FATAL) message.
     68   internal::StreamExecutorFactory factory;
     69   switch (platform_kind) {
     70     case PlatformKind::kCuda:
     71       factory = *internal::MakeCUDAExecutorImplementation();
     72       break;
     73     case PlatformKind::kOpenCL:
     74       factory = *internal::MakeOpenCLExecutorImplementation();
     75       break;
     76     case PlatformKind::kHost:
     77       factory = internal::MakeHostExecutorImplementation;
     78       break;
     79     default:
     80       factory = nullptr;
     81   }
     82   if (factory == nullptr) {
     83     LOG(FATAL)
     84         << "cannot create StreamExecutor implementation for platform kind: "
     85         << PlatformKindString(platform_kind);
     86   }
     87   return factory(plugin_config);
     88 }
     89 
     90 std::atomic_int_fast64_t correlation_id_generator(0);
     91 
     92 }  // namespace
     93 
     94 template <typename BeginCallT, typename CompleteCallT,
     95           typename ReturnT, typename... BeginArgsT>
     96 class ScopedTracer {
     97  public:
     98   ScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
     99                CompleteCallT complete_call, const ReturnT *result,
    100                BeginArgsT... begin_args)
    101       : stream_exec_(stream_exec),
    102         complete_call_(complete_call),
    103         result_(result) {
    104     if (stream_exec_->tracing_enabled_) {
    105       correlation_id_ =
    106           correlation_id_generator.fetch_add(1, std::memory_order_relaxed) - 1;
    107       Trace(begin_call, begin_args...);
    108     }
    109   }
    110 
    111   ~ScopedTracer() {
    112     if (stream_exec_->tracing_enabled_) {
    113       Trace(complete_call_, result_);
    114     }
    115   }
    116 
    117  private:
    118   template <typename CallbackT, typename... TraceArgsT>
    119   void Trace(CallbackT callback, TraceArgsT... args) {
    120     {
    121       // Instance tracers held in a block to limit the lock lifetime.
    122       tf_shared_lock lock{stream_exec_->mu_};
    123       for (TraceListener *listener : stream_exec_->listeners_) {
    124         (listener->*callback)(correlation_id_,
    125                               std::forward<TraceArgsT>(args)...);
    126       }
    127     }
    128   }
    129 
    130   StreamExecutor *stream_exec_;
    131   CompleteCallT complete_call_;
    132   const ReturnT* result_;
    133   int64 correlation_id_;
    134 };
    135 
    136 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
    137           typename... BeginArgsT>
    138 ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>
    139 MakeScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
    140                  CompleteCallT complete_call, ReturnT *result,
    141                  BeginArgsT... begin_args) {
    142   return ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>(
    143       stream_exec, begin_call, complete_call, result,
    144       std::forward<BeginArgsT>(begin_args)...);
    145 }
    146 
    147 #define SCOPED_TRACE(LOC, ...)                                      \
    148   auto tracer = MakeScopedTracer(this, &LOC ## Begin,               \
    149                                  &LOC ## Complete, ## __VA_ARGS__);
    150 
    151 /* static */ mutex StreamExecutor::static_mu_{LINKER_INITIALIZED};
    152 
    153 StreamExecutor::StreamExecutor(PlatformKind platform_kind,
    154                                const PluginConfig &plugin_config)
    155     : platform_(nullptr),
    156       implementation_(StreamExecutorImplementationFromPlatformKind(
    157           platform_kind, plugin_config)),
    158       platform_kind_(platform_kind),
    159       device_ordinal_(-1),
    160       background_threads_(new port::ThreadPool(
    161           port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
    162       live_stream_count_(0),
    163       tracing_enabled_(false) {
    164   CheckPlatformKindIsValid(platform_kind);
    165 }
    166 
    167 StreamExecutor::StreamExecutor(
    168     const Platform *platform,
    169     std::unique_ptr<internal::StreamExecutorInterface> implementation)
    170     : platform_(platform),
    171       implementation_(std::move(implementation)),
    172       device_ordinal_(-1),
    173       background_threads_(new port::ThreadPool(
    174           port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
    175       live_stream_count_(0),
    176       tracing_enabled_(false) {
    177   if (port::Lowercase(platform_->Name()) == "cuda") {
    178     platform_kind_ = PlatformKind::kCuda;
    179   } else if (port::Lowercase(platform_->Name()) == "opencl") {
    180     platform_kind_ = PlatformKind::kOpenCL;
    181   } else if (port::Lowercase(platform_->Name()) == "host") {
    182     platform_kind_ = PlatformKind::kHost;
    183   }
    184 }
    185 
    186 StreamExecutor::~StreamExecutor() {
    187   BlockOnThreadExecutor(background_threads_.get());
    188 
    189   if (live_stream_count_.load() != 0) {
    190     LOG(WARNING) << "Not all streams were deallocated at executor destruction "
    191                  << "time. This may lead to unexpected/bad behavior - "
    192                  << "especially if any stream is still active!";
    193   }
    194 
    195   if (FLAGS_check_device_leaks) {
    196     for (auto it : mem_allocs_) {
    197       LOG(INFO) << "Memory alloced at executor exit: addr: "
    198                 << port::Printf("%p", it.first)
    199                 << ", bytes: " << it.second.bytes << ", trace: \n"
    200                 << it.second.stack_trace;
    201     }
    202   }
    203 }
    204 
    205 port::Status StreamExecutor::Init(int device_ordinal,
    206                                   DeviceOptions device_options) {
    207   device_ordinal_ = device_ordinal;
    208   return implementation_->Init(device_ordinal, std::move(device_options));
    209 }
    210 
    211 port::Status StreamExecutor::Init() {
    212   return Init(0, DeviceOptions::Default());
    213 }
    214 
    215 bool StreamExecutor::GetKernel(const MultiKernelLoaderSpec &spec,
    216                                KernelBase *kernel) {
    217   return implementation_->GetKernel(spec, kernel);
    218 }
    219 
    220 void StreamExecutor::UnloadKernel(const KernelBase *kernel) {
    221   implementation_->UnloadKernel(kernel);
    222 }
    223 
    224 void StreamExecutor::Deallocate(DeviceMemoryBase *mem) {
    225   VLOG(1) << "Called StreamExecutor::Deallocate(mem=" << mem->opaque()
    226           << ") mem->size()=" << mem->size() << StackTraceIfVLOG10();
    227 
    228   if (mem->opaque() != nullptr) {
    229     EraseAllocRecord(mem->opaque());
    230   }
    231   implementation_->Deallocate(mem);
    232   mem->Reset(nullptr, 0);
    233 }
    234 
    235 void StreamExecutor::GetMemAllocs(std::map<void *, AllocRecord> *records_out) {
    236   tf_shared_lock lock{mu_};
    237   *records_out = mem_allocs_;
    238 }
    239 
    240 bool StreamExecutor::CanEnablePeerAccessTo(StreamExecutor *other) {
    241   return implementation_->CanEnablePeerAccessTo(other->implementation_.get());
    242 }
    243 
    244 port::Status StreamExecutor::EnablePeerAccessTo(StreamExecutor *other) {
    245   return implementation_->EnablePeerAccessTo(other->implementation_.get());
    246 }
    247 
    248 SharedMemoryConfig StreamExecutor::GetDeviceSharedMemoryConfig() {
    249   return implementation_->GetDeviceSharedMemoryConfig();
    250 }
    251 
    252 port::Status StreamExecutor::SetDeviceSharedMemoryConfig(
    253     SharedMemoryConfig config) {
    254   if (config != SharedMemoryConfig::kDefault &&
    255       config != SharedMemoryConfig::kFourByte &&
    256       config != SharedMemoryConfig::kEightByte) {
    257     string error_msg = port::Printf(
    258         "Invalid shared memory config specified: %d", static_cast<int>(config));
    259     LOG(ERROR) << error_msg;
    260     return port::Status{port::error::INVALID_ARGUMENT, error_msg};
    261   }
    262   return implementation_->SetDeviceSharedMemoryConfig(config);
    263 }
    264 
    265 const DeviceDescription &StreamExecutor::GetDeviceDescription() const {
    266   mutex_lock lock{mu_};
    267   if (device_description_ != nullptr) {
    268     return *device_description_;
    269   }
    270 
    271   device_description_.reset(PopulateDeviceDescription());
    272   return *device_description_;
    273 }
    274 
    275 int64 StreamExecutor::GetDeviceLoad() const {
    276   return implementation_->GetDeviceLoad();
    277 }
    278 
    279 int StreamExecutor::PlatformDeviceCount() const {
    280   return implementation_->PlatformDeviceCount();
    281 }
    282 
    283 bool StreamExecutor::SupportsBlas() const {
    284   return implementation_->SupportsBlas();
    285 }
    286 
    287 bool StreamExecutor::SupportsRng() const {
    288   return implementation_->SupportsRng();
    289 }
    290 
    291 bool StreamExecutor::SupportsDnn() const {
    292   return implementation_->SupportsDnn();
    293 }
    294 
    295 bool StreamExecutor::GetConvolveAlgorithms(
    296     bool with_winograd_nonfused,
    297     std::vector<dnn::AlgorithmDesc> *out_algorithms) {
    298   dnn::DnnSupport *dnn_support = AsDnn();
    299   if (!dnn_support) {
    300     return false;
    301   }
    302   int cc_major, cc_minor;
    303   GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
    304   return dnn_support->GetConvolveAlgorithms(with_winograd_nonfused, cc_major,
    305                                             cc_minor, out_algorithms);
    306 }
    307 
    308 bool StreamExecutor::GetConvolveBackwardDataAlgorithms(
    309     bool with_winograd_nonfused,
    310     std::vector<dnn::AlgorithmDesc> *out_algorithms) {
    311   dnn::DnnSupport *dnn_support = AsDnn();
    312   if (!dnn_support) {
    313     return false;
    314   }
    315   int cc_major, cc_minor;
    316   GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
    317   return dnn_support->GetConvolveBackwardDataAlgorithms(
    318       with_winograd_nonfused, cc_major, cc_minor, out_algorithms);
    319 }
    320 
    321 bool StreamExecutor::GetConvolveBackwardFilterAlgorithms(
    322     bool with_winograd_nonfused,
    323     std::vector<dnn::AlgorithmDesc> *out_algorithms) {
    324   dnn::DnnSupport *dnn_support = AsDnn();
    325   if (!dnn_support) {
    326     return false;
    327   }
    328   int cc_major, cc_minor;
    329   GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
    330   return dnn_support->GetConvolveBackwardFilterAlgorithms(
    331       with_winograd_nonfused, cc_major, cc_minor, out_algorithms);
    332 }
    333 
    334 bool StreamExecutor::GetBlasGemmAlgorithms(
    335     std::vector<blas::AlgorithmType> *out_algorithms) {
    336   blas::BlasSupport *blas_support = AsBlas();
    337   if (!blas_support) {
    338     return false;
    339   }
    340   return blas_support->GetBlasGemmAlgorithms(out_algorithms);
    341 }
    342 
    343 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
    344 StreamExecutor::createRnnDescriptor(
    345     int num_layers, int hidden_size, int input_size,
    346     dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
    347     dnn::RnnMode rnn_mode, dnn::DataType data_type, float dropout, uint64 seed,
    348     ScratchAllocator *state_allocator) {
    349   dnn::DnnSupport *dnn_support = AsDnn();
    350   if (!dnn_support) {
    351     return port::Status(port::error::UNKNOWN,
    352                         "Fail to find the dnn implementation.");
    353   }
    354   return dnn_support->createRnnDescriptor(
    355       num_layers, hidden_size, input_size, input_mode, direction_mode, rnn_mode,
    356       data_type, dropout, seed, state_allocator);
    357 }
    358 
    359 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
    360 StreamExecutor::createRnnSequenceTensorDescriptor(int seq_length,
    361                                                   int batch_size, int data_size,
    362                                                   dnn::DataType data_type) {
    363   dnn::DnnSupport *dnn_support = AsDnn();
    364   if (!dnn_support) {
    365     return port::Status(port::error::UNKNOWN,
    366                         "Fail to find the dnn implementation.");
    367   }
    368   return dnn_support->createRnnSequenceTensorDescriptor(seq_length, batch_size,
    369                                                         data_size, data_type);
    370 }
    371 
    372 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
    373 StreamExecutor::createRnnStateTensorDescriptor(int num_layer, int batch_size,
    374                                                int data_size,
    375                                                dnn::DataType data_type) {
    376   dnn::DnnSupport *dnn_support = AsDnn();
    377   if (!dnn_support) {
    378     return port::Status(port::error::UNKNOWN,
    379                         "Fail to find the dnn implementation.");
    380   }
    381   return dnn_support->createRnnStateTensorDescriptor(num_layer, batch_size,
    382                                                      data_size, data_type);
    383 }
    384 
    385 dnn::DnnSupport *StreamExecutor::AsDnn() {
    386   mutex_lock lock{mu_};
    387   if (dnn_ != nullptr) {
    388     return dnn_.get();
    389   }
    390 
    391   dnn_.reset(implementation_->CreateDnn());
    392   return dnn_.get();
    393 }
    394 
    395 blas::BlasSupport *StreamExecutor::AsBlas() {
    396   mutex_lock lock{mu_};
    397   if (blas_ != nullptr) {
    398     return blas_.get();
    399   }
    400 
    401   blas_.reset(implementation_->CreateBlas());
    402   return blas_.get();
    403 }
    404 
    405 fft::FftSupport *StreamExecutor::AsFft() {
    406   mutex_lock lock{mu_};
    407   if (fft_ != nullptr) {
    408     return fft_.get();
    409   }
    410 
    411   fft_.reset(implementation_->CreateFft());
    412   return fft_.get();
    413 }
    414 
    415 rng::RngSupport *StreamExecutor::AsRng() {
    416   mutex_lock lock{mu_};
    417   if (rng_ != nullptr) {
    418     return rng_.get();
    419   }
    420 
    421   rng_.reset(implementation_->CreateRng());
    422   return rng_.get();
    423 }
    424 
    425 bool StreamExecutor::Launch(Stream *stream, const ThreadDim &thread_dims,
    426                             const BlockDim &block_dims,
    427                             const KernelBase &kernel,
    428                             const KernelArgsArrayBase &args) {
    429   SubmitTrace(&TraceListener::LaunchSubmit, stream, thread_dims, block_dims,
    430               kernel, args);
    431 
    432   return implementation_->Launch(stream, thread_dims, block_dims, kernel, args);
    433 }
    434 
    435 port::Status StreamExecutor::BlockHostUntilDone(Stream *stream) {
    436   port::Status result;
    437   SCOPED_TRACE(TraceListener::BlockHostUntilDone, &result, stream);
    438 
    439   result = implementation_->BlockHostUntilDone(stream);
    440   return result;
    441 }
    442 
    443 void *StreamExecutor::Allocate(uint64 size) {
    444   void *buf = implementation_->Allocate(size);
    445   VLOG(1) << "Called StreamExecutor::Allocate(size=" << size << ") returns "
    446           << buf << StackTraceIfVLOG10();
    447   CreateAllocRecord(buf, size);
    448 
    449   return buf;
    450 }
    451 
    452 bool StreamExecutor::GetSymbol(const string &symbol_name, void **mem,
    453                                size_t *bytes) {
    454   return implementation_->GetSymbol(symbol_name, mem, bytes);
    455 }
    456 
    457 void *StreamExecutor::HostMemoryAllocate(uint64 size) {
    458   void *buffer = implementation_->HostMemoryAllocate(size);
    459   VLOG(1) << "Called StreamExecutor::HostMemoryAllocate(size=" << size
    460           << ") returns " << buffer << StackTraceIfVLOG10();
    461   return buffer;
    462 }
    463 
    464 void StreamExecutor::HostMemoryDeallocate(void *location) {
    465   VLOG(1) << "Called StreamExecutor::HostMemoryDeallocate(location=" << location
    466           << ")" << StackTraceIfVLOG10();
    467 
    468   return implementation_->HostMemoryDeallocate(location);
    469 }
    470 
    471 bool StreamExecutor::HostMemoryRegister(void *location, uint64 size) {
    472   VLOG(1) << "Called StreamExecutor::HostMemoryRegister(location=" << location
    473           << ", size=" << size << ")" << StackTraceIfVLOG10();
    474   if (location == nullptr || size == 0) {
    475     LOG(WARNING) << "attempting to register null or zero-sized memory: "
    476                  << location << "; size " << size;
    477   }
    478   return implementation_->HostMemoryRegister(location, size);
    479 }
    480 
    481 bool StreamExecutor::HostMemoryUnregister(void *location) {
    482   VLOG(1) << "Called StreamExecutor::HostMemoryUnregister(location=" << location
    483           << ")" << StackTraceIfVLOG10();
    484   return implementation_->HostMemoryUnregister(location);
    485 }
    486 
    487 bool StreamExecutor::SynchronizeAllActivity() {
    488   VLOG(1) << "Called StreamExecutor::SynchronizeAllActivity()"
    489           << StackTraceIfVLOG10();
    490   bool ok = implementation_->SynchronizeAllActivity();
    491 
    492   // This should all be quick and infallible work, so we can perform the
    493   // synchronization even in the case of failure.
    494   BlockOnThreadExecutor(background_threads_.get());
    495 
    496   return ok;
    497 }
    498 
    499 bool StreamExecutor::SynchronousMemZero(DeviceMemoryBase *location,
    500                                         uint64 size) {
    501   VLOG(1) << "Called StreamExecutor::SynchronousMemZero(location=" << location
    502           << ", size=" << size << ")" << StackTraceIfVLOG10();
    503 
    504   return implementation_->SynchronousMemZero(location, size);
    505 }
    506 
    507 bool StreamExecutor::SynchronousMemSet(DeviceMemoryBase *location, int value,
    508                                        uint64 size) {
    509   VLOG(1) << "Called StreamExecutor::SynchronousMemSet(location=" << location
    510           << ", value=" << value << ", size=" << size << ")"
    511           << StackTraceIfVLOG10();
    512 
    513   return implementation_->SynchronousMemSet(location, value, size);
    514 }
    515 
    516 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst,
    517                                        const void *host_src, uint64 size) {
    518   VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst="
    519           << device_dst->opaque() << ", host_src=" << host_src
    520           << ", size=" << size << ") H2D" << StackTraceIfVLOG10();
    521 
    522   // Tracing overloaded methods is very difficult due to issues with type
    523   // inference on template args. Since use of these overloaded methods is
    524   // discouraged anyway, this isn't a huge deal.
    525   port::Status status =
    526       implementation_->SynchronousMemcpy(device_dst, host_src, size);
    527   if (!status.ok()) {
    528     LOG(ERROR) << "synchronous memcpy: " << status;
    529   }
    530   return status.ok();
    531 }
    532 
    533 bool StreamExecutor::SynchronousMemcpy(void *host_dst,
    534                                        const DeviceMemoryBase &device_src,
    535                                        uint64 size) {
    536   VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(host_dst=" << host_dst
    537           << ", device_src=" << device_src.opaque() << ", size=" << size
    538           << ") D2H" << StackTraceIfVLOG10();
    539 
    540   port::Status status =
    541       implementation_->SynchronousMemcpy(host_dst, device_src, size);
    542   if (!status.ok()) {
    543     LOG(ERROR) << "synchronous memcpy: " << status;
    544   }
    545   return status.ok();
    546 }
    547 
    548 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst,
    549                                        const DeviceMemoryBase &device_src,
    550                                        uint64 size) {
    551   VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst="
    552           << device_dst->opaque() << ", device_src=" << device_src.opaque()
    553           << ", size=" << size << ") D2D" << StackTraceIfVLOG10();
    554 
    555   port::Status status = implementation_->SynchronousMemcpyDeviceToDevice(
    556       device_dst, device_src, size);
    557   if (!status.ok()) {
    558     LOG(ERROR) << "synchronous memcpy: " << status;
    559   }
    560   return status.ok();
    561 }
    562 
    563 port::Status StreamExecutor::SynchronousMemcpyD2H(
    564     const DeviceMemoryBase &device_src, int64 size, void *host_dst) {
    565   VLOG(1) << "Called StreamExecutor::SynchronousMemcpyD2H(device_src="
    566           << device_src.opaque() << ", size=" << size
    567           << ", host_dst=" << host_dst << ")" << StackTraceIfVLOG10();
    568 
    569   port::Status result;
    570   SCOPED_TRACE(TraceListener::SynchronousMemcpyD2H, &result, device_src, size,
    571                host_dst);
    572 
    573   result = implementation_->SynchronousMemcpy(host_dst, device_src, size);
    574   if (!result.ok()) {
    575     result = port::Status{port::error::INTERNAL,
    576                           port::Printf("failed to synchronously memcpy "
    577                                        "device-to-host: device %p to host %p "
    578                                        "size %lld: %s",
    579                                        device_src.opaque(), host_dst, size,
    580                                        result.ToString().c_str())};
    581   }
    582 
    583   return result;
    584 }
    585 
    586 port::Status StreamExecutor::SynchronousMemcpyH2D(
    587     const void *host_src, int64 size, DeviceMemoryBase *device_dst) {
    588   VLOG(1) << "Called StreamExecutor::SynchronousMemcpyH2D(host_src=" << host_src
    589           << ", size=" << size << ", device_dst" << device_dst->opaque() << ")"
    590           << StackTraceIfVLOG10();
    591 
    592   port::Status result;
    593   SCOPED_TRACE(TraceListener::SynchronousMemcpyH2D, &result, host_src, size,
    594                device_dst);
    595 
    596   result = implementation_->SynchronousMemcpy(device_dst, host_src, size);
    597   if (!result.ok()) {
    598     result = port::Status{
    599         port::error::INTERNAL,
    600         port::Printf("failed to synchronously memcpy host-to-device: host "
    601                      "%p to device %p size %lld: %s",
    602                      host_src, device_dst->opaque(), size,
    603                      result.ToString().c_str())};
    604   }
    605 
    606   return result;
    607 }
    608 
    609 bool StreamExecutor::Memcpy(Stream *stream, void *host_dst,
    610                             const DeviceMemoryBase &device_src, uint64 size) {
    611   return implementation_->Memcpy(stream, host_dst, device_src, size);
    612 }
    613 
    614 bool StreamExecutor::Memcpy(Stream *stream, DeviceMemoryBase *device_dst,
    615                             const void *host_src, uint64 size) {
    616   return implementation_->Memcpy(stream, device_dst, host_src, size);
    617 }
    618 
    619 bool StreamExecutor::MemcpyDeviceToDevice(Stream *stream,
    620                                           DeviceMemoryBase *device_dst,
    621                                           const DeviceMemoryBase &device_src,
    622                                           uint64 size) {
    623   return implementation_->MemcpyDeviceToDevice(stream, device_dst, device_src,
    624                                                size);
    625 }
    626 
    627 bool StreamExecutor::MemZero(Stream *stream, DeviceMemoryBase *location,
    628                              uint64 size) {
    629   return implementation_->MemZero(stream, location, size);
    630 }
    631 
    632 bool StreamExecutor::Memset32(Stream *stream, DeviceMemoryBase *location,
    633                               uint32 pattern, uint64 size) {
    634   CHECK_EQ(0, size % 4)
    635       << "need 32-bit multiple size to fill with 32-bit pattern";
    636   return implementation_->Memset32(stream, location, pattern, size);
    637 }
    638 
    639 bool StreamExecutor::HostCallback(Stream *stream,
    640                                   std::function<void()> callback) {
    641   return implementation_->HostCallback(stream, std::move(callback));
    642 }
    643 
    644 port::Status StreamExecutor::AllocateEvent(Event *event) {
    645   return implementation_->AllocateEvent(event);
    646 }
    647 
    648 port::Status StreamExecutor::DeallocateEvent(Event *event) {
    649   return implementation_->DeallocateEvent(event);
    650 }
    651 
    652 port::Status StreamExecutor::RecordEvent(Stream *stream, Event *event) {
    653   return implementation_->RecordEvent(stream, event);
    654 }
    655 
    656 port::Status StreamExecutor::WaitForEvent(Stream *stream, Event *event) {
    657   return implementation_->WaitForEvent(stream, event);
    658 }
    659 
    660 Event::Status StreamExecutor::PollForEventStatus(Event *event) {
    661   return implementation_->PollForEventStatus(event);
    662 }
    663 
    664 bool StreamExecutor::AllocateStream(Stream *stream) {
    665   live_stream_count_.fetch_add(1, std::memory_order_relaxed);
    666   if (!implementation_->AllocateStream(stream)) {
    667     auto count = live_stream_count_.fetch_sub(1);
    668     CHECK_GE(count, 0) << "live stream count should not dip below zero";
    669     LOG(INFO) << "failed to allocate stream; live stream count: " << count;
    670     return false;
    671   }
    672 
    673   return true;
    674 }
    675 
    676 void StreamExecutor::DeallocateStream(Stream *stream) {
    677   implementation_->DeallocateStream(stream);
    678   CHECK_GE(live_stream_count_.fetch_sub(1), 0)
    679       << "live stream count should not dip below zero";
    680 }
    681 
    682 bool StreamExecutor::CreateStreamDependency(Stream *dependent, Stream *other) {
    683   return implementation_->CreateStreamDependency(dependent, other);
    684 }
    685 
    686 bool StreamExecutor::AllocateTimer(Timer *timer) {
    687   return implementation_->AllocateTimer(timer);
    688 }
    689 
    690 void StreamExecutor::DeallocateTimer(Timer *timer) {
    691   return implementation_->DeallocateTimer(timer);
    692 }
    693 
    694 bool StreamExecutor::StartTimer(Stream *stream, Timer *timer) {
    695   return implementation_->StartTimer(stream, timer);
    696 }
    697 
    698 bool StreamExecutor::StopTimer(Stream *stream, Timer *timer) {
    699   return implementation_->StopTimer(stream, timer);
    700 }
    701 
    702 DeviceDescription *StreamExecutor::PopulateDeviceDescription() const {
    703   return implementation_->PopulateDeviceDescription();
    704 }
    705 
    706 bool StreamExecutor::DeviceMemoryUsage(int64 *free, int64 *total) const {
    707   return implementation_->DeviceMemoryUsage(free, total);
    708 }
    709 
    710 void StreamExecutor::EnqueueOnBackgroundThread(std::function<void()> task) {
    711   background_threads_->Schedule(std::move(task));
    712 }
    713 
    714 void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) {
    715   if (FLAGS_check_device_leaks && opaque != nullptr && bytes != 0) {
    716     mutex_lock lock{mu_};
    717     mem_allocs_[opaque] = AllocRecord{
    718         bytes, ""};
    719   }
    720 }
    721 
    722 void StreamExecutor::EraseAllocRecord(void *opaque) {
    723   if (FLAGS_check_device_leaks && opaque != nullptr) {
    724     mutex_lock lock{mu_};
    725     if (mem_allocs_.find(opaque) == mem_allocs_.end()) {
    726       LOG(ERROR) << "Deallocating unknown pointer: "
    727                  << port::Printf("0x%p", opaque);
    728     } else {
    729       mem_allocs_.erase(opaque);
    730     }
    731   }
    732 }
    733 
    734 void StreamExecutor::EnableTracing(bool enabled) { tracing_enabled_ = enabled; }
    735 
    736 void StreamExecutor::RegisterTraceListener(TraceListener *listener) {
    737   {
    738     mutex_lock lock{mu_};
    739     if (listeners_.find(listener) != listeners_.end()) {
    740       LOG(INFO) << "Attempt to register already-registered listener, "
    741                 << listener;
    742     } else {
    743       listeners_.insert(listener);
    744     }
    745   }
    746 
    747   implementation_->RegisterTraceListener(listener);
    748 }
    749 
    750 bool StreamExecutor::UnregisterTraceListener(TraceListener *listener) {
    751   {
    752     mutex_lock lock{mu_};
    753     if (listeners_.find(listener) == listeners_.end()) {
    754       LOG(INFO) << "Attempt to unregister unknown listener, " << listener;
    755       return false;
    756     }
    757     listeners_.erase(listener);
    758   }
    759 
    760   implementation_->UnregisterTraceListener(listener);
    761   return true;
    762 }
    763 
    764 template <typename TraceCallT, typename... ArgsT>
    765 void StreamExecutor::SubmitTrace(TraceCallT trace_call, ArgsT &&... args) {
    766   if (tracing_enabled_) {
    767     {
    768       // instance tracers held in a block to limit the lock lifetime.
    769       tf_shared_lock lock{mu_};
    770       for (TraceListener *listener : listeners_) {
    771         (listener->*trace_call)(std::forward<ArgsT>(args)...);
    772       }
    773     }
    774   }
    775 }
    776 
    777 internal::StreamExecutorInterface *StreamExecutor::implementation() {
    778   return implementation_->GetUnderlyingExecutor();
    779 }
    780 
    781 }  // namespace gputools
    782 }  // namespace perftools
    783