1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" 17 18 #include <stddef.h> 19 #include <string.h> 20 #include <map> 21 #include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex. 22 #include <string> 23 #include <unordered_map> 24 #include <utility> 25 #include <vector> 26 27 // IWYU pragma: no_include "llvm/Config/Disassemblers.def.inc" 28 // IWYU pragma: no_include "llvm/Config/Targets.def.inc" 29 #include "llvm/ADT/StringRef.h" 30 #include "llvm/ADT/Triple.h" 31 #include "llvm/IR/Function.h" 32 #include "llvm/IR/LLVMContext.h" 33 #include "llvm/IR/Module.h" 34 #include "llvm/IR/Verifier.h" 35 #include "llvm/Object/ObjectFile.h" 36 #include "llvm/Support/CommandLine.h" 37 #include "llvm/Support/TargetRegistry.h" 38 #include "llvm/Support/TargetSelect.h" 39 #include "llvm/Target/TargetMachine.h" 40 #include "llvm/Target/TargetOptions.h" 41 #include "tensorflow/compiler/xla/literal_util.h" 42 #include "tensorflow/compiler/xla/map_util.h" 43 #include "tensorflow/compiler/xla/protobuf_util.h" 44 #include "tensorflow/compiler/xla/ptr_util.h" 45 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" 46 #include "tensorflow/compiler/xla/service/batchnorm_expander.h" 47 #include "tensorflow/compiler/xla/service/buffer_assignment.h" 48 #include "tensorflow/compiler/xla/service/buffer_liveness.h" 49 #include "tensorflow/compiler/xla/service/call_inliner.h" 50 #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" 51 #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h" 52 #include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h" 53 #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" 54 #include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h" 55 #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h" 56 #include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h" 57 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" 58 #include "tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h" 59 #include "tensorflow/compiler/xla/service/cpu/disassembler.h" 60 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" 61 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" 62 #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h" 63 #include "tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h" 64 #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" 65 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" 66 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 67 #include "tensorflow/compiler/xla/service/dot_decomposer.h" 68 #include "tensorflow/compiler/xla/service/flatten_call_graph.h" 69 #include "tensorflow/compiler/xla/service/hlo.pb.h" 70 #include "tensorflow/compiler/xla/service/hlo_computation.h" 71 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" 72 #include "tensorflow/compiler/xla/service/hlo_cse.h" 73 #include "tensorflow/compiler/xla/service/hlo_dce.h" 74 #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" 75 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 76 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 77 #include "tensorflow/compiler/xla/service/hlo_ordering.h" 78 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" 79 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" 80 #include "tensorflow/compiler/xla/service/hlo_proto_util.h" 81 #include "tensorflow/compiler/xla/service/hlo_scheduling.h" 82 #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" 83 #include "tensorflow/compiler/xla/service/hlo_verifier.h" 84 #include "tensorflow/compiler/xla/service/inliner.h" 85 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 86 #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" 87 #include "tensorflow/compiler/xla/service/reshape_mover.h" 88 #include "tensorflow/compiler/xla/service/transpose_folding.h" 89 #include "tensorflow/compiler/xla/service/tuple_simplifier.h" 90 #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" 91 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" 92 #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" 93 #include "tensorflow/compiler/xla/status_macros.h" 94 #include "tensorflow/compiler/xla/statusor.h" 95 #include "tensorflow/compiler/xla/types.h" 96 #include "tensorflow/compiler/xla/util.h" 97 #include "tensorflow/compiler/xla/xla_data.pb.h" 98 #include "tensorflow/core/lib/strings/str_util.h" 99 #include "tensorflow/core/lib/strings/strcat.h" 100 101 namespace se = ::perftools::gputools; 102 103 namespace xla { 104 namespace cpu { 105 106 CpuAotCompilationOptions::CpuAotCompilationOptions( 107 string triple, string cpu_name, string features, string entry_point_name, 108 RelocationModel relocation_model) 109 : triple_(std::move(triple)), 110 cpu_name_(std::move(cpu_name)), 111 features_(std::move(features)), 112 entry_point_name_(std::move(entry_point_name)), 113 relocation_model_(relocation_model) {} 114 115 CpuAotCompilationOptions::~CpuAotCompilationOptions() = default; 116 117 se::Platform::Id CpuAotCompilationOptions::PlatformId() const { 118 return se::host::kHostPlatformId; 119 } 120 121 CpuAotCompilationResult::CpuAotCompilationResult( 122 ObjectFileData object_file_data, BufferSizes buffer_sizes, 123 int64 result_buffer_index) 124 : object_file_data_(std::move(object_file_data)), 125 buffer_sizes_(std::move(buffer_sizes)), 126 result_buffer_index_(result_buffer_index) {} 127 128 CpuAotCompilationResult::~CpuAotCompilationResult() = default; 129 130 CpuCompiler::CpuCompiler() { 131 // Initialize LLVM the first time the CpuCompiler is initialized. 132 static bool llvm_initialized = []() { 133 InitializeLLVMTarget(); 134 return true; 135 }(); 136 (void)llvm_initialized; 137 } 138 139 /* static */ void CpuCompiler::InitializeLLVMTarget() { 140 // Initialize LLVM's MC layer for the native target. 141 llvm::InitializeNativeTarget(); 142 llvm::InitializeNativeTargetAsmPrinter(); 143 LLVMInitializeX86Target(); 144 LLVMInitializeX86TargetInfo(); 145 LLVMInitializeX86TargetMC(); 146 LLVMInitializeX86AsmPrinter(); 147 LLVMInitializeX86Disassembler(); 148 LLVMInitializeARMTarget(); 149 LLVMInitializeARMTargetInfo(); 150 LLVMInitializeARMTargetMC(); 151 LLVMInitializeARMAsmPrinter(); 152 LLVMInitializeARMDisassembler(); 153 LLVMInitializeAArch64Target(); 154 LLVMInitializeAArch64TargetInfo(); 155 LLVMInitializeAArch64TargetMC(); 156 LLVMInitializeAArch64AsmPrinter(); 157 LLVMInitializeAArch64Disassembler(); 158 } 159 160 namespace { 161 162 // LLVM makes certain options configurable only through its command-line 163 // options; it provide the ParseCommandLineOptions function that lets us set 164 // flags at runtime. However, since these flags are global we want to avoid 165 // multiple invocations of the LLVM compilation pipeline with a different set of 166 // flags. Therefore, we only pass command-line flags to LLVM once, before the 167 // first module is compiled. 168 std::once_flag llvm_command_line_options_initialized; 169 170 // This visitor records which HLO instructions should have profiling information 171 // recorded. 172 class CollectProfileCandidates : public DfsHloVisitorWithDefault { 173 public: 174 static StatusOr<std::unordered_map<const HloInstruction*, int64>> 175 GetCandidatesForComputation( 176 HloComputation* computation, 177 const std::unordered_map<const HloInstruction*, int64>& 178 assigned_indices) { 179 std::unordered_map<const HloInstruction*, int64> hlo_to_profile_idx; 180 CollectProfileCandidates profile_candidates_for_computation( 181 &hlo_to_profile_idx, assigned_indices); 182 TF_RETURN_IF_ERROR( 183 computation->Accept(&profile_candidates_for_computation)); 184 return hlo_to_profile_idx; 185 } 186 187 private: 188 CollectProfileCandidates( 189 std::unordered_map<const HloInstruction*, int64>* hlo_to_profile_idx, 190 const std::unordered_map<const HloInstruction*, int64>& assigned_indices) 191 : hlo_to_profile_idx_(hlo_to_profile_idx), 192 assigned_indices_(assigned_indices) {} 193 194 Status DefaultAction(HloInstruction* hlo_instruction) override { 195 hlo_to_profile_idx_->insert( 196 {hlo_instruction, FindOrDie(assigned_indices_, hlo_instruction)}); 197 return Status::OK(); 198 } 199 200 Status HandleCall(HloInstruction* call) override { 201 TF_RETURN_IF_ERROR(DefaultAction(call)); 202 CollectProfileCandidates candidates_for_call(hlo_to_profile_idx_, 203 assigned_indices_); 204 TF_RETURN_IF_ERROR(call->to_apply()->Accept(&candidates_for_call)); 205 return Status::OK(); 206 } 207 208 // Skip constants, there is nothing to profile. 209 Status HandleConstant(HloInstruction*) override { return Status::OK(); } 210 // Skip parameters, they are a simple load. 211 Status HandleParameter(HloInstruction*) override { return Status::OK(); } 212 // It is important to recurse for "while" or else we risk overly coarse 213 // profiling information. 214 Status HandleWhile(HloInstruction* xla_while) override { 215 TF_RETURN_IF_ERROR(DefaultAction(xla_while)); 216 217 CollectProfileCandidates candidates_for_condition(hlo_to_profile_idx_, 218 assigned_indices_); 219 TF_RETURN_IF_ERROR( 220 xla_while->while_condition()->Accept(&candidates_for_condition)); 221 222 CollectProfileCandidates candidates_for_body(hlo_to_profile_idx_, 223 assigned_indices_); 224 TF_RETURN_IF_ERROR(xla_while->while_body()->Accept(&candidates_for_body)); 225 226 return Status::OK(); 227 } 228 229 std::unordered_map<const HloInstruction*, int64>* hlo_to_profile_idx_; 230 const std::unordered_map<const HloInstruction*, int64>& assigned_indices_; 231 }; 232 } // namespace 233 234 Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { 235 // Optimization pipeline. 236 HloPassPipeline pipeline("CPU"); 237 pipeline.AddInvariantChecker<HloVerifier>(); 238 pipeline.AddPass<CpuHloSupportChecker>(); 239 240 ReducePrecisionInsertion::AddPasses( 241 &pipeline, module->config().debug_options(), 242 ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION); 243 244 // TODO(b/35786417): Re-enable inliner pass after fixing the bug and deciding 245 // where we will take this pass in future. 246 // pipeline.AddPass<Inliner>(); 247 248 // TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner 249 // pass. 250 pipeline.AddPass<CallInliner>(); 251 pipeline.AddPass<DotDecomposer>(); 252 pipeline.AddPass<ConvCanonicalization>(); 253 { 254 auto& pass = 255 pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification"); 256 pass.AddInvariantChecker<HloVerifier>(); 257 258 pass.AddPass<BatchNormExpander>( 259 /*rewrite_training_op=*/true, 260 /*rewrite_inference_op=*/true, 261 /*rewrite_grad_op=*/true, 262 /*use_fusion=*/false); 263 pass.AddPass<AlgebraicSimplifier>( 264 /*is_layout_sensitive=*/false, 265 [](const Shape&, const Shape&) { return false; }, 266 /*enable_dot_strength_reduction=*/false); 267 268 // BatchNormExpander can create zero-sized ops, so zero-sized HLO 269 // elimination has to come after that pass. 270 pipeline.AddPass<ZeroSizedHloElimination>(); 271 272 pass.AddPass<WhileLoopInvariantCodeMotion>(); 273 pass.AddPass<TupleSimplifier>(); 274 pass.AddPass<WhileLoopSimplifier>(); 275 pass.AddPass<HloDCE>(); 276 pass.AddPass<ReshapeMover>(); 277 pass.AddPass<HloConstantFolding>(); 278 } 279 pipeline.AddPass<TransposeFolding>( 280 [](const HloInstruction& dot, 281 const TransposeFolding::OperandIndices& candidate_operands) { 282 return PotentiallyImplementedAsEigenDot(dot) 283 ? candidate_operands 284 : TransposeFolding::OperandIndices{}; 285 }, 286 TransposeFolding::NeverFoldTranspose); 287 pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false); 288 pipeline.AddPass<CpuInstructionFusion>(); 289 290 ReducePrecisionInsertion::AddPasses( 291 &pipeline, module->config().debug_options(), 292 ReducePrecisionInsertion::PassTiming::AFTER_FUSION); 293 294 pipeline.AddPass<CpuLayoutAssignment>( 295 module->mutable_entry_computation_layout()); 296 // The LayoutAssignment pass may leave behind kCopy instructions which are 297 // duplicate or NOPs, so remove them with algebraic simplification and CSE. 298 pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>( 299 /*is_layout_sensitive=*/true, 300 [](const Shape&, const Shape&) { return true; }, 301 /*enable_dot_strength_reduction=*/false); 302 pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true); 303 pipeline.AddPass<HloElementTypeConverter>(BF16, F32); 304 // Outline ops in the entry computation into calls to subcomputations. 305 const int max_parallelism = 306 module->config().intra_op_parallelism_threads() > 0 307 ? module->config().intra_op_parallelism_threads() 308 : tensorflow::port::NumSchedulableCPUs(); 309 if (options::CpuParallelBackendRequested(module->config())) { 310 pipeline.AddPass<ParallelizationPreparation>(max_parallelism, 311 ShapeSizeBytesFunction()); 312 } else if (!is_aot_compile) { 313 // Run ParallelTaskAssigner to assign parallel tasks to HLOs in module. 314 // Note this is not run for AOT because it would bring in thread pool 315 // and thread synchronization dependencies which would likely increase 316 // binary size (and most AOT applications are single-threaded). 317 // TODO(29630486) Support multi-threaded AOT. 318 pipeline.AddPass<ParallelTaskAssigner>(max_parallelism, 319 ShapeSizeBytesFunction()); 320 } 321 // Copy insertion should be performed immediately before IR emission to avoid 322 // inserting unnecessary copies (later pass adds an instruction which 323 // materializes the value) or missing a necessary copy (later pass removes an 324 // instruction which materializes a value). DCE must be run immediately before 325 // (and sometime after) copy insertion, to avoid dead code from interfering 326 // with the rewrites. 327 pipeline.AddPass<HloDCE>(); 328 pipeline.AddPass<FlattenCallGraph>(); 329 pipeline.AddPass<CpuCopyInsertion>(); 330 if (options::CpuParallelBackendRequested(module->config())) { 331 // Re-run the outlining, in case any copies were inserted into the entry 332 // computation. 333 pipeline.AddPass<ParallelizationPreparation>(max_parallelism, 334 ShapeSizeBytesFunction()); 335 pipeline.AddPass<CpuCopyInsertion>(); 336 } 337 pipeline.AddPass<HloDCE>(); 338 return pipeline.Run(module).status(); 339 } 340 341 namespace { 342 343 // Align buffers to 16-byte boundaries. 344 constexpr int64 kMemoryAlignment = 16; 345 auto memory_alignment = [](LogicalBuffer::Color) { return kMemoryAlignment; }; 346 347 llvm::TargetOptions CompilerTargetOptions( 348 const HloModuleConfig& module_config) { 349 llvm::TargetOptions target_options; 350 llvm_ir::SetTargetOptions( 351 /*fast_math_enabled=*/module_config.debug_options() 352 .xla_enable_fast_math(), 353 &target_options); 354 return target_options; 355 } 356 357 llvm::CodeGenOpt::Level CodeGenOptLevel(const HloModuleConfig& module_config) { 358 VLOG(2) << "backend_optimization_level: " 359 << module_config.debug_options().xla_backend_optimization_level(); 360 switch (module_config.debug_options().xla_backend_optimization_level()) { 361 case 1: 362 return llvm::CodeGenOpt::Less; 363 case 2: 364 return llvm::CodeGenOpt::Default; 365 case 3: 366 return llvm::CodeGenOpt::Aggressive; 367 default: 368 return llvm::CodeGenOpt::None; 369 } 370 } 371 372 Status InitializeModuleHooks( 373 const HloModule& hlo_module, 374 const LLVMCompiler::ModuleHook& user_pre_optimization_hook, 375 const LLVMCompiler::ModuleHook& user_post_optimization_hook, 376 LLVMCompiler::ModuleHook* pre_optimization_ir_hook, 377 LLVMCompiler::ModuleHook* post_optimization_ir_hook) { 378 const string& ir_dump_directory = 379 hlo_module.config().debug_options().xla_dump_ir_to(); 380 if (ir_dump_directory.empty()) { 381 *pre_optimization_ir_hook = user_pre_optimization_hook; 382 *post_optimization_ir_hook = user_post_optimization_hook; 383 return Status::OK(); 384 } 385 386 const string& hlo_module_name = hlo_module.name(); 387 388 // Create the IR hooks. If applicable, each IR hook does the following: 389 // 390 // * Calls the user supplied module hook. 391 // * Writes out the IR to a file in the output directory designated by 392 // --xla_dump_ir_to 393 394 *pre_optimization_ir_hook = 395 [user_pre_optimization_hook, ir_dump_directory, 396 hlo_module_name](const llvm::Module& llvm_module) { 397 if (user_pre_optimization_hook) { 398 TF_RETURN_IF_ERROR(user_pre_optimization_hook(llvm_module)); 399 } 400 return llvm_ir::DumpIRToDirectory(/*directory_name=*/ir_dump_directory, 401 /*hlo_module_name=*/hlo_module_name, 402 llvm_module, 403 /*optimized=*/false); 404 }; 405 406 *post_optimization_ir_hook = 407 [user_post_optimization_hook, ir_dump_directory, 408 hlo_module_name](const llvm::Module& llvm_module) { 409 if (user_post_optimization_hook) { 410 TF_RETURN_IF_ERROR(user_post_optimization_hook(llvm_module)); 411 } 412 return llvm_ir::DumpIRToDirectory(/*directory_name=*/ir_dump_directory, 413 /*hlo_module_name=*/hlo_module_name, 414 llvm_module, 415 /*optimized=*/true); 416 }; 417 418 return Status::OK(); 419 } 420 421 Status VerifyLlvmModule(const llvm::Module& llvm_module) { 422 XLA_SCOPED_LOGGING_TIMER("CpuCompiler - Running LLVM verifier"); 423 424 std::string err; 425 llvm::raw_string_ostream err_stream(err); 426 427 // verifyModule() returns true if the module is broken. 428 TF_RET_CHECK(!llvm::verifyModule(llvm_module, &err_stream)) 429 << "Invalid LLVM IR before optimizations:\n" 430 << err_stream.str() 431 << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. " 432 "Rerun with --xla_dump_ir_to to get the IR. "; 433 return Status::OK(); 434 } 435 436 } // namespace 437 438 StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses( 439 std::unique_ptr<HloModule> module, 440 perftools::gputools::StreamExecutor* /*stream_exec*/, 441 DeviceMemoryAllocator* /*device_allocator*/) { 442 VLOG(2) << "Before optimization:"; 443 XLA_VLOG_LINES(2, module->ToString()); 444 445 TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false)); 446 447 VLOG(2) << "After optimization:"; 448 XLA_VLOG_LINES(2, module->ToString()); 449 return std::move(module); 450 } 451 452 StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend( 453 std::unique_ptr<HloModule> module, 454 perftools::gputools::StreamExecutor* stream_exec, 455 DeviceMemoryAllocator* /*device_allocator*/) { 456 const string timer_message = 457 "Compiling [" + module->name() + "] for CPU using JIT"; 458 XLA_SCOPED_LOGGING_TIMER(timer_message); 459 460 VLOG(1) << "Compiling: " << module->name(); 461 TF_RET_CHECK(stream_exec != nullptr); 462 std::call_once(llvm_command_line_options_initialized, 463 &llvm_ir::InitializeLLVMCommandLineOptions, module->config()); 464 465 ModuleHook pre_optimization_ir_hook; 466 ModuleHook post_optimization_ir_hook; 467 TF_RETURN_IF_ERROR(InitializeModuleHooks( 468 *module, user_pre_optimization_hook_, user_post_optimization_hook_, 469 &pre_optimization_ir_hook, &post_optimization_ir_hook)); 470 471 // Compile must be thread-safe so create a new LLVM context for the module. 472 auto llvm_context = xla::MakeUnique<llvm::LLVMContext>(); 473 auto llvm_module = 474 xla::MakeUnique<llvm::Module>("__compute_module", *llvm_context); 475 476 auto jit = xla::MakeUnique<SimpleOrcJIT>( 477 CompilerTargetOptions(module->config()), 478 CodeGenOptLevel(module->config()), 479 options::OptimizeForSizeRequested(module->config()), 480 module->config().debug_options().xla_enable_fast_math(), 481 module->config().debug_options().xla_llvm_disable_expensive_passes(), 482 pre_optimization_ir_hook, post_optimization_ir_hook); 483 llvm_module->setDataLayout(jit->data_layout()); 484 llvm_module->setTargetTriple(jit->target_triple().getTriple()); 485 486 HloComputation* entry_computation = module->entry_computation(); 487 std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx; 488 std::unordered_map<const HloComputation*, int64> computation_to_profile_idx; 489 std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map; 490 std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data; 491 if (module->config().hlo_profiling_enabled()) { 492 hlo_profile_index_map = MakeUnique<HloProfileIndexMap>(*module); 493 494 TF_ASSIGN_OR_RETURN( 495 instruction_to_profile_idx, 496 CollectProfileCandidates::GetCandidatesForComputation( 497 entry_computation, 498 hlo_profile_index_map->instruction_to_profile_idx())); 499 500 auto shape_size_bytes = [](const Shape& shape) { 501 // On the cpu, opaques are pointers. 502 if (ShapeUtil::IsOpaque(shape)) { 503 return static_cast<int64>(sizeof(void*)); 504 } 505 return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); 506 }; 507 508 HloCostAnalysis cost_analysis(shape_size_bytes); 509 TF_RETURN_IF_ERROR(entry_computation->Accept(&cost_analysis)); 510 hlo_profile_printer_data = 511 CreateHloProfilePrinterData(*hlo_profile_index_map, cost_analysis); 512 computation_to_profile_idx = 513 hlo_profile_index_map->computation_to_profile_idx(); 514 } 515 516 std::unique_ptr<Executable> cpu_executable; 517 518 // Cache these flags here since we'll want to access them after the module's 519 // ownership is std::moved. 520 const bool embed_ir_in_executable = 521 module->config().debug_options().xla_embed_ir_in_executable(); 522 const string xla_dump_optimized_hlo_proto_to = 523 module->config().debug_options().xla_dump_optimized_hlo_proto_to(); 524 525 if (options::CpuParallelBackendRequested(module->config())) { 526 VLOG(1) << "Using parallel cpu backend"; 527 528 // Run buffer analysis on the HLO graph. This analysis figures out which 529 // temporary buffers are required to run the computation. 530 // DependencyHloOrdering is used for the parallel emitter because the order 531 // of HLO instruction execution is not known ahead of time. 532 // DependencyHloOrdering is the most conservative partial order and only 533 // uses data dependencies for determining order. 534 TF_ASSIGN_OR_RETURN( 535 std::unique_ptr<BufferAssignment> assignment, 536 BufferAssigner::Run( 537 module.get(), xla::MakeUnique<DependencyHloOrdering>(module.get()), 538 BufferSizeBytesFunction(), memory_alignment)); 539 // BufferAssignment::ToString() includes a header, so no need for us to 540 // print one ourselves. 541 XLA_VLOG_LINES(2, assignment->ToString()); 542 543 if (!xla_dump_optimized_hlo_proto_to.empty()) { 544 HloProto proto = MakeHloProto(*module, *assignment); 545 TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( 546 proto, xla_dump_optimized_hlo_proto_to, module->name())); 547 } 548 549 // If we are using the parallel CPU backend, we need to create map from 550 // HloInstruction to the corresponding generated function name. 551 std::map<HloComputation*, HloInstruction*> parallel_computations; 552 std::unordered_map<const HloInstruction*, std::unique_ptr<unsigned char[]>> 553 aligned_constants; 554 for (auto instruction : entry_computation->MakeInstructionPostOrder()) { 555 // Parameters and constants don't get their own computation. 556 if (instruction->opcode() == HloOpcode::kParameter) { 557 continue; 558 } 559 if (instruction->opcode() == HloOpcode::kConstant) { 560 // Copy the constant out of the ProtocolBuffer so that we can give it a 561 // higher alignment. 562 const void* data = instruction->literal().untyped_data(); 563 int64 size = CpuExecutable::ShapeSizeBytes(instruction->shape()); 564 auto iter = aligned_constants.emplace( 565 instruction, xla::MakeUnique<unsigned char[]>(size)); 566 CHECK_EQ(iter.second, true); 567 unsigned char* aligned_data = iter.first->second.get(); 568 memcpy(aligned_data, data, size); 569 continue; 570 } 571 // The parallel preparation should have ensured that the top-level 572 // computation consists solely of Call instructions. 573 TF_RET_CHECK(instruction->opcode() == HloOpcode::kCall) 574 << module->ToString(); 575 HloComputation* to_apply = instruction->to_apply(); 576 parallel_computations.emplace(to_apply, instruction); 577 } 578 579 IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), 580 std::move(instruction_to_profile_idx), 581 std::move(computation_to_profile_idx), 582 jit->target_machine(), jit->external_constant_pool()); 583 584 std::unique_ptr<HloInstructionMap<string>> function_names( 585 new HloInstructionMap<string>()); 586 for (auto embedded_computation : 587 entry_computation->MakeEmbeddedComputationsList()) { 588 if (embedded_computation->IsFusionComputation()) { 589 continue; 590 } 591 auto parallel_computation_iter = 592 parallel_computations.find(embedded_computation); 593 // All parallel computations are considered to be an entry computation for 594 // IR generation purposes. 595 bool computation_is_parallel = 596 parallel_computation_iter != parallel_computations.end(); 597 TF_ASSIGN_OR_RETURN( 598 llvm::Function * ir_function, 599 ir_emitter.EmitComputation( 600 embedded_computation, embedded_computation->name(), 601 /*is_top_level_computation=*/computation_is_parallel, 602 /*instruction_order=*/nullptr)); 603 // If this computation is parallel, remember it in the function name map. 604 // This way we know what function to execute when we try to run code for 605 // the Call instruction. 606 if (computation_is_parallel) { 607 HloInstruction* call_instruction = parallel_computation_iter->second; 608 InsertOrDie(function_names.get(), call_instruction, 609 llvm_ir::AsString(ir_function->getName())); 610 } 611 } 612 613 string ir_module_string; 614 if (embed_ir_in_executable) { 615 ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); 616 } 617 TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); 618 619 // JIT compile the LLVM IR module to in-memory machine code. 620 jit->AddModule(std::move(llvm_module)); 621 cpu_executable.reset(new ParallelCpuExecutable( 622 std::move(jit), std::move(assignment), std::move(module), 623 std::move(function_names), std::move(aligned_constants), 624 std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map))); 625 626 if (embed_ir_in_executable) { 627 static_cast<CpuExecutable&>(*cpu_executable) 628 .set_ir_module_string(ir_module_string); 629 } 630 } else { 631 VLOG(1) << "Using sequential cpu backend"; 632 633 // Select an order for emitting the HLO instructions for each 634 // computation. Using this sequence enables tighter buffer liveness analysis 635 // and reduced memory usage (as compared to using DependencyHloOrdering). 636 TF_ASSIGN_OR_RETURN( 637 SequentialHloOrdering::HloModuleSequence module_sequence, 638 CreateMemoryMinimizingSequence(*module, BufferSizeBytesFunction())); 639 640 // Run buffer analysis on the HLO graph. This analysis figures out which 641 // temporary buffers are required to run the computation. 642 TF_ASSIGN_OR_RETURN( 643 std::unique_ptr<BufferAssignment> assignment, 644 BufferAssigner::Run(module.get(), 645 xla::MakeUnique<SequentialHloOrdering>( 646 module.get(), module_sequence), 647 BufferSizeBytesFunction(), memory_alignment)); 648 // BufferAssignment::ToString() includes a header, so no need for us to 649 // print one ourselves. 650 XLA_VLOG_LINES(2, assignment->ToString()); 651 652 if (!xla_dump_optimized_hlo_proto_to.empty()) { 653 HloProto proto = MakeHloProto(*module, *assignment); 654 TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( 655 proto, xla_dump_optimized_hlo_proto_to, module->name())); 656 } 657 658 // Each computation is a single function. Emit all embedded computations 659 // before the entry computation. The order of computations returned from 660 // GetEmbeddedComputations guarantees that a called computation occurs 661 // before a caller computation. 662 663 IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), 664 std::move(instruction_to_profile_idx), 665 std::move(computation_to_profile_idx), 666 jit->target_machine(), jit->external_constant_pool()); 667 668 for (auto embedded_computation : 669 entry_computation->MakeEmbeddedComputationsList()) { 670 if (embedded_computation->IsFusionComputation()) { 671 continue; 672 } 673 TF_RETURN_IF_ERROR( 674 ir_emitter 675 .EmitComputation(embedded_computation, 676 embedded_computation->name(), 677 /*is_top_level_computation=*/false, 678 &module_sequence.at(embedded_computation)) 679 .status()); 680 } 681 string function_name_prefix = entry_computation->name().empty() 682 ? "__compute" 683 : entry_computation->name(); 684 TF_ASSIGN_OR_RETURN( 685 llvm::Function * entry_function, 686 ir_emitter.EmitComputation(entry_computation, function_name_prefix, 687 /*is_top_level_computation=*/true, 688 &module_sequence.at(entry_computation))); 689 690 string function_name = llvm_ir::AsString(entry_function->getName()); 691 string ir_module_string; 692 if (embed_ir_in_executable) { 693 ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); 694 } 695 TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); 696 697 XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(*llvm_module)); 698 699 // JIT compile the LLVM IR module to in-memory machine code. 700 jit->AddModule(std::move(llvm_module)); 701 cpu_executable.reset(new CpuExecutable( 702 std::move(jit), std::move(assignment), std::move(module), function_name, 703 std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map))); 704 705 if (embed_ir_in_executable) { 706 static_cast<CpuExecutable&>(*cpu_executable) 707 .set_ir_module_string(ir_module_string); 708 } 709 } 710 711 VLOG(1) << "Compilation finished"; 712 return std::move(cpu_executable); 713 } 714 715 StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> 716 CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules, 717 const AotCompilationOptions& aot_options) { 718 TF_RET_CHECK(!modules.empty()); 719 std::call_once(llvm_command_line_options_initialized, 720 &llvm_ir::InitializeLLVMCommandLineOptions, 721 modules[0]->config()); 722 723 // We can pass just one llvm::TargetOptions when we compile the LLVM module, 724 // so we bail if the configs have conflicting flags. At the moment, the only 725 // flag that needs to be consistent is fast-math. 726 const bool fast_math_enabled = 727 modules[0]->config().debug_options().xla_enable_fast_math(); 728 for (const auto& module : modules) { 729 if (module->config().debug_options().xla_enable_fast_math() != 730 fast_math_enabled) { 731 return InvalidArgument( 732 "All HLO module configs must have the same value for " 733 "xla_enable_fast_math."); 734 } 735 } 736 737 if (aot_options.PlatformId() != se::host::kHostPlatformId) { 738 return InvalidArgument("Incompatible AOT compilation platform"); 739 } 740 const CpuAotCompilationOptions& options = 741 static_cast<const CpuAotCompilationOptions&>(aot_options); 742 llvm::StringRef target_triple = llvm_ir::AsStringRef(options.triple()); 743 llvm::Triple triple(llvm::Triple::normalize(target_triple)); 744 std::string error; 745 const llvm::Target* target = 746 llvm::TargetRegistry::lookupTarget(triple.getTriple(), error); 747 if (target == nullptr) { 748 return InternalError("TargetRegistry::lookupTarget failed: %s", 749 error.c_str()); 750 } 751 752 llvm::Reloc::Model reloc_model = llvm::Reloc::Static; 753 llvm::PICLevel::Level pic_level = llvm::PICLevel::NotPIC; 754 llvm::PIELevel::Level pie_level = llvm::PIELevel::Default; 755 switch (options.relocation_model()) { 756 case CpuAotCompilationOptions::RelocationModel::Static: 757 reloc_model = llvm::Reloc::Static; 758 pic_level = llvm::PICLevel::NotPIC; 759 pie_level = llvm::PIELevel::Default; 760 break; 761 case CpuAotCompilationOptions::RelocationModel::SmallPic: 762 reloc_model = llvm::Reloc::PIC_; 763 pic_level = llvm::PICLevel::SmallPIC; 764 pie_level = llvm::PIELevel::Default; 765 break; 766 case CpuAotCompilationOptions::RelocationModel::BigPic: 767 reloc_model = llvm::Reloc::PIC_; 768 pic_level = llvm::PICLevel::BigPIC; 769 pie_level = llvm::PIELevel::Default; 770 break; 771 case CpuAotCompilationOptions::RelocationModel::SmallPie: 772 reloc_model = llvm::Reloc::PIC_; 773 pic_level = llvm::PICLevel::SmallPIC; 774 pie_level = llvm::PIELevel::Small; 775 break; 776 case CpuAotCompilationOptions::RelocationModel::BigPie: 777 reloc_model = llvm::Reloc::PIC_; 778 pic_level = llvm::PICLevel::BigPIC; 779 pie_level = llvm::PIELevel::Large; 780 break; 781 } 782 llvm::StringRef cpu_name = llvm_ir::AsStringRef(options.cpu_name()); 783 llvm::StringRef features = llvm_ir::AsStringRef(options.features()); 784 llvm::CodeGenOpt::Level opt_level = CodeGenOptLevel(modules[0]->config()); 785 std::unique_ptr<llvm::TargetMachine> target_machine = WrapUnique( 786 target->createTargetMachine(triple.getTriple(), cpu_name, features, 787 CompilerTargetOptions(modules[0]->config()), 788 reloc_model, llvm::None, opt_level)); 789 790 // Compile must be thread-safe so create a new LLVM context for the module. 791 llvm::LLVMContext llvm_context; 792 llvm::Module llvm_module("__compute_module", llvm_context); 793 llvm_module.setDataLayout(target_machine->createDataLayout()); 794 llvm_module.setTargetTriple(triple.getTriple()); 795 if (pic_level != llvm::PICLevel::NotPIC) { 796 llvm_module.setPICLevel(pic_level); 797 } 798 if (pie_level != llvm::PIELevel::Default) { 799 llvm_module.setPIELevel(pie_level); 800 } 801 802 std::vector<std::unique_ptr<AotCompilationResult>> results; 803 for (size_t i = 0; i < modules.size(); ++i) { 804 HloModule* module = modules[i].get(); 805 VLOG(1) << "Compiling ahead-of-time: " << module->name(); 806 807 VLOG(2) << "Before optimization:"; 808 XLA_VLOG_LINES(2, module->ToString()); 809 810 TF_RETURN_IF_ERROR(RunHloPasses(module, /*is_aot_compile=*/true)); 811 812 VLOG(2) << "After optimization:"; 813 XLA_VLOG_LINES(2, module->ToString()); 814 815 TF_ASSIGN_OR_RETURN( 816 SequentialHloOrdering::HloModuleSequence module_sequence, 817 CreateMemoryMinimizingSequence(*module, BufferSizeBytesFunction())); 818 819 // Run buffer analysis on the HLO graph. This analysis figures out which 820 // temporary buffers are required to run the computation. 821 TF_ASSIGN_OR_RETURN( 822 std::unique_ptr<BufferAssignment> assignment, 823 BufferAssigner::Run( 824 module, 825 xla::MakeUnique<SequentialHloOrdering>(module, module_sequence), 826 BufferSizeBytesFunction(), memory_alignment)); 827 // BufferAssignment::ToString() includes a header, so no need for us to 828 // print one ourselves. 829 XLA_VLOG_LINES(2, assignment->ToString()); 830 831 const string xla_dump_optimized_hlo_proto_to = 832 module->config().debug_options().xla_dump_optimized_hlo_proto_to(); 833 if (!xla_dump_optimized_hlo_proto_to.empty()) { 834 HloProto proto = MakeHloProto(*module, *assignment); 835 TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( 836 proto, xla_dump_optimized_hlo_proto_to, module->name())); 837 } 838 839 IrEmitter ir_emitter(*module, *assignment, &llvm_module, 840 /*instruction_to_profile_idx=*/ 841 std::unordered_map<const HloInstruction*, int64>{}, 842 /*computation_to_profile_idx=*/ 843 std::unordered_map<const HloComputation*, int64>{}, 844 target_machine.get(), 845 /*external_constant_pool=*/nullptr); 846 HloComputation* computation = module->entry_computation(); 847 for (auto embedded_computation : 848 computation->MakeEmbeddedComputationsList()) { 849 if (embedded_computation->IsFusionComputation()) { 850 continue; 851 } 852 TF_RETURN_IF_ERROR( 853 ir_emitter 854 .EmitComputation(embedded_computation, 855 embedded_computation->name(), 856 /*is_top_level_computation=*/false, 857 &module_sequence.at(embedded_computation)) 858 .status()); 859 } 860 const string& entry_point_name = options.entry_point_name(); 861 TF_ASSIGN_OR_RETURN( 862 llvm::Function * entry_function, 863 ir_emitter.EmitComputation(computation, entry_point_name, 864 /*is_top_level_computation=*/true, 865 &module_sequence.at(computation))); 866 867 CHECK(entry_function->getName() == llvm_ir::AsStringRef(entry_point_name)); 868 869 ModuleHook pre_optimization_ir_dump_hook; 870 ModuleHook post_optimization_ir_dump_hook; 871 TF_RETURN_IF_ERROR(InitializeModuleHooks( 872 *module, user_pre_optimization_hook_, user_post_optimization_hook_, 873 &pre_optimization_ir_dump_hook, &post_optimization_ir_dump_hook)); 874 875 // Run the LLVM verifier over the unoptimized LLVM IR. If it fails, run the 876 // pre-optimization IR dump hook before returning. 877 { 878 Status verify_status = VerifyLlvmModule(llvm_module); 879 if (!verify_status.ok() && pre_optimization_ir_dump_hook) { 880 pre_optimization_ir_dump_hook(llvm_module).IgnoreError(); 881 } 882 TF_RETURN_IF_ERROR(verify_status); 883 } 884 885 Disassembler disassembler(*target_machine); 886 CompilerFunctor compiler_functor( 887 target_machine.get(), &disassembler, opt_level, 888 options::OptimizeForSizeRequested(module->config()), 889 module->config().debug_options().xla_enable_fast_math(), 890 module->config().debug_options().xla_llvm_disable_expensive_passes(), 891 pre_optimization_ir_dump_hook, post_optimization_ir_dump_hook); 892 llvm::object::OwningBinary<llvm::object::ObjectFile> object_file = 893 compiler_functor(llvm_module); 894 llvm::StringRef object_file_data_ref = object_file.getBinary()->getData(); 895 ObjectFileData object_file_data(object_file_data_ref.begin(), 896 object_file_data_ref.end()); 897 898 BufferSizes buffer_sizes; 899 for (const BufferAllocation& allocation : assignment->Allocations()) { 900 // Callers don't need to allocate temporary buffers for parameters. 901 if (allocation.is_entry_computation_parameter()) { 902 buffer_sizes.push_back(-1); 903 continue; 904 } 905 // Callers don't need to allocate anything for thread-local temporary 906 // buffers. They are lowered to allocas. 907 if (allocation.is_thread_local()) { 908 buffer_sizes.push_back(-1); 909 continue; 910 } 911 buffer_sizes.push_back(allocation.size()); 912 } 913 914 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, 915 assignment->GetUniqueTopLevelOutputSlice()); 916 917 results.emplace_back(MakeUnique<CpuAotCompilationResult>( 918 std::move(object_file_data), std::move(buffer_sizes), 919 result_slice.index())); 920 } 921 922 VLOG(1) << "Compilation finished"; 923 return std::move(results); 924 } 925 926 se::Platform::Id CpuCompiler::PlatformId() const { 927 return se::host::kHostPlatformId; 928 } 929 930 HloCostAnalysis::ShapeSizeFunction CpuCompiler::ShapeSizeBytesFunction() const { 931 return CpuExecutable::ShapeSizeBytes; 932 } 933 934 } // namespace cpu 935 } // namespace xla 936 937 static bool InitModule() { 938 xla::Compiler::RegisterCompilerFactory(se::host::kHostPlatformId, []() { 939 return xla::MakeUnique<xla::cpu::CpuCompiler>(); 940 }); 941 return true; 942 } 943 static bool module_initialized = InitModule(); 944