Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/jit/kernels/xla_ops.h"
     17 
     18 #include "absl/container/flat_hash_map.h"
     19 #include "absl/memory/memory.h"
     20 #include "tensorflow/compiler/jit/defs.h"
     21 #include "tensorflow/compiler/jit/flags.h"
     22 #include "tensorflow/compiler/tf2xla/shape_util.h"
     23 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
     24 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
     25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     26 #include "tensorflow/compiler/xla/client/client_library.h"
     27 #include "tensorflow/compiler/xla/client/local_client.h"
     28 #include "tensorflow/compiler/xla/service/compiler.h"
     29 #include "tensorflow/compiler/xla/status_macros.h"
     30 #include "tensorflow/compiler/xla/statusor.h"
     31 #include "tensorflow/core/common_runtime/dma_helper.h"
     32 #include "tensorflow/core/common_runtime/function.h"
     33 #include "tensorflow/core/framework/allocator.h"
     34 #include "tensorflow/core/framework/node_def_util.h"
     35 #include "tensorflow/core/framework/op.h"
     36 #include "tensorflow/core/framework/op_kernel.h"
     37 #include "tensorflow/core/framework/tensor.h"
     38 #include "tensorflow/core/framework/types.h"
     39 #include "tensorflow/core/lib/core/errors.h"
     40 #include "tensorflow/core/lib/core/status.h"
     41 #include "tensorflow/core/platform/env.h"
     42 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
     43 #include "tensorflow/core/util/stream_executor_util.h"
     44 
     45 // OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
     46 // in error case, it returns RET instead of void.
     47 #define OP_REQUIRES_OK_RETURN(CTX, RET, ...)                \
     48   do {                                                      \
     49     ::tensorflow::Status _s(__VA_ARGS__);                   \
     50     if (!TF_PREDICT_TRUE(_s.ok())) {                        \
     51       (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
     52       return RET;                                           \
     53     }                                                       \
     54   } while (0)
     55 
     56 namespace tensorflow {
     57 
     58 namespace {
     59 
     60 XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) {
     61   DeviceType device_type = ctx->device_type();
     62   se::Platform::Id platform_id = nullptr;
     63   const XlaDevice::Metadata* xla_device_metadata = nullptr;
     64   std::unique_ptr<XlaAllocator> xla_allocator;
     65   xla::DeviceMemoryAllocator* device_allocator = nullptr;
     66 
     67   if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
     68     platform_id = se::host::kHostPlatformId;
     69   } else if (ctx->device_type() == DeviceType(DEVICE_GPU)) {
     70     platform_id = ctx->device()
     71                       ->tensorflow_gpu_device_info()
     72                       ->stream->parent()
     73                       ->platform()
     74                       ->id();
     75   } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) {
     76     // If we are on an XlaDevice, use the underlying XLA platform's allocator
     77     // directly. We could use the StreamExecutor's allocator which may
     78     // theoretically be more correct, but XLA returns a nice OOM message in a
     79     // Status and StreamExecutor does not.
     80     //
     81     // Importantly we can't use ctx->device()->GetAllocator() as the allocator
     82     // (which xla_allocator above uses) as on an XlaDevice, this is a dummy
     83     // allocator that returns XlaTensor objects. The XlaCompiler needs a real
     84     // allocator to allocate real buffers.
     85 
     86     platform_id = xla_device_metadata->platform()->id();
     87     device_allocator =
     88         xla_device_metadata->client()->backend().memory_allocator();
     89   }
     90 
     91   if (!device_allocator) {
     92     xla::StatusOr<se::Platform*> maybe_platform =
     93         se::MultiPlatformManager::PlatformWithId(platform_id);
     94     OP_REQUIRES_OK_RETURN(ctx, XlaPlatformInfo(), maybe_platform.status());
     95 
     96     xla_allocator = absl::make_unique<XlaAllocator>(
     97         maybe_platform.ValueOrDie(), ctx->device()->GetAllocator({}));
     98   }
     99 
    100   return XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
    101                          std::move(xla_allocator), device_allocator);
    102 }
    103 
    104 // A closure describing how to run a compiled version of a TensorFlow function.
    105 //
    106 // It may seem unusual to stick the resource variable snapshots in this class.
    107 // This is necessary: we need to use the snapshots observed by the compiler as
    108 // the initial values for the resource variables (and cannot snapshot them again
    109 // during execution) because otherwise we risk observing a different snapshot
    110 // with shapes different from what we compiled for.
    111 class XlaExecutableClosure {
    112  public:
    113   explicit XlaExecutableClosure(
    114       xla::LocalClient* client, xla::LocalExecutable* executable,
    115       const XlaCompiler::CompilationResult* compilation_result,
    116       std::map<int, OptionalTensor> resource_var_snapshots,
    117       int num_constant_args)
    118       : client_(client),
    119         executable_(executable),
    120         compilation_result_(compilation_result),
    121         resource_var_snapshots_(std::move(resource_var_snapshots)),
    122         num_constant_args_(num_constant_args) {}
    123 
    124   XlaExecutableClosure(XlaExecutableClosure&&) = default;
    125   XlaExecutableClosure& operator=(XlaExecutableClosure&&) = default;
    126 
    127   xla::LocalClient* client() const { return client_; }
    128   xla::LocalExecutable* executable() const { return executable_; }
    129   const XlaCompiler::CompilationResult* compilation_result() const {
    130     return compilation_result_;
    131   }
    132   const std::map<int, OptionalTensor>& resource_var_snapshots() const {
    133     return resource_var_snapshots_;
    134   }
    135   int num_constant_args() const { return num_constant_args_; }
    136 
    137  private:
    138   xla::LocalClient* client_;
    139   xla::LocalExecutable* executable_;
    140   const XlaCompiler::CompilationResult* compilation_result_;
    141   std::map<int, OptionalTensor> resource_var_snapshots_;
    142   int num_constant_args_;
    143 
    144   TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure);
    145 };
    146 
    147 // This maintains a mapping from a globally unique ID to XlaExecutableClosure
    148 // instances.
    149 class XlaExecutableClosureStore {
    150  public:
    151   XlaExecutableClosureStore() : key_counter_(0) {}
    152 
    153   using KeyT = string;
    154 
    155   KeyT Produce(XlaExecutableClosure result) {
    156     mutex_lock l(mutex_);
    157     KeyT key = absl::StrCat(key_counter_++);
    158     bool insert_successful = closures_.emplace(key, std::move(result)).second;
    159     DCHECK(insert_successful);
    160     (void)insert_successful;
    161     return key;
    162   }
    163 
    164   XlaExecutableClosure Consume(const KeyT& key) {
    165     mutex_lock l(mutex_);
    166     auto it = closures_.find(key);
    167     DCHECK(it != closures_.end());
    168     XlaExecutableClosure value = std::move(it->second);
    169     closures_.erase(it);
    170     return value;
    171   }
    172 
    173   static XlaExecutableClosureStore* Global() {
    174     static XlaExecutableClosureStore* instance = new XlaExecutableClosureStore;
    175     return instance;
    176   }
    177 
    178  private:
    179   mutex mutex_;
    180   int64 key_counter_ GUARDED_BY(mutex_);
    181   absl::flat_hash_map<KeyT, XlaExecutableClosure> closures_ GUARDED_BY(mutex_);
    182 
    183   TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
    184 };
    185 
    186 }  // namespace
    187 
    188 XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
    189                                        const std::vector<int>& constants,
    190                                        const std::vector<int>& resources,
    191                                        const NameAttrList& function)
    192     : OpKernel(ctx),
    193       constants_(constants),
    194       resources_(resources),
    195       function_(function),
    196       platform_info_(PlatformInfoFromContext(ctx)) {}
    197 
    198 static Status BuildCompilationCache(OpKernelContext* ctx,
    199                                     const XlaPlatformInfo& platform_info,
    200                                     XlaCompilationCache** cache) {
    201   if (platform_info.xla_device_metadata()) {
    202     *cache = new XlaCompilationCache(
    203         platform_info.xla_device_metadata()->client(),
    204         platform_info.xla_device_metadata()->jit_device_type());
    205     return Status::OK();
    206   }
    207 
    208   auto platform =
    209       se::MultiPlatformManager::PlatformWithId(platform_info.platform_id());
    210   if (!platform.ok()) {
    211     return platform.status();
    212   }
    213 
    214   xla::StatusOr<xla::Compiler*> compiler_for_platform =
    215       xla::Compiler::GetForPlatform(platform.ValueOrDie());
    216   if (!compiler_for_platform.ok()) {
    217     // In some rare cases (usually in unit tests with very small clusters) we
    218     // may end up transforming an XLA cluster with at least one GPU operation
    219     // (which would normally force the cluster to be compiled using XLA:GPU)
    220     // into an XLA cluster with no GPU operations (i.e. containing only CPU
    221     // operations).  Such a cluster can fail compilation (in way that
    222     // MarkForCompilation could not have detected) if the CPU JIT is not linked
    223     // in.
    224     //
    225     // So bail out of _XlaCompile in this case, and let the executor handle the
    226     // situation for us.
    227     const Status& status = compiler_for_platform.status();
    228     if (status.code() == error::NOT_FOUND) {
    229       return errors::Unimplemented("Could not find compiler for platform ",
    230                                    platform.ValueOrDie()->Name(), ": ",
    231                                    status.ToString());
    232     }
    233   }
    234 
    235   xla::LocalClientOptions client_options;
    236   client_options.set_platform(platform.ValueOrDie());
    237   client_options.set_intra_op_parallelism_threads(
    238       ctx->device()->tensorflow_cpu_worker_threads()->num_threads);
    239   auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
    240   if (!client.ok()) {
    241     return client.status();
    242   }
    243   const XlaOpRegistry::DeviceRegistration* registration;
    244   if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(),
    245                                            &registration)) {
    246     return errors::InvalidArgument("No JIT device registered for ",
    247                                    platform_info.device_type().type());
    248   }
    249   *cache = new XlaCompilationCache(
    250       client.ValueOrDie(), DeviceType(registration->compilation_device_name));
    251   return Status::OK();
    252 }
    253 
    254 static Status CompileToLocalExecutable(
    255     OpKernelContext* ctx, const NameAttrList& function,
    256     const XlaPlatformInfo& platform_info, absl::Span<const int> resources,
    257     absl::Span<const int> constants, bool lazy, xla::LocalClient** client,
    258     std::map<int, OptionalTensor>* variables,
    259     const XlaCompiler::CompilationResult** kernel,
    260     xla::LocalExecutable** executable) {
    261   // We store information about the JIT-compiled XLA computation
    262   // in the ResourceMgr.
    263   ResourceMgr* rm = ctx->resource_manager();
    264   if (!rm) {
    265     return errors::Internal("No resource manager.");
    266   }
    267 
    268   XlaCompilationCache* cache;
    269   TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
    270       rm->default_container(), "xla_cache", &cache,
    271       [&](XlaCompilationCache** cache) {
    272         return BuildCompilationCache(ctx, platform_info, cache);
    273       }));
    274   // Hold the reference to the JIT during evaluation. (We could probably
    275   // free it sooner because the ResourceMgr will retain a reference, but
    276   // this is more obviously correct.)
    277   core::ScopedUnref cache_ref(cache);
    278 
    279   TF_RETURN_IF_ERROR(SnapshotResourceVariables(ctx, resources, variables));
    280   *client = static_cast<xla::LocalClient*>(cache->client());
    281 
    282   XlaCompiler::Options options;
    283   options.client = *client;
    284   if (ctx->op_device_context() != nullptr) {
    285     options.device_ordinal =
    286         ctx->op_device_context()->stream()->parent()->device_ordinal();
    287   }
    288   options.device_type = cache->device_type();
    289   options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
    290   options.graph_def_version = ctx->function_library()->graph_def_version();
    291   options.allow_cpu_custom_calls =
    292       (platform_info.platform_id() == se::host::kHostPlatformId);
    293   options.device_allocator = platform_info.allocator();
    294   if (platform_info.xla_device_metadata()) {
    295     options.shape_representation_fn =
    296         platform_info.xla_device_metadata()->shape_representation_fn();
    297   }
    298 
    299   std::map<int, Tensor> constant_args;
    300   for (int i : constants) {
    301     constant_args.insert({i, ctx->input(i)});
    302   }
    303   XlaCompiler::CompileOptions compile_options;
    304   compile_options.is_entry_computation = true;
    305   // If we resolve constants we never emit them on the device, meaning that if
    306   // they are needed by a following computation the host has to transfer
    307   // them. Not resolving constants is expected to be faster than resolving
    308   // constants.
    309   compile_options.resolve_compile_time_constants = true;
    310   // Optimization: where possible, have the computation return a naked array
    311   // rather than a one-element tuple.
    312   compile_options.always_return_tuple = false;
    313 
    314   std::vector<XlaCompiler::Argument> args;
    315   TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
    316       constant_args, *variables, ctx, &args));
    317   return cache->Compile(options, function, args, compile_options,
    318                         lazy ? XlaCompilationCache::CompileMode::kLazy
    319                              : XlaCompilationCache::CompileMode::kStrict,
    320                         kernel, executable);
    321 }
    322 
    323 void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
    324   VLOG(1) << "XlaLocalLaunchOpBase::Compute "
    325           << Canonicalize(function_.name(), AttrSlice(&function_.attr()));
    326 
    327   xla::LocalClient* client;
    328   const XlaCompiler::CompilationResult* kernel;
    329   xla::LocalExecutable* executable;
    330   std::map<int, OptionalTensor> variables;
    331 
    332   {
    333     Status s = CompileToLocalExecutable(
    334         ctx, function_, platform_info_, resources_, constants_, /*lazy=*/false,
    335         &client, &variables, &kernel, &executable);
    336     if (!s.ok() && (platform_info_.device_type().type_string() == DEVICE_CPU ||
    337                     platform_info_.device_type().type_string() == DEVICE_GPU)) {
    338       // Suggest auto jit if the failure was with GPU or CPU.
    339       errors::AppendToMessage(&s,
    340                               xla::status_macros::kPossibleAutoJitAlternative);
    341     }
    342 
    343     OP_REQUIRES_OK(ctx, s);
    344   }
    345 
    346   se::Stream* stream =
    347       ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
    348 
    349   VLOG(1) << "Executing XLA Computation...";
    350 
    351   XlaComputationLaunchContext launch_context(
    352       client, platform_info_.allocator(),
    353       /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
    354       platform_info_.UseMultipleStreams());
    355   launch_context.PopulateInputs(ctx, kernel, variables,
    356                                 /*missing_ctx_input_prefix=*/0);
    357 
    358   // Execute the computation.
    359   VLOG(2) << "Executing computation.";
    360   xla::ExecutableRunOptions run_options;
    361   run_options.set_stream(stream);
    362   run_options.set_allocator(platform_info_.allocator());
    363   run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
    364   run_options.set_rng_seed(GetXLARandomSeed());
    365   Env* env = Env::Default();
    366   auto start_time = env->NowMicros();
    367 
    368   auto run_result = executable->Run(launch_context.arguments(), run_options);
    369   OP_REQUIRES(ctx, run_result.ok(), run_result.status());
    370 
    371   auto elapsed = env->NowMicros() - start_time;
    372   VLOG(2) << "Elapsed time: " << elapsed << "us";
    373 
    374   OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs(
    375                           ctx, kernel, run_result.ConsumeValueOrDie(),
    376                           /*missing_ctx_input_prefix=*/0));
    377   VLOG(1) << "Done";
    378 }
    379 
    380 namespace {
    381 // Helper static functions to construct parameters for
    382 // XlaLocalLaunchBase constructor from OpKernelConstruction.
    383 std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
    384   DataTypeVector constant_types;
    385   OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
    386                         ctx->GetAttr("Tconstants", &constant_types));
    387   std::vector<int> constants(constant_types.size());
    388   std::iota(constants.begin(), constants.end(), 0);
    389   return constants;
    390 }
    391 
    392 std::vector<int> ResourcesVector(OpKernelConstruction* ctx) {
    393   DataTypeVector constant_types;
    394   OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
    395                         ctx->GetAttr("Tconstants", &constant_types));
    396 
    397   DataTypeVector arg_types;
    398   OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
    399                         ctx->GetAttr("Targs", &arg_types));
    400 
    401   int num_resources;
    402   OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
    403                         ctx->GetAttr("Nresources", &num_resources));
    404 
    405   std::vector<int> resources(num_resources);
    406   std::iota(resources.begin(), resources.end(),
    407             constant_types.size() + arg_types.size());
    408   return resources;
    409 }
    410 
    411 NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
    412   const NameAttrList* func;
    413   OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func));
    414   return *func;
    415 }
    416 
    417 bool MustCompileAttr(OpKernelConstruction* ctx) {
    418   bool must_compile;
    419   OP_REQUIRES_OK_RETURN(ctx, false,
    420                         ctx->GetAttr("must_compile", &must_compile));
    421   return must_compile;
    422 }
    423 }  // namespace
    424 
    425 XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
    426     : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
    427                          FunctionAttr(ctx)) {}
    428 
    429 XlaLocalLaunchOp::~XlaLocalLaunchOp() {
    430   VLOG(1) << "XlaLocalLaunchOp destroyed";
    431 }
    432 
    433 XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
    434     : OpKernel(ctx),
    435       constants_(ConstantsVector(ctx)),
    436       resources_(ResourcesVector(ctx)),
    437       function_(FunctionAttr(ctx)),
    438       platform_info_(PlatformInfoFromContext(ctx)),
    439       must_compile_(MustCompileAttr(ctx)) {}
    440 
    441 void XlaCompileOp::Compute(OpKernelContext* ctx) {
    442   VLOG(3) << "XlaCompileOp " << def().name()
    443           << (must_compile_ ? "(must-compile)" : "");
    444   xla::LocalClient* client;
    445   const XlaCompiler::CompilationResult* kernel;
    446   xla::LocalExecutable* executable;
    447   std::map<int, OptionalTensor> variables;
    448 
    449   bool cannot_compile_cluster;
    450   {
    451     mutex_lock guard(cannot_compile_cluster_mu_);
    452     cannot_compile_cluster = cannot_compile_cluster_;
    453   }
    454 
    455   if (GetXlaOpsCommonFlags().tf_xla_always_defer_compilation ||
    456       cannot_compile_cluster) {
    457     executable = nullptr;
    458   } else {
    459     Status status = CompileToLocalExecutable(
    460         ctx, function_, platform_info_, resources_, constants_,
    461         /*lazy=*/!must_compile_, &client, &variables, &kernel, &executable);
    462     if (must_compile_ || status.code() != error::UNIMPLEMENTED) {
    463       OP_REQUIRES_OK(ctx, status);
    464     }
    465 
    466     if (status.code() == error::UNIMPLEMENTED) {
    467       LOG(WARNING) << "Compilation failed:" << status.ToString()
    468                    << ".  Falling back to TF function call.";
    469       executable = nullptr;
    470       mutex_lock guard(cannot_compile_cluster_mu_);
    471       cannot_compile_cluster_ = true;
    472     }
    473   }
    474 
    475   AllocatorAttributes host_alloc_attrs;
    476   host_alloc_attrs.set_gpu_compatible(true);
    477   host_alloc_attrs.set_on_host(true);
    478   Allocator* cpu_allocator = ctx->device()->GetAllocator(host_alloc_attrs);
    479 
    480   if (!executable) {
    481     DCHECK(!must_compile_);
    482     Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
    483 
    484     Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
    485     compilation_successful.scalar<bool>()() = false;
    486     ctx->set_output(0, Tensor(cpu_allocator, DT_STRING, TensorShape({})));
    487     ctx->set_output(1, compilation_successful);
    488     return;
    489   }
    490 
    491   // Each execution of an XlaCompile op creates a new XlaExecutableClosure, even
    492   // if it didn't have to compile the cluster because of a compilation-cache
    493   // hit.  This is because we at least need new snapshots of the resource
    494   // variables.
    495   XlaExecutableClosureStore::KeyT key =
    496       XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure(
    497           client, executable, kernel, std::move(variables), constants_.size()));
    498 
    499   Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
    500   compilation_key.flat<string>()(0) = key;
    501 
    502   Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
    503   compilation_successful.flat<bool>()(0) = true;
    504 
    505   ctx->set_output(0, compilation_key);
    506   ctx->set_output(1, compilation_successful);
    507 }
    508 
    509 XlaRunOp::XlaRunOp(OpKernelConstruction* ctx)
    510     : OpKernel(ctx), platform_info_(PlatformInfoFromContext(ctx)) {}
    511 
    512 void XlaRunOp::Compute(OpKernelContext* ctx) {
    513   VLOG(3) << "XlaRunOp " << def().name();
    514   Tensor key_tensor = ctx->input(ctx->num_inputs() - 1);
    515   const XlaExecutableClosureStore::KeyT& key = key_tensor.flat<string>()(0);
    516 
    517   XlaExecutableClosure closure =
    518       XlaExecutableClosureStore::Global()->Consume(key);
    519 
    520   XlaComputationLaunchContext launch_context(
    521       closure.client(), platform_info_.allocator(),
    522       /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
    523       /*use_multiple_streams=*/platform_info_.UseMultipleStreams());
    524 
    525   // We're missing the must-be-constant inputs, tell `PopulateInputs`
    526   // about this.  We don't actually need these inputs because they've
    527   // already been baked into the compiled kernel.
    528   launch_context.PopulateInputs(
    529       ctx, closure.compilation_result(), closure.resource_var_snapshots(),
    530       /*missing_ctx_input_prefix=*/closure.num_constant_args());
    531 
    532   se::Stream* stream =
    533       ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
    534   xla::ExecutableRunOptions run_options;
    535   run_options.set_stream(stream);
    536   run_options.set_allocator(platform_info_.allocator());
    537   run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
    538   run_options.set_rng_seed(GetXLARandomSeed());
    539   Env* env = Env::Default();
    540   auto start_time = env->NowMicros();
    541 
    542   auto run_result =
    543       closure.executable()->Run(launch_context.arguments(), run_options);
    544   OP_REQUIRES(ctx, run_result.ok(), run_result.status());
    545 
    546   auto elapsed = env->NowMicros() - start_time;
    547   VLOG(2) << "Elapsed time in computation: " << elapsed << "us";
    548 
    549   OP_REQUIRES_OK(
    550       ctx,
    551       launch_context.PopulateOutputs(
    552           ctx, closure.compilation_result(), run_result.ConsumeValueOrDie(),
    553           /*missing_ctx_input_prefix=*/closure.num_constant_args()));
    554 }
    555 
    556 REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
    557 
    558 REGISTER_KERNEL_BUILDER(Name("XlaLaunch")
    559                             .Device(DEVICE_GPU)
    560                             .HostMemory("constants")
    561                             .HostMemory("resources"),
    562                         XlaLocalLaunchOp);
    563 
    564 REGISTER_KERNEL_BUILDER(Name("_XlaCompile").Device(DEVICE_CPU), XlaCompileOp);
    565 REGISTER_KERNEL_BUILDER(Name("_XlaCompile")
    566                             .Device(DEVICE_GPU)
    567                             .HostMemory("constants")
    568                             .HostMemory("key")
    569                             .HostMemory("compilation_successful")
    570                             .HostMemory("resources"),
    571                         XlaCompileOp);
    572 
    573 REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_CPU), XlaRunOp);
    574 REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_GPU), XlaRunOp);
    575 
    576 }  // namespace tensorflow
    577