1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" 17 18 #include <algorithm> 19 #include <iterator> 20 #include <memory> 21 #include <string> 22 #include <utility> 23 #include <vector> 24 25 #include "llvm/ADT/StringRef.h" 26 #include "llvm/Analysis/TargetLibraryInfo.h" 27 #include "llvm/Analysis/TargetTransformInfo.h" 28 #include "llvm/ExecutionEngine/ObjectMemoryBuffer.h" 29 #include "llvm/IR/LegacyPassManager.h" 30 #include "llvm/IR/Verifier.h" 31 #include "llvm/MC/MCContext.h" 32 #include "llvm/Object/ObjectFile.h" 33 #include "llvm/Support/raw_ostream.h" 34 #include "llvm/Target/TargetMachine.h" 35 #include "llvm/Transforms/IPO.h" 36 #include "llvm/Transforms/IPO/AlwaysInliner.h" 37 #include "llvm/Transforms/IPO/PassManagerBuilder.h" 38 #include "tensorflow/compiler/xla/ptr_util.h" 39 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" 40 #include "tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h" 41 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 42 #include "tensorflow/compiler/xla/statusor.h" 43 #include "tensorflow/compiler/xla/types.h" 44 #include "tensorflow/compiler/xla/util.h" 45 #include "tensorflow/core/platform/logging.h" 46 47 namespace xla { 48 namespace cpu { 49 50 /* Create filtered versions of the LLVM Pass Managers to filter out some 51 of the expensive passes. 52 Profiling: 53 learning/brain/google/xla/benchmarks:inception_cpu_benchmark 54 learning/brain/google/xla/benchmarks:cifarnet 55 pointed to LICM and IndVarSimplify as the hottest passes. 56 LICM is known to exhibit O(n^2) time in the number of instructions. 57 IndVarSimplify is slow due to SCEV. If loops are emitted in canonical form, 58 this pass is not necessary. 59 Disabling these as a starting point. 60 */ 61 // TODO(b/64227304) Creating a custom pass pipeline will replace this. 62 63 namespace { 64 class FilteredFunctionPassManager : public llvm::legacy::FunctionPassManager { 65 public: 66 FilteredFunctionPassManager(llvm::Module* m, bool disable_expensive_passes) 67 : llvm::legacy::FunctionPassManager(m), 68 disable_expensive_passes_(disable_expensive_passes) {} 69 void add(llvm::Pass* p) override { 70 llvm::legacy::FunctionPassManager::add(p); 71 } 72 73 private: 74 bool disable_expensive_passes_; 75 }; 76 77 class FilteredPassManager : public llvm::legacy::PassManager { 78 public: 79 explicit FilteredPassManager(bool disable_expensive_passes) 80 : disable_expensive_passes_(disable_expensive_passes) {} 81 void add(llvm::Pass* p) override { 82 if (disable_expensive_passes_) { 83 llvm::StringRef PassName = p->getPassName(); 84 if (PassName.contains("Unroll loops")) { 85 return; 86 } 87 } 88 llvm::legacy::PassManager::add(p); 89 } 90 91 private: 92 bool disable_expensive_passes_; 93 }; 94 } // anonymous namespace 95 96 llvm::object::OwningBinary<llvm::object::ObjectFile> CompilerFunctor:: 97 operator()(llvm::Module& module) const { 98 FilteredPassManager module_passes(disable_expensive_passes_); 99 FilteredFunctionPassManager function_passes(&module, 100 disable_expensive_passes_); 101 102 VLOG(2) << "IR before optimizations"; 103 XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module)); 104 105 if (pre_optimization_hook_) { 106 TF_CHECK_OK(pre_optimization_hook_(module)); 107 } 108 109 // Add the appropriate TargetLibraryInfo and TargetTransformInfo. 110 AddTargetInfoPasses(&module_passes); 111 112 // Build up optimization pipeline. 113 if (optimize_for_size_) { 114 // Optimizing for size turns on -O2 level optimizations. 115 // 116 // TODO(b/64153864): Although the code generator supports size_level = 2 to 117 // turn on more aggressive code size optimizations than size_level = 1, we 118 // pass size_level = 1 because in many cases a size_level of 2 does 119 // worse. Investigate why. 120 AddOptimizationPasses(&module_passes, &function_passes, /*opt_level=*/2, 121 /*size_level=*/1); 122 } else { 123 AddOptimizationPasses(&module_passes, &function_passes, 124 /*opt_level=*/opt_level_, /*size_level=*/0); 125 } 126 127 // Run optimization passes on module. 128 function_passes.doInitialization(); 129 130 CHECK(!llvm::verifyModule(module, &llvm::dbgs())); 131 132 for (auto func = module.begin(); func != module.end(); ++func) { 133 function_passes.run(*func); 134 } 135 function_passes.doFinalization(); 136 module_passes.run(module); 137 138 CHECK(!llvm::verifyModule(module, &llvm::dbgs())); 139 140 runtime::RewriteIRRuntimeFunctions(&module, enable_fast_math_); 141 142 // Buffer for holding machine code prior to constructing the ObjectFile. 143 llvm::SmallVector<char, 0> stream_buffer; 144 llvm::raw_svector_ostream ostream(stream_buffer); 145 146 VLOG(2) << "IR after optimizations"; 147 XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module)); 148 149 if (post_optimization_hook_) { 150 TF_CHECK_OK(post_optimization_hook_(module)); 151 } 152 153 // Generate code. 154 llvm::MCContext* mc_context; 155 llvm::legacy::PassManager codegen_passes; 156 target_machine_->addPassesToEmitMC(codegen_passes, mc_context, ostream); 157 codegen_passes.run(module); 158 159 // Construct ObjectFile from machine code buffer. 160 std::unique_ptr<llvm::MemoryBuffer> memory_buffer( 161 new llvm::ObjectMemoryBuffer(std::move(stream_buffer))); 162 llvm::Expected<std::unique_ptr<llvm::object::ObjectFile>> 163 object_file_or_error = llvm::object::ObjectFile::createObjectFile( 164 memory_buffer->getMemBufferRef()); 165 CHECK(object_file_or_error); 166 167 std::unique_ptr<llvm::object::ObjectFile> object_file = 168 std::move(object_file_or_error.get()); 169 if (VLOG_IS_ON(2)) { 170 StatusOr<DisassemblerResult> disassembly_status = 171 disassembler_->DisassembleObjectFile(*object_file); 172 if (disassembly_status.ok()) { 173 auto result = disassembly_status.ValueOrDie(); 174 XLA_VLOG_LINES(2, result.text); 175 VLOG(2) << "compiled code size: " << result.code_size_bytes << " bytes"; 176 } 177 } 178 179 return llvm::object::OwningBinary<llvm::object::ObjectFile>( 180 std::move(object_file), std::move(memory_buffer)); 181 } 182 183 static std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl() { 184 std::vector<llvm::VecDesc> result = { 185 {"tanhf", runtime::kTanhV4F32SymbolName, 4}, 186 {"llvm.tanh.f32", runtime::kTanhV4F32SymbolName, 4}, 187 188 {"tanhf", runtime::kTanhV8F32SymbolName, 8}, 189 {"llvm.tanh.f32", runtime::kTanhV8F32SymbolName, 8}, 190 191 {"expf", runtime::kExpV4F32SymbolName, 4}, 192 {"llvm.exp.f32", runtime::kExpV4F32SymbolName, 4}, 193 194 {"expf", runtime::kExpV8F32SymbolName, 8}, 195 {"llvm.exp.f32", runtime::kExpV8F32SymbolName, 8}, 196 197 {"logf", runtime::kLogV4F32SymbolName, 4}, 198 {"llvm.log.f32", runtime::kLogV4F32SymbolName, 4}, 199 200 {"logf", runtime::kLogV8F32SymbolName, 8}, 201 {"llvm.log.f32", runtime::kLogV8F32SymbolName, 8}, 202 }; 203 return result; 204 } 205 206 void CompilerFunctor::AddTargetInfoPasses( 207 llvm::legacy::PassManagerBase* passes) const { 208 llvm::Triple target_triple(target_machine_->getTargetTriple()); 209 auto target_library_info_impl = 210 MakeUnique<llvm::TargetLibraryInfoImpl>(target_triple); 211 target_library_info_impl->addVectorizableFunctions( 212 VectorFunctionsForTargetLibraryInfoImpl()); 213 passes->add( 214 new llvm::TargetLibraryInfoWrapperPass(*target_library_info_impl)); 215 passes->add(createTargetTransformInfoWrapperPass( 216 target_machine_->getTargetIRAnalysis())); 217 } 218 219 void CompilerFunctor::AddOptimizationPasses( 220 llvm::legacy::PassManagerBase* module_passes, 221 llvm::legacy::FunctionPassManager* function_passes, unsigned opt_level, 222 unsigned size_level) const { 223 llvm::PassManagerBuilder builder; 224 builder.OptLevel = opt_level; 225 builder.SizeLevel = size_level; 226 227 if (opt_level > 1) { 228 builder.Inliner = llvm::createFunctionInliningPass(); 229 } else { 230 // Only inline functions marked with "alwaysinline". 231 builder.Inliner = llvm::createAlwaysInlinerLegacyPass(); 232 } 233 234 builder.DisableUnitAtATime = false; 235 builder.DisableUnrollLoops = opt_level == 0; 236 builder.LoopVectorize = opt_level > 0 && size_level == 0; 237 builder.SLPVectorize = opt_level > 1 && size_level == 0; 238 239 builder.populateFunctionPassManager(*function_passes); 240 builder.populateModulePassManager(*module_passes); 241 } 242 243 } // namespace cpu 244 } // namespace xla 245