Home | History | Annotate | Download | only in cpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SIMPLE_ORC_JIT_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SIMPLE_ORC_JIT_H_
     18 
     19 #include <memory>
     20 #include <string>
     21 #include <vector>
     22 
     23 #include "llvm/ADT/Triple.h"
     24 #include "llvm/ExecutionEngine/Orc/Core.h"
     25 #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
     26 #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
     27 #include "llvm/ExecutionEngine/Orc/SymbolStringPool.h"
     28 #include "llvm/IR/Module.h"
     29 #include "llvm/Target/TargetMachine.h"
     30 #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
     31 #include "tensorflow/compiler/xla/service/cpu/disassembler.h"
     32 #include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h"
     33 #include "tensorflow/compiler/xla/types.h"
     34 
     35 namespace xla {
     36 namespace cpu {
     37 
     38 // Simplified LLVM JIT based on the new Orc API.
     39 //
     40 // This class wraps Orc's functionality into a single interface that only
     41 // exposes what we need for XLA.
     42 //
     43 // Supports JIT-ing multiple modules but without cross-module linking.
     44 // Implements eager compilation - the module is lowered to binary as soon as
     45 // it's added to the JIT.
     46 class SimpleOrcJIT {
     47  public:
     48   using ObjLayerT = llvm::orc::RTDyldObjectLinkingLayer;
     49   using CompileFtor =
     50       std::function<llvm::object::OwningBinary<llvm::object::ObjectFile>(
     51           llvm::Module&)>;
     52   using CompileLayerT = llvm::orc::IRCompileLayer<ObjLayerT, CompileFtor>;
     53   using VModuleKeyT = llvm::orc::VModuleKey;
     54 
     55   // Create a new JIT, targeting the host architecture.
     56   // The |target_options| parameter allows customization of certain code
     57   // generation properties of the TargetMachine (whether or not float point math
     58   // can be reassociated, etc.).
     59   // The |opt_level| parameter controls the optimization level of the code
     60   // generator.
     61   // The |optimize_for_size| parameter specifies that the code generator should
     62   // optimize to reduce code size, potentially at the cost of performance.
     63   // The |disable_expensive_passes| parameter will disable certain optimization
     64   // passes
     65   // The |pre_optimization_hook| is invoked on the module before any IR
     66   // level optimizations are applied.
     67   // The |post_optimization_hook| is invoked on the module after all IR
     68   // level optimizations are applied.
     69   SimpleOrcJIT(const llvm::TargetOptions& target_options,
     70                llvm::CodeGenOpt::Level opt_level, bool optimize_for_size,
     71                bool enable_fast_math, bool disable_expensive_passes,
     72                LLVMCompiler::ModuleHook pre_optimization_hook,
     73                LLVMCompiler::ModuleHook post_optimization_hook);
     74 
     75   // Data layout this JIT was created with.
     76   const llvm::DataLayout& data_layout() const { return data_layout_; }
     77 
     78   // Target triple (host) this JIT was created with.
     79   const llvm::Triple& target_triple() const {
     80     return target_machine_->getTargetTriple();
     81   }
     82 
     83   // Add a module to the JIT. Returns an opaque key that can be used to later
     84   // remove this module.
     85   VModuleKeyT AddModule(std::unique_ptr<llvm::Module> module);
     86 
     87   // Remove a module from the JIT and free the memory associated with it.
     88   void RemoveModule(VModuleKeyT key);
     89 
     90   // Get the runtime address of the compiled symbol whose name is given. Returns
     91   // nullptr if the symbol cannot be found.
     92   llvm::JITSymbol FindCompiledSymbol(const std::string& name);
     93 
     94   llvm::TargetMachine* target_machine() const { return target_machine_.get(); }
     95 
     96   ExternalConstantPool* external_constant_pool() {
     97     return &external_constant_pool_;
     98   }
     99 
    100  private:
    101   llvm::JITSymbol ResolveRuntimeSymbol(const std::string& name);
    102 
    103   std::vector<VModuleKeyT> module_keys_;
    104   std::unique_ptr<llvm::TargetMachine> target_machine_;
    105   const Disassembler disassembler_;
    106   const llvm::DataLayout data_layout_;
    107   llvm::orc::SymbolStringPool string_pool_;
    108   llvm::orc::ExecutionSession execution_session_;
    109   std::shared_ptr<llvm::orc::SymbolResolver> symbol_resolver_;
    110   ObjLayerT object_layer_;
    111   CompileLayerT compile_layer_;
    112   ExternalConstantPool external_constant_pool_;
    113 };
    114 
    115 }  // namespace cpu
    116 }  // namespace xla
    117 
    118 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SIMPLE_ORC_JIT_H_
    119