1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/allocation_tracker.h" 17 18 #include <utility> 19 20 #include "tensorflow/compiler/xla/map_util.h" 21 #include "tensorflow/compiler/xla/ptr_util.h" 22 #include "tensorflow/compiler/xla/service/device_memory_allocator.h" 23 #include "tensorflow/compiler/xla/service/transfer_manager.h" 24 #include "tensorflow/compiler/xla/shape_util.h" 25 #include "tensorflow/compiler/xla/status_macros.h" 26 #include "tensorflow/compiler/xla/types.h" 27 #include "tensorflow/compiler/xla/util.h" 28 #include "tensorflow/core/lib/strings/strcat.h" 29 #include "tensorflow/core/platform/logging.h" 30 31 namespace xla { 32 33 StatusOr<GlobalDataHandle> AllocationTracker::Register( 34 std::unique_ptr<ShapedBuffer> shaped_buffer, const string& tag) { 35 tensorflow::mutex_lock lock(mutex_); 36 VLOG(2) << "Register"; 37 return RegisterInternal(std::move(shaped_buffer), tag); 38 } 39 40 StatusOr<GlobalDataHandle> AllocationTracker::RegisterInternal( 41 std::unique_ptr<ShapedBuffer> shaped_buffer, const string& tag) { 42 VLOG(2) << "RegisterInternal(" 43 << "tag: \"" << tag << "\" " 44 << "shaped_buffer: " << *shaped_buffer; 45 if (shaped_buffer->platform() != backend_->platform()) { 46 return InvalidArgument( 47 "AllocationTracker for platform %s cannot register buffer from " 48 "platform %s", 49 backend_->platform()->Name().c_str(), 50 shaped_buffer->platform()->Name().c_str()); 51 } 52 53 int64 handle = next_handle_++; 54 std::vector<ShapeIndex> shape_indices; 55 ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(), 56 [this, &shape_indices](const Shape& /*subshape*/, 57 const ShapeIndex& index) { 58 shape_indices.push_back(index); 59 }); 60 for (const ShapeIndex& index : shape_indices) { 61 AddAllocationOrIncrementRefCount(shaped_buffer->buffer(index), 62 shaped_buffer->device_ordinal()); 63 } 64 GlobalDataHandle result; 65 result.set_handle(handle); 66 67 handle_to_shaped_buffer_[handle] = std::move(shaped_buffer); 68 69 VLOG(2) << "handle: " << handle; 70 71 return result; 72 } 73 74 tensorflow::Status AllocationTracker::Unregister(const GlobalDataHandle& data) { 75 tensorflow::mutex_lock lock(mutex_); 76 VLOG(2) << "Unregister(" 77 << "handle: " << data.handle() << ")"; 78 TF_ASSIGN_OR_RETURN(ShapedBuffer * shaped_buffer, ResolveInternal(data)); 79 std::vector<ShapeIndex> shape_indices; 80 ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(), 81 [this, &shape_indices](const Shape& /*subshape*/, 82 const ShapeIndex& index) { 83 shape_indices.push_back(index); 84 }); 85 for (const ShapeIndex& index : shape_indices) { 86 TF_RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index), 87 shaped_buffer->device_ordinal())); 88 } 89 90 // Keep a nullptr as a tombstone for unregistered handles. This enables better 91 // error messages. That is, "handle has been deallocated" versus "handle does 92 // not exist". 93 handle_to_shaped_buffer_.at(data.handle()).reset(); 94 95 return tensorflow::Status::OK(); 96 } 97 98 StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple( 99 const GlobalDataHandle& data) { 100 tensorflow::mutex_lock lock(mutex_); 101 102 TF_ASSIGN_OR_RETURN(ShapedBuffer * shaped_buffer, ResolveInternal(data)); 103 if (!ShapeUtil::IsTuple(shaped_buffer->on_host_shape())) { 104 return InvalidArgument("global data handle %lld is not a tuple", 105 data.handle()); 106 } 107 // If the on-host representation is a tuple, then the on-device one should be 108 // as well. 109 TF_RET_CHECK(ShapeUtil::IsTuple(shaped_buffer->on_device_shape())); 110 111 if (ShapeUtil::IsNestedTuple(shaped_buffer->on_device_shape())) { 112 return Unimplemented("deconstructing nested tuples not yet supported"); 113 } 114 115 std::vector<GlobalDataHandle> element_handles; 116 for (int i = 0; 117 i < ShapeUtil::TupleElementCount(shaped_buffer->on_device_shape()); 118 ++i) { 119 auto element_buffer = MakeUnique<ShapedBuffer>( 120 ShapeUtil::GetTupleElementShape(shaped_buffer->on_host_shape(), i), 121 ShapeUtil::GetTupleElementShape(shaped_buffer->on_device_shape(), i), 122 shaped_buffer->platform(), shaped_buffer->device_ordinal()); 123 element_buffer->set_buffer(shaped_buffer->buffer(/*index=*/{i}), 124 /*index=*/{}); 125 TF_ASSIGN_OR_RETURN( 126 GlobalDataHandle element_handle, 127 RegisterInternal(std::move(element_buffer), "deconstructed tuple")); 128 129 element_handles.push_back(element_handle); 130 } 131 return std::move(element_handles); 132 } 133 134 StatusOr<const ShapedBuffer*> AllocationTracker::Resolve( 135 const GlobalDataHandle& data) { 136 tensorflow::mutex_lock lock(mutex_); 137 return AllocationTracker::ResolveInternal(data); 138 } 139 140 StatusOr<ShapedBuffer*> AllocationTracker::ResolveInternal( 141 const GlobalDataHandle& data) { 142 VLOG(2) << "resolve:" << data.handle(); 143 auto it = handle_to_shaped_buffer_.find(data.handle()); 144 if (it == handle_to_shaped_buffer_.end()) { 145 return NotFound("no allocation record for global data handle: %lld", 146 data.handle()); 147 } 148 ShapedBuffer* shaped_buffer = it->second.get(); 149 150 if (shaped_buffer == nullptr) { 151 return InvalidArgument("global data handle %lld was previously deallocated", 152 data.handle()); 153 } 154 155 return shaped_buffer; 156 } 157 158 void AllocationTracker::AddAllocationOrIncrementRefCount( 159 perftools::gputools::DeviceMemoryBase device_memory, int device_ordinal) { 160 AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal]; 161 auto it = allocation_map.find(device_memory.opaque()); 162 if (it == allocation_map.end()) { 163 allocation_map[device_memory.opaque()] = {device_memory, device_ordinal, 164 /*ref_count=*/1}; 165 } else { 166 it->second.ref_count++; 167 } 168 } 169 170 Status AllocationTracker::DecrementRefCount( 171 perftools::gputools::DeviceMemoryBase device_memory, int device_ordinal) { 172 AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal]; 173 auto it = allocation_map.find(device_memory.opaque()); 174 TF_RET_CHECK(it != allocation_map.end()); 175 Allocation& allocation = it->second; 176 TF_RET_CHECK(allocation.ref_count >= 1); 177 if (allocation.ref_count == 1) { 178 TF_RETURN_IF_ERROR(backend_->memory_allocator()->Deallocate( 179 device_ordinal, &device_memory)); 180 allocation_map.erase(it); 181 } else { 182 allocation.ref_count--; 183 } 184 return tensorflow::Status::OK(); 185 } 186 187 } // namespace xla 188