1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Usage: replay_computation some_binary_snapshot_proto* 17 // 18 // Replays computations and shows the results on the command line. 19 // 20 // some_binary_snapshot_proto is obtained by serializing the HloSnapshot from 21 // ServiceInterface::SnapshotComputation to disk. 22 // 23 // Computations that require arguments can be replayed using fake data by 24 // passing --use_fake_data on the command line. If the real data is available 25 // in the proto and --use_fake_data is false, the real data is used. 26 // 27 // Input can be a binary HloSnapshot proto, a binary HloProto proto, or a 28 // textual HLO string. 29 // 30 // The output format is: 31 // 32 // file_path: computation_name :: type:literal_str 33 // 34 // Note: If you pass multiple modules, they will be compiled in parallel but run 35 // in series. 36 37 #include <stdio.h> 38 #include <memory> 39 #include <string> 40 #include <utility> 41 #include <vector> 42 43 #include "absl/types/span.h" 44 #include "tensorflow/compiler/xla/client/client.h" 45 #include "tensorflow/compiler/xla/client/client_library.h" 46 #include "tensorflow/compiler/xla/client/global_data.h" 47 #include "tensorflow/compiler/xla/client/lib/testing.h" 48 #include "tensorflow/compiler/xla/client/local_client.h" 49 #include "tensorflow/compiler/xla/client/xla_computation.h" 50 #include "tensorflow/compiler/xla/debug_options_flags.h" 51 #include "tensorflow/compiler/xla/execution_options_util.h" 52 #include "tensorflow/compiler/xla/literal.h" 53 #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" 54 #include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h" 55 #include "tensorflow/compiler/xla/service/hlo.pb.h" 56 #include "tensorflow/compiler/xla/service/hlo_parser.h" 57 #include "tensorflow/compiler/xla/shape_util.h" 58 #include "tensorflow/compiler/xla/status_macros.h" 59 #include "tensorflow/compiler/xla/statusor.h" 60 #include "tensorflow/compiler/xla/tests/test_utils.h" 61 #include "tensorflow/compiler/xla/types.h" 62 #include "tensorflow/compiler/xla/xla_data.pb.h" 63 #include "tensorflow/core/lib/core/threadpool.h" 64 #include "tensorflow/core/platform/env.h" 65 #include "tensorflow/core/platform/init_main.h" 66 #include "tensorflow/core/platform/logging.h" 67 #include "tensorflow/core/util/command_line_flags.h" 68 69 namespace xla { 70 namespace tools { 71 namespace { 72 73 // Command-line opts to this tool. See main() for descriptions of these 74 // fields. 75 struct Options { 76 string fake_infeed_shape; 77 string fake_outfeed_shape; 78 79 // generate_fake_infeed == true is a safe default: If the model has 0 or 1 80 // infeeds, then it will work like normal. If the model has more than one 81 // infeed, it will be an error, but that wouldn't have worked anyway if you 82 // hadn't passed generate_fake_infeed. 83 // 84 // Same for generate_fake_outfeed. 85 bool generate_fake_infeed = true; 86 bool generate_fake_outfeed = true; 87 88 bool use_fake_data = false; 89 bool print_result = true; 90 int num_runs = 1; 91 }; 92 93 StatusOr<std::unique_ptr<LocalExecutable>> CompileExecutable( 94 const HloSnapshot& module, LocalClient* client) { 95 XlaComputation computation(module.hlo().hlo_module()); 96 std::vector<Shape> argument_layouts; 97 argument_layouts.reserve( 98 computation.proto().host_program_shape().parameters_size()); 99 std::vector<const Shape*> argument_layout_ptrs; 100 for (const ShapeProto& param : 101 computation.proto().host_program_shape().parameters()) { 102 argument_layouts.push_back(Shape(param)); 103 argument_layout_ptrs.push_back(&argument_layouts.back()); 104 } 105 ExecutableBuildOptions exec_build_options; 106 *exec_build_options.mutable_debug_options() = GetDebugOptionsFromFlags(); 107 return client->Compile(computation, argument_layout_ptrs, exec_build_options); 108 } 109 110 absl::optional<Shape> GetXfeedShape(bool is_infeed, 111 const HloModuleProto& module, 112 const Options& opts) { 113 std::vector<HloInstructionProto> xfeed_instrs; 114 for (const auto& comp : module.computations()) { 115 for (const auto& instruction : comp.instructions()) { 116 if (instruction.opcode() == HloOpcodeString(is_infeed 117 ? HloOpcode::kInfeed 118 : HloOpcode::kOutfeed)) { 119 xfeed_instrs.push_back(instruction); 120 } 121 } 122 } 123 124 auto log_xfeed_instrs = [&] { 125 for (const auto& infeed : xfeed_instrs) { 126 LOG(ERROR) << " " << ShapeUtil::HumanString(Shape(infeed.shape())) << " " 127 << infeed.name(); 128 } 129 }; 130 131 auto find_instruction_from_id_or_die = [&](int64 id) { 132 for (const auto& comp : module.computations()) { 133 for (const auto& instruction : comp.instructions()) { 134 if (instruction.id() == id) { 135 return instruction; 136 } 137 } 138 } 139 LOG(FATAL) << "No instruction with id " << id; 140 }; 141 142 absl::optional<Shape> xfeed_shape; 143 string xfeed_name = is_infeed ? "infeed" : "outfeed"; 144 string fake_xfeed_shape = 145 is_infeed ? opts.fake_infeed_shape : opts.fake_outfeed_shape; 146 bool generate_fake_xfeed = 147 is_infeed ? opts.generate_fake_infeed : opts.generate_fake_outfeed; 148 if (!fake_xfeed_shape.empty()) { 149 xfeed_shape = std::move(ParseShape(fake_xfeed_shape)).ValueOrDie(); 150 } else if (generate_fake_xfeed) { 151 CHECK_LT(xfeed_instrs.size(), 2) 152 << "--generate_fake_" << xfeed_name 153 << " only works if the model has 0 or 1 " << xfeed_name << " ops."; 154 if (xfeed_instrs.empty()) { 155 LOG(INFO) << "Not generating fake " << xfeed_name 156 << " shape; model has no " << xfeed_name << "s."; 157 } else if (xfeed_instrs.size() == 1) { 158 // kInfeed instructions should have a shape (buffer, token). kOutfeed 159 // instructions should have operand 0 of shape `buffer`. We want to xfeed 160 // just `buffer`. 161 xfeed_shape = is_infeed 162 ? Shape(xfeed_instrs.front().shape()).tuple_shapes(0) 163 : Shape(find_instruction_from_id_or_die( 164 xfeed_instrs.front().operand_ids(0)) 165 .shape()); 166 LOG(INFO) << "Generating fake " << xfeed_name << " with inferred shape: " 167 << ShapeUtil::HumanString(*xfeed_shape); 168 } else { 169 LOG(ERROR) << "--generate_fake_" << xfeed_name 170 << " only works if the model has 0 or 1 " << xfeed_name 171 << " ops, but this model has " << xfeed_instrs.size() 172 << " of them:"; 173 log_xfeed_instrs(); 174 LOG(FATAL) << "Can't run model with --generate_fake_infeed."; 175 } 176 } else if (!xfeed_instrs.empty()) { 177 LOG(ERROR) << "Model contains " << xfeed_instrs.size() << " " << xfeed_name 178 << " instruction(s), but neither --generate_fake_" << xfeed_name 179 << " nor --fake_" << xfeed_name 180 << "_shape was specified. Execution will likely hang."; 181 log_xfeed_instrs(); 182 } 183 184 return xfeed_shape; 185 } 186 187 // Invokes the given computation passing arbitrary data for every (unbound) 188 // parameter if use_fake_data, Otherwise use recorded data if available. 189 // 190 // Similarly, infeeds fake data of shape fake_infeed_shape if it is provided. 191 // If generate_fake_infeed is true, the required infeed shape is derived from 192 // the computation and then used to provide a fake infeed shape. 193 // 194 // If neither generate_fake_infeed is true nor a fake_infeed_shape is provided, 195 // no infeed is performed. 196 StatusOr<Literal> ReplayComputation(const HloSnapshot& module, 197 LocalExecutable* executable, 198 LocalClient* client, const Options& opts) { 199 XlaComputation computation(module.hlo().hlo_module()); 200 201 // Build the `argument_ptrs` vector, which contains ShapedBuffer*s to our 202 // arguments. This is a bit involved, because we may have to convert from 203 // GlobalData to ShapedBuffer*, and we have to manage the lifetime of all our 204 // objects. 205 std::vector<ScopedShapedBuffer> scoped_shaped_buffer_arguments; 206 std::vector<std::unique_ptr<GlobalData>> global_data_arguments; 207 std::vector<const ShapedBuffer*> argument_ptrs; 208 if (opts.use_fake_data) { 209 // Run fake computations with debug options ignoring XLA_FLAGS. Users very 210 // likely want XLA_FLAGS only to apply to the "real" computation being run, 211 // not to the fake computations we use for generating arguments. 212 auto debug_opts = DefaultDebugOptionsIgnoringFlags(); 213 global_data_arguments = 214 MakeFakeArgumentsOrDie(computation, client, &debug_opts); 215 for (const auto& data : global_data_arguments) { 216 argument_ptrs.push_back( 217 client->GlobalDataToShapedBuffer(data->handle(), /*device_ordinal=*/0) 218 .ValueOrDie()); 219 } 220 } else { // use recorded data if available 221 for (const auto& proto : module.arguments()) { 222 TF_ASSIGN_OR_RETURN(Literal literal, Literal::CreateFromProto(proto)); 223 TF_ASSIGN_OR_RETURN( 224 ScopedShapedBuffer data, 225 client->LiteralToShapedBuffer(literal, /*device_ordinal=*/0)); 226 scoped_shaped_buffer_arguments.push_back(std::move(data)); 227 } 228 for (const auto& argument : scoped_shaped_buffer_arguments) { 229 argument_ptrs.push_back(&argument); 230 } 231 } 232 233 if (absl::optional<Shape> infeed_shape = GetXfeedShape( 234 /*is_infeed=*/true, computation.proto(), opts)) { 235 auto infeed_data = std::make_shared<Literal>( 236 std::move(MakeFakeLiteral(*infeed_shape)).ValueOrDie()); 237 xla::gpu::GetOrCreateInfeedManager() 238 ->RegisterBeforeGetNextDestinationCallback([infeed_data, client] { 239 TF_CHECK_OK(client->TransferToInfeed(*infeed_data)); 240 }); 241 } 242 243 absl::optional<tensorflow::thread::ThreadPool> outfeed_thread_pool; 244 if (absl::optional<Shape> outfeed_shape = GetXfeedShape( 245 /*is_infeed=*/false, computation.proto(), opts)) { 246 // For each an outfeed that runs, enqueue a task that will consume it. We 247 // need a thread pool because the act of running an outfeed blocks on there 248 // being a destination available, and the act of making a destination 249 // available blocks on there being outfeed data available. 250 outfeed_thread_pool.emplace(tensorflow::Env::Default(), "infeed", 251 /*num_threads=*/1); 252 auto consume_outfeed = [client, outfeed_shape] { 253 TF_CHECK_OK( 254 client->TransferFromOutfeedLocal(*outfeed_shape, /*device_ordinal=*/0) 255 .status()); 256 VLOG(1) << "Received outfeed data of shape " 257 << ShapeUtil::HumanStringWithLayout(*outfeed_shape); 258 }; 259 xla::gpu::GetOrCreateOutfeedManager() 260 ->RegisterBeforeGetNextDestinationCallback( 261 [consume_outfeed, &outfeed_thread_pool] { 262 outfeed_thread_pool->Schedule(consume_outfeed); 263 }); 264 } 265 266 // Do not attempt to run the executable if num_runs is less than 1. 267 if (opts.num_runs < 1) { 268 return Cancelled("Cancelled after compilation since --num_runs < 1."); 269 } 270 271 // Run the computation num_runs times, and return the result from the last 272 // execution. 273 const bool xla_hlo_profile = GetDebugOptionsFromFlags().xla_hlo_profile(); 274 StreamExecutorMemoryAllocator allocator( 275 client->platform(), 276 {client->platform()->ExecutorForDevice(0).ValueOrDie()}); 277 absl::optional<ScopedShapedBuffer> final_result; 278 for (int i = 0; i < opts.num_runs; ++i) { 279 // If xla_hlo_profile is enabled, print a noisy message before the last run, 280 // making it easier to separate this profile from the others in the logspam. 281 bool is_final_result = i == opts.num_runs - 1; 282 if (xla_hlo_profile && is_final_result) { 283 LOG(INFO) << "\n\n***** Final run below ******"; 284 } 285 ExecutionProfile profile; 286 ExecutableRunOptions run_options; 287 run_options.set_execution_profile(&profile); 288 run_options.set_allocator(&allocator); 289 290 TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, 291 executable->Run(argument_ptrs, run_options)); 292 LOG(INFO) << "Done executing in " 293 << static_cast<double>(profile.compute_time_ns()) / 1e9 294 << "s: " << module.hlo().hlo_module().name(); 295 296 // Save the result if this is for the final iteration. Otherwise discard 297 // the result before rerunning the computation, so as to free up the 298 // relevant memory. 299 if (is_final_result) { 300 final_result = std::move(result); 301 } 302 } 303 304 TF_ASSIGN_OR_RETURN(Literal result_literal, 305 client->ShapedBufferToLiteral(*final_result)); 306 return result_literal; 307 } 308 309 StatusOr<HloSnapshot> ParseInputFile(const string& filename, 310 const Options& opts) { 311 tensorflow::Env* env = tensorflow::Env::Default(); 312 HloSnapshot snapshot; 313 auto s = tensorflow::ReadBinaryProto(env, filename, &snapshot); 314 if (s.ok()) { 315 return snapshot; 316 } 317 if (s.code() == tensorflow::error::NOT_FOUND) { 318 return s; 319 } 320 CHECK(opts.use_fake_data) 321 << "Without --use_fake_data, you must pass an HloSnapshot -- HloProto " 322 "and textual HLO don't carry real data."; 323 fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n", 324 filename.c_str()); 325 326 if (tensorflow::ReadBinaryProto(env, filename, snapshot.mutable_hlo()).ok()) { 327 return snapshot; 328 } 329 fprintf(stderr, "%s: is not HloProto. Trying HLO text.\n", filename.c_str()); 330 string contents; 331 TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(env, filename, &contents)); 332 HloModuleConfig config; 333 config.set_debug_options(GetDebugOptionsFromFlags()); 334 StatusOr<std::unique_ptr<HloModule>> module = 335 ParseHloString(contents, config); 336 if (module.ok()) { 337 *snapshot.mutable_hlo()->mutable_hlo_module() = 338 module.ValueOrDie()->ToProto(); 339 return snapshot; 340 } 341 fprintf(stderr, "%s: is not HLO text. Nothing left to try.\n", 342 filename.c_str()); 343 return InvalidArgument("Could not parse %s.", filename); 344 } 345 346 int RealMain(absl::Span<char* const> args, const Options& opts) { 347 LocalClient* client = ClientLibrary::LocalClientOrDie(); 348 int exit_status = EXIT_SUCCESS; 349 350 std::vector<HloSnapshot> snapshots; 351 for (char* arg : args) { 352 StatusOr<HloSnapshot> maybe_snapshot = ParseInputFile(arg, opts); 353 if (maybe_snapshot.ok()) { 354 snapshots.push_back(std::move(maybe_snapshot).ValueOrDie()); 355 } else { 356 LOG(ERROR) << "Can't handle file " << arg << ": " 357 << maybe_snapshot.status(); 358 } 359 } 360 361 // Compile all the modules in parallel. 362 LOG(INFO) << "Compiling " << snapshots.size() << " modules in parallel."; 363 std::vector<StatusOr<std::unique_ptr<LocalExecutable>>> executables; 364 { 365 // ThreadPool CHECK-fails if we give it 0 threads. 366 tensorflow::thread::ThreadPool thread_pool( 367 tensorflow::Env::Default(), tensorflow::ThreadOptions(), 368 "compile_modules", std::max(size_t{1}, snapshots.size()), 369 /*low_latency_hint=*/false); 370 executables.resize(snapshots.size()); 371 for (int64 i = 0; i < snapshots.size(); ++i) { 372 thread_pool.Schedule([&snapshots, &executables, client, i] { 373 executables[i] = CompileExecutable(snapshots[i], client); 374 }); 375 } 376 } 377 LOG(INFO) << "Done compiling; now running the modules."; 378 379 for (int64 i = 0; i < executables.size(); ++i) { 380 if (!executables[i].ok()) { 381 LOG(ERROR) << "Compilation failed: " << executables[i].status(); 382 exit_status = EXIT_FAILURE; 383 continue; 384 } 385 LocalExecutable* executable = executables[i].ValueOrDie().get(); 386 LOG(ERROR) << "Running iteration " << i; 387 StatusOr<Literal> result_status = 388 ReplayComputation(snapshots[i], executable, client, opts); 389 LOG(ERROR) << "iteration complete."; 390 if (!result_status.ok()) { 391 fprintf(stderr, "%s: error: %s\n", args[i], 392 result_status.status().ToString().c_str()); 393 exit_status = EXIT_FAILURE; 394 continue; 395 } 396 397 if (opts.print_result) { 398 Literal result = std::move(result_status).ValueOrDie(); 399 fprintf(stdout, "%s: %s :: %s:%s\n", args[i], 400 executable->executable()->module().name().c_str(), 401 ShapeUtil::HumanString(result.shape()).c_str(), 402 result.ToString().c_str()); 403 auto& snapshot = snapshots[i]; 404 if (snapshot.has_result()) { 405 Literal literal = 406 Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie(); 407 fprintf( 408 stdout, "was %s:%s\n", 409 ShapeUtil::HumanString(Shape(snapshot.result().shape())).c_str(), 410 literal.ToString().c_str()); 411 } 412 } 413 } 414 415 ClientLibrary::DestroyLocalInstances(); 416 return exit_status; 417 } 418 419 } // namespace 420 } // namespace tools 421 } // namespace xla 422 423 int main(int argc, char** argv) { 424 xla::tools::Options opts; 425 const std::vector<tensorflow::Flag> flag_list = { 426 tensorflow::Flag("use_fake_data", &opts.use_fake_data, 427 "Replay computation using fake data"), 428 tensorflow::Flag("print_result", &opts.print_result, 429 "Print the result of the computation to stdout"), 430 tensorflow::Flag("num_runs", &opts.num_runs, 431 "Number of times to run each computation"), 432 tensorflow::Flag("fake_infeed_shape", &opts.fake_infeed_shape, 433 "Shape of fake data to construct for (infinite) infeed"), 434 tensorflow::Flag("fake_outfeed_shape", &opts.fake_outfeed_shape, 435 "Shape of fake data to outfeed from computation"), 436 tensorflow::Flag("generate_fake_infeed", &opts.generate_fake_infeed, 437 "Whether a fake infeed shape should be derived " 438 "from the computation"), 439 tensorflow::Flag("generate_fake_outfeed", &opts.generate_fake_outfeed, 440 "Whether a fake outfeed shape should be derived " 441 "from the computation"), 442 }; 443 xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); 444 bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); 445 tensorflow::port::InitMain(argv[0], &argc, &argv); 446 if (argc < 2 || !parse_ok) { 447 LOG(QFATAL) << usage; 448 } 449 450 absl::Span<char* const> args(argv, argc); 451 args.remove_prefix(1); // Pop off the binary name, argv[0] 452 return xla::tools::RealMain(args, opts); 453 } 454