1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/jit/kernels/xla_ops.h" 17 18 #include "absl/container/flat_hash_map.h" 19 #include "absl/memory/memory.h" 20 #include "tensorflow/compiler/jit/defs.h" 21 #include "tensorflow/compiler/jit/flags.h" 22 #include "tensorflow/compiler/tf2xla/shape_util.h" 23 #include "tensorflow/compiler/tf2xla/tf2xla_util.h" 24 #include "tensorflow/compiler/tf2xla/xla_compiler.h" 25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 26 #include "tensorflow/compiler/xla/client/client_library.h" 27 #include "tensorflow/compiler/xla/client/local_client.h" 28 #include "tensorflow/compiler/xla/service/compiler.h" 29 #include "tensorflow/compiler/xla/status_macros.h" 30 #include "tensorflow/compiler/xla/statusor.h" 31 #include "tensorflow/core/common_runtime/dma_helper.h" 32 #include "tensorflow/core/common_runtime/function.h" 33 #include "tensorflow/core/framework/allocator.h" 34 #include "tensorflow/core/framework/node_def_util.h" 35 #include "tensorflow/core/framework/op.h" 36 #include "tensorflow/core/framework/op_kernel.h" 37 #include "tensorflow/core/framework/tensor.h" 38 #include "tensorflow/core/framework/types.h" 39 #include "tensorflow/core/lib/core/errors.h" 40 #include "tensorflow/core/lib/core/status.h" 41 #include "tensorflow/core/platform/env.h" 42 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 43 #include "tensorflow/core/util/stream_executor_util.h" 44 45 // OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that 46 // in error case, it returns RET instead of void. 47 #define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \ 48 do { \ 49 ::tensorflow::Status _s(__VA_ARGS__); \ 50 if (!TF_PREDICT_TRUE(_s.ok())) { \ 51 (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ 52 return RET; \ 53 } \ 54 } while (0) 55 56 namespace tensorflow { 57 58 namespace { 59 60 XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) { 61 DeviceType device_type = ctx->device_type(); 62 se::Platform::Id platform_id = nullptr; 63 const XlaDevice::Metadata* xla_device_metadata = nullptr; 64 std::unique_ptr<XlaAllocator> xla_allocator; 65 xla::DeviceMemoryAllocator* device_allocator = nullptr; 66 67 if (ctx->device_type() == DeviceType(DEVICE_CPU)) { 68 platform_id = se::host::kHostPlatformId; 69 } else if (ctx->device_type() == DeviceType(DEVICE_GPU)) { 70 platform_id = ctx->device() 71 ->tensorflow_gpu_device_info() 72 ->stream->parent() 73 ->platform() 74 ->id(); 75 } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) { 76 // If we are on an XlaDevice, use the underlying XLA platform's allocator 77 // directly. We could use the StreamExecutor's allocator which may 78 // theoretically be more correct, but XLA returns a nice OOM message in a 79 // Status and StreamExecutor does not. 80 // 81 // Importantly we can't use ctx->device()->GetAllocator() as the allocator 82 // (which xla_allocator above uses) as on an XlaDevice, this is a dummy 83 // allocator that returns XlaTensor objects. The XlaCompiler needs a real 84 // allocator to allocate real buffers. 85 86 platform_id = xla_device_metadata->platform()->id(); 87 device_allocator = 88 xla_device_metadata->client()->backend().memory_allocator(); 89 } 90 91 if (!device_allocator) { 92 xla::StatusOr<se::Platform*> maybe_platform = 93 se::MultiPlatformManager::PlatformWithId(platform_id); 94 OP_REQUIRES_OK_RETURN(ctx, XlaPlatformInfo(), maybe_platform.status()); 95 96 xla_allocator = absl::make_unique<XlaAllocator>( 97 maybe_platform.ValueOrDie(), ctx->device()->GetAllocator({})); 98 } 99 100 return XlaPlatformInfo(device_type, platform_id, xla_device_metadata, 101 std::move(xla_allocator), device_allocator); 102 } 103 104 // A closure describing how to run a compiled version of a TensorFlow function. 105 // 106 // It may seem unusual to stick the resource variable snapshots in this class. 107 // This is necessary: we need to use the snapshots observed by the compiler as 108 // the initial values for the resource variables (and cannot snapshot them again 109 // during execution) because otherwise we risk observing a different snapshot 110 // with shapes different from what we compiled for. 111 class XlaExecutableClosure { 112 public: 113 explicit XlaExecutableClosure( 114 xla::LocalClient* client, xla::LocalExecutable* executable, 115 const XlaCompiler::CompilationResult* compilation_result, 116 std::map<int, OptionalTensor> resource_var_snapshots, 117 int num_constant_args) 118 : client_(client), 119 executable_(executable), 120 compilation_result_(compilation_result), 121 resource_var_snapshots_(std::move(resource_var_snapshots)), 122 num_constant_args_(num_constant_args) {} 123 124 XlaExecutableClosure(XlaExecutableClosure&&) = default; 125 XlaExecutableClosure& operator=(XlaExecutableClosure&&) = default; 126 127 xla::LocalClient* client() const { return client_; } 128 xla::LocalExecutable* executable() const { return executable_; } 129 const XlaCompiler::CompilationResult* compilation_result() const { 130 return compilation_result_; 131 } 132 const std::map<int, OptionalTensor>& resource_var_snapshots() const { 133 return resource_var_snapshots_; 134 } 135 int num_constant_args() const { return num_constant_args_; } 136 137 private: 138 xla::LocalClient* client_; 139 xla::LocalExecutable* executable_; 140 const XlaCompiler::CompilationResult* compilation_result_; 141 std::map<int, OptionalTensor> resource_var_snapshots_; 142 int num_constant_args_; 143 144 TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure); 145 }; 146 147 // This maintains a mapping from a globally unique ID to XlaExecutableClosure 148 // instances. 149 class XlaExecutableClosureStore { 150 public: 151 XlaExecutableClosureStore() : key_counter_(0) {} 152 153 using KeyT = string; 154 155 KeyT Produce(XlaExecutableClosure result) { 156 mutex_lock l(mutex_); 157 KeyT key = absl::StrCat(key_counter_++); 158 bool insert_successful = closures_.emplace(key, std::move(result)).second; 159 DCHECK(insert_successful); 160 (void)insert_successful; 161 return key; 162 } 163 164 XlaExecutableClosure Consume(const KeyT& key) { 165 mutex_lock l(mutex_); 166 auto it = closures_.find(key); 167 DCHECK(it != closures_.end()); 168 XlaExecutableClosure value = std::move(it->second); 169 closures_.erase(it); 170 return value; 171 } 172 173 static XlaExecutableClosureStore* Global() { 174 static XlaExecutableClosureStore* instance = new XlaExecutableClosureStore; 175 return instance; 176 } 177 178 private: 179 mutex mutex_; 180 int64 key_counter_ GUARDED_BY(mutex_); 181 absl::flat_hash_map<KeyT, XlaExecutableClosure> closures_ GUARDED_BY(mutex_); 182 183 TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore); 184 }; 185 186 } // namespace 187 188 XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, 189 const std::vector<int>& constants, 190 const std::vector<int>& resources, 191 const NameAttrList& function) 192 : OpKernel(ctx), 193 constants_(constants), 194 resources_(resources), 195 function_(function), 196 platform_info_(PlatformInfoFromContext(ctx)) {} 197 198 static Status BuildCompilationCache(OpKernelContext* ctx, 199 const XlaPlatformInfo& platform_info, 200 XlaCompilationCache** cache) { 201 if (platform_info.xla_device_metadata()) { 202 *cache = new XlaCompilationCache( 203 platform_info.xla_device_metadata()->client(), 204 platform_info.xla_device_metadata()->jit_device_type()); 205 return Status::OK(); 206 } 207 208 auto platform = 209 se::MultiPlatformManager::PlatformWithId(platform_info.platform_id()); 210 if (!platform.ok()) { 211 return platform.status(); 212 } 213 214 xla::StatusOr<xla::Compiler*> compiler_for_platform = 215 xla::Compiler::GetForPlatform(platform.ValueOrDie()); 216 if (!compiler_for_platform.ok()) { 217 // In some rare cases (usually in unit tests with very small clusters) we 218 // may end up transforming an XLA cluster with at least one GPU operation 219 // (which would normally force the cluster to be compiled using XLA:GPU) 220 // into an XLA cluster with no GPU operations (i.e. containing only CPU 221 // operations). Such a cluster can fail compilation (in way that 222 // MarkForCompilation could not have detected) if the CPU JIT is not linked 223 // in. 224 // 225 // So bail out of _XlaCompile in this case, and let the executor handle the 226 // situation for us. 227 const Status& status = compiler_for_platform.status(); 228 if (status.code() == error::NOT_FOUND) { 229 return errors::Unimplemented("Could not find compiler for platform ", 230 platform.ValueOrDie()->Name(), ": ", 231 status.ToString()); 232 } 233 } 234 235 xla::LocalClientOptions client_options; 236 client_options.set_platform(platform.ValueOrDie()); 237 client_options.set_intra_op_parallelism_threads( 238 ctx->device()->tensorflow_cpu_worker_threads()->num_threads); 239 auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options); 240 if (!client.ok()) { 241 return client.status(); 242 } 243 const XlaOpRegistry::DeviceRegistration* registration; 244 if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(), 245 ®istration)) { 246 return errors::InvalidArgument("No JIT device registered for ", 247 platform_info.device_type().type()); 248 } 249 *cache = new XlaCompilationCache( 250 client.ValueOrDie(), DeviceType(registration->compilation_device_name)); 251 return Status::OK(); 252 } 253 254 static Status CompileToLocalExecutable( 255 OpKernelContext* ctx, const NameAttrList& function, 256 const XlaPlatformInfo& platform_info, absl::Span<const int> resources, 257 absl::Span<const int> constants, bool lazy, xla::LocalClient** client, 258 std::map<int, OptionalTensor>* variables, 259 const XlaCompiler::CompilationResult** kernel, 260 xla::LocalExecutable** executable) { 261 // We store information about the JIT-compiled XLA computation 262 // in the ResourceMgr. 263 ResourceMgr* rm = ctx->resource_manager(); 264 if (!rm) { 265 return errors::Internal("No resource manager."); 266 } 267 268 XlaCompilationCache* cache; 269 TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>( 270 rm->default_container(), "xla_cache", &cache, 271 [&](XlaCompilationCache** cache) { 272 return BuildCompilationCache(ctx, platform_info, cache); 273 })); 274 // Hold the reference to the JIT during evaluation. (We could probably 275 // free it sooner because the ResourceMgr will retain a reference, but 276 // this is more obviously correct.) 277 core::ScopedUnref cache_ref(cache); 278 279 TF_RETURN_IF_ERROR(SnapshotResourceVariables(ctx, resources, variables)); 280 *client = static_cast<xla::LocalClient*>(cache->client()); 281 282 XlaCompiler::Options options; 283 options.client = *client; 284 if (ctx->op_device_context() != nullptr) { 285 options.device_ordinal = 286 ctx->op_device_context()->stream()->parent()->device_ordinal(); 287 } 288 options.device_type = cache->device_type(); 289 options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); 290 options.graph_def_version = ctx->function_library()->graph_def_version(); 291 options.allow_cpu_custom_calls = 292 (platform_info.platform_id() == se::host::kHostPlatformId); 293 options.device_allocator = platform_info.allocator(); 294 if (platform_info.xla_device_metadata()) { 295 options.shape_representation_fn = 296 platform_info.xla_device_metadata()->shape_representation_fn(); 297 } 298 299 std::map<int, Tensor> constant_args; 300 for (int i : constants) { 301 constant_args.insert({i, ctx->input(i)}); 302 } 303 XlaCompiler::CompileOptions compile_options; 304 compile_options.is_entry_computation = true; 305 // If we resolve constants we never emit them on the device, meaning that if 306 // they are needed by a following computation the host has to transfer 307 // them. Not resolving constants is expected to be faster than resolving 308 // constants. 309 compile_options.resolve_compile_time_constants = true; 310 // Optimization: where possible, have the computation return a naked array 311 // rather than a one-element tuple. 312 compile_options.always_return_tuple = false; 313 314 std::vector<XlaCompiler::Argument> args; 315 TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments( 316 constant_args, *variables, ctx, &args)); 317 return cache->Compile(options, function, args, compile_options, 318 lazy ? XlaCompilationCache::CompileMode::kLazy 319 : XlaCompilationCache::CompileMode::kStrict, 320 kernel, executable); 321 } 322 323 void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { 324 VLOG(1) << "XlaLocalLaunchOpBase::Compute " 325 << Canonicalize(function_.name(), AttrSlice(&function_.attr())); 326 327 xla::LocalClient* client; 328 const XlaCompiler::CompilationResult* kernel; 329 xla::LocalExecutable* executable; 330 std::map<int, OptionalTensor> variables; 331 332 { 333 Status s = CompileToLocalExecutable( 334 ctx, function_, platform_info_, resources_, constants_, /*lazy=*/false, 335 &client, &variables, &kernel, &executable); 336 if (!s.ok() && (platform_info_.device_type().type_string() == DEVICE_CPU || 337 platform_info_.device_type().type_string() == DEVICE_GPU)) { 338 // Suggest auto jit if the failure was with GPU or CPU. 339 errors::AppendToMessage(&s, 340 xla::status_macros::kPossibleAutoJitAlternative); 341 } 342 343 OP_REQUIRES_OK(ctx, s); 344 } 345 346 se::Stream* stream = 347 ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; 348 349 VLOG(1) << "Executing XLA Computation..."; 350 351 XlaComputationLaunchContext launch_context( 352 client, platform_info_.allocator(), 353 /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(), 354 platform_info_.UseMultipleStreams()); 355 launch_context.PopulateInputs(ctx, kernel, variables, 356 /*missing_ctx_input_prefix=*/0); 357 358 // Execute the computation. 359 VLOG(2) << "Executing computation."; 360 xla::ExecutableRunOptions run_options; 361 run_options.set_stream(stream); 362 run_options.set_allocator(platform_info_.allocator()); 363 run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); 364 run_options.set_rng_seed(GetXLARandomSeed()); 365 Env* env = Env::Default(); 366 auto start_time = env->NowMicros(); 367 368 auto run_result = executable->Run(launch_context.arguments(), run_options); 369 OP_REQUIRES(ctx, run_result.ok(), run_result.status()); 370 371 auto elapsed = env->NowMicros() - start_time; 372 VLOG(2) << "Elapsed time: " << elapsed << "us"; 373 374 OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs( 375 ctx, kernel, run_result.ConsumeValueOrDie(), 376 /*missing_ctx_input_prefix=*/0)); 377 VLOG(1) << "Done"; 378 } 379 380 namespace { 381 // Helper static functions to construct parameters for 382 // XlaLocalLaunchBase constructor from OpKernelConstruction. 383 std::vector<int> ConstantsVector(OpKernelConstruction* ctx) { 384 DataTypeVector constant_types; 385 OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(), 386 ctx->GetAttr("Tconstants", &constant_types)); 387 std::vector<int> constants(constant_types.size()); 388 std::iota(constants.begin(), constants.end(), 0); 389 return constants; 390 } 391 392 std::vector<int> ResourcesVector(OpKernelConstruction* ctx) { 393 DataTypeVector constant_types; 394 OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(), 395 ctx->GetAttr("Tconstants", &constant_types)); 396 397 DataTypeVector arg_types; 398 OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(), 399 ctx->GetAttr("Targs", &arg_types)); 400 401 int num_resources; 402 OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(), 403 ctx->GetAttr("Nresources", &num_resources)); 404 405 std::vector<int> resources(num_resources); 406 std::iota(resources.begin(), resources.end(), 407 constant_types.size() + arg_types.size()); 408 return resources; 409 } 410 411 NameAttrList FunctionAttr(OpKernelConstruction* ctx) { 412 const NameAttrList* func; 413 OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func)); 414 return *func; 415 } 416 417 bool MustCompileAttr(OpKernelConstruction* ctx) { 418 bool must_compile; 419 OP_REQUIRES_OK_RETURN(ctx, false, 420 ctx->GetAttr("must_compile", &must_compile)); 421 return must_compile; 422 } 423 } // namespace 424 425 XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) 426 : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx), 427 FunctionAttr(ctx)) {} 428 429 XlaLocalLaunchOp::~XlaLocalLaunchOp() { 430 VLOG(1) << "XlaLocalLaunchOp destroyed"; 431 } 432 433 XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx) 434 : OpKernel(ctx), 435 constants_(ConstantsVector(ctx)), 436 resources_(ResourcesVector(ctx)), 437 function_(FunctionAttr(ctx)), 438 platform_info_(PlatformInfoFromContext(ctx)), 439 must_compile_(MustCompileAttr(ctx)) {} 440 441 void XlaCompileOp::Compute(OpKernelContext* ctx) { 442 VLOG(3) << "XlaCompileOp " << def().name() 443 << (must_compile_ ? "(must-compile)" : ""); 444 xla::LocalClient* client; 445 const XlaCompiler::CompilationResult* kernel; 446 xla::LocalExecutable* executable; 447 std::map<int, OptionalTensor> variables; 448 449 bool cannot_compile_cluster; 450 { 451 mutex_lock guard(cannot_compile_cluster_mu_); 452 cannot_compile_cluster = cannot_compile_cluster_; 453 } 454 455 if (GetXlaOpsCommonFlags().tf_xla_always_defer_compilation || 456 cannot_compile_cluster) { 457 executable = nullptr; 458 } else { 459 Status status = CompileToLocalExecutable( 460 ctx, function_, platform_info_, resources_, constants_, 461 /*lazy=*/!must_compile_, &client, &variables, &kernel, &executable); 462 if (must_compile_ || status.code() != error::UNIMPLEMENTED) { 463 OP_REQUIRES_OK(ctx, status); 464 } 465 466 if (status.code() == error::UNIMPLEMENTED) { 467 LOG(WARNING) << "Compilation failed:" << status.ToString() 468 << ". Falling back to TF function call."; 469 executable = nullptr; 470 mutex_lock guard(cannot_compile_cluster_mu_); 471 cannot_compile_cluster_ = true; 472 } 473 } 474 475 AllocatorAttributes host_alloc_attrs; 476 host_alloc_attrs.set_gpu_compatible(true); 477 host_alloc_attrs.set_on_host(true); 478 Allocator* cpu_allocator = ctx->device()->GetAllocator(host_alloc_attrs); 479 480 if (!executable) { 481 DCHECK(!must_compile_); 482 Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({})); 483 484 Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({})); 485 compilation_successful.scalar<bool>()() = false; 486 ctx->set_output(0, Tensor(cpu_allocator, DT_STRING, TensorShape({}))); 487 ctx->set_output(1, compilation_successful); 488 return; 489 } 490 491 // Each execution of an XlaCompile op creates a new XlaExecutableClosure, even 492 // if it didn't have to compile the cluster because of a compilation-cache 493 // hit. This is because we at least need new snapshots of the resource 494 // variables. 495 XlaExecutableClosureStore::KeyT key = 496 XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure( 497 client, executable, kernel, std::move(variables), constants_.size())); 498 499 Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({})); 500 compilation_key.flat<string>()(0) = key; 501 502 Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({})); 503 compilation_successful.flat<bool>()(0) = true; 504 505 ctx->set_output(0, compilation_key); 506 ctx->set_output(1, compilation_successful); 507 } 508 509 XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) 510 : OpKernel(ctx), platform_info_(PlatformInfoFromContext(ctx)) {} 511 512 void XlaRunOp::Compute(OpKernelContext* ctx) { 513 VLOG(3) << "XlaRunOp " << def().name(); 514 Tensor key_tensor = ctx->input(ctx->num_inputs() - 1); 515 const XlaExecutableClosureStore::KeyT& key = key_tensor.flat<string>()(0); 516 517 XlaExecutableClosure closure = 518 XlaExecutableClosureStore::Global()->Consume(key); 519 520 XlaComputationLaunchContext launch_context( 521 closure.client(), platform_info_.allocator(), 522 /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(), 523 /*use_multiple_streams=*/platform_info_.UseMultipleStreams()); 524 525 // We're missing the must-be-constant inputs, tell `PopulateInputs` 526 // about this. We don't actually need these inputs because they've 527 // already been baked into the compiled kernel. 528 launch_context.PopulateInputs( 529 ctx, closure.compilation_result(), closure.resource_var_snapshots(), 530 /*missing_ctx_input_prefix=*/closure.num_constant_args()); 531 532 se::Stream* stream = 533 ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; 534 xla::ExecutableRunOptions run_options; 535 run_options.set_stream(stream); 536 run_options.set_allocator(platform_info_.allocator()); 537 run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); 538 run_options.set_rng_seed(GetXLARandomSeed()); 539 Env* env = Env::Default(); 540 auto start_time = env->NowMicros(); 541 542 auto run_result = 543 closure.executable()->Run(launch_context.arguments(), run_options); 544 OP_REQUIRES(ctx, run_result.ok(), run_result.status()); 545 546 auto elapsed = env->NowMicros() - start_time; 547 VLOG(2) << "Elapsed time in computation: " << elapsed << "us"; 548 549 OP_REQUIRES_OK( 550 ctx, 551 launch_context.PopulateOutputs( 552 ctx, closure.compilation_result(), run_result.ConsumeValueOrDie(), 553 /*missing_ctx_input_prefix=*/closure.num_constant_args())); 554 } 555 556 REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp); 557 558 REGISTER_KERNEL_BUILDER(Name("XlaLaunch") 559 .Device(DEVICE_GPU) 560 .HostMemory("constants") 561 .HostMemory("resources"), 562 XlaLocalLaunchOp); 563 564 REGISTER_KERNEL_BUILDER(Name("_XlaCompile").Device(DEVICE_CPU), XlaCompileOp); 565 REGISTER_KERNEL_BUILDER(Name("_XlaCompile") 566 .Device(DEVICE_GPU) 567 .HostMemory("constants") 568 .HostMemory("key") 569 .HostMemory("compilation_successful") 570 .HostMemory("resources"), 571 XlaCompileOp); 572 573 REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_CPU), XlaRunOp); 574 REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_GPU), XlaRunOp); 575 576 } // namespace tensorflow 577