Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #define EIGEN_USE_THREADS
     16 
     17 #include "tensorflow/compiler/xla/service/hlo_runner.h"
     18 
     19 #include <string>
     20 #include <utility>
     21 
     22 #include "absl/memory/memory.h"
     23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     24 #include "tensorflow/compiler/xla/layout_util.h"
     25 #include "tensorflow/compiler/xla/service/hlo_module_group.h"
     26 #include "tensorflow/compiler/xla/service/hlo_parser.h"
     27 #include "tensorflow/compiler/xla/service/transfer_manager.h"
     28 #include "tensorflow/compiler/xla/shape_util.h"
     29 #include "tensorflow/core/common_runtime/eigen_thread_pool.h"
     30 #include "tensorflow/core/lib/core/blocking_counter.h"
     31 #include "tensorflow/core/platform/logging.h"
     32 #include "tensorflow/core/platform/types.h"
     33 
     34 namespace xla {
     35 
     36 /*static*/ StatusOr<std::unique_ptr<HloModule>>
     37 HloRunner::CreateModuleFromString(const absl::string_view hlo_string,
     38                                   const DebugOptions& debug_options) {
     39   HloModuleConfig config;
     40   config.set_debug_options(debug_options);
     41   return ParseHloString(hlo_string, config);
     42 }
     43 
     44 namespace {
     45 
     46 // Creates an HloModule from the given proto.
     47 StatusOr<std::unique_ptr<HloModule>> HloProtoToModule(
     48     const HloProto& proto, const DebugOptions& debug_options) {
     49   TF_ASSIGN_OR_RETURN(HloModuleConfig config,
     50                       HloModule::CreateModuleConfigFromProto(proto.hlo_module(),
     51                                                              debug_options));
     52   TF_ASSIGN_OR_RETURN(auto module,
     53                       HloModule::CreateFromProto(proto.hlo_module(), config));
     54   return std::move(module);
     55 }
     56 
     57 }  // namespace
     58 
     59 /*static*/ StatusOr<std::unique_ptr<HloModule>>
     60 HloRunner::ReadModuleFromBinaryProtoFile(const std::string& filename,
     61                                          const DebugOptions& debug_options) {
     62   HloProto proto;
     63   TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
     64                                                  filename, &proto));
     65   return HloProtoToModule(proto, debug_options);
     66 }
     67 
     68 /*static*/ StatusOr<std::unique_ptr<HloModule>>
     69 HloRunner::ReadModuleFromTextProtoFile(const std::string& filename,
     70                                        const DebugOptions& debug_options) {
     71   HloProto proto;
     72   TF_RETURN_IF_ERROR(
     73       tensorflow::ReadTextProto(tensorflow::Env::Default(), filename, &proto));
     74   return HloProtoToModule(proto, debug_options);
     75 }
     76 
     77 /*static*/ StatusOr<std::unique_ptr<HloModule>>
     78 HloRunner::ReadModuleFromHloTextFile(const std::string& filename,
     79                                      const DebugOptions& debug_options) {
     80   string hlo_string;
     81   TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(),
     82                                                   filename, &hlo_string));
     83   HloModuleConfig config;
     84   config.set_debug_options(debug_options);
     85   return ParseHloString(hlo_string, config);
     86 }
     87 
     88 HloRunner::HloRunner(se::Platform* platform) {
     89   BackendOptions backend_options;
     90   backend_options.set_platform(platform);
     91   backend_ = Backend::CreateBackend(backend_options).ConsumeValueOrDie();
     92   VLOG(1) << "Created HloRunner for platform: " << platform->Name();
     93 }
     94 
     95 HloRunner::~HloRunner() {}
     96 
     97 StatusOr<ScopedShapedBuffer> HloRunner::TransferLiteralToDevice(
     98     const Literal& literal) {
     99   TF_ASSIGN_OR_RETURN(ScopedShapedBuffer buffer,
    100                       backend().transfer_manager()->AllocateScopedShapedBuffer(
    101                           literal.shape(), backend().memory_allocator(),
    102                           backend().default_device_ordinal()));
    103   TF_ASSIGN_OR_RETURN(
    104       auto stream, backend().BorrowStream(backend().default_stream_executor()));
    105   TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
    106       stream.get(), literal, buffer));
    107   return std::move(buffer);
    108 }
    109 
    110 StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
    111     const absl::Span<const Literal* const> literals) {
    112   std::vector<ScopedShapedBuffer> buffers;
    113   for (const Literal* literal : literals) {
    114     CHECK(literal != nullptr);
    115     TF_ASSIGN_OR_RETURN(ScopedShapedBuffer buffer,
    116                         TransferLiteralToDevice(*literal));
    117     buffers.push_back(std::move(buffer));
    118   }
    119   return std::move(buffers);
    120 }
    121 
    122 StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
    123     const absl::Span<const Literal> literals) {
    124   std::vector<const Literal*> literal_pointers;
    125   literal_pointers.reserve(literals.size());
    126   for (const auto& literal : literals) {
    127     literal_pointers.push_back(&literal);
    128   }
    129   return TransferLiteralsToDevice(literal_pointers);
    130 }
    131 
    132 StatusOr<Literal> HloRunner::TransferLiteralFromDevice(
    133     const ShapedBuffer& buffer) {
    134   TF_ASSIGN_OR_RETURN(
    135       auto stream, backend().BorrowStream(backend().default_stream_executor()));
    136   return backend().transfer_manager()->TransferLiteralFromDevice(stream.get(),
    137                                                                  buffer);
    138 }
    139 
    140 StatusOr<Literal> HloRunner::Execute(
    141     std::unique_ptr<HloModule> module,
    142     const absl::Span<const Literal* const> arguments, bool run_hlo_passes,
    143     ExecutionProfile* profile) {
    144   TF_ASSIGN_OR_RETURN(std::vector<ScopedShapedBuffer> argument_buffers,
    145                       TransferLiteralsToDevice(arguments));
    146   TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
    147                       ExecuteWithDeviceBuffers(
    148                           /*module=*/std::move(module),
    149                           /*arguments=*/argument_buffers,
    150                           /*run_hlo_passes=*/run_hlo_passes,
    151                           /*profile=*/profile));
    152   return TransferLiteralFromDevice(result);
    153 }
    154 
    155 StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
    156                                      const absl::Span<const Literal> arguments,
    157                                      bool run_hlo_passes,
    158                                      ExecutionProfile* profile) {
    159   // Construct a vector of plain pointers for the arguments.
    160   std::vector<const Literal*> argument_pointers;
    161   argument_pointers.reserve(arguments.size());
    162   for (const auto& argument : arguments) {
    163     argument_pointers.push_back(&argument);
    164   }
    165   return Execute(
    166       /*module=*/std::move(module),
    167       /*arguments=*/argument_pointers,
    168       /*run_hlo_passes=*/run_hlo_passes,
    169       /*profile=*/profile);
    170 }
    171 
    172 StatusOr<Literal> HloRunner::Execute(
    173     std::unique_ptr<Executable> executable,
    174     const absl::Span<const Literal* const> arguments,
    175     ExecutionProfile* profile) {
    176   TF_ASSIGN_OR_RETURN(std::vector<ScopedShapedBuffer> argument_buffers,
    177                       TransferLiteralsToDevice(arguments));
    178   TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
    179                       ExecuteWithDeviceBuffers(
    180                           /*executable=*/executable.get(),
    181                           /*arguments=*/argument_buffers,
    182                           /*profile=*/profile));
    183   return TransferLiteralFromDevice(result);
    184 }
    185 
    186 StatusOr<Literal> HloRunner::Execute(std::unique_ptr<Executable> executable,
    187                                      const absl::Span<const Literal> arguments,
    188                                      ExecutionProfile* profile) {
    189   // Construct a vector of plain pointers for the arguments.
    190   std::vector<const Literal*> argument_pointers;
    191   argument_pointers.reserve(arguments.size());
    192   for (const auto& argument : arguments) {
    193     argument_pointers.push_back(&argument);
    194   }
    195   return Execute(
    196       /*module=*/std::move(executable),
    197       /*arguments=*/argument_pointers,
    198       /*profile=*/profile);
    199 }
    200 
    201 StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
    202     std::unique_ptr<HloModule> module,
    203     const absl::Span<const ShapedBuffer* const> arguments, bool run_hlo_passes,
    204     ExecutionProfile* profile) {
    205   // Get service run options.
    206   se::Stream stream(backend().default_stream_executor());
    207   stream.Init();
    208   ServiceExecutableRunOptions service_run_options =
    209       GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream,
    210                                     nullptr);
    211 
    212   TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
    213                       CreateExecutable(std::move(module), run_hlo_passes));
    214   TF_ASSIGN_OR_RETURN(
    215       ScopedShapedBuffer retval,
    216       executable->ExecuteOnStreamWrapper(&service_run_options,
    217                                          /*profile=*/profile, arguments));
    218   TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
    219   return std::move(retval);
    220 }
    221 
    222 StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
    223     std::unique_ptr<HloModule> module,
    224     const absl::Span<const ScopedShapedBuffer> arguments, bool run_hlo_passes,
    225     ExecutionProfile* profile) {
    226   std::vector<const ShapedBuffer*> argument_pointers;
    227   argument_pointers.reserve(arguments.size());
    228   for (const auto& argument : arguments) {
    229     argument_pointers.push_back(&argument);
    230   }
    231   return ExecuteWithDeviceBuffers(
    232       /*module=*/std::move(module),
    233       /*arguments=*/argument_pointers,
    234       /*run_hlo_passes=*/run_hlo_passes,
    235       /*profile=*/profile);
    236 }
    237 
    238 StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
    239     Executable* executable,
    240     const absl::Span<const ShapedBuffer* const> arguments,
    241     ExecutionProfile* profile) {
    242   // Get service run options.
    243   se::Stream stream(backend().default_stream_executor());
    244   stream.Init();
    245   ServiceExecutableRunOptions service_run_options =
    246       GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream,
    247                                     nullptr);
    248 
    249   TF_ASSIGN_OR_RETURN(
    250       ScopedShapedBuffer retval,
    251       executable->ExecuteOnStreamWrapper(&service_run_options,
    252                                          /*profile=*/profile, arguments));
    253   TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
    254   return std::move(retval);
    255 }
    256 
    257 StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
    258     Executable* executable,
    259     const absl::Span<const ScopedShapedBuffer> arguments,
    260     ExecutionProfile* profile) {
    261   std::vector<const ShapedBuffer*> argument_pointers;
    262   argument_pointers.reserve(arguments.size());
    263   for (const auto& argument : arguments) {
    264     argument_pointers.push_back(&argument);
    265   }
    266   return ExecuteWithDeviceBuffers(
    267       /*executable=*/std::move(executable),
    268       /*arguments=*/argument_pointers,
    269       /*profile=*/profile);
    270 }
    271 
    272 StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
    273     std::unique_ptr<HloModule> module, const ReplicatedExecuteOptions& options,
    274     DeviceAssignment* device_assignment, bool use_threads) {
    275   TF_ASSIGN_OR_RETURN(
    276       std::unique_ptr<Executable> executable,
    277       CreateExecutable(std::move(module), options.run_hlo_passes));
    278   std::vector<std::unique_ptr<se::Stream>> streams;
    279   std::vector<ServiceExecutableRunOptions> service_run_options;
    280 
    281   std::vector<ScopedShapedBuffer> argument_buffers;
    282   // This reserve() call is necessary for correctness, because
    283   // argument_buffer_ptrs contains pointers into the elements of
    284   // argument_buffers.
    285   argument_buffers.reserve(options.num_replicas * options.arguments.size());
    286 
    287   // Plus one so we can safely get &argument_buffer_ptrs[0] in case there are
    288   // no arguments.
    289   std::vector<const ShapedBuffer*> argument_buffer_ptrs(
    290       options.num_replicas * options.arguments.size() + 1);
    291   std::vector<absl::Span<const ShapedBuffer* const>> argument_buffer_slices;
    292   int64 index = 0;
    293   for (int64 i = 0; i < options.num_replicas; ++i) {
    294     int64 device = (*device_assignment)(i, 0);
    295     TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
    296                         backend().stream_executor(device));
    297     streams.push_back(absl::make_unique<se::Stream>(executor));
    298     streams.back()->Init();
    299     service_run_options.emplace_back(GetServiceRunOptionsForDevice(
    300         device, streams.back().get(), device_assignment));
    301 
    302     // Copy arguments to device.
    303     for (const Literal* argument : options.arguments) {
    304       TF_ASSIGN_OR_RETURN(
    305           ScopedShapedBuffer argument_buffer,
    306           backend().transfer_manager()->AllocateScopedShapedBuffer(
    307               argument->shape(), backend().memory_allocator(), device));
    308       TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
    309           streams.back().get(), *argument, argument_buffer));
    310       argument_buffers.push_back(std::move(argument_buffer));
    311       argument_buffer_ptrs[index++] = &argument_buffers.back();
    312     }
    313     argument_buffer_slices.emplace_back(
    314         &argument_buffer_ptrs[index - options.arguments.size()],
    315         options.arguments.size());
    316   }
    317 
    318   std::unique_ptr<tensorflow::thread::ThreadPool> pool;
    319   int64 num_threads = (options.infeed != nullptr) ? options.num_replicas : 0;
    320   if (ShapeUtil::IsInitialized(options.outfeed_shape)) {
    321     num_threads += options.num_replicas;
    322   }
    323   if (num_threads > 0) {
    324     pool = absl::make_unique<tensorflow::thread::ThreadPool>(
    325         tensorflow::Env::Default(), "infeed_outfeed",
    326         /*num_threads=*/num_threads);
    327   }
    328   if (options.infeed != nullptr) {
    329     for (int64 i = 0; i < options.num_replicas; ++i) {
    330       int64 device = (*device_assignment)(i, 0);
    331       pool->Schedule([this, device, &options]() {
    332         se::StreamExecutor* executor =
    333             backend().stream_executor(device).ValueOrDie();
    334         VLOG(1) << "Starting infeed on device " << device;
    335         for (int64 step = 1;
    336              options.infeed_steps < 0 || step <= options.infeed_steps; ++step) {
    337           TF_CHECK_OK(backend().transfer_manager()->TransferLiteralToInfeed(
    338               executor, *options.infeed));
    339           if (step % 100 == 0) {
    340             VLOG(1) << "Infeed step " << step;
    341           }
    342         }
    343       });
    344     }
    345   }
    346   if (ShapeUtil::IsInitialized(options.outfeed_shape)) {
    347     for (int64 i = 0; i < options.num_replicas; ++i) {
    348       int64 device = (*device_assignment)(i, 0);
    349       pool->Schedule([this, device, &options]() {
    350         se::StreamExecutor* executor =
    351             backend().stream_executor(device).ValueOrDie();
    352         VLOG(1) << "Starting outfeed on device " << device;
    353         for (int64 step = 1;
    354              options.infeed_steps < 0 || step <= options.infeed_steps; ++step) {
    355           Literal literal;
    356           TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed(
    357               executor, options.outfeed_shape, &literal));
    358           if (options.outfeed_values != nullptr) {
    359             options.outfeed_values->push_back(std::move(literal));
    360           }
    361           if (step % 100 == 0) {
    362             VLOG(1) << "Outfeed step " << step;
    363           }
    364         }
    365       });
    366     }
    367   }
    368 
    369   LOG(INFO) << "Replicated execution started";
    370   std::vector<ScopedShapedBuffer> results;
    371   if (!use_threads) {
    372     TF_ASSIGN_OR_RETURN(results,
    373                         executable->ExecuteOnStreams(service_run_options,
    374                                                      argument_buffer_slices));
    375   } else {
    376     tensorflow::mutex mutex;
    377     std::vector<StatusOr<ScopedShapedBuffer>> thread_results(
    378         options.num_replicas);
    379     {
    380       LOG(INFO) << "Creating thread pool for " << options.num_replicas
    381                 << " replicas";
    382       tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(),
    383                                           "replicas", options.num_replicas);
    384       for (int64 i = 0; i < options.num_replicas; ++i) {
    385         pool.Schedule([&, i] {
    386           auto result = executable->ExecuteOnStream(
    387               &service_run_options[i], argument_buffer_slices[i], nullptr);
    388           tensorflow::mutex_lock lock(mutex);
    389           thread_results[i] = std::move(result);
    390         });
    391       }
    392 
    393       // Note: the thread pool destructor guarantees it completes all work
    394       // before we leave this scope.
    395     }
    396     for (auto& thread_result : thread_results) {
    397       if (!thread_result.ok()) {
    398         return thread_result.status();
    399       }
    400       results.push_back(std::move(thread_result).ValueOrDie());
    401     }
    402   }
    403   LOG(INFO) << "Replicated execution terminated";
    404 
    405   std::vector<Literal> exec_results;
    406   for (int64 i = 0; i < options.num_replicas; ++i) {
    407     TF_RETURN_IF_ERROR(streams[i]->BlockHostUntilDone());
    408     TF_ASSIGN_OR_RETURN(Literal literal,
    409                         backend().transfer_manager()->TransferLiteralFromDevice(
    410                             streams[i].get(), results[i]));
    411     exec_results.push_back(std::move(literal));
    412   }
    413   return std::move(exec_results);
    414 }
    415 
    416 StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
    417     std::unique_ptr<HloModule> module, const ReplicatedExecuteOptions& options,
    418     bool use_threads) {
    419   TF_ASSIGN_OR_RETURN(
    420       DeviceAssignment device_assignment,
    421       backend().computation_placer()->AssignDevices(options.num_replicas, 1));
    422   return ExecuteReplicated(std::move(module), options, &device_assignment,
    423                            use_threads);
    424 }
    425 
    426 StatusOr<std::unique_ptr<Executable>> HloRunner::CreateExecutable(
    427     std::unique_ptr<HloModule> module, bool run_hlo_passes) {
    428   if (run_hlo_passes) {
    429     auto module_group = absl::make_unique<HloModuleGroup>(std::move(module));
    430     TF_ASSIGN_OR_RETURN(
    431         auto executables,
    432         backend().compiler()->Compile(std::move(module_group),
    433                                       {{backend().default_stream_executor()}},
    434                                       backend().memory_allocator()));
    435     return std::move(executables[0]);
    436   }
    437   return backend().compiler()->RunBackend(std::move(module),
    438                                           backend().default_stream_executor(),
    439                                           backend().memory_allocator());
    440 }
    441 
    442 ServiceExecutableRunOptions HloRunner::GetServiceRunOptionsForDevice(
    443     int64 device, se::Stream* stream, DeviceAssignment* device_assignment) {
    444   ExecutableRunOptions run_options;
    445   run_options.set_device_ordinal(device);
    446   run_options.set_stream(stream);
    447   run_options.set_allocator(backend().memory_allocator());
    448   run_options.set_intra_op_thread_pool(
    449       backend().eigen_intra_op_thread_pool_device());
    450   if (device_assignment != nullptr) {
    451     run_options.set_device_assignment(device_assignment);
    452   }
    453   return ServiceExecutableRunOptions(run_options, backend().StreamBorrower());
    454 }
    455 
    456 Backend& HloRunner::backend() {
    457   if (!backend_) {
    458     backend_ = Backend::CreateDefaultBackend().ConsumeValueOrDie();
    459     VLOG(1) << "Executing on platform " << backend().platform()->Name();
    460   }
    461   return *backend_;
    462 }
    463 
    464 const Backend& HloRunner::backend() const {
    465   return const_cast<HloRunner*>(this)->backend();
    466 }
    467 
    468 }  // namespace xla
    469