1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Interfaces for platform-dependent implementations to satisfy. This are 17 // delegated to from the StreamExecutor in pointer-to-implementation style; i.e. 18 // the StreamExecutor is just a husk that delegates calls to the 19 // platform-specific objects which implement the interfaces defined here. 20 21 #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ 22 #define TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ 23 24 #include <functional> 25 #include <map> 26 #include <memory> 27 #include <utility> 28 #include <vector> 29 30 #include "tensorflow/stream_executor/device_description.h" 31 #include "tensorflow/stream_executor/device_memory.h" 32 #include "tensorflow/stream_executor/device_options.h" 33 #include "tensorflow/stream_executor/dnn.h" 34 #include "tensorflow/stream_executor/event.h" 35 #include "tensorflow/stream_executor/kernel.h" 36 #include "tensorflow/stream_executor/kernel_cache_config.h" 37 #include "tensorflow/stream_executor/kernel_spec.h" 38 #include "tensorflow/stream_executor/launch_dim.h" 39 #include "tensorflow/stream_executor/lib/status.h" 40 #include "tensorflow/stream_executor/lib/statusor.h" 41 #include "tensorflow/stream_executor/platform.h" 42 #include "tensorflow/stream_executor/platform/port.h" 43 #include "tensorflow/stream_executor/plugin_registry.h" 44 #include "tensorflow/stream_executor/shared_memory_config.h" 45 #include "tensorflow/stream_executor/trace_listener.h" 46 #include "tensorflow/stream_executor/lib/inlined_vector.h" 47 48 namespace perftools { 49 namespace gputools { 50 51 class Stream; 52 class Timer; 53 54 namespace internal { 55 56 // Platform-dependent interface class for the generic Events interface, in 57 // the PIMPL style. 58 class EventInterface { 59 public: 60 EventInterface() {} 61 virtual ~EventInterface() {} 62 63 private: 64 SE_DISALLOW_COPY_AND_ASSIGN(EventInterface); 65 }; 66 67 // Pointer-to-implementation object type (i.e. the KernelBase class delegates to 68 // this interface) with virtual destruction. This class exists for the 69 // platform-dependent code to hang any kernel data/resource info/functionality 70 // off of. 71 class KernelInterface { 72 public: 73 // Default constructor for the abstract interface. 74 KernelInterface() {} 75 76 // Default destructor for the abstract interface. 77 virtual ~KernelInterface() {} 78 79 // Returns the number of formal parameters that this kernel accepts. 80 virtual unsigned Arity() const = 0; 81 82 // Sets the preferred cache configuration. 83 virtual void SetPreferredCacheConfig(KernelCacheConfig config) = 0; 84 85 // Gets the preferred cache configuration. 86 virtual KernelCacheConfig GetPreferredCacheConfig() const = 0; 87 88 private: 89 SE_DISALLOW_COPY_AND_ASSIGN(KernelInterface); 90 }; 91 92 // Pointer-to-implementation object type (i.e. the Stream class delegates to 93 // this interface) with virtual destruction. This class exists for the 94 // platform-dependent code to hang any kernel data/resource info/functionality 95 // off of. 96 class StreamInterface { 97 public: 98 // Default constructor for the abstract interface. 99 StreamInterface() {} 100 101 // Default destructor for the abstract interface. 102 virtual ~StreamInterface() {} 103 104 // Returns the CUDA stream associated with this platform's stream 105 // implementation. 106 // 107 // WARNING: checks that the underlying platform is, in fact, CUDA, causing a 108 // fatal error if it is not. This hack is made available solely for use from 109 // distbelief code, which temporarily has strong ties to CUDA as a platform. 110 virtual void *CudaStreamHack() { return nullptr; } 111 112 // See the above comment on CudaStreamHack -- this further breaks abstraction 113 // for Eigen within distbelief, which has strong ties to CUDA as a platform, 114 // and a historical attachment to a programming model which takes a 115 // stream-slot rather than a stream-value. 116 virtual void **CudaStreamMemberHack() { return nullptr; } 117 118 private: 119 SE_DISALLOW_COPY_AND_ASSIGN(StreamInterface); 120 }; 121 122 // Pointer-to-implementation object type (i.e. the Timer class delegates to 123 // this interface) with virtual destruction. This class exists for the 124 // platform-dependent code to hang any timer data/resource info/functionality 125 // off of. 126 class TimerInterface { 127 public: 128 // Default constructor for the abstract interface. 129 TimerInterface() {} 130 131 // Default destructor for the abstract interface. 132 virtual ~TimerInterface() {} 133 134 // Returns the number of microseconds elapsed in a completed timer. 135 virtual uint64 Microseconds() const = 0; 136 137 // Returns the number of nanoseconds elapsed in a completed timer. 138 virtual uint64 Nanoseconds() const = 0; 139 140 private: 141 SE_DISALLOW_COPY_AND_ASSIGN(TimerInterface); 142 }; 143 144 // Interface for the different StreamExecutor platforms (i.e. CUDA, OpenCL). 145 // 146 // Various platforms will provide an implementation that satisfy this interface. 147 class StreamExecutorInterface { 148 public: 149 // Default constructor for the abstract interface. 150 StreamExecutorInterface() {} 151 152 // Default destructor for the abstract interface. 153 virtual ~StreamExecutorInterface() {} 154 155 // Returns the (transitively) wrapped executor if this executor is 156 // wrapping another executor; otherwise, returns this. 157 virtual StreamExecutorInterface *GetUnderlyingExecutor() { return this; } 158 159 // See the StreamExecutor interface for comments on the same-named methods. 160 virtual port::Status Init(int device_ordinal, 161 DeviceOptions device_options) = 0; 162 163 virtual bool GetKernel(const MultiKernelLoaderSpec &spec, 164 KernelBase *kernel) { 165 return false; 166 } 167 virtual bool Launch(Stream *stream, const ThreadDim &thread_dims, 168 const BlockDim &block_dims, const KernelBase &k, 169 const KernelArgsArrayBase &args) { 170 return false; 171 } 172 // Releases any state associated with the kernel. 173 virtual void UnloadKernel(const KernelBase *kernel) {} 174 virtual void *Allocate(uint64 size) = 0; 175 virtual void *AllocateSubBuffer(DeviceMemoryBase *parent, uint64 offset, 176 uint64 size) = 0; 177 virtual void Deallocate(DeviceMemoryBase *mem) = 0; 178 virtual void *HostMemoryAllocate(uint64 size) = 0; 179 virtual void HostMemoryDeallocate(void *mem) = 0; 180 virtual bool HostMemoryRegister(void *mem, uint64 size) = 0; 181 virtual bool HostMemoryUnregister(void *mem) = 0; 182 virtual bool SynchronizeAllActivity() = 0; 183 virtual bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) = 0; 184 virtual bool SynchronousMemSet(DeviceMemoryBase *location, int value, 185 uint64 size) = 0; 186 virtual port::Status SynchronousMemcpy(DeviceMemoryBase *gpu_dst, 187 const void *host_src, uint64 size) = 0; 188 virtual port::Status SynchronousMemcpy(void *host_dst, 189 const DeviceMemoryBase &gpu_src, 190 uint64 size) = 0; 191 virtual port::Status SynchronousMemcpyDeviceToDevice( 192 DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src, 193 uint64 size) = 0; 194 virtual bool MemZero(Stream *stream, DeviceMemoryBase *location, 195 uint64 size) = 0; 196 virtual bool Memset(Stream *stream, DeviceMemoryBase *location, uint8 pattern, 197 uint64 size) { 198 return false; 199 } 200 virtual bool Memset32(Stream *stream, DeviceMemoryBase *location, 201 uint32 pattern, uint64 size) = 0; 202 virtual bool Memcpy(Stream *stream, void *host_dst, 203 const DeviceMemoryBase &gpu_src, uint64 size) = 0; 204 virtual bool Memcpy(Stream *stream, DeviceMemoryBase *gpu_dst, 205 const void *host_src, uint64 size) = 0; 206 virtual bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *gpu_dst, 207 const DeviceMemoryBase &host_src, 208 uint64 size) = 0; 209 virtual bool HostCallback(Stream *stream, std::function<void()> callback) = 0; 210 virtual port::Status AllocateEvent(Event *event) = 0; 211 virtual port::Status DeallocateEvent(Event *event) = 0; 212 virtual port::Status RecordEvent(Stream *stream, Event *event) = 0; 213 virtual port::Status WaitForEvent(Stream *stream, Event *event) = 0; 214 virtual Event::Status PollForEventStatus(Event *event) = 0; 215 virtual bool AllocateStream(Stream *stream) = 0; 216 virtual void DeallocateStream(Stream *stream) = 0; 217 virtual bool CreateStreamDependency(Stream *dependent, Stream *other) = 0; 218 virtual bool AllocateTimer(Timer *timer) = 0; 219 virtual void DeallocateTimer(Timer *timer) = 0; 220 virtual bool StartTimer(Stream *stream, Timer *timer) = 0; 221 virtual bool StopTimer(Stream *stream, Timer *timer) = 0; 222 virtual port::Status BlockHostUntilDone(Stream *stream) = 0; 223 virtual int PlatformDeviceCount() = 0; 224 virtual port::Status EnablePeerAccessTo(StreamExecutorInterface *other) = 0; 225 virtual bool CanEnablePeerAccessTo(StreamExecutorInterface *other) = 0; 226 virtual SharedMemoryConfig GetDeviceSharedMemoryConfig() = 0; 227 virtual port::Status SetDeviceSharedMemoryConfig( 228 SharedMemoryConfig config) = 0; 229 230 virtual int64 GetDeviceLoad() { return -1; } 231 232 virtual bool DeviceMemoryUsage(int64 *free, int64 *total) const { 233 return false; 234 } 235 236 // Retrieves device pointer and size for a symbol. The device pointer is 237 // stored at mem, and the size is stored at size. Either mem or bytes can be 238 // null, however, both of them cannot be null at the same time. To use 239 // constant memory in CUDA, GetSymbol has to be used. Returns true if symbol 240 // is found. 241 virtual bool GetSymbol(const string& symbol_name, void **mem, size_t *bytes) { 242 return false; 243 } 244 245 // Creates a new DeviceDescription object. Ownership is transferred to the 246 // caller. 247 virtual DeviceDescription *PopulateDeviceDescription() const = 0; 248 249 // Attempts to register the provided TraceListener with the device-specific 250 // Executor implementation. When this is called, the PIMPL interface has 251 // already taken ownership of the object and is managing the generic tracing 252 // events. The device-specific implementation must determine if the passed 253 // listener is of a type appropriate for it to trace during registration (and 254 // before dispatching events to it). 255 // Returns true if the listener was successfully registered, false otherwise. 256 // Does not take ownership of listener. 257 virtual bool RegisterTraceListener(TraceListener* listener) { return false; } 258 259 // Unregisters the specified listener from the device-specific Executor. 260 // Returns true if the listener was successfully registered, false otherwise. 261 virtual bool UnregisterTraceListener(TraceListener* listener) { 262 return false; 263 } 264 265 // Returns whether this StreamExecutor has BLAS support for its underlying 266 // platform. 267 virtual bool SupportsBlas() const { return false; } 268 269 // Creates a new BlasSupport object, ownership is transferred to the caller. 270 // If SupportsBlas() is false, this will always return null. 271 // 272 // If SupportsBlas() is true, this may return null, for example, if the BLAS 273 // initialization fails. 274 virtual blas::BlasSupport *CreateBlas() { return nullptr; } 275 276 // Returns whether this StreamExecutor has FFT support for its underlying 277 // platform. 278 virtual bool SupportsFft() const { return false; } 279 280 // Creates a new fft::FftSupport object, ownership is transferred to the 281 // caller. 282 // If SupportsFft() is false, this will always return null. 283 // 284 // If SupportsFft() is true, this may return null, for example, if the FFT 285 // initialization fails. 286 virtual fft::FftSupport *CreateFft() { return nullptr; } 287 288 // Returns whether this StreamExecutor has Random Number Generation support 289 // for 290 // its underlying platform. 291 virtual bool SupportsRng() const { return false; } 292 293 // Returns whether this StreamExecutor has neural net support for its 294 // underlying 295 // platform. 296 virtual bool SupportsDnn() const { return false; } 297 298 // Creates a new RngSupport object, ownership is transferred to the caller. 299 // If SupportsRng() is false, this will always return null. 300 // 301 // If SupportsRng() is true, this may return null, for example, if the RNG 302 // initialization fails. 303 virtual rng::RngSupport *CreateRng() { return nullptr; } 304 305 // Creates a new DnnSupport object, ownership is transferred to the caller. 306 // If SupportsDnn() is false, this will always return null. 307 // 308 // If SupportsDnn() is true, this may return null, for example, if the DNN 309 // initialization fails. 310 virtual dnn::DnnSupport *CreateDnn() { return nullptr; } 311 312 // Each call creates a new instance of the platform-specific implementation of 313 // the corresponding interface type. 314 virtual std::unique_ptr<EventInterface> CreateEventImplementation() = 0; 315 virtual std::unique_ptr<KernelInterface> CreateKernelImplementation() = 0; 316 virtual std::unique_ptr<StreamInterface> GetStreamImplementation() = 0; 317 virtual std::unique_ptr<TimerInterface> GetTimerImplementation() = 0; 318 319 // Returns the CUDA context associated with this StreamExecutor platform 320 // implementation. 321 // 322 // WARNING: checks that the underlying platform is, in fact, CUDA, causing a 323 // fatal error if it is not. This hack is made available solely for use from 324 // distbelief code, which temporarily has strong ties to CUDA as a platform. 325 virtual void *CudaContextHack() { return nullptr; } 326 327 private: 328 SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutorInterface); 329 }; 330 331 using StreamExecutorFactory = 332 std::function<StreamExecutorInterface *(const PluginConfig &)>; 333 using EventFactory = std::function<EventInterface *(StreamExecutor *)>; 334 using StreamFactory = std::function<StreamInterface *(StreamExecutor *)>; 335 using TimerFactory = std::function<TimerInterface *(StreamExecutor *)>; 336 using KernelFactory = std::function<KernelInterface*()>; 337 338 StreamExecutorFactory* MakeCUDAExecutorImplementation(); 339 340 StreamExecutorFactory* MakeOpenCLExecutorImplementation(); 341 342 extern StreamExecutorFactory MakeHostExecutorImplementation; 343 344 345 } // namespace internal 346 } // namespace gputools 347 } // namespace perftools 348 349 #endif // TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ 350