Home | History | Annotate | Download | only in host
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Implementation of HostExecutor class [of those methods not defined in the
     17 // class declaration].
     18 #include "tensorflow/stream_executor/host/host_gpu_executor.h"
     19 
     20 #include <string.h>
     21 
     22 #include "tensorflow/core/platform/profile_utils/cpu_utils.h"
     23 #include "tensorflow/stream_executor/host/host_platform_id.h"
     24 #include "tensorflow/stream_executor/host/host_stream.h"
     25 #include "tensorflow/stream_executor/host/host_timer.h"
     26 #include "tensorflow/stream_executor/lib/statusor.h"
     27 #include "tensorflow/stream_executor/plugin_registry.h"
     28 
     29 bool FLAGS_stream_executor_cpu_real_clock_rate = false;
     30 
     31 namespace perftools {
     32 namespace gputools {
     33 namespace host {
     34 
     35 HostStream *AsHostStream(Stream *stream) {
     36   DCHECK(stream != nullptr);
     37   return dynamic_cast<HostStream *>(stream->implementation());
     38 }
     39 
     40 HostExecutor::HostExecutor(const PluginConfig &plugin_config)
     41     : plugin_config_(plugin_config) {}
     42 
     43 HostExecutor::~HostExecutor() {}
     44 
     45 void *HostExecutor::Allocate(uint64 size) { return new char[size]; }
     46 
     47 void *HostExecutor::AllocateSubBuffer(DeviceMemoryBase *parent,
     48                                       uint64 offset_bytes, uint64 size_bytes) {
     49   return reinterpret_cast<char *>(parent->opaque()) + offset_bytes;
     50 }
     51 
     52 void HostExecutor::Deallocate(DeviceMemoryBase *mem) {
     53   if (!mem->is_sub_buffer()) {
     54     delete[] static_cast<char *>(mem->opaque());
     55   }
     56 }
     57 
     58 bool HostExecutor::SynchronousMemZero(DeviceMemoryBase *location, uint64 size) {
     59   memset(location->opaque(), 0, size);
     60   return true;
     61 }
     62 
     63 bool HostExecutor::SynchronousMemSet(DeviceMemoryBase *location, int value,
     64                                      uint64 size) {
     65   memset(location->opaque(), value, size);
     66   return true;
     67 }
     68 
     69 bool HostExecutor::Memcpy(Stream *stream, void *host_dst,
     70                           const DeviceMemoryBase &gpu_src, uint64 size) {
     71   // Enqueue the [asynchronous] memcpy on the stream (HostStream) associated
     72   // with the HostExecutor.
     73   void *src_mem = const_cast<void *>(gpu_src.opaque());
     74   AsHostStream(stream)->EnqueueTask(
     75       [host_dst, src_mem, size]() { memcpy(host_dst, src_mem, size); });
     76   return true;
     77 }
     78 
     79 bool HostExecutor::Memcpy(Stream *stream, DeviceMemoryBase *gpu_dst,
     80                           const void *host_src, uint64 size) {
     81   void *dst_mem = gpu_dst->opaque();
     82   // Enqueue the [asynchronous] memcpy on the stream (HostStream) associated
     83   // with the HostExecutor.
     84   AsHostStream(stream)->EnqueueTask(
     85       [dst_mem, host_src, size]() { memcpy(dst_mem, host_src, size); });
     86   return true;
     87 }
     88 
     89 bool HostExecutor::MemcpyDeviceToDevice(Stream *stream,
     90                                         DeviceMemoryBase *gpu_dst,
     91                                         const DeviceMemoryBase &gpu_src,
     92                                         uint64 size) {
     93   void *dst_mem = gpu_dst->opaque();
     94   void *src_mem = const_cast<void *>(gpu_src.opaque());
     95   // Enqueue this [asynchronous] "device-to-device" (i.e., host-to-host, given
     96   // the nature of the HostExecutor) memcpy  on the stream (HostStream)
     97   // associated with the HostExecutor.
     98   AsHostStream(stream)->EnqueueTask(
     99       [src_mem, dst_mem, size]() { memcpy(src_mem, dst_mem, size); });
    100   return true;
    101 }
    102 
    103 bool HostExecutor::MemZero(Stream *stream, DeviceMemoryBase *location,
    104                            uint64 size) {
    105   void *gpu_mem = location->opaque();
    106   // Enqueue the [asynchronous] memzero on the stream (HostStream) associated
    107   // with the HostExecutor.
    108   AsHostStream(stream)->EnqueueTask(
    109       [gpu_mem, size]() { memset(gpu_mem, 0, size); });
    110   return true;
    111 }
    112 
    113 bool HostExecutor::Memset(Stream *stream, DeviceMemoryBase *location,
    114                           uint8 pattern, uint64 size) {
    115   void *gpu_mem = location->opaque();
    116   // Enqueue the [asynchronous] memzero on the stream (HostStream) associated
    117   // with the HostExecutor.
    118   AsHostStream(stream)->EnqueueTask(
    119       [gpu_mem, size, pattern]() { memset(gpu_mem, pattern, size); });
    120   return true;
    121 }
    122 
    123 bool HostExecutor::Memset32(Stream *stream, DeviceMemoryBase *location,
    124                             uint32 pattern, uint64 size) {
    125   void *gpu_mem = location->opaque();
    126   // Enqueue the [asynchronous] memzero on the stream (HostStream) associated
    127   // with the HostExecutor.
    128   AsHostStream(stream)->EnqueueTask(
    129       [gpu_mem, size, pattern]() { memset(gpu_mem, pattern, size); });
    130   return true;
    131 }
    132 
    133 port::Status HostExecutor::SynchronousMemcpy(DeviceMemoryBase *gpu_dst,
    134                                              const void *host_src,
    135                                              uint64 size) {
    136   memcpy(gpu_dst->opaque(), host_src, size);
    137   return port::Status::OK();
    138 }
    139 
    140 port::Status HostExecutor::SynchronousMemcpy(void *host_dst,
    141                                              const DeviceMemoryBase &gpu_src,
    142                                              uint64 size) {
    143   memcpy(host_dst, gpu_src.opaque(), size);
    144   return port::Status::OK();
    145 }
    146 
    147 port::Status HostExecutor::SynchronousMemcpyDeviceToDevice(
    148     DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src, uint64 size) {
    149   memcpy(gpu_dst->opaque(), gpu_src.opaque(), size);
    150   return port::Status::OK();
    151 }
    152 
    153 bool HostExecutor::HostCallback(Stream *stream,
    154                                 std::function<void()> callback) {
    155   AsHostStream(stream)->EnqueueTask(callback);
    156   return true;
    157 }
    158 
    159 bool HostExecutor::AllocateStream(Stream *stream) { return true; }
    160 
    161 void HostExecutor::DeallocateStream(Stream *stream) {}
    162 
    163 bool HostExecutor::CreateStreamDependency(Stream *dependent, Stream *other) {
    164   AsHostStream(dependent)->EnqueueTask(
    165       [other]() { SE_CHECK_OK(other->BlockHostUntilDone()); });
    166   AsHostStream(dependent)->BlockUntilDone();
    167   return true;
    168 }
    169 
    170 bool HostExecutor::StartTimer(Stream *stream, Timer *timer) {
    171   dynamic_cast<HostTimer *>(timer->implementation())->Start(stream);
    172   return true;
    173 }
    174 
    175 bool HostExecutor::StopTimer(Stream *stream, Timer *timer) {
    176   dynamic_cast<HostTimer *>(timer->implementation())->Stop(stream);
    177   return true;
    178 }
    179 
    180 port::Status HostExecutor::BlockHostUntilDone(Stream *stream) {
    181   AsHostStream(stream)->BlockUntilDone();
    182   return port::Status::OK();
    183 }
    184 
    185 DeviceDescription *HostExecutor::PopulateDeviceDescription() const {
    186   internal::DeviceDescriptionBuilder builder;
    187 
    188   builder.set_device_address_bits(64);
    189 
    190   // TODO(rspringer): How to report a value that's based in reality but that
    191   // doesn't result in thrashing or other badness? 4GiB chosen arbitrarily.
    192   builder.set_device_memory_size(static_cast<uint64>(4) * 1024 * 1024 * 1024);
    193 
    194   float cycle_counter_frequency = 1e9;
    195   if (FLAGS_stream_executor_cpu_real_clock_rate) {
    196     cycle_counter_frequency = static_cast<float>(
    197         tensorflow::profile_utils::CpuUtils::GetCycleCounterFrequency());
    198   }
    199   builder.set_clock_rate_ghz(cycle_counter_frequency / 1e9);
    200 
    201   auto built = builder.Build();
    202   return built.release();
    203 }
    204 
    205 bool HostExecutor::SupportsBlas() const {
    206   return PluginRegistry::Instance()
    207       ->GetFactory<PluginRegistry::BlasFactory>(kHostPlatformId,
    208                                                 plugin_config_.blas())
    209       .ok();
    210 }
    211 
    212 blas::BlasSupport *HostExecutor::CreateBlas() {
    213   PluginRegistry *registry = PluginRegistry::Instance();
    214   port::StatusOr<PluginRegistry::BlasFactory> status =
    215       registry->GetFactory<PluginRegistry::BlasFactory>(kHostPlatformId,
    216                                                         plugin_config_.blas());
    217   if (!status.ok()) {
    218     LOG(ERROR) << "Unable to retrieve BLAS factory: "
    219                << status.status().error_message();
    220     return nullptr;
    221   }
    222 
    223   return status.ValueOrDie()(this);
    224 }
    225 
    226 bool HostExecutor::SupportsFft() const {
    227   return PluginRegistry::Instance()
    228       ->GetFactory<PluginRegistry::FftFactory>(kHostPlatformId,
    229                                                plugin_config_.fft())
    230       .ok();
    231 }
    232 
    233 fft::FftSupport *HostExecutor::CreateFft() {
    234   PluginRegistry *registry = PluginRegistry::Instance();
    235   port::StatusOr<PluginRegistry::FftFactory> status =
    236       registry->GetFactory<PluginRegistry::FftFactory>(kHostPlatformId,
    237                                                        plugin_config_.fft());
    238   if (!status.ok()) {
    239     LOG(ERROR) << "Unable to retrieve FFT factory: "
    240                << status.status().error_message();
    241     return nullptr;
    242   }
    243 
    244   return status.ValueOrDie()(this);
    245 }
    246 
    247 bool HostExecutor::SupportsRng() const {
    248   return PluginRegistry::Instance()
    249       ->GetFactory<PluginRegistry::RngFactory>(kHostPlatformId,
    250                                                plugin_config_.rng())
    251       .ok();
    252 }
    253 
    254 rng::RngSupport *HostExecutor::CreateRng() {
    255   PluginRegistry *registry = PluginRegistry::Instance();
    256   port::StatusOr<PluginRegistry::RngFactory> status =
    257       registry->GetFactory<PluginRegistry::RngFactory>(kHostPlatformId,
    258                                                        plugin_config_.rng());
    259   if (!status.ok()) {
    260     LOG(ERROR) << "Unable to retrieve RNG factory: "
    261                << status.status().error_message();
    262     return nullptr;
    263   }
    264 
    265   return status.ValueOrDie()(this);
    266 }
    267 
    268 }  // namespace host
    269 }  // namespace gputools
    270 }  // namespace perftools
    271