Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #define EIGEN_USE_THREADS
     16 
     17 #include "tensorflow/compiler/xla/service/hlo_runner.h"
     18 
     19 #include <set>
     20 #include <string>
     21 #include <utility>
     22 
     23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     24 #include "tensorflow/compiler/xla/layout_util.h"
     25 #include "tensorflow/compiler/xla/ptr_util.h"
     26 #include "tensorflow/compiler/xla/service/backend.h"
     27 #include "tensorflow/compiler/xla/service/executable.h"
     28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     29 #include "tensorflow/compiler/xla/service/transfer_manager.h"
     30 #include "tensorflow/compiler/xla/shape_util.h"
     31 #include "tensorflow/compiler/xla/statusor.h"
     32 #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
     33 #include "tensorflow/compiler/xla/types.h"
     34 #include "tensorflow/core/common_runtime/eigen_thread_pool.h"
     35 #include "tensorflow/core/platform/logging.h"
     36 #include "tensorflow/core/platform/types.h"
     37 
     38 namespace se = ::perftools::gputools;
     39 
     40 namespace xla {
     41 
     42 /*static*/ StatusOr<std::unique_ptr<HloModule>>
     43 HloRunner::CreateModuleFromString(const tensorflow::StringPiece hlo_string,
     44                                   const DebugOptions& debug_options) {
     45   HloModuleConfig config;
     46   config.set_debug_options(debug_options);
     47   return tools::Parse(hlo_string, config);
     48 }
     49 
     50 namespace {
     51 
     52 // Creates an HloModule from the given proto.
     53 StatusOr<std::unique_ptr<HloModule>> HloProtoToModule(
     54     const HloProto& proto, const DebugOptions& debug_options) {
     55   TF_ASSIGN_OR_RETURN(
     56       HloModuleConfig config,
     57       HloModule::CreateModuleConfigFromProto(proto.hlo_module()));
     58   config.set_debug_options(debug_options);
     59   TF_ASSIGN_OR_RETURN(auto module,
     60                       HloModule::CreateFromProto(proto.hlo_module(), config));
     61   return std::move(module);
     62 }
     63 
     64 }  // namespace
     65 
     66 /*static*/ StatusOr<std::unique_ptr<HloModule>>
     67 HloRunner::ReadModuleFromBinaryProtoFile(const std::string& filename,
     68                                          const DebugOptions& debug_options) {
     69   HloProto proto;
     70   TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
     71                                                  filename, &proto));
     72   return HloProtoToModule(proto, debug_options);
     73 }
     74 
     75 /*static*/ StatusOr<std::unique_ptr<HloModule>>
     76 HloRunner::ReadModuleFromTextProtoFile(const std::string& filename,
     77                                        const DebugOptions& debug_options) {
     78   HloProto proto;
     79   TF_RETURN_IF_ERROR(
     80       tensorflow::ReadTextProto(tensorflow::Env::Default(), filename, &proto));
     81   return HloProtoToModule(proto, debug_options);
     82 }
     83 
     84 /*static*/ StatusOr<std::unique_ptr<HloModule>>
     85 HloRunner::ReadModuleFromHloTextFile(const std::string& filename,
     86                                      const DebugOptions& debug_options) {
     87   string hlo_string;
     88   TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(),
     89                                                   filename, &hlo_string));
     90   HloModuleConfig config;
     91   config.set_debug_options(debug_options);
     92   return tools::Parse(hlo_string, config);
     93 }
     94 
     95 // Define this in .cc file to avoid having to include eigen or forward declare
     96 // these types in the header.
     97 struct HloRunner::EigenThreadPoolWrapper {
     98   std::unique_ptr<EigenThreadPoolWrapper> pool;
     99   std::unique_ptr<Eigen::ThreadPoolDevice> device;
    100 };
    101 
    102 HloRunner::HloRunner() {}
    103 
    104 HloRunner::HloRunner(se::Platform* platform) {
    105   BackendOptions backend_options;
    106   backend_options.set_platform(platform);
    107   backend_ = Backend::CreateBackend(backend_options).ConsumeValueOrDie();
    108   VLOG(1) << "Created HloRunner for platform: " << platform->Name();
    109 }
    110 
    111 HloRunner::~HloRunner() {}
    112 
    113 StatusOr<std::unique_ptr<Literal>> HloRunner::ExecuteInternal(
    114     std::unique_ptr<HloModule> module,
    115     const tensorflow::gtl::ArraySlice<Literal*> arguments,
    116     bool run_hlo_passes) {
    117   if (run_hlo_passes) {
    118     TF_ASSIGN_OR_RETURN(
    119         module, backend().compiler()->RunHloPasses(
    120                     std::move(module), backend().default_stream_executor(),
    121                     /*device_allocator=*/nullptr));
    122   }
    123   TF_ASSIGN_OR_RETURN(
    124       std::unique_ptr<Executable> executable,
    125       backend().compiler()->RunBackend(std::move(module),
    126                                        backend().default_stream_executor(),
    127                                        /*device_allocator=*/nullptr));
    128 
    129   se::Stream stream(backend().default_stream_executor());
    130   stream.Init();
    131 
    132   ExecutableRunOptions run_options;
    133   run_options.set_device_ordinal(backend().default_device_ordinal());
    134   run_options.set_stream(&stream);
    135   run_options.set_allocator(backend().memory_allocator());
    136   run_options.set_inter_op_thread_pool(backend().inter_op_thread_pool());
    137   run_options.set_intra_op_thread_pool(
    138       backend().eigen_intra_op_thread_pool_device());
    139 
    140   ServiceExecutableRunOptions service_run_options(
    141       run_options, backend().StreamBorrower(),
    142       backend().inter_op_thread_pool());
    143 
    144   // Copy arguments to device.
    145   std::vector<std::unique_ptr<ScopedShapedBuffer>> argument_buffers;
    146   std::vector<ShapedBuffer*> argument_buffer_ptrs;
    147   for (Literal* argument : arguments) {
    148     TF_ASSIGN_OR_RETURN(
    149         std::unique_ptr<ScopedShapedBuffer> argument_buffer,
    150         backend().transfer_manager()->AllocateScopedShapedBuffer(
    151             argument->shape(), run_options.allocator(),
    152             run_options.device_ordinal()));
    153     TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
    154         stream.parent(), *argument, *argument_buffer));
    155     argument_buffers.push_back(std::move(argument_buffer));
    156     argument_buffer_ptrs.push_back(argument_buffers.back().get());
    157   }
    158 
    159   TF_ASSIGN_OR_RETURN(
    160       std::unique_ptr<ShapedBuffer> result,
    161       executable->ExecuteOnStream(&service_run_options, argument_buffer_ptrs,
    162                                   /*hlo_execution_profile=*/nullptr));
    163 
    164   // Create a ScopedShapedBuffer of the result to manage deallocation. This will
    165   // deallocate all the device memory when it goes out of scope.
    166   TF_ASSIGN_OR_RETURN(
    167       std::unique_ptr<ScopedShapedBuffer> scoped_result,
    168       ScopedShapedBuffer::MakeScoped(result.get(), run_options.allocator()));
    169 
    170   auto result_literal = backend().transfer_manager()->TransferLiteralFromDevice(
    171       stream.parent(), *scoped_result);
    172   if (result_literal.ok()) {
    173     VLOG(4) << "Executed binary and got result: "
    174             << result_literal.ValueOrDie()->ToString();
    175   } else {
    176     VLOG(4) << "Executed binary and got status: "
    177             << result_literal.status().ToString();
    178   }
    179   return result_literal;
    180 }
    181 
    182 Backend& HloRunner::backend() {
    183   if (!backend_) {
    184     backend_ = Backend::CreateDefaultBackend().ConsumeValueOrDie();
    185     VLOG(1) << "executing on platform " << backend().platform()->Name();
    186   }
    187   return *backend_;
    188 }
    189 
    190 }  // namespace xla
    191