Home | History | Annotate | Download | only in stream_executor
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Temporary memories are used to allocate scratch space required by an
     17 // operation about to be enqueued onto a stream.
     18 //
     19 //    std::unique_ptr<TemporaryDeviceMemory<float>> temporary_memory =
     20 //        stream.AllocateTemporaryArray<float>(1024).ConsumeValueOrDie();
     21 //    // ... enqueue stuff onto the stream using the temporary memory ...
     22 //    // Note that the memory is accessible via
     23 //    // temporary_memory->device_memory() and similar.
     24 //
     25 //    // Finalize the temporary memory. The underlying device memory may
     26 //    // be released any time after this program point, as another thread may
     27 //    // call Stream::BlockHostUntilDone, causing synchronization. This
     28 //    // finalization also happens automatically for the user if the unique_ptr
     29 //    // goes out of scope.
     30 //    temporary_memory.Finalize();
     31 //
     32 // WARNING: do NOT hold onto the device memory associated with temporary_memory
     33 // after finalization. If temporary_memory->device_memory() is used after the
     34 // temporary memory is finalized, it will cause a DCHECK failure.
     35 //
     36 // Note that standard usage takes advantage of the type-safe wrapper,
     37 // TemporaryDeviceMemory<T>, defined below.
     38 //
     39 // Also see tests for executable sample usage.
     40 
     41 #ifndef TENSORFLOW_STREAM_EXECUTOR_TEMPORARY_DEVICE_MEMORY_H_
     42 #define TENSORFLOW_STREAM_EXECUTOR_TEMPORARY_DEVICE_MEMORY_H_
     43 
     44 #include "tensorflow/stream_executor/device_memory.h"
     45 
     46 namespace stream_executor {
     47 
     48 class Stream;
     49 namespace internal {
     50 class TemporaryMemoryManager;
     51 }
     52 
     53 // Untyped base class (analogous to a void*) for temporary device memory
     54 // allocations associated with a stream.
     55 class TemporaryDeviceMemoryBase {
     56  public:
     57   // Marks the temporary memory as finalized if it is not already marked as
     58   // such.
     59   ~TemporaryDeviceMemoryBase();
     60 
     61   // Precondition: !IsFinalized()
     62   DeviceMemoryBase* mutable_device_memory();
     63 
     64   // Precondition: !IsFinalized()
     65   const DeviceMemoryBase& device_memory() const;
     66 
     67   // "Finalizes" this temporary memory, making it acceptable to release at the
     68   // next stream synchronization point -- the device memory can be reclaimed at
     69   // any time after the temporary memory is marked as finalized (e.g. if a
     70   // separate thread is calls Stream::BlockHostUntilDone). This may only be
     71   // called once -- see the precondition below.
     72   //
     73   // Precondition: !IsFinalized()
     74   void Finalize();
     75 
     76   // Returns true iff the temporary memory is finalized (that is, the user is
     77   // done referring to the temporary device memory, and thus it can be released
     78   // at the next stream synchronization point).
     79   bool IsFinalized() const;
     80 
     81   // Returns true iff the temporary memory is still allocated.
     82   //
     83   // Note: this is a polling call, no guarantee is made that the temporary
     84   // memory is still allocated after the call has completed.
     85   bool IsAllocated() const;
     86 
     87  private:
     88   friend class internal::TemporaryMemoryManager;
     89   friend class TemporaryDeviceMemoryTest;
     90 
     91   // Note: construction DCHECKs that the memory is known-allocated in the
     92   // stream's temporary-allocation-manager.
     93   TemporaryDeviceMemoryBase(Stream* parent, DeviceMemoryBase device_memory,
     94                             uint64 allocation_generation);
     95 
     96   // The device memory region that has allocated.
     97   DeviceMemoryBase device_memory_;
     98 
     99   // The generation counter value for the temporary memory record in the
    100   // temporary memory manager.
    101   uint64 allocation_generation_;
    102 
    103   // The stream that this temporary memory was allocated for.
    104   Stream* parent_;
    105 };
    106 
    107 // Type-safe wrapper around the base type (which is analogous to a void*).
    108 template <typename T>
    109 class TemporaryDeviceMemory : public TemporaryDeviceMemoryBase {
    110  public:
    111   // Type-safe wrapper around TemporaryDeviceMemoryBase::mutable_device_memory.
    112   DeviceMemory<T>* mutable_device_memory() {
    113     StaticSlicingAssertionDummy();
    114     return reinterpret_cast<DeviceMemory<T>*>(
    115         TemporaryDeviceMemoryBase::mutable_device_memory());
    116   }
    117 
    118   // Type-safe wrapper around TemporaryDeviceMemoryBase::device_memory.
    119   const DeviceMemory<T>& device_memory() const {
    120     StaticSlicingAssertionDummy();
    121     return reinterpret_cast<const DeviceMemory<T>&>(
    122         TemporaryDeviceMemoryBase::device_memory());
    123   }
    124 
    125  private:
    126   static void StaticSlicingAssertionDummy() {
    127     static_assert(
    128         sizeof(TemporaryDeviceMemory) == sizeof(TemporaryDeviceMemoryBase),
    129         "derived class is simply a wrapper, no members may be added due to "
    130         "slicing");
    131   }
    132 };
    133 
    134 }  // namespace stream_executor
    135 
    136 #endif  // TENSORFLOW_STREAM_EXECUTOR_TEMPORARY_DEVICE_MEMORY_H_
    137