Home | History | Annotate | Download | only in cpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
     17 
     18 #include <algorithm>
     19 #include <iterator>
     20 #include <memory>
     21 #include <string>
     22 #include <utility>
     23 #include <vector>
     24 
     25 #include "llvm/ADT/StringRef.h"
     26 #include "llvm/Analysis/TargetLibraryInfo.h"
     27 #include "llvm/Analysis/TargetTransformInfo.h"
     28 #include "llvm/ExecutionEngine/ObjectMemoryBuffer.h"
     29 #include "llvm/IR/LegacyPassManager.h"
     30 #include "llvm/IR/Verifier.h"
     31 #include "llvm/MC/MCContext.h"
     32 #include "llvm/Object/ObjectFile.h"
     33 #include "llvm/Support/raw_ostream.h"
     34 #include "llvm/Target/TargetMachine.h"
     35 #include "llvm/Transforms/IPO.h"
     36 #include "llvm/Transforms/IPO/AlwaysInliner.h"
     37 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
     38 #include "tensorflow/compiler/xla/ptr_util.h"
     39 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
     40 #include "tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h"
     41 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
     42 #include "tensorflow/compiler/xla/statusor.h"
     43 #include "tensorflow/compiler/xla/types.h"
     44 #include "tensorflow/compiler/xla/util.h"
     45 #include "tensorflow/core/platform/logging.h"
     46 
     47 namespace xla {
     48 namespace cpu {
     49 
     50 /* Create filtered versions of the LLVM Pass Managers to filter out some
     51 of the expensive passes.
     52 Profiling:
     53    learning/brain/google/xla/benchmarks:inception_cpu_benchmark
     54    learning/brain/google/xla/benchmarks:cifarnet
     55 pointed to LICM and IndVarSimplify as the hottest passes.
     56 LICM is known to exhibit O(n^2) time in the number of instructions.
     57 IndVarSimplify is slow due to SCEV. If loops are emitted in canonical form,
     58 this pass is not necessary.
     59 Disabling these as a starting point.
     60 */
     61 // TODO(b/64227304) Creating a custom pass pipeline will replace this.
     62 
     63 namespace {
     64 class FilteredFunctionPassManager : public llvm::legacy::FunctionPassManager {
     65  public:
     66   FilteredFunctionPassManager(llvm::Module* m, bool disable_expensive_passes)
     67       : llvm::legacy::FunctionPassManager(m),
     68         disable_expensive_passes_(disable_expensive_passes) {}
     69   void add(llvm::Pass* p) override {
     70     llvm::legacy::FunctionPassManager::add(p);
     71   }
     72 
     73  private:
     74   bool disable_expensive_passes_;
     75 };
     76 
     77 class FilteredPassManager : public llvm::legacy::PassManager {
     78  public:
     79   explicit FilteredPassManager(bool disable_expensive_passes)
     80       : disable_expensive_passes_(disable_expensive_passes) {}
     81   void add(llvm::Pass* p) override {
     82     if (disable_expensive_passes_) {
     83       llvm::StringRef PassName = p->getPassName();
     84       if (PassName.contains("Unroll loops")) {
     85         return;
     86       }
     87     }
     88     llvm::legacy::PassManager::add(p);
     89   }
     90 
     91  private:
     92   bool disable_expensive_passes_;
     93 };
     94 }  // anonymous namespace
     95 
     96 llvm::object::OwningBinary<llvm::object::ObjectFile> CompilerFunctor::
     97 operator()(llvm::Module& module) const {
     98   FilteredPassManager module_passes(disable_expensive_passes_);
     99   FilteredFunctionPassManager function_passes(&module,
    100                                               disable_expensive_passes_);
    101 
    102   VLOG(2) << "IR before optimizations";
    103   XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module));
    104 
    105   if (pre_optimization_hook_) {
    106     TF_CHECK_OK(pre_optimization_hook_(module));
    107   }
    108 
    109   // Add the appropriate TargetLibraryInfo and TargetTransformInfo.
    110   AddTargetInfoPasses(&module_passes);
    111 
    112   // Build up optimization pipeline.
    113   if (optimize_for_size_) {
    114     // Optimizing for size turns on -O2 level optimizations.
    115     //
    116     // TODO(b/64153864): Although the code generator supports size_level = 2 to
    117     // turn on more aggressive code size optimizations than size_level = 1, we
    118     // pass size_level = 1 because in many cases a size_level of 2 does
    119     // worse. Investigate why.
    120     AddOptimizationPasses(&module_passes, &function_passes, /*opt_level=*/2,
    121                           /*size_level=*/1);
    122   } else {
    123     AddOptimizationPasses(&module_passes, &function_passes,
    124                           /*opt_level=*/opt_level_, /*size_level=*/0);
    125   }
    126 
    127   // Run optimization passes on module.
    128   function_passes.doInitialization();
    129 
    130   CHECK(!llvm::verifyModule(module, &llvm::dbgs()));
    131 
    132   for (auto func = module.begin(); func != module.end(); ++func) {
    133     function_passes.run(*func);
    134   }
    135   function_passes.doFinalization();
    136   module_passes.run(module);
    137 
    138   CHECK(!llvm::verifyModule(module, &llvm::dbgs()));
    139 
    140   runtime::RewriteIRRuntimeFunctions(&module, enable_fast_math_);
    141 
    142   // Buffer for holding machine code prior to constructing the ObjectFile.
    143   llvm::SmallVector<char, 0> stream_buffer;
    144   llvm::raw_svector_ostream ostream(stream_buffer);
    145 
    146   VLOG(2) << "IR after optimizations";
    147   XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module));
    148 
    149   if (post_optimization_hook_) {
    150     TF_CHECK_OK(post_optimization_hook_(module));
    151   }
    152 
    153   // Generate code.
    154   llvm::MCContext* mc_context;
    155   llvm::legacy::PassManager codegen_passes;
    156   target_machine_->addPassesToEmitMC(codegen_passes, mc_context, ostream);
    157   codegen_passes.run(module);
    158 
    159   // Construct ObjectFile from machine code buffer.
    160   std::unique_ptr<llvm::MemoryBuffer> memory_buffer(
    161       new llvm::ObjectMemoryBuffer(std::move(stream_buffer)));
    162   llvm::Expected<std::unique_ptr<llvm::object::ObjectFile>>
    163       object_file_or_error = llvm::object::ObjectFile::createObjectFile(
    164           memory_buffer->getMemBufferRef());
    165   CHECK(object_file_or_error);
    166 
    167   std::unique_ptr<llvm::object::ObjectFile> object_file =
    168       std::move(object_file_or_error.get());
    169   if (VLOG_IS_ON(2)) {
    170     StatusOr<DisassemblerResult> disassembly_status =
    171         disassembler_->DisassembleObjectFile(*object_file);
    172     if (disassembly_status.ok()) {
    173       auto result = disassembly_status.ValueOrDie();
    174       XLA_VLOG_LINES(2, result.text);
    175       VLOG(2) << "compiled code size: " << result.code_size_bytes << " bytes";
    176     }
    177   }
    178 
    179   return llvm::object::OwningBinary<llvm::object::ObjectFile>(
    180       std::move(object_file), std::move(memory_buffer));
    181 }
    182 
    183 static std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl() {
    184   std::vector<llvm::VecDesc> result = {
    185       {"tanhf", runtime::kTanhV4F32SymbolName, 4},
    186       {"llvm.tanh.f32", runtime::kTanhV4F32SymbolName, 4},
    187 
    188       {"tanhf", runtime::kTanhV8F32SymbolName, 8},
    189       {"llvm.tanh.f32", runtime::kTanhV8F32SymbolName, 8},
    190 
    191       {"expf", runtime::kExpV4F32SymbolName, 4},
    192       {"llvm.exp.f32", runtime::kExpV4F32SymbolName, 4},
    193 
    194       {"expf", runtime::kExpV8F32SymbolName, 8},
    195       {"llvm.exp.f32", runtime::kExpV8F32SymbolName, 8},
    196 
    197       {"logf", runtime::kLogV4F32SymbolName, 4},
    198       {"llvm.log.f32", runtime::kLogV4F32SymbolName, 4},
    199 
    200       {"logf", runtime::kLogV8F32SymbolName, 8},
    201       {"llvm.log.f32", runtime::kLogV8F32SymbolName, 8},
    202   };
    203   return result;
    204 }
    205 
    206 void CompilerFunctor::AddTargetInfoPasses(
    207     llvm::legacy::PassManagerBase* passes) const {
    208   llvm::Triple target_triple(target_machine_->getTargetTriple());
    209   auto target_library_info_impl =
    210       MakeUnique<llvm::TargetLibraryInfoImpl>(target_triple);
    211   target_library_info_impl->addVectorizableFunctions(
    212       VectorFunctionsForTargetLibraryInfoImpl());
    213   passes->add(
    214       new llvm::TargetLibraryInfoWrapperPass(*target_library_info_impl));
    215   passes->add(createTargetTransformInfoWrapperPass(
    216       target_machine_->getTargetIRAnalysis()));
    217 }
    218 
    219 void CompilerFunctor::AddOptimizationPasses(
    220     llvm::legacy::PassManagerBase* module_passes,
    221     llvm::legacy::FunctionPassManager* function_passes, unsigned opt_level,
    222     unsigned size_level) const {
    223   llvm::PassManagerBuilder builder;
    224   builder.OptLevel = opt_level;
    225   builder.SizeLevel = size_level;
    226 
    227   if (opt_level > 1) {
    228     builder.Inliner = llvm::createFunctionInliningPass();
    229   } else {
    230     // Only inline functions marked with "alwaysinline".
    231     builder.Inliner = llvm::createAlwaysInlinerLegacyPass();
    232   }
    233 
    234   builder.DisableUnitAtATime = false;
    235   builder.DisableUnrollLoops = opt_level == 0;
    236   builder.LoopVectorize = opt_level > 0 && size_level == 0;
    237   builder.SLPVectorize = opt_level > 1 && size_level == 0;
    238 
    239   builder.populateFunctionPassManager(*function_passes);
    240   builder.populateModulePassManager(*module_passes);
    241 }
    242 
    243 }  // namespace cpu
    244 }  // namespace xla
    245