Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/service.h"
     17 
     18 #include <memory>
     19 #include <string>
     20 #include <utility>
     21 #include <vector>
     22 
     23 #include "tensorflow/compiler/xla/execution_options_util.h"
     24 #include "tensorflow/compiler/xla/layout_util.h"
     25 #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
     26 #include "tensorflow/compiler/xla/ptr_util.h"
     27 #include "tensorflow/compiler/xla/service/compiler.h"
     28 #include "tensorflow/compiler/xla/service/computation_layout.h"
     29 #include "tensorflow/compiler/xla/service/device_memory_allocator.h"
     30 #include "tensorflow/compiler/xla/service/executable.h"
     31 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     32 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
     33 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
     34 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     35 #include "tensorflow/compiler/xla/service/hlo_module.h"
     36 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
     37 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
     38 #include "tensorflow/compiler/xla/service/platform_util.h"
     39 #include "tensorflow/compiler/xla/service/session.pb.h"
     40 #include "tensorflow/compiler/xla/service/source_map_util.h"
     41 #include "tensorflow/compiler/xla/service/transfer_manager.h"
     42 #include "tensorflow/compiler/xla/shape_layout.h"
     43 #include "tensorflow/compiler/xla/shape_util.h"
     44 #include "tensorflow/compiler/xla/status_macros.h"
     45 #include "tensorflow/compiler/xla/types.h"
     46 #include "tensorflow/compiler/xla/util.h"
     47 #include "tensorflow/compiler/xla/xla_data.pb.h"
     48 #include "tensorflow/core/lib/gtl/cleanup.h"
     49 #include "tensorflow/core/lib/strings/strcat.h"
     50 #include "tensorflow/core/lib/strings/stringprintf.h"
     51 #include "tensorflow/core/platform/env.h"
     52 #include "tensorflow/core/platform/logging.h"
     53 #include "tensorflow/core/platform/protobuf.h"
     54 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
     55 #include "tensorflow/core/platform/types.h"
     56 
     57 namespace se = ::perftools::gputools;
     58 
     59 using ::tensorflow::strings::Printf;
     60 using ::tensorflow::strings::StrCat;
     61 using ::xla::source_map_util::InvalidParameterArgument;
     62 
     63 namespace xla {
     64 
     65 namespace {
     66 
     67 // Records the arguments used to invoke a computation in a SessionModule
     68 // proto.
     69 tensorflow::Status RecordArguments(
     70     const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
     71     se::StreamExecutor* executor, TransferManager* transfer_manager,
     72     SessionModule* module) {
     73   module->clear_arguments();
     74   for (const ShapedBuffer* argument : arguments) {
     75     TF_ASSIGN_OR_RETURN(
     76         std::unique_ptr<Literal> literal,
     77         transfer_manager->TransferLiteralFromDevice(executor, *argument));
     78     *module->add_arguments() = literal->ToProto();
     79   }
     80   return tensorflow::Status::OK();
     81 }
     82 
     83 // Records the result of a computation in a SessionModule proto.
     84 tensorflow::Status RecordResult(const ShapedBuffer& result,
     85                                 se::StreamExecutor* executor,
     86                                 TransferManager* transfer_manager,
     87                                 SessionModule* module) {
     88   module->clear_result();
     89   TF_ASSIGN_OR_RETURN(
     90       std::unique_ptr<Literal> literal,
     91       transfer_manager->TransferLiteralFromDevice(executor, result));
     92   *module->mutable_result() = literal->ToProto();
     93   return tensorflow::Status::OK();
     94 }
     95 
     96 }  // namespace
     97 
     98 ServiceOptions& ServiceOptions::set_platform(
     99     perftools::gputools::Platform* platform) {
    100   platform_ = platform;
    101   return *this;
    102 }
    103 
    104 perftools::gputools::Platform* ServiceOptions::platform() const {
    105   return platform_;
    106 }
    107 
    108 ServiceOptions& ServiceOptions::set_number_of_replicas(int number_of_replicas) {
    109   number_of_replicas_ = number_of_replicas;
    110   return *this;
    111 }
    112 
    113 int ServiceOptions::number_of_replicas() const { return number_of_replicas_; }
    114 
    115 ServiceOptions& ServiceOptions::set_intra_op_parallelism_threads(
    116     int num_threads) {
    117   intra_op_parallelism_threads_ = num_threads;
    118   return *this;
    119 }
    120 
    121 int ServiceOptions::intra_op_parallelism_threads() const {
    122   return intra_op_parallelism_threads_;
    123 }
    124 
    125 /* static */ StatusOr<std::unique_ptr<Service>> Service::NewService(
    126     perftools::gputools::Platform* platform) {
    127   ServiceOptions default_options;
    128   default_options.set_platform(platform);
    129   return NewService(default_options);
    130 }
    131 
    132 /* static */ StatusOr<std::unique_ptr<Service>> Service::NewService(
    133     const ServiceOptions& options) {
    134   perftools::gputools::Platform* platform = options.platform();
    135   std::unique_ptr<Backend> execute_backend;
    136   if (platform == nullptr) {
    137     TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
    138   }
    139   BackendOptions backend_options;
    140   backend_options.set_platform(platform);
    141   TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options));
    142 
    143   std::unique_ptr<Service> service(
    144       new Service(options, std::move(execute_backend)));
    145   return std::move(service);
    146 }
    147 
    148 Service::Service(const ServiceOptions& options,
    149                  std::unique_ptr<Backend> execute_backend)
    150     : options_(options),
    151       allocation_tracker_(execute_backend.get()),
    152       execute_backend_(std::move(execute_backend)) {
    153   CHECK_GT(options_.number_of_replicas(), 0);
    154   if (execute_backend_) {
    155     if (execute_backend_->device_count() > 0) {
    156       CHECK_GE(execute_backend_->device_count(), options_.number_of_replicas())
    157           << "Requested more replicas than there are devices.";
    158     }
    159     LOG(INFO) << Printf(
    160         "XLA service %p executing computations on platform %s. Devices:", this,
    161         execute_backend_->platform()->Name().c_str());
    162     for (int i = 0; i < execute_backend_->device_count(); ++i) {
    163       if (execute_backend_->device_ordinal_supported(i)) {
    164         se::StreamExecutor* executor =
    165             execute_backend_->stream_executor(i).ValueOrDie();
    166         const auto& description = executor->GetDeviceDescription();
    167         LOG(INFO) << Printf("  StreamExecutor device (%d): %s, %s", i,
    168                             description.name().c_str(),
    169                             description.platform_version().c_str());
    170       } else {
    171         LOG(INFO) << Printf("  StreamExecutor device (%d) not supported", i);
    172       }
    173     }
    174   } else {
    175     VLOG(1) << "XLA compile-only service constructed";
    176   }
    177 }
    178 
    179 tensorflow::Status Service::Computation(const ComputationRequest* arg,
    180                                         ComputationResponse* result) {
    181   if (arg->name().empty()) {
    182     return InvalidArgument("computation request needs a name");
    183   }
    184 
    185   *result->mutable_computation() =
    186       computation_tracker_.NewComputation(arg->name());
    187   VLOG(1) << Printf("Created new computation %s on service %p, name %s",
    188                     result->computation().ShortDebugString().c_str(), this,
    189                     arg->name().c_str());
    190   return tensorflow::Status::OK();
    191 }
    192 
    193 tensorflow::Status Service::CreateChannelHandle(
    194     const CreateChannelHandleRequest* arg,
    195     CreateChannelHandleResponse* result) {
    196   *result->mutable_channel() = channel_tracker_.NewChannel();
    197   return tensorflow::Status::OK();
    198 }
    199 
    200 tensorflow::Status Service::Unregister(const UnregisterRequest* arg,
    201                                        UnregisterResponse* result) {
    202   return allocation_tracker_.Unregister(arg->data());
    203 }
    204 
    205 // Deconstructs a previously-allocated global handle.
    206 tensorflow::Status Service::DeconstructTuple(const DeconstructTupleRequest* arg,
    207                                              DeconstructTupleResponse* result) {
    208   TF_ASSIGN_OR_RETURN(
    209       std::vector<GlobalDataHandle> elements,
    210       allocation_tracker_.DeconstructTuple(arg->tuple_handle()));
    211 
    212   for (auto& element : elements) {
    213     *result->add_element_handles() = element;
    214   }
    215   return tensorflow::Status::OK();
    216 }
    217 
    218 tensorflow::Status Service::ValidateResultShapeWithLayout(
    219     const Shape& shape_with_layout, const Shape& result_shape) const {
    220   if (!ShapeUtil::Compatible(shape_with_layout, result_shape)) {
    221     return InvalidArgument(
    222         "Shape used to set computation result layout %s is not compatible "
    223         "with result shape %s",
    224         ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(),
    225         ShapeUtil::HumanString(result_shape).c_str());
    226   }
    227   if (!LayoutUtil::HasLayout(shape_with_layout)) {
    228     return InvalidArgument(
    229         "Shape used to set computation result layout %s does not have layout",
    230         ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str());
    231   }
    232   return ShapeUtil::ValidateShape(shape_with_layout);
    233 }
    234 
    235 StatusOr<std::vector<const ShapedBuffer*>> Service::ResolveAndValidateArguments(
    236     tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments,
    237     int device_ordinal) {
    238   std::vector<const ShapedBuffer*> shaped_buffers;
    239   for (size_t i = 0; i < arguments.size(); ++i) {
    240     auto buffer_status = allocation_tracker_.Resolve(*arguments[i]);
    241     if (!buffer_status.ok()) {
    242       return Status(buffer_status.status().code(),
    243                     StrCat(buffer_status.status().error_message(), ", ",
    244                            "failed to resolve allocation for parameter ", i));
    245     }
    246     const ShapedBuffer* shaped_buffer = buffer_status.ValueOrDie();
    247 
    248     // Verify allocation is same platform and device as the execution.
    249     if (shaped_buffer->platform() != execute_backend_->platform() ||
    250         shaped_buffer->device_ordinal() != device_ordinal) {
    251       return InvalidArgument(
    252           "argument %lu is on device %s:%d but computation will be executed "
    253           "on device %s",
    254           i, shaped_buffer->platform()->Name().c_str(),
    255           shaped_buffer->device_ordinal(),
    256           execute_backend_->device_name(device_ordinal).c_str());
    257     }
    258 
    259     shaped_buffers.push_back(shaped_buffer);
    260   }
    261   return shaped_buffers;
    262 }
    263 
    264 StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
    265     const ProgramShape& program_shape,
    266     tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,
    267     const ExecutionOptions* execution_options,
    268     const UserComputation& user_computation) {
    269   auto config = MakeUnique<HloModuleConfig>(program_shape);
    270   auto* computation_layout = config->mutable_entry_computation_layout();
    271 
    272   if (program_shape.parameters_size() != argument_shapes.size()) {
    273     return InvalidArgument("computation takes %d parameters, but %zu given",
    274                            program_shape.parameters_size(),
    275                            argument_shapes.size());
    276   }
    277   for (int i = 0; i < argument_shapes.size(); ++i) {
    278     // Verify that shape of arguments matches the shape of the arguments in the
    279     // ProgramShape.
    280     if (!ShapeUtil::Compatible(*argument_shapes[i],
    281                                program_shape.parameters(i))) {
    282       return InvalidParameterArgument(
    283           *user_computation.ParameterMetadata(i).value(),
    284           "Argument does not match shape of computation parameter %d: want %s, "
    285           "got %s",
    286           i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(),
    287           ShapeUtil::HumanString(*argument_shapes[i]).c_str());
    288     }
    289     TF_RETURN_IF_ERROR(
    290         computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
    291             *argument_shapes[i]));
    292   }
    293   if (execution_options != nullptr &&
    294       execution_options->has_shape_with_output_layout()) {
    295     const auto& shape_with_output_layout =
    296         execution_options->shape_with_output_layout();
    297     TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout(shape_with_output_layout,
    298                                                      program_shape.result()));
    299     TF_RETURN_IF_ERROR(
    300         computation_layout->mutable_result_layout()->CopyLayoutFromShape(
    301             shape_with_output_layout));
    302   } else {
    303     computation_layout->mutable_result_layout()->Clear();
    304   }
    305 
    306   config->set_replica_count(options_.number_of_replicas());
    307   if (execution_options != nullptr) {
    308     config->set_seed(execution_options->seed());
    309     config->set_debug_options(execution_options->debug_options());
    310     config->enable_hlo_profiling(
    311         execution_options->debug_options().xla_hlo_profile());
    312   } else {
    313     config->set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
    314   }
    315 
    316   if (execute_backend_ != nullptr &&
    317       execute_backend_->eigen_intra_op_thread_pool() != nullptr) {
    318     config->set_intra_op_parallelism_threads(
    319         execute_backend_->eigen_intra_op_thread_pool()->NumThreads());
    320   }
    321   return std::move(config);
    322 }
    323 
    324 StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
    325     const ProgramShape& program_shape,
    326     tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
    327     const ExecutionOptions& execution_options,
    328     const UserComputation& user_computation) {
    329   std::vector<const Shape*> argument_shapes;
    330   for (const auto* arg : arguments) {
    331     argument_shapes.push_back(&arg->on_host_shape());
    332   }
    333   return CreateModuleConfig(program_shape, argument_shapes, &execution_options,
    334                             user_computation);
    335 }
    336 
    337 StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
    338     std::vector<VersionedComputationHandle> versioned_handles,
    339     std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
    340     Backend* backend,
    341     std::vector<std::vector<perftools::gputools::StreamExecutor*>> executors,
    342     DeviceMemoryAllocator* device_allocator) {
    343   VLOG(1) << Printf("BuildExecutable on service %p", this);
    344 
    345   // Dump computation proto state if flag is set.
    346   std::vector<std::unique_ptr<SessionModule>> session_modules;
    347   for (int64 i = 0; i < versioned_handles.size(); ++i) {
    348     const string& directory_path =
    349         module_configs[i]->debug_options().xla_dump_computations_to();
    350     const string& other_directory_path =
    351         module_configs[i]->debug_options().xla_dump_executions_to();
    352     if (directory_path.empty() && other_directory_path.empty()) {
    353       continue;
    354     }
    355     TF_ASSIGN_OR_RETURN(
    356         std::unique_ptr<SessionModule> session_module,
    357         computation_tracker_.SnapshotComputation(versioned_handles[i].handle));
    358     if (!directory_path.empty()) {
    359       string filename = Printf("computation_%lld__%s__version_%lld",
    360                                versioned_handles[i].handle.handle(),
    361                                session_module->entry().name().c_str(),
    362                                versioned_handles[i].version);
    363       TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename,
    364                                                      *session_module));
    365       session_modules.push_back(std::move(session_module));
    366     }
    367   }
    368 
    369   VLOG(1) << "Computation handles:";
    370   for (const VersionedComputationHandle& versioned_handle : versioned_handles) {
    371     VLOG(1) << versioned_handle;
    372   }
    373 
    374   CHECK_EQ(versioned_handles.size(), module_configs.size());
    375   std::vector<std::unique_ptr<HloModule>> modules;
    376   for (int64 i = 0; i < versioned_handles.size(); ++i) {
    377     const VersionedComputationHandle& versioned_handle = versioned_handles[i];
    378     const HloModuleConfig& config = *module_configs[i];
    379     TF_ASSIGN_OR_RETURN(auto module,
    380                         computation_tracker_.BuildHloModule(
    381                             versioned_handle, config,
    382                             /*include_unreachable_instructions=*/true));
    383     modules.push_back(std::move(module));
    384   }
    385 
    386   TF_ASSIGN_OR_RETURN(
    387       std::vector<std::unique_ptr<Executable>> executables,
    388       backend->compiler()->Compile(std::move(modules), std::move(executors),
    389                                    device_allocator));
    390 
    391   for (size_t i = 0; i < versioned_handles.size(); ++i) {
    392     if (!module_configs[i]->debug_options().xla_dump_executions_to().empty()) {
    393       executables[i]->set_session_module(std::move(session_modules[i]));
    394     }
    395   }
    396 
    397   return std::move(executables);
    398 }
    399 
    400 StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
    401     const VersionedComputationHandle& versioned_handle,
    402     std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
    403     se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) {
    404   VLOG(1) << Printf("BuildExecutable on service %p with handle %s", this,
    405                     versioned_handle.ToString().c_str());
    406 
    407   // Dump computation proto state if flag is set.
    408   std::unique_ptr<SessionModule> session_module;
    409   const string& directory_path =
    410       module_config->debug_options().xla_dump_computations_to();
    411   const string& other_directory_path =
    412       module_config->debug_options().xla_dump_executions_to();
    413   if (!directory_path.empty() || !other_directory_path.empty()) {
    414     TF_ASSIGN_OR_RETURN(
    415         session_module,
    416         computation_tracker_.SnapshotComputation(versioned_handle.handle));
    417     if (!directory_path.empty()) {
    418       string filename = Printf("computation_%lld__%s__version_%lld",
    419                                versioned_handle.handle.handle(),
    420                                session_module->entry().name().c_str(),
    421                                versioned_handle.version);
    422       TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename,
    423                                                      *session_module));
    424     }
    425   }
    426 
    427   TF_ASSIGN_OR_RETURN(
    428       std::unique_ptr<HloModule> module,
    429       computation_tracker_.BuildHloModule(versioned_handle, *module_config,
    430                                           /*include_unreachable_instructions=*/
    431                                           true));
    432 
    433   TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module));
    434 
    435   TF_ASSIGN_OR_RETURN(
    436       module, backend->compiler()->RunHloPasses(std::move(module), executor,
    437                                                 device_allocator));
    438 
    439   TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
    440                       backend->compiler()->RunBackend(
    441                           std::move(module), executor, device_allocator));
    442 
    443   if (!other_directory_path.empty()) {
    444     executable->set_session_module(std::move(session_module));
    445   }
    446 
    447   return std::move(executable);
    448 }
    449 
    450 StatusOr<std::shared_ptr<Executable>> Service::BuildAndCacheExecutable(
    451     const VersionedComputationHandle& versioned_handle,
    452     std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
    453     perftools::gputools::StreamExecutor* executor, ExecutionProfile* profile,
    454     DeviceMemoryAllocator* device_allocator) {
    455   std::shared_ptr<Executable> executable =
    456       compilation_cache_.LookUp(versioned_handle, *module_config);
    457 
    458   if (executable != nullptr) {
    459     // Executable found in the computation cache.
    460     if (profile != nullptr) {
    461       profile->set_compilation_cache_hit(true);
    462     }
    463     return executable;
    464   }
    465 
    466   uint64 start_micros =
    467       // Avoid reading the clock if we don't want timing info
    468       (profile != nullptr) ? tensorflow::Env::Default()->NowMicros() : 0;
    469 
    470   // Take a copy of the module config, as compilation introduces layouts where
    471   // layouts were optional before.
    472   HloModuleConfig original_module_config = *module_config;
    473   TF_ASSIGN_OR_RETURN(
    474       std::unique_ptr<Executable> executable_unique_ptr,
    475       BuildExecutable(versioned_handle, std::move(module_config), backend,
    476                       executor, device_allocator));
    477 
    478   if (profile != nullptr) {
    479     uint64 end_micros = tensorflow::Env::Default()->NowMicros();
    480     uint64 milliseconds = (end_micros - start_micros) / 1000;
    481     profile->set_compilation_cache_hit(false);
    482     profile->set_compile_time_ms(milliseconds);
    483   }
    484 
    485   // Insert executable into the cache.
    486   return compilation_cache_.Insert(std::move(executable_unique_ptr),
    487                                    original_module_config);
    488 }
    489 
    490 StatusOr<std::vector<GlobalDataHandle>>
    491 Service::ExecuteParallelAndRegisterResult(
    492     tensorflow::gtl::ArraySlice<Executable*> executables,
    493     tensorflow::gtl::ArraySlice<std::vector<const ShapedBuffer*>> arguments,
    494     Backend* backend, tensorflow::gtl::ArraySlice<DeviceHandle> device_handles,
    495     tensorflow::gtl::ArraySlice<string> result_tags,
    496     ExecutionProfile* profile) {
    497   // Streams where the computation are launched, so we can wait on the streams
    498   // to complete.
    499   std::vector<Pool<se::Stream>::SmartPtr> streams;
    500   std::vector<std::unique_ptr<perftools::gputools::Timer>> timers;
    501 
    502   // Global data handles for the computation results, one for each computation.
    503   std::vector<GlobalDataHandle> result_handles;
    504 
    505   // Device ID to stream executor, populated only with devices that are being
    506   // profiled.
    507   std::map<int64, se::Stream*> index_to_profiled_streams;
    508 
    509   TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
    510                       backend->computation_placer()->AssignDevices(
    511                           options_.number_of_replicas(), executables.size()));
    512 
    513   for (int64 i = 0; i < executables.size(); i++) {
    514     // Stream executors for the replicas of the current computation.
    515     TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i]));
    516     for (int64 replica = 0; replica < replicas.size(); ++replica) {
    517       TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream,
    518                           backend->BorrowStream(replicas[replica]));
    519       streams.push_back(std::move(stream));
    520 
    521       if (replica == 0 && profile != nullptr) {
    522         timers.emplace_back(
    523             new perftools::gputools::Timer(streams.back()->parent()));
    524         streams.back()
    525             ->InitTimer(timers.back().get())
    526             .ThenStartTimer(timers.back().get());
    527         CHECK(timers.front() != nullptr);
    528       }
    529 
    530       if (replica == 0 &&
    531           executables[i]->module_config().debug_options().xla_hlo_profile() &&
    532           executables[i]->hlo_profiling_enabled()) {
    533         index_to_profiled_streams[i] = streams.back().get();
    534       }
    535 
    536       // Set up run options.
    537       ExecutableRunOptions options;
    538       options.set_stream(streams.back().get());
    539       options.set_allocator(backend->memory_allocator());
    540       options.set_inter_op_thread_pool(backend->inter_op_thread_pool());
    541       options.set_intra_op_thread_pool(
    542           backend->eigen_intra_op_thread_pool_device());
    543       options.set_device_assignment(&device_assignment);
    544       ServiceExecutableRunOptions run_options(options,
    545                                               backend->StreamBorrower());
    546 
    547       // Asynchronously launch the computation.
    548       TF_ASSIGN_OR_RETURN(
    549           std::unique_ptr<ShapedBuffer> result,
    550           executables[i]->ExecuteAsyncOnStream(&run_options, arguments[i]));
    551 
    552       if (replica == 0 && profile != nullptr) {
    553         streams.back()->ThenStopTimer(timers.back().get());
    554       }
    555 
    556       // All replicas share the same device address for the result allocation,
    557       // so only one of the replicas need to register the result handle.
    558       if (replica == 0) {
    559         TF_ASSIGN_OR_RETURN(
    560             GlobalDataHandle handle,
    561             allocation_tracker_.Register(std::move(result), result_tags[i]));
    562         result_handles.push_back(handle);
    563       }
    564     }
    565   }
    566 
    567   // Wait for all executions to complete.
    568   for (int64 i = 0; i < streams.size(); ++i) {
    569     Status block_status = streams[i]->BlockHostUntilDone();
    570     if (!block_status.ok()) {
    571       return InternalError("failed to complete execution for stream %lld: %s",
    572                            i, block_status.error_message().c_str());
    573     }
    574   }
    575 
    576   // For every stream that had profiling enabled, obtain and debug-dump the HLO
    577   // profile.
    578   for (auto& index_to_profiled_stream : index_to_profiled_streams) {
    579     int64 device = index_to_profiled_stream.first;
    580     se::Stream* stream = index_to_profiled_stream.second;
    581     Executable* executable = executables[device];
    582     const HloModule& module = executable->module();
    583     HloExecutionProfile hlo_profile(&executable->hlo_profile_printer_data(),
    584                                     &executable->hlo_profile_index_map());
    585     TF_RETURN_IF_ERROR(
    586         executable->PopulateExecutionProfile(&hlo_profile, stream->parent()));
    587     XLA_LOG_LINES(
    588         tensorflow::INFO,
    589         hlo_profile.ToString(streams[0]->parent()->GetDeviceDescription()));
    590     hlo_graph_dumper::MaybeDumpHloModule(module, "Service::Execute",
    591                                          &hlo_profile);
    592   }
    593 
    594   if (profile != nullptr) {
    595     CHECK(!timers.empty());
    596     std::vector<uint64> timer_nanoseconds;
    597     timer_nanoseconds.reserve(timers.size());
    598     for (auto& timer : timers) {
    599       timer_nanoseconds.push_back(timer->Nanoseconds());
    600     }
    601     uint64 nanoseconds =
    602         *std::max_element(timer_nanoseconds.begin(), timer_nanoseconds.end());
    603 
    604     // Merge in run-time profile information from execution_profile on the
    605     // zeroth device.
    606     profile->MergeFrom(executables[0]->execution_profile());
    607 
    608     // Overall execution time (in nanoseconds) from the executor timer.
    609     profile->set_compute_and_transfer_time_ns(nanoseconds);
    610 
    611     // TODO(b/28123297): On GPU we end up including transfer time in
    612     // the compute time this way. Instead, we should get the correct
    613     // value by measuring it. Setting the field here at least lets
    614     // benchmarks provide *some* value for GPU computations.
    615     //
    616     // TODO(b/28447609): The value in compute_and_transfer_time_ns is actually
    617     // the compute time without the transfer time, so this way we get the
    618     // correct compute time. We should instead have the correct value for
    619     // compute_and_transfer_time and set compute_time to the compute time.
    620     if (profile->compute_time_ns() == 0) {
    621       profile->set_compute_time_ns(profile->compute_and_transfer_time_ns());
    622     }
    623   }
    624 
    625   return result_handles;
    626 }
    627 
    628 StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
    629     Executable* executable,
    630     const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
    631     Backend* backend, perftools::gputools::StreamExecutor* executor,
    632     const string& result_tag, ExecutionProfile* profile) {
    633   // Set up streams.
    634   std::vector<Pool<se::Stream>::SmartPtr> streams;
    635 
    636   TF_ASSIGN_OR_RETURN(auto replicas,
    637                       Replicas(*backend, SingleComputationDeviceHandle()));
    638   TF_RET_CHECK(!replicas.empty());
    639   for (se::StreamExecutor* executor : replicas) {
    640     TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream,
    641                         backend->BorrowStream(executor));
    642     streams.push_back(std::move(stream));
    643   }
    644 
    645   TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
    646                       backend->computation_placer()->AssignDevices(
    647                           options_.number_of_replicas(),
    648                           /*computation_count=*/1));
    649 
    650   // Set up run options.
    651   std::vector<ServiceExecutableRunOptions> run_options;
    652   for (const Pool<se::Stream>::SmartPtr& stream : streams) {
    653     ExecutableRunOptions options;
    654     options.set_stream(stream.get());
    655     options.set_device_ordinal(stream->parent()->device_ordinal());
    656     options.set_allocator(backend->memory_allocator());
    657     options.set_inter_op_thread_pool(backend->inter_op_thread_pool());
    658     options.set_intra_op_thread_pool(
    659         backend->eigen_intra_op_thread_pool_device());
    660     options.set_device_assignment(&device_assignment);
    661     run_options.emplace_back(options, backend->StreamBorrower(),
    662                              backend->inter_op_thread_pool());
    663   }
    664 
    665   std::unique_ptr<ShapedBuffer> result;
    666   if (options_.number_of_replicas() == 1) {
    667     TF_ASSIGN_OR_RETURN(result, executable->ExecuteOnStreamWrapper(
    668                                     &run_options[0], profile, arguments));
    669   } else {
    670     // TODO(b/69985541): Support profiling also on this path.
    671     std::vector<tensorflow::gtl::ArraySlice<const ShapedBuffer*>>
    672         repeated_arguments(options_.number_of_replicas(), arguments);
    673 
    674     TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams(
    675                                           run_options, repeated_arguments));
    676     TF_RET_CHECK(!results.empty());
    677     result = std::move(results[0]);
    678   }
    679   return allocation_tracker_.Register(std::move(result), result_tag);
    680 }
    681 
    682 tensorflow::Status Service::SetReturnValue(const SetReturnValueRequest* arg,
    683                                            SetReturnValueResponse* results) {
    684   TF_ASSIGN_OR_RETURN(UserComputation * computation,
    685                       computation_tracker_.Resolve(arg->computation()));
    686   return computation->SetReturnValue(arg->operand());
    687 }
    688 
    689 tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
    690                                             ExecuteParallelResponse* result) {
    691   VLOG(1) << "running execute-parallel request: " << arg->ShortDebugString();
    692 
    693   std::vector<std::vector<const ShapedBuffer*>> all_arguments;
    694   std::vector<std::vector<perftools::gputools::StreamExecutor*>> all_executors;
    695   std::vector<VersionedComputationHandle> versioned_handles;
    696   std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
    697   std::vector<string> computation_names;
    698   std::vector<DeviceHandle> device_handles;
    699 
    700   int num_requested_devices =
    701       std::accumulate(arg->requests().begin(), arg->requests().end(), 0,
    702                       [](int a, const ExecuteRequest& r) -> int {
    703                         return a + r.execution_options().device_handles_size();
    704                       });
    705   if (num_requested_devices * options_.number_of_replicas() >
    706       execute_backend_->device_count()) {
    707     return FailedPrecondition(
    708         "there are not enough stream executors to execute %d computations",
    709         num_requested_devices);
    710   }
    711 
    712   for (int64 i = 0; i < arg->requests_size(); ++i) {
    713     // Get the stream executor for the i'th computation. This stream executor
    714     // is one of the executors to run the replicated computation.
    715     const ExecutionOptions& execution_options =
    716         arg->requests(i).execution_options();
    717     if (execution_options.device_handles().empty()) {
    718       return FailedPrecondition(
    719           "device handles must be given to execute parallel computations");
    720     }
    721     std::vector<perftools::gputools::StreamExecutor*> executors;
    722     for (const auto& device_handle : execution_options.device_handles()) {
    723       TF_ASSIGN_OR_RETURN(auto replicas,
    724                           Replicas(*execute_backend_, device_handle));
    725       se::StreamExecutor* executor = replicas[0];
    726       CHECK(executor != nullptr);
    727       executors.push_back(executor);
    728     }
    729 
    730     // Resolve the UserComputation object associated with the requested
    731     // computation and compute the program shape.
    732     const ExecuteRequest& request = arg->requests(i);
    733     TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
    734                         computation_tracker_.Resolve(request.computation()));
    735     VersionedComputationHandle versioned_handle =
    736         user_computation->GetVersionedHandle();
    737     if (user_computation->request_count(versioned_handle.version) == 0) {
    738       return InvalidArgument("computations may not be empty");
    739     }
    740 
    741     TF_ASSIGN_OR_RETURN(
    742         std::shared_ptr<const ProgramShape> program_shape,
    743         user_computation->ComputeProgramShape(versioned_handle.version));
    744 
    745     // Resolve the allocations for the arguments of the computation, and create
    746     // a vector of device memory offsets for the arguments from the allocations.
    747     // In the case of partitioned computations, assume all arguments go on the
    748     // zeroth core.
    749     TF_ASSIGN_OR_RETURN(
    750         std::vector<const ShapedBuffer*> arguments,
    751         ResolveAndValidateArguments(request.arguments(),
    752                                     executors[0]->device_ordinal()));
    753 
    754     // Create an HloModuleConfig object for the computation, given the shape of
    755     // the program and the argument allocations.
    756     TF_ASSIGN_OR_RETURN(
    757         std::unique_ptr<HloModuleConfig> module_config,
    758         CreateModuleConfig(*program_shape, arguments,
    759                            request.execution_options(), *user_computation));
    760     VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: "
    761             << module_config->entry_computation_layout().ToString();
    762 
    763     // Adds to the vectors to build and execute the computations after the loop.
    764     all_arguments.push_back(arguments);
    765     all_arguments.insert(all_arguments.end(), executors.size() - 1, {});
    766     versioned_handles.push_back(versioned_handle);
    767     module_configs.push_back(std::move(module_config));
    768     computation_names.insert(computation_names.end(), executors.size(),
    769                              user_computation->name());
    770     all_executors.push_back(executors);
    771     device_handles.insert(device_handles.end(),
    772                           execution_options.device_handles().begin(),
    773                           execution_options.device_handles().end());
    774   }
    775 
    776   // Build the user computations into HloModules and compile to generate the
    777   // executables.
    778   //
    779   // TODO(jlebar): There's currently no way to pass a device allocator to
    780   // ExecuteParallel, so we have to pass a null device_allocator below.
    781   TF_ASSIGN_OR_RETURN(
    782       std::vector<std::unique_ptr<Executable>> executables,
    783       BuildExecutables(versioned_handles, std::move(module_configs),
    784                        execute_backend_.get(), all_executors,
    785                        /*device_allocator=*/nullptr));
    786   std::vector<Executable*> executable_ptrs;
    787   executable_ptrs.reserve(executables.size());
    788   for (const auto& executable : executables) {
    789     executable_ptrs.push_back(executable.get());
    790   }
    791 
    792   // Execute the generated executables in parallel and return the device
    793   // handles for each computation's output.
    794   ExecutionProfile profile;
    795   TF_ASSIGN_OR_RETURN(
    796       std::vector<GlobalDataHandle> outputs,
    797       ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments,
    798                                        execute_backend_.get(), device_handles,
    799                                        computation_names, &profile));
    800   for (const GlobalDataHandle& output : outputs) {
    801     ExecuteResponse response;
    802     *response.mutable_output() = output;
    803     *response.mutable_profile() = profile;
    804     *result->add_responses() = response;
    805   }
    806 
    807   VLOG(1) << "successfully completed 'execute-parallel' request";
    808   return tensorflow::Status::OK();
    809 }
    810 
    811 tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
    812                                              GetDeviceHandlesResponse* result) {
    813   const int64 available_device_count = execute_backend_->device_count();
    814   const int64 replica_count = options_.number_of_replicas();
    815   if (replica_count <= 0) {
    816     return FailedPrecondition("Replica count must be a positive integer");
    817   }
    818   if (available_device_count < arg->device_count() * replica_count) {
    819     return ResourceExhausted(
    820         "Requested device count (%lld) exceeds the number of available devices "
    821         "on the target (%lld)",
    822         arg->device_count(), available_device_count);
    823   }
    824 
    825   for (int64 i = 0; i < arg->device_count(); ++i) {
    826     DeviceHandle device_handle;
    827     device_handle.set_handle(i);
    828     device_handle.set_device_count(arg->device_count());
    829     *result->add_device_handles() = device_handle;
    830   }
    831 
    832   return tensorflow::Status::OK();
    833 }
    834 
    835 tensorflow::Status Service::Execute(const ExecuteRequest* arg,
    836                                     ExecuteResponse* result) {
    837   VLOG(1) << "running execute request: " << arg->ShortDebugString();
    838 
    839   TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
    840                       computation_tracker_.Resolve(arg->computation()));
    841 
    842   VersionedComputationHandle versioned_handle =
    843       user_computation->GetVersionedHandle();
    844 
    845   if (user_computation->request_count(versioned_handle.version) == 0) {
    846     return InvalidArgument("computations may not be empty");
    847   }
    848 
    849   // If we received multiple device handles, we must partition the module.
    850   if (arg->execution_options().device_handles_size() > 1) {
    851     ExecuteParallelRequest parallel_arg;
    852     *parallel_arg.add_requests() = *arg;
    853     ExecuteParallelResponse parallel_result;
    854     TF_RETURN_IF_ERROR(ExecuteParallel(&parallel_arg, &parallel_result));
    855     TF_RET_CHECK(parallel_result.responses_size() > 0);
    856     *result = parallel_result.responses(0);
    857     return Status::OK();
    858   }
    859 
    860   TF_ASSIGN_OR_RETURN(
    861       std::shared_ptr<const ProgramShape> program_shape,
    862       user_computation->ComputeProgramShape(versioned_handle.version));
    863 
    864   TF_ASSIGN_OR_RETURN(
    865       std::vector<const ShapedBuffer*> arguments,
    866       ResolveAndValidateArguments(arg->arguments(),
    867                                   execute_backend_->default_device_ordinal()));
    868 
    869   TF_ASSIGN_OR_RETURN(
    870       std::unique_ptr<HloModuleConfig> module_config,
    871       CreateModuleConfig(*program_shape, arguments, arg->execution_options(),
    872                          *user_computation));
    873 
    874   VLOG(3) << "Execute created HloModuleConfig computation layout: "
    875           << module_config->entry_computation_layout().ToString();
    876 
    877   TF_ASSIGN_OR_RETURN(
    878       std::shared_ptr<Executable> executable,
    879       BuildAndCacheExecutable(versioned_handle, std::move(module_config),
    880                               execute_backend_.get(),
    881                               execute_backend_->default_stream_executor(),
    882                               result->mutable_profile()));
    883 
    884   if (executable->dumping()) {
    885     executable->session_module()->set_execution_platform(
    886         execute_backend_->platform()->Name());
    887     TF_RETURN_IF_ERROR(RecordArguments(
    888         arguments, execute_backend_->default_stream_executor(),
    889         execute_backend_->transfer_manager(), executable->session_module()));
    890   }
    891 
    892   TF_ASSIGN_OR_RETURN(
    893       *result->mutable_output(),
    894       ExecuteAndRegisterResult(
    895           executable.get(), arguments, execute_backend_.get(),
    896           execute_backend_->default_stream_executor(),
    897           "result of " + user_computation->name(), result->mutable_profile()));
    898 
    899   if (executable->dumping()) {
    900     TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer,
    901                         allocation_tracker_.Resolve(result->output()));
    902     TF_RETURN_IF_ERROR(RecordResult(
    903         *result_buffer, execute_backend_->default_stream_executor(),
    904         execute_backend_->transfer_manager(), executable->session_module()));
    905     TF_RETURN_IF_ERROR(executable->DumpSessionModule());
    906   }
    907 
    908   VLOG(1) << "successfully completed 'execute' request";
    909   return tensorflow::Status::OK();
    910 }
    911 
    912 tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg,
    913                                          ExecuteAsyncResponse* result) {
    914   VLOG(1) << "running execute-async request: " << arg->ShortDebugString();
    915 
    916   TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
    917                       computation_tracker_.Resolve(arg->computation()));
    918 
    919   VersionedComputationHandle versioned_handle =
    920       user_computation->GetVersionedHandle();
    921   if (user_computation->request_count(versioned_handle.version) == 0) {
    922     return InvalidArgument("computations may not be empty");
    923   }
    924 
    925   TF_ASSIGN_OR_RETURN(
    926       std::shared_ptr<const ProgramShape> program_shape,
    927       user_computation->ComputeProgramShape(versioned_handle.version));
    928 
    929   TF_ASSIGN_OR_RETURN(
    930       std::vector<const ShapedBuffer*> arguments,
    931       ResolveAndValidateArguments(arg->arguments(),
    932                                   execute_backend_->default_device_ordinal()));
    933 
    934   TF_ASSIGN_OR_RETURN(
    935       std::unique_ptr<HloModuleConfig> module_config,
    936       CreateModuleConfig(*program_shape, arguments, arg->execution_options(),
    937                          *user_computation));
    938 
    939   VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: "
    940           << module_config->entry_computation_layout().ToString();
    941 
    942   ExecutionProfile profile;
    943 
    944   TF_ASSIGN_OR_RETURN(
    945       std::shared_ptr<Executable> executable,
    946       BuildAndCacheExecutable(
    947           versioned_handle, std::move(module_config), execute_backend_.get(),
    948           execute_backend_->default_stream_executor(), &profile));
    949 
    950   TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_,
    951                                               SingleComputationDeviceHandle()));
    952   TF_RET_CHECK(!replicas.empty());
    953 
    954   // Set up streams.
    955   std::vector<Pool<se::Stream>::SmartPtr> streams;
    956 
    957   for (se::StreamExecutor* executor : replicas) {
    958     TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream,
    959                         execute_backend_->BorrowStream(executor));
    960     streams.push_back(std::move(stream));
    961   }
    962 
    963   std::unique_ptr<ShapedBuffer> result_buffer;
    964   for (const Pool<se::Stream>::SmartPtr& stream : streams) {
    965     ExecutableRunOptions options;
    966     options.set_stream(stream.get());
    967     options.set_allocator(execute_backend_->memory_allocator());
    968     options.set_inter_op_thread_pool(execute_backend_->inter_op_thread_pool());
    969     options.set_intra_op_thread_pool(
    970         execute_backend_->eigen_intra_op_thread_pool_device());
    971 
    972     ServiceExecutableRunOptions service_options(
    973         options, execute_backend_->StreamBorrower());
    974 
    975     TF_ASSIGN_OR_RETURN(
    976         std::unique_ptr<ShapedBuffer> this_result_buffer,
    977         executable->ExecuteAsyncOnStream(&service_options, arguments));
    978 
    979     // Take the first result.
    980     if (result_buffer == nullptr) {
    981       result_buffer = std::move(this_result_buffer);
    982     }
    983   }
    984 
    985   TF_ASSIGN_OR_RETURN(
    986       GlobalDataHandle output,
    987       allocation_tracker_.Register(std::move(result_buffer),
    988                                    "result of " + user_computation->name()));
    989 
    990   *result->mutable_execution() = execution_tracker_.Register(
    991       execute_backend_.get(), std::move(streams), profile, output);
    992   streams.clear();
    993 
    994   VLOG(1) << "successfully completed 'execute-async' request";
    995   return tensorflow::Status::OK();
    996 }
    997 
    998 tensorflow::Status Service::WaitForExecution(const WaitForExecutionRequest* arg,
    999                                              WaitForExecutionResponse* result) {
   1000   TF_ASSIGN_OR_RETURN(const auto execution,
   1001                       execution_tracker_.Resolve(arg->execution()));
   1002 
   1003   TF_RETURN_IF_ERROR(execution->BlockUntilDone());
   1004 
   1005   *result->mutable_output() = execution->result();
   1006   *result->mutable_profile() = execution->profile();
   1007 
   1008   TF_RETURN_IF_ERROR(execution_tracker_.Unregister(arg->execution()));
   1009   VLOG(1) << "successfully completed 'wait-for-execution' request";
   1010   return tensorflow::Status::OK();
   1011 }
   1012 
   1013 tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg,
   1014                                              TransferToClientResponse* result) {
   1015   TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer,
   1016                       allocation_tracker_.Resolve(arg->data()));
   1017 
   1018   const Shape* return_shape;
   1019   if (arg->has_shape_with_layout()) {
   1020     if (!LayoutUtil::HasLayout(arg->shape_with_layout())) {
   1021       return InvalidArgument("shape_with_layout must have layout if present.");
   1022     }
   1023     return_shape = &arg->shape_with_layout();
   1024   } else {
   1025     return_shape = &shaped_buffer->on_host_shape();
   1026   }
   1027 
   1028   TF_ASSIGN_OR_RETURN(
   1029       se::StreamExecutor * executor,
   1030       execute_backend_->stream_executor(shaped_buffer->device_ordinal()));
   1031 
   1032   TF_ASSIGN_OR_RETURN(
   1033       std::unique_ptr<Literal> result_literal,
   1034       execute_backend_->transfer_manager()->TransferLiteralFromDevice(
   1035           executor, *shaped_buffer));
   1036 
   1037   if (LayoutUtil::LayoutsInShapesEqual(*return_shape,
   1038                                        result_literal->shape())) {
   1039     *result->mutable_literal() = result_literal->ToProto();
   1040   } else {
   1041     *result->mutable_literal() =
   1042         result_literal->Relayout(*return_shape)->ToProto();
   1043   }
   1044   return tensorflow::Status::OK();
   1045 }
   1046 
   1047 namespace {
   1048 
   1049 // Creates a clone of the given shaped buffer with the given device ordinal. The
   1050 // shape and DeviceMemoryBase values of the clone are identical to the original.
   1051 std::unique_ptr<ShapedBuffer> CloneShapedBufferOnDevice(
   1052     const ShapedBuffer& shaped_buffer, int device_ordinal) {
   1053   auto clone = MakeUnique<ShapedBuffer>(
   1054       shaped_buffer.on_host_shape(), shaped_buffer.on_device_shape(),
   1055       shaped_buffer.platform(), device_ordinal);
   1056   clone->buffers() = shaped_buffer.buffers();
   1057   return clone;
   1058 }
   1059 
   1060 }  // namespace
   1061 
   1062 tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
   1063                                              TransferToServerResponse* result) {
   1064   TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
   1065                       Literal::CreateFromProto(arg->literal()));
   1066   const Shape& shape = literal->shape();
   1067 
   1068   std::vector<se::StreamExecutor*> replicas;
   1069   if (arg->has_device_handle()) {
   1070     TF_ASSIGN_OR_RETURN(replicas,
   1071                         Replicas(*execute_backend_, arg->device_handle()));
   1072   } else {
   1073     TF_ASSIGN_OR_RETURN(
   1074         replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle()));
   1075   }
   1076 
   1077   // All memory allocation is done on the first replica. The allocations in all
   1078   // other replicas mirror the firsts'.
   1079   int master_device_ordinal = replicas[0]->device_ordinal();
   1080   TF_ASSIGN_OR_RETURN(
   1081       std::unique_ptr<ShapedBuffer> shaped_buffer,
   1082       execute_backend_->transfer_manager()->AllocateShapedBuffer(
   1083           shape, execute_backend_->memory_allocator(), master_device_ordinal));
   1084 
   1085   // Transfer the data to the replicas.
   1086   for (se::StreamExecutor* executor : replicas) {
   1087     if (executor->device_ordinal() == master_device_ordinal) {
   1088       TF_RETURN_IF_ERROR(
   1089           execute_backend_->transfer_manager()->TransferLiteralToDevice(
   1090               executor, *literal, *shaped_buffer));
   1091     } else {
   1092       // The replica is not the master. Create an cloned shaped buffer with
   1093       // the replica's device ordinal. This is required because
   1094       // TransferLiteralToDevice verifies that the device ordinal of the shaped
   1095       // buffer matches that of the executor.
   1096       std::unique_ptr<ShapedBuffer> clone =
   1097           CloneShapedBufferOnDevice(*shaped_buffer, executor->device_ordinal());
   1098       TF_RETURN_IF_ERROR(
   1099           execute_backend_->transfer_manager()->TransferLiteralToDevice(
   1100               executor, *literal, *clone));
   1101     }
   1102   }
   1103   TF_ASSIGN_OR_RETURN(
   1104       *result->mutable_data(),
   1105       allocation_tracker_.Register(std::move(shaped_buffer),
   1106                                    StrCat("TransferToServer literal of shape ",
   1107                                           ShapeUtil::HumanString(shape))));
   1108 
   1109   return tensorflow::Status::OK();
   1110 }
   1111 
   1112 tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
   1113                                              TransferToInfeedResponse* result) {
   1114   const int64 replica_count = options_.number_of_replicas();
   1115   if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
   1116     return FailedPrecondition(
   1117         "%s",
   1118         StrCat("The replica_id=", arg->replica_id(),
   1119                " on TransferToInfeedRequest not in range [0, replica_count=",
   1120                replica_count, ").")
   1121             .c_str());
   1122   }
   1123 
   1124   se::StreamExecutor* executor;
   1125   if (arg->has_device_handle()) {
   1126     TF_ASSIGN_OR_RETURN(auto replicas,
   1127                         Replicas(*execute_backend_, arg->device_handle()));
   1128     executor = replicas[arg->replica_id()];
   1129   } else {
   1130     TF_ASSIGN_OR_RETURN(
   1131         auto replicas,
   1132         Replicas(*execute_backend_, SingleComputationDeviceHandle()));
   1133     executor = replicas[arg->replica_id()];
   1134   }
   1135 
   1136   TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
   1137                       Literal::CreateFromProto(arg->literal()));
   1138   return execute_backend_->transfer_manager()->TransferLiteralToInfeed(
   1139       executor, *literal);
   1140 }
   1141 
   1142 tensorflow::Status Service::TransferFromOutfeed(
   1143     const TransferFromOutfeedRequest* arg,
   1144     TransferFromOutfeedResponse* result) {
   1145   const int64 replica_count = options_.number_of_replicas();
   1146   if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
   1147     return FailedPrecondition(
   1148         "The replica_id=%lld on TransferFromOutfeedRequest not in range [0, "
   1149         "%lld)",
   1150         arg->replica_id(), replica_count);
   1151   }
   1152 
   1153   se::StreamExecutor* executor;
   1154   if (arg->has_device_handle()) {
   1155     TF_ASSIGN_OR_RETURN(auto replicas,
   1156                         Replicas(*execute_backend_, arg->device_handle()));
   1157     executor = replicas[arg->replica_id()];
   1158   } else {
   1159     TF_ASSIGN_OR_RETURN(
   1160         auto replicas,
   1161         Replicas(*execute_backend_, SingleComputationDeviceHandle()));
   1162     executor = replicas[arg->replica_id()];
   1163   }
   1164 
   1165   Literal literal;
   1166   TF_RETURN_IF_ERROR(
   1167       execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
   1168           executor, arg->shape_with_layout(), &literal));
   1169   *result->mutable_literal() = literal.ToProto();
   1170   return tensorflow::Status::OK();
   1171 }
   1172 
   1173 tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg,
   1174                                         ResetDeviceResponse* result) {
   1175   return execute_backend_->ResetDevices();
   1176 }
   1177 
   1178 tensorflow::Status Service::IsConstant(const IsConstantRequest* arg,
   1179                                        IsConstantResponse* result) {
   1180   TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
   1181                       computation_tracker_.Resolve(arg->computation()));
   1182 
   1183   VersionedComputationHandle versioned_handle =
   1184       user_computation->GetVersionedHandleAtOperation(arg->operand());
   1185 
   1186   if (user_computation->request_count(versioned_handle.version) == 0) {
   1187     return InvalidArgument("computations may not be empty");
   1188   }
   1189 
   1190   TF_ASSIGN_OR_RETURN(
   1191       bool is_constant,
   1192       user_computation->IsConstant(arg->operand(), arg->num_parameters()));
   1193 
   1194   result->set_is_constant(is_constant);
   1195   return tensorflow::Status::OK();
   1196 }
   1197 
   1198 tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg,
   1199                                             ComputeConstantResponse* result) {
   1200   TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
   1201                       computation_tracker_.Resolve(arg->computation()));
   1202 
   1203   VersionedComputationHandle versioned_handle =
   1204       user_computation->GetVersionedHandleAtOperation(arg->operand());
   1205 
   1206   if (user_computation->request_count(versioned_handle.version) == 0) {
   1207     return InvalidArgument("computations may not be empty");
   1208   }
   1209 
   1210   TF_ASSIGN_OR_RETURN(
   1211       bool is_constant,
   1212       user_computation->IsConstant(arg->operand(), arg->parameters_size()));
   1213   if (!is_constant) {
   1214     StatusOr<const OperationRequest*> op_request_status =
   1215         user_computation->LookUpRequestForErrorReporting(arg->operand());
   1216     string op_request_string = "<unknown operation>";
   1217     if (op_request_status.ok()) {
   1218       op_request_string = op_request_status.ValueOrDie()->ShortDebugString();
   1219     }
   1220     return InvalidArgument(
   1221         "Operand to ComputeConstant depends on a parameter.\n\n"
   1222         "  op requested for constant evaluation: %s\n\n"
   1223         "This is an internal error that typically happens when the XLA user "
   1224         "(e.g. TensorFlow) is attempting to determine a value that must be a "
   1225         "compile-time constant (e.g. an array dimension) but it is not capable "
   1226         "of being evaluated at XLA compile time.\n\n"
   1227         "Please file a usability bug with the framework being used (e.g. "
   1228         "TensorFlow).",
   1229         op_request_string.c_str());
   1230   }
   1231 
   1232   // We can't use ComputeProgramShape because it checks that all parameter
   1233   // instructions are present and contiguous. Instead construct ProgramShape
   1234   // directly.
   1235   ProgramShape program_shape;
   1236   TF_ASSIGN_OR_RETURN(*program_shape.mutable_result(),
   1237                       user_computation->GetShape(arg->operand()));
   1238 
   1239   TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result()));
   1240 
   1241   ExecutionOptions execution_options = xla::CreateDefaultExecutionOptions();
   1242   execution_options.mutable_debug_options()->set_xla_enable_fast_math(false);
   1243   execution_options.mutable_debug_options()
   1244       ->set_xla_eliminate_hlo_implicit_broadcast(true);
   1245   *execution_options.mutable_shape_with_output_layout() =
   1246       program_shape.result();
   1247 
   1248   Shape shape_with_output_layout(program_shape.result());
   1249   if (arg->has_output_layout()) {
   1250     TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape(
   1251         arg->output_layout(), execution_options.shape_with_output_layout()));
   1252     *execution_options.mutable_shape_with_output_layout()->mutable_layout() =
   1253         arg->output_layout();
   1254   }
   1255 
   1256   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config,
   1257                       CreateModuleConfig(program_shape, {}, execution_options,
   1258                                          *user_computation));
   1259 
   1260   // Exclude dead parameter instructions for the purpose of computing constants.
   1261   TF_ASSIGN_OR_RETURN(
   1262       std::unique_ptr<HloModule> module,
   1263       computation_tracker_.BuildHloModule(versioned_handle, *module_config,
   1264                                           /*include_unreachable_instructions=*/
   1265                                           false));
   1266 
   1267   std::vector<std::unique_ptr<Literal>> parameters(arg->parameters_size());
   1268   for (int64 i = 0; i < arg->parameters_size(); ++i) {
   1269     TF_ASSIGN_OR_RETURN(parameters[i],
   1270                         Literal::CreateFromProto(arg->parameters(i)));
   1271   }
   1272   HloEvaluator evaluator;
   1273   TF_ASSIGN_OR_RETURN(
   1274       auto result_literal,
   1275       evaluator.Evaluate<std::unique_ptr<Literal>>(*module, parameters));
   1276 
   1277   // Since the shape_with_output_layout option in ExecutionOption is
   1278   // non-effective to the Evaluator results, explicit relayout here.
   1279   if (arg->has_output_layout()) {
   1280     result_literal = result_literal->Relayout(arg->output_layout());
   1281   }
   1282   *result->mutable_literal() = result_literal->ToProto();
   1283 
   1284   return tensorflow::Status::OK();
   1285 }
   1286 
   1287 tensorflow::Status Service::GetShape(const GetShapeRequest* arg,
   1288                                      GetShapeResponse* result) {
   1289   TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer,
   1290                       allocation_tracker_.Resolve(arg->data()));
   1291   *result->mutable_shape() = buffer->on_host_shape();
   1292   return tensorflow::Status::OK();
   1293 }
   1294 
   1295 tensorflow::Status Service::GetComputationShape(
   1296     const GetComputationShapeRequest* arg,
   1297     GetComputationShapeResponse* result) {
   1298   TF_ASSIGN_OR_RETURN(UserComputation * computation,
   1299                       computation_tracker_.Resolve(arg->computation()));
   1300 
   1301   VersionedComputationHandle versioned_handle =
   1302       computation->GetVersionedHandle();
   1303 
   1304   TF_ASSIGN_OR_RETURN(auto program_shape, computation->ComputeProgramShape(
   1305                                               versioned_handle.version));
   1306   *result->mutable_program_shape() = *program_shape;
   1307   return tensorflow::Status::OK();
   1308 }
   1309 
   1310 tensorflow::Status Service::GetLocalShape(const GetLocalShapeRequest* arg,
   1311                                           GetLocalShapeResponse* result) {
   1312   TF_ASSIGN_OR_RETURN(UserComputation * computation,
   1313                       computation_tracker_.Resolve(arg->computation()));
   1314 
   1315   TF_ASSIGN_OR_RETURN(*result->mutable_shape(),
   1316                       computation->GetShape(arg->operand()));
   1317   return tensorflow::Status::OK();
   1318 }
   1319 
   1320 tensorflow::Status Service::GetComputationStats(
   1321     const ComputationStatsRequest* arg, ComputationStatsResponse* result) {
   1322   TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
   1323                       computation_tracker_.Resolve(arg->computation()));
   1324 
   1325   VersionedComputationHandle versioned_handle =
   1326       user_computation->GetVersionedHandle();
   1327 
   1328   HloModuleConfig config;
   1329   config.set_debug_options(arg->debug_options());
   1330   TF_ASSIGN_OR_RETURN(
   1331       std::unique_ptr<HloModule> module,
   1332       computation_tracker_.BuildHloModule(versioned_handle, config));
   1333 
   1334   hlo_graph_dumper::MaybeDumpHloModule(*module,
   1335                                        "computation statistics subject");
   1336 
   1337   // Run HLO analysis to get the computation statistics.
   1338   HloCostAnalysis analysis(
   1339       execute_backend_->compiler()->ShapeSizeBytesFunction());
   1340 
   1341   TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&analysis));
   1342 
   1343   ComputationStats stats;
   1344   stats.set_flop_count(analysis.flop_count());
   1345   stats.set_transcendental_count(analysis.transcendental_count());
   1346   *result->mutable_stats() = stats;
   1347   return tensorflow::Status::OK();
   1348 }
   1349 
   1350 template <typename RequestT, typename ResponseT>
   1351 tensorflow::Status Service::AddInstruction(
   1352     const RequestT* arg, ResponseT* result,
   1353     const std::function<StatusOr<ComputationDataHandle>(UserComputation*)>&
   1354         adder) {
   1355   TF_ASSIGN_OR_RETURN(UserComputation * computation,
   1356                       computation_tracker_.Resolve(arg->computation()));
   1357 
   1358   TF_ASSIGN_OR_RETURN(*result->mutable_output(), adder(computation));
   1359   return tensorflow::Status::OK();
   1360 }
   1361 
   1362 tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
   1363   TF_ASSIGN_OR_RETURN(UserComputation * computation,
   1364                       computation_tracker_.Resolve(arg->computation()));
   1365   StatusOr<ComputationDataHandle> handle_status;
   1366 
   1367   switch (arg->op_case()) {
   1368     case OpRequest::kBatchNormTrainingRequest:
   1369       handle_status = computation->AddBatchNormTrainingInstruction(
   1370           arg->batch_norm_training_request());
   1371       break;
   1372     case OpRequest::kBatchNormInferenceRequest:
   1373       handle_status = computation->AddBatchNormInferenceInstruction(
   1374           arg->batch_norm_inference_request());
   1375       break;
   1376     case OpRequest::kBatchNormGradRequest:
   1377       handle_status = computation->AddBatchNormGradInstruction(
   1378           arg->batch_norm_grad_request());
   1379       break;
   1380     case OpRequest::kBinaryOpRequest:
   1381       handle_status =
   1382           computation->AddBinaryInstruction(arg->binary_op_request());
   1383       break;
   1384     case OpRequest::kBroadcastRequest:
   1385       handle_status =
   1386           computation->AddBroadcastInstruction(arg->broadcast_request());
   1387       break;
   1388     case OpRequest::kCallRequest: {
   1389       TF_ASSIGN_OR_RETURN(
   1390           UserComputation * to_apply,
   1391           computation_tracker_.Resolve(arg->call_request().to_apply()));
   1392       handle_status =
   1393           computation->AddCallInstruction(arg->call_request(), *to_apply);
   1394       break;
   1395     }
   1396     case OpRequest::kConcatenateRequest:
   1397       handle_status =
   1398           computation->AddConcatenateInstruction(arg->concatenate_request());
   1399       break;
   1400     case OpRequest::kConditionalRequest: {
   1401       TF_ASSIGN_OR_RETURN(UserComputation * true_computation,
   1402                           computation_tracker_.Resolve(
   1403                               arg->conditional_request().true_computation()));
   1404       TF_ASSIGN_OR_RETURN(UserComputation * false_computation,
   1405                           computation_tracker_.Resolve(
   1406                               arg->conditional_request().false_computation()));
   1407       handle_status = computation->AddConditionalInstruction(
   1408           arg->conditional_request(), *true_computation, *false_computation);
   1409       break;
   1410     }
   1411     case OpRequest::kConstantRequest:
   1412       handle_status =
   1413           computation->AddConstantInstruction(arg->constant_request());
   1414       break;
   1415     case OpRequest::kConvertRequest:
   1416       handle_status =
   1417           computation->AddConvertInstruction(arg->convert_request());
   1418       break;
   1419     case OpRequest::kBitcastConvertRequest:
   1420       handle_status = computation->AddBitcastConvertInstruction(
   1421           arg->bitcast_convert_request());
   1422       break;
   1423     case OpRequest::kConvolveRequest:
   1424       handle_status =
   1425           computation->AddConvolveInstruction(arg->convolve_request());
   1426       break;
   1427     case OpRequest::kCrossReplicaSumRequest:
   1428       handle_status = computation->AddCrossReplicaSumInstruction(
   1429           arg->cross_replica_sum_request());
   1430       break;
   1431     case OpRequest::kCustomCallRequest:
   1432       handle_status =
   1433           computation->AddCustomCallInstruction(arg->custom_call_request());
   1434       break;
   1435     case OpRequest::kDotRequest:
   1436       handle_status = computation->AddDotInstruction(arg->dot_request());
   1437       break;
   1438     case OpRequest::kDynamicSliceRequest:
   1439       handle_status =
   1440           computation->AddDynamicSliceInstruction(arg->dynamic_slice_request());
   1441       break;
   1442     case OpRequest::kDynamicUpdateSliceRequest:
   1443       handle_status = computation->AddDynamicUpdateSliceInstruction(
   1444           arg->dynamic_update_slice_request());
   1445       break;
   1446     case OpRequest::kFftRequest:
   1447       handle_status = computation->AddFftInstruction(arg->fft_request());
   1448       break;
   1449     case OpRequest::kGatherRequest:
   1450       handle_status = computation->AddGatherInstruction(arg->gather_request());
   1451       break;
   1452     case OpRequest::kGetTupleElementRequest:
   1453       handle_status = computation->AddGetTupleElementInstruction(
   1454           arg->get_tuple_element_request());
   1455       break;
   1456     case OpRequest::kInfeedRequest:
   1457       handle_status = computation->AddInfeedInstruction(arg->infeed_request());
   1458       break;
   1459     case OpRequest::kOutfeedRequest:
   1460       handle_status =
   1461           computation->AddOutfeedInstruction(arg->outfeed_request());
   1462       break;
   1463     case OpRequest::kHostComputeRequest:
   1464       handle_status =
   1465           computation->AddHostComputeInstruction(arg->host_compute_request());
   1466       break;
   1467     case OpRequest::kMapRequest: {
   1468       TF_ASSIGN_OR_RETURN(
   1469           UserComputation * to_apply,
   1470           computation_tracker_.Resolve(arg->map_request().to_apply()));
   1471       handle_status =
   1472           computation->AddMapInstruction(arg->map_request(), *to_apply);
   1473       break;
   1474     }
   1475     case OpRequest::kPadRequest:
   1476       handle_status = computation->AddPadInstruction(arg->pad_request());
   1477       break;
   1478     case OpRequest::kParameterRequest:
   1479       handle_status =
   1480           computation->AddParameterInstruction(arg->parameter_request());
   1481       break;
   1482     case OpRequest::kReduceRequest: {
   1483       TF_ASSIGN_OR_RETURN(
   1484           UserComputation * to_apply,
   1485           computation_tracker_.Resolve(arg->reduce_request().to_apply()));
   1486       handle_status =
   1487           computation->AddReduceInstruction(arg->reduce_request(), *to_apply);
   1488       break;
   1489     }
   1490     case OpRequest::kReducePrecisionRequest: {
   1491       handle_status = computation->AddReducePrecisionInstruction(
   1492           arg->reduce_precision_request());
   1493       break;
   1494     }
   1495     case OpRequest::kReduceWindowRequest: {
   1496       TF_ASSIGN_OR_RETURN(UserComputation * to_apply,
   1497                           computation_tracker_.Resolve(
   1498                               arg->reduce_window_request().to_apply()));
   1499       handle_status = computation->AddReduceWindowInstruction(
   1500           arg->reduce_window_request(), *to_apply);
   1501       break;
   1502     }
   1503     case OpRequest::kReshapeRequest:
   1504       handle_status =
   1505           computation->AddReshapeInstruction(arg->reshape_request());
   1506       break;
   1507     case OpRequest::kReverseRequest:
   1508       handle_status =
   1509           computation->AddReverseInstruction(arg->reverse_request());
   1510       break;
   1511     case OpRequest::kRngRequest:
   1512       handle_status = computation->AddRngInstruction(arg->rng_request());
   1513       break;
   1514     case OpRequest::kSelectAndScatterRequest: {
   1515       TF_ASSIGN_OR_RETURN(UserComputation * select,
   1516                           computation_tracker_.Resolve(
   1517                               arg->select_and_scatter_request().select()));
   1518       TF_ASSIGN_OR_RETURN(UserComputation * scatter,
   1519                           computation_tracker_.Resolve(
   1520                               arg->select_and_scatter_request().scatter()));
   1521       handle_status = computation->AddSelectAndScatterInstruction(
   1522           arg->select_and_scatter_request(), *select, *scatter);
   1523       break;
   1524     }
   1525     case OpRequest::kSliceRequest:
   1526       handle_status = computation->AddSliceInstruction(arg->slice_request());
   1527       break;
   1528     case OpRequest::kTernaryOpRequest:
   1529       handle_status =
   1530           computation->AddTernaryInstruction(arg->ternary_op_request());
   1531       break;
   1532     case OpRequest::kTraceRequest:
   1533       return computation->AddTraceInstruction(arg->trace_request());
   1534     case OpRequest::kTransposeRequest:
   1535       handle_status =
   1536           computation->AddTransposeInstruction(arg->transpose_request());
   1537       break;
   1538     case OpRequest::kUnaryOpRequest:
   1539       handle_status = computation->AddUnaryInstruction(arg->unary_op_request());
   1540       break;
   1541     case OpRequest::kVariadicOpRequest:
   1542       handle_status =
   1543           computation->AddVariadicInstruction(arg->variadic_op_request());
   1544       break;
   1545     case OpRequest::kWhileRequest: {
   1546       TF_ASSIGN_OR_RETURN(
   1547           UserComputation * condition,
   1548           computation_tracker_.Resolve(arg->while_request().condition()));
   1549       TF_ASSIGN_OR_RETURN(
   1550           UserComputation * body,
   1551           computation_tracker_.Resolve(arg->while_request().body()));
   1552       handle_status = computation->AddWhileInstruction(arg->while_request(),
   1553                                                        *condition, *body);
   1554       break;
   1555     }
   1556     case OpRequest::kSendRequest: {
   1557       TF_RETURN_IF_ERROR(
   1558           channel_tracker_.RegisterSend(arg->send_request().channel_handle()));
   1559       TF_RETURN_IF_ERROR(computation->AddSendInstruction(arg->send_request()));
   1560       return tensorflow::Status::OK();
   1561     }
   1562     case OpRequest::kRecvRequest: {
   1563       TF_RETURN_IF_ERROR(
   1564           channel_tracker_.RegisterRecv(arg->recv_request().channel_handle()));
   1565       handle_status = computation->AddRecvInstruction(arg->recv_request());
   1566       break;
   1567     }
   1568     case OpRequest::OP_NOT_SET:
   1569       return InvalidArgument("XLA service received OpRequest with OP_NOT_SET");
   1570     default:
   1571       return InvalidArgument("Unsupported operation in XLA service");
   1572   }
   1573   TF_ASSIGN_OR_RETURN(*result->mutable_output(), handle_status);
   1574 
   1575   // We set the debug metadata here, because we slice off part of the OpRequest
   1576   // proto in the above switch statement.
   1577   TF_ASSIGN_OR_RETURN(ComputationDataHandle handle, handle_status);
   1578   TF_RETURN_IF_ERROR(computation->SetOpMetadata(handle, arg->metadata()));
   1579   if (arg->has_sharding()) {
   1580     TF_RETURN_IF_ERROR(computation->SetOpSharding(handle, arg->sharding()));
   1581   }
   1582   return tensorflow::Status::OK();
   1583 }
   1584 
   1585 tensorflow::Status Service::SnapshotComputation(
   1586     const SnapshotComputationRequest* arg,
   1587     SnapshotComputationResponse* result) {
   1588   TF_ASSIGN_OR_RETURN(
   1589       std::unique_ptr<SessionModule> module,
   1590       computation_tracker_.SnapshotComputation(arg->computation()));
   1591 
   1592   result->set_allocated_module(module.release());
   1593 
   1594   return tensorflow::Status::OK();
   1595 }
   1596 
   1597 tensorflow::Status Service::LoadComputationSnapshot(
   1598     const LoadComputationSnapshotRequest* arg,
   1599     LoadComputationSnapshotResponse* result) {
   1600   TF_ASSIGN_OR_RETURN(*result->mutable_computation(),
   1601                       computation_tracker_.LoadSessionModule(arg->module()));
   1602   return tensorflow::Status::OK();
   1603 }
   1604 
   1605 DeviceHandle Service::SingleComputationDeviceHandle() const {
   1606   DeviceHandle device_handle;
   1607   device_handle.set_handle(0);
   1608   device_handle.set_device_count(1);
   1609   return device_handle;
   1610 }
   1611 
   1612 StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Service::Replicas(
   1613     const Backend& backend, const DeviceHandle& device_handle) const {
   1614   std::vector<perftools::gputools::StreamExecutor*> replicas;
   1615   for (int replica = 0; replica < options_.number_of_replicas(); ++replica) {
   1616     // From the computation placer, find out the device ids of the replicas for
   1617     // the given device handle.
   1618     TF_ASSIGN_OR_RETURN(
   1619         int device_ordinal,
   1620         backend.computation_placer()->DeviceId(replica, device_handle.handle(),
   1621                                                options_.number_of_replicas(),
   1622                                                device_handle.device_count()));
   1623     TF_ASSIGN_OR_RETURN(auto executor, backend.stream_executor(device_ordinal));
   1624     replicas.push_back(executor);
   1625   }
   1626   return replicas;
   1627 }
   1628 
   1629 Status Service::MaybeDumpHloModule(const HloModule& module) const {
   1630   const string xla_dump_unoptimized_hlo_proto_to =
   1631       module.config().debug_options().xla_dump_unoptimized_hlo_proto_to();
   1632   if (xla_dump_unoptimized_hlo_proto_to.empty()) {
   1633     return Status::OK();
   1634   }
   1635   HloProto proto = MakeHloProto(module);
   1636   return protobuf_util::DumpProtoToDirectory(
   1637       proto, xla_dump_unoptimized_hlo_proto_to, module.name());
   1638 }
   1639 
   1640 }  // namespace xla
   1641