Home | History | Annotate | Download | only in jit
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/jit/xla_device.h"
     17 
     18 #include <stdlib.h>
     19 #include <unordered_set>
     20 
     21 #include "tensorflow/compiler/jit/defs.h"
     22 #include "tensorflow/compiler/jit/xla_device_context.h"
     23 #include "tensorflow/compiler/jit/xla_device_ops.h"
     24 #include "tensorflow/compiler/tf2xla/dump_graph.h"
     25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     26 #include "tensorflow/compiler/xla/client/client_library.h"
     27 #include "tensorflow/core/common_runtime/device.h"
     28 #include "tensorflow/core/common_runtime/device_factory.h"
     29 #include "tensorflow/core/common_runtime/dma_helper.h"
     30 #include "tensorflow/core/common_runtime/function.h"
     31 #include "tensorflow/core/common_runtime/renamed_device.h"
     32 #include "tensorflow/core/framework/allocator.h"
     33 #include "tensorflow/core/framework/device_base.h"
     34 #include "tensorflow/core/framework/function.h"
     35 #include "tensorflow/core/framework/kernel_def.pb.h"
     36 #include "tensorflow/core/framework/node_def_builder.h"
     37 #include "tensorflow/core/framework/op_kernel.h"
     38 #include "tensorflow/core/framework/tensor.h"
     39 #include "tensorflow/core/framework/tensor.pb.h"
     40 #include "tensorflow/core/framework/types.h"
     41 #include "tensorflow/core/graph/graph_constructor.h"
     42 #include "tensorflow/core/lib/core/notification.h"
     43 #include "tensorflow/core/lib/core/status.h"
     44 #include "tensorflow/core/platform/logging.h"
     45 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
     46 #include "tensorflow/core/platform/tracing.h"
     47 #include "tensorflow/core/public/session_options.h"
     48 #include "tensorflow/core/public/version.h"
     49 #include "tensorflow/core/util/device_name_utils.h"
     50 #include "tensorflow/core/util/stream_executor_util.h"
     51 
     52 namespace se = ::perftools::gputools;
     53 
     54 namespace tensorflow {
     55 
     56 // Caches a XlaDeviceAllocator per <backend, device ordinal> pair. A
     57 // XlaDeviceAllocator is created on demand and is associated with a
     58 // XlaDevice. It outlives the device itself (for instance, the buffer
     59 // backing a tensor holds a pointer to the allocator for book-keeping,
     60 // and this buffer can outlast the device).
     61 class XlaDeviceAllocatorState {
     62  public:
     63   // Creates or returns a cached XlaDeviceAllocator for a given
     64   // backend and device_ordinal.
     65   static XlaDeviceAllocator* GetOrCreateXlaDeviceAllocator(
     66       const xla::Backend* backend, int device_ordinal);
     67 
     68  private:
     69   // Returns the singleton instance of XlaDeviceAllocatorState.
     70   static XlaDeviceAllocatorState& Singleton();
     71   XlaDeviceAllocatorState();
     72   ~XlaDeviceAllocatorState();
     73 
     74   mutex allocator_mutex_;  // Guards the singleton allocator state.
     75   std::unordered_map<std::pair<const xla::Backend*, int>,
     76                      std::unique_ptr<XlaDeviceAllocator>,
     77                      hash<std::pair<const xla::Backend*, int>>>
     78       allocators_ GUARDED_BY(allocator_mutex_);
     79 
     80   TF_DISALLOW_COPY_AND_ASSIGN(XlaDeviceAllocatorState);
     81 };
     82 
     83 /* static */ XlaDeviceAllocatorState& XlaDeviceAllocatorState::Singleton() {
     84   static auto a = new XlaDeviceAllocatorState;
     85   return *a;
     86 }
     87 
     88 XlaDeviceAllocatorState::XlaDeviceAllocatorState() = default;
     89 XlaDeviceAllocatorState::~XlaDeviceAllocatorState() = default;
     90 
     91 XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
     92     const xla::Backend* backend, int device_ordinal) {
     93   XlaDeviceAllocatorState& state = Singleton();
     94   mutex_lock lock(state.allocator_mutex_);
     95 
     96   auto it = state.allocators_.find({backend, device_ordinal});
     97   if (it != state.allocators_.end()) {
     98     return it->second.get();
     99   }
    100 
    101   std::unique_ptr<XlaDeviceAllocator> alloc =
    102       xla::MakeUnique<XlaDeviceAllocator>(backend, device_ordinal);
    103   XlaDeviceAllocator* alloc_ptr = alloc.get();
    104   state.allocators_[{backend, device_ordinal}] = std::move(alloc);
    105   return alloc_ptr;
    106 }
    107 
    108 /* static */ Status XlaDevice::Create(
    109     const string& platform_name, const string& device_name, int device_ordinal,
    110     const string& jit_device_name, const SessionOptions& options,
    111     const string& name_prefix, bool register_device_for_compilation,
    112     std::unique_ptr<XlaDevice>* device) {
    113   VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":"
    114           << device_ordinal;
    115 
    116   if (register_device_for_compilation) {
    117     // These are no-ops if they have already been done previously for
    118     // this device_name/compilation_device_name pair.
    119     XlaOpRegistry::DeviceRegistration registration;
    120     registration.compilation_device_name = jit_device_name;
    121     registration.requires_compilation = true;
    122     registration.enable_jit_by_default = false;
    123     registration.compile_resource_ops = true;
    124     XlaOpRegistry::RegisterCompilationDevice(device_name, registration);
    125   }
    126 
    127   auto platform = se::MultiPlatformManager::PlatformWithName(platform_name);
    128   if (!platform.ok()) {
    129     return StreamExecutorUtil::ConvertStatus(platform.status());
    130   }
    131 
    132   const DeviceAttributes attrs = Device::BuildDeviceAttributes(
    133       strings::StrCat(name_prefix, "/device:", device_name, ":",
    134                       device_ordinal),
    135       DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
    136       strings::StrCat("device: ", device_name, " device"));
    137 
    138   device->reset(new XlaDevice(options, attrs, device_ordinal,
    139                               DeviceType(jit_device_name),
    140                               platform.ValueOrDie()));
    141   return Status::OK();
    142 }
    143 
    144 XlaDevice::Metadata::Metadata(int device_ordinal, se::Platform* platform,
    145                               const DeviceType& device_type)
    146     : device_ordinal_(device_ordinal),
    147       device_type_(device_type),
    148       platform_(platform) {}
    149 
    150 int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
    151 
    152 se::Platform* XlaDevice::Metadata::platform() const { return platform_; }
    153 
    154 xla::LocalClient* XlaDevice::Metadata::client() const {
    155   auto client = xla::ClientLibrary::GetOrCreateLocalClient(platform_);
    156   return client.ValueOrDie();
    157 }
    158 
    159 const DeviceType& XlaDevice::Metadata::jit_device_type() const {
    160   return device_type_;
    161 }
    162 
    163 /* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx,
    164                                            const Metadata** metadata) {
    165   XlaDevice* xla_device =
    166       dynamic_cast<XlaDevice*>(ctx->device()->UnderlyingDevice());
    167   if (xla_device == nullptr) {
    168     return errors::Internal(
    169         "Cannot get XLA metadata from non-XLA device \"", ctx->device()->name(),
    170         "\". GetMetadata must only be called on an XLA device. Either an "
    171         "internal bug has been triggered, or an XLA-specific op has been "
    172         "placed on the wrong device.");
    173   }
    174   *metadata = &(xla_device->xla_metadata_);
    175   return Status::OK();
    176 }
    177 
    178 XlaDevice::XlaDevice(const SessionOptions& options,
    179                      const DeviceAttributes& attrs, int device_ordinal,
    180                      const DeviceType& jit_device_name, se::Platform* platform)
    181     : LocalDevice(options, attrs),
    182       xla_metadata_(device_ordinal, platform, jit_device_name),
    183       device_ordinal_(device_ordinal),
    184       jit_device_name_(jit_device_name),
    185       xla_allocator_(nullptr),
    186       platform_(platform) {}
    187 
    188 XlaDevice::~XlaDevice() {}
    189 
    190 xla::LocalClient* XlaDevice::client() const {
    191   // We lazily create the client because the platform commits to the
    192   // details of the host hardware when the client is created, so we
    193   // don't want to do it until we get a chance to hook the platform up
    194   // to a simulator.
    195 
    196   // For now GetOrCreateLocalClient always returns success when passed
    197   // a non-null platform. If that changes we may have to plumb in some
    198   // way to pass Status back.
    199   return xla::ClientLibrary::GetOrCreateLocalClient(platform_).ValueOrDie();
    200 }
    201 
    202 Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) {
    203   if (attr.on_host()) {
    204     return cpu_allocator();
    205   }
    206 
    207   if (xla_allocator_ == nullptr) {
    208     xla::Backend* backend = client()->mutable_backend();
    209     xla_allocator_ = XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
    210         backend, device_ordinal_);
    211   }
    212   return xla_allocator_;
    213 }
    214 
    215 xla::StatusOr<se::Stream*> XlaDevice::GetStream() {
    216   if (!stream_) {
    217     xla::Backend* backend = client()->mutable_backend();
    218     TF_ASSIGN_OR_RETURN(stream_, backend->BorrowStream(device_ordinal_));
    219   }
    220   return stream_.get();
    221 }
    222 
    223 Status XlaDevice::FillContextMap(const Graph* graph,
    224                                  DeviceContextMap* device_context_map) {
    225   VLOG(1) << "XlaDevice::FillContextMap";
    226   device_context_map->resize(graph->num_node_ids());
    227   TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
    228   auto ctx = new XlaDeviceContext(stream);
    229   for (Node* n : graph->nodes()) {
    230     VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name();
    231     ctx->Ref();
    232     (*device_context_map)[n->id()] = ctx;
    233   }
    234   ctx->Unref();
    235   return Status::OK();
    236 }
    237 
    238 void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
    239   VLOG(1) << "XlaDevice::Compute " << op_kernel->name() << ":"
    240           << op_kernel->type_string();
    241   // When TraceMe profiling is off (which is the default), the
    242   // following TraceMe constructor is simply a conditional test of
    243   // false value. Measurements show that its overhead is negligible.
    244   port::Tracing::TraceMe trace_me(op_kernel->name(), op_kernel->type_string(),
    245                                   op_kernel->IsExpensive());
    246   op_kernel->Compute(context);
    247 }
    248 
    249 void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
    250                              AsyncOpKernel::DoneCallback done) {
    251   VLOG(1) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":"
    252           << op_kernel->type_string();
    253   port::Tracing::TraceMe trace_me(op_kernel->name(), op_kernel->type_string(),
    254                                   op_kernel->IsExpensive());
    255   op_kernel->ComputeAsync(context, done);
    256 }
    257 
    258 Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
    259                                       const AllocatorAttributes alloc_attrs,
    260                                       Tensor* tensor) {
    261   VLOG(1) << "XlaDevice::MakeTensorFromProto";
    262 
    263   Tensor parsed(tensor_proto.dtype());
    264   if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
    265     return errors::InvalidArgument("Cannot parse tensor from proto: ",
    266                                    tensor_proto.DebugString());
    267   }
    268 
    269   Status status;
    270   if (alloc_attrs.on_host()) {
    271     *tensor = parsed;
    272   } else {
    273     Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
    274     Notification n;
    275     TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
    276     XlaTransferManager manager(stream);
    277     manager.CopyCPUTensorToDevice(&parsed, this, &copy,
    278                                   [&n, &status](const Status& s) {
    279                                     status = s;
    280                                     n.Notify();
    281                                   });
    282     n.WaitForNotification();
    283     *tensor = copy;
    284   }
    285   VLOG(2) << "Allocated tensor at " << DMAHelper::base(tensor);
    286   return status;
    287 }
    288 
    289 XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
    290                                                    const char* jit_device) {
    291   XlaOpRegistry::RegisterCompilationKernels();
    292   XlaDeviceOpRegistrations* registrations = new XlaDeviceOpRegistrations;
    293   auto dummy_factory = [](OpKernelConstruction* context) -> OpKernel* {
    294     return new XlaDeviceDummyOp(context);
    295   };
    296   for (const KernelDef* jit_def : XlaOpRegistry::DeviceKernels(
    297            jit_device,
    298            /*include_compilation_only_kernels=*/false)) {
    299     KernelDef* def = new KernelDef(*jit_def);
    300     def->set_device_type(device);
    301     registrations->op_kernel_registrars.emplace_back(
    302         new kernel_factory::OpKernelRegistrar(def, "XlaDeviceDummyOp",
    303                                               dummy_factory));
    304   }
    305   return registrations;
    306 }
    307 
    308 }  // namespace tensorflow
    309