Home | History | Annotate | Download | only in cpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
     17 
     18 #include <stdint.h>
     19 #include <algorithm>
     20 #include <list>
     21 #include <utility>
     22 
     23 #include "absl/memory/memory.h"
     24 #include "llvm/ExecutionEngine/ExecutionEngine.h"
     25 #include "llvm/ExecutionEngine/JITSymbol.h"
     26 #include "llvm/ExecutionEngine/SectionMemoryManager.h"
     27 #include "llvm/IR/Mangler.h"
     28 #include "llvm/Support/CodeGen.h"
     29 #include "llvm/Support/Host.h"
     30 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
     31 #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
     32 #include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h"
     33 #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h"
     34 #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h"
     35 #include "tensorflow/compiler/xla/service/cpu/runtime_fft.h"
     36 #include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h"
     37 #include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h"
     38 #include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h"
     39 #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
     40 #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
     41 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h"
     42 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h"
     43 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
     44 #include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h"
     45 #include "tensorflow/compiler/xla/types.h"
     46 #include "tensorflow/core/platform/logging.h"
     47 
     48 namespace xla {
     49 namespace cpu {
     50 namespace {
     51 
     52 llvm::SmallVector<std::string, 0> DetectMachineAttributes() {
     53   llvm::SmallVector<std::string, 0> result;
     54   llvm::StringMap<bool> host_features;
     55   if (llvm::sys::getHostCPUFeatures(host_features)) {
     56     for (auto& feature : host_features) {
     57       if (feature.second) {
     58         llvm::StringRef feature_name = feature.first();
     59         // Skip avx512 for now, it isn't quite ready in LLVM.
     60         if (feature_name.startswith("avx512")) {
     61           continue;
     62         }
     63         result.push_back(feature_name);
     64       }
     65     }
     66   }
     67   return result;
     68 }
     69 
     70 llvm::StringRef GetHostCpuName() {
     71   auto cpu_name = llvm::sys::getHostCPUName();
     72   // Skip avx512 for now, it isn't quite ready in LLVM.
     73   cpu_name.consume_back("-avx512");
     74   return cpu_name;
     75 }
     76 }  // namespace
     77 
     78 /*static*/ std::unique_ptr<llvm::TargetMachine>
     79 SimpleOrcJIT::InferTargetMachineForJIT(
     80     const llvm::TargetOptions& target_options,
     81     llvm::CodeGenOpt::Level opt_level) {
     82   std::unique_ptr<llvm::TargetMachine> target_machine(
     83       llvm::EngineBuilder()
     84           .setTargetOptions(target_options)
     85           .setOptLevel(opt_level)
     86           .selectTarget(
     87               /*TargetTriple=*/llvm::Triple(), /*MArch=*/"",
     88               /*MCPU=*/GetHostCpuName(),
     89               /*MAttrs=*/DetectMachineAttributes()));
     90   CHECK(target_machine != nullptr);
     91   return target_machine;
     92 }
     93 
     94 SimpleOrcJIT::SimpleOrcJIT(
     95     const llvm::TargetOptions& target_options,
     96     llvm::CodeGenOpt::Level opt_level, bool optimize_for_size,
     97     bool enable_fast_math, bool disable_expensive_passes,
     98     LLVMCompiler::ModuleHook pre_optimization_hook,
     99     LLVMCompiler::ModuleHook post_optimization_hook,
    100     std::function<void(const llvm::object::ObjectFile&)> post_codegen_hook)
    101     : target_machine_(InferTargetMachineForJIT(target_options, opt_level)),
    102       data_layout_(target_machine_->createDataLayout()),
    103       symbol_resolver_(llvm::orc::createLegacyLookupResolver(
    104           execution_session_,
    105           [this](const std::string& name) -> llvm::JITSymbol {
    106             return this->ResolveRuntimeSymbol(name);
    107           },
    108           [](llvm::Error Err) {
    109             cantFail(std::move(Err), "lookupFlags failed");
    110           })),
    111       object_layer_(
    112           execution_session_,
    113           [this](llvm::orc::VModuleKey) {
    114             llvm::orc::LegacyRTDyldObjectLinkingLayer::Resources result;
    115             result.MemMgr = std::make_shared<llvm::SectionMemoryManager>(
    116                 orc_jit_memory_mapper::GetInstance());
    117             result.Resolver = symbol_resolver_;
    118             return result;
    119           },
    120           /*NotifyLoaded=*/
    121           llvm::orc::LegacyRTDyldObjectLinkingLayer::NotifyLoadedFtor(),
    122           /*NotifyFinalized=*/
    123           [this](VModuleKeyT, const llvm::object::ObjectFile& object,
    124                  const llvm::RuntimeDyld::LoadedObjectInfo& object_info) {
    125             this->NotifyObjectFinalized(object, object_info);
    126           },
    127           /*NotifyFreed=*/
    128           [this](VModuleKeyT, const llvm::object::ObjectFile& object) {
    129             this->NotifyObjectFreed(object);
    130           }),
    131       compile_layer_(
    132           object_layer_,
    133           CompilerFunctor(target_machine_.get(), opt_level, optimize_for_size,
    134                           enable_fast_math, disable_expensive_passes,
    135                           std::move(pre_optimization_hook),
    136                           std::move(post_optimization_hook),
    137                           std::move(post_codegen_hook))),
    138       gdb_jit_event_listener_(
    139           llvm::JITEventListener::createGDBRegistrationListener()) {
    140   VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str()
    141           << " features: " << target_machine_->getTargetFeatureString().str();
    142 }
    143 
    144 llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) {
    145   void* func_addr = nullptr;
    146   if (name.size() > 1 && name.front() == data_layout_.getGlobalPrefix()) {
    147     // On Mac OS X, 'name' may have a leading underscore prefix, even though the
    148     // registered name may not.
    149     std::string stripped_name(name.begin() + 1, name.end());
    150     func_addr = CustomCallTargetRegistry::Global()->Lookup(stripped_name);
    151   } else {
    152     func_addr = CustomCallTargetRegistry::Global()->Lookup(name);
    153   }
    154 
    155   if (func_addr == nullptr) {
    156     LOG(ERROR) << "Unable to resolve runtime symbol: " << name;
    157     return nullptr;
    158   }
    159   llvm::JITEvaluatedSymbol symbol_info(reinterpret_cast<uint64_t>(func_addr),
    160                                        llvm::JITSymbolFlags::None);
    161   return symbol_info;
    162 }
    163 
    164 void SimpleOrcJIT::NotifyObjectFinalized(
    165     const llvm::object::ObjectFile& object,
    166     const llvm::RuntimeDyld::LoadedObjectInfo& object_info) {
    167   uint64_t key = static_cast<uint64_t>(
    168       reinterpret_cast<uintptr_t>(object.getData().data()));
    169   gdb_jit_event_listener_->notifyObjectLoaded(key, object, object_info);
    170 }
    171 
    172 void SimpleOrcJIT::NotifyObjectFreed(const llvm::object::ObjectFile& object) {
    173   uint64_t key = static_cast<uint64_t>(
    174       reinterpret_cast<uintptr_t>(object.getData().data()));
    175   gdb_jit_event_listener_->notifyFreeingObject(key);
    176 }
    177 
    178 SimpleOrcJIT::VModuleKeyT SimpleOrcJIT::AddModule(
    179     std::unique_ptr<llvm::Module> module) {
    180   auto key = execution_session_.allocateVModule();
    181   cantFail(compile_layer_.addModule(key, std::move(module)));
    182   module_keys_.push_back(key);
    183   return key;
    184 }
    185 
    186 void SimpleOrcJIT::RemoveModule(SimpleOrcJIT::VModuleKeyT key) {
    187   module_keys_.erase(std::remove(module_keys_.begin(), module_keys_.end(), key),
    188                      module_keys_.end());
    189   cantFail(compile_layer_.removeModule(key));
    190 }
    191 
    192 llvm::JITSymbol SimpleOrcJIT::FindCompiledSymbol(const std::string& name) {
    193   // Resolve symbol from last module to first, allowing later redefinitions of
    194   // symbols shadow earlier ones.
    195   for (auto& key :
    196        llvm::make_range(module_keys_.rbegin(), module_keys_.rend())) {
    197     if (auto symbol =
    198             compile_layer_.findSymbolIn(key, name,
    199                                         /*ExportedSymbolsOnly=*/true)) {
    200       return symbol;
    201     }
    202   }
    203 
    204   return nullptr;
    205 }
    206 
    207 namespace {
    208 // Register some known symbols with the CustomCallTargetRegistry.
    209 bool RegisterKnownJITSymbols() {
    210   CustomCallTargetRegistry* registry = CustomCallTargetRegistry::Global();
    211 
    212 #define REGISTER_CPU_RUNTIME_SYMBOL(base_name)                               \
    213   do {                                                                       \
    214     auto* function_address =                                                 \
    215         reinterpret_cast<void*>(__xla_cpu_runtime_##base_name);              \
    216     registry->Register(xla::cpu::runtime::k##base_name##SymbolName,          \
    217                        function_address);                                    \
    218     CHECK_EQ(absl::string_view(xla::cpu::runtime::k##base_name##SymbolName), \
    219              "__xla_cpu_runtime_" #base_name);                               \
    220   } while (false)
    221 
    222   REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue);
    223   REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation);
    224   REGISTER_CPU_RUNTIME_SYMBOL(MKLConvF32);
    225   REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF16);
    226   REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF32);
    227   REGISTER_CPU_RUNTIME_SYMBOL(EigenFft);
    228   REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF16);
    229   REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32);
    230   REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64);
    231   REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF32);
    232   REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF64);
    233   REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF32);
    234   REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF64);
    235   REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF16);
    236   REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32);
    237   REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedFft);
    238   REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF16);
    239   REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32);
    240   REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64);
    241   REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin);
    242   REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue);
    243   REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation);
    244   REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSort);
    245 
    246   registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee));
    247   registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee));
    248 
    249 #undef REGISTER_CPU_RUNTIME_SYMBOL
    250 
    251 // Register both the f32 (float) and f64 (double) versions of a libm symbol.
    252 // Unfortunately the double versions are overloaded on some systems, e.g.
    253 // Mac so we need an explicit cast. This requires passing the function signature
    254 // for that case.
    255 #define REGISTER_LIBM_SYMBOL(name, double_sig)                          \
    256   do {                                                                  \
    257     registry->Register(#name "f", reinterpret_cast<void*>(name##f));    \
    258     registry->Register(                                                 \
    259         #name, reinterpret_cast<void*>(static_cast<double_sig>(name))); \
    260   } while (false)
    261 
    262   REGISTER_LIBM_SYMBOL(acos, double (*)(double));
    263   REGISTER_LIBM_SYMBOL(acosh, double (*)(double));
    264   REGISTER_LIBM_SYMBOL(asin, double (*)(double));
    265   REGISTER_LIBM_SYMBOL(asinh, double (*)(double));
    266   REGISTER_LIBM_SYMBOL(atan, double (*)(double));
    267   REGISTER_LIBM_SYMBOL(atan2, double (*)(double, double));
    268   REGISTER_LIBM_SYMBOL(atanh, double (*)(double));
    269   REGISTER_LIBM_SYMBOL(cbrt, double (*)(double));
    270   REGISTER_LIBM_SYMBOL(ceil, double (*)(double));
    271   REGISTER_LIBM_SYMBOL(copysign, double (*)(double, double));
    272   REGISTER_LIBM_SYMBOL(cos, double (*)(double));
    273   REGISTER_LIBM_SYMBOL(cosh, double (*)(double));
    274   REGISTER_LIBM_SYMBOL(erf, double (*)(double));
    275   REGISTER_LIBM_SYMBOL(erfc, double (*)(double));
    276   REGISTER_LIBM_SYMBOL(exp, double (*)(double));
    277   REGISTER_LIBM_SYMBOL(exp2, double (*)(double));
    278   REGISTER_LIBM_SYMBOL(expm1, double (*)(double));
    279   REGISTER_LIBM_SYMBOL(fabs, double (*)(double));
    280   REGISTER_LIBM_SYMBOL(fdim, double (*)(double, double));
    281   REGISTER_LIBM_SYMBOL(floor, double (*)(double));
    282   REGISTER_LIBM_SYMBOL(fma, double (*)(double, double, double));
    283   REGISTER_LIBM_SYMBOL(fmax, double (*)(double, double));
    284   REGISTER_LIBM_SYMBOL(fmin, double (*)(double, double));
    285   REGISTER_LIBM_SYMBOL(fmod, double (*)(double, double));
    286   REGISTER_LIBM_SYMBOL(frexp, double (*)(double, int*));
    287   REGISTER_LIBM_SYMBOL(hypot, double (*)(double, double));
    288   REGISTER_LIBM_SYMBOL(ilogb, int (*)(double));
    289   REGISTER_LIBM_SYMBOL(ldexp, double (*)(double, int));
    290   REGISTER_LIBM_SYMBOL(lgamma, double (*)(double));
    291   REGISTER_LIBM_SYMBOL(llrint, long long (*)(double));   // NOLINT(runtime/int)
    292   REGISTER_LIBM_SYMBOL(llround, long long (*)(double));  // NOLINT(runtime/int)
    293   REGISTER_LIBM_SYMBOL(log, double (*)(double));
    294   REGISTER_LIBM_SYMBOL(log10, double (*)(double));
    295   REGISTER_LIBM_SYMBOL(log1p, double (*)(double));
    296   REGISTER_LIBM_SYMBOL(log2, double (*)(double));
    297   REGISTER_LIBM_SYMBOL(logb, double (*)(double));
    298   REGISTER_LIBM_SYMBOL(lrint, long (*)(double));   // NOLINT(runtime/int)
    299   REGISTER_LIBM_SYMBOL(lround, long (*)(double));  // NOLINT(runtime/int)
    300   REGISTER_LIBM_SYMBOL(modf, double (*)(double, double*));
    301   REGISTER_LIBM_SYMBOL(nan, double (*)(const char*));
    302   REGISTER_LIBM_SYMBOL(nearbyint, double (*)(double));
    303   REGISTER_LIBM_SYMBOL(nextafter, double (*)(double, double));
    304   REGISTER_LIBM_SYMBOL(nexttoward, double (*)(double, long double));
    305   REGISTER_LIBM_SYMBOL(pow, double (*)(double, double));
    306   REGISTER_LIBM_SYMBOL(remainder, double (*)(double, double));
    307   REGISTER_LIBM_SYMBOL(remquo, double (*)(double, double, int*));
    308   REGISTER_LIBM_SYMBOL(rint, double (*)(double));
    309   REGISTER_LIBM_SYMBOL(round, double (*)(double));
    310   REGISTER_LIBM_SYMBOL(scalbln,
    311                        double (*)(double, long));  // NOLINT(runtime/int)
    312   REGISTER_LIBM_SYMBOL(scalbn, double (*)(double, int));
    313   REGISTER_LIBM_SYMBOL(sin, double (*)(double));
    314 #ifdef __APPLE__
    315   REGISTER_LIBM_SYMBOL(__sincos, void (*)(double, double*, double*));
    316   registry->Register("__sincosf_stret",
    317                      reinterpret_cast<void*>(__sincosf_stret));
    318   registry->Register("__sincos_stret", reinterpret_cast<void*>(__sincos_stret));
    319 #else
    320   REGISTER_LIBM_SYMBOL(sincos, void (*)(double, double*, double*));
    321 #endif
    322   REGISTER_LIBM_SYMBOL(sinh, double (*)(double));
    323   REGISTER_LIBM_SYMBOL(sqrt, double (*)(double));
    324   REGISTER_LIBM_SYMBOL(tan, double (*)(double));
    325   REGISTER_LIBM_SYMBOL(tanh, double (*)(double));
    326   REGISTER_LIBM_SYMBOL(tgamma, double (*)(double));
    327   REGISTER_LIBM_SYMBOL(trunc, double (*)(double));
    328 
    329 #undef REGISTER_LIBM_SYMBOL
    330 
    331   registry->Register("memcpy", reinterpret_cast<void*>(memcpy));
    332   registry->Register("memmove", reinterpret_cast<void*>(memmove));
    333   registry->Register("memset", reinterpret_cast<void*>(memset));
    334 
    335 #ifdef __APPLE__
    336   registry->Register("__bzero", reinterpret_cast<void*>(bzero));
    337   registry->Register("memset_pattern16",
    338                      reinterpret_cast<void*>(memset_pattern16));
    339 #endif
    340 
    341 #ifdef MEMORY_SANITIZER
    342   registry->Register("__msan_unpoison",
    343                      reinterpret_cast<void*>(__msan_unpoison));
    344 #endif
    345 
    346   return true;
    347 }
    348 
    349 bool unused = RegisterKnownJITSymbols();
    350 }  // namespace
    351 
    352 }  // namespace cpu
    353 }  // namespace xla
    354