1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h" 17 18 #include "llvm/IR/Function.h" 19 #include "llvm/IR/IRBuilder.h" 20 #include "llvm/IR/Intrinsics.h" 21 #include "llvm/IR/Verifier.h" 22 #include "llvm/Transforms/Utils/Cloning.h" 23 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" 24 #include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h" 25 #include "tensorflow/core/platform/logging.h" 26 27 namespace xla { 28 namespace cpu { 29 namespace runtime { 30 31 const char* const kTanhV4F32SymbolName = "__xla_cpu_runtime_TanhV4F32"; 32 const char* const kTanhV8F32SymbolName = "__xla_cpu_runtime_TanhV8F32"; 33 const char* const kExpV4F32SymbolName = "__xla_cpu_runtime_ExpV4F32"; 34 const char* const kExpV8F32SymbolName = "__xla_cpu_runtime_ExpV8F32"; 35 const char* const kLogV4F32SymbolName = "__xla_cpu_runtime_LogV4F32AVX"; 36 const char* const kLogV8F32SymbolName = "__xla_cpu_runtime_LogV8F32AVX"; 37 38 namespace { 39 40 // Replaces calls to the function `fn_name` with the code generated by 41 // fn_body_generator. 42 // 43 // We assume that fn_name accepts either a scalar f32 or a vector of 44 // vector_width f32s, and that fn_body_generator generates a function body with 45 // the same inputs/outputs as fn_name. 46 void RewriteCalls( 47 llvm::Module* module, const char* fn_name, 48 std::function<llvm::Value*(llvm::IRBuilder<>* b, llvm::Value* input, 49 int32 vector_width)> 50 fn_body_generator, 51 int32 vector_width, bool enable_fast_math) { 52 llvm::Function* fn = module->getFunction(fn_name); 53 if (fn == nullptr) { 54 // If the function declaration is not present in the module, there can't be 55 // any calls to resolve. Don't emit the function in this case. 56 return; 57 } 58 59 // Our task is to generate a function body for `fn`, but we can't generate a 60 // function body for an LLVM intrinsic. So if fn is an intrinsic, replace it 61 // with a new function. 62 if (fn->isIntrinsic()) { 63 llvm::Function* new_fn = llvm::Function::Create( 64 fn->getFunctionType(), llvm::GlobalValue::InternalLinkage, 65 llvm::Twine("xla_impl.") + fn_name, module); 66 fn->replaceAllUsesWith(new_fn); 67 fn->eraseFromParent(); 68 fn = new_fn; 69 } 70 71 llvm::LLVMContext* context = &module->getContext(); 72 73 llvm::BasicBlock* fn_body = llvm::BasicBlock::Create(*context, "body", fn); 74 llvm::IRBuilder<> b(fn_body); 75 llvm::FastMathFlags fast_math_flags; 76 fast_math_flags.setFast(enable_fast_math); 77 b.setFastMathFlags(fast_math_flags); 78 79 llvm::Value* input = &*fn->arg_begin(); 80 81 // Upcast to vector type if input is a scalar. 82 if (vector_width == 1) { 83 llvm::Type* v1_type = llvm::VectorType::get(input->getType(), 1); 84 input = b.CreateInsertElement(llvm::UndefValue::get(v1_type), input, 85 uint64_t{0}); 86 } 87 88 // Generate the vectorized code. 89 CHECK_EQ(vector_width, input->getType()->getVectorNumElements()); 90 llvm::Value* result = fn_body_generator(&b, input, vector_width); 91 92 // Downcast result to scalar type if necessary. 93 if (vector_width == 1) { 94 result = b.CreateExtractElement(result, uint64_t{0}); 95 } 96 b.CreateRet(result); 97 DCHECK(!llvm::verifyFunction(*fn)); 98 99 // Force-inline `fn` into all of its callers and then delete `fn`. 100 // 101 // TODO(b/73081976): Should we avoid inlining these in some cases? 102 std::vector<llvm::CallInst*> calls_to_inline; 103 for (auto* user : fn->users()) { 104 calls_to_inline.push_back(llvm::cast<llvm::CallInst>(user)); 105 } 106 for (auto* call_to_inline : calls_to_inline) { 107 llvm::InlineFunctionInfo inline_function_info; 108 CHECK(llvm::InlineFunction(call_to_inline, inline_function_info)); 109 } 110 fn->eraseFromParent(); 111 } 112 113 llvm::Value* GenerateVF32Tanh(llvm::IRBuilder<>* b, llvm::Value* input, 114 int32 /*vector_width*/) { 115 return llvm_ir::EmitFastTanh(b, input); 116 } 117 118 llvm::Value* GenerateVF32Exp(llvm::IRBuilder<>* b, llvm::Value* input, 119 int32 vector_width) { 120 VectorSupportLibrary vsl(F32, vector_width, b, "exp_f32"); 121 122 // This implements the same polynomial approximation as implemented in Eigen3. 123 124 const llvm::APFloat half = GetIeeeF32(0.5); 125 const llvm::APFloat one = GetIeeeF32(1.0); 126 127 const llvm::APFloat exp_hi = GetIeeeF32(88.3762626647950); 128 const llvm::APFloat exp_lo = GetIeeeF32(-88.3762626647949); 129 130 const llvm::APFloat cephes_LOG2EF = GetIeeeF32(1.44269504088896341); 131 const llvm::APFloat cephes_exp_C1 = GetIeeeF32(0.693359375); 132 const llvm::APFloat cephes_exp_C2 = GetIeeeF32(-2.12194440e-4); 133 134 const llvm::APFloat cephes_exp_p0 = GetIeeeF32(1.9875691500E-4); 135 const llvm::APFloat cephes_exp_p1 = GetIeeeF32(1.3981999507E-3); 136 const llvm::APFloat cephes_exp_p2 = GetIeeeF32(8.3334519073E-3); 137 const llvm::APFloat cephes_exp_p3 = GetIeeeF32(4.1665795894E-2); 138 const llvm::APFloat cephes_exp_p4 = GetIeeeF32(1.6666665459E-1); 139 const llvm::APFloat cephes_exp_p5 = GetIeeeF32(5.0000001201E-1); 140 141 llvm::Value* input_clamped = 142 vsl.Clamp(input, /*low=*/exp_lo, /*high=*/exp_hi); 143 llvm::Value* fx = vsl.Floor(vsl.MulAdd(input_clamped, cephes_LOG2EF, half)); 144 llvm::Value* tmp = vsl.Mul(cephes_exp_C1, fx); 145 llvm::Value* z = vsl.Mul(cephes_exp_C2, fx); 146 llvm::Value* x = vsl.Sub(input_clamped, tmp); 147 x = vsl.Sub(x, z); 148 z = vsl.Mul(x, x); 149 150 llvm::Value* y = vsl.MulAdd(x, cephes_exp_p0, cephes_exp_p1); 151 y = vsl.MulAdd(y, x, cephes_exp_p2); 152 y = vsl.MulAdd(y, x, cephes_exp_p3); 153 y = vsl.MulAdd(y, x, cephes_exp_p4); 154 y = vsl.MulAdd(y, x, cephes_exp_p5); 155 y = vsl.MulAdd(y, z, x); 156 y = vsl.Add(one, y); 157 158 // VectorSupportLibrary (intentionally) can't juggle more than one type at a 159 // time so drop down to IRBuilder for this bit. 160 llvm::Value* vector_constant_0x7f = 161 b->CreateVectorSplat(vector_width, b->getInt32(0x7f)); 162 llvm::Value* vector_constant_23 = 163 b->CreateVectorSplat(vector_width, b->getInt32(23)); 164 llvm::Type* i32_vector_type = 165 llvm::VectorType::get(b->getInt32Ty(), vector_width); 166 // fx is clamped so we don't have to worry about it being out of range for 167 // i32. 168 llvm::Value* emm0 = b->CreateFPToSI(fx, i32_vector_type); 169 emm0 = b->CreateAdd(emm0, vector_constant_0x7f); 170 emm0 = b->CreateShl(emm0, vector_constant_23); 171 llvm::Value* emm0_f32 = b->CreateBitCast(emm0, vsl.vector_type()); 172 173 return vsl.Max(vsl.Mul(y, emm0_f32), input); 174 } 175 176 llvm::Value* GenerateVF32Log(llvm::IRBuilder<>* b, llvm::Value* input, 177 int32 vector_width) { 178 VectorSupportLibrary vsl(F32, vector_width, b, "log_f32"); 179 180 const llvm::APFloat half = GetIeeeF32(0.5); 181 const llvm::APFloat one = GetIeeeF32(1.0); 182 183 // This implements the same polynomial approximation as implemented in Eigen3. 184 // Returns NaN for x < 0, -INF for x = 0 185 const llvm::APFloat cephes_SQRTHF = GetIeeeF32(0.707106781186547524); 186 const llvm::APFloat cephes_log_p0 = GetIeeeF32(7.0376836292E-2); 187 const llvm::APFloat cephes_log_p1 = GetIeeeF32(-1.1514610310E-1); 188 const llvm::APFloat cephes_log_p2 = GetIeeeF32(1.1676998740E-1); 189 const llvm::APFloat cephes_log_p3 = GetIeeeF32(-1.2420140846E-1); 190 const llvm::APFloat cephes_log_p4 = GetIeeeF32(+1.4249322787E-1); 191 const llvm::APFloat cephes_log_p5 = GetIeeeF32(-1.6668057665E-1); 192 const llvm::APFloat cephes_log_p6 = GetIeeeF32(+2.0000714765E-1); 193 const llvm::APFloat cephes_log_p7 = GetIeeeF32(-2.4999993993E-1); 194 const llvm::APFloat cephes_log_p8 = GetIeeeF32(+3.3333331174E-1); 195 const llvm::APFloat cephes_log_q1 = GetIeeeF32(-2.12194440e-4); 196 const llvm::APFloat cephes_log_q2 = GetIeeeF32(0.693359375); 197 198 // The smallest non denormalized float number. 199 const llvm::APFloat min_norm_pos = GetIeeeF32FromBitwiseRep(0x00800000); 200 const llvm::APFloat minus_inf = GetIeeeF32FromBitwiseRep(0xff800000); 201 const llvm::APFloat pos_inf = GetIeeeF32FromBitwiseRep(0x7f800000); 202 const llvm::APFloat inv_mant_mask = GetIeeeF32FromBitwiseRep(~0x7f800000); 203 204 // invalid_mask is set if x is negative or NaN (and therefore output 205 // must be NaN). 206 llvm::Value* invalid_mask = vsl.FCmpULEMask(input, vsl.GetZeroVector()); 207 llvm::Value* is_zero_mask = vsl.FCmpEQMask(input, vsl.GetZeroVector()); 208 llvm::Value* is_pos_inf_mask = vsl.FCmpEQMask(input, pos_inf); 209 210 // Cut off denormalized stuff. 211 llvm::Value* tmp0 = vsl.Max(min_norm_pos, input); 212 213 // VectorSupportLibrary (intentionally) can't juggle more than one type at a 214 // time so drop down to IRBuilder for this bit. 215 llvm::Value* vector_constant_0x7f = 216 b->CreateVectorSplat(vector_width, b->getInt32(0x7f)); 217 llvm::Value* vector_constant_23 = 218 b->CreateVectorSplat(vector_width, b->getInt32(23)); 219 llvm::Type* i32_vector_type = 220 llvm::VectorType::get(b->getInt32Ty(), vector_width); 221 222 llvm::Value* emm0 = b->CreateLShr(b->CreateBitCast(tmp0, i32_vector_type), 223 vector_constant_23); 224 225 // Keep only the fractional part. 226 tmp0 = vsl.FloatAnd(tmp0, inv_mant_mask); 227 tmp0 = vsl.FloatOr(tmp0, half); 228 229 emm0 = b->CreateSub(emm0, vector_constant_0x7f); 230 llvm::Value* e = vsl.Add(one, b->CreateSIToFP(emm0, vsl.vector_type())); 231 232 // part2: 233 // if( x < SQRTHF ) { 234 // e -= 1; 235 // x = x + x - 1.0; 236 // } else { x = x - 1.0; } 237 llvm::Value* mask = vsl.FCmpOLTMask(tmp0, cephes_SQRTHF); 238 llvm::Value* tmp1 = vsl.FloatAnd(tmp0, mask); 239 tmp0 = vsl.Sub(tmp0, one); 240 e = vsl.Sub(e, vsl.FloatAnd(mask, one)); 241 tmp0 = vsl.Add(tmp0, tmp1); 242 243 llvm::Value* x2 = vsl.Mul(tmp0, tmp0); 244 llvm::Value* x3 = vsl.Mul(x2, tmp0); 245 246 llvm::Value *y, *y1, *y2; 247 y = vsl.MulAdd(tmp0, cephes_log_p0, cephes_log_p1); 248 y1 = vsl.MulAdd(tmp0, cephes_log_p3, cephes_log_p4); 249 y2 = vsl.MulAdd(tmp0, cephes_log_p6, cephes_log_p7); 250 y = vsl.MulAdd(y, tmp0, cephes_log_p2); 251 y1 = vsl.MulAdd(y1, tmp0, cephes_log_p5); 252 y2 = vsl.MulAdd(y2, tmp0, cephes_log_p8); 253 y = vsl.MulAdd(y, x3, y1); 254 y = vsl.MulAdd(y, x3, y2); 255 y = vsl.Mul(y, x3); 256 257 y1 = vsl.Mul(cephes_log_q1, e); 258 llvm::Value* tmp2 = vsl.Mul(half, x2); 259 y = vsl.Add(y, y1); 260 tmp0 = vsl.Sub(tmp0, tmp2); 261 y2 = vsl.Mul(cephes_log_q2, e); 262 tmp0 = vsl.Add(tmp0, y); 263 tmp0 = vsl.Add(tmp0, y2); 264 265 // Contains +/-inf where +/-inf is the correct answer, otherwise 0. 266 llvm::Value* result_inf = vsl.FloatOr(vsl.FloatAnd(is_zero_mask, minus_inf), 267 vsl.FloatAnd(is_pos_inf_mask, pos_inf)); 268 269 // Contains a finite result or nan. This is the correct answer only if both 270 // result_minus_inf and result_pos_inf are both 0. 271 // 272 // (This implementation works because 0xffffffff is a nan.) 273 llvm::Value* result_finite_or_nan = vsl.FloatOr(tmp0, invalid_mask); 274 275 // Combine the above into a final result. 276 return vsl.FloatOr(result_inf, 277 vsl.FloatAndNot(vsl.FloatOr(is_zero_mask, is_pos_inf_mask), 278 result_finite_or_nan)); 279 } 280 } // namespace 281 282 void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) { 283 // Curry some params to RewriteCalls. 284 auto rewrite_calls = 285 std::bind(RewriteCalls, module, std::placeholders::_1, 286 std::placeholders::_2, std::placeholders::_3, enable_fast_math); 287 288 rewrite_calls("tanhf", GenerateVF32Tanh, /*vector_width=*/1); 289 rewrite_calls("llvm.tanh.f32", GenerateVF32Tanh, /*vector_width=*/1); 290 rewrite_calls(kTanhV4F32SymbolName, GenerateVF32Tanh, /*vector_width=*/4); 291 rewrite_calls(kTanhV8F32SymbolName, GenerateVF32Tanh, /*vector_width=*/8); 292 293 rewrite_calls("expf", GenerateVF32Exp, /*vector_width=*/1); 294 rewrite_calls("llvm.exp.f32", GenerateVF32Exp, /*vector_width=*/1); 295 rewrite_calls(kExpV4F32SymbolName, GenerateVF32Exp, /*vector_width=*/4); 296 rewrite_calls(kExpV8F32SymbolName, GenerateVF32Exp, /*vector_width=*/8); 297 298 rewrite_calls("logf", GenerateVF32Log, /*vector_width=*/1); 299 rewrite_calls("llvm.log.f32", GenerateVF32Log, /*vector_width=*/1); 300 rewrite_calls(kLogV4F32SymbolName, GenerateVF32Log, /*vector_width=*/4); 301 rewrite_calls(kLogV8F32SymbolName, GenerateVF32Log, /*vector_width=*/8); 302 } 303 304 } // namespace runtime 305 } // namespace cpu 306 } // namespace xla 307