1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/jit/kernels/xla_launch_op.h" 17 18 #include "tensorflow/compiler/jit/defs.h" 19 #include "tensorflow/compiler/jit/xla_device.h" 20 #include "tensorflow/compiler/tf2xla/shape_util.h" 21 #include "tensorflow/compiler/tf2xla/xla_compiler.h" 22 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 23 #include "tensorflow/compiler/xla/client/client_library.h" 24 #include "tensorflow/compiler/xla/client/local_client.h" 25 #include "tensorflow/compiler/xla/statusor.h" 26 #include "tensorflow/core/common_runtime/dma_helper.h" 27 #include "tensorflow/core/common_runtime/function.h" 28 #include "tensorflow/core/framework/allocator.h" 29 #include "tensorflow/core/framework/node_def_util.h" 30 #include "tensorflow/core/framework/op.h" 31 #include "tensorflow/core/framework/op_kernel.h" 32 #include "tensorflow/core/framework/tensor.h" 33 #include "tensorflow/core/framework/types.h" 34 #include "tensorflow/core/kernels/variable_ops.h" 35 #include "tensorflow/core/platform/env.h" 36 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 37 #include "tensorflow/core/util/stream_executor_util.h" 38 39 namespace gpu = perftools::gputools; 40 41 namespace tensorflow { 42 43 // Adapter class that wraps a Tensorflow allocator as an XLA allocator. 44 // Assumes that the Tensorflow allocator permits asynchronous deallocation: 45 // see comment on `AllowsAsynchronousDeallocation()`. 46 class XlaAllocator : public xla::DeviceMemoryAllocator { 47 public: 48 XlaAllocator(const gpu::Platform* platform, OpKernelContext* op_context); 49 ~XlaAllocator() override; 50 xla::StatusOr<gpu::DeviceMemoryBase> Allocate(int device_ordinal, uint64 size, 51 bool retry_on_failure) override; 52 Status Deallocate(int device_ordinal, gpu::DeviceMemoryBase* mem) override; 53 54 // Register an Tensor (input or resource variable) with the allocator. If 55 // the operation returns an alias to one of its inputs, then the allocator 56 // needs to be able to handle it. 57 Status RegisterArgument(const Tensor* t); 58 59 // Makes 'tensor' a wrapper around the data buffer at 'ptr'. The buffer is 60 // interpreted as having data type 'dtype' and shape 'shape'. 61 Status MakeTensorFromBuffer(gpu::DeviceMemoryBase buffer, DataType dtype, 62 const TensorShape& shape, Tensor* tensor) const; 63 64 // The Tensorflow BFC allocator used on GPU allows host-side deallocation 65 // before GPU execution takes place. Tensorflow uses the ordering of the main 66 // compute stream to enforce a happens-before relationship between a memory 67 // allocation and code that reuses the same memory. If Tensorflow adds 68 // support for multiple GPU streams or allocators with different ordering 69 // requirements, this code may need to change. 70 // (This attribute has no effect on CPU.) 71 bool AllowsAsynchronousDeallocation() const override { return true; } 72 73 private: 74 OpKernelContext* const op_context_; 75 76 // Map from pointer address to the owning Tensor; used by 77 // MakeTensorFromBuffer. Also used to automatically release Tensors when the 78 // allocator is freed. 79 std::unordered_map<void*, Tensor> tensors_; 80 }; 81 82 XlaAllocator::XlaAllocator(const gpu::Platform* platform, 83 OpKernelContext* op_context) 84 : xla::DeviceMemoryAllocator(platform), op_context_(op_context) {} 85 86 XlaAllocator::~XlaAllocator() = default; 87 88 xla::StatusOr<gpu::DeviceMemoryBase> XlaAllocator::Allocate( 89 int device_ordinal, uint64 size, bool retry_on_failure) { 90 AllocatorAttributes allocator_attrs; 91 allocator_attrs.set_on_host(false); 92 93 AllocationAttributes allocation_attrs; 94 allocation_attrs.no_retry_on_failure = !retry_on_failure; 95 96 Tensor t; 97 Status status = op_context_->allocate_temp( 98 DT_UINT8, TensorShape({static_cast<int64>(size)}), &t, allocator_attrs, 99 allocation_attrs); 100 if (!status.ok()) { 101 VLOG(2) << "Allocation failed " << size; 102 return status; 103 } 104 void* data = 105 reinterpret_cast<void*>(const_cast<char*>(t.tensor_data().data())); 106 tensors_[data] = t; 107 return gpu::DeviceMemoryBase(data, size); 108 } 109 110 Status XlaAllocator::RegisterArgument(const Tensor* t) { 111 void* data = 112 reinterpret_cast<void*>(const_cast<char*>(t->tensor_data().data())); 113 tensors_[data] = *t; 114 return Status::OK(); 115 } 116 117 Status XlaAllocator::Deallocate(int device_ordinal, 118 gpu::DeviceMemoryBase* mem) { 119 if (mem->opaque() != nullptr) { 120 if (tensors_.erase(mem->opaque()) == 0) { 121 return tensorflow::errors::InvalidArgument("Unknown tensor address"); 122 } 123 } 124 return Status::OK(); 125 } 126 127 Status XlaAllocator::MakeTensorFromBuffer(gpu::DeviceMemoryBase buffer, 128 DataType dtype, 129 const TensorShape& shape, 130 Tensor* out_tensor) const { 131 void* ptr = const_cast<void*>(buffer.opaque()); 132 auto it = tensors_.find(ptr); 133 if (it == tensors_.end()) { 134 return errors::InvalidArgument("Unknown tensor address"); 135 } 136 const Tensor& tensor = it->second; 137 138 int64 output_size = DataTypeSize(dtype) * shape.num_elements(); 139 if (tensor.TotalBytes() == output_size) { 140 out_tensor->UnsafeCopyFromInternal(tensor, dtype, shape); 141 } else { 142 Tensor slice = tensor.Slice(0, output_size); 143 out_tensor->UnsafeCopyFromInternal(slice, dtype, shape); 144 } 145 return Status::OK(); 146 } 147 148 XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) 149 : OpKernel(ctx), device_type_(ctx->device_type()) { 150 const NameAttrList* func; 151 OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func)); 152 function_ = *func; 153 DataTypeVector constant_types; 154 OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types)); 155 num_constant_args_ = constant_types.size(); 156 OP_REQUIRES_OK(ctx, ctx->GetAttr("Nresources", &num_resource_args_)); 157 if (device_type_ == DeviceType(DEVICE_CPU)) { 158 platform_id_ = gpu::host::kHostPlatformId; 159 } else if (device_type_ == DeviceType(DEVICE_GPU)) { 160 platform_id_ = gpu::cuda::kCudaPlatformId; 161 } else { 162 platform_id_ = nullptr; 163 } 164 } 165 166 Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx, 167 XlaCompilationCache** cache) { 168 const XlaDevice::Metadata* metadata; 169 Status s = XlaDevice::GetMetadata(ctx, &metadata); 170 if (s.ok()) { 171 *cache = new XlaCompilationCache(metadata->client(), 172 metadata->jit_device_type()); 173 return Status::OK(); 174 } 175 176 auto platform = gpu::MultiPlatformManager::PlatformWithId(platform_id_); 177 if (!platform.ok()) { 178 return StreamExecutorUtil::ConvertStatus(platform.status()); 179 } 180 xla::LocalClientOptions client_options; 181 client_options.set_platform(platform.ValueOrDie()); 182 client_options.set_intra_op_parallelism_threads( 183 ctx->device()->tensorflow_cpu_worker_threads()->num_threads); 184 auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options); 185 if (!client.ok()) { 186 return client.status(); 187 } 188 const XlaOpRegistry::DeviceRegistration* registration; 189 if (!XlaOpRegistry::GetCompilationDevice(device_type_.type(), 190 ®istration)) { 191 return errors::InvalidArgument("No JIT device registered for ", 192 device_type_.type()); 193 } 194 *cache = new XlaCompilationCache( 195 client.ValueOrDie(), DeviceType(registration->compilation_device_name)); 196 return Status::OK(); 197 } 198 199 std::vector<OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx, 200 int num_variables) { 201 std::vector<OptionalTensor> snapshot(num_variables); 202 int first_variable = ctx->num_inputs() - num_variables; 203 for (int i = 0; i < num_variables; ++i) { 204 Var* variable = nullptr; 205 ResourceHandle handle = HandleFromInput(ctx, first_variable + i); 206 if (LookupResource(ctx, handle, &variable).ok()) { 207 tf_shared_lock lock(*variable->mu()); 208 snapshot[i].name = handle.name(); 209 snapshot[i].present = true; 210 snapshot[i].value = *variable->tensor(); 211 } 212 } 213 return snapshot; 214 } 215 216 void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { 217 VLOG(1) << "XlaLocalLaunchOp::Compute " 218 << Canonicalize(function_.name(), AttrSlice(&function_.attr())); 219 // We store information about the JIT-compiled XLA computation 220 // in the ResourceMgr. 221 ResourceMgr* rm = ctx->resource_manager(); 222 OP_REQUIRES(ctx, rm, errors::Internal("No resource manager.")); 223 224 gpu::Stream* stream = 225 ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; 226 227 XlaCompilationCache* cache; 228 OP_REQUIRES_OK(ctx, rm->LookupOrCreate<XlaCompilationCache>( 229 rm->default_container(), "xla_cache", &cache, 230 [this, ctx](XlaCompilationCache** cache) { 231 return BuildCompilationCache(ctx, cache); 232 })); 233 // Hold the reference to the JIT during evaluation. (We could probably 234 // free it sooner because the ResourceMgr will retain a reference, but 235 // this is more obviously correct.) 236 core::ScopedUnref cache_ref(cache); 237 238 // Get the platform_id_ for XLA_* devices. 239 if (platform_id_ == nullptr) { 240 const XlaDevice::Metadata* metadata; 241 Status s = XlaDevice::GetMetadata(ctx, &metadata); 242 if (s.ok()) { 243 platform_id_ = metadata->platform()->id(); 244 } 245 } 246 247 std::vector<OptionalTensor> variables = 248 SnapshotResourceVariables(ctx, num_resource_args_); 249 250 xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client()); 251 252 // Builds an XLA allocator for the device. 253 XlaAllocator xla_allocator(client->platform(), ctx); 254 255 XlaCompiler::Options options; 256 options.client = client; 257 options.device_type = &cache->device_type(); 258 options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); 259 options.graph_def_version = ctx->function_library()->graph_def_version(); 260 options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId); 261 options.device_allocator = &xla_allocator; 262 263 const XlaCompiler::CompilationResult* kernel; 264 xla::LocalExecutable* executable; 265 266 OP_REQUIRES_OK(ctx, cache->Compile(options, function_, num_constant_args_, 267 variables, ctx, &kernel, &executable, 268 /*compile_options=*/nullptr)); 269 270 VLOG(1) << "Executing XLA Computation..."; 271 272 std::unique_ptr<xla::ShapedBuffer> output; 273 // Build xla::ShapedBuffers that point directly to the Tensor buffers. 274 std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers; 275 arg_buffers.reserve(kernel->xla_input_shapes.size() + 1); 276 arg_buffers.resize(kernel->xla_input_shapes.size()); 277 std::vector<xla::ShapedBuffer*> arg_ptrs(arg_buffers.size()); 278 279 const int first_variable_arg = ctx->num_inputs() - num_resource_args_; 280 // Pass remaining parameters. 281 const Tensor* t; 282 for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) { 283 int arg_num = kernel->input_mapping[i]; 284 const xla::Shape& shape = kernel->xla_input_shapes[i]; 285 if (arg_num >= first_variable_arg) { 286 t = &(variables[arg_num - first_variable_arg].value); 287 } else { 288 t = &(ctx->input(arg_num)); 289 } 290 291 gpu::DeviceMemoryBase dmem = gpu::DeviceMemoryBase( 292 const_cast<char*>(t->tensor_data().data()), t->tensor_data().size()); 293 294 const xla::Shape on_device_shape = 295 client->backend().transfer_manager()->HostShapeToDeviceShape(shape); 296 CHECK(xla::ShapeUtil::Equal(shape, on_device_shape)) 297 << "On-device shape " 298 << xla::ShapeUtil::HumanStringWithLayout(on_device_shape) 299 << " not the same as on-host shape " 300 << xla::ShapeUtil::HumanStringWithLayout(shape); 301 arg_buffers[i] = xla::MakeUnique<xla::ShapedBuffer>( 302 /*on_host_shape=*/shape, /*on_device_shape=*/shape, client->platform(), 303 client->default_device_ordinal()); 304 arg_buffers[i]->set_buffer(dmem, /*index=*/{}); 305 arg_ptrs[i] = arg_buffers[i].get(); 306 307 OP_REQUIRES_OK(ctx, xla_allocator.RegisterArgument(t)); 308 } 309 310 // Execute the computation. 311 VLOG(2) << "Executing computation."; 312 xla::ExecutableRunOptions run_options; 313 run_options.set_stream(stream); 314 run_options.set_allocator(&xla_allocator); 315 run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); 316 Env* env = Env::Default(); 317 auto start_time = env->NowMicros(); 318 auto run_result = executable->Run(arg_ptrs, run_options); 319 OP_REQUIRES(ctx, run_result.ok(), run_result.status()); 320 321 output = run_result.ConsumeValueOrDie()->release(); 322 auto elapsed = env->NowMicros() - start_time; 323 VLOG(2) << "Elapsed time: " << elapsed << "us"; 324 325 // Computation output should always be a tuple. 326 if (VLOG_IS_ON(2)) { 327 VLOG(2) << "Result tuple shape: " << output->on_host_shape().DebugString(); 328 } 329 CHECK_EQ(ctx->num_outputs(), kernel->outputs.size()); 330 331 // Copy XLA results to the OpOutputList. 332 int output_num = 0; 333 for (int i = 0; i < ctx->num_outputs(); ++i) { 334 if (kernel->outputs[i].is_constant) { 335 // Output is a constant. 336 const Tensor& const_tensor = kernel->outputs[i].constant_value; 337 const size_t total_bytes = const_tensor.TotalBytes(); 338 if (stream && total_bytes > 0) { 339 // Copy host -> device. (Empty tensors don't have backing buffers.) 340 VLOG(1) << "Constant output tensor on device"; 341 Tensor* output_tensor; 342 TF_CHECK_OK( 343 ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); 344 345 const void* src_ptr = DMAHelper::base(&const_tensor); 346 void* dst_ptr = DMAHelper::base(output_tensor); 347 gpu::DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes); 348 stream->ThenMemcpy(&gpu_dst_ptr, src_ptr, total_bytes); 349 } else { 350 // No copy required. 351 ctx->set_output(i, const_tensor); 352 } 353 } else { 354 const TensorShape& shape = kernel->outputs[i].shape; 355 VLOG(2) << "Retval " << i << " shape " << shape.DebugString(); 356 357 gpu::DeviceMemoryBase buffer = output->buffer({output_num}); 358 Tensor output_tensor; 359 // Looks up the owning Tensor by buffer address. 360 OP_REQUIRES_OK(ctx, xla_allocator.MakeTensorFromBuffer( 361 buffer, ctx->expected_output_dtype(i), shape, 362 &output_tensor)); 363 ctx->set_output(i, output_tensor); 364 ++output_num; 365 } 366 367 if (VLOG_IS_ON(3)) { 368 VLOG(3) << ctx->mutable_output(i)->DebugString(); 369 } 370 } 371 372 // Apply variable updates, if any. 373 VLOG(2) << "Applying variable updates"; 374 for (int i = 0; i < kernel->resource_updates.size(); ++i) { 375 const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; 376 OP_REQUIRES(ctx, 377 write.input_index >= 0 && write.input_index < ctx->num_inputs(), 378 errors::Internal("Invalid input index for variable write.")); 379 380 gpu::DeviceMemoryBase buffer = output->buffer({output_num}); 381 382 Var* variable = nullptr; 383 // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, not 384 // a Tensor. 385 OP_REQUIRES_OK(ctx, LookupOrCreateResource<Var>( 386 ctx, HandleFromInput(ctx, write.input_index), 387 &variable, [this, ctx, &write](Var** ptr) { 388 *ptr = new Var(write.type); 389 return Status::OK(); 390 })); 391 392 core::ScopedUnref s(variable); 393 394 mutex_lock ml(*variable->mu()); 395 OP_REQUIRES(ctx, variable->tensor()->dtype() == write.type, 396 errors::Internal("Mismatched type in variable write")); 397 398 // Looks up the owning Tensor by buffer address. 399 OP_REQUIRES_OK( 400 ctx, xla_allocator.MakeTensorFromBuffer(buffer, write.type, write.shape, 401 variable->tensor())); 402 ++output_num; 403 } 404 405 VLOG(1) << "Done"; 406 } 407 408 XlaLocalLaunchOp::~XlaLocalLaunchOp() { 409 VLOG(1) << "XlaLocalLaunchOp destroyed"; 410 } 411 412 REGISTER_KERNEL_BUILDER(Name("_XlaLaunch").Device(DEVICE_CPU), 413 XlaLocalLaunchOp); 414 415 REGISTER_KERNEL_BUILDER(Name("_XlaLaunch") 416 .Device(DEVICE_GPU) 417 .HostMemory("constants") 418 .HostMemory("resources"), 419 XlaLocalLaunchOp); 420 421 } // namespace tensorflow 422