Home | History | Annotate | Download | only in tools
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Usage: replay_computation some_binary_snapshot_proto*
     17 //
     18 // Replays computations and shows the results on the command line.
     19 //
     20 // some_binary_snapshot_proto is obtained by serializing the SessionModule from
     21 // ServiceInterface::SnapshotComputation to disk.
     22 //
     23 // Computations that require arguments can be replayed using fake data by
     24 // passing --use_fake_data on the command line.  If the real data is available
     25 // in the proto and --use_fake_data is false, the real data is used.
     26 //
     27 // The output format is:
     28 //
     29 // file_path: computation_name :: type:literal_str
     30 
     31 #include <stdio.h>
     32 #include <memory>
     33 #include <string>
     34 #include <utility>
     35 #include <vector>
     36 
     37 #include "tensorflow/compiler/xla/client/client.h"
     38 #include "tensorflow/compiler/xla/client/client_library.h"
     39 #include "tensorflow/compiler/xla/client/computation.h"
     40 #include "tensorflow/compiler/xla/client/global_data.h"
     41 #include "tensorflow/compiler/xla/client/lib/testing.h"
     42 #include "tensorflow/compiler/xla/client/local_client.h"
     43 #include "tensorflow/compiler/xla/literal_util.h"
     44 #include "tensorflow/compiler/xla/service/session.pb.h"
     45 #include "tensorflow/compiler/xla/shape_util.h"
     46 #include "tensorflow/compiler/xla/status_macros.h"
     47 #include "tensorflow/compiler/xla/statusor.h"
     48 #include "tensorflow/compiler/xla/tests/test_utils.h"
     49 #include "tensorflow/compiler/xla/types.h"
     50 #include "tensorflow/compiler/xla/xla_data.pb.h"
     51 #include "tensorflow/core/lib/core/threadpool.h"
     52 #include "tensorflow/core/lib/gtl/array_slice.h"
     53 #include "tensorflow/core/platform/env.h"
     54 #include "tensorflow/core/platform/init_main.h"
     55 #include "tensorflow/core/platform/logging.h"
     56 #include "tensorflow/core/util/command_line_flags.h"
     57 
     58 namespace xla {
     59 namespace tools {
     60 namespace {
     61 
     62 // Command-line opts to this tool.  See main() for descriptions of these
     63 // fields.
     64 struct Options {
     65   string fake_infeed_shape;
     66   bool use_fake_data = false;
     67   bool print_result = true;
     68   int num_runs = 1;
     69 };
     70 
     71 // Invokes the given computation passing arbitrary data for every (unbound)
     72 // parameter if use_fake_data, Otherwise use recorded data if available.
     73 //
     74 // Similarly, infeeds fake data of shape fake_infeed_shape if it is provided;
     75 // otherwise, no infeed is performed.
     76 StatusOr<std::unique_ptr<Literal>> ReplayComputation(
     77     const SessionModule& module, Client* client, const Options& opts) {
     78   TF_ASSIGN_OR_RETURN(Computation computation, client->LoadSnapshot(module));
     79 
     80   std::vector<std::unique_ptr<GlobalData>> arguments;
     81   if (opts.use_fake_data) {
     82     arguments = MakeFakeArgumentsOrDie(computation, client);
     83   } else {  // use recorded data if available
     84     for (const auto& proto : module.arguments()) {
     85       TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Literal> literal,
     86                           Literal::CreateFromProto(proto));
     87       TF_ASSIGN_OR_RETURN(std::unique_ptr<GlobalData> data,
     88                           client->TransferToServer(*literal));
     89       arguments.push_back(std::move(data));
     90     }
     91   }
     92 
     93   // We only instantiate the thread pool if the user has requested that a
     94   // concurrent infeed occur via the fake_infeed_shape.
     95   tensorflow::gtl::optional<tensorflow::thread::ThreadPool> pool;
     96 
     97   if (!opts.fake_infeed_shape.empty()) {
     98     pool.emplace(tensorflow::Env::Default(), "infeed",
     99                  /*num_threads=*/1);
    100     pool->Schedule([opts, client]() {
    101       StatusOr<Shape> shape_status =
    102           ShapeUtil::ParseShapeString(opts.fake_infeed_shape);
    103       TF_CHECK_OK(shape_status.status());
    104       Shape shape = std::move(shape_status).ValueOrDie();
    105       StatusOr<std::unique_ptr<Literal>> data_status = MakeFakeLiteral(shape);
    106       TF_CHECK_OK(data_status.status());
    107       std::unique_ptr<Literal> data = std::move(data_status).ValueOrDie();
    108       while (true) {
    109         TF_CHECK_OK(client->TransferToInfeed(*data));
    110       }
    111     });
    112   }
    113 
    114   std::vector<GlobalData*> execute_arguments;
    115   execute_arguments.reserve(arguments.size());
    116   for (auto& argument : arguments) {
    117     execute_arguments.push_back(argument.get());
    118   }
    119 
    120   // Run the computation num_runs times, and return the result from the last
    121   // execution.
    122   std::unique_ptr<Literal> result;
    123   for (int i = 0; i < opts.num_runs; ++i) {
    124     ExecutionProfile profile;
    125     if (opts.print_result) {
    126       TF_ASSIGN_OR_RETURN(result, client->ExecuteAndTransfer(
    127                                       computation, execute_arguments,
    128                                       /*execution_options=*/nullptr, &profile));
    129     } else {
    130       // If we're not printing the result, execute the computation but don't
    131       // bother retrieving the result.  This can be a significant speedup.
    132       TF_RETURN_IF_ERROR(client
    133                              ->Execute(computation, execute_arguments,
    134                                        /*execution_options=*/nullptr, &profile)
    135                              .status());
    136     }
    137     LOG(INFO) << "Execution took "
    138               << static_cast<double>(profile.compute_time_ns()) / 1e9 << "s";
    139   }
    140 
    141   return std::move(result);
    142 }
    143 
    144 int RealMain(tensorflow::gtl::ArraySlice<char*> args, const Options& opts) {
    145   Client* client = ClientLibrary::LocalClientOrDie();
    146   tensorflow::Env* env = tensorflow::Env::Default();
    147   int exit_status = EXIT_SUCCESS;
    148   for (char* arg : args) {
    149     SessionModule module;
    150     TF_CHECK_OK(tensorflow::ReadBinaryProto(env, arg, &module));
    151     StatusOr<std::unique_ptr<Literal>> result_status =
    152         ReplayComputation(module, client, opts);
    153     if (!result_status.ok()) {
    154       fprintf(stderr, "%s: error: %s\n", arg,
    155               result_status.status().ToString().c_str());
    156       exit_status = EXIT_FAILURE;
    157       continue;
    158     }
    159 
    160     std::unique_ptr<Literal> result = result_status.ConsumeValueOrDie();
    161     if (result != nullptr) {
    162       fprintf(stdout, "%s: %s :: %s:%s\n", arg, module.entry().name().c_str(),
    163               ShapeUtil::HumanString(result->shape()).c_str(),
    164               result->ToString().c_str());
    165       if (module.has_result()) {
    166         std::unique_ptr<Literal> literal =
    167             Literal::CreateFromProto(module.result()).ConsumeValueOrDie();
    168         fprintf(stdout, "was %s:%s\n",
    169                 ShapeUtil::HumanString(module.result().shape()).c_str(),
    170                 literal->ToString().c_str());
    171       }
    172     }
    173   }
    174 
    175   ClientLibrary::DestroyLocalInstances();
    176   return exit_status;
    177 }
    178 
    179 }  // namespace
    180 }  // namespace tools
    181 }  // namespace xla
    182 
    183 int main(int argc, char** argv) {
    184   xla::tools::Options opts;
    185   const std::vector<tensorflow::Flag> flag_list = {
    186       tensorflow::Flag("use_fake_data", &opts.use_fake_data,
    187                        "Replay computation using fake data"),
    188       tensorflow::Flag("print_result", &opts.print_result,
    189                        "Print the result of the computation to stdout"),
    190       tensorflow::Flag("num_runs", &opts.num_runs,
    191                        "Number of times to run each computation"),
    192       tensorflow::Flag("fake_infeed_shape", &opts.fake_infeed_shape,
    193                        "Shape of fake data to construct for (infinite) infeed"),
    194   };
    195   xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
    196   bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
    197   tensorflow::port::InitMain(argv[0], &argc, &argv);
    198   if (argc < 2 || !parse_ok) {
    199     LOG(QFATAL) << usage;
    200   }
    201 
    202   tensorflow::gtl::ArraySlice<char*> args(argv, argc);
    203   args.pop_front();  // Pop off the binary name, argv[0]
    204   return xla::tools::RealMain(args, opts);
    205 }
    206