1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h" 17 18 #include <stdlib.h> 19 #include <atomic> 20 #include <functional> 21 #include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex. 22 #include <utility> 23 24 #include "llvm/IR/DiagnosticInfo.h" 25 #include "llvm/IR/DiagnosticPrinter.h" 26 #include "llvm/IR/LLVMContext.h" 27 #include "llvm/IR/Module.h" 28 #include "llvm/IR/Verifier.h" 29 #include "tensorflow/compiler/xla/protobuf_util.h" 30 #include "tensorflow/compiler/xla/ptr_util.h" 31 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" 32 #include "tensorflow/compiler/xla/service/batchnorm_expander.h" 33 #include "tensorflow/compiler/xla/service/buffer_assignment.h" 34 #include "tensorflow/compiler/xla/service/buffer_liveness.h" 35 #include "tensorflow/compiler/xla/service/call_inliner.h" 36 #include "tensorflow/compiler/xla/service/dot_decomposer.h" 37 #include "tensorflow/compiler/xla/service/flatten_call_graph.h" 38 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" 39 #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" 40 #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" 41 #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" 42 #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" 43 #include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h" 44 #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" 45 #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" 46 #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" 47 #include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" 48 #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" 49 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" 50 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" 51 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" 52 #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" 53 #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" 54 #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" 55 #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" 56 #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" 57 #include "tensorflow/compiler/xla/service/hlo.pb.h" 58 #include "tensorflow/compiler/xla/service/hlo_computation.h" 59 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" 60 #include "tensorflow/compiler/xla/service/hlo_cse.h" 61 #include "tensorflow/compiler/xla/service/hlo_dce.h" 62 #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" 63 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 64 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" 65 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" 66 #include "tensorflow/compiler/xla/service/hlo_proto_util.h" 67 #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" 68 #include "tensorflow/compiler/xla/service/hlo_verifier.h" 69 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 70 #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" 71 #include "tensorflow/compiler/xla/service/reshape_mover.h" 72 #include "tensorflow/compiler/xla/service/transpose_folding.h" 73 #include "tensorflow/compiler/xla/service/tuple_simplifier.h" 74 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" 75 #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" 76 #include "tensorflow/compiler/xla/status_macros.h" 77 #include "tensorflow/compiler/xla/types.h" 78 #include "tensorflow/compiler/xla/util.h" 79 #include "tensorflow/core/lib/core/status.h" 80 #include "tensorflow/core/lib/gtl/cleanup.h" 81 #include "tensorflow/core/lib/io/path.h" 82 #include "tensorflow/core/lib/strings/strcat.h" 83 #include "tensorflow/core/platform/cuda_libdevice_path.h" 84 #include "tensorflow/core/platform/env.h" 85 #include "tensorflow/core/platform/logging.h" 86 #include "tensorflow/core/platform/regexp.h" 87 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 88 #include "tensorflow/core/platform/subprocess.h" 89 #include "tensorflow/core/platform/tracing.h" 90 #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h" 91 92 namespace se = ::perftools::gputools; 93 94 namespace xla { 95 namespace gpu { 96 97 /* static */ const char* GpuCompiler::kTargetTriple = "nvptx64-nvidia-cuda"; 98 /* static */ const char* GpuCompiler::kDataLayout = 99 "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"; 100 101 namespace { 102 103 using tensorflow::port::Tracing; 104 105 // Returns the directory containing nvvm libdevice files. config_cuda_data_dir 106 // should be equal to config().debug_options().xla_gpu_cuda_data_dir() of the 107 // HloModule being compiled. 108 string GetLibdeviceDir(const string& config_cuda_data_dir) { 109 std::vector<string> potential_libdevice_dirs; 110 if (!config_cuda_data_dir.empty()) { 111 potential_libdevice_dirs.push_back(config_cuda_data_dir); 112 } 113 potential_libdevice_dirs.push_back(tensorflow::LibdeviceRoot()); 114 115 // Tries all potential libdevice directories in the order they are inserted. 116 // Returns the first directory that exists in the file system. 117 for (const string& potential_libdevice_dir : potential_libdevice_dirs) { 118 if (tensorflow::Env::Default()->IsDirectory(potential_libdevice_dir).ok()) { 119 VLOG(2) << "Found libdevice dir " << potential_libdevice_dir; 120 return potential_libdevice_dir; 121 } 122 VLOG(2) << "Unable to find potential libdevice dir " 123 << potential_libdevice_dir; 124 } 125 126 // Last resort: maybe in the current folder. 127 return "."; 128 } 129 130 // Runs optimization passes on the given HLO module. 131 tensorflow::Status OptimizeHloModule(HloModule* hlo_module, 132 se::StreamExecutor* stream_exec, 133 DeviceMemoryAllocator* device_allocator) { 134 { 135 HloPassPipeline pipeline("optimization"); 136 pipeline.AddInvariantChecker<HloVerifier>(); 137 pipeline.AddPass<GpuHloSupportChecker>(); 138 ReducePrecisionInsertion::AddPasses( 139 &pipeline, hlo_module->config().debug_options(), 140 ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION); 141 142 // TODO(b/64094172): make Call work on GPU instead of inlining. 143 pipeline.AddPass<CallInliner>(); 144 // Convert BF16 operations to F32 operations so that the GPU backend can 145 // support BF16 operations without directly implementing a BF16 lowering for 146 // most ops. 147 pipeline.AddPass<HloElementTypeConverter>(BF16, F32); 148 pipeline.AddPass<DotDecomposer>(); 149 150 { 151 auto& pass = 152 pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification"); 153 pass.AddInvariantChecker<HloVerifier>(); 154 155 // If cudnn batchnorms are enabled, rewrite batchnorm HLOs to cudnn calls 156 // where possible. Not every batchnorm op can be implemented as a call to 157 // cudnn, so decompose any remaining batchnorm ops into a soup of HLOs. 158 if (hlo_module->config().debug_options().xla_gpu_use_cudnn_batchnorm()) { 159 pass.AddPass<CudnnBatchNormRewriter>(); 160 } 161 pass.AddPass<BatchNormExpander>( 162 /*rewrite_training_op=*/true, 163 /*rewrite_inference_op=*/true, 164 /*rewrite_grad_op=*/true, 165 /*use_fusion=*/false); 166 167 // BatchNormExpander can create zero-sized ops, so zero-sized HLO 168 // elimination has to come after that pass. 169 pipeline.AddPass<ZeroSizedHloElimination>(); 170 171 pass.AddPass<AlgebraicSimplifier>( 172 /*is_layout_sensitive=*/false, 173 [](const Shape&, const Shape&) { return false; }); 174 pass.AddPass<TupleSimplifier>(); 175 pass.AddPass<WhileLoopSimplifier>(); 176 pass.AddPass<HloDCE>(); 177 pass.AddPass<ReshapeMover>(); 178 pass.AddPass<HloConstantFolding>(); 179 } 180 181 pipeline.AddPass<TransposeFolding>( 182 [](const HloInstruction& dot, 183 const TransposeFolding::OperandIndices& candidate_operands) { 184 return ImplementedAsGemm(dot) ? candidate_operands 185 : TransposeFolding::OperandIndices{}; 186 }, 187 TransposeFolding::NeverFoldTranspose); 188 pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false); 189 pipeline.AddPass<HloDCE>(); 190 TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); 191 } 192 193 { 194 // Convert convolutions into CustomCalls to cudnn, then canonicalize them 195 // (PadInsertion). 196 HloPassPipeline pipeline("conv_canonicalization"); 197 pipeline.AddInvariantChecker<HloVerifier>(); 198 pipeline.AddPass<CudnnConvolutionRewriter>(); 199 pipeline.AddPass<PadInsertion>(); 200 201 // Choose the fastest algorithm for each conv. 202 // 203 // In theory doing this here is way too early: It needs to happen after 204 // layout assignment, because the layout of the inputs/outputs affects the 205 // speed of the conv. But currently we only allow only one input/output 206 // layout when calling cudnn, so there's no ambiguity. 207 // 208 // We pick the algorithm at this early stage so we can generate better HLO. 209 // After CudnnConvolutionRewriter, our convolutions are CustomCalls which 210 // return a tuple (conv_result, scratch_memory), and the each conv uses 0 211 // bytes of scratch: 212 // 213 // customcall = (f32[...], f32[0]) 214 // return gte(customcall, 0) 215 // 216 // The algorithm picker then chooses the best algorithm, and potentially 217 // increases the scratch space. It replaces customcall with new_tuple, 218 // giving us the following: 219 // 220 // new_customcall = (f32[...], f32[N]) 221 // new_tuple = tuple(gte(new_customcall, 0), constant f32[0]) 222 // return gte(new_tuple, 0) 223 // 224 // The new tuple and gte instructions then be simplified away, because 225 // nobody is expected to use the scratch value. 226 // 227 // However, if we were to run CudnnConvolutionAlgorithmPicker after layout 228 // assignment, fusion would already have run, and the gte(customcall, 0) 229 // would probably already be into a fusion node. We can't simplify across 230 // HloComputation boundaries, so in this case we wouldn't be able to 231 // simplify away the new_tuple bits. 232 // 233 // We'll need to revisit this if we ever allow multiple layouts for the 234 // inputs/outputs of a cudnn convolution. 235 pipeline.AddPass<CudnnConvolutionAlgorithmPicker>(stream_exec, 236 device_allocator); 237 // Clean up new_tuple described above. 238 pipeline.AddPass<TupleSimplifier>(); 239 pipeline.AddPass<HloDCE>(); 240 241 TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); 242 } 243 244 { 245 HloPassFix<HloPassPipeline> fusion("fusion"); 246 fusion.AddInvariantChecker<HloVerifier>(); 247 fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false); 248 fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true); 249 fusion.AddPass<FusionMerger>(); 250 TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); 251 252 HloPassPipeline reduce_pipeline("reduce-precision"); 253 reduce_pipeline.AddInvariantChecker<HloVerifier>(); 254 ReducePrecisionInsertion::AddPasses( 255 &reduce_pipeline, hlo_module->config().debug_options(), 256 ReducePrecisionInsertion::PassTiming::AFTER_FUSION); 257 StatusOr<bool> reduce_result = reduce_pipeline.Run(hlo_module); 258 TF_RETURN_IF_ERROR(reduce_result.status()); 259 260 if (reduce_result.ValueOrDie()) { 261 // Do another fusion pass, with the expectation that we may be able to 262 // fuse the new ReducePrecision operations. 263 TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); 264 } 265 } 266 return tensorflow::Status::OK(); 267 } 268 269 // Modifies the given HLO module so that it will be accepted by IrEmitter. 270 // Unlike optimization passes, the passes are necessary for correctness. 271 tensorflow::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { 272 // In some cases, we have to place the result of an instruction in a temporary 273 // buffer. For instance, the buffer that holds an external parameter is 274 // assumed immutable at this point, and should not be reused for output 275 // (b/27180329). Therefore, in that case, we set the output to be a copy of 276 // the parameter. 277 HloPassPipeline pipeline("GPU-ir-emit-prepare"); 278 pipeline.AddInvariantChecker<HloVerifier>(); 279 280 pipeline.AddPass<GpuLayoutAssignment>( 281 hlo_module->mutable_entry_computation_layout()); 282 283 // The LayoutAssignment pass may leave behind kCopy instructions which are 284 // duplicate or NOPs, so remove them with algebraic simplification and CSE. 285 pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>( 286 /*is_layout_sensitive=*/true, 287 [](const Shape&, const Shape&) { return true; }); 288 pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true); 289 // Copy insertion should be performed immediately before IR emission to avoid 290 // inserting unnecessary copies (later pass adds an instruction which 291 // materializes the value) or missing a necessary copy (later pass removes an 292 // instruction which materializes a value). DCE must be run immediately before 293 // (and sometime after) copy insertion, to avoid dead code from interfering 294 // with the rewrites. 295 pipeline.AddPass<HloDCE>(); 296 pipeline.AddPass<FlattenCallGraph>(); 297 pipeline.AddPass<GpuCopyInsertion>(); 298 return pipeline.Run(hlo_module).status(); 299 } 300 301 // Prints a warning if the ptxas at ptxas_path has known bugs. 302 // 303 // Only prints a warning the first time it's called for a particular value of 304 // ptxas_path. 305 void WarnIfBadPtxasVersion(const string& ptxas_path) { 306 static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); 307 static std::unordered_set<string>* seen_ptxas_paths GUARDED_BY(mu) = 308 new std::unordered_set<string>(); 309 310 tensorflow::mutex_lock lock(mu); 311 if (!seen_ptxas_paths->insert(ptxas_path).second) { 312 // Already checked this ptx binary, nothing to do. 313 return; 314 } 315 316 tensorflow::SubProcess ptxas; 317 ptxas.SetProgram(ptxas_path, {ptxas_path, "--version"}); 318 ptxas.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE); 319 if (!ptxas.Start()) { 320 LOG(WARNING) << "Couldn't invoke " << ptxas_path << " --version"; 321 return; 322 } 323 324 string out; 325 int exit_code = ptxas.Communicate(/*stdin_input=*/nullptr, &out, 326 /*stderr_output=*/nullptr); 327 if (exit_code != 0) { 328 LOG(WARNING) << "Running " << ptxas_path << " --version returned " 329 << exit_code; 330 return; 331 } 332 333 int64 vmaj, vmin, vdot; 334 string vmaj_str, vmin_str, vdot_str; 335 if (!RE2::PartialMatch(out, R"(\bV(\d+)\.(\d+)\.(\d+)\b)", &vmaj_str, 336 &vmin_str, &vdot_str) || 337 !tensorflow::strings::safe_strto64(vmaj_str, &vmaj) || 338 !tensorflow::strings::safe_strto64(vmin_str, &vmin) || 339 !tensorflow::strings::safe_strto64(vdot_str, &vdot)) { 340 LOG(WARNING) << "Couldn't parse ptxas version in output of " << ptxas_path 341 << " --version:\n" 342 << out; 343 return; 344 } 345 346 // ptxas 9.0 before 9.0.276 and ptxas 9.1 before 9.1.121 miscompile some 347 // address calculations with large offsets (e.g. "load ptr + large_constant"), 348 // b/70245379. 349 if ((vmaj == 9 && vmin == 0 && vdot < 276) || 350 (vmaj == 9 && vmin == 1 && vdot < 121)) { 351 LOG(WARNING) << "*** WARNING *** You are using ptxas " << vmaj << "." 352 << vmin << "." << vdot 353 << ", which is in range [9.0.0, 9.0.276) + [9.1.0, 9.1.121). " 354 "These versions are known to miscompile XLA code, leading " 355 "to incorrect results or invalid-address errors."; 356 } 357 } 358 359 // Prints a warning if the ptx->sass JIT in the driver has known bugs. 360 // 361 // Using such a driver only a problem if we fail to use ptxas to compile our ptx 362 // and have to use the driver instead, so you should only call this function if 363 // we're going to use the driver JIT. 364 // 365 // Only prints a warning the first time it's called. 366 void WarnIfBadDriverJITVersion() { 367 static std::once_flag run_once; 368 std::call_once(run_once, [] { 369 auto version_or_status = se::cuda::Diagnostician::FindKernelDriverVersion(); 370 if (!version_or_status.ok()) { 371 LOG(WARNING) << "Couldn't read CUDA driver version."; 372 return; 373 } 374 se::cuda::DriverVersion version = version_or_status.ValueOrDie(); 375 376 // The following versions of the driver JIT miscompile some address 377 // calculations with large offsets (e.g. "load ptr + large_constant"), 378 // b/70245379: 379 // 380 // - 384.x before 384.108 381 // - 387.x before 387.40 382 // - 390.x before 390.10. 383 auto vmaj = std::get<0>(version); 384 auto vmin = std::get<1>(version); 385 if ((vmaj == 384 && vmin < 108) || // 386 (vmaj == 387 && vmin < 40) || // 387 (vmaj == 390 && vmin < 10)) { 388 LOG(WARNING) 389 << "*** WARNING *** Invoking the PTX->SASS JIT from driver version " 390 << se::cuda::DriverVersionToString(version) 391 << ", which is in range [384.0.0, 384.108.0) + [387.0.0, 387.40.0) + " 392 "[390.0.0, 390.10.0). These versions are known to miscompile XLA " 393 "code, leading to incorrect results or invalid-address errors."; 394 } 395 }); 396 } 397 398 // Compiles the given PTX string using ptxas and returns the resulting machine 399 // code (i.e. a cubin) as a byte array. 400 StatusOr<std::vector<uint8>> CompilePtx(const string& ptx, int cc_major, 401 int cc_minor) { 402 Tracing::TraceMe annotation("Compile PTX", /*is_expensive=*/true); 403 const string ptxas_path = 404 tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin", "ptxas"); 405 VLOG(2) << "Using ptxas at " << ptxas_path; 406 auto env = tensorflow::Env::Default(); 407 TF_RETURN_IF_ERROR(env->FileExists(ptxas_path)); 408 409 WarnIfBadPtxasVersion(ptxas_path); 410 411 // Write ptx into a temporary file. 412 string ptx_path; 413 if (!env->LocalTempFilename(&ptx_path)) { 414 return InternalError("couldn't get temp PTX file name"); 415 } 416 auto ptx_cleaner = tensorflow::gtl::MakeCleanup([&ptx_path] { 417 TF_CHECK_OK(tensorflow::Env::Default()->DeleteFile(ptx_path)); 418 }); 419 420 TF_RETURN_IF_ERROR(tensorflow::WriteStringToFile(env, ptx_path, ptx)); 421 VLOG(2) << "ptx written to: " << ptx_path; 422 423 // Invoke ptxas and collect its output. 424 string cubin_path; 425 if (!env->LocalTempFilename(&cubin_path)) { 426 return InternalError("couldn't get temp CUBIN file name"); 427 } 428 auto cubin_cleaner = tensorflow::gtl::MakeCleanup([&cubin_path] { 429 // CUBIN file may never be created, so the failure to delete it should not 430 // produce TF error. 431 tensorflow::Env::Default()->DeleteFile(cubin_path).IgnoreError(); 432 }); 433 tensorflow::SubProcess ptxas_info_dumper; 434 std::vector<string> ptxas_args = { 435 ptxas_path, ptx_path, "-o", cubin_path, 436 tensorflow::strings::StrCat("-arch=sm_", cc_major, cc_minor)}; 437 if (VLOG_IS_ON(2)) { 438 ptxas_args.push_back("-v"); 439 } 440 ptxas_info_dumper.SetProgram(ptxas_path, ptxas_args); 441 ptxas_info_dumper.SetChannelAction(tensorflow::CHAN_STDERR, 442 tensorflow::ACTION_PIPE); 443 if (!ptxas_info_dumper.Start()) { 444 return InternalError("Failed to launch ptxas"); 445 } 446 string stderr_output; 447 int exit_status = ptxas_info_dumper.Communicate( 448 /*stdin_input=*/nullptr, /*stdout_output=*/nullptr, &stderr_output); 449 XLA_LOG_LINES(tensorflow::INFO, stderr_output); 450 if (exit_status != 0) { 451 return InternalError("ptxas exited with non-zero error code %d", 452 exit_status); 453 } 454 455 // Read in the result of compilation and return it as a byte vector. 456 string cubin; 457 TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(), 458 cubin_path, &cubin)); 459 std::vector<uint8> cubin_vector(cubin.begin(), cubin.end()); 460 return cubin_vector; 461 } 462 463 } // namespace 464 465 GpuCompiler::GpuCompiler() 466 : pointer_size_(llvm::DataLayout(kDataLayout) 467 .getPointerSize(0 /* default address space */)) {} 468 469 StatusOr<std::unique_ptr<HloModule>> GpuCompiler::RunHloPasses( 470 std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec, 471 DeviceMemoryAllocator* device_allocator) { 472 XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses"); 473 Tracing::TraceMe annotation("HLO Transforms", module->name(), 474 /*is_expensive=*/true); 475 TF_RETURN_IF_ERROR( 476 OptimizeHloModule(module.get(), stream_exec, device_allocator)); 477 return std::move(module); 478 } 479 480 StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend( 481 std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec, 482 DeviceMemoryAllocator* device_allocator) { 483 XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend"); 484 485 TF_RET_CHECK(stream_exec != nullptr); 486 487 TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get())); 488 489 llvm::LLVMContext llvm_context; 490 std::string buffer; 491 llvm::raw_string_ostream error(buffer); 492 llvm::DiagnosticPrinterRawOStream printer(error); 493 auto DiagnosticHandler = [](const llvm::DiagnosticInfo& diag_info, 494 void* Context) { 495 auto printer = static_cast<llvm::DiagnosticPrinterRawOStream*>(Context); 496 diag_info.print(*printer); 497 }; 498 llvm_context.setDiagnosticHandlerCallBack(DiagnosticHandler, &printer); 499 500 llvm::Module llvm_module(module->name().c_str(), llvm_context); 501 // Set the target triple and the data layout. 502 llvm_module.setTargetTriple(kTargetTriple); 503 llvm_module.setDataLayout(kDataLayout); 504 505 // Determine the HLO schedule, which is an ordering of HLO instructions. This 506 // is used by buffer assignment to enable buffer reuse, and the same ordering 507 // must also be used to determine the thunk launch schedule. 508 std::unique_ptr<StreamAssignment> stream_assignment = AssignStreams(*module); 509 TF_ASSIGN_OR_RETURN( 510 std::unique_ptr<HloSchedule> hlo_schedule, 511 HloSchedule::Build(*module, *stream_assignment, pointer_size_)); 512 513 // Run buffer analysis on the HLO graph. This analysis figures out which 514 // temporary buffers are required to run the computation. 515 TF_ASSIGN_OR_RETURN( 516 std::unique_ptr<BufferAssignment> buffer_assignment, 517 BufferAssigner::Run(module.get(), hlo_schedule->ConsumeHloOrdering(), 518 BufferSizeBytesFunction(), 519 /*color_alignment=*/[](LogicalBuffer::Color) { 520 return kCudaMallocAlignBytes; 521 })); 522 // BufferAssignment::Stats::ToString() and BufferAssignment::ToString() 523 // include headers, so no need for us to print them ourselves. 524 XLA_VLOG_LINES(1, buffer_assignment->GetStats().ToString()); 525 XLA_VLOG_LINES(2, buffer_assignment->ToString()); 526 XLA_VLOG_LINES(2, module->ToString()); 527 const string xla_dump_optimized_hlo_proto_to = 528 module->config().debug_options().xla_dump_optimized_hlo_proto_to(); 529 if (!xla_dump_optimized_hlo_proto_to.empty()) { 530 HloProto proto = MakeHloProto(*module, *buffer_assignment); 531 TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( 532 proto, xla_dump_optimized_hlo_proto_to, module->name())); 533 } 534 535 IrEmitterContext ir_emitter_context(module.get(), buffer_assignment.get(), 536 &stream_exec->GetDeviceDescription(), 537 &llvm_module); 538 539 HloComputation* entry_computation = module->entry_computation(); 540 IrEmitterUnnested ir_emitter(module->config(), entry_computation, 541 &ir_emitter_context); 542 { 543 XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission"); 544 TF_RETURN_IF_ERROR( 545 entry_computation->root_instruction()->Accept(&ir_emitter)); 546 } 547 548 if (user_pre_optimization_hook_) { 549 TF_CHECK_OK(user_pre_optimization_hook_(llvm_module)); 550 } 551 string ir_module_string_before_opt; 552 const bool embed_ir_in_executable = 553 module->config().debug_options().xla_embed_ir_in_executable(); 554 if (VLOG_IS_ON(2) || embed_ir_in_executable) { 555 ir_module_string_before_opt = llvm_ir::DumpModuleToString(llvm_module); 556 VLOG(2) << "LLVM module before optimizations:"; 557 XLA_VLOG_LINES(2, ir_module_string_before_opt); 558 } 559 560 const string& ir_dump_directory = 561 module->config().debug_options().xla_dump_ir_to(); 562 563 if (!ir_dump_directory.empty()) { 564 TF_RETURN_IF_ERROR(llvm_ir::DumpIRToDirectory( 565 /*directory_name=*/ir_dump_directory, 566 /*hlo_module_name=*/module->name(), llvm_module, 567 /*optimized=*/false)); 568 } 569 570 { 571 XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - Running LLVM verifier"); 572 573 std::string err; 574 llvm::raw_string_ostream err_stream(err); 575 576 // verifyModule() returns true if the module is broken. 577 TF_RET_CHECK(!llvm::verifyModule(llvm_module, &err_stream)) 578 << "Invalid LLVM IR before optimizations:\n" 579 << err_stream.str() 580 << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. " 581 "Rerun with --xla_dump_ir_to to get the IR. "; 582 } 583 584 string libdevice_dir; 585 { 586 tensorflow::mutex_lock lock(mutex_); 587 588 // Find the directory containing libdevice. To avoid searching for it every 589 // time, we have a one-element cache, keyed on the module's config's 590 // cuda_data_dir. 591 const auto& config_cuda_data_dir = 592 module->config().debug_options().xla_gpu_cuda_data_dir(); 593 if (cached_libdevice_dir_.empty() || 594 cached_cuda_data_dir_ != config_cuda_data_dir) { 595 cached_cuda_data_dir_ = config_cuda_data_dir; 596 cached_libdevice_dir_ = GetLibdeviceDir(config_cuda_data_dir); 597 } 598 libdevice_dir = cached_libdevice_dir_; 599 } 600 int cc_major, cc_minor; 601 if (!stream_exec->GetDeviceDescription().cuda_compute_capability(&cc_major, 602 &cc_minor)) { 603 LOG(WARNING) 604 << "Couldn't get compute capability for device; assuming sm_20."; 605 cc_major = 2; 606 cc_minor = 0; 607 } 608 609 string ptx; 610 { 611 XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - CompileToPtx"); 612 TF_ASSIGN_OR_RETURN(ptx, CompileToPtx(&llvm_module, {cc_major, cc_minor}, 613 module->config(), libdevice_dir)); 614 } 615 616 if (!ir_dump_directory.empty()) { 617 TF_RETURN_IF_ERROR(llvm_ir::DumpIRToDirectory( 618 /*directory_name=*/ir_dump_directory, 619 /*hlo_module_name=*/module->name(), llvm_module, 620 /*optimized=*/true)); 621 } 622 623 if (user_post_optimization_hook_) { 624 TF_CHECK_OK(user_post_optimization_hook_(llvm_module)); 625 } 626 VLOG(2) << "LLVM module after optimizations:"; 627 XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(llvm_module)); 628 VLOG(2) << "PTX:"; 629 XLA_VLOG_LINES(2, ptx); 630 631 // Write PTX to IR dump directory, if IR dumping was requested. 632 if (!ir_dump_directory.empty()) { 633 const string ptx_outfile = tensorflow::io::JoinPath( 634 ir_dump_directory, tensorflow::strings::StrCat(module->name(), ".ptx")); 635 auto status = [&] { 636 auto* env = tensorflow::Env::Default(); 637 TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(ir_dump_directory)); 638 TF_RETURN_IF_ERROR(tensorflow::WriteStringToFile(env, ptx_outfile, ptx)); 639 return Status::OK(); 640 }(); 641 if (!status.ok()) { 642 LOG(WARNING) << "Couldn't dump PTX for module " << module->name() 643 << " to " << ptx_outfile << ": " << status; 644 } 645 } 646 647 const std::vector<uint8> cubin = 648 CompilePtxOrGetCachedResult(ptx, cc_major, cc_minor); 649 650 auto thunk_schedule = MakeUnique<ThunkSchedule>( 651 ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment), 652 hlo_schedule->ThunkLaunchOrder()); 653 VLOG(2) << "Printing the thunk schedule..."; 654 XLA_VLOG_LINES(2, thunk_schedule->ToString()); 655 656 std::unique_ptr<HloProfileIndexMap> profile_index_map; 657 std::unique_ptr<HloProfilePrinterData> profile_printer; 658 659 if (module->config().hlo_profiling_enabled()) { 660 HloCostAnalysis cost_analysis(ShapeSizeBytesFunction()); 661 TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis)); 662 profile_index_map = MakeUnique<HloProfileIndexMap>(*module); 663 profile_printer = 664 CreateHloProfilePrinterData(*profile_index_map, cost_analysis); 665 } 666 667 auto* gpu_executable = new GpuExecutable( 668 ptx, cubin, {cc_major, cc_minor}, std::move(thunk_schedule), 669 std::move(module), std::move(buffer_assignment), 670 std::move(profile_printer), std::move(profile_index_map)); 671 if (embed_ir_in_executable) { 672 DCHECK_NE("", ir_module_string_before_opt); 673 gpu_executable->set_ir_module_string(ir_module_string_before_opt); 674 } 675 return std::unique_ptr<Executable>(gpu_executable); 676 } 677 678 std::vector<uint8> GpuCompiler::CompilePtxOrGetCachedResult(const string& ptx, 679 int cc_major, 680 int cc_minor) { 681 XLA_SCOPED_LOGGING_TIMER("GpuCompiler::CompilePtxOrGetCachedResult"); 682 Tracing::TraceMe annotation("PTX->CUBIN", /*is_expensive=*/true); 683 bool inserted; 684 decltype(compilation_cache_.begin()) iter; 685 // Pointers into compilation_cache_ where the ptx and (optional) cubin are 686 // stored. 687 const string* cache_ptx = nullptr; 688 CompilationCacheValue* cache_value = nullptr; 689 690 { 691 tensorflow::mutex_lock lock(mutex_); 692 std::tie(iter, inserted) = compilation_cache_.emplace( 693 std::piecewise_construct, 694 std::forward_as_tuple(ptx, cc_major, cc_minor), 695 std::forward_as_tuple()); 696 cache_ptx = &iter->first.ptx; 697 cache_value = &iter->second; 698 } 699 700 // Compile the ptx if it wasn't in the cache before we called this function. 701 // Other threads asking for the same compilation key will block on 702 // cache_value->mutex_ until compilation is done. 703 { 704 tensorflow::mutex_lock lock(cache_value->mutex_); 705 if (inserted) { 706 CHECK(!cache_value->compilation_done); 707 if (!ptx.empty()) { 708 StatusOr<std::vector<uint8>> maybe_cubin = 709 CompilePtx(*cache_ptx, cc_major, cc_minor); 710 if (maybe_cubin.ok()) { 711 cache_value->cubin_data = std::move(maybe_cubin).ValueOrDie(); 712 VLOG(2) << "Compiled PTX size:" << ptx.size() 713 << " CUBIN size: " << cache_value->cubin_data.size(); 714 } else { 715 bool log_warning = true; 716 if (maybe_cubin.status().code() == 717 tensorflow::error::Code::NOT_FOUND) { 718 // Missing ptxas is expected in some environments where CUDA SDK 719 // binaries are not available. We don't want to spam logs with 720 // identical warnings in this case. 721 722 // TODO(zhengxq): we should implement a LOG_FIRST_N and LOG_EVERY_N 723 // for more general usage. 724 static std::atomic<bool> warning_done(false); 725 log_warning = !warning_done.exchange(true); 726 } 727 if (log_warning) { 728 LOG(WARNING) 729 << "Failed to compile ptx to cubin. Will attempt to let " 730 "GPU driver compile the ptx. " 731 << maybe_cubin.status(); 732 } 733 734 // We're going to use the driver to JIT our PTX->SASS, so warn if 735 // the JIT in the driver has known bugs. 736 WarnIfBadDriverJITVersion(); 737 } 738 } 739 cache_value->compilation_done = true; 740 cache_value->compilation_done_cv_.notify_all(); 741 } else { 742 while (!cache_value->compilation_done) { 743 cache_value->compilation_done_cv_.wait(lock); 744 } 745 } 746 } 747 748 CHECK(cache_value != nullptr); 749 CHECK(cache_value->compilation_done); 750 return cache_value->cubin_data; 751 } 752 753 StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> 754 GpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> module, 755 const AotCompilationOptions& options) { 756 return Unimplemented("not yet implemented: GpuCompiler::CompileAheadOfTime"); 757 } 758 759 se::Platform::Id GpuCompiler::PlatformId() const { 760 return se::cuda::kCudaPlatformId; 761 } 762 763 } // namespace gpu 764 } // namespace xla 765 766 static bool InitModule() { 767 xla::Compiler::RegisterCompilerFactory(se::cuda::kCudaPlatformId, []() { 768 return xla::MakeUnique<xla::gpu::GpuCompiler>(); 769 }); 770 return true; 771 } 772 static bool module_initialized = InitModule(); 773