Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/allocation_tracker.h"
     17 
     18 #include <utility>
     19 
     20 #include "tensorflow/compiler/xla/map_util.h"
     21 #include "tensorflow/compiler/xla/ptr_util.h"
     22 #include "tensorflow/compiler/xla/service/device_memory_allocator.h"
     23 #include "tensorflow/compiler/xla/service/transfer_manager.h"
     24 #include "tensorflow/compiler/xla/shape_util.h"
     25 #include "tensorflow/compiler/xla/status_macros.h"
     26 #include "tensorflow/compiler/xla/types.h"
     27 #include "tensorflow/compiler/xla/util.h"
     28 #include "tensorflow/core/lib/strings/strcat.h"
     29 #include "tensorflow/core/platform/logging.h"
     30 
     31 namespace xla {
     32 
     33 StatusOr<GlobalDataHandle> AllocationTracker::Register(
     34     std::unique_ptr<ShapedBuffer> shaped_buffer, const string& tag) {
     35   tensorflow::mutex_lock lock(mutex_);
     36   VLOG(2) << "Register";
     37   return RegisterInternal(std::move(shaped_buffer), tag);
     38 }
     39 
     40 StatusOr<GlobalDataHandle> AllocationTracker::RegisterInternal(
     41     std::unique_ptr<ShapedBuffer> shaped_buffer, const string& tag) {
     42   VLOG(2) << "RegisterInternal("
     43           << "tag: \"" << tag << "\" "
     44           << "shaped_buffer: " << *shaped_buffer;
     45   if (shaped_buffer->platform() != backend_->platform()) {
     46     return InvalidArgument(
     47         "AllocationTracker for platform %s cannot register buffer from "
     48         "platform %s",
     49         backend_->platform()->Name().c_str(),
     50         shaped_buffer->platform()->Name().c_str());
     51   }
     52 
     53   int64 handle = next_handle_++;
     54   std::vector<ShapeIndex> shape_indices;
     55   ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(),
     56                              [this, &shape_indices](const Shape& /*subshape*/,
     57                                                     const ShapeIndex& index) {
     58                                shape_indices.push_back(index);
     59                              });
     60   for (const ShapeIndex& index : shape_indices) {
     61     AddAllocationOrIncrementRefCount(shaped_buffer->buffer(index),
     62                                      shaped_buffer->device_ordinal());
     63   }
     64   GlobalDataHandle result;
     65   result.set_handle(handle);
     66 
     67   handle_to_shaped_buffer_[handle] = std::move(shaped_buffer);
     68 
     69   VLOG(2) << "handle: " << handle;
     70 
     71   return result;
     72 }
     73 
     74 tensorflow::Status AllocationTracker::Unregister(const GlobalDataHandle& data) {
     75   tensorflow::mutex_lock lock(mutex_);
     76   VLOG(2) << "Unregister("
     77           << "handle: " << data.handle() << ")";
     78   TF_ASSIGN_OR_RETURN(ShapedBuffer * shaped_buffer, ResolveInternal(data));
     79   std::vector<ShapeIndex> shape_indices;
     80   ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(),
     81                              [this, &shape_indices](const Shape& /*subshape*/,
     82                                                     const ShapeIndex& index) {
     83                                shape_indices.push_back(index);
     84                              });
     85   for (const ShapeIndex& index : shape_indices) {
     86     TF_RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index),
     87                                          shaped_buffer->device_ordinal()));
     88   }
     89 
     90   // Keep a nullptr as a tombstone for unregistered handles. This enables better
     91   // error messages. That is, "handle has been deallocated" versus "handle does
     92   // not exist".
     93   handle_to_shaped_buffer_.at(data.handle()).reset();
     94 
     95   return tensorflow::Status::OK();
     96 }
     97 
     98 StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple(
     99     const GlobalDataHandle& data) {
    100   tensorflow::mutex_lock lock(mutex_);
    101 
    102   TF_ASSIGN_OR_RETURN(ShapedBuffer * shaped_buffer, ResolveInternal(data));
    103   if (!ShapeUtil::IsTuple(shaped_buffer->on_host_shape())) {
    104     return InvalidArgument("global data handle %lld is not a tuple",
    105                            data.handle());
    106   }
    107   // If the on-host representation is a tuple, then the on-device one should be
    108   // as well.
    109   TF_RET_CHECK(ShapeUtil::IsTuple(shaped_buffer->on_device_shape()));
    110 
    111   if (ShapeUtil::IsNestedTuple(shaped_buffer->on_device_shape())) {
    112     return Unimplemented("deconstructing nested tuples not yet supported");
    113   }
    114 
    115   std::vector<GlobalDataHandle> element_handles;
    116   for (int i = 0;
    117        i < ShapeUtil::TupleElementCount(shaped_buffer->on_device_shape());
    118        ++i) {
    119     auto element_buffer = MakeUnique<ShapedBuffer>(
    120         ShapeUtil::GetTupleElementShape(shaped_buffer->on_host_shape(), i),
    121         ShapeUtil::GetTupleElementShape(shaped_buffer->on_device_shape(), i),
    122         shaped_buffer->platform(), shaped_buffer->device_ordinal());
    123     element_buffer->set_buffer(shaped_buffer->buffer(/*index=*/{i}),
    124                                /*index=*/{});
    125     TF_ASSIGN_OR_RETURN(
    126         GlobalDataHandle element_handle,
    127         RegisterInternal(std::move(element_buffer), "deconstructed tuple"));
    128 
    129     element_handles.push_back(element_handle);
    130   }
    131   return std::move(element_handles);
    132 }
    133 
    134 StatusOr<const ShapedBuffer*> AllocationTracker::Resolve(
    135     const GlobalDataHandle& data) {
    136   tensorflow::mutex_lock lock(mutex_);
    137   return AllocationTracker::ResolveInternal(data);
    138 }
    139 
    140 StatusOr<ShapedBuffer*> AllocationTracker::ResolveInternal(
    141     const GlobalDataHandle& data) {
    142   VLOG(2) << "resolve:" << data.handle();
    143   auto it = handle_to_shaped_buffer_.find(data.handle());
    144   if (it == handle_to_shaped_buffer_.end()) {
    145     return NotFound("no allocation record for global data handle: %lld",
    146                     data.handle());
    147   }
    148   ShapedBuffer* shaped_buffer = it->second.get();
    149 
    150   if (shaped_buffer == nullptr) {
    151     return InvalidArgument("global data handle %lld was previously deallocated",
    152                            data.handle());
    153   }
    154 
    155   return shaped_buffer;
    156 }
    157 
    158 void AllocationTracker::AddAllocationOrIncrementRefCount(
    159     perftools::gputools::DeviceMemoryBase device_memory, int device_ordinal) {
    160   AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal];
    161   auto it = allocation_map.find(device_memory.opaque());
    162   if (it == allocation_map.end()) {
    163     allocation_map[device_memory.opaque()] = {device_memory, device_ordinal,
    164                                               /*ref_count=*/1};
    165   } else {
    166     it->second.ref_count++;
    167   }
    168 }
    169 
    170 Status AllocationTracker::DecrementRefCount(
    171     perftools::gputools::DeviceMemoryBase device_memory, int device_ordinal) {
    172   AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal];
    173   auto it = allocation_map.find(device_memory.opaque());
    174   TF_RET_CHECK(it != allocation_map.end());
    175   Allocation& allocation = it->second;
    176   TF_RET_CHECK(allocation.ref_count >= 1);
    177   if (allocation.ref_count == 1) {
    178     TF_RETURN_IF_ERROR(backend_->memory_allocator()->Deallocate(
    179         device_ordinal, &device_memory));
    180     allocation_map.erase(it);
    181   } else {
    182     allocation.ref_count--;
    183   }
    184   return tensorflow::Status::OK();
    185 }
    186 
    187 }  // namespace xla
    188