Home | History | Annotate | Download | only in stream_executor
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Interfaces for platform-dependent implementations to satisfy. This are
     17 // delegated to from the StreamExecutor in pointer-to-implementation style; i.e.
     18 // the StreamExecutor is just a husk that delegates calls to the
     19 // platform-specific objects which implement the interfaces defined here.
     20 
     21 #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
     22 #define TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
     23 
     24 #include <functional>
     25 #include <map>
     26 #include <memory>
     27 #include <utility>
     28 #include <vector>
     29 
     30 #include "absl/types/optional.h"
     31 #include "tensorflow/stream_executor/allocator_stats.h"
     32 #include "tensorflow/stream_executor/device_description.h"
     33 #include "tensorflow/stream_executor/device_memory.h"
     34 #include "tensorflow/stream_executor/device_options.h"
     35 #include "tensorflow/stream_executor/dnn.h"
     36 #include "tensorflow/stream_executor/event.h"
     37 #include "tensorflow/stream_executor/kernel.h"
     38 #include "tensorflow/stream_executor/kernel_cache_config.h"
     39 #include "tensorflow/stream_executor/kernel_spec.h"
     40 #include "tensorflow/stream_executor/launch_dim.h"
     41 #include "tensorflow/stream_executor/lib/status.h"
     42 #include "tensorflow/stream_executor/lib/statusor.h"
     43 #include "tensorflow/stream_executor/module_spec.h"
     44 #include "tensorflow/stream_executor/platform.h"
     45 #include "tensorflow/stream_executor/platform/port.h"
     46 #include "tensorflow/stream_executor/plugin_registry.h"
     47 #include "tensorflow/stream_executor/shared_memory_config.h"
     48 #include "tensorflow/stream_executor/trace_listener.h"
     49 
     50 namespace stream_executor {
     51 
     52 class Stream;
     53 class Timer;
     54 
     55 // An opaque handle to a loaded module.
     56 //
     57 // An instance of this is returned from StreamExecutor::GetModule.
     58 class ModuleHandle {
     59  public:
     60   /*implicit*/ ModuleHandle(void *id = nullptr) : id_(id) {}
     61 
     62   // A ModuleHandle with id() == nullptr is an invalid module handle, akin to a
     63   // null pointer.
     64   void *id() const { return id_; }
     65 
     66   explicit operator bool() const { return id() != nullptr; }
     67 
     68  private:
     69   void *id_;
     70 };
     71 
     72 namespace internal {
     73 
     74 // Platform-dependent interface class for the generic Events interface, in
     75 // the PIMPL style.
     76 class EventInterface {
     77  public:
     78   EventInterface() {}
     79   virtual ~EventInterface() {}
     80 
     81  private:
     82   SE_DISALLOW_COPY_AND_ASSIGN(EventInterface);
     83 };
     84 
     85 // Pointer-to-implementation object type (i.e. the KernelBase class delegates to
     86 // this interface) with virtual destruction. This class exists for the
     87 // platform-dependent code to hang any kernel data/resource info/functionality
     88 // off of.
     89 class KernelInterface {
     90  public:
     91   // Default constructor for the abstract interface.
     92   KernelInterface() {}
     93 
     94   // Default destructor for the abstract interface.
     95   virtual ~KernelInterface() {}
     96 
     97   // Returns the number of formal parameters that this kernel accepts.
     98   virtual unsigned Arity() const = 0;
     99 
    100   // Sets the preferred cache configuration.
    101   virtual void SetPreferredCacheConfig(KernelCacheConfig config) = 0;
    102 
    103   // Gets the preferred cache configuration.
    104   virtual KernelCacheConfig GetPreferredCacheConfig() const = 0;
    105 
    106  private:
    107   SE_DISALLOW_COPY_AND_ASSIGN(KernelInterface);
    108 };
    109 
    110 // Pointer-to-implementation object type (i.e. the Stream class delegates to
    111 // this interface) with virtual destruction. This class exists for the
    112 // platform-dependent code to hang any kernel data/resource info/functionality
    113 // off of.
    114 class StreamInterface {
    115  public:
    116   // Default constructor for the abstract interface.
    117   StreamInterface() {}
    118 
    119   // Default destructor for the abstract interface.
    120   virtual ~StreamInterface() {}
    121 
    122   // Returns the GPU stream associated with this platform's stream
    123   // implementation.
    124   //
    125   // WARNING: checks that the underlying platform is, in fact, CUDA or ROCm,
    126   // causing a fatal error if it is not. This hack is made available solely for
    127   // use from distbelief code, which temporarily has strong ties to CUDA or
    128   // ROCm as a platform.
    129   virtual void *GpuStreamHack() { return nullptr; }
    130 
    131   // See the above comment on GpuStreamHack -- this further breaks abstraction
    132   // for Eigen within distbelief, which has strong ties to CUDA or ROCm as a
    133   // platform, and a historical attachment to a programming model which takes a
    134   // stream-slot rather than a stream-value.
    135   virtual void **GpuStreamMemberHack() { return nullptr; }
    136 
    137  private:
    138   SE_DISALLOW_COPY_AND_ASSIGN(StreamInterface);
    139 };
    140 
    141 // Pointer-to-implementation object type (i.e. the Timer class delegates to
    142 // this interface) with virtual destruction. This class exists for the
    143 // platform-dependent code to hang any timer data/resource info/functionality
    144 // off of.
    145 class TimerInterface {
    146  public:
    147   // Default constructor for the abstract interface.
    148   TimerInterface() {}
    149 
    150   // Default destructor for the abstract interface.
    151   virtual ~TimerInterface() {}
    152 
    153   // Returns the number of microseconds elapsed in a completed timer.
    154   virtual uint64 Microseconds() const = 0;
    155 
    156   // Returns the number of nanoseconds elapsed in a completed timer.
    157   virtual uint64 Nanoseconds() const = 0;
    158 
    159  private:
    160   SE_DISALLOW_COPY_AND_ASSIGN(TimerInterface);
    161 };
    162 
    163 // Interface for the different StreamExecutor platforms (i.e. CUDA, OpenCL).
    164 //
    165 // Various platforms will provide an implementation that satisfy this interface.
    166 class StreamExecutorInterface {
    167  public:
    168   // Default constructor for the abstract interface.
    169   StreamExecutorInterface() {}
    170 
    171   // Default destructor for the abstract interface.
    172   virtual ~StreamExecutorInterface() {}
    173 
    174   // Returns the (transitively) wrapped executor if this executor is
    175   // wrapping another executor; otherwise, returns this.
    176   virtual StreamExecutorInterface *GetUnderlyingExecutor() { return this; }
    177 
    178   // See the StreamExecutor interface for comments on the same-named methods.
    179   virtual port::Status Init(int device_ordinal,
    180                             DeviceOptions device_options) = 0;
    181 
    182   virtual bool GetKernel(const MultiKernelLoaderSpec &spec,
    183                          KernelBase *kernel) {
    184     return false;
    185   }
    186   virtual bool LoadModule(const MultiModuleLoaderSpec &spec,
    187                           ModuleHandle *module_handle) {
    188     return false;
    189   }
    190   virtual bool UnloadModule(ModuleHandle module_handle) { return false; }
    191   virtual bool Launch(Stream *stream, const ThreadDim &thread_dims,
    192                       const BlockDim &block_dims, const KernelBase &k,
    193                       const KernelArgsArrayBase &args) {
    194     return false;
    195   }
    196   // Releases any state associated with the kernel.
    197   virtual void UnloadKernel(const KernelBase *kernel) {}
    198   virtual void *Allocate(uint64 size) = 0;
    199   virtual void *AllocateSubBuffer(DeviceMemoryBase *parent, uint64 offset,
    200                                   uint64 size) = 0;
    201   virtual void Deallocate(DeviceMemoryBase *mem) = 0;
    202   // Allocates unified memory space of the given size, if supported.
    203   // See
    204   // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-unified-memory-programming-hd
    205   // for more details on unified memory.
    206   virtual void *UnifiedMemoryAllocate(uint64 size) { return nullptr; }
    207 
    208   // Deallocates unified memory space previously allocated with
    209   // UnifiedMemoryAllocate.
    210   virtual void UnifiedMemoryDeallocate(void *mem) {}
    211   virtual void *HostMemoryAllocate(uint64 size) = 0;
    212   virtual void HostMemoryDeallocate(void *mem) = 0;
    213   virtual bool HostMemoryRegister(void *mem, uint64 size) = 0;
    214   virtual bool HostMemoryUnregister(void *mem) = 0;
    215   virtual bool SynchronizeAllActivity() = 0;
    216   virtual bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) = 0;
    217   virtual bool SynchronousMemSet(DeviceMemoryBase *location, int value,
    218                                  uint64 size) = 0;
    219   virtual port::Status SynchronousMemcpy(DeviceMemoryBase *gpu_dst,
    220                                          const void *host_src, uint64 size) = 0;
    221   virtual port::Status SynchronousMemcpy(void *host_dst,
    222                                          const DeviceMemoryBase &gpu_src,
    223                                          uint64 size) = 0;
    224   virtual port::Status SynchronousMemcpyDeviceToDevice(
    225       DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src,
    226       uint64 size) = 0;
    227   virtual bool MemZero(Stream *stream, DeviceMemoryBase *location,
    228                        uint64 size) = 0;
    229   virtual bool Memset(Stream *stream, DeviceMemoryBase *location, uint8 pattern,
    230                       uint64 size) {
    231     return false;
    232   }
    233   virtual bool Memset32(Stream *stream, DeviceMemoryBase *location,
    234                         uint32 pattern, uint64 size) = 0;
    235   virtual bool Memcpy(Stream *stream, void *host_dst,
    236                       const DeviceMemoryBase &gpu_src, uint64 size) = 0;
    237   virtual bool Memcpy(Stream *stream, DeviceMemoryBase *gpu_dst,
    238                       const void *host_src, uint64 size) = 0;
    239   virtual bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *gpu_dst,
    240                                     const DeviceMemoryBase &gpu_src,
    241                                     uint64 size) = 0;
    242   virtual bool HostCallback(Stream *stream, std::function<void()> callback);
    243   virtual bool HostCallback(Stream *stream,
    244                             std::function<port::Status()> callback) = 0;
    245   virtual port::Status AllocateEvent(Event *event) = 0;
    246   virtual port::Status DeallocateEvent(Event *event) = 0;
    247   virtual port::Status RecordEvent(Stream *stream, Event *event) = 0;
    248   virtual port::Status WaitForEvent(Stream *stream, Event *event) = 0;
    249   virtual Event::Status PollForEventStatus(Event *event) = 0;
    250   virtual bool AllocateStream(Stream *stream) = 0;
    251   virtual void DeallocateStream(Stream *stream) = 0;
    252   virtual bool CreateStreamDependency(Stream *dependent, Stream *other) = 0;
    253   virtual bool AllocateTimer(Timer *timer) = 0;
    254   virtual void DeallocateTimer(Timer *timer) = 0;
    255   virtual bool StartTimer(Stream *stream, Timer *timer) = 0;
    256   virtual bool StopTimer(Stream *stream, Timer *timer) = 0;
    257   virtual port::Status BlockHostUntilDone(Stream *stream) = 0;
    258   virtual port::Status GetStatus(Stream *stream) {
    259     return port::Status(port::error::UNIMPLEMENTED,
    260                         "GetStatus is not supported on this executor.");
    261   }
    262   virtual int PlatformDeviceCount() = 0;
    263   virtual port::Status EnablePeerAccessTo(StreamExecutorInterface *other) = 0;
    264   virtual bool CanEnablePeerAccessTo(StreamExecutorInterface *other) = 0;
    265   virtual SharedMemoryConfig GetDeviceSharedMemoryConfig() = 0;
    266   virtual port::Status SetDeviceSharedMemoryConfig(
    267       SharedMemoryConfig config) = 0;
    268 
    269   virtual int64 GetDeviceLoad() { return -1; }
    270 
    271   virtual bool DeviceMemoryUsage(int64 *free, int64 *total) const {
    272     return false;
    273   }
    274 
    275   // Retrieves device pointer and size for a symbol. The device pointer is
    276   // stored at mem, and the size is stored at size. Either mem or bytes can be
    277   // null, however, both of them cannot be null at the same time. To use
    278   // constant memory in CUDA, GetSymbol has to be used. Returns true if symbol
    279   // is found.
    280   //
    281   // If ModuleHandle is set then we search for `symbol_name` only within the
    282   // module corresponding to `module_handle`.  Otherwise all loaded modules are
    283   // searched.
    284   virtual bool GetSymbol(const string &symbol_name, ModuleHandle module_handle,
    285                          void **mem, size_t *bytes) {
    286     return false;
    287   }
    288 
    289   // Creates a new DeviceDescription object. Ownership is transferred to the
    290   // caller.
    291   virtual DeviceDescription *PopulateDeviceDescription() const = 0;
    292 
    293   // Attempts to register the provided TraceListener with the device-specific
    294   // Executor implementation. When this is called, the PIMPL interface has
    295   // already taken ownership of the object and is managing the generic tracing
    296   // events. The device-specific implementation must determine if the passed
    297   // listener is of a type appropriate for it to trace during registration (and
    298   // before dispatching events to it).
    299   // Returns true if the listener was successfully registered, false otherwise.
    300   // Does not take ownership of listener.
    301   virtual bool RegisterTraceListener(TraceListener* listener) { return false; }
    302 
    303   // Unregisters the specified listener from the device-specific Executor.
    304   // Returns true if the listener was successfully registered, false otherwise.
    305   virtual bool UnregisterTraceListener(TraceListener* listener) {
    306     return false;
    307   }
    308 
    309   // Returns whether this StreamExecutor has BLAS support for its underlying
    310   // platform.
    311   virtual bool SupportsBlas() const { return false; }
    312 
    313   // Creates a new BlasSupport object, ownership is transferred to the caller.
    314   // If SupportsBlas() is false, this will always return null.
    315   //
    316   // If SupportsBlas() is true, this may return null, for example, if the BLAS
    317   // initialization fails.
    318   virtual blas::BlasSupport *CreateBlas() { return nullptr; }
    319 
    320   // Returns whether this StreamExecutor has FFT support for its underlying
    321   // platform.
    322   virtual bool SupportsFft() const { return false; }
    323 
    324   // Creates a new fft::FftSupport object, ownership is transferred to the
    325   // caller.
    326   // If SupportsFft() is false, this will always return null.
    327   //
    328   // If SupportsFft() is true, this may return null, for example, if the FFT
    329   // initialization fails.
    330   virtual fft::FftSupport *CreateFft() { return nullptr; }
    331 
    332   // Returns whether this StreamExecutor has Random Number Generation support
    333   // for
    334   // its underlying platform.
    335   virtual bool SupportsRng() const { return false; }
    336 
    337   // Returns whether this StreamExecutor has neural net support for its
    338   // underlying
    339   // platform.
    340   virtual bool SupportsDnn() const { return false; }
    341 
    342   // Creates a new RngSupport object, ownership is transferred to the caller.
    343   // If SupportsRng() is false, this will always return null.
    344   //
    345   // If SupportsRng() is true, this may return null, for example, if the RNG
    346   // initialization fails.
    347   virtual rng::RngSupport *CreateRng() { return nullptr; }
    348 
    349   // Creates a new DnnSupport object, ownership is transferred to the caller.
    350   // If SupportsDnn() is false, this will always return null.
    351   //
    352   // If SupportsDnn() is true, this may return null, for example, if the DNN
    353   // initialization fails.
    354   virtual dnn::DnnSupport *CreateDnn() { return nullptr; }
    355 
    356   // Each call creates a new instance of the platform-specific implementation of
    357   // the corresponding interface type.
    358   virtual std::unique_ptr<EventInterface> CreateEventImplementation() = 0;
    359   virtual std::unique_ptr<KernelInterface> CreateKernelImplementation() = 0;
    360   virtual std::unique_ptr<StreamInterface> GetStreamImplementation() = 0;
    361   virtual std::unique_ptr<TimerInterface> GetTimerImplementation() = 0;
    362 
    363   // Returns the CUDA or ROCm context associated with this StreamExecutor
    364   // platform implementation.
    365   //
    366   // WARNING: checks that the underlying platform is, in fact, CUDA or ROCm,
    367   // causing a fatal error if it is not. This hack is made available solely for
    368   // use from distbelief code, which temporarily has strong ties to CUDA or ROCm
    369   // as a platform.
    370   virtual void *GpuContextHack() { return nullptr; }
    371 
    372   // Return allocator statistics.
    373   virtual absl::optional<AllocatorStats> GetAllocatorStats() {
    374     return absl::nullopt;
    375   }
    376 
    377  private:
    378   SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutorInterface);
    379 };
    380 
    381 using StreamExecutorFactory =
    382     std::function<StreamExecutorInterface *(const PluginConfig &)>;
    383 using EventFactory = std::function<EventInterface *(StreamExecutor *)>;
    384 using StreamFactory = std::function<StreamInterface *(StreamExecutor *)>;
    385 using TimerFactory = std::function<TimerInterface *(StreamExecutor *)>;
    386 using KernelFactory = std::function<KernelInterface*()>;
    387 
    388 StreamExecutorFactory *MakeCUDAExecutorImplementation();
    389 
    390 StreamExecutorFactory *MakeROCMExecutorImplementation();
    391 
    392 StreamExecutorFactory *MakeOpenCLExecutorImplementation();
    393 
    394 extern StreamExecutorFactory MakeHostExecutorImplementation;
    395 
    396 
    397 }  // namespace internal
    398 }  // namespace stream_executor
    399 
    400 #endif  // TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
    401