Home | History | Annotate | Download | only in stream_executor
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
     17 #define TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
     18 
     19 #include <atomic>
     20 #include <memory>
     21 #include <set>
     22 #include <tuple>
     23 #include <vector>
     24 
     25 #include "tensorflow/stream_executor/lib/status.h"
     26 #include "tensorflow/stream_executor/lib/statusor.h"
     27 #include "tensorflow/stream_executor/lib/strcat.h"
     28 #include "tensorflow/stream_executor/lib/threadpool.h"
     29 #include "tensorflow/stream_executor/platform.h"
     30 #include "tensorflow/stream_executor/platform/logging.h"
     31 #include "tensorflow/stream_executor/platform/mutex.h"
     32 #include "tensorflow/stream_executor/platform/port.h"
     33 #include "tensorflow/stream_executor/platform/thread_annotations.h"
     34 #include "tensorflow/stream_executor/rng.h"
     35 #include "tensorflow/stream_executor/shared_memory_config.h"
     36 #include "tensorflow/stream_executor/stream.h"
     37 #include "tensorflow/stream_executor/stream_executor_internal.h"
     38 #include "tensorflow/stream_executor/trace_listener.h"
     39 
     40 namespace perftools {
     41 namespace gputools {
     42 
     43 // Structure used for device memory leak checking.
     44 struct AllocRecord {
     45   // The requested allocation size of the buffer.
     46   uint64 bytes;
     47 
     48   // Holds a representation of the stack at the time the associated buffer was
     49   // allocated. Produced in a form described in
     50   // //util/symbolize/symbolized_stacktrace.h.
     51   string stack_trace;
     52 };
     53 
     54 // Forward declaration of private friend class.
     55 template <typename BeginCallT, typename CompleteCallT,
     56           typename ReturnT, typename... BeginArgsT>
     57 class ScopedTracer;
     58 
     59 // A StreamExecutor manages a single device, in terms of executing work (kernel
     60 // launches) and memory management (allocation/deallocation, memory copies to
     61 // and from the device). It is conceptually the "handle" for a device -- Stream
     62 // objects, which are used to enqueue work to run on the
     63 // coprocessor have a StreamExecutor instance as their "parent" object.
     64 //
     65 // StreamExecutor objects have an underlying platform that is specified up
     66 // front;
     67 // e.g. either it is a CUDA or OpenCL executor.
     68 //
     69 // Thread-safe after initialization.
     70 // StreamExecutor interface should not be invoked from a signal handler.
     71 class StreamExecutor {
     72  public:
     73   explicit StreamExecutor(PlatformKind kind,
     74                           const PluginConfig &plugin_config = PluginConfig());
     75 
     76   StreamExecutor(
     77       const Platform *platform,
     78       std::unique_ptr<internal::StreamExecutorInterface> implementation);
     79 
     80   ~StreamExecutor();
     81 
     82   port::Status Init();
     83   port::Status Init(int device_ordinal, DeviceOptions device_options);
     84 
     85   // DEPRECATED: Do not use; use platform() instead.
     86   // Returns the platform that this StreamExecutor is acting upon.
     87   PlatformKind platform_kind() const { return platform_kind_; }
     88 
     89   // Returns a reference to the platform that created this executor.
     90   const Platform *platform() const { return platform_; }
     91 
     92   // Retrieves (loads) a kernel for the platform this StreamExecutor is acting
     93   // upon, if one exists.
     94   //
     95   // Parameters:
     96   //   spec: The MultiKernelLoaderSpec is usually generated as a compile-time
     97   //    constant into an appropriate namespace. For example, see
     98   //    perftools::gputools::executor_sample::kKernelLoaderSpecs, from which a
     99   //    MultiKernelLoaderSpec is selected.
    100   //   kernel: Outparam that the kernel is loaded into. A given Kernel
    101   //    instantiation should not be loaded into more than once.
    102   //
    103   // If an error occurs, or there is no kernel available for the StreamExecutor
    104   // platform, false is returned.
    105   bool GetKernel(const MultiKernelLoaderSpec &spec, KernelBase *kernel);
    106 
    107   // Releases any state associated with the previously loaded kernel.
    108   void UnloadKernel(const KernelBase *kernel);
    109 
    110   // Synchronously allocates an array on the device of type T with element_count
    111   // elements.
    112   template <typename T>
    113   DeviceMemory<T> AllocateArray(uint64 element_count);
    114 
    115   // As AllocateArray(), but returns a ScopedDeviceMemory<T>.
    116   template <typename T>
    117   ScopedDeviceMemory<T> AllocateOwnedArray(uint64 element_count) {
    118     return ScopedDeviceMemory<T>(this, AllocateArray<T>(element_count));
    119   }
    120 
    121   // Convenience wrapper that allocates space for a single element of type T in
    122   // device memory.
    123   template <typename T>
    124   DeviceMemory<T> AllocateScalar() {
    125     return AllocateArray<T>(1);
    126   }
    127 
    128   // As AllocateScalar(), but returns a ScopedDeviceMemory<T>.
    129   template <typename T>
    130   ScopedDeviceMemory<T> AllocateOwnedScalar() {
    131     return AllocateOwnedArray<T>(1);
    132   }
    133 
    134   // Synchronously allocates a scalar of type T on the device that is (POD)
    135   // zero-byte initialized.
    136   template <typename T>
    137   DeviceMemory<T> AllocateZeroed();
    138 
    139   // As AllocateZeroed(), but returns a ScopedDeviceMemory<T>.
    140   template <typename T>
    141   ScopedDeviceMemory<T> AllocateOwnedZeroed() {
    142     return ScopedDeviceMemory<T>(this, AllocateZeroed<T>());
    143   }
    144 
    145   // Allocate a memory region inside another allocated memory region.
    146   // Offset and size are specified in terms of T elements.
    147   // Warning: Do not free a parent buffer before its sub-buffers; this may cause
    148   // use-after-free issues (the specific behavior is not consistent across
    149   // platforms).
    150   //  - Note: OpenCL uses refcounting to manage buffer lifetimes, so use of a
    151   //    sub-buffer after parent deallocation is expected to be safe. This will
    152   //    render your code non-platform-portable, however.
    153   template <typename T>
    154   DeviceMemory<T> AllocateSubBuffer(DeviceMemory<T> *parent,
    155                                     uint64 element_offset,
    156                                     uint64 element_count);
    157 
    158   // As AllocateSubBuffer(), but returns a ScopedDeviceMemory<T>.
    159   template <typename T>
    160   ScopedDeviceMemory<T> AllocateOwnedSubBuffer(DeviceMemory<T> *parent,
    161                                                uint64 element_offset,
    162                                                uint64 element_count) {
    163     return ScopedDeviceMemory<T>(
    164         this, AllocateSubBuffer<T>(parent, element_offset, element_count));
    165   }
    166 
    167   // Finds a symbol and returns device memory allocated to the symbol. The
    168   // symbol is searched in any kernels that were previously loaded through
    169   // GetKernel() before the GetSymbol() call. The user has to make sure that the
    170   // type of symbol and T match.
    171   // - Note: symbol_name should include its namespace as well. For example,
    172   //         pass "nms0::symbol" if referring to nms0::symbol.
    173   template <typename T>
    174   port::StatusOr<DeviceMemory<T>> GetSymbol(const string &symbol_name);
    175 
    176   // Deallocate the DeviceMemory previously allocated via this interface.
    177   // Deallocation of a nullptr-representative value is permitted.
    178   //
    179   // Resets the internal contents of mem to be null-representative, but this
    180   // null-out effect should not be relied upon in client code.
    181   void Deallocate(DeviceMemoryBase *mem);
    182 
    183   // Retrieves a mapping of active opaque device memory pointer to a string
    184   // representation of the [allocating thread's] stack at the time the pointer
    185   // was allocated. Useful for tracking device memory leaks.
    186   //
    187   // Note: this will only be populated if --check_device_leaks flag is
    188   // activated.
    189   void GetMemAllocs(std::map<void *, AllocRecord> *records_out);
    190 
    191   // Allocates a region of host memory and registers it with the platform API.
    192   // Memory allocated in this manner (or allocated and registered with
    193   // HostMemoryRegister() is required for use in asynchronous memcpy operations,
    194   // such as Stream::ThenMemcpy.
    195   void *HostMemoryAllocate(uint64 bytes);
    196 
    197   // Deallocates a region of host memory allocated by HostMemoryAllocate().
    198   void HostMemoryDeallocate(void *location);
    199 
    200   // Registers a region of host memory with the platform API. Registered memory
    201   // (or memory allocated with HostMemoryAllocate) is required for use with
    202   // asynchronous memcpy operations, such as Stream::ThenMemcpy. This method
    203   // is used to register memory allocated outside the StreamExecutor;
    204   // HostMemoryAllocate implicitly registers its allocations and
    205   // HostMemoryDeallocate implicitly deregisters on deallocation.
    206   bool HostMemoryRegister(void *location, uint64 size) SE_MUST_USE_RESULT;
    207 
    208   // Unregisters a region of host memory registered with HostMemoryRegister.
    209   // This should be done before deallocating the region with delete[]/free/etc.
    210   bool HostMemoryUnregister(void *location) SE_MUST_USE_RESULT;
    211 
    212   // Synchronizes all activity occurring in the StreamExecutor's context (most
    213   // likely a whole device).
    214   bool SynchronizeAllActivity() SE_MUST_USE_RESULT;
    215 
    216   // Blocks the caller while "size" bytes are zeroed out (in POD fashion) at the
    217   // given location in device memory.
    218   bool SynchronousMemZero(DeviceMemoryBase *location,
    219                           uint64 size) SE_MUST_USE_RESULT;
    220 
    221   // Blocks the caller while "size" bytes are initialized to "value" (in POD
    222   // fashion) at the given location in device memory.
    223   bool SynchronousMemSet(DeviceMemoryBase *location, int value,
    224                          uint64 size) SE_MUST_USE_RESULT;
    225 
    226   // [deprecated] Blocks the caller while a data segment of the given size is
    227   // copied from the host source to the device destination.
    228   //
    229   // Deprecation: prefer explicit H2D below, to avoid error-prone API usage.
    230   bool SynchronousMemcpy(DeviceMemoryBase *device_dst, const void *host_src,
    231                          uint64 size) SE_MUST_USE_RESULT;
    232 
    233   // [deprecated] Blocks the caller while a data segment of the given size is
    234   // copied from the device source to the host destination.
    235   //
    236   // Deprecation: prefer explicit D2H below, to avoid error-prone API usage.
    237   bool SynchronousMemcpy(void *host_dst, const DeviceMemoryBase &device_src,
    238                          uint64 size) SE_MUST_USE_RESULT;
    239 
    240   // Same as SynchronousMemcpy(DeviceMemoryBase*, ...) above.
    241   port::Status SynchronousMemcpyH2D(const void *host_src, int64 size,
    242                                     DeviceMemoryBase *device_dst);
    243 
    244   // Alternative interface for memcpying from host to device that takes an
    245   // array slice. Checks that the destination size can accommodate the host
    246   // slice size.
    247   template <class T>
    248   port::Status SynchronousMemcpyH2D(port::ArraySlice<T> host_src,
    249                                     DeviceMemoryBase *device_dst) {
    250     auto host_size = host_src.size() * sizeof(T);
    251     CHECK(device_dst->size() == 0 || device_dst->size() >= host_size);
    252     return SynchronousMemcpyH2D(host_src.begin(), host_size, device_dst);
    253   }
    254 
    255   // Same as SynchronousMemcpy(void*, ...) above.
    256   port::Status SynchronousMemcpyD2H(const DeviceMemoryBase &device_src,
    257                                     int64 size, void *host_dst);
    258 
    259   // Alternative interface for memcpying from device to host that takes an
    260   // array slice. Checks that the destination size can accommodate the host
    261   // slice size.
    262   template <typename T>
    263   port::Status SynchronousMemcpyD2H(const DeviceMemory<T> &device_src,
    264                                     port::MutableArraySlice<T> host_dst) {
    265     auto host_size = host_dst.size() * sizeof(T);
    266     CHECK(device_src.size() == 0 || host_size >= device_src.size());
    267     return SynchronousMemcpyD2H(device_src, host_size, host_dst.begin());
    268   }
    269 
    270   // Blocks the caller while a data segment of the given size is copied from the
    271   // device source to the device destination.
    272   bool SynchronousMemcpy(DeviceMemoryBase *device_dst,
    273                          const DeviceMemoryBase &device_src,
    274                          uint64 size) SE_MUST_USE_RESULT;
    275 
    276   // Enqueues an operation onto stream to zero out size bytes at the given
    277   // device memory location. Neither stream nor location may be null. Returns
    278   // whether the operation was successfully enqueued onto the stream.
    279   bool MemZero(Stream *stream, DeviceMemoryBase *location,
    280                uint64 size) SE_MUST_USE_RESULT;
    281 
    282   // Enqueues an operation onto stream to set 32-bit patterns starting at
    283   // location, for byte count given by size. size must be 32-bit quantified
    284   // (i.e. evently divisible by 4). Returns whether the operation was
    285   // successfully enqueued onto the stream.
    286   bool Memset32(Stream *stream, DeviceMemoryBase *location, uint32 pattern,
    287                 uint64 size) SE_MUST_USE_RESULT;
    288 
    289   // Enables peer access from this StreamExecutor to memory
    290   // allocated by other, such that launched device code, memcpies, etc may
    291   // access it directly.
    292   //
    293   // Both this StreamExecutor and other must be backed by the same platform (as
    294   // in
    295   // CUDA vs OpenCL) implementation.
    296   port::Status EnablePeerAccessTo(StreamExecutor *other);
    297 
    298   // Returns whether it's possible to enable peer access from this
    299   // StreamExecutor
    300   // to memory allocated by another.
    301   //
    302   // Even when this returns true, EnablePeerAccessTo may fail for other reasons;
    303   // this is more an up-front test as to whether it's expressly forbidden.
    304   bool CanEnablePeerAccessTo(StreamExecutor *other);
    305 
    306   // Gets the preferred shared memory configuration for the device to which this
    307   // executor is bound.
    308   SharedMemoryConfig GetDeviceSharedMemoryConfig();
    309 
    310   // Sets the preferred shared memory configuration for the device to which this
    311   // executor is bound.
    312   port::Status SetDeviceSharedMemoryConfig(SharedMemoryConfig config);
    313 
    314   // Obtains metadata about the underlying device.
    315   // The value is cached on first use.
    316   const DeviceDescription &GetDeviceDescription() const;
    317 
    318   // If implemented, returns device specific measurement of load
    319   // (e.g. pending requests).
    320   int64 GetDeviceLoad() const;
    321 
    322   // Returns the underlying device memory usage information, if it is available.
    323   // If it is not available (false is returned), free/total may not be
    324   // initialized.
    325   //
    326   // Note: "Free" reflects the amount of free memory on the underlying device,
    327   // so allocations via other StreamExecutors that have the same underlying
    328   // device
    329   // will be reflected in "free".
    330   bool DeviceMemoryUsage(int64 *free, int64 *total) const;
    331 
    332   // The device count reported by this StreamExecutor's platform.
    333   // Note: on OpenCL we implicitly select platform zero at the moment.
    334   int PlatformDeviceCount() const;
    335 
    336   // Returns whether the StreamExecutor supports BLAS routines for the platform
    337   // that underlies this interface.
    338   bool SupportsBlas() const;
    339 
    340   // Returns whether the StreamExecutor supports FFT routines for the platform
    341   // that underlies this interface.
    342   bool SupportsFft() const;
    343 
    344   // Returns whether the StreamExecutor supports RNG routines for the platform
    345   // that underlies this interface.
    346   bool SupportsRng() const;
    347 
    348   // Returns whether the StreamExecutor support neural net routines for the
    349   // platform that underlies this interface.
    350   bool SupportsDnn() const;
    351 
    352   // Get the list of supported algorithms for the forward convolution opeartion.
    353   bool GetConvolveAlgorithms(bool with_winograd_nonfused,
    354                              std::vector<dnn::AlgorithmDesc> *out_algorithms);
    355 
    356   // Get the list of supported algorithms for the backward convolution on data.
    357   bool GetConvolveBackwardDataAlgorithms(
    358       bool with_winograd_nonfused,
    359       std::vector<dnn::AlgorithmDesc> *out_algorithms);
    360 
    361   // Get the list of supported algorithms for the backward convolution on the
    362   // filter.
    363   bool GetConvolveBackwardFilterAlgorithms(
    364       bool with_winograd_nonfused,
    365       std::vector<dnn::AlgorithmDesc> *out_algorithms);
    366 
    367   // Get the list of supported algorithms for BLAS gemm.
    368   bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms);
    369 
    370   // Create an RNN descriptor based on model shapes and configurations.
    371   // The caller retains the ownership of the descriptor.
    372   port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
    373       int num_layers, int hidden_size, int input_size,
    374       dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
    375       dnn::RnnMode rnn_mode, dnn::DataType data_type, float dropout,
    376       uint64 seed, ScratchAllocator *state_allocator);
    377 
    378   // Create a RNN sequence descriptor that specifies either the input or output
    379   // sequence. The caller retains the ownership of the returned descriptor.
    380   port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
    381   createRnnSequenceTensorDescriptor(int seq_length, int batch_size,
    382                                     int data_size, dnn::DataType data_type);
    383 
    384   // Create an RNN state descriptor that specifies the input or hidden state.
    385   // The caller retains the ownership of the returned descriptor.
    386   port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
    387   createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
    388                                  dnn::DataType data_type);
    389 
    390   // Returns the device ordinal that this StreamExecutor was initialized with.
    391   // Meaningless before initialization.
    392   int device_ordinal() const { return device_ordinal_; }
    393 
    394   // Returns a borrowed pointer to the underlying StreamExecutor implementation.
    395   internal::StreamExecutorInterface *implementation();
    396 
    397   // Warning: use Stream::ThenLaunch instead, this method is not for general
    398   // consumption. However, this is the only way to launch a kernel for which
    399   // the type signature is only known at runtime; say, if an application
    400   // supports loading/launching kernels with arbitrary type signatures.
    401   // In this case, the application is expected to know how to do parameter
    402   // packing that obeys the contract of the underlying platform implementation.
    403   //
    404   // Launches a data parallel kernel with the given thread/block
    405   // dimensionality and already-packed args/sizes to pass to the underlying
    406   // platform driver.
    407   //
    408   // This is called by Stream::Launch() to delegate to the platform's launch
    409   // implementation in StreamExecutorInterface::Launch().
    410   bool Launch(Stream *stream, const ThreadDim &thread_dims,
    411               const BlockDim &block_dims, const KernelBase &kernel,
    412               const KernelArgsArrayBase &args);
    413 
    414   // Gets-or-creates (creates with memoization) a FftSupport datatype that can
    415   // be used to execute FFT routines on the current platform.
    416   //
    417   // Ownership and user-facing is the same as AsBlas() below.
    418   //
    419   // Returns null if there was an error initializing the FFT support for the
    420   // underlying platform.
    421   fft::FftSupport *AsFft();
    422 
    423   // Gets-or-creates (creates with memoization) a DnnSupport datatype that can
    424   // be used for neural network routines on the current platform.
    425   //
    426   // Ownership and user-facing is the same as AsBlas() below.
    427   //
    428   // Returns null if there was an error initializing the DNN support for the
    429   // underlying platform.
    430   dnn::DnnSupport *AsDnn();
    431 
    432   // Turns StreamExecutor operation tracing on or off.
    433   void EnableTracing(bool enable);
    434 
    435   // Registers a trace listener to receive callbacks for only a single
    436   // StreamExecutor instance.
    437   // To register a listener for all executors for a given platform, see
    438   // Platform::RegisterTraceListener().
    439   // Does not take ownership of listener.
    440   void RegisterTraceListener(TraceListener* listener);
    441 
    442   // Removes a TraceListener from this StreamExecutor instance.
    443   // Returns false (and logs) in cases where the argument listener was not
    444   // previously registered.
    445   bool UnregisterTraceListener(TraceListener* listener);
    446 
    447  private:
    448   template <typename BeginCallT, typename CompleteCallT,
    449             typename ReturnT, typename... BeginArgsT>
    450   friend class ScopedTracer;
    451   friend class Event;
    452   friend class Stream;
    453   friend class Timer;
    454   template <typename... Params>
    455   friend class TypedKernel;
    456   template <typename... Args>
    457   friend struct ThenBlasImpl;
    458 
    459   // Gets-or-creates (creates with memoization) a BlasSupport datatype that can
    460   // be used to execute BLAS routines on the current platform. This is typically
    461   // not user-facing, as users will use the Stream::ThenBlas* family of routines
    462   // to entrain BLAS operations. See blas.h for additional details.
    463   //
    464   // Ownership is not transferred to the caller -- ownership is retained by this
    465   // object for memoization. This BLAS interface is also only expected to be
    466   // used by a Stream for entraining calls to BLAS functionality.
    467   //
    468   // Returns null if there was an error initializing the BLAS support for the
    469   // underlying platform.
    470   blas::BlasSupport *AsBlas();
    471 
    472   // Gets-or-creates (creates with memoization) an RngSupport datatype that can
    473   // be used for random-number-generation routines on the current platform.
    474   //
    475   // Ownership and user-facing is the same as AsBlas() above.
    476   //
    477   // Returns null if there was an error initializing the RNG support for the
    478   // underlying platform.
    479   rng::RngSupport *AsRng();
    480 
    481   // Causes the host code to synchronously wait for operations entrained onto
    482   // stream to complete. Effectively a join on the asynchronous device
    483   // operations enqueued on the stream before this program point.
    484   port::Status BlockHostUntilDone(Stream *stream);
    485 
    486   // Synchronously allocates size bytes on the underlying platform and returns
    487   // an opaque void* representing that allocation. In the case of failure,
    488   // nullptr is returned.
    489   void *Allocate(uint64 size);
    490 
    491   // Finds and retrieves device memory for the symbol on the underlying
    492   // platform.
    493   bool GetSymbol(const string& symbol_name, void **mem, size_t *bytes);
    494 
    495   // Entrains a memcpy operation onto stream, with a host destination location
    496   // host_dst and a device memory source, with target size size.
    497   bool Memcpy(Stream *stream, void *host_dst,
    498               const DeviceMemoryBase &device_src, uint64 size);
    499 
    500   // Entrains a memcpy operation onto stream, with a device destination location
    501   // and a host memory source, with target size size.
    502   bool Memcpy(Stream *stream, DeviceMemoryBase *device_dst,
    503               const void *host_src, uint64 size);
    504 
    505   // Entrains a memcpy operation onto stream, with a device destination location
    506   // and a device source location, with target size size. Peer access should
    507   // have been enabled between the StreamExecutors owning the device memory
    508   // regions.
    509   bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *device_dst,
    510                             const DeviceMemoryBase &device_src, uint64 size);
    511 
    512   // Entrains on a stream a user-specified function to be run on the host.
    513   // See Stream::ThenDoHostCallback for full details.
    514   bool HostCallback(Stream *stream, std::function<void()> callback);
    515 
    516   // Performs platform-specific allocation and initialization of an event.
    517   port::Status AllocateEvent(Event *event);
    518 
    519   // Performs platform-specific deallocation and cleanup of an event.
    520   port::Status DeallocateEvent(Event *event);
    521 
    522   // Inserts the specified event at the end of the specified stream.
    523   port::Status RecordEvent(Stream *stream, Event *event);
    524 
    525   // Wait for the specified event at the end of the specified stream.
    526   port::Status WaitForEvent(Stream *stream, Event *event);
    527 
    528   // Requests the current status of the event from the underlying platform.
    529   Event::Status PollForEventStatus(Event *event);
    530 
    531   // Allocates stream resources on the underlying platform for subject and
    532   // initializes its internals.
    533   bool AllocateStream(Stream *subject);
    534 
    535   // Deallocates stream resources on the underlying platform.
    536   void DeallocateStream(Stream *subject);
    537 
    538   // Causes dependent to not begin execution until other has finished its
    539   // last-enqueued work.
    540   bool CreateStreamDependency(Stream *dependent, Stream *other);
    541 
    542   // Allocates timer resources on the underlying platform for subject and
    543   // initializes its internals.
    544   bool AllocateTimer(Timer *subject);
    545 
    546   // Deallocates timer resources on the underlying platform.
    547   void DeallocateTimer(Timer *subject);
    548 
    549   // Records a start event for an interval timer.
    550   bool StartTimer(Stream *stream, Timer *timer);
    551 
    552   // Records a stop event for an interval timer.
    553   bool StopTimer(Stream *stream, Timer *timer);
    554 
    555   // Allocates a new metadata object, appropriately populated, on the heap, with
    556   // ownership transfer to caller.
    557   DeviceDescription *PopulateDeviceDescription() const;
    558 
    559   // Adds a task to the port::ThreadPool work queue. These tasks must be
    560   // fire-and-forget and have no external data or timing dependencies; their
    561   // execution order and completion time have no guarantees.
    562   // For an example of an appropriate task, see HostBlas::DoBlasGemmInternal;
    563   // there, temporary internal buffers are freed using this method.
    564   void EnqueueOnBackgroundThread(std::function<void()> task);
    565 
    566   // Adds an AllocRecord for 'opaque' of size 'bytes' to the record map, for
    567   // leak checking. NULL buffer pointers and buffer sizes of 0 will not be
    568   // tracked.
    569   void CreateAllocRecord(void *opaque, uint64 size);
    570 
    571   // Removes the AllocRecord keyed by 'opaque' from the record map. NULL
    572   // pointers will not be erased (as they're not tracked, per above).
    573   void EraseAllocRecord(void *opaque);
    574 
    575   // Calls the relevant TraceListener routine to begin tracing for the specified
    576   // asynchronous method.
    577   template <typename TraceCallT, typename... ArgsT>
    578   void SubmitTrace(TraceCallT trace_call, ArgsT&&... args);
    579 
    580   // Reader/writer lock for class-static StreamExecutor members.
    581   static mutex static_mu_;
    582 
    583   // Reader/writer lock for mutable data structures on this StreamExecutor.
    584   //
    585   // Mutable so that caching functions (like DeviceDescription, AsBlas, etc.)
    586   // can acquire the lock on their first (mutating) call as well.
    587   mutable mutex mu_;
    588 
    589   // Reference to the platform that created this executor.
    590   const Platform *platform_;
    591 
    592   // Pointer to the platform-specific-interface implementation. This is
    593   // delegated to by the interface routines in pointer-to-implementation
    594   // fashion.
    595   std::unique_ptr<internal::StreamExecutorInterface> implementation_;
    596 
    597   // A mapping of pointer (to device memory) to string representation of the
    598   // stack (of the allocating thread) at the time at which the pointer was
    599   // allocated.
    600   std::map<void *, AllocRecord> mem_allocs_ GUARDED_BY(mu_);
    601 
    602   // Memoized BLAS support object -- we only want to create this once when asked
    603   // for a BLAS interface.
    604   std::unique_ptr<blas::BlasSupport> blas_ GUARDED_BY(mu_);
    605 
    606   // Memoized DNN support object -- we only want to create this once when asked
    607   // for an DNN interface.
    608   std::unique_ptr<dnn::DnnSupport> dnn_ GUARDED_BY(mu_);
    609 
    610   // Memoized FFT support object -- we only want to create this once when asked
    611   // for a FFT interface.
    612   std::unique_ptr<fft::FftSupport> fft_;
    613 
    614   // Memoized RNG support object -- we only want to create this once when asked
    615   // for an RNG interface.
    616   std::unique_ptr<rng::RngSupport> rng_ GUARDED_BY(mu_);
    617 
    618   // Slot to cache the owned DeviceDescription for the underlying device
    619   // once it has been quieried from DeviceDescription().
    620   mutable std::unique_ptr<DeviceDescription> device_description_
    621       GUARDED_BY(mu_);
    622 
    623   // The kind of the underlying platform that is being targeted, as passed
    624   // during construction.
    625   //
    626   // Immutable post-initialization.
    627   PlatformKind platform_kind_;
    628 
    629   // The device ordinal that this object was initialized with.
    630   //
    631   // Immutable post-initialization.
    632   int device_ordinal_;
    633 
    634   // Executor for handling host callback work that cannot be performed
    635   // by a host callback thread - for example, cleanup after a host BLAS routine
    636   // (which may make device API calls). This work cannot block the host
    637   // callback thread, will be completed asynchronously, and should be treated
    638   // as fire-and-forget. Assume no ordering guarantees WRT the tasks enqueued
    639   // here.
    640   //
    641   // Immutable post-initialization. Object is thread-safe.
    642   std::unique_ptr<port::ThreadPool> background_threads_;
    643 
    644   // Counter for the current number of live streams. This is used to check
    645   // for accidentally-outstanding streams at StreamExecutor teardown time, as
    646   // well
    647   // as to indicate leaks (via a large outstanding count being logged) in the
    648   // case we can't allocate more streams.
    649   std::atomic_int_fast32_t live_stream_count_;
    650 
    651   // Only one worker thread is needed; little work will be done by the
    652   // executor.
    653   static const int kNumBackgroundThreads = 1;
    654 
    655   // Indicates if StreamExecutor operation tracing should be performed.
    656   bool tracing_enabled_;
    657 
    658   // The set of TraceListeners registered for this StreamExecutor.
    659   std::set<TraceListener*> listeners_ GUARDED_BY(mu_);
    660 
    661   SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutor);
    662 };
    663 
    664 ////////////
    665 // Inlines
    666 
    667 template <typename T>
    668 inline DeviceMemory<T> StreamExecutor::AllocateArray(uint64 element_count) {
    669   uint64 bytes = sizeof(T) * element_count;
    670   void *opaque = Allocate(bytes);
    671   return DeviceMemory<T>::MakeFromByteSize(opaque, bytes);
    672 }
    673 
    674 template <typename T>
    675 inline port::StatusOr<DeviceMemory<T>> StreamExecutor::GetSymbol(
    676     const string &symbol_name) {
    677   // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to
    678   // be nullptr/0 for consistency with DeviceMemory semantics.
    679   void *opaque = nullptr;
    680   size_t bytes = 0;
    681   if (GetSymbol(symbol_name, &opaque, &bytes)) {
    682     CHECK_EQ(bytes % sizeof(T), 0);
    683     return DeviceMemory<T>::MakeFromByteSize(opaque, bytes);
    684   }
    685   return port::Status(
    686       port::error::NOT_FOUND,
    687       port::StrCat("Check if kernel using the symbol is loaded: ",
    688                    symbol_name));
    689 }
    690 
    691 template <typename ElemT>
    692 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory()
    693     : wrapped_(DeviceMemoryBase()), parent_(nullptr) {}
    694 
    695 template <typename ElemT>
    696 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory(StreamExecutor *parent,
    697                                               DeviceMemoryBase value)
    698     : wrapped_(value), parent_(parent) {}
    699 
    700 template <typename ElemT>
    701 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory(
    702     StreamExecutor *parent, std::initializer_list<ElemT> values)
    703     : ScopedDeviceMemory(parent, parent->AllocateArray<ElemT>(values.size())) {
    704   if (ptr() != nullptr) {
    705     std::vector<ElemT> local(values);
    706     if (!parent->SynchronousMemcpy(ptr(), const_cast<const ElemT *>(&local[0]),
    707                                    ptr()->size())) {
    708       Reset(nullptr);
    709     }
    710   }
    711 }
    712 
    713 template <typename ElemT>
    714 ScopedDeviceMemory<ElemT>::~ScopedDeviceMemory() {
    715   if (wrapped_ == nullptr) return;
    716   DCHECK(parent_ != nullptr);
    717   parent_->Deallocate(&wrapped_);
    718 }
    719 
    720 template <typename ElemT>
    721 void ScopedDeviceMemory<ElemT>::Reset(DeviceMemory<ElemT> updated) {
    722   if (wrapped_ != nullptr) {
    723     DCHECK(parent_ != nullptr);
    724     parent_->Deallocate(&wrapped_);
    725   }
    726   wrapped_ = updated;
    727 }
    728 
    729 template <typename ElemT>
    730 void ScopedDeviceMemory<ElemT>::Reset(std::nullptr_t) {
    731   if (wrapped_ != nullptr) {
    732     DCHECK(parent_ != nullptr);
    733     parent_->Deallocate(&wrapped_);
    734   }
    735   wrapped_ = DeviceMemory<ElemT>{};
    736 }
    737 
    738 template <typename T>
    739 DeviceMemory<T> StreamExecutor::AllocateZeroed() {
    740   void *opaque = Allocate(sizeof(T));
    741   if (opaque == nullptr) {
    742     return DeviceMemory<T>{};
    743   }
    744 
    745   DeviceMemory<T> result = DeviceMemory<T>::MakeFromByteSize(opaque, sizeof(T));
    746   bool ok = SynchronousMemZero(&result, sizeof(T));
    747   if (!ok) {
    748     Deallocate(&result);
    749     return DeviceMemory<T>{};
    750   }
    751 
    752   return result;
    753 }
    754 
    755 template <typename T>
    756 DeviceMemory<T> StreamExecutor::AllocateSubBuffer(DeviceMemory<T> *parent,
    757                                                   uint64 element_offset,
    758                                                   uint64 element_count) {
    759   if (element_offset + element_count > parent->ElementCount()) {
    760     LOG(ERROR) << "requested sub-buffer allocation (offset + size) is greater "
    761                << "than parent allocation size: (" << element_offset << " + "
    762                << element_count << ") vs. (" << parent->ElementCount() << ")";
    763     return DeviceMemory<T>{};
    764   }
    765 
    766   void *opaque = implementation_->AllocateSubBuffer(
    767       parent, sizeof(T) * element_offset, sizeof(T) * element_count);
    768   if (opaque == nullptr) {
    769     return DeviceMemory<T>{};
    770   }
    771   CreateAllocRecord(opaque, sizeof(T) * element_count);
    772   return DeviceMemory<T>(DeviceMemoryBase(opaque, sizeof(T) * element_count,
    773                                     true /* = is_sub_buffer */));
    774 }
    775 
    776 template <typename... Params, typename... Args>
    777 inline Stream &Stream::ThenLaunch(ThreadDim thread_dims, BlockDim block_dims,
    778                                   const TypedKernel<Params...> &kernel,
    779                                   Args... args) {
    780   KernelInvocationChecker<std::tuple<Params...>,
    781                           std::tuple<Args...>>::CheckAllStaticAssert();
    782   if (ok()) {
    783     // This is the core that allows type-safe kernel launching.
    784     // Since the platforms take kernel arguments as tuples of (void *, size),
    785     // we pack the variadic parameters passed as ...args into the desired
    786     // tuple form and pass that packed form to the StreamExecutor::Launch()
    787     // implementation.
    788     KernelArgsArray<sizeof...(args)> kernel_args;
    789     kernel.PackParams(&kernel_args, args...);
    790     DCHECK(parent_ != nullptr);
    791     bool ok =
    792         parent_->Launch(this, thread_dims, block_dims, kernel, kernel_args);
    793     if (!ok) {
    794       SetError();
    795       LOG(WARNING) << "parent failed to launch kernel: " << &kernel;
    796     }
    797   }
    798   return *this;
    799 }
    800 
    801 }  // namespace gputools
    802 }  // namespace perftools
    803 
    804 #endif  // TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
    805