1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" 17 18 #include <stdint.h> 19 #include <algorithm> 20 #include <list> 21 #include <utility> 22 23 #include "absl/memory/memory.h" 24 #include "llvm/ExecutionEngine/ExecutionEngine.h" 25 #include "llvm/ExecutionEngine/JITSymbol.h" 26 #include "llvm/ExecutionEngine/SectionMemoryManager.h" 27 #include "llvm/IR/Mangler.h" 28 #include "llvm/Support/CodeGen.h" 29 #include "llvm/Support/Host.h" 30 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" 31 #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" 32 #include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h" 33 #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h" 34 #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h" 35 #include "tensorflow/compiler/xla/service/cpu/runtime_fft.h" 36 #include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h" 37 #include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h" 38 #include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h" 39 #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" 40 #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" 41 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" 42 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h" 43 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" 44 #include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h" 45 #include "tensorflow/compiler/xla/types.h" 46 #include "tensorflow/core/platform/logging.h" 47 48 namespace xla { 49 namespace cpu { 50 namespace { 51 52 llvm::SmallVector<std::string, 0> DetectMachineAttributes() { 53 llvm::SmallVector<std::string, 0> result; 54 llvm::StringMap<bool> host_features; 55 if (llvm::sys::getHostCPUFeatures(host_features)) { 56 for (auto& feature : host_features) { 57 if (feature.second) { 58 llvm::StringRef feature_name = feature.first(); 59 // Skip avx512 for now, it isn't quite ready in LLVM. 60 if (feature_name.startswith("avx512")) { 61 continue; 62 } 63 result.push_back(feature_name); 64 } 65 } 66 } 67 return result; 68 } 69 70 llvm::StringRef GetHostCpuName() { 71 auto cpu_name = llvm::sys::getHostCPUName(); 72 // Skip avx512 for now, it isn't quite ready in LLVM. 73 cpu_name.consume_back("-avx512"); 74 return cpu_name; 75 } 76 } // namespace 77 78 /*static*/ std::unique_ptr<llvm::TargetMachine> 79 SimpleOrcJIT::InferTargetMachineForJIT( 80 const llvm::TargetOptions& target_options, 81 llvm::CodeGenOpt::Level opt_level) { 82 std::unique_ptr<llvm::TargetMachine> target_machine( 83 llvm::EngineBuilder() 84 .setTargetOptions(target_options) 85 .setOptLevel(opt_level) 86 .selectTarget( 87 /*TargetTriple=*/llvm::Triple(), /*MArch=*/"", 88 /*MCPU=*/GetHostCpuName(), 89 /*MAttrs=*/DetectMachineAttributes())); 90 CHECK(target_machine != nullptr); 91 return target_machine; 92 } 93 94 SimpleOrcJIT::SimpleOrcJIT( 95 const llvm::TargetOptions& target_options, 96 llvm::CodeGenOpt::Level opt_level, bool optimize_for_size, 97 bool enable_fast_math, bool disable_expensive_passes, 98 LLVMCompiler::ModuleHook pre_optimization_hook, 99 LLVMCompiler::ModuleHook post_optimization_hook, 100 std::function<void(const llvm::object::ObjectFile&)> post_codegen_hook) 101 : target_machine_(InferTargetMachineForJIT(target_options, opt_level)), 102 data_layout_(target_machine_->createDataLayout()), 103 symbol_resolver_(llvm::orc::createLegacyLookupResolver( 104 execution_session_, 105 [this](const std::string& name) -> llvm::JITSymbol { 106 return this->ResolveRuntimeSymbol(name); 107 }, 108 [](llvm::Error Err) { 109 cantFail(std::move(Err), "lookupFlags failed"); 110 })), 111 object_layer_( 112 execution_session_, 113 [this](llvm::orc::VModuleKey) { 114 llvm::orc::LegacyRTDyldObjectLinkingLayer::Resources result; 115 result.MemMgr = std::make_shared<llvm::SectionMemoryManager>( 116 orc_jit_memory_mapper::GetInstance()); 117 result.Resolver = symbol_resolver_; 118 return result; 119 }, 120 /*NotifyLoaded=*/ 121 llvm::orc::LegacyRTDyldObjectLinkingLayer::NotifyLoadedFtor(), 122 /*NotifyFinalized=*/ 123 [this](VModuleKeyT, const llvm::object::ObjectFile& object, 124 const llvm::RuntimeDyld::LoadedObjectInfo& object_info) { 125 this->NotifyObjectFinalized(object, object_info); 126 }, 127 /*NotifyFreed=*/ 128 [this](VModuleKeyT, const llvm::object::ObjectFile& object) { 129 this->NotifyObjectFreed(object); 130 }), 131 compile_layer_( 132 object_layer_, 133 CompilerFunctor(target_machine_.get(), opt_level, optimize_for_size, 134 enable_fast_math, disable_expensive_passes, 135 std::move(pre_optimization_hook), 136 std::move(post_optimization_hook), 137 std::move(post_codegen_hook))), 138 gdb_jit_event_listener_( 139 llvm::JITEventListener::createGDBRegistrationListener()) { 140 VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str() 141 << " features: " << target_machine_->getTargetFeatureString().str(); 142 } 143 144 llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) { 145 void* func_addr = nullptr; 146 if (name.size() > 1 && name.front() == data_layout_.getGlobalPrefix()) { 147 // On Mac OS X, 'name' may have a leading underscore prefix, even though the 148 // registered name may not. 149 std::string stripped_name(name.begin() + 1, name.end()); 150 func_addr = CustomCallTargetRegistry::Global()->Lookup(stripped_name); 151 } else { 152 func_addr = CustomCallTargetRegistry::Global()->Lookup(name); 153 } 154 155 if (func_addr == nullptr) { 156 LOG(ERROR) << "Unable to resolve runtime symbol: " << name; 157 return nullptr; 158 } 159 llvm::JITEvaluatedSymbol symbol_info(reinterpret_cast<uint64_t>(func_addr), 160 llvm::JITSymbolFlags::None); 161 return symbol_info; 162 } 163 164 void SimpleOrcJIT::NotifyObjectFinalized( 165 const llvm::object::ObjectFile& object, 166 const llvm::RuntimeDyld::LoadedObjectInfo& object_info) { 167 uint64_t key = static_cast<uint64_t>( 168 reinterpret_cast<uintptr_t>(object.getData().data())); 169 gdb_jit_event_listener_->notifyObjectLoaded(key, object, object_info); 170 } 171 172 void SimpleOrcJIT::NotifyObjectFreed(const llvm::object::ObjectFile& object) { 173 uint64_t key = static_cast<uint64_t>( 174 reinterpret_cast<uintptr_t>(object.getData().data())); 175 gdb_jit_event_listener_->notifyFreeingObject(key); 176 } 177 178 SimpleOrcJIT::VModuleKeyT SimpleOrcJIT::AddModule( 179 std::unique_ptr<llvm::Module> module) { 180 auto key = execution_session_.allocateVModule(); 181 cantFail(compile_layer_.addModule(key, std::move(module))); 182 module_keys_.push_back(key); 183 return key; 184 } 185 186 void SimpleOrcJIT::RemoveModule(SimpleOrcJIT::VModuleKeyT key) { 187 module_keys_.erase(std::remove(module_keys_.begin(), module_keys_.end(), key), 188 module_keys_.end()); 189 cantFail(compile_layer_.removeModule(key)); 190 } 191 192 llvm::JITSymbol SimpleOrcJIT::FindCompiledSymbol(const std::string& name) { 193 // Resolve symbol from last module to first, allowing later redefinitions of 194 // symbols shadow earlier ones. 195 for (auto& key : 196 llvm::make_range(module_keys_.rbegin(), module_keys_.rend())) { 197 if (auto symbol = 198 compile_layer_.findSymbolIn(key, name, 199 /*ExportedSymbolsOnly=*/true)) { 200 return symbol; 201 } 202 } 203 204 return nullptr; 205 } 206 207 namespace { 208 // Register some known symbols with the CustomCallTargetRegistry. 209 bool RegisterKnownJITSymbols() { 210 CustomCallTargetRegistry* registry = CustomCallTargetRegistry::Global(); 211 212 #define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \ 213 do { \ 214 auto* function_address = \ 215 reinterpret_cast<void*>(__xla_cpu_runtime_##base_name); \ 216 registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \ 217 function_address); \ 218 CHECK_EQ(absl::string_view(xla::cpu::runtime::k##base_name##SymbolName), \ 219 "__xla_cpu_runtime_" #base_name); \ 220 } while (false) 221 222 REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue); 223 REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation); 224 REGISTER_CPU_RUNTIME_SYMBOL(MKLConvF32); 225 REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF16); 226 REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF32); 227 REGISTER_CPU_RUNTIME_SYMBOL(EigenFft); 228 REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF16); 229 REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32); 230 REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64); 231 REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF32); 232 REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF64); 233 REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF32); 234 REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF64); 235 REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF16); 236 REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32); 237 REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedFft); 238 REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF16); 239 REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32); 240 REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64); 241 REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin); 242 REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue); 243 REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation); 244 REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSort); 245 246 registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee)); 247 registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee)); 248 249 #undef REGISTER_CPU_RUNTIME_SYMBOL 250 251 // Register both the f32 (float) and f64 (double) versions of a libm symbol. 252 // Unfortunately the double versions are overloaded on some systems, e.g. 253 // Mac so we need an explicit cast. This requires passing the function signature 254 // for that case. 255 #define REGISTER_LIBM_SYMBOL(name, double_sig) \ 256 do { \ 257 registry->Register(#name "f", reinterpret_cast<void*>(name##f)); \ 258 registry->Register( \ 259 #name, reinterpret_cast<void*>(static_cast<double_sig>(name))); \ 260 } while (false) 261 262 REGISTER_LIBM_SYMBOL(acos, double (*)(double)); 263 REGISTER_LIBM_SYMBOL(acosh, double (*)(double)); 264 REGISTER_LIBM_SYMBOL(asin, double (*)(double)); 265 REGISTER_LIBM_SYMBOL(asinh, double (*)(double)); 266 REGISTER_LIBM_SYMBOL(atan, double (*)(double)); 267 REGISTER_LIBM_SYMBOL(atan2, double (*)(double, double)); 268 REGISTER_LIBM_SYMBOL(atanh, double (*)(double)); 269 REGISTER_LIBM_SYMBOL(cbrt, double (*)(double)); 270 REGISTER_LIBM_SYMBOL(ceil, double (*)(double)); 271 REGISTER_LIBM_SYMBOL(copysign, double (*)(double, double)); 272 REGISTER_LIBM_SYMBOL(cos, double (*)(double)); 273 REGISTER_LIBM_SYMBOL(cosh, double (*)(double)); 274 REGISTER_LIBM_SYMBOL(erf, double (*)(double)); 275 REGISTER_LIBM_SYMBOL(erfc, double (*)(double)); 276 REGISTER_LIBM_SYMBOL(exp, double (*)(double)); 277 REGISTER_LIBM_SYMBOL(exp2, double (*)(double)); 278 REGISTER_LIBM_SYMBOL(expm1, double (*)(double)); 279 REGISTER_LIBM_SYMBOL(fabs, double (*)(double)); 280 REGISTER_LIBM_SYMBOL(fdim, double (*)(double, double)); 281 REGISTER_LIBM_SYMBOL(floor, double (*)(double)); 282 REGISTER_LIBM_SYMBOL(fma, double (*)(double, double, double)); 283 REGISTER_LIBM_SYMBOL(fmax, double (*)(double, double)); 284 REGISTER_LIBM_SYMBOL(fmin, double (*)(double, double)); 285 REGISTER_LIBM_SYMBOL(fmod, double (*)(double, double)); 286 REGISTER_LIBM_SYMBOL(frexp, double (*)(double, int*)); 287 REGISTER_LIBM_SYMBOL(hypot, double (*)(double, double)); 288 REGISTER_LIBM_SYMBOL(ilogb, int (*)(double)); 289 REGISTER_LIBM_SYMBOL(ldexp, double (*)(double, int)); 290 REGISTER_LIBM_SYMBOL(lgamma, double (*)(double)); 291 REGISTER_LIBM_SYMBOL(llrint, long long (*)(double)); // NOLINT(runtime/int) 292 REGISTER_LIBM_SYMBOL(llround, long long (*)(double)); // NOLINT(runtime/int) 293 REGISTER_LIBM_SYMBOL(log, double (*)(double)); 294 REGISTER_LIBM_SYMBOL(log10, double (*)(double)); 295 REGISTER_LIBM_SYMBOL(log1p, double (*)(double)); 296 REGISTER_LIBM_SYMBOL(log2, double (*)(double)); 297 REGISTER_LIBM_SYMBOL(logb, double (*)(double)); 298 REGISTER_LIBM_SYMBOL(lrint, long (*)(double)); // NOLINT(runtime/int) 299 REGISTER_LIBM_SYMBOL(lround, long (*)(double)); // NOLINT(runtime/int) 300 REGISTER_LIBM_SYMBOL(modf, double (*)(double, double*)); 301 REGISTER_LIBM_SYMBOL(nan, double (*)(const char*)); 302 REGISTER_LIBM_SYMBOL(nearbyint, double (*)(double)); 303 REGISTER_LIBM_SYMBOL(nextafter, double (*)(double, double)); 304 REGISTER_LIBM_SYMBOL(nexttoward, double (*)(double, long double)); 305 REGISTER_LIBM_SYMBOL(pow, double (*)(double, double)); 306 REGISTER_LIBM_SYMBOL(remainder, double (*)(double, double)); 307 REGISTER_LIBM_SYMBOL(remquo, double (*)(double, double, int*)); 308 REGISTER_LIBM_SYMBOL(rint, double (*)(double)); 309 REGISTER_LIBM_SYMBOL(round, double (*)(double)); 310 REGISTER_LIBM_SYMBOL(scalbln, 311 double (*)(double, long)); // NOLINT(runtime/int) 312 REGISTER_LIBM_SYMBOL(scalbn, double (*)(double, int)); 313 REGISTER_LIBM_SYMBOL(sin, double (*)(double)); 314 #ifdef __APPLE__ 315 REGISTER_LIBM_SYMBOL(__sincos, void (*)(double, double*, double*)); 316 registry->Register("__sincosf_stret", 317 reinterpret_cast<void*>(__sincosf_stret)); 318 registry->Register("__sincos_stret", reinterpret_cast<void*>(__sincos_stret)); 319 #else 320 REGISTER_LIBM_SYMBOL(sincos, void (*)(double, double*, double*)); 321 #endif 322 REGISTER_LIBM_SYMBOL(sinh, double (*)(double)); 323 REGISTER_LIBM_SYMBOL(sqrt, double (*)(double)); 324 REGISTER_LIBM_SYMBOL(tan, double (*)(double)); 325 REGISTER_LIBM_SYMBOL(tanh, double (*)(double)); 326 REGISTER_LIBM_SYMBOL(tgamma, double (*)(double)); 327 REGISTER_LIBM_SYMBOL(trunc, double (*)(double)); 328 329 #undef REGISTER_LIBM_SYMBOL 330 331 registry->Register("memcpy", reinterpret_cast<void*>(memcpy)); 332 registry->Register("memmove", reinterpret_cast<void*>(memmove)); 333 registry->Register("memset", reinterpret_cast<void*>(memset)); 334 335 #ifdef __APPLE__ 336 registry->Register("__bzero", reinterpret_cast<void*>(bzero)); 337 registry->Register("memset_pattern16", 338 reinterpret_cast<void*>(memset_pattern16)); 339 #endif 340 341 #ifdef MEMORY_SANITIZER 342 registry->Register("__msan_unpoison", 343 reinterpret_cast<void*>(__msan_unpoison)); 344 #endif 345 346 return true; 347 } 348 349 bool unused = RegisterKnownJITSymbols(); 350 } // namespace 351 352 } // namespace cpu 353 } // namespace xla 354