Home | History | Annotate | Download | only in xrt
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Classes for allocating XLA literals in device memory and managing handles
     17 // that refer to them.
     18 
     19 #include "tensorflow/compiler/xrt/xrt_state.h"
     20 
     21 #include <stdint.h>
     22 #include <map>
     23 #include <memory>
     24 #include <string>
     25 #include <utility>
     26 
     27 #include "absl/memory/memory.h"
     28 #include "absl/strings/str_cat.h"
     29 #include "tensorflow/compiler/xla/literal.h"
     30 #include "tensorflow/compiler/xla/service/backend.h"
     31 #include "tensorflow/compiler/xla/shape_util.h"
     32 #include "tensorflow/compiler/xla/status_macros.h"
     33 #include "tensorflow/compiler/xla/statusor.h"
     34 #include "tensorflow/core/framework/resource_mgr.h"
     35 #include "tensorflow/core/lib/core/status.h"
     36 #include "tensorflow/core/lib/random/random.h"
     37 #include "tensorflow/core/platform/mutex.h"
     38 #include "tensorflow/core/platform/types.h"
     39 #include "tensorflow/stream_executor/stream_executor.h"
     40 
     41 namespace tensorflow {
     42 
     43 namespace {
     44 
     45 class BufferAllocStats {
     46  public:
     47   struct Stats {
     48     int64 count = 0;
     49     int64 size = 0;
     50   };
     51 
     52   Stats ReportAlloc(int64 device, int64 msize) {
     53     mutex_lock lock(lock_);
     54     Stats* device_stats = &stats_[device];
     55     device_stats->count += 1;
     56     device_stats->size += msize;
     57     return *device_stats;
     58   }
     59 
     60   Stats ReportFree(int64 device, int64 msize) {
     61     mutex_lock lock(lock_);
     62     Stats* device_stats = &stats_[device];
     63     device_stats->count -= 1;
     64     device_stats->size -= msize;
     65     return *device_stats;
     66   }
     67 
     68  private:
     69   mutable mutex lock_;
     70   std::map<int64, Stats> stats_;
     71 };
     72 
     73 const char* kTupleContainer = "tuples";
     74 
     75 int64 get_uid() {
     76   uint64 unsigned_rand = random::New64() & INT64_MAX;
     77   return static_cast<int64>(unsigned_rand);
     78 }
     79 
     80 BufferAllocStats* GetAllocStats() {
     81   static BufferAllocStats* stats = new BufferAllocStats();
     82   return stats;
     83 }
     84 
     85 Status AllocateScopedShapedBuffer(
     86     xla::Backend* backend, int device_ordinal, const xla::Shape& shape,
     87     std::unique_ptr<xla::ScopedShapedBuffer>* buffer) {
     88   auto transfer_manager = backend->transfer_manager();
     89   auto allocator = backend->memory_allocator();
     90   TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
     91 
     92   // XLA may use a different representation on device than the representation on
     93   // the host. XLA does not document any contract for the relationship between
     94   // these representations :/ Right now, the device shape is always a superset
     95   // of the host shape, meaning that for any valid ShapeIndex in the host shape
     96   // that ShapeIndex is also valid in the device shape, but not vice versa. In
     97   // particular, some host-side types are rewritten to be tuples. We rely on
     98   // this property when making sub-buffers, because we assume that if the client
     99   // requests the host-shape sub-buffer at index i, that will correspond to the
    100   // right device-shape sub-buffer at the same index.
    101   xla::Shape on_device_shape = transfer_manager->HostShapeToDeviceShape(shape);
    102   VLOG(3) << "Allocating literal buffer: host_shape="
    103           << xla::ShapeUtil::HumanStringWithLayout(shape) << " device_shape="
    104           << xla::ShapeUtil::HumanStringWithLayout(on_device_shape);
    105 
    106   // The ScopedShapedBuffer frees the buffers that have so far been allocated if
    107   // it goes out of scope. That's useful if we return early as the result of an
    108   // error allocating one of the later buffers.
    109   *buffer = absl::make_unique<xla::ScopedShapedBuffer>(
    110       shape, on_device_shape, allocator, device_ordinal);
    111   for (auto& index_to_buffer : (*buffer)->buffers()) {
    112     xla::Shape subshape =
    113         xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first);
    114     uint64 size = transfer_manager->GetByteSizeRequirement(subshape);
    115     TF_ASSIGN_OR_RETURN(
    116         xla::OwningDeviceMemory buffer,
    117         allocator->Allocate(device_ordinal, size, /*retry_on_failure=*/false));
    118     // Move our buffer into shaped_buffer, which takes ownership of it.
    119     index_to_buffer.second = buffer.Forget();
    120     VLOG(2) << "Allocated buffer at " << index_to_buffer.second.opaque()
    121             << " index " << index_to_buffer.first.ToString();
    122   }
    123 
    124   TF_RETURN_IF_ERROR(
    125       transfer_manager->WriteTupleIndexTables(stream.get(), *(buffer->get())));
    126 
    127   return Status::OK();
    128 }
    129 
    130 }  // namespace
    131 
    132 XRTBufferAllocation::XRTBufferAllocation(const se::DeviceMemoryBase& allocation,
    133                                          int device_ordinal,
    134                                          xla::DeviceMemoryAllocator* allocator)
    135     : size_(allocation.size()),
    136       allocation_(allocation),
    137       device_ordinal_(device_ordinal),
    138       allocator_(allocator) {
    139   if (VLOG_IS_ON(2)) {
    140     auto stats =
    141         GetAllocStats()->ReportAlloc(device_ordinal_, allocation_.size());
    142     LOG(INFO) << "XRT Allocation Stats: device=" << device_ordinal_
    143               << " count=" << stats.count << " size=" << stats.size;
    144   }
    145 }
    146 
    147 XRTBufferAllocation::~XRTBufferAllocation() {
    148   if (VLOG_IS_ON(2)) {
    149     GetAllocStats()->ReportFree(device_ordinal_, allocation_.size());
    150   }
    151   // Deallocate explicitly allows allocation_ to be null.
    152   Status s = allocator_->Deallocate(device_ordinal_, allocation_);
    153   // Nothing to do but check fail here if memory datastructures are corrupted.
    154   CHECK(s.ok());
    155   VLOG(2) << "Freed buffer at " << allocation_.opaque();
    156 }
    157 
    158 const se::DeviceMemoryBase& XRTBufferAllocation::allocation() {
    159   return allocation_;
    160 }
    161 
    162 void XRTBufferAllocation::DiscardAllocation() {
    163   // Replace the allocation with a null.
    164   allocation_ = se::DeviceMemoryBase();
    165 }
    166 
    167 XRTTupleAllocation::XRTTupleAllocation(int device_ordinal,
    168                                        xla::DeviceMemoryAllocator* allocator,
    169                                        const xla::Shape& on_host_shape,
    170                                        const xla::Shape& on_device_shape)
    171     : device_ordinal_(device_ordinal),
    172       allocator_(allocator),
    173       on_host_shape_(on_host_shape),
    174       on_device_shape_(on_device_shape),
    175       buffers_(&on_device_shape_) {}
    176 
    177 XRTTupleAllocation::~XRTTupleAllocation() {
    178   for (auto& buffer : buffers_) {
    179     buffer.second->Unref();
    180   }
    181 }
    182 
    183 /*static*/ Status XRTTupleAllocation::CreateAndTransfer(
    184     const xla::LiteralBase& literal, xla::Backend* backend, int device_ordinal,
    185     XRTTupleAllocation** allocation) {
    186   auto transfer_manager = backend->transfer_manager();
    187   auto allocator = backend->memory_allocator();
    188 
    189   std::unique_ptr<xla::ScopedShapedBuffer> scoped_buffer;
    190   TF_RETURN_IF_ERROR(AllocateScopedShapedBuffer(
    191       backend, device_ordinal, literal.shape(), &scoped_buffer));
    192   TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
    193   TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
    194       stream.get(), literal, *scoped_buffer));
    195 
    196   // By releasing the ScopedShapedBuffer we ensure that the underlying storage
    197   // won't be freed when the buffer goes out of scope at the end of this
    198   // call. To avoid a leak, there must be no error-case returns from here until
    199   // the end of the method.
    200   auto shaped_buffer = scoped_buffer->release();
    201   *allocation = new XRTTupleAllocation(device_ordinal, allocator,
    202                                        shaped_buffer.on_host_shape(),
    203                                        shaped_buffer.on_device_shape());
    204   (*allocation)
    205       ->InitializeFromShapedBuffer(shaped_buffer, allocator, device_ordinal);
    206   return Status::OK();
    207 }
    208 
    209 /*static*/ Status XRTTupleAllocation::CreateFromBuffer(
    210     const xla::ShapedBuffer& shaped_buffer, xla::Backend* backend,
    211     int device_ordinal, XRTTupleAllocation** allocation) {
    212   auto allocator = backend->memory_allocator();
    213 
    214   *allocation = new XRTTupleAllocation(device_ordinal, allocator,
    215                                        shaped_buffer.on_host_shape(),
    216                                        shaped_buffer.on_device_shape());
    217   (*allocation)
    218       ->InitializeFromShapedBuffer(shaped_buffer, allocator, device_ordinal);
    219   return Status::OK();
    220 }
    221 
    222 Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal,
    223                                      xla::MutableLiteralBase* literal) {
    224   auto transfer_manager = backend->transfer_manager();
    225   TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
    226 
    227   // Validate the allocation buffers as if nulls gets to
    228   // TransferLiteralFromDevice() a CHECK is issued.
    229   xla::ShapedBuffer shaped_buffer = ToShapedBuffer();
    230   for (auto& index_buffer : shaped_buffer.buffers()) {
    231     if (index_buffer.second.is_null()) {
    232       return errors::InvalidArgument("Literal buffer at index ",
    233                                      index_buffer.first.ToString(),
    234                                      " has been released");
    235     }
    236   }
    237   return transfer_manager->TransferLiteralFromDevice(stream.get(),
    238                                                      shaped_buffer, *literal);
    239 }
    240 
    241 Status XRTTupleAllocation::WriteLiteral(xla::Backend* backend,
    242                                         const xla::Literal& literal) {
    243   if (!xla::ShapeUtil::Equal(literal.shape(), on_host_shape())) {
    244     return errors::InvalidArgument(
    245         "New literal shape not matching the existing one: literal=",
    246         xla::ShapeUtil::HumanStringWithLayout(literal.shape()),
    247         " device=", xla::ShapeUtil::HumanStringWithLayout(on_host_shape()));
    248   }
    249   auto transfer_manager = backend->transfer_manager();
    250   TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal()));
    251   return transfer_manager->TransferLiteralToDevice(stream.get(), literal,
    252                                                    ToShapedBuffer());
    253 }
    254 
    255 void XRTTupleAllocation::DiscardAllocation(
    256     const xla::ShapeIndex& buffer_index) {
    257   buffers_.element(buffer_index)->DiscardAllocation();
    258 }
    259 
    260 const xla::Shape& XRTTupleAllocation::on_host_shape() { return on_host_shape_; }
    261 
    262 const xla::Shape& XRTTupleAllocation::on_device_shape() {
    263   return on_device_shape_;
    264 }
    265 
    266 int XRTTupleAllocation::device_ordinal() { return device_ordinal_; }
    267 
    268 const se::DeviceMemoryBase& XRTTupleAllocation::root_allocation() {
    269   return buffers_.element({})->allocation();
    270 }
    271 
    272 /*static*/ Status XRTTupleAllocation::Lookup(ResourceMgr* rm, int64 key,
    273                                              XRTTupleAllocation** allocation) {
    274   string key_string = absl::StrCat(key);
    275   TF_RETURN_IF_ERROR(rm->Lookup(kTupleContainer, key_string, allocation));
    276   return Status::OK();
    277 }
    278 
    279 /*static*/ Status XRTTupleAllocation::DeleteFromResourceManager(ResourceMgr* rm,
    280                                                                 int64 key) {
    281   string key_string = absl::StrCat(key);
    282   return rm->Delete<XRTTupleAllocation>(kTupleContainer, key_string);
    283 }
    284 
    285 /* static */ Status XRTTupleAllocation::ReleaseAllAllocations(ResourceMgr* rm) {
    286   VLOG(1) << "Releasing all XRT held device memory";
    287   return rm->Cleanup(kTupleContainer);
    288 }
    289 
    290 // Helper typedef to make ShapeTree ForEach helper lambda signatures more
    291 // readable. They need a type of const T& where in this case T is the
    292 // following pointer.
    293 typedef XRTBufferAllocation* XRTBufferAllocationPtr;
    294 
    295 /*static*/ Status XRTTupleAllocation::MakeSubBuffer(
    296     XRTTupleAllocation* parent, const xla::ShapeIndex& subshape,
    297     XRTTupleAllocation** allocation, bool alias_parent_allocation) {
    298   TF_ASSIGN_OR_RETURN(
    299       const xla::Shape* host_sub_shape,
    300       xla::ShapeUtil::TryGetSubshape(parent->on_host_shape(), subshape));
    301   TF_ASSIGN_OR_RETURN(
    302       const xla::Shape* device_sub_shape,
    303       xla::ShapeUtil::TryGetSubshape(parent->on_device_shape(), subshape));
    304 
    305   *allocation =
    306       new XRTTupleAllocation(parent->device_ordinal(), parent->allocator_,
    307                              *host_sub_shape, *device_sub_shape);
    308   if (alias_parent_allocation) {
    309     // Copy the subtree of allocations from the parent allocation.
    310     (*allocation)->buffers_.CopySubtreeFrom(parent->buffers_, subshape, {});
    311     // Increment the refcount on each aliased buffer.
    312     (*allocation)
    313         ->buffers_.ForEachElement(
    314             [](const xla::ShapeIndex& index,
    315                const XRTBufferAllocationPtr& buffer) { buffer->Ref(); });
    316   } else {
    317     // Find the buffers in the parent allocation that match the subtree, and
    318     // move the parent allocation's buffer over to the new allocation.
    319     (*allocation)
    320         ->buffers_.ForEachMutableElement(
    321             [&](const xla::ShapeIndex& index, XRTBufferAllocationPtr* buffer) {
    322               // Extend the allocation's index to the parent's frame by adding
    323               // subshape as a prefix.
    324               xla::ShapeIndex parent_index = subshape;
    325               for (int i = 0; i < index.size(); ++i) {
    326                 parent_index.push_back(index[i]);
    327               }
    328               *buffer = parent->buffers_.element(parent_index);
    329               *parent->buffers_.mutable_element(parent_index) =
    330                   new XRTBufferAllocation(se::DeviceMemoryBase(),
    331                                           parent->device_ordinal(),
    332                                           parent->allocator_);
    333             });
    334   }
    335 
    336   return Status::OK();
    337 }
    338 
    339 /* static */ Status XRTTupleAllocation::ExpandTreeOfTuples(
    340     const xla::ShapeTree<ExpandedTupleInput>& elements, int device_ordinal,
    341     xla::DeviceMemoryAllocator* allocator, xla::Shape* host_shape,
    342     xla::Shape* device_shape) {
    343   // Initialize both host and device shape to be the 'spine' of the new tuple
    344   // shape, given by the shape of the tree of tuples.
    345   *host_shape = elements.shape();
    346   *device_shape = elements.shape();
    347   // Now go over the leaves of the tree of tuples, and 'graft' the host/device
    348   // shapes of the allocation at that leaf onto the expanded host/device shapes
    349   // at the leaf position.
    350   TF_RETURN_IF_ERROR(elements.ForEachElementWithStatus(
    351       [&](const xla::ShapeIndex& index, const ExpandedTupleInput& element) {
    352         if (elements.IsLeaf(index)) {
    353           if (element.allocation == nullptr) {
    354             return errors::InvalidArgument(
    355                 "MakeTuple elements has a null internal node at index ",
    356                 index.ToString());
    357           }
    358           if (device_ordinal != element.allocation->device_ordinal() ||
    359               allocator != element.allocation->allocator_) {
    360             return errors::InvalidArgument(
    361                 "MakeTuple elements must all be allocated on the same device "
    362                 "as the destination.");
    363           }
    364           *xla::ShapeUtil::GetMutableSubshape(host_shape, index) =
    365               element.allocation->on_host_shape();
    366           *xla::ShapeUtil::GetMutableSubshape(device_shape, index) =
    367               element.allocation->on_device_shape();
    368         } else {
    369           if (element.allocation != nullptr) {
    370             return errors::InvalidArgument(
    371                 "MakeTuple elements has a non-null internal node at index ",
    372                 index.ToString());
    373           }
    374         }
    375         return Status::OK();
    376       }));
    377   return Status::OK();
    378 }
    379 
    380 /*static*/ Status XRTTupleAllocation::MakeTuple(
    381     xla::Backend* backend, int device_ordinal,
    382     const xla::ShapeTree<ExpandedTupleInput>& elements,
    383     XRTTupleAllocation** allocation) {
    384   auto transfer_manager = backend->transfer_manager();
    385   auto allocator = backend->memory_allocator();
    386   TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
    387 
    388   xla::Shape host_shape;
    389   xla::Shape device_shape;
    390   TF_RETURN_IF_ERROR(ExpandTreeOfTuples(elements, device_ordinal, allocator,
    391                                         &host_shape, &device_shape));
    392 
    393   // The aliasing is determined below based on whether or not all the inputs are
    394   // released while being transferred. allocation_tmp is a local pointer that is
    395   // copied to *allocation at the end only if the method succeeds.
    396   auto allocation_tmp = new XRTTupleAllocation(device_ordinal, allocator,
    397                                                host_shape, device_shape);
    398   core::ScopedUnref allocation_unref(allocation_tmp);
    399   // First allocate device memory for the new tuple index tables, one at each
    400   // internal node of the elements tree. Do this in a separate pass into a
    401   // ScopedShapedBuffer so that it's easy to free the newly-allocated memory if
    402   // an allocation fails. Make sure the shape has layout so that the code that
    403   // writes index tables will be happy lower down.
    404   xla::Shape spine_shape = elements.shape();
    405   xla::LayoutUtil::SetToDefaultLayout(&spine_shape);
    406   auto new_tuple_buffers = absl::make_unique<xla::ScopedShapedBuffer>(
    407       spine_shape, spine_shape, allocator, device_ordinal);
    408   TF_RETURN_IF_ERROR(elements.ForEachElementWithStatus(
    409       [&](const xla::ShapeIndex& index, const ExpandedTupleInput& element) {
    410         if (!elements.IsLeaf(index)) {
    411           xla::Shape subshape =
    412               xla::ShapeUtil::GetSubshape(device_shape, index);
    413           uint64 size = transfer_manager->GetByteSizeRequirement(subshape);
    414           TF_ASSIGN_OR_RETURN(xla::OwningDeviceMemory buffer,
    415                               allocator->Allocate(device_ordinal, size,
    416                                                   /*retry_on_failure=*/false));
    417           VLOG(2) << "Allocated buffer at " << buffer.opaque() << " index "
    418                   << index.ToString();
    419           // Move the new buffer into new_tuple_buffers, which takes ownership
    420           // of it.
    421           new_tuple_buffers->set_buffer(std::move(buffer), index);
    422         }
    423         return Status::OK();
    424       }));
    425   // Transfer from the ScopedShapedBuffer to a ShapedBuffer, which does not own
    426   // the newly-allocated index tables. Right now there's no owner for the new
    427   // index tables, so next we will transfer ownership to the new allocation,
    428   // taking care not to return early on any errors in the meantime.
    429   xla::ShapedBuffer tuple_buffers = new_tuple_buffers->release();
    430   // Now fill in the remaining datastructures. After this ForEachElement
    431   // completes:
    432   //   1) Every leaf element of tuple_buffers will be the root buffer of
    433   //      an existing allocation, and every internal element of tuple_buffers
    434   //      will be a newly-allocated index table. tuple_buffers does not own any
    435   //      of these.
    436   //   2) Every element of allocation_tmp->buffers_ will be a correctly
    437   //   constructed
    438   //      XRTBufferAllocation wrapping the necessary allocations. For buffers in
    439   //      existing allocations there will be a new reference owned by the new
    440   //      allocation, and for newly-allocated index tables there will be a
    441   //      single reference owned by the new allocation.
    442   elements.ForEachElement([&](const xla::ShapeIndex& index,
    443                               const ExpandedTupleInput& element) {
    444     if (elements.IsLeaf(index)) {
    445       allocation_tmp->buffers_.CopySubtreeFrom(element.allocation->buffers_, {},
    446                                                index);
    447       tuple_buffers.set_buffer(element.allocation->root_allocation(), index);
    448       if (element.release_allocation_after_use) {
    449         // Transfer the references from element's buffers to the new allocation
    450         // rather than incrementing the refcount. The caller should have
    451         // validated that release_allocation_after_use is false if
    452         // element.allocation appears in more than one leaf.
    453         element.allocation->buffers_.ForEachMutableElement(
    454             [&](const xla::ShapeIndex& index, XRTBufferAllocationPtr* buffer) {
    455               *buffer = new XRTBufferAllocation(
    456                   se::DeviceMemoryBase(), element.allocation->device_ordinal(),
    457                   element.allocation->allocator_);
    458             });
    459       } else {
    460         // Increment the refcount on each newly-aliased buffer.
    461         element.allocation->buffers_.ForEachElement(
    462             [](const xla::ShapeIndex& index,
    463                const XRTBufferAllocationPtr& buffer) { buffer->Ref(); });
    464       }
    465     } else {
    466       // This is an internal node of the tuple tree so take ownership of the
    467       // newly-created index table.
    468       *allocation_tmp->buffers_.mutable_element(index) =
    469           new XRTBufferAllocation(tuple_buffers.buffer(index), device_ordinal,
    470                                   allocator);
    471     }
    472   });
    473   // Because the internal nodes of tuple_buffers are exactly the new index
    474   // tables, WriteTupleIndexTables will write only the new index tables and not
    475   // rewrite the index tables for the existing allocations.
    476   TF_RETURN_IF_ERROR(
    477       transfer_manager->WriteTupleIndexTables(stream.get(), tuple_buffers));
    478 
    479   *allocation = allocation_tmp;
    480   // Get another reference since allocation_tmp will be Unrefed automatically on
    481   // exit.
    482   (*allocation)->Ref();
    483   return Status::OK();
    484 }
    485 
    486 Status XRTTupleAllocation::Intern(ResourceMgr* rm, int64* key) {
    487   *key = get_uid();
    488   string key_string = absl::StrCat(*key);
    489   return rm->Create(kTupleContainer, key_string, this);
    490 }
    491 
    492 bool XRTTupleAllocation::IsExclusiveOwner() {
    493   for (const auto& buffer : buffers_) {
    494     if (!buffer.second->RefCountIsOne()) return false;
    495   }
    496   return true;
    497 }
    498 
    499 void XRTTupleAllocation::InitializeFromShapedBuffer(
    500     const xla::ShapedBuffer& shaped_buffer,
    501     xla::DeviceMemoryAllocator* allocator, int device_ordinal) {
    502   for (auto& buffer : buffers_) {
    503     // Make a reference-counted version of the allocated buffer.
    504     buffer.second = new XRTBufferAllocation(shaped_buffer.buffer(buffer.first),
    505                                             device_ordinal, allocator);
    506   }
    507 }
    508 
    509 xla::ShapedBuffer XRTTupleAllocation::ToShapedBuffer() {
    510   xla::ShapedBuffer shaped_buffer(on_host_shape(), on_device_shape(),
    511                                   allocator_->platform(), device_ordinal_);
    512   for (const auto& buffer : buffers_) {
    513     shaped_buffer.set_buffer(buffer.second->allocation(), buffer.first);
    514   }
    515   return shaped_buffer;
    516 }
    517 
    518 Status XRTTupleAllocation::AliasBufferFrom(const XRTTupleAllocation& source,
    519                                            const xla::ShapeIndex& source_index,
    520                                            const xla::ShapeIndex& dest_index) {
    521   XRTBufferAllocation* source_buffer = source.buffers_.element(source_index);
    522   XRTBufferAllocation* dest_buffer = buffers_.element(dest_index);
    523   // We allow the destination size being zero, because there are cases where we
    524   // are coming in later filling in null/uninitialized device buffers.
    525   // In all other cases, the size of the new buffer must match.
    526   if (source_buffer->size() != dest_buffer->size() &&
    527       dest_buffer->size() != 0) {
    528     return errors::InvalidArgument(
    529         "Source buffer at index ", source_index.ToString(),
    530         " does not match the size of destination buffer at index ",
    531         dest_index.ToString(), ": ", source_buffer->size(), " vs ",
    532         dest_buffer->size());
    533   }
    534   *buffers_.mutable_element(dest_index) = source_buffer;
    535   source_buffer->Ref();
    536   dest_buffer->Unref();
    537   return Status::OK();
    538 }
    539 
    540 xla::ShapeTree<xla::MaybeOwningDeviceMemory>
    541 XRTTupleAllocation::ToDeviceMemoryTree(
    542     const std::function<bool(const xla::ShapeIndex&)>& release_checker) {
    543   xla::ShapeTree<xla::MaybeOwningDeviceMemory> shaped_tree(on_device_shape());
    544   for (const auto& buffer : buffers_) {
    545     if (!release_checker(buffer.first)) {
    546       *shaped_tree.mutable_element(buffer.first) = buffer.second->allocation();
    547     } else {
    548       *shaped_tree.mutable_element(buffer.first) = xla::OwningDeviceMemory(
    549           buffer.second->allocation(), device_ordinal_, allocator_);
    550       DiscardAllocation(buffer.first);
    551     }
    552   }
    553   return shaped_tree;
    554 }
    555 
    556 }  // namespace tensorflow
    557