1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Implementation of HostExecutor class [of those methods not defined in the 17 // class declaration]. 18 #include "tensorflow/stream_executor/host/host_gpu_executor.h" 19 20 #include <string.h> 21 22 #include "tensorflow/core/platform/profile_utils/cpu_utils.h" 23 #include "tensorflow/stream_executor/host/host_platform_id.h" 24 #include "tensorflow/stream_executor/host/host_stream.h" 25 #include "tensorflow/stream_executor/host/host_timer.h" 26 #include "tensorflow/stream_executor/lib/statusor.h" 27 #include "tensorflow/stream_executor/plugin_registry.h" 28 29 bool FLAGS_stream_executor_cpu_real_clock_rate = false; 30 31 namespace perftools { 32 namespace gputools { 33 namespace host { 34 35 HostStream *AsHostStream(Stream *stream) { 36 DCHECK(stream != nullptr); 37 return dynamic_cast<HostStream *>(stream->implementation()); 38 } 39 40 HostExecutor::HostExecutor(const PluginConfig &plugin_config) 41 : plugin_config_(plugin_config) {} 42 43 HostExecutor::~HostExecutor() {} 44 45 void *HostExecutor::Allocate(uint64 size) { return new char[size]; } 46 47 void *HostExecutor::AllocateSubBuffer(DeviceMemoryBase *parent, 48 uint64 offset_bytes, uint64 size_bytes) { 49 return reinterpret_cast<char *>(parent->opaque()) + offset_bytes; 50 } 51 52 void HostExecutor::Deallocate(DeviceMemoryBase *mem) { 53 if (!mem->is_sub_buffer()) { 54 delete[] static_cast<char *>(mem->opaque()); 55 } 56 } 57 58 bool HostExecutor::SynchronousMemZero(DeviceMemoryBase *location, uint64 size) { 59 memset(location->opaque(), 0, size); 60 return true; 61 } 62 63 bool HostExecutor::SynchronousMemSet(DeviceMemoryBase *location, int value, 64 uint64 size) { 65 memset(location->opaque(), value, size); 66 return true; 67 } 68 69 bool HostExecutor::Memcpy(Stream *stream, void *host_dst, 70 const DeviceMemoryBase &gpu_src, uint64 size) { 71 // Enqueue the [asynchronous] memcpy on the stream (HostStream) associated 72 // with the HostExecutor. 73 void *src_mem = const_cast<void *>(gpu_src.opaque()); 74 AsHostStream(stream)->EnqueueTask( 75 [host_dst, src_mem, size]() { memcpy(host_dst, src_mem, size); }); 76 return true; 77 } 78 79 bool HostExecutor::Memcpy(Stream *stream, DeviceMemoryBase *gpu_dst, 80 const void *host_src, uint64 size) { 81 void *dst_mem = gpu_dst->opaque(); 82 // Enqueue the [asynchronous] memcpy on the stream (HostStream) associated 83 // with the HostExecutor. 84 AsHostStream(stream)->EnqueueTask( 85 [dst_mem, host_src, size]() { memcpy(dst_mem, host_src, size); }); 86 return true; 87 } 88 89 bool HostExecutor::MemcpyDeviceToDevice(Stream *stream, 90 DeviceMemoryBase *gpu_dst, 91 const DeviceMemoryBase &gpu_src, 92 uint64 size) { 93 void *dst_mem = gpu_dst->opaque(); 94 void *src_mem = const_cast<void *>(gpu_src.opaque()); 95 // Enqueue this [asynchronous] "device-to-device" (i.e., host-to-host, given 96 // the nature of the HostExecutor) memcpy on the stream (HostStream) 97 // associated with the HostExecutor. 98 AsHostStream(stream)->EnqueueTask( 99 [src_mem, dst_mem, size]() { memcpy(src_mem, dst_mem, size); }); 100 return true; 101 } 102 103 bool HostExecutor::MemZero(Stream *stream, DeviceMemoryBase *location, 104 uint64 size) { 105 void *gpu_mem = location->opaque(); 106 // Enqueue the [asynchronous] memzero on the stream (HostStream) associated 107 // with the HostExecutor. 108 AsHostStream(stream)->EnqueueTask( 109 [gpu_mem, size]() { memset(gpu_mem, 0, size); }); 110 return true; 111 } 112 113 bool HostExecutor::Memset(Stream *stream, DeviceMemoryBase *location, 114 uint8 pattern, uint64 size) { 115 void *gpu_mem = location->opaque(); 116 // Enqueue the [asynchronous] memzero on the stream (HostStream) associated 117 // with the HostExecutor. 118 AsHostStream(stream)->EnqueueTask( 119 [gpu_mem, size, pattern]() { memset(gpu_mem, pattern, size); }); 120 return true; 121 } 122 123 bool HostExecutor::Memset32(Stream *stream, DeviceMemoryBase *location, 124 uint32 pattern, uint64 size) { 125 void *gpu_mem = location->opaque(); 126 // Enqueue the [asynchronous] memzero on the stream (HostStream) associated 127 // with the HostExecutor. 128 AsHostStream(stream)->EnqueueTask( 129 [gpu_mem, size, pattern]() { memset(gpu_mem, pattern, size); }); 130 return true; 131 } 132 133 port::Status HostExecutor::SynchronousMemcpy(DeviceMemoryBase *gpu_dst, 134 const void *host_src, 135 uint64 size) { 136 memcpy(gpu_dst->opaque(), host_src, size); 137 return port::Status::OK(); 138 } 139 140 port::Status HostExecutor::SynchronousMemcpy(void *host_dst, 141 const DeviceMemoryBase &gpu_src, 142 uint64 size) { 143 memcpy(host_dst, gpu_src.opaque(), size); 144 return port::Status::OK(); 145 } 146 147 port::Status HostExecutor::SynchronousMemcpyDeviceToDevice( 148 DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src, uint64 size) { 149 memcpy(gpu_dst->opaque(), gpu_src.opaque(), size); 150 return port::Status::OK(); 151 } 152 153 bool HostExecutor::HostCallback(Stream *stream, 154 std::function<void()> callback) { 155 AsHostStream(stream)->EnqueueTask(callback); 156 return true; 157 } 158 159 bool HostExecutor::AllocateStream(Stream *stream) { return true; } 160 161 void HostExecutor::DeallocateStream(Stream *stream) {} 162 163 bool HostExecutor::CreateStreamDependency(Stream *dependent, Stream *other) { 164 AsHostStream(dependent)->EnqueueTask( 165 [other]() { SE_CHECK_OK(other->BlockHostUntilDone()); }); 166 AsHostStream(dependent)->BlockUntilDone(); 167 return true; 168 } 169 170 bool HostExecutor::StartTimer(Stream *stream, Timer *timer) { 171 dynamic_cast<HostTimer *>(timer->implementation())->Start(stream); 172 return true; 173 } 174 175 bool HostExecutor::StopTimer(Stream *stream, Timer *timer) { 176 dynamic_cast<HostTimer *>(timer->implementation())->Stop(stream); 177 return true; 178 } 179 180 port::Status HostExecutor::BlockHostUntilDone(Stream *stream) { 181 AsHostStream(stream)->BlockUntilDone(); 182 return port::Status::OK(); 183 } 184 185 DeviceDescription *HostExecutor::PopulateDeviceDescription() const { 186 internal::DeviceDescriptionBuilder builder; 187 188 builder.set_device_address_bits(64); 189 190 // TODO(rspringer): How to report a value that's based in reality but that 191 // doesn't result in thrashing or other badness? 4GiB chosen arbitrarily. 192 builder.set_device_memory_size(static_cast<uint64>(4) * 1024 * 1024 * 1024); 193 194 float cycle_counter_frequency = 1e9; 195 if (FLAGS_stream_executor_cpu_real_clock_rate) { 196 cycle_counter_frequency = static_cast<float>( 197 tensorflow::profile_utils::CpuUtils::GetCycleCounterFrequency()); 198 } 199 builder.set_clock_rate_ghz(cycle_counter_frequency / 1e9); 200 201 auto built = builder.Build(); 202 return built.release(); 203 } 204 205 bool HostExecutor::SupportsBlas() const { 206 return PluginRegistry::Instance() 207 ->GetFactory<PluginRegistry::BlasFactory>(kHostPlatformId, 208 plugin_config_.blas()) 209 .ok(); 210 } 211 212 blas::BlasSupport *HostExecutor::CreateBlas() { 213 PluginRegistry *registry = PluginRegistry::Instance(); 214 port::StatusOr<PluginRegistry::BlasFactory> status = 215 registry->GetFactory<PluginRegistry::BlasFactory>(kHostPlatformId, 216 plugin_config_.blas()); 217 if (!status.ok()) { 218 LOG(ERROR) << "Unable to retrieve BLAS factory: " 219 << status.status().error_message(); 220 return nullptr; 221 } 222 223 return status.ValueOrDie()(this); 224 } 225 226 bool HostExecutor::SupportsFft() const { 227 return PluginRegistry::Instance() 228 ->GetFactory<PluginRegistry::FftFactory>(kHostPlatformId, 229 plugin_config_.fft()) 230 .ok(); 231 } 232 233 fft::FftSupport *HostExecutor::CreateFft() { 234 PluginRegistry *registry = PluginRegistry::Instance(); 235 port::StatusOr<PluginRegistry::FftFactory> status = 236 registry->GetFactory<PluginRegistry::FftFactory>(kHostPlatformId, 237 plugin_config_.fft()); 238 if (!status.ok()) { 239 LOG(ERROR) << "Unable to retrieve FFT factory: " 240 << status.status().error_message(); 241 return nullptr; 242 } 243 244 return status.ValueOrDie()(this); 245 } 246 247 bool HostExecutor::SupportsRng() const { 248 return PluginRegistry::Instance() 249 ->GetFactory<PluginRegistry::RngFactory>(kHostPlatformId, 250 plugin_config_.rng()) 251 .ok(); 252 } 253 254 rng::RngSupport *HostExecutor::CreateRng() { 255 PluginRegistry *registry = PluginRegistry::Instance(); 256 port::StatusOr<PluginRegistry::RngFactory> status = 257 registry->GetFactory<PluginRegistry::RngFactory>(kHostPlatformId, 258 plugin_config_.rng()); 259 if (!status.ok()) { 260 LOG(ERROR) << "Unable to retrieve RNG factory: " 261 << status.status().error_message(); 262 return nullptr; 263 } 264 265 return status.ValueOrDie()(this); 266 } 267 268 } // namespace host 269 } // namespace gputools 270 } // namespace perftools 271