1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/allocation_tracker.h" 17 18 #include <utility> 19 20 #include "absl/memory/memory.h" 21 #include "absl/strings/str_cat.h" 22 #include "tensorflow/compiler/xla/map_util.h" 23 #include "tensorflow/compiler/xla/service/device_memory_allocator.h" 24 #include "tensorflow/compiler/xla/service/transfer_manager.h" 25 #include "tensorflow/compiler/xla/shape_util.h" 26 #include "tensorflow/compiler/xla/status_macros.h" 27 #include "tensorflow/compiler/xla/types.h" 28 #include "tensorflow/compiler/xla/util.h" 29 #include "tensorflow/core/platform/logging.h" 30 31 namespace xla { 32 33 StatusOr<GlobalDataHandle> AllocationTracker::Register( 34 ScopedShapedBuffer shaped_buffer, const string& tag) { 35 tensorflow::mutex_lock lock(mutex_); 36 VLOG(2) << "Register"; 37 std::vector<ScopedShapedBuffer> replicated_buffers; 38 replicated_buffers.emplace_back(std::move(shaped_buffer)); 39 return RegisterInternal(std::move(replicated_buffers), tag); 40 } 41 42 StatusOr<GlobalDataHandle> AllocationTracker::RegisterReplicatedBuffers( 43 std::vector<ScopedShapedBuffer> replicated_buffers, const string& tag) { 44 tensorflow::mutex_lock lock(mutex_); 45 VLOG(2) << "RegisterReplicatedBuffers"; 46 return RegisterInternal(std::move(replicated_buffers), tag); 47 } 48 49 // ReleaseIfScopedShapedBuffer lets RegisterInternal<ShapedBufferTy>(b) call 50 // b.release() if b is a ScopedShapedBuffer, or otherwise pass b through 51 // unmodified. 52 static ShapedBuffer ReleaseIfScopedShapedBuffer(ShapedBuffer b) { return b; } 53 static ShapedBuffer ReleaseIfScopedShapedBuffer(ScopedShapedBuffer b) { 54 return b.release(); 55 } 56 57 template <typename ShapedBufferTy> 58 StatusOr<GlobalDataHandle> AllocationTracker::RegisterInternal( 59 std::vector<ShapedBufferTy> replicated_buffers, const string& tag) { 60 static_assert(std::is_same<ShapedBufferTy, ShapedBuffer>::value || 61 std::is_same<ShapedBufferTy, ScopedShapedBuffer>::value, 62 "ShapedBufferTy must be ShapedBuffer or ScopedShapedBuffer."); 63 VLOG(2) << "RegisterInternal(" 64 << "tag: \"" << tag << "\" with " << replicated_buffers.size() 65 << " shaped_buffers."; 66 for (const auto& shaped_buffer : replicated_buffers) { 67 VLOG(2) << "shaped_buffer:" << shaped_buffer; 68 if (shaped_buffer.platform() != backend_->platform()) { 69 return InvalidArgument( 70 "AllocationTracker for platform %s cannot register buffer from " 71 "platform %s", 72 backend_->platform()->Name(), shaped_buffer.platform()->Name()); 73 } 74 } 75 76 int64 handle = next_handle_++; 77 for (auto& shaped_buffer : replicated_buffers) { 78 std::vector<ShapeIndex> shape_indices; 79 ShapeUtil::ForEachSubshape( 80 shaped_buffer.on_device_shape(), 81 [&](const Shape& /*subshape*/, const ShapeIndex& index) { 82 shape_indices.push_back(index); 83 }); 84 // Add shaped_buffer's buffers to opaque_to_allocation_map_, which owns 85 // them. 86 for (const ShapeIndex& index : shape_indices) { 87 AddAllocationOrIncrementRefCount(shaped_buffer.buffer(index), 88 shaped_buffer.device_ordinal()); 89 } 90 // If ShapedBufferTy is ScopedShapedBuffer, release the ScopedShapedBuffer 91 // into a regular ShapedBuffer, which is stored in 92 // handle_to_shaped_buffers_. 93 handle_to_shaped_buffers_[handle].emplace_back( 94 absl::make_unique<ShapedBuffer>( 95 ReleaseIfScopedShapedBuffer(std::move(shaped_buffer)))); 96 } 97 98 GlobalDataHandle result; 99 result.set_handle(handle); 100 VLOG(2) << "handle: " << handle; 101 return result; 102 } 103 104 Status AllocationTracker::Unregister(const GlobalDataHandle& data) { 105 tensorflow::mutex_lock lock(mutex_); 106 VLOG(2) << "Unregister(" 107 << "handle: " << data.handle() << ")"; 108 TF_ASSIGN_OR_RETURN(std::vector<const ShapedBuffer*> replicated_buffers, 109 ResolveInternal(data)); 110 for (const auto& shaped_buffer : replicated_buffers) { 111 std::vector<ShapeIndex> shape_indices; 112 ShapeUtil::ForEachSubshape( 113 shaped_buffer->on_device_shape(), 114 [&shape_indices](const Shape& /*subshape*/, const ShapeIndex& index) { 115 shape_indices.push_back(index); 116 }); 117 for (const ShapeIndex& index : shape_indices) { 118 TF_RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index), 119 shaped_buffer->device_ordinal())); 120 } 121 } 122 // Keep a nullptr as a tombstone for unregistered handles. This enables 123 // better error messages. That is, "handle has been deallocated" versus 124 // "handle does not exist". 125 auto it = handle_to_shaped_buffers_.find(data.handle()); 126 if (it == handle_to_shaped_buffers_.end()) { 127 return NotFound("no allocation record for global data handle: %d", 128 data.handle()); 129 } 130 for (auto& shaped_buffer : it->second) { 131 shaped_buffer.reset(); 132 } 133 return Status::OK(); 134 } 135 136 StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple( 137 const GlobalDataHandle& data) { 138 tensorflow::mutex_lock lock(mutex_); 139 140 TF_ASSIGN_OR_RETURN(std::vector<const ShapedBuffer*> replicated_buffers, 141 ResolveInternal(data)); 142 // We only need to care about replica id 0 here, since the GlobalDataHandle is 143 // the same for all buffers across replicas. 144 const ShapedBuffer* shaped_buffer = replicated_buffers[0]; 145 if (!shaped_buffer->on_host_shape().IsTuple()) { 146 return InvalidArgument("global data handle %d is not a tuple", 147 data.handle()); 148 } 149 // If the on-host representation is a tuple, then the on-device one should be 150 // as well. 151 TF_RET_CHECK(shaped_buffer->on_device_shape().IsTuple()); 152 153 if (ShapeUtil::IsNestedTuple(shaped_buffer->on_device_shape())) { 154 return Unimplemented("Deconstructing nested tuples is not implemented."); 155 } 156 157 std::vector<GlobalDataHandle> element_handles; 158 for (int i = 0; 159 i < ShapeUtil::TupleElementCount(shaped_buffer->on_device_shape()); 160 ++i) { 161 auto element_buffer = ShapedBuffer( 162 ShapeUtil::GetTupleElementShape(shaped_buffer->on_host_shape(), i), 163 ShapeUtil::GetTupleElementShape(shaped_buffer->on_device_shape(), i), 164 shaped_buffer->platform(), shaped_buffer->device_ordinal()); 165 element_buffer.set_buffer(shaped_buffer->buffer(/*index=*/{i}), 166 /*index=*/{}); 167 std::vector<ShapedBuffer> replicated_buffers; 168 replicated_buffers.push_back(std::move(element_buffer)); 169 TF_ASSIGN_OR_RETURN( 170 GlobalDataHandle element_handle, 171 RegisterInternal(std::move(replicated_buffers), "deconstructed tuple")); 172 173 element_handles.push_back(element_handle); 174 } 175 return std::move(element_handles); 176 } 177 178 StatusOr<std::vector<const ShapedBuffer*>> AllocationTracker::Resolve( 179 const GlobalDataHandle& data) const { 180 tensorflow::mutex_lock lock(mutex_); 181 return AllocationTracker::ResolveInternal(data); 182 } 183 184 StatusOr<const ShapedBuffer*> AllocationTracker::ResolveForReplica( 185 const GlobalDataHandle& data, int replica_id) const { 186 tensorflow::mutex_lock lock(mutex_); 187 TF_ASSIGN_OR_RETURN(std::vector<const ShapedBuffer*> replicated_buffers, 188 ResolveInternal(data)); 189 if (replica_id >= replicated_buffers.size()) { 190 return InvalidArgument( 191 "Requesting buffer for replica %d, but found buffers only for %lu " 192 "replicas.", 193 replica_id, replicated_buffers.size()); 194 } 195 return replicated_buffers[replica_id]; 196 } 197 198 StatusOr<std::vector<const ShapedBuffer*>> AllocationTracker::ResolveInternal( 199 const GlobalDataHandle& data) const { 200 VLOG(2) << "resolve:" << data.handle(); 201 auto it = handle_to_shaped_buffers_.find(data.handle()); 202 if (it == handle_to_shaped_buffers_.end()) { 203 return NotFound("no allocation record for global data handle: %d", 204 data.handle()); 205 } 206 std::vector<const ShapedBuffer*> replicated_buffers; 207 for (const auto& shaped_buffer : it->second) { 208 if (shaped_buffer == nullptr) { 209 return InvalidArgument("global data handle %d was previously deallocated", 210 data.handle()); 211 } 212 replicated_buffers.push_back(shaped_buffer.get()); 213 } 214 215 return replicated_buffers; 216 } 217 218 void AllocationTracker::AddAllocationOrIncrementRefCount( 219 se::DeviceMemoryBase device_memory, int device_ordinal) { 220 AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal]; 221 auto it = allocation_map.find(device_memory.opaque()); 222 if (it == allocation_map.end()) { 223 allocation_map[device_memory.opaque()] = { 224 OwningDeviceMemory(device_memory, device_ordinal, 225 backend_->memory_allocator()), 226 /*ref_count=*/1}; 227 } else { 228 it->second.ref_count++; 229 } 230 } 231 232 Status AllocationTracker::DecrementRefCount(se::DeviceMemoryBase device_memory, 233 int device_ordinal) { 234 AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal]; 235 auto it = allocation_map.find(device_memory.opaque()); 236 TF_RET_CHECK(it != allocation_map.end()); 237 Allocation& allocation = it->second; 238 TF_RET_CHECK(allocation.ref_count >= 1); 239 if (allocation.ref_count == 1) { 240 allocation.device_memory.Free(); 241 allocation_map.erase(it); 242 } else { 243 allocation.ref_count--; 244 } 245 return Status::OK(); 246 } 247 248 } // namespace xla 249