Home | History | Annotate | Download | only in stream_executor
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Interfaces for platform-dependent implementations to satisfy. This are
     17 // delegated to from the StreamExecutor in pointer-to-implementation style; i.e.
     18 // the StreamExecutor is just a husk that delegates calls to the
     19 // platform-specific objects which implement the interfaces defined here.
     20 
     21 #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
     22 #define TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
     23 
     24 #include <functional>
     25 #include <map>
     26 #include <memory>
     27 #include <utility>
     28 #include <vector>
     29 
     30 #include "tensorflow/stream_executor/device_description.h"
     31 #include "tensorflow/stream_executor/device_memory.h"
     32 #include "tensorflow/stream_executor/device_options.h"
     33 #include "tensorflow/stream_executor/dnn.h"
     34 #include "tensorflow/stream_executor/event.h"
     35 #include "tensorflow/stream_executor/kernel.h"
     36 #include "tensorflow/stream_executor/kernel_cache_config.h"
     37 #include "tensorflow/stream_executor/kernel_spec.h"
     38 #include "tensorflow/stream_executor/launch_dim.h"
     39 #include "tensorflow/stream_executor/lib/status.h"
     40 #include "tensorflow/stream_executor/lib/statusor.h"
     41 #include "tensorflow/stream_executor/platform.h"
     42 #include "tensorflow/stream_executor/platform/port.h"
     43 #include "tensorflow/stream_executor/plugin_registry.h"
     44 #include "tensorflow/stream_executor/shared_memory_config.h"
     45 #include "tensorflow/stream_executor/trace_listener.h"
     46 #include "tensorflow/stream_executor/lib/inlined_vector.h"
     47 
     48 namespace perftools {
     49 namespace gputools {
     50 
     51 class Stream;
     52 class Timer;
     53 
     54 namespace internal {
     55 
     56 // Platform-dependent interface class for the generic Events interface, in
     57 // the PIMPL style.
     58 class EventInterface {
     59  public:
     60   EventInterface() {}
     61   virtual ~EventInterface() {}
     62 
     63  private:
     64   SE_DISALLOW_COPY_AND_ASSIGN(EventInterface);
     65 };
     66 
     67 // Pointer-to-implementation object type (i.e. the KernelBase class delegates to
     68 // this interface) with virtual destruction. This class exists for the
     69 // platform-dependent code to hang any kernel data/resource info/functionality
     70 // off of.
     71 class KernelInterface {
     72  public:
     73   // Default constructor for the abstract interface.
     74   KernelInterface() {}
     75 
     76   // Default destructor for the abstract interface.
     77   virtual ~KernelInterface() {}
     78 
     79   // Returns the number of formal parameters that this kernel accepts.
     80   virtual unsigned Arity() const = 0;
     81 
     82   // Sets the preferred cache configuration.
     83   virtual void SetPreferredCacheConfig(KernelCacheConfig config) = 0;
     84 
     85   // Gets the preferred cache configuration.
     86   virtual KernelCacheConfig GetPreferredCacheConfig() const = 0;
     87 
     88  private:
     89   SE_DISALLOW_COPY_AND_ASSIGN(KernelInterface);
     90 };
     91 
     92 // Pointer-to-implementation object type (i.e. the Stream class delegates to
     93 // this interface) with virtual destruction. This class exists for the
     94 // platform-dependent code to hang any kernel data/resource info/functionality
     95 // off of.
     96 class StreamInterface {
     97  public:
     98   // Default constructor for the abstract interface.
     99   StreamInterface() {}
    100 
    101   // Default destructor for the abstract interface.
    102   virtual ~StreamInterface() {}
    103 
    104   // Returns the CUDA stream associated with this platform's stream
    105   // implementation.
    106   //
    107   // WARNING: checks that the underlying platform is, in fact, CUDA, causing a
    108   // fatal error if it is not. This hack is made available solely for use from
    109   // distbelief code, which temporarily has strong ties to CUDA as a platform.
    110   virtual void *CudaStreamHack() { return nullptr; }
    111 
    112   // See the above comment on CudaStreamHack -- this further breaks abstraction
    113   // for Eigen within distbelief, which has strong ties to CUDA as a platform,
    114   // and a historical attachment to a programming model which takes a
    115   // stream-slot rather than a stream-value.
    116   virtual void **CudaStreamMemberHack() { return nullptr; }
    117 
    118  private:
    119   SE_DISALLOW_COPY_AND_ASSIGN(StreamInterface);
    120 };
    121 
    122 // Pointer-to-implementation object type (i.e. the Timer class delegates to
    123 // this interface) with virtual destruction. This class exists for the
    124 // platform-dependent code to hang any timer data/resource info/functionality
    125 // off of.
    126 class TimerInterface {
    127  public:
    128   // Default constructor for the abstract interface.
    129   TimerInterface() {}
    130 
    131   // Default destructor for the abstract interface.
    132   virtual ~TimerInterface() {}
    133 
    134   // Returns the number of microseconds elapsed in a completed timer.
    135   virtual uint64 Microseconds() const = 0;
    136 
    137   // Returns the number of nanoseconds elapsed in a completed timer.
    138   virtual uint64 Nanoseconds() const = 0;
    139 
    140  private:
    141   SE_DISALLOW_COPY_AND_ASSIGN(TimerInterface);
    142 };
    143 
    144 // Interface for the different StreamExecutor platforms (i.e. CUDA, OpenCL).
    145 //
    146 // Various platforms will provide an implementation that satisfy this interface.
    147 class StreamExecutorInterface {
    148  public:
    149   // Default constructor for the abstract interface.
    150   StreamExecutorInterface() {}
    151 
    152   // Default destructor for the abstract interface.
    153   virtual ~StreamExecutorInterface() {}
    154 
    155   // Returns the (transitively) wrapped executor if this executor is
    156   // wrapping another executor; otherwise, returns this.
    157   virtual StreamExecutorInterface *GetUnderlyingExecutor() { return this; }
    158 
    159   // See the StreamExecutor interface for comments on the same-named methods.
    160   virtual port::Status Init(int device_ordinal,
    161                             DeviceOptions device_options) = 0;
    162 
    163   virtual bool GetKernel(const MultiKernelLoaderSpec &spec,
    164                          KernelBase *kernel) {
    165     return false;
    166   }
    167   virtual bool Launch(Stream *stream, const ThreadDim &thread_dims,
    168                       const BlockDim &block_dims, const KernelBase &k,
    169                       const KernelArgsArrayBase &args) {
    170     return false;
    171   }
    172   // Releases any state associated with the kernel.
    173   virtual void UnloadKernel(const KernelBase *kernel) {}
    174   virtual void *Allocate(uint64 size) = 0;
    175   virtual void *AllocateSubBuffer(DeviceMemoryBase *parent, uint64 offset,
    176                                   uint64 size) = 0;
    177   virtual void Deallocate(DeviceMemoryBase *mem) = 0;
    178   virtual void *HostMemoryAllocate(uint64 size) = 0;
    179   virtual void HostMemoryDeallocate(void *mem) = 0;
    180   virtual bool HostMemoryRegister(void *mem, uint64 size) = 0;
    181   virtual bool HostMemoryUnregister(void *mem) = 0;
    182   virtual bool SynchronizeAllActivity() = 0;
    183   virtual bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) = 0;
    184   virtual bool SynchronousMemSet(DeviceMemoryBase *location, int value,
    185                                  uint64 size) = 0;
    186   virtual port::Status SynchronousMemcpy(DeviceMemoryBase *gpu_dst,
    187                                          const void *host_src, uint64 size) = 0;
    188   virtual port::Status SynchronousMemcpy(void *host_dst,
    189                                          const DeviceMemoryBase &gpu_src,
    190                                          uint64 size) = 0;
    191   virtual port::Status SynchronousMemcpyDeviceToDevice(
    192       DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src,
    193       uint64 size) = 0;
    194   virtual bool MemZero(Stream *stream, DeviceMemoryBase *location,
    195                        uint64 size) = 0;
    196   virtual bool Memset(Stream *stream, DeviceMemoryBase *location, uint8 pattern,
    197                       uint64 size) {
    198     return false;
    199   }
    200   virtual bool Memset32(Stream *stream, DeviceMemoryBase *location,
    201                         uint32 pattern, uint64 size) = 0;
    202   virtual bool Memcpy(Stream *stream, void *host_dst,
    203                       const DeviceMemoryBase &gpu_src, uint64 size) = 0;
    204   virtual bool Memcpy(Stream *stream, DeviceMemoryBase *gpu_dst,
    205                       const void *host_src, uint64 size) = 0;
    206   virtual bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *gpu_dst,
    207                                     const DeviceMemoryBase &host_src,
    208                                     uint64 size) = 0;
    209   virtual bool HostCallback(Stream *stream, std::function<void()> callback) = 0;
    210   virtual port::Status AllocateEvent(Event *event) = 0;
    211   virtual port::Status DeallocateEvent(Event *event) = 0;
    212   virtual port::Status RecordEvent(Stream *stream, Event *event) = 0;
    213   virtual port::Status WaitForEvent(Stream *stream, Event *event) = 0;
    214   virtual Event::Status PollForEventStatus(Event *event) = 0;
    215   virtual bool AllocateStream(Stream *stream) = 0;
    216   virtual void DeallocateStream(Stream *stream) = 0;
    217   virtual bool CreateStreamDependency(Stream *dependent, Stream *other) = 0;
    218   virtual bool AllocateTimer(Timer *timer) = 0;
    219   virtual void DeallocateTimer(Timer *timer) = 0;
    220   virtual bool StartTimer(Stream *stream, Timer *timer) = 0;
    221   virtual bool StopTimer(Stream *stream, Timer *timer) = 0;
    222   virtual port::Status BlockHostUntilDone(Stream *stream) = 0;
    223   virtual int PlatformDeviceCount() = 0;
    224   virtual port::Status EnablePeerAccessTo(StreamExecutorInterface *other) = 0;
    225   virtual bool CanEnablePeerAccessTo(StreamExecutorInterface *other) = 0;
    226   virtual SharedMemoryConfig GetDeviceSharedMemoryConfig() = 0;
    227   virtual port::Status SetDeviceSharedMemoryConfig(
    228       SharedMemoryConfig config) = 0;
    229 
    230   virtual int64 GetDeviceLoad() { return -1; }
    231 
    232   virtual bool DeviceMemoryUsage(int64 *free, int64 *total) const {
    233     return false;
    234   }
    235 
    236   // Retrieves device pointer and size for a symbol. The device pointer is
    237   // stored at mem, and the size is stored at size. Either mem or bytes can be
    238   // null, however, both of them cannot be null at the same time. To use
    239   // constant memory in CUDA, GetSymbol has to be used. Returns true if symbol
    240   // is found.
    241   virtual bool GetSymbol(const string& symbol_name, void **mem, size_t *bytes) {
    242     return false;
    243   }
    244 
    245   // Creates a new DeviceDescription object. Ownership is transferred to the
    246   // caller.
    247   virtual DeviceDescription *PopulateDeviceDescription() const = 0;
    248 
    249   // Attempts to register the provided TraceListener with the device-specific
    250   // Executor implementation. When this is called, the PIMPL interface has
    251   // already taken ownership of the object and is managing the generic tracing
    252   // events. The device-specific implementation must determine if the passed
    253   // listener is of a type appropriate for it to trace during registration (and
    254   // before dispatching events to it).
    255   // Returns true if the listener was successfully registered, false otherwise.
    256   // Does not take ownership of listener.
    257   virtual bool RegisterTraceListener(TraceListener* listener) { return false; }
    258 
    259   // Unregisters the specified listener from the device-specific Executor.
    260   // Returns true if the listener was successfully registered, false otherwise.
    261   virtual bool UnregisterTraceListener(TraceListener* listener) {
    262     return false;
    263   }
    264 
    265   // Returns whether this StreamExecutor has BLAS support for its underlying
    266   // platform.
    267   virtual bool SupportsBlas() const { return false; }
    268 
    269   // Creates a new BlasSupport object, ownership is transferred to the caller.
    270   // If SupportsBlas() is false, this will always return null.
    271   //
    272   // If SupportsBlas() is true, this may return null, for example, if the BLAS
    273   // initialization fails.
    274   virtual blas::BlasSupport *CreateBlas() { return nullptr; }
    275 
    276   // Returns whether this StreamExecutor has FFT support for its underlying
    277   // platform.
    278   virtual bool SupportsFft() const { return false; }
    279 
    280   // Creates a new fft::FftSupport object, ownership is transferred to the
    281   // caller.
    282   // If SupportsFft() is false, this will always return null.
    283   //
    284   // If SupportsFft() is true, this may return null, for example, if the FFT
    285   // initialization fails.
    286   virtual fft::FftSupport *CreateFft() { return nullptr; }
    287 
    288   // Returns whether this StreamExecutor has Random Number Generation support
    289   // for
    290   // its underlying platform.
    291   virtual bool SupportsRng() const { return false; }
    292 
    293   // Returns whether this StreamExecutor has neural net support for its
    294   // underlying
    295   // platform.
    296   virtual bool SupportsDnn() const { return false; }
    297 
    298   // Creates a new RngSupport object, ownership is transferred to the caller.
    299   // If SupportsRng() is false, this will always return null.
    300   //
    301   // If SupportsRng() is true, this may return null, for example, if the RNG
    302   // initialization fails.
    303   virtual rng::RngSupport *CreateRng() { return nullptr; }
    304 
    305   // Creates a new DnnSupport object, ownership is transferred to the caller.
    306   // If SupportsDnn() is false, this will always return null.
    307   //
    308   // If SupportsDnn() is true, this may return null, for example, if the DNN
    309   // initialization fails.
    310   virtual dnn::DnnSupport *CreateDnn() { return nullptr; }
    311 
    312   // Each call creates a new instance of the platform-specific implementation of
    313   // the corresponding interface type.
    314   virtual std::unique_ptr<EventInterface> CreateEventImplementation() = 0;
    315   virtual std::unique_ptr<KernelInterface> CreateKernelImplementation() = 0;
    316   virtual std::unique_ptr<StreamInterface> GetStreamImplementation() = 0;
    317   virtual std::unique_ptr<TimerInterface> GetTimerImplementation() = 0;
    318 
    319   // Returns the CUDA context associated with this StreamExecutor platform
    320   // implementation.
    321   //
    322   // WARNING: checks that the underlying platform is, in fact, CUDA, causing a
    323   // fatal error if it is not. This hack is made available solely for use from
    324   // distbelief code, which temporarily has strong ties to CUDA as a platform.
    325   virtual void *CudaContextHack() { return nullptr; }
    326 
    327  private:
    328   SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutorInterface);
    329 };
    330 
    331 using StreamExecutorFactory =
    332     std::function<StreamExecutorInterface *(const PluginConfig &)>;
    333 using EventFactory = std::function<EventInterface *(StreamExecutor *)>;
    334 using StreamFactory = std::function<StreamInterface *(StreamExecutor *)>;
    335 using TimerFactory = std::function<TimerInterface *(StreamExecutor *)>;
    336 using KernelFactory = std::function<KernelInterface*()>;
    337 
    338 StreamExecutorFactory* MakeCUDAExecutorImplementation();
    339 
    340 StreamExecutorFactory* MakeOpenCLExecutorImplementation();
    341 
    342 extern StreamExecutorFactory MakeHostExecutorImplementation;
    343 
    344 
    345 }  // namespace internal
    346 }  // namespace gputools
    347 }  // namespace perftools
    348 
    349 #endif  // TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
    350