Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_
     18 
     19 #include <map>
     20 #include <set>
     21 #include <vector>
     22 
     23 #include "absl/types/span.h"
     24 #include "tensorflow/compiler/xla/literal.h"
     25 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
     26 #include "tensorflow/compiler/xla/statusor.h"
     27 #include "tensorflow/compiler/xla/types.h"
     28 #include "tensorflow/compiler/xla/xla_data.pb.h"
     29 #include "tensorflow/core/platform/mutex.h"
     30 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
     31 #include "tensorflow/core/platform/thread_annotations.h"
     32 #include "tensorflow/core/platform/types.h"
     33 
     34 namespace xla {
     35 
     36 // The TransferManager interface lets backends provide platform-specific
     37 // mechanisms for constructing literals from given device memory handles.
     38 // This lets each platform customize how literals are transferred to/from the
     39 // device in terms of padding, leading dimension, etc.
     40 class TransferManager {
     41  public:
     42   virtual ~TransferManager() {}
     43 
     44   // Returns the ID of the platform that this transfer manager acts on.
     45   virtual se::Platform::Id PlatformId() const = 0;
     46 
     47   // Returns the shape of the on-device representation for the given shape on
     48   // the host. This is intended for use with ShapedBuffer where buffers are
     49   // pre-allocated by the host, e.g. TransferLiteralToDevice, without the user
     50   // needing to consider device-specific behaviors.
     51   virtual Shape HostShapeToDeviceShape(const Shape& host_shape) const {
     52     return host_shape;
     53   }
     54 
     55   // Base class for specifying platform specific transfer metadata that can be
     56   // used to tell the underlying implementation to perform specific optimization
     57   // to a transfer. Actual metadata passed to supported transfer methods should
     58   // subclass this class.
     59   class TransferMetadata {
     60    public:
     61     virtual ~TransferMetadata() = 0;
     62   };
     63   // Returns a literal containing the data held in the given ShapedBuffer
     64   // using the provided executor. This operation is performed synchronously
     65   // without waiting for any other operation on a stream to complete.
     66   //
     67   // This function should be avoided in favor of the asynchronous version below.
     68   //
     69   // Optionally caller can specify platform-specific transfer metadata that
     70   // tells the actual implementation to do something special.
     71   virtual StatusOr<Literal> TransferLiteralFromDevice(
     72       se::Stream* stream, const ShapedBuffer& device_buffer,
     73       const TransferMetadata* transfer_metadata);
     74   StatusOr<Literal> TransferLiteralFromDevice(
     75       se::Stream* stream, const ShapedBuffer& device_buffer) {
     76     return TransferLiteralFromDevice(stream, device_buffer, nullptr);
     77   }
     78   virtual Status TransferLiteralFromDevice(
     79       se::Stream* stream, const ShapedBuffer& device_buffer,
     80       const MutableBorrowingLiteral& literal,
     81       const TransferMetadata* transfer_metadata);
     82   Status TransferLiteralFromDevice(se::Stream* stream,
     83                                    const ShapedBuffer& device_buffer,
     84                                    const MutableBorrowingLiteral& literal) {
     85     return TransferLiteralFromDevice(stream, device_buffer, literal, nullptr);
     86   }
     87 
     88   // Begins transferring a literal containing the data held in the given
     89   // ShapedBuffer using the provided executor.
     90   //
     91   // This operation is performed asynchronously on the given stream. It returns
     92   // once the transfer is enqueued. 'done' is invoked with the result when
     93   // complete.
     94   //
     95   // device_buffer is copied by reference and must live at least until done() is
     96   // invoked.
     97   //
     98   // Optionally caller can specify platform-specific transfer metadata that
     99   // tells the actual implementation to do something special.
    100   virtual void TransferLiteralFromDevice(
    101       se::Stream* stream, const ShapedBuffer& device_buffer,
    102       MutableBorrowingLiteral literal, std::function<void(Status)> done,
    103       const TransferMetadata* transfer_metadata) = 0;
    104   void TransferLiteralFromDevice(se::Stream* stream,
    105                                  const ShapedBuffer& device_buffer,
    106                                  MutableBorrowingLiteral literal,
    107                                  std::function<void(Status)> done) {
    108     return TransferLiteralFromDevice(stream, device_buffer, literal, done,
    109                                      nullptr);
    110   }
    111 
    112   // Transfers the given literal into the previously allocated device memory
    113   // represented by the given ShapedBuffer using the given executor. The shape
    114   // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible,
    115   // but need not have the same layout.
    116   //
    117   // This operation is performed synchronously without waiting for any other
    118   // operation on a stream to complete. This function should be avoided in favor
    119   // of the asynchronous version below.
    120   //
    121   // Optionally caller can specify platform-specific transfer metadata that
    122   // tells the actual implementation to do something special.
    123   virtual Status TransferLiteralToDevice(
    124       se::Stream* stream, const LiteralSlice& literal,
    125       const ShapedBuffer& device_buffer,
    126       const TransferMetadata* transfer_metadata);
    127   Status TransferLiteralToDevice(se::Stream* stream,
    128                                  const LiteralSlice& literal,
    129                                  const ShapedBuffer& device_buffer) {
    130     return TransferLiteralToDevice(stream, literal, device_buffer, nullptr);
    131   }
    132 
    133   // Transfers the given literal into the previously allocated device memory
    134   // represented by the given ShapedBuffer using the given executor. The shape
    135   // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible,
    136   // but need not have the same layout.
    137   //
    138   // This operation is performed asynchronously on the given stream. It returns
    139   // once the transfer is enqueued, and may return before the transfer has
    140   // completed.
    141   //
    142   // The caller may free the data structures 'literal' and 'device_buffer'
    143   // immediately after this function returns, however their constituent buffers
    144   // on both host and device must remain valid until the enqueued transfer has
    145   // completed on 'stream'.
    146   //
    147   // Optionally caller can specify platform-specific transfer metadata that
    148   // tells the actual implementation to do something special.
    149   virtual Status TransferLiteralToDeviceAsync(
    150       se::Stream* stream, const LiteralSlice& literal,
    151       const ShapedBuffer& device_buffer,
    152       const TransferMetadata* transfer_metadata) = 0;
    153   Status TransferLiteralToDeviceAsync(se::Stream* stream,
    154                                       const LiteralSlice& literal,
    155                                       const ShapedBuffer& device_buffer) {
    156     return TransferLiteralToDeviceAsync(stream, literal, device_buffer,
    157                                         nullptr);
    158   }
    159 
    160   // Convenience methods for transferring an array to or from the device at a
    161   // known address. This avoids having to construct a ShapedBuffer just to
    162   // transfer an array at a known address.
    163   //
    164   // Optionally caller can specify platform-specific transfer metadata that
    165   // tells the actual implementation to do something special.
    166   Status TransferArrayToDevice(
    167       se::Stream* stream, const LiteralSlice& literal,
    168       const se::DeviceMemoryBase& dest,
    169       const TransferMetadata* transfer_metadata = nullptr);
    170   void TransferArrayFromDevice(
    171       se::Stream* stream, const Shape& shape,
    172       const se::DeviceMemoryBase& source,
    173       const MutableBorrowingLiteral& literal, std::function<void(Status)> done,
    174       const TransferMetadata* transfer_metadata = nullptr);
    175 
    176   Status TransferArrayToDeviceAsync(
    177       se::Stream* stream, const LiteralSlice& literal,
    178       const se::DeviceMemoryBase& dest,
    179       const TransferMetadata* transfer_metadata = nullptr);
    180   StatusOr<Literal> TransferArrayFromDevice(
    181       se::Stream* stream, const Shape& shape,
    182       const se::DeviceMemoryBase& source,
    183       const TransferMetadata* transfer_metadata = nullptr);
    184 
    185   // Transfers the given literal into the Infeed interface of the device,
    186   // using the given executor.
    187   virtual Status TransferLiteralToInfeed(se::StreamExecutor* executor,
    188                                          const LiteralSlice& literal) = 0;
    189 
    190   // Transfers the given literal from the Outfeed interface of the device,
    191   // using the given executor.
    192   virtual Status TransferLiteralFromOutfeed(
    193       se::StreamExecutor* executor, const Shape& literal_shape,
    194       MutableBorrowingLiteral literal) = 0;
    195 
    196   // Resets the devices associated with this transfer manager.
    197   virtual Status ResetDevices(
    198       absl::Span<se::StreamExecutor* const> executor) = 0;
    199 
    200   // Given an allocated ShapedBuffer, constructs the tuple index table(s) in
    201   // each buffer of the given ShapedBuffer corresponding to tuple shapes. If the
    202   // ShapedBuffer is array-shaped this method does nothing.
    203   Status WriteTupleIndexTables(se::Stream* stream,
    204                                const ShapedBuffer& device_buffer);
    205   Status WriteTupleIndexTablesAsync(se::Stream* stream,
    206                                     const ShapedBuffer& device_buffer);
    207 
    208   // Writes a tuple index buffer for the root of 'device_buffer', which must
    209   // be a tuple. Unlike WriteTupleIndexTables, only writes the root buffer,
    210   // rather than writing all subbuffers. This method is always asynchronous.
    211   Status WriteRootTupleIndexTable(se::Stream* stream,
    212                                   const ShapedBuffer& device_buffer);
    213 
    214   // Determines the byte size requirement for the given shape on the underlying
    215   // architecture. This will be used to allocate an appropriately sized memory
    216   // region for a host-to-device transfer.
    217   virtual int64 GetByteSizeRequirement(const Shape& shape) const = 0;
    218 
    219   // Allocates a ScopedShapedBuffer which can hold data with the given on-host
    220   // shape. The on-device shape may be different as indicated by
    221   // HostShapeToDeviceShape.
    222   StatusOr<ScopedShapedBuffer> AllocateScopedShapedBuffer(
    223       const Shape& on_host_shape, DeviceMemoryAllocator* allocator,
    224       int device_ordinal);
    225 
    226   // The given ShapedBuffer holds a handle to allocated memory, but it is not
    227   // in the general case legal to immediately copy or access that allocated
    228   // memory because queued operations on the device may alias that memory.
    229   // Memory ordering is enforced by the Stream's happens-before relationship
    230   // which allows eager deallocation and reallocation of buffers host-side even
    231   // if the device hasn't finished with them.
    232   //
    233   // In certain cases, it can be known that a ShapedBuffer does not have any
    234   // conflicting accesses on the device and thus is eligible to be accessed at
    235   // any time from the host.
    236   //
    237   // This function returns true if device_buffer can be accessed immediately
    238   // without waiting for the Stream's previously enqueued items. This only
    239   // returns true if all subbuffers in device_buffer can be accessed
    240   // immediately.
    241   virtual bool CanShapedBufferBeAccessedNow(
    242       se::StreamExecutor* executor, const ShapedBuffer& device_buffer) const {
    243     return false;
    244   }
    245 
    246   /////
    247   // The TransferManager class also serves as a point to register objects for
    248   // the various platforms.
    249 
    250   // Registers the TransferManager singleton for the platform kind. This is
    251   // assumed to be a singleton, so no ownership is transferred.
    252   //
    253   // Precondition: a platform kind must not be registered more than once.
    254   typedef std::unique_ptr<TransferManager> (*TransferManagerCreationFunction)();
    255   static void RegisterTransferManager(
    256       se::Platform::Id platform_id,
    257       TransferManagerCreationFunction transfer_manager);
    258 
    259   // Returns the transfer manager singleton pointer if it is available for the
    260   // given platform, or an error status if it is not.
    261   static StatusOr<TransferManager*> GetForPlatform(
    262       const se::Platform* platform);
    263 
    264  protected:
    265   // Transfer a memory block of the given size from the device source into the
    266   // 'destination' buffer.
    267   //
    268   // size is the size to transfer to destination in bytes.
    269   virtual Status TransferBufferFromDevice(se::Stream* stream,
    270                                           const se::DeviceMemoryBase& source,
    271                                           int64 size, void* destination);
    272 
    273   // Transfer a memory block of the given size from 'source' buffer to the given
    274   // destination of the device.
    275   //
    276   // size is the size to transfer from source in bytes.
    277   virtual Status TransferBufferToDevice(se::Stream* stream, int64 size,
    278                                         const void* source,
    279                                         se::DeviceMemoryBase* destination);
    280 
    281   // Writes the given device-memory pointers in 'elements' to the given region
    282   // to construct a tuple index table in the platform-specific tuple
    283   // representation.
    284   virtual Status WriteSingleTupleIndexTable(
    285       se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements,
    286       const Shape& shape, se::DeviceMemoryBase* region) = 0;
    287 
    288  private:
    289   // The mutex that guards the platform-to-transfer manager map.
    290   static tensorflow::mutex platform_transfer_manager_mutex_;
    291 
    292   // State kept for each kind of TransferManager.  Registration functions
    293   // set up creation_function, and then we use that to lazily create
    294   // "manager" the first time GetForPlatform is invoked for a particular id.
    295   struct State {
    296     std::unique_ptr<TransferManager> manager;
    297     TransferManagerCreationFunction creation_function = nullptr;
    298   };
    299 
    300   // Map from platform kind to transfer manager singleton.
    301   static std::map<se::Platform::Id, State>* GetPlatformTransferManagers();
    302 };
    303 
    304 }  // namespace xla
    305 
    306 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_
    307