Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ALLOCATION_TRACKER_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_ALLOCATION_TRACKER_H_
     18 
     19 #include <map>
     20 #include <memory>
     21 #include <set>
     22 #include <string>
     23 #include <vector>
     24 
     25 #include "tensorflow/compiler/xla/service/backend.h"
     26 #include "tensorflow/compiler/xla/statusor.h"
     27 #include "tensorflow/compiler/xla/types.h"
     28 #include "tensorflow/compiler/xla/xla_data.pb.h"
     29 #include "tensorflow/core/platform/macros.h"
     30 #include "tensorflow/core/platform/mutex.h"
     31 #include "tensorflow/core/platform/thread_annotations.h"
     32 #include "tensorflow/core/platform/types.h"
     33 
     34 namespace xla {
     35 
     36 // Tracks allocations for the XLA service; allocations can be registered
     37 // with shape/device/tag and resolved from a handle for later use.
     38 class AllocationTracker {
     39  public:
     40   // The allocator is used for deallocating memory when allocations are
     41   // deregistered. All registered allocations must have the same platform as the
     42   // allocator.
     43   AllocationTracker(Backend* backend) : backend_(backend), next_handle_(1) {}
     44 
     45   // Registers a shaped buffer of device memory, and returns a corresponding
     46   // handle that can be used for talking to XLA clients.
     47   StatusOr<GlobalDataHandle> Register(
     48       std::unique_ptr<ShapedBuffer> shaped_buffer, const string& tag);
     49 
     50   // Unregister the allocation for the given data handle.
     51   Status Unregister(const GlobalDataHandle& data);
     52 
     53   // Returns a vector of global data handles that point to the tuple elements.
     54   StatusOr<std::vector<GlobalDataHandle>> DeconstructTuple(
     55       const GlobalDataHandle& Data);
     56 
     57   // Resolve a handle from an XLA client to a shaped buffer, or provide an error
     58   // status to say whether it was not found (or found, but found deallocated).
     59   StatusOr<const ShapedBuffer*> Resolve(const GlobalDataHandle& data);
     60 
     61  private:
     62   // Data structure encapsulating single memory allocation on the device.
     63   struct Allocation {
     64     // The pointer to this allocation.
     65     perftools::gputools::DeviceMemoryBase device_memory;
     66 
     67     // The device that the memory is allocated on.
     68     int device_ordinal;
     69 
     70     // This is the number of times this memory allocation is referred to by
     71     // registered data handles.
     72     int ref_count;
     73   };
     74 
     75   // Internal helper which resolves the given GlobalDataHandle to a
     76   // ShapedBuffer.
     77   StatusOr<ShapedBuffer*> ResolveInternal(const GlobalDataHandle& data)
     78       EXCLUSIVE_LOCKS_REQUIRED(mutex_);
     79 
     80   // Internal helper which registers a shaped buffer.
     81   StatusOr<GlobalDataHandle> RegisterInternal(
     82       std::unique_ptr<ShapedBuffer> shaped_buffer, const string& tag)
     83       EXCLUSIVE_LOCKS_REQUIRED(mutex_);
     84 
     85   // Adds the given device address to the allocation tracker, or if it already
     86   // exists, then increment it's reference count.
     87   void AddAllocationOrIncrementRefCount(
     88       perftools::gputools::DeviceMemoryBase device_memory, int device_ordinal)
     89       EXCLUSIVE_LOCKS_REQUIRED(mutex_);
     90 
     91   // Decrements the reference count of the given device memory. Then, if it is
     92   // zero, deallocate the memory.
     93   Status DecrementRefCount(perftools::gputools::DeviceMemoryBase device_memory,
     94                            int device_ordinal) EXCLUSIVE_LOCKS_REQUIRED(mutex_);
     95 
     96   // A map from device memory opaque value to allocation. One such map is
     97   // maintained per device ordinal.
     98   using AllocationMap = tensorflow::gtl::FlatMap<const void*, Allocation>;
     99 
    100   tensorflow::mutex mutex_;
    101 
    102   // Backend to use with this tracker. The backend supplies the memory allocator
    103   // to use when deallocating memory.
    104   Backend* backend_;
    105 
    106   // The next handle to assign to an allocation, guarded by the same mutex as
    107   // the mapping as they'll be mutated at the same time.
    108   int64 next_handle_ GUARDED_BY(mutex_);
    109 
    110   // A map from device ordinal to AllocationMap.
    111   tensorflow::gtl::FlatMap<int, AllocationMap> opaque_to_allocation_map_
    112       GUARDED_BY(mutex_);
    113 
    114   // A map from data handle to ShapedBuffer.
    115   tensorflow::gtl::FlatMap<int64, std::unique_ptr<ShapedBuffer>>
    116       handle_to_shaped_buffer_ GUARDED_BY(mutex_);
    117 
    118   TF_DISALLOW_COPY_AND_ASSIGN(AllocationTracker);
    119 };
    120 
    121 }  // namespace xla
    122 
    123 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_ALLOCATION_TRACKER_H_
    124