1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Usage: replay_computation some_binary_snapshot_proto* 17 // 18 // Replays computations and shows the results on the command line. 19 // 20 // some_binary_snapshot_proto is obtained by serializing the SessionModule from 21 // ServiceInterface::SnapshotComputation to disk. 22 // 23 // Computations that require arguments can be replayed using fake data by 24 // passing --use_fake_data on the command line. If the real data is available 25 // in the proto and --use_fake_data is false, the real data is used. 26 // 27 // The output format is: 28 // 29 // file_path: computation_name :: type:literal_str 30 31 #include <stdio.h> 32 #include <memory> 33 #include <string> 34 #include <utility> 35 #include <vector> 36 37 #include "tensorflow/compiler/xla/client/client.h" 38 #include "tensorflow/compiler/xla/client/client_library.h" 39 #include "tensorflow/compiler/xla/client/computation.h" 40 #include "tensorflow/compiler/xla/client/global_data.h" 41 #include "tensorflow/compiler/xla/client/lib/testing.h" 42 #include "tensorflow/compiler/xla/client/local_client.h" 43 #include "tensorflow/compiler/xla/literal_util.h" 44 #include "tensorflow/compiler/xla/service/session.pb.h" 45 #include "tensorflow/compiler/xla/shape_util.h" 46 #include "tensorflow/compiler/xla/status_macros.h" 47 #include "tensorflow/compiler/xla/statusor.h" 48 #include "tensorflow/compiler/xla/tests/test_utils.h" 49 #include "tensorflow/compiler/xla/types.h" 50 #include "tensorflow/compiler/xla/xla_data.pb.h" 51 #include "tensorflow/core/lib/core/threadpool.h" 52 #include "tensorflow/core/lib/gtl/array_slice.h" 53 #include "tensorflow/core/platform/env.h" 54 #include "tensorflow/core/platform/init_main.h" 55 #include "tensorflow/core/platform/logging.h" 56 #include "tensorflow/core/util/command_line_flags.h" 57 58 namespace xla { 59 namespace tools { 60 namespace { 61 62 // Command-line opts to this tool. See main() for descriptions of these 63 // fields. 64 struct Options { 65 string fake_infeed_shape; 66 bool use_fake_data = false; 67 bool print_result = true; 68 int num_runs = 1; 69 }; 70 71 // Invokes the given computation passing arbitrary data for every (unbound) 72 // parameter if use_fake_data, Otherwise use recorded data if available. 73 // 74 // Similarly, infeeds fake data of shape fake_infeed_shape if it is provided; 75 // otherwise, no infeed is performed. 76 StatusOr<std::unique_ptr<Literal>> ReplayComputation( 77 const SessionModule& module, Client* client, const Options& opts) { 78 TF_ASSIGN_OR_RETURN(Computation computation, client->LoadSnapshot(module)); 79 80 std::vector<std::unique_ptr<GlobalData>> arguments; 81 if (opts.use_fake_data) { 82 arguments = MakeFakeArgumentsOrDie(computation, client); 83 } else { // use recorded data if available 84 for (const auto& proto : module.arguments()) { 85 TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Literal> literal, 86 Literal::CreateFromProto(proto)); 87 TF_ASSIGN_OR_RETURN(std::unique_ptr<GlobalData> data, 88 client->TransferToServer(*literal)); 89 arguments.push_back(std::move(data)); 90 } 91 } 92 93 // We only instantiate the thread pool if the user has requested that a 94 // concurrent infeed occur via the fake_infeed_shape. 95 tensorflow::gtl::optional<tensorflow::thread::ThreadPool> pool; 96 97 if (!opts.fake_infeed_shape.empty()) { 98 pool.emplace(tensorflow::Env::Default(), "infeed", 99 /*num_threads=*/1); 100 pool->Schedule([opts, client]() { 101 StatusOr<Shape> shape_status = 102 ShapeUtil::ParseShapeString(opts.fake_infeed_shape); 103 TF_CHECK_OK(shape_status.status()); 104 Shape shape = std::move(shape_status).ValueOrDie(); 105 StatusOr<std::unique_ptr<Literal>> data_status = MakeFakeLiteral(shape); 106 TF_CHECK_OK(data_status.status()); 107 std::unique_ptr<Literal> data = std::move(data_status).ValueOrDie(); 108 while (true) { 109 TF_CHECK_OK(client->TransferToInfeed(*data)); 110 } 111 }); 112 } 113 114 std::vector<GlobalData*> execute_arguments; 115 execute_arguments.reserve(arguments.size()); 116 for (auto& argument : arguments) { 117 execute_arguments.push_back(argument.get()); 118 } 119 120 // Run the computation num_runs times, and return the result from the last 121 // execution. 122 std::unique_ptr<Literal> result; 123 for (int i = 0; i < opts.num_runs; ++i) { 124 ExecutionProfile profile; 125 if (opts.print_result) { 126 TF_ASSIGN_OR_RETURN(result, client->ExecuteAndTransfer( 127 computation, execute_arguments, 128 /*execution_options=*/nullptr, &profile)); 129 } else { 130 // If we're not printing the result, execute the computation but don't 131 // bother retrieving the result. This can be a significant speedup. 132 TF_RETURN_IF_ERROR(client 133 ->Execute(computation, execute_arguments, 134 /*execution_options=*/nullptr, &profile) 135 .status()); 136 } 137 LOG(INFO) << "Execution took " 138 << static_cast<double>(profile.compute_time_ns()) / 1e9 << "s"; 139 } 140 141 return std::move(result); 142 } 143 144 int RealMain(tensorflow::gtl::ArraySlice<char*> args, const Options& opts) { 145 Client* client = ClientLibrary::LocalClientOrDie(); 146 tensorflow::Env* env = tensorflow::Env::Default(); 147 int exit_status = EXIT_SUCCESS; 148 for (char* arg : args) { 149 SessionModule module; 150 TF_CHECK_OK(tensorflow::ReadBinaryProto(env, arg, &module)); 151 StatusOr<std::unique_ptr<Literal>> result_status = 152 ReplayComputation(module, client, opts); 153 if (!result_status.ok()) { 154 fprintf(stderr, "%s: error: %s\n", arg, 155 result_status.status().ToString().c_str()); 156 exit_status = EXIT_FAILURE; 157 continue; 158 } 159 160 std::unique_ptr<Literal> result = result_status.ConsumeValueOrDie(); 161 if (result != nullptr) { 162 fprintf(stdout, "%s: %s :: %s:%s\n", arg, module.entry().name().c_str(), 163 ShapeUtil::HumanString(result->shape()).c_str(), 164 result->ToString().c_str()); 165 if (module.has_result()) { 166 std::unique_ptr<Literal> literal = 167 Literal::CreateFromProto(module.result()).ConsumeValueOrDie(); 168 fprintf(stdout, "was %s:%s\n", 169 ShapeUtil::HumanString(module.result().shape()).c_str(), 170 literal->ToString().c_str()); 171 } 172 } 173 } 174 175 ClientLibrary::DestroyLocalInstances(); 176 return exit_status; 177 } 178 179 } // namespace 180 } // namespace tools 181 } // namespace xla 182 183 int main(int argc, char** argv) { 184 xla::tools::Options opts; 185 const std::vector<tensorflow::Flag> flag_list = { 186 tensorflow::Flag("use_fake_data", &opts.use_fake_data, 187 "Replay computation using fake data"), 188 tensorflow::Flag("print_result", &opts.print_result, 189 "Print the result of the computation to stdout"), 190 tensorflow::Flag("num_runs", &opts.num_runs, 191 "Number of times to run each computation"), 192 tensorflow::Flag("fake_infeed_shape", &opts.fake_infeed_shape, 193 "Shape of fake data to construct for (infinite) infeed"), 194 }; 195 xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); 196 bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); 197 tensorflow::port::InitMain(argv[0], &argc, &argv); 198 if (argc < 2 || !parse_ok) { 199 LOG(QFATAL) << usage; 200 } 201 202 tensorflow::gtl::ArraySlice<char*> args(argv, argc); 203 args.pop_front(); // Pop off the binary name, argv[0] 204 return xla::tools::RealMain(args, opts); 205 } 206