1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h" 17 18 #include "llvm/IR/Function.h" 19 #include "llvm/IR/IRBuilder.h" 20 #include "llvm/IR/Intrinsics.h" 21 #include "llvm/IR/Verifier.h" 22 #include "llvm/Transforms/Utils/Cloning.h" 23 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" 24 #include "tensorflow/core/lib/core/casts.h" 25 #include "tensorflow/core/platform/logging.h" 26 27 namespace xla { 28 namespace cpu { 29 namespace runtime { 30 31 const char* const kTanhV4F32SymbolName = "__xla_cpu_runtime_TanhV4F32"; 32 const char* const kTanhV8F32SymbolName = "__xla_cpu_runtime_TanhV8F32"; 33 const char* const kExpV4F32SymbolName = "__xla_cpu_runtime_ExpV4F32"; 34 const char* const kExpV8F32SymbolName = "__xla_cpu_runtime_ExpV8F32"; 35 const char* const kLogV4F32SymbolName = "__xla_cpu_runtime_LogV4F32AVX"; 36 const char* const kLogV8F32SymbolName = "__xla_cpu_runtime_LogV8F32AVX"; 37 38 namespace { 39 llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module, 40 llvm::StringRef function_name, 41 int vector_width, 42 bool enable_fast_math) { 43 llvm::Function* vector_tanh_function = module->getFunction(function_name); 44 if (vector_tanh_function == nullptr) { 45 // If the function declaration is not present in the module, there can't be 46 // any calls to resolve. Don't emit the function in this case. 47 return nullptr; 48 } 49 50 llvm::LLVMContext* context = &module->getContext(); 51 52 llvm::BasicBlock* vector_tanh_body = 53 llvm::BasicBlock::Create(*context, "body", vector_tanh_function); 54 55 llvm::IRBuilder<> ir_builder(vector_tanh_body); 56 llvm::FastMathFlags fast_math_flags; 57 fast_math_flags.setFast(); 58 ir_builder.setFastMathFlags(fast_math_flags); 59 60 VectorSupportLibrary vsl(F32, vector_width, &ir_builder, "tanh_f32"); 61 62 llvm::Value* input = &*vector_tanh_function->arg_begin(); 63 CHECK_EQ(input->getType(), vsl.vector_type()); 64 65 // This implements the same rational interpolant as implemented in Eigen3. 66 llvm::Value* input_clamped = 67 vsl.Clamp(input, /*low=*/GetIeeeF32(-9.0), /*high=*/GetIeeeF32(9.0)); 68 69 std::array<float, 7> numerator_coeffs{ 70 -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f, 71 5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f, 72 4.89352455891786e-03f}; 73 74 std::array<float, 4> denominator_coeffs{ 75 1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f, 76 4.89352518554385e-03f}; 77 78 llvm::Value* input_squared = vsl.Mul(input_clamped, input_clamped); 79 llvm::Value* numerator = vsl.SplatFloat(GetIeeeF32(numerator_coeffs[0])); 80 for (int i = 1; i < numerator_coeffs.size(); i++) { 81 numerator = 82 vsl.MulAdd(input_squared, numerator, GetIeeeF32(numerator_coeffs[i])); 83 } 84 85 numerator = vsl.Mul(input_clamped, numerator); 86 87 llvm::Value* denominator = vsl.SplatFloat(GetIeeeF32(denominator_coeffs[0])); 88 for (int i = 1; i < denominator_coeffs.size(); i++) { 89 denominator = vsl.MulAdd(input_squared, denominator, 90 GetIeeeF32(denominator_coeffs[i])); 91 } 92 93 llvm::Value* result = vsl.Div(numerator, denominator); 94 ir_builder.CreateRet(result); 95 96 DCHECK(!llvm::verifyFunction(*vector_tanh_function)); 97 return vector_tanh_function; 98 } 99 100 llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module, 101 llvm::StringRef function_name, 102 int vector_width, 103 bool enable_fast_math) { 104 llvm::Function* vector_exp_function = module->getFunction(function_name); 105 if (vector_exp_function == nullptr) { 106 // If the function declaration is not present in the module, there can't be 107 // any calls to resolve. Don't emit the function in this case. 108 return nullptr; 109 } 110 111 llvm::LLVMContext* context = &module->getContext(); 112 113 llvm::BasicBlock* vector_exp_body = 114 llvm::BasicBlock::Create(*context, "body", vector_exp_function); 115 116 llvm::IRBuilder<> ir_builder(vector_exp_body); 117 llvm::FastMathFlags fast_math_flags; 118 fast_math_flags.setFast(); 119 ir_builder.setFastMathFlags(fast_math_flags); 120 121 VectorSupportLibrary vsl(F32, vector_width, &ir_builder, "exp_f32"); 122 123 // This implements the same polynomial approximation as implemented in Eigen3. 124 125 const llvm::APFloat half = GetIeeeF32(0.5); 126 const llvm::APFloat one = GetIeeeF32(1.0); 127 128 const llvm::APFloat exp_hi = GetIeeeF32(88.3762626647950); 129 const llvm::APFloat exp_lo = GetIeeeF32(-88.3762626647949); 130 131 const llvm::APFloat cephes_LOG2EF = GetIeeeF32(1.44269504088896341); 132 const llvm::APFloat cephes_exp_C1 = GetIeeeF32(0.693359375); 133 const llvm::APFloat cephes_exp_C2 = GetIeeeF32(-2.12194440e-4); 134 135 const llvm::APFloat cephes_exp_p0 = GetIeeeF32(1.9875691500E-4); 136 const llvm::APFloat cephes_exp_p1 = GetIeeeF32(1.3981999507E-3); 137 const llvm::APFloat cephes_exp_p2 = GetIeeeF32(8.3334519073E-3); 138 const llvm::APFloat cephes_exp_p3 = GetIeeeF32(4.1665795894E-2); 139 const llvm::APFloat cephes_exp_p4 = GetIeeeF32(1.6666665459E-1); 140 const llvm::APFloat cephes_exp_p5 = GetIeeeF32(5.0000001201E-1); 141 142 llvm::Value* input = &*vector_exp_function->arg_begin(); 143 llvm::Value* input_clamped = 144 vsl.Clamp(input, /*low=*/exp_lo, /*high=*/exp_hi); 145 llvm::Value* fx = vsl.Floor(vsl.MulAdd(input_clamped, cephes_LOG2EF, half)); 146 llvm::Value* tmp = vsl.Mul(cephes_exp_C1, fx); 147 llvm::Value* z = vsl.Mul(cephes_exp_C2, fx); 148 llvm::Value* x = vsl.Sub(input_clamped, tmp); 149 x = vsl.Sub(x, z); 150 z = vsl.Mul(x, x); 151 152 llvm::Value* y = vsl.MulAdd(x, cephes_exp_p0, cephes_exp_p1); 153 y = vsl.MulAdd(y, x, cephes_exp_p2); 154 y = vsl.MulAdd(y, x, cephes_exp_p3); 155 y = vsl.MulAdd(y, x, cephes_exp_p4); 156 y = vsl.MulAdd(y, x, cephes_exp_p5); 157 y = vsl.MulAdd(y, z, x); 158 y = vsl.Add(one, y); 159 160 // VectorSupportLibrary (intentionally) can't juggle more than one type at a 161 // time so drop down to IRBuilder for this bit. 162 llvm::Value* vector_constant_0x7f = 163 ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(0x7f)); 164 llvm::Value* vector_constant_23 = 165 ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(23)); 166 llvm::Type* i32_vector_type = 167 llvm::VectorType::get(ir_builder.getInt32Ty(), vector_width); 168 // fx is clamped so we don't have to worry about it being out of range for 169 // i32. 170 llvm::Value* emm0 = ir_builder.CreateFPToSI(fx, i32_vector_type); 171 emm0 = ir_builder.CreateAdd(emm0, vector_constant_0x7f); 172 emm0 = ir_builder.CreateShl(emm0, vector_constant_23); 173 llvm::Value* emm0_f32 = ir_builder.CreateBitCast(emm0, vsl.vector_type()); 174 175 llvm::Value* result = vsl.Max(vsl.Mul(y, emm0_f32), input); 176 177 ir_builder.CreateRet(result); 178 179 DCHECK(!llvm::verifyFunction(*vector_exp_function)); 180 return vector_exp_function; 181 } 182 183 llvm::Function* EmitVectorF32LogIfNeeded(llvm::Module* module, 184 llvm::StringRef function_name, 185 int vector_width, 186 bool enable_fast_math) { 187 llvm::Function* vector_log_function = module->getFunction(function_name); 188 if (vector_log_function == nullptr) { 189 // If the function declaration is not present in the module, there can't be 190 // any calls to resolve. Don't emit the function in this case. 191 return nullptr; 192 } 193 194 llvm::LLVMContext* context = &module->getContext(); 195 196 llvm::BasicBlock* vector_log_body = 197 llvm::BasicBlock::Create(*context, "body", vector_log_function); 198 199 llvm::IRBuilder<> ir_builder(vector_log_body); 200 llvm::FastMathFlags fast_math_flags; 201 fast_math_flags.setFast(); 202 ir_builder.setFastMathFlags(fast_math_flags); 203 204 llvm::Value* input = &*vector_log_function->arg_begin(); 205 VectorSupportLibrary vsl(F32, vector_width, &ir_builder, "log_f32"); 206 207 const llvm::APFloat half = GetIeeeF32(0.5); 208 const llvm::APFloat one = GetIeeeF32(1.0); 209 210 // This implements the same polynomial approximation as implemented in Eigen3. 211 // Returns NaN for x < 0, -INF for x = 0 212 const llvm::APFloat cephes_SQRTHF = GetIeeeF32(0.707106781186547524); 213 const llvm::APFloat cephes_log_p0 = GetIeeeF32(7.0376836292E-2); 214 const llvm::APFloat cephes_log_p1 = GetIeeeF32(-1.1514610310E-1); 215 const llvm::APFloat cephes_log_p2 = GetIeeeF32(1.1676998740E-1); 216 const llvm::APFloat cephes_log_p3 = GetIeeeF32(-1.2420140846E-1); 217 const llvm::APFloat cephes_log_p4 = GetIeeeF32(+1.4249322787E-1); 218 const llvm::APFloat cephes_log_p5 = GetIeeeF32(-1.6668057665E-1); 219 const llvm::APFloat cephes_log_p6 = GetIeeeF32(+2.0000714765E-1); 220 const llvm::APFloat cephes_log_p7 = GetIeeeF32(-2.4999993993E-1); 221 const llvm::APFloat cephes_log_p8 = GetIeeeF32(+3.3333331174E-1); 222 const llvm::APFloat cephes_log_q1 = GetIeeeF32(-2.12194440e-4); 223 const llvm::APFloat cephes_log_q2 = GetIeeeF32(0.693359375); 224 225 // The smallest non denormalized float number. 226 const llvm::APFloat min_norm_pos = GetIeeeF32FromBitwiseRep(0x00800000); 227 const llvm::APFloat minus_inf = GetIeeeF32FromBitwiseRep(0xff800000); 228 const llvm::APFloat inv_mant_mask = GetIeeeF32FromBitwiseRep(~0x7f800000); 229 230 // invalid_mask is set if x is negative or NaN (and therefore output 231 // must be NaN). 232 llvm::Value* invalid_mask = vsl.FCmpULEMask(input, vsl.GetZeroVector()); 233 llvm::Value* iszero_mask = vsl.FCmpEQMask(input, vsl.GetZeroVector()); 234 235 // Cut off denormalized stuff. 236 input = vsl.Max(min_norm_pos, input); 237 238 // VectorSupportLibrary (intentionally) can't juggle more than one type at a 239 // time so drop down to IRBuilder for this bit. 240 llvm::Value* vector_constant_0x7f = 241 ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(0x7f)); 242 llvm::Value* vector_constant_23 = 243 ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(23)); 244 llvm::Type* i32_vector_type = 245 llvm::VectorType::get(ir_builder.getInt32Ty(), vector_width); 246 247 llvm::Value* emm0 = ir_builder.CreateLShr( 248 ir_builder.CreateBitCast(input, i32_vector_type), vector_constant_23); 249 250 // Keep only the fractional part. 251 input = vsl.FloatAnd(input, inv_mant_mask); 252 input = vsl.FloatOr(input, half); 253 254 emm0 = ir_builder.CreateSub(emm0, vector_constant_0x7f); 255 llvm::Value* e = 256 vsl.Add(one, ir_builder.CreateSIToFP(emm0, vsl.vector_type())); 257 258 // part2: 259 // if( x < SQRTHF ) { 260 // e -= 1; 261 // x = x + x - 1.0; 262 // } else { x = x - 1.0; } 263 llvm::Value* mask = vsl.FCmpOLTMask(input, cephes_SQRTHF); 264 llvm::Value* tmp = vsl.FloatAnd(input, mask); 265 input = vsl.Sub(input, one); 266 e = vsl.Sub(e, vsl.FloatAnd(mask, one)); 267 input = vsl.Add(input, tmp); 268 269 llvm::Value* x2 = vsl.Mul(input, input); 270 llvm::Value* x3 = vsl.Mul(x2, input); 271 272 llvm::Value *y, *y1, *y2; 273 y = vsl.MulAdd(input, cephes_log_p0, cephes_log_p1); 274 y1 = vsl.MulAdd(input, cephes_log_p3, cephes_log_p4); 275 y2 = vsl.MulAdd(input, cephes_log_p6, cephes_log_p7); 276 y = vsl.MulAdd(y, input, cephes_log_p2); 277 y1 = vsl.MulAdd(y1, input, cephes_log_p5); 278 y2 = vsl.MulAdd(y2, input, cephes_log_p8); 279 y = vsl.MulAdd(y, x3, y1); 280 y = vsl.MulAdd(y, x3, y2); 281 y = vsl.Mul(y, x3); 282 283 y1 = vsl.Mul(cephes_log_q1, e); 284 tmp = vsl.Mul(half, x2); 285 y = vsl.Add(y, y1); 286 input = vsl.Sub(input, tmp); 287 y2 = vsl.Mul(cephes_log_q2, e); 288 input = vsl.Add(input, y); 289 input = vsl.Add(input, y2); 290 291 // Negative arg will be NAN, 0 will be -INF. 292 llvm::Value* or_lhs = 293 vsl.FloatAndNot(iszero_mask, vsl.FloatOr(input, invalid_mask)); 294 llvm::Value* or_rhs = vsl.FloatAnd(iszero_mask, minus_inf); 295 llvm::Value* result = vsl.FloatOr(or_lhs, or_rhs); 296 297 ir_builder.CreateRet(result); 298 299 DCHECK(!llvm::verifyFunction(*vector_log_function)); 300 return vector_log_function; 301 } 302 } // namespace 303 304 void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) { 305 auto* tanh_v4f32 = 306 EmitVectorF32TanhIfNeeded(module, kTanhV4F32SymbolName, 307 /*vector_width=*/4, enable_fast_math); 308 auto* tanh_v8f32 = 309 EmitVectorF32TanhIfNeeded(module, kTanhV8F32SymbolName, 310 /*vector_width=*/8, enable_fast_math); 311 312 auto* exp_v4f32 = 313 EmitVectorF32ExpIfNeeded(module, kExpV4F32SymbolName, 314 /*vector_width=*/4, enable_fast_math); 315 auto* exp_v8f32 = 316 EmitVectorF32ExpIfNeeded(module, kExpV8F32SymbolName, 317 /*vector_width=*/8, enable_fast_math); 318 319 auto* log_v4f32 = 320 EmitVectorF32LogIfNeeded(module, kLogV4F32SymbolName, 321 /*vector_width=*/4, enable_fast_math); 322 auto* log_v8f32 = 323 EmitVectorF32LogIfNeeded(module, kLogV8F32SymbolName, 324 /*vector_width=*/8, enable_fast_math); 325 326 // Gather all the call sites, force inline them and then delete the vector 327 // function bodies. 328 // 329 // TODO(b/73081976): Should we avoid inlining these intrinsics in some cases? 330 331 std::vector<llvm::CallInst*> calls_to_inline; 332 for (auto* function : 333 {tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32, log_v4f32, log_v8f32}) { 334 if (function != nullptr) { 335 for (auto* user : function->users()) { 336 calls_to_inline.push_back(llvm::cast<llvm::CallInst>(user)); 337 } 338 } 339 } 340 341 for (auto* call_to_inline : calls_to_inline) { 342 llvm::InlineFunctionInfo inline_function_info; 343 CHECK(llvm::InlineFunction(call_to_inline, inline_function_info)); 344 } 345 346 for (auto* function : 347 {tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32, log_v4f32, log_v8f32}) { 348 if (function != nullptr) { 349 function->eraseFromParent(); 350 } 351 } 352 } 353 354 } // namespace runtime 355 } // namespace cpu 356 } // namespace xla 357