Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/generic_transfer_manager.h"
     17 
     18 #include <string>
     19 #include <utility>
     20 #include <vector>
     21 
     22 #include "tensorflow/compiler/xla/layout_util.h"
     23 #include "tensorflow/compiler/xla/literal_util.h"
     24 #include "tensorflow/compiler/xla/service/interpreter/platform_id.h"
     25 #include "tensorflow/compiler/xla/shape_util.h"
     26 #include "tensorflow/compiler/xla/status_macros.h"
     27 #include "tensorflow/compiler/xla/statusor.h"
     28 #include "tensorflow/compiler/xla/types.h"
     29 #include "tensorflow/compiler/xla/util.h"
     30 #include "tensorflow/compiler/xla/xla_data.pb.h"
     31 #include "tensorflow/core/lib/core/errors.h"
     32 #include "tensorflow/core/platform/logging.h"
     33 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
     34 
     35 namespace se = ::perftools::gputools;
     36 
     37 namespace xla {
     38 
     39 GenericTransferManager::GenericTransferManager(se::Platform::Id platform_id,
     40                                                size_t pointer_size)
     41     : platform_id_(platform_id), pointer_size_(pointer_size) {
     42   // We currently only support kHostPlatformId for CPU, kCudaPlatformId for
     43   // GPU and kInterpreterPlatformId for Interpreter. Before supporting other
     44   // platforms, we need to test this transfer manager on them.
     45   CHECK(platform_id_ == se::host::kHostPlatformId ||
     46         platform_id_ == se::interpreter::kInterpreterPlatformId ||
     47         platform_id_ == se::cuda::kCudaPlatformId);
     48 }
     49 
     50 se::Platform::Id GenericTransferManager::PlatformId() const {
     51   return platform_id_;
     52 }
     53 
     54 Status GenericTransferManager::WriteSingleTupleIndexTable(
     55     perftools::gputools::StreamExecutor* executor,
     56     tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> elements,
     57     const Shape& shape, perftools::gputools::DeviceMemoryBase* region) {
     58   TF_RET_CHECK(elements.size() == ShapeUtil::TupleElementCount(shape));
     59 
     60   std::vector<const void*> element_pointers;
     61   for (const se::DeviceMemoryBase& element : elements) {
     62     element_pointers.push_back(element.opaque());
     63   }
     64   return TransferBufferToDevice(executor, GetByteSizeRequirement(shape),
     65                                 element_pointers.data(), region);
     66 }
     67 
     68 StatusOr<std::unique_ptr<Literal>>
     69 GenericTransferManager::TransferLiteralFromDevice(
     70     se::StreamExecutor* executor, const ShapedBuffer& device_buffer) {
     71   VLOG(2) << "transferring literal from device ordinal "
     72           << executor->device_ordinal() << "; device buffer: " << device_buffer;
     73   TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal());
     74 
     75   // The on-host and on-device shape should always be the same for the generic
     76   // transfer manager.
     77   TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(),
     78                                 device_buffer.on_host_shape()));
     79 
     80   std::unique_ptr<Literal> literal =
     81       Literal::CreateFromShape(device_buffer.on_host_shape());
     82 
     83   TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
     84       device_buffer.on_host_shape(),
     85       [&](const Shape& subshape, const ShapeIndex& index) -> Status {
     86         if (!ShapeUtil::IsTuple(subshape)) {
     87           TF_RETURN_IF_ERROR(TransferBufferFromDevice(
     88               executor,
     89               /*source=*/device_buffer.buffer(index),
     90               /*size=*/GetByteSizeRequirement(subshape),
     91               /*destination=*/
     92               literal->untyped_data(index)));
     93         }
     94 
     95         return Status::OK();
     96       }));
     97   return std::move(literal);
     98 }
     99 
    100 Status GenericTransferManager::TransferLiteralToDevice(
    101     se::StreamExecutor* executor, const Literal& literal,
    102     const ShapedBuffer& device_buffer) {
    103   const Shape& shape = literal.shape();
    104   VLOG(2) << "transferring literal shape to device: "
    105           << ShapeUtil::HumanString(shape)
    106           << "; device buffer: " << device_buffer;
    107 
    108   // The on-host and on-device shape should always be the same for the generic
    109   // transfer manager.
    110   TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(),
    111                                 device_buffer.on_host_shape()));
    112 
    113   TF_RET_CHECK(
    114       ShapeUtil::Compatible(literal.shape(), device_buffer.on_host_shape()));
    115   TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal());
    116 
    117   TF_RETURN_IF_ERROR(WriteTupleIndexTables(executor, device_buffer));
    118 
    119   return ShapeUtil::ForEachSubshapeWithStatus(
    120       device_buffer.on_host_shape(),
    121       [&](const Shape& device_subshape, const ShapeIndex& index) -> Status {
    122         se::DeviceMemoryBase device_memory = device_buffer.buffer(index);
    123         if (ShapeUtil::IsArray(device_subshape)) {
    124           TF_RET_CHECK(GetByteSizeRequirement(device_subshape) ==
    125                        device_memory.size());
    126           // Element is array-shaped: transfer array data to device buffer.
    127           const auto subliteral = LiteralView::Create(literal, index);
    128           std::unique_ptr<Literal> relayed_out_literal;
    129           const void* source;
    130           if (LayoutUtil::Equal(device_subshape.layout(),
    131                                 subliteral.shape().layout())) {
    132             source = subliteral.untyped_data();
    133           } else {
    134             // Relayout data before transferring.
    135             relayed_out_literal = subliteral.Relayout(device_subshape.layout(),
    136                                                       /*shape_index=*/{});
    137             source = relayed_out_literal->untyped_data();
    138           }
    139           return TransferBufferToDevice(
    140               executor,
    141               /*size=*/GetByteSizeRequirement(device_subshape), source,
    142               &device_memory);
    143         }
    144         return Status::OK();
    145       });
    146 }
    147 
    148 Status GenericTransferManager::TransferLiteralToInfeed(
    149     se::StreamExecutor* executor, const Literal& literal) {
    150   return Unimplemented("Generic transfer to Infeed");
    151 }
    152 
    153 Status GenericTransferManager::TransferBufferToInfeed(
    154     perftools::gputools::StreamExecutor* executor, int64 size,
    155     const void* source) {
    156   return Unimplemented("Generic transfer to Infeed");
    157 }
    158 
    159 Status GenericTransferManager::TransferLiteralFromOutfeed(
    160     perftools::gputools::StreamExecutor* executor, const Shape& literal_shape,
    161     Literal* literal) {
    162   return Unimplemented(
    163       "Outfeed is not supported on this platform (b/30467474)");
    164 }
    165 
    166 Status GenericTransferManager::ResetDevices(
    167     tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
    168     /*executors*/) {
    169   return Unimplemented(
    170       "Device reset is not yet supported on this platform (b/30481585)");
    171 }
    172 
    173 int64 GenericTransferManager::GetByteSizeRequirement(const Shape& shape) const {
    174   return ShapeUtil::ByteSizeOf(shape, pointer_size_);
    175 }
    176 
    177 }  // namespace xla
    178