Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
     17 
     18 #include "tensorflow/compiler/jit/defs.h"
     19 #include "tensorflow/compiler/jit/xla_device.h"
     20 #include "tensorflow/compiler/tf2xla/shape_util.h"
     21 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
     22 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     23 #include "tensorflow/compiler/xla/client/client_library.h"
     24 #include "tensorflow/compiler/xla/client/local_client.h"
     25 #include "tensorflow/compiler/xla/statusor.h"
     26 #include "tensorflow/core/common_runtime/dma_helper.h"
     27 #include "tensorflow/core/common_runtime/function.h"
     28 #include "tensorflow/core/framework/allocator.h"
     29 #include "tensorflow/core/framework/node_def_util.h"
     30 #include "tensorflow/core/framework/op.h"
     31 #include "tensorflow/core/framework/op_kernel.h"
     32 #include "tensorflow/core/framework/tensor.h"
     33 #include "tensorflow/core/framework/types.h"
     34 #include "tensorflow/core/kernels/variable_ops.h"
     35 #include "tensorflow/core/platform/env.h"
     36 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
     37 #include "tensorflow/core/util/stream_executor_util.h"
     38 
     39 namespace gpu = perftools::gputools;
     40 
     41 namespace tensorflow {
     42 
     43 // Adapter class that wraps a Tensorflow allocator as an XLA allocator.
     44 // Assumes that the Tensorflow allocator permits asynchronous deallocation:
     45 // see comment on `AllowsAsynchronousDeallocation()`.
     46 class XlaAllocator : public xla::DeviceMemoryAllocator {
     47  public:
     48   XlaAllocator(const gpu::Platform* platform, OpKernelContext* op_context);
     49   ~XlaAllocator() override;
     50   xla::StatusOr<gpu::DeviceMemoryBase> Allocate(int device_ordinal, uint64 size,
     51                                                 bool retry_on_failure) override;
     52   Status Deallocate(int device_ordinal, gpu::DeviceMemoryBase* mem) override;
     53 
     54   // Register an Tensor (input or resource variable) with the allocator. If
     55   // the operation returns an alias to one of its inputs, then the allocator
     56   // needs to be able to handle it.
     57   Status RegisterArgument(const Tensor* t);
     58 
     59   // Makes 'tensor' a wrapper around the data buffer at 'ptr'. The buffer is
     60   // interpreted as having data type 'dtype' and shape 'shape'.
     61   Status MakeTensorFromBuffer(gpu::DeviceMemoryBase buffer, DataType dtype,
     62                               const TensorShape& shape, Tensor* tensor) const;
     63 
     64   // The Tensorflow BFC allocator used on GPU allows host-side deallocation
     65   // before GPU execution takes place. Tensorflow uses the ordering of the main
     66   // compute stream to enforce a happens-before relationship between a memory
     67   // allocation and code that reuses the same memory. If Tensorflow adds
     68   // support for multiple GPU streams or allocators with different ordering
     69   // requirements, this code may need to change.
     70   // (This attribute has no effect on CPU.)
     71   bool AllowsAsynchronousDeallocation() const override { return true; }
     72 
     73  private:
     74   OpKernelContext* const op_context_;
     75 
     76   // Map from pointer address to the owning Tensor; used by
     77   // MakeTensorFromBuffer. Also used to automatically release Tensors when the
     78   // allocator is freed.
     79   std::unordered_map<void*, Tensor> tensors_;
     80 };
     81 
     82 XlaAllocator::XlaAllocator(const gpu::Platform* platform,
     83                            OpKernelContext* op_context)
     84     : xla::DeviceMemoryAllocator(platform), op_context_(op_context) {}
     85 
     86 XlaAllocator::~XlaAllocator() = default;
     87 
     88 xla::StatusOr<gpu::DeviceMemoryBase> XlaAllocator::Allocate(
     89     int device_ordinal, uint64 size, bool retry_on_failure) {
     90   AllocatorAttributes allocator_attrs;
     91   allocator_attrs.set_on_host(false);
     92 
     93   AllocationAttributes allocation_attrs;
     94   allocation_attrs.no_retry_on_failure = !retry_on_failure;
     95 
     96   Tensor t;
     97   Status status = op_context_->allocate_temp(
     98       DT_UINT8, TensorShape({static_cast<int64>(size)}), &t, allocator_attrs,
     99       allocation_attrs);
    100   if (!status.ok()) {
    101     VLOG(2) << "Allocation failed " << size;
    102     return status;
    103   }
    104   void* data =
    105       reinterpret_cast<void*>(const_cast<char*>(t.tensor_data().data()));
    106   tensors_[data] = t;
    107   return gpu::DeviceMemoryBase(data, size);
    108 }
    109 
    110 Status XlaAllocator::RegisterArgument(const Tensor* t) {
    111   void* data =
    112       reinterpret_cast<void*>(const_cast<char*>(t->tensor_data().data()));
    113   tensors_[data] = *t;
    114   return Status::OK();
    115 }
    116 
    117 Status XlaAllocator::Deallocate(int device_ordinal,
    118                                 gpu::DeviceMemoryBase* mem) {
    119   if (mem->opaque() != nullptr) {
    120     if (tensors_.erase(mem->opaque()) == 0) {
    121       return tensorflow::errors::InvalidArgument("Unknown tensor address");
    122     }
    123   }
    124   return Status::OK();
    125 }
    126 
    127 Status XlaAllocator::MakeTensorFromBuffer(gpu::DeviceMemoryBase buffer,
    128                                           DataType dtype,
    129                                           const TensorShape& shape,
    130                                           Tensor* out_tensor) const {
    131   void* ptr = const_cast<void*>(buffer.opaque());
    132   auto it = tensors_.find(ptr);
    133   if (it == tensors_.end()) {
    134     return errors::InvalidArgument("Unknown tensor address");
    135   }
    136   const Tensor& tensor = it->second;
    137 
    138   int64 output_size = DataTypeSize(dtype) * shape.num_elements();
    139   if (tensor.TotalBytes() == output_size) {
    140     out_tensor->UnsafeCopyFromInternal(tensor, dtype, shape);
    141   } else {
    142     Tensor slice = tensor.Slice(0, output_size);
    143     out_tensor->UnsafeCopyFromInternal(slice, dtype, shape);
    144   }
    145   return Status::OK();
    146 }
    147 
    148 XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
    149     : OpKernel(ctx), device_type_(ctx->device_type()) {
    150   const NameAttrList* func;
    151   OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func));
    152   function_ = *func;
    153   DataTypeVector constant_types;
    154   OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types));
    155   num_constant_args_ = constant_types.size();
    156   OP_REQUIRES_OK(ctx, ctx->GetAttr("Nresources", &num_resource_args_));
    157   if (device_type_ == DeviceType(DEVICE_CPU)) {
    158     platform_id_ = gpu::host::kHostPlatformId;
    159   } else if (device_type_ == DeviceType(DEVICE_GPU)) {
    160     platform_id_ = gpu::cuda::kCudaPlatformId;
    161   } else {
    162     platform_id_ = nullptr;
    163   }
    164 }
    165 
    166 Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx,
    167                                                XlaCompilationCache** cache) {
    168   const XlaDevice::Metadata* metadata;
    169   Status s = XlaDevice::GetMetadata(ctx, &metadata);
    170   if (s.ok()) {
    171     *cache = new XlaCompilationCache(metadata->client(),
    172                                      metadata->jit_device_type());
    173     return Status::OK();
    174   }
    175 
    176   auto platform = gpu::MultiPlatformManager::PlatformWithId(platform_id_);
    177   if (!platform.ok()) {
    178     return StreamExecutorUtil::ConvertStatus(platform.status());
    179   }
    180   xla::LocalClientOptions client_options;
    181   client_options.set_platform(platform.ValueOrDie());
    182   client_options.set_intra_op_parallelism_threads(
    183       ctx->device()->tensorflow_cpu_worker_threads()->num_threads);
    184   auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
    185   if (!client.ok()) {
    186     return client.status();
    187   }
    188   const XlaOpRegistry::DeviceRegistration* registration;
    189   if (!XlaOpRegistry::GetCompilationDevice(device_type_.type(),
    190                                            &registration)) {
    191     return errors::InvalidArgument("No JIT device registered for ",
    192                                    device_type_.type());
    193   }
    194   *cache = new XlaCompilationCache(
    195       client.ValueOrDie(), DeviceType(registration->compilation_device_name));
    196   return Status::OK();
    197 }
    198 
    199 std::vector<OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx,
    200                                                       int num_variables) {
    201   std::vector<OptionalTensor> snapshot(num_variables);
    202   int first_variable = ctx->num_inputs() - num_variables;
    203   for (int i = 0; i < num_variables; ++i) {
    204     Var* variable = nullptr;
    205     ResourceHandle handle = HandleFromInput(ctx, first_variable + i);
    206     if (LookupResource(ctx, handle, &variable).ok()) {
    207       tf_shared_lock lock(*variable->mu());
    208       snapshot[i].name = handle.name();
    209       snapshot[i].present = true;
    210       snapshot[i].value = *variable->tensor();
    211     }
    212   }
    213   return snapshot;
    214 }
    215 
    216 void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
    217   VLOG(1) << "XlaLocalLaunchOp::Compute "
    218           << Canonicalize(function_.name(), AttrSlice(&function_.attr()));
    219   // We store information about the JIT-compiled XLA computation
    220   // in the ResourceMgr.
    221   ResourceMgr* rm = ctx->resource_manager();
    222   OP_REQUIRES(ctx, rm, errors::Internal("No resource manager."));
    223 
    224   gpu::Stream* stream =
    225       ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
    226 
    227   XlaCompilationCache* cache;
    228   OP_REQUIRES_OK(ctx, rm->LookupOrCreate<XlaCompilationCache>(
    229                           rm->default_container(), "xla_cache", &cache,
    230                           [this, ctx](XlaCompilationCache** cache) {
    231                             return BuildCompilationCache(ctx, cache);
    232                           }));
    233   // Hold the reference to the JIT during evaluation. (We could probably
    234   // free it sooner because the ResourceMgr will retain a reference, but
    235   // this is more obviously correct.)
    236   core::ScopedUnref cache_ref(cache);
    237 
    238   // Get the platform_id_ for XLA_* devices.
    239   if (platform_id_ == nullptr) {
    240     const XlaDevice::Metadata* metadata;
    241     Status s = XlaDevice::GetMetadata(ctx, &metadata);
    242     if (s.ok()) {
    243       platform_id_ = metadata->platform()->id();
    244     }
    245   }
    246 
    247   std::vector<OptionalTensor> variables =
    248       SnapshotResourceVariables(ctx, num_resource_args_);
    249 
    250   xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
    251 
    252   // Builds an XLA allocator for the device.
    253   XlaAllocator xla_allocator(client->platform(), ctx);
    254 
    255   XlaCompiler::Options options;
    256   options.client = client;
    257   options.device_type = &cache->device_type();
    258   options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
    259   options.graph_def_version = ctx->function_library()->graph_def_version();
    260   options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId);
    261   options.device_allocator = &xla_allocator;
    262 
    263   const XlaCompiler::CompilationResult* kernel;
    264   xla::LocalExecutable* executable;
    265 
    266   OP_REQUIRES_OK(ctx, cache->Compile(options, function_, num_constant_args_,
    267                                      variables, ctx, &kernel, &executable,
    268                                      /*compile_options=*/nullptr));
    269 
    270   VLOG(1) << "Executing XLA Computation...";
    271 
    272   std::unique_ptr<xla::ShapedBuffer> output;
    273   // Build xla::ShapedBuffers that point directly to the Tensor buffers.
    274   std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers;
    275   arg_buffers.reserve(kernel->xla_input_shapes.size() + 1);
    276   arg_buffers.resize(kernel->xla_input_shapes.size());
    277   std::vector<xla::ShapedBuffer*> arg_ptrs(arg_buffers.size());
    278 
    279   const int first_variable_arg = ctx->num_inputs() - num_resource_args_;
    280   // Pass remaining parameters.
    281   const Tensor* t;
    282   for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) {
    283     int arg_num = kernel->input_mapping[i];
    284     const xla::Shape& shape = kernel->xla_input_shapes[i];
    285     if (arg_num >= first_variable_arg) {
    286       t = &(variables[arg_num - first_variable_arg].value);
    287     } else {
    288       t = &(ctx->input(arg_num));
    289     }
    290 
    291     gpu::DeviceMemoryBase dmem = gpu::DeviceMemoryBase(
    292         const_cast<char*>(t->tensor_data().data()), t->tensor_data().size());
    293 
    294     const xla::Shape on_device_shape =
    295         client->backend().transfer_manager()->HostShapeToDeviceShape(shape);
    296     CHECK(xla::ShapeUtil::Equal(shape, on_device_shape))
    297         << "On-device shape "
    298         << xla::ShapeUtil::HumanStringWithLayout(on_device_shape)
    299         << " not the same as on-host shape "
    300         << xla::ShapeUtil::HumanStringWithLayout(shape);
    301     arg_buffers[i] = xla::MakeUnique<xla::ShapedBuffer>(
    302         /*on_host_shape=*/shape, /*on_device_shape=*/shape, client->platform(),
    303         client->default_device_ordinal());
    304     arg_buffers[i]->set_buffer(dmem, /*index=*/{});
    305     arg_ptrs[i] = arg_buffers[i].get();
    306 
    307     OP_REQUIRES_OK(ctx, xla_allocator.RegisterArgument(t));
    308   }
    309 
    310   // Execute the computation.
    311   VLOG(2) << "Executing computation.";
    312   xla::ExecutableRunOptions run_options;
    313   run_options.set_stream(stream);
    314   run_options.set_allocator(&xla_allocator);
    315   run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
    316   Env* env = Env::Default();
    317   auto start_time = env->NowMicros();
    318   auto run_result = executable->Run(arg_ptrs, run_options);
    319   OP_REQUIRES(ctx, run_result.ok(), run_result.status());
    320 
    321   output = run_result.ConsumeValueOrDie()->release();
    322   auto elapsed = env->NowMicros() - start_time;
    323   VLOG(2) << "Elapsed time: " << elapsed << "us";
    324 
    325   // Computation output should always be a tuple.
    326   if (VLOG_IS_ON(2)) {
    327     VLOG(2) << "Result tuple shape: " << output->on_host_shape().DebugString();
    328   }
    329   CHECK_EQ(ctx->num_outputs(), kernel->outputs.size());
    330 
    331   // Copy XLA results to the OpOutputList.
    332   int output_num = 0;
    333   for (int i = 0; i < ctx->num_outputs(); ++i) {
    334     if (kernel->outputs[i].is_constant) {
    335       // Output is a constant.
    336       const Tensor& const_tensor = kernel->outputs[i].constant_value;
    337       const size_t total_bytes = const_tensor.TotalBytes();
    338       if (stream && total_bytes > 0) {
    339         // Copy host -> device. (Empty tensors don't have backing buffers.)
    340         VLOG(1) << "Constant output tensor on device";
    341         Tensor* output_tensor;
    342         TF_CHECK_OK(
    343             ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
    344 
    345         const void* src_ptr = DMAHelper::base(&const_tensor);
    346         void* dst_ptr = DMAHelper::base(output_tensor);
    347         gpu::DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes);
    348         stream->ThenMemcpy(&gpu_dst_ptr, src_ptr, total_bytes);
    349       } else {
    350         // No copy required.
    351         ctx->set_output(i, const_tensor);
    352       }
    353     } else {
    354       const TensorShape& shape = kernel->outputs[i].shape;
    355       VLOG(2) << "Retval " << i << " shape " << shape.DebugString();
    356 
    357       gpu::DeviceMemoryBase buffer = output->buffer({output_num});
    358       Tensor output_tensor;
    359       // Looks up the owning Tensor by buffer address.
    360       OP_REQUIRES_OK(ctx, xla_allocator.MakeTensorFromBuffer(
    361                               buffer, ctx->expected_output_dtype(i), shape,
    362                               &output_tensor));
    363       ctx->set_output(i, output_tensor);
    364       ++output_num;
    365     }
    366 
    367     if (VLOG_IS_ON(3)) {
    368       VLOG(3) << ctx->mutable_output(i)->DebugString();
    369     }
    370   }
    371 
    372   // Apply variable updates, if any.
    373   VLOG(2) << "Applying variable updates";
    374   for (int i = 0; i < kernel->resource_updates.size(); ++i) {
    375     const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
    376     OP_REQUIRES(ctx,
    377                 write.input_index >= 0 && write.input_index < ctx->num_inputs(),
    378                 errors::Internal("Invalid input index for variable write."));
    379 
    380     gpu::DeviceMemoryBase buffer = output->buffer({output_num});
    381 
    382     Var* variable = nullptr;
    383     // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, not
    384     // a Tensor.
    385     OP_REQUIRES_OK(ctx, LookupOrCreateResource<Var>(
    386                             ctx, HandleFromInput(ctx, write.input_index),
    387                             &variable, [this, ctx, &write](Var** ptr) {
    388                               *ptr = new Var(write.type);
    389                               return Status::OK();
    390                             }));
    391 
    392     core::ScopedUnref s(variable);
    393 
    394     mutex_lock ml(*variable->mu());
    395     OP_REQUIRES(ctx, variable->tensor()->dtype() == write.type,
    396                 errors::Internal("Mismatched type in variable write"));
    397 
    398     // Looks up the owning Tensor by buffer address.
    399     OP_REQUIRES_OK(
    400         ctx, xla_allocator.MakeTensorFromBuffer(buffer, write.type, write.shape,
    401                                                 variable->tensor()));
    402     ++output_num;
    403   }
    404 
    405   VLOG(1) << "Done";
    406 }
    407 
    408 XlaLocalLaunchOp::~XlaLocalLaunchOp() {
    409   VLOG(1) << "XlaLocalLaunchOp destroyed";
    410 }
    411 
    412 REGISTER_KERNEL_BUILDER(Name("_XlaLaunch").Device(DEVICE_CPU),
    413                         XlaLocalLaunchOp);
    414 
    415 REGISTER_KERNEL_BUILDER(Name("_XlaLaunch")
    416                             .Device(DEVICE_GPU)
    417                             .HostMemory("constants")
    418                             .HostMemory("resources"),
    419                         XlaLocalLaunchOp);
    420 
    421 }  // namespace tensorflow
    422