Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/allocation_tracker.h"
     17 
     18 #include <utility>
     19 
     20 #include "absl/memory/memory.h"
     21 #include "absl/strings/str_cat.h"
     22 #include "tensorflow/compiler/xla/map_util.h"
     23 #include "tensorflow/compiler/xla/service/device_memory_allocator.h"
     24 #include "tensorflow/compiler/xla/service/transfer_manager.h"
     25 #include "tensorflow/compiler/xla/shape_util.h"
     26 #include "tensorflow/compiler/xla/status_macros.h"
     27 #include "tensorflow/compiler/xla/types.h"
     28 #include "tensorflow/compiler/xla/util.h"
     29 #include "tensorflow/core/platform/logging.h"
     30 
     31 namespace xla {
     32 
     33 StatusOr<GlobalDataHandle> AllocationTracker::Register(
     34     ScopedShapedBuffer shaped_buffer, const string& tag) {
     35   tensorflow::mutex_lock lock(mutex_);
     36   VLOG(2) << "Register";
     37   std::vector<ScopedShapedBuffer> replicated_buffers;
     38   replicated_buffers.emplace_back(std::move(shaped_buffer));
     39   return RegisterInternal(std::move(replicated_buffers), tag);
     40 }
     41 
     42 StatusOr<GlobalDataHandle> AllocationTracker::RegisterReplicatedBuffers(
     43     std::vector<ScopedShapedBuffer> replicated_buffers, const string& tag) {
     44   tensorflow::mutex_lock lock(mutex_);
     45   VLOG(2) << "RegisterReplicatedBuffers";
     46   return RegisterInternal(std::move(replicated_buffers), tag);
     47 }
     48 
     49 // ReleaseIfScopedShapedBuffer lets RegisterInternal<ShapedBufferTy>(b) call
     50 // b.release() if b is a ScopedShapedBuffer, or otherwise pass b through
     51 // unmodified.
     52 static ShapedBuffer ReleaseIfScopedShapedBuffer(ShapedBuffer b) { return b; }
     53 static ShapedBuffer ReleaseIfScopedShapedBuffer(ScopedShapedBuffer b) {
     54   return b.release();
     55 }
     56 
     57 template <typename ShapedBufferTy>
     58 StatusOr<GlobalDataHandle> AllocationTracker::RegisterInternal(
     59     std::vector<ShapedBufferTy> replicated_buffers, const string& tag) {
     60   static_assert(std::is_same<ShapedBufferTy, ShapedBuffer>::value ||
     61                     std::is_same<ShapedBufferTy, ScopedShapedBuffer>::value,
     62                 "ShapedBufferTy must be ShapedBuffer or ScopedShapedBuffer.");
     63   VLOG(2) << "RegisterInternal("
     64           << "tag: \"" << tag << "\" with " << replicated_buffers.size()
     65           << " shaped_buffers.";
     66   for (const auto& shaped_buffer : replicated_buffers) {
     67     VLOG(2) << "shaped_buffer:" << shaped_buffer;
     68     if (shaped_buffer.platform() != backend_->platform()) {
     69       return InvalidArgument(
     70           "AllocationTracker for platform %s cannot register buffer from "
     71           "platform %s",
     72           backend_->platform()->Name(), shaped_buffer.platform()->Name());
     73     }
     74   }
     75 
     76   int64 handle = next_handle_++;
     77   for (auto& shaped_buffer : replicated_buffers) {
     78     std::vector<ShapeIndex> shape_indices;
     79     ShapeUtil::ForEachSubshape(
     80         shaped_buffer.on_device_shape(),
     81         [&](const Shape& /*subshape*/, const ShapeIndex& index) {
     82           shape_indices.push_back(index);
     83         });
     84     // Add shaped_buffer's buffers to opaque_to_allocation_map_, which owns
     85     // them.
     86     for (const ShapeIndex& index : shape_indices) {
     87       AddAllocationOrIncrementRefCount(shaped_buffer.buffer(index),
     88                                        shaped_buffer.device_ordinal());
     89     }
     90     // If ShapedBufferTy is ScopedShapedBuffer, release the ScopedShapedBuffer
     91     // into a regular ShapedBuffer, which is stored in
     92     // handle_to_shaped_buffers_.
     93     handle_to_shaped_buffers_[handle].emplace_back(
     94         absl::make_unique<ShapedBuffer>(
     95             ReleaseIfScopedShapedBuffer(std::move(shaped_buffer))));
     96   }
     97 
     98   GlobalDataHandle result;
     99   result.set_handle(handle);
    100   VLOG(2) << "handle: " << handle;
    101   return result;
    102 }
    103 
    104 Status AllocationTracker::Unregister(const GlobalDataHandle& data) {
    105   tensorflow::mutex_lock lock(mutex_);
    106   VLOG(2) << "Unregister("
    107           << "handle: " << data.handle() << ")";
    108   TF_ASSIGN_OR_RETURN(std::vector<const ShapedBuffer*> replicated_buffers,
    109                       ResolveInternal(data));
    110   for (const auto& shaped_buffer : replicated_buffers) {
    111     std::vector<ShapeIndex> shape_indices;
    112     ShapeUtil::ForEachSubshape(
    113         shaped_buffer->on_device_shape(),
    114         [&shape_indices](const Shape& /*subshape*/, const ShapeIndex& index) {
    115           shape_indices.push_back(index);
    116         });
    117     for (const ShapeIndex& index : shape_indices) {
    118       TF_RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index),
    119                                            shaped_buffer->device_ordinal()));
    120     }
    121   }
    122   // Keep a nullptr as a tombstone for unregistered handles. This enables
    123   // better error messages. That is, "handle has been deallocated" versus
    124   // "handle does not exist".
    125   auto it = handle_to_shaped_buffers_.find(data.handle());
    126   if (it == handle_to_shaped_buffers_.end()) {
    127     return NotFound("no allocation record for global data handle: %d",
    128                     data.handle());
    129   }
    130   for (auto& shaped_buffer : it->second) {
    131     shaped_buffer.reset();
    132   }
    133   return Status::OK();
    134 }
    135 
    136 StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple(
    137     const GlobalDataHandle& data) {
    138   tensorflow::mutex_lock lock(mutex_);
    139 
    140   TF_ASSIGN_OR_RETURN(std::vector<const ShapedBuffer*> replicated_buffers,
    141                       ResolveInternal(data));
    142   // We only need to care about replica id 0 here, since the GlobalDataHandle is
    143   // the same for all buffers across replicas.
    144   const ShapedBuffer* shaped_buffer = replicated_buffers[0];
    145   if (!shaped_buffer->on_host_shape().IsTuple()) {
    146     return InvalidArgument("global data handle %d is not a tuple",
    147                            data.handle());
    148   }
    149   // If the on-host representation is a tuple, then the on-device one should be
    150   // as well.
    151   TF_RET_CHECK(shaped_buffer->on_device_shape().IsTuple());
    152 
    153   if (ShapeUtil::IsNestedTuple(shaped_buffer->on_device_shape())) {
    154     return Unimplemented("Deconstructing nested tuples is not implemented.");
    155   }
    156 
    157   std::vector<GlobalDataHandle> element_handles;
    158   for (int i = 0;
    159        i < ShapeUtil::TupleElementCount(shaped_buffer->on_device_shape());
    160        ++i) {
    161     auto element_buffer = ShapedBuffer(
    162         ShapeUtil::GetTupleElementShape(shaped_buffer->on_host_shape(), i),
    163         ShapeUtil::GetTupleElementShape(shaped_buffer->on_device_shape(), i),
    164         shaped_buffer->platform(), shaped_buffer->device_ordinal());
    165     element_buffer.set_buffer(shaped_buffer->buffer(/*index=*/{i}),
    166                               /*index=*/{});
    167     std::vector<ShapedBuffer> replicated_buffers;
    168     replicated_buffers.push_back(std::move(element_buffer));
    169     TF_ASSIGN_OR_RETURN(
    170         GlobalDataHandle element_handle,
    171         RegisterInternal(std::move(replicated_buffers), "deconstructed tuple"));
    172 
    173     element_handles.push_back(element_handle);
    174   }
    175   return std::move(element_handles);
    176 }
    177 
    178 StatusOr<std::vector<const ShapedBuffer*>> AllocationTracker::Resolve(
    179     const GlobalDataHandle& data) const {
    180   tensorflow::mutex_lock lock(mutex_);
    181   return AllocationTracker::ResolveInternal(data);
    182 }
    183 
    184 StatusOr<const ShapedBuffer*> AllocationTracker::ResolveForReplica(
    185     const GlobalDataHandle& data, int replica_id) const {
    186   tensorflow::mutex_lock lock(mutex_);
    187   TF_ASSIGN_OR_RETURN(std::vector<const ShapedBuffer*> replicated_buffers,
    188                       ResolveInternal(data));
    189   if (replica_id >= replicated_buffers.size()) {
    190     return InvalidArgument(
    191         "Requesting buffer for replica %d, but found buffers only for %lu "
    192         "replicas.",
    193         replica_id, replicated_buffers.size());
    194   }
    195   return replicated_buffers[replica_id];
    196 }
    197 
    198 StatusOr<std::vector<const ShapedBuffer*>> AllocationTracker::ResolveInternal(
    199     const GlobalDataHandle& data) const {
    200   VLOG(2) << "resolve:" << data.handle();
    201   auto it = handle_to_shaped_buffers_.find(data.handle());
    202   if (it == handle_to_shaped_buffers_.end()) {
    203     return NotFound("no allocation record for global data handle: %d",
    204                     data.handle());
    205   }
    206   std::vector<const ShapedBuffer*> replicated_buffers;
    207   for (const auto& shaped_buffer : it->second) {
    208     if (shaped_buffer == nullptr) {
    209       return InvalidArgument("global data handle %d was previously deallocated",
    210                              data.handle());
    211     }
    212     replicated_buffers.push_back(shaped_buffer.get());
    213   }
    214 
    215   return replicated_buffers;
    216 }
    217 
    218 void AllocationTracker::AddAllocationOrIncrementRefCount(
    219     se::DeviceMemoryBase device_memory, int device_ordinal) {
    220   AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal];
    221   auto it = allocation_map.find(device_memory.opaque());
    222   if (it == allocation_map.end()) {
    223     allocation_map[device_memory.opaque()] = {
    224         OwningDeviceMemory(device_memory, device_ordinal,
    225                            backend_->memory_allocator()),
    226         /*ref_count=*/1};
    227   } else {
    228     it->second.ref_count++;
    229   }
    230 }
    231 
    232 Status AllocationTracker::DecrementRefCount(se::DeviceMemoryBase device_memory,
    233                                             int device_ordinal) {
    234   AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal];
    235   auto it = allocation_map.find(device_memory.opaque());
    236   TF_RET_CHECK(it != allocation_map.end());
    237   Allocation& allocation = it->second;
    238   TF_RET_CHECK(allocation.ref_count >= 1);
    239   if (allocation.ref_count == 1) {
    240     allocation.device_memory.Free();
    241     allocation_map.erase(it);
    242   } else {
    243     allocation.ref_count--;
    244   }
    245   return Status::OK();
    246 }
    247 
    248 }  // namespace xla
    249